Select Git revision
main.cpp 4.58 KiB
#include <pybind11/pybind11.h>
#include <pybind11/complex.h>
#include <pybind11/stl.h>
#include <pybind11/stl_bind.h>
#include <pybind11/iostream.h>
#include <eisgenerator/model.h>
#include <eisgenerator/eistype.h>
#include <eisgenerator/log.h>
#include <eisgenerator/basicmath.h>
#include <vector>
#include <sstream>
namespace py = pybind11;
using namespace eis;
std::string reprEisSpectra(const EisSpectra& spectra)
{
std::stringstream ss;
spectra.saveToStream(ss);
return ss.str();
}
std::string reprDataPoint(const DataPoint& dp)
{
std::stringstream ss;
ss<<std::scientific;
ss<<'('<<dp.im.real()<<'+'<<dp.im.imag()<<"j)";
return ss.str();
}
void logPrint(const std::string& str, Log::Level level)
{
Log(level)<<str;
}
PYBIND11_MODULE(_core, m)
{
py::class_<Model>(m, "Model")
.def(py::init<const std::string&, size_t, bool>(),
py::arg("str"), py::arg("paramSweepCount") = 100, py::arg("defaultToRange") = false)
.def("execute", &Model::execute, py::arg("omaga"), py::arg("omaga") = 0)
.def("executeSweep", static_cast<std::vector<DataPoint> (Model::*)(const std::vector<fvalue>&, size_t)>(&Model::executeSweep),
py::arg("omega"), py::arg("index") = 0)
.def("executeAllSweeps", &Model::executeAllSweeps)
.def("getModelStr", &Model::getModelStr)
.def("setParamSweepCountClosestTotal", &Model::setParamSweepCountClosestTotal, py::arg("total"))
.def("getModelStrWithParam", static_cast<std::string (Model::*)(size_t)>(&Model::getModelStrWithParam),
py::arg("index") = 0)
.def("getUuid", &Model::getUuid)
.def("compile", &Model::compile)
.def("isReady", &Model::isReady)
.def("resolveSteps", &Model::resolveSteps)
.def("getRequiredStepsForSweeps", &Model::getRequiredStepsForSweeps)
.def("getCppCode", &Model::getCode)
.def("getTorchScript", &Model::getTorchScript)
.def("getCompiledFunctionName", &Model::getCompiledFunctionName)
.def("getFlatParameters", &Model::getFlatParameters)
.def("getParameterNames", &Model::getParameterNames)
.def("getParameterCount", &Model::getParameterCount)
.def("getRecommendedParamIndices", &Model::getRecommendedParamIndices)
.def("__repr__", &Model::getModelStr);
py::class_<DataPoint>(m, "DataPoint")
.def(py::init<std::complex<fvalue>, fvalue>(), py::arg("im") = std::complex<fvalue>(0, 0), py::arg("omega") = 100)
.def_readwrite("omega", &DataPoint::omega)
.def_readwrite("im", &DataPoint::im)
.def("__gt__", &DataPoint::operator>)
.def("__lt__", &DataPoint::operator<)
.def("__eq__", &DataPoint::operator==)