Skip to content
Snippets Groups Projects
Commit 2bde3497 authored by Carl Philipp Klemm's avatar Carl Philipp Klemm
Browse files

Finnaly manage to support torch script exection

Note: for varous reasons torchScript compilation va torch.jit was not possible,
on the python side the system insits on file resident code for traceing, with no recourse
on the c++ the object is subltly different from the code used internally in the compile
performed after traceing, breaking pytorchs usual bindings.
By shipping own code based on the traceing path in an c++ torch extension the c++ path to
torch::jit compilation would be possible, but this is a lot of effort.

instead the usual cpython and pytoch -> libtorch is used instead this negatively affects performance
when compeared to the usual case of torch::jit::compile used when using eisgenerator directly from c++
parent 8ce1484a
No related branches found
No related tags found
No related merge requests found
......@@ -3,9 +3,6 @@ cmake_minimum_required(VERSION 3.22)
project(eisgenerator VERSION "1.0")
if(SKBUILD)
# Scikit-Build does not add your site-packages to the search path
# automatically, so we need to add it _or_ the pybind11 specific directory
# here.
execute_process(
COMMAND "${PYTHON_EXECUTABLE}" -c
"import pybind11; print(pybind11.get_cmake_dir())"
......@@ -14,12 +11,12 @@ if(SKBUILD)
list(APPEND CMAKE_PREFIX_PATH "${_tmp_dir}")
endif()
# Now we can find pybind11
find_package(pybind11 CONFIG REQUIRED)
pybind11_add_module(_core MODULE src/main.cpp)
target_compile_definitions(_core PRIVATE VERSION_INFO=${PROJECT_VERSION})
target_link_libraries(_core PUBLIC -leisgenerator)
install(TARGETS _core DESTINATION .)
#!/bin/python
import eisgenerator as eis
import torch
eis.Log.level = eis.Log.level.ERROR
print(f'set log level: {eis.Log.level}')
model = eis.Model("cr")
print(f"model: {model}")
script = model.getTorchScript()
print(f'TorchScript:\n{script}')
eis.compileModel(model)
modelFn = eis.getModelFunction(model)
parameters = torch.empty((2))
parameters[0] = 1e-6
parameters[1] = 100
omegas = torch.logspace(0, 5, 10)
print(modelFn(parameters, omegas))
from ._core import __doc__, __version__, Model, DataPoint, Range, Log, ParseError, FileError, eisDistance, EisSpectra, eisNyquistDistance, ostream_redirect
from .execute import compileModel, getModelFunction
from ._core import Model, Log
models = dict()
def compileModel(model: Model) -> bool:
try:
import torch
except ModuleNotFoundError:
Log.print("Could not import torch, torch must be availble to compile a torch model.", Log.level.ERROR)
return False
script = model.getTorchScript()
if len(script) == 0:
Log.print("eisgenerator reports that this model can not be executed as a torch script.", Log.level.ERROR)
return False
try:
exec(script)
fn_name = model.getCompiledFunctionName()
models[fn_name] = locals()[fn_name]
except SyntaxError as err:
Log.print(f'could not compile model: {err}', Log.level.ERROR)
return False
return True
def getModelFunction(model: Model):
fn_name = model.getCompiledFunctionName()
if fn_name not in models:
Log.print("You must first compile a model before getting its function", Log.level.ERROR)
return None
return models[fn_name]
......@@ -29,6 +29,11 @@ std::string reprDataPoint(const DataPoint& dp)
return ss.str();
}
void logPrint(const std::string& str, Log::Level level)
{
Log(level)<<str;
}
PYBIND11_MODULE(_core, m)
{
py::class_<Model>(m, "Model")
......@@ -94,13 +99,14 @@ PYBIND11_MODULE(_core, m)
.def("getFvalueLabels", &EisSpectra::getFvalueLabels)
.def("saveToDisk", &EisSpectra::saveToDisk)
.def("__repr__", &reprEisSpectra);
py::class_<Log>(m, "Log")
.def_readwrite_static("level", &Log::level);
py::enum_<Log::Level>(m, "Level")
.value("DEBUG", Log::DEBUG)
.value("INFO", Log::INFO)
.value("WARN", Log::WARN)
.value("ERROR", Log::ERROR);
py::class_<Log>(m, "Log")
.def_readwrite_static("level", &Log::level)
.def_static("print", &logPrint, py::arg("string"), py::arg("level"));
py::register_exception<parse_errror>(m, "ParseError");
py::register_exception<file_error>(m, "FileError");
py::add_ostream_redirect(m, "ostream_redirect");
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment