Skip to content
Snippets Groups Projects
Commit fce66d16 authored by Carl Philipp Klemm's avatar Carl Philipp Klemm
Browse files

make using torch execution easier for sweeps by adding a method of easily...

make using torch execution easier for sweeps by adding a method of easily creating the tensors taken by the torch model function
parent af577c66
No related branches found
No related tags found
No related merge requests found
......@@ -45,3 +45,5 @@ TorchScript exectution is performed differently than on the C++ and two extra fu
1. a 1d, (n) sized tensor containing the model parameters
2. a 1d, (n) sized tensor with omega values to calculate the impance at
* a 1d complex tensor is returned with the impedance values
* `eisgenerator.getModelParameters`
* returns the model paramters at the given index as a torch.Tensor
......@@ -16,9 +16,8 @@ eis.compileModel(model)
modelFn = eis.getModelFunction(model)
parameters = torch.empty((2))
parameters[0] = 1e-6
parameters[1] = 100
parameters = eis.getModelParameters(model, 0)
omegas = torch.logspace(0, 5, 10)
print(f'Parameters:\n{parameters}\nOmegas:\n{omegas}\n')
print(modelFn(parameters, omegas))
from ._core import __doc__, __version__, Model, DataPoint, Range, Log, ParseError, FileError, eisDistance, EisSpectra, eisNyquistDistance, ostream_redirect
from .execute import compileModel, getModelFunction
from .execute import compileModel, getModelFunction, getModelParameters
......@@ -35,3 +35,15 @@ def getModelFunction(model: Model):
return models[fn_name]
def getModelParameters(model: Model, index: int):
try:
import torch
except ModuleNotFoundError:
Log.print("Could not import torch, torch must be availble create torch paramter tensors.", Log.level.ERROR)
return False
model.resolveSteps(index)
eisparameters = model.getFlatParameters()
parameters = torch.Tensor(eisparameters)
return parameters
......@@ -54,6 +54,7 @@ PYBIND11_MODULE(_core, m)
.def("getCppCode", &Model::getCode)
.def("getTorchScript", &Model::getTorchScript)
.def("getCompiledFunctionName", &Model::getCompiledFunctionName)
.def("getFlatParameters", &Model::getFlatParameters)
.def("__repr__", &Model::getModelStr);
py::class_<DataPoint>(m, "DataPoint")
.def(py::init<std::complex<fvalue>, fvalue>(), py::arg("im") = std::complex<fvalue>(0, 0), py::arg("omega") = 100)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment