diff --git a/README.md b/README.md index 96274ab723f7ad7d1457a9748e4df26dbd13767d..7a816c6addc73a58865d369701c7a76dac66c1a1 100644 --- a/README.md +++ b/README.md @@ -45,3 +45,5 @@ TorchScript exectution is performed differently than on the C++ and two extra fu 1. a 1d, (n) sized tensor containing the model parameters 2. a 1d, (n) sized tensor with omega values to calculate the impance at * a 1d complex tensor is returned with the impedance values +* `eisgenerator.getModelParameters` + * returns the model paramters at the given index as a torch.Tensor diff --git a/example/exampleTorch.py b/example/exampleTorch.py index 9ba9528d60f247a5381a00ea8c6851d53c74e550..6cc9db6e11dfc10437a81d054a13bc5b27d940cb 100755 --- a/example/exampleTorch.py +++ b/example/exampleTorch.py @@ -16,9 +16,8 @@ eis.compileModel(model) modelFn = eis.getModelFunction(model) -parameters = torch.empty((2)) -parameters[0] = 1e-6 -parameters[1] = 100 +parameters = eis.getModelParameters(model, 0) omegas = torch.logspace(0, 5, 10) +print(f'Parameters:\n{parameters}\nOmegas:\n{omegas}\n') print(modelFn(parameters, omegas)) diff --git a/src/eisgenerator/__init__.py b/src/eisgenerator/__init__.py index 1140aca0a38badedfc74b56d7d30fb48fa421128..60a43a8c478257766fc3a226e7f25133c9c39b50 100644 --- a/src/eisgenerator/__init__.py +++ b/src/eisgenerator/__init__.py @@ -1,2 +1,2 @@ from ._core import __doc__, __version__, Model, DataPoint, Range, Log, ParseError, FileError, eisDistance, EisSpectra, eisNyquistDistance, ostream_redirect -from .execute import compileModel, getModelFunction +from .execute import compileModel, getModelFunction, getModelParameters diff --git a/src/eisgenerator/execute.py b/src/eisgenerator/execute.py index 7c76e64903ab00893de7489ab99d8444a42ed7e2..0176f9075f3ed9a986f40ceed8a31bf72ffc000e 100644 --- a/src/eisgenerator/execute.py +++ b/src/eisgenerator/execute.py @@ -35,3 +35,15 @@ def getModelFunction(model: Model): return models[fn_name] +def getModelParameters(model: Model, index: int): + try: + import torch + except ModuleNotFoundError: + Log.print("Could not import torch, torch must be availble create torch paramter tensors.", Log.level.ERROR) + return False + + model.resolveSteps(index) + eisparameters = model.getFlatParameters() + + parameters = torch.Tensor(eisparameters) + return parameters diff --git a/src/main.cpp b/src/main.cpp index 1108418f1c0788aacc66948994f4d0f7c5f07ed1..c297239ddcae8c82f369df875111e277e1fafefa 100644 --- a/src/main.cpp +++ b/src/main.cpp @@ -54,6 +54,7 @@ PYBIND11_MODULE(_core, m) .def("getCppCode", &Model::getCode) .def("getTorchScript", &Model::getTorchScript) .def("getCompiledFunctionName", &Model::getCompiledFunctionName) + .def("getFlatParameters", &Model::getFlatParameters) .def("__repr__", &Model::getModelStr); py::class_<DataPoint>(m, "DataPoint") .def(py::init<std::complex<fvalue>, fvalue>(), py::arg("im") = std::complex<fvalue>(0, 0), py::arg("omega") = 100)