Skip to content
Snippets Groups Projects
Select Git revision
  • cf0ef7856b1016e52f0967f2d70dee887d73b8d6
  • master default
2 results

main.cpp

Blame
  • main.cpp 4.58 KiB
    #include <pybind11/pybind11.h>
    #include <pybind11/complex.h>
    #include <pybind11/stl.h>
    #include <pybind11/stl_bind.h>
    #include <pybind11/iostream.h>
    
    #include <eisgenerator/model.h>
    #include <eisgenerator/eistype.h>
    #include <eisgenerator/log.h>
    #include <eisgenerator/basicmath.h>
    #include <vector>
    #include <sstream>
    
    namespace py = pybind11;
    
    using namespace eis;
    
    std::string reprEisSpectra(const EisSpectra& spectra)
    {
    	std::stringstream ss;
    	spectra.saveToStream(ss);
    	return ss.str();
    }
    
    std::string reprDataPoint(const DataPoint& dp)
    {
    	std::stringstream ss;
    	ss<<std::scientific;
    	ss<<'('<<dp.im.real()<<'+'<<dp.im.imag()<<"j)";
    	return ss.str();
    }
    
    void logPrint(const std::string& str, Log::Level level)
    {
    	Log(level)<<str;
    }
    
    PYBIND11_MODULE(_core, m)
    {
    	py::class_<Model>(m, "Model")
    		.def(py::init<const std::string&, size_t, bool>(),
    			 py::arg("str"), py::arg("paramSweepCount") = 100, py::arg("defaultToRange") = false)
    		.def("execute", &Model::execute, py::arg("omaga"), py::arg("omaga") = 0)
    		.def("executeSweep", static_cast<std::vector<DataPoint> (Model::*)(const std::vector<fvalue>&, size_t)>(&Model::executeSweep),
    			 py::arg("omega"), py::arg("index") = 0)
    		.def("executeAllSweeps", &Model::executeAllSweeps)
    		.def("getModelStr", &Model::getModelStr)
    		.def("setParamSweepCountClosestTotal", &Model::setParamSweepCountClosestTotal, py::arg("total"))
    		.def("getModelStrWithParam", static_cast<std::string (Model::*)(size_t)>(&Model::getModelStrWithParam),
    			 py::arg("index") = 0)
    		.def("getUuid", &Model::getUuid)
    		.def("compile", &Model::compile)
    		.def("isReady", &Model::isReady)
    		.def("resolveSteps", &Model::resolveSteps)
    		.def("getRequiredStepsForSweeps", &Model::getRequiredStepsForSweeps)
    		.def("getCppCode", &Model::getCode)
    		.def("getTorchScript", &Model::getTorchScript)
    		.def("getCompiledFunctionName", &Model::getCompiledFunctionName)
    		.def("getFlatParameters", &Model::getFlatParameters)
    		.def("getParameterNames", &Model::getParameterNames)
    		.def("getParameterCount", &Model::getParameterCount)
    		.def("getRecommendedParamIndices", &Model::getRecommendedParamIndices)
    		.def("__repr__", &Model::getModelStr);
    	py::class_<DataPoint>(m, "DataPoint")
    		.def(py::init<std::complex<fvalue>, fvalue>(), py::arg("im") = std::complex<fvalue>(0, 0), py::arg("omega") = 100)
    		.def_readwrite("omega", &DataPoint::omega)
    		.def_readwrite("im", &DataPoint::im)
    		.def("__gt__", &DataPoint::operator>)
    		.def("__lt__", &DataPoint::operator<)
    		.def("__eq__", &DataPoint::operator==)