From 2bde349787f6d612f8ae3ef7b2113f727dd41c0a Mon Sep 17 00:00:00 2001
From: Carl Philipp Klemm <philipp@uvos.xyz>
Date: Fri, 26 Jan 2024 15:27:38 +0100
Subject: [PATCH] Finnaly manage to support torch script exection Note: for
 varous reasons torchScript compilation va torch.jit was not possible, on the
 python side the system insits on file resident code for traceing, with no
 recourse on the c++ the object is subltly different from the code used
 internally in the compile performed after traceing, breaking pytorchs usual
 bindings. By shipping own code based on the traceing path in an c++ torch
 extension the c++ path to torch::jit compilation would be possible, but this
 is a lot of effort.

instead the usual cpython and pytoch -> libtorch is used instead this negatively affects performance
when compeared to the usual case of torch::jit::compile used when using eisgenerator directly from c++
---
 CMakeLists.txt               | 21 +++++++++-----------
 example/exampleTorch.py      | 24 +++++++++++++++++++++++
 src/eisgenerator/__init__.py |  1 +
 src/eisgenerator/execute.py  | 37 ++++++++++++++++++++++++++++++++++++
 src/main.cpp                 | 10 ++++++++--
 5 files changed, 79 insertions(+), 14 deletions(-)
 create mode 100755 example/exampleTorch.py
 create mode 100644 src/eisgenerator/execute.py

diff --git a/CMakeLists.txt b/CMakeLists.txt
index 0a671fa..869a7f5 100644
--- a/CMakeLists.txt
+++ b/CMakeLists.txt
@@ -3,23 +3,20 @@ cmake_minimum_required(VERSION 3.22)
 project(eisgenerator VERSION "1.0")
 
 if(SKBUILD)
-  # Scikit-Build does not add your site-packages to the search path
-  # automatically, so we need to add it _or_ the pybind11 specific directory
-  # here.
-  execute_process(
-    COMMAND "${PYTHON_EXECUTABLE}" -c
-            "import pybind11; print(pybind11.get_cmake_dir())"
-    OUTPUT_VARIABLE _tmp_dir
-    OUTPUT_STRIP_TRAILING_WHITESPACE COMMAND_ECHO STDOUT)
-  list(APPEND CMAKE_PREFIX_PATH "${_tmp_dir}")
+	execute_process(
+		COMMAND "${PYTHON_EXECUTABLE}" -c
+		"import pybind11; print(pybind11.get_cmake_dir())"
+		OUTPUT_VARIABLE _tmp_dir
+		OUTPUT_STRIP_TRAILING_WHITESPACE COMMAND_ECHO STDOUT)
+	list(APPEND CMAKE_PREFIX_PATH "${_tmp_dir}")
 endif()
 
-# Now we can find pybind11
 find_package(pybind11 CONFIG REQUIRED)
 
 pybind11_add_module(_core MODULE src/main.cpp)
-
 target_compile_definitions(_core PRIVATE VERSION_INFO=${PROJECT_VERSION})
 target_link_libraries(_core PUBLIC -leisgenerator)
-
 install(TARGETS _core DESTINATION .)
+
+
+
diff --git a/example/exampleTorch.py b/example/exampleTorch.py
new file mode 100755
index 0000000..9ba9528
--- /dev/null
+++ b/example/exampleTorch.py
@@ -0,0 +1,24 @@
+#!/bin/python
+
+import eisgenerator as eis
+import torch
+
+eis.Log.level = eis.Log.level.ERROR
+print(f'set log level: {eis.Log.level}')
+
+model = eis.Model("cr")
+print(f"model: {model}")
+
+script = model.getTorchScript()
+print(f'TorchScript:\n{script}')
+
+eis.compileModel(model)
+
+modelFn = eis.getModelFunction(model)
+
+parameters = torch.empty((2))
+parameters[0] = 1e-6
+parameters[1] = 100
+omegas = torch.logspace(0, 5, 10)
+
+print(modelFn(parameters, omegas))
diff --git a/src/eisgenerator/__init__.py b/src/eisgenerator/__init__.py
index 2186174..1140aca 100644
--- a/src/eisgenerator/__init__.py
+++ b/src/eisgenerator/__init__.py
@@ -1 +1,2 @@
 from ._core import __doc__, __version__, Model, DataPoint, Range, Log, ParseError, FileError, eisDistance, EisSpectra, eisNyquistDistance, ostream_redirect
+from .execute import compileModel, getModelFunction
diff --git a/src/eisgenerator/execute.py b/src/eisgenerator/execute.py
new file mode 100644
index 0000000..7c76e64
--- /dev/null
+++ b/src/eisgenerator/execute.py
@@ -0,0 +1,37 @@
+from ._core import Model, Log
+
+models = dict()
+
+def compileModel(model: Model) -> bool:
+	try:
+		import torch
+	except  ModuleNotFoundError:
+		Log.print("Could not import torch, torch must be availble to compile a torch model.", Log.level.ERROR)
+		return False
+
+	script = model.getTorchScript()
+	if len(script) == 0:
+		Log.print("eisgenerator reports that this model can not be executed as a torch script.", Log.level.ERROR)
+		return False
+
+	try:
+		exec(script)
+		fn_name = model.getCompiledFunctionName()
+		models[fn_name] = locals()[fn_name]
+	except SyntaxError as err:
+		Log.print(f'could not compile model: {err}', Log.level.ERROR)
+		return False
+
+	return True
+
+
+def getModelFunction(model: Model):
+	fn_name = model.getCompiledFunctionName()
+
+	if fn_name not in models:
+		Log.print("You must first compile a model before getting its function", Log.level.ERROR)
+		return None
+
+	return models[fn_name]
+
+
diff --git a/src/main.cpp b/src/main.cpp
index 1edfbc1..1108418 100644
--- a/src/main.cpp
+++ b/src/main.cpp
@@ -29,6 +29,11 @@ std::string reprDataPoint(const DataPoint& dp)
 	return ss.str();
 }
 
+void logPrint(const std::string& str, Log::Level level)
+{
+	Log(level)<<str;
+}
+
 PYBIND11_MODULE(_core, m)
 {
 	py::class_<Model>(m, "Model")
@@ -94,13 +99,14 @@ PYBIND11_MODULE(_core, m)
 		.def("getFvalueLabels", &EisSpectra::getFvalueLabels)
 		.def("saveToDisk", &EisSpectra::saveToDisk)
 		.def("__repr__", &reprEisSpectra);
-	py::class_<Log>(m, "Log")
-		.def_readwrite_static("level", &Log::level);
 	py::enum_<Log::Level>(m, "Level")
 		.value("DEBUG", Log::DEBUG)
 		.value("INFO", Log::INFO)
 		.value("WARN", Log::WARN)
 		.value("ERROR", Log::ERROR);
+	py::class_<Log>(m, "Log")
+		.def_readwrite_static("level", &Log::level)
+		.def_static("print", &logPrint, py::arg("string"), py::arg("level"));
 	py::register_exception<parse_errror>(m, "ParseError");
 	py::register_exception<file_error>(m, "FileError");
 	py::add_ostream_redirect(m, "ostream_redirect");
-- 
GitLab