diff --git a/CMakeLists.txt b/CMakeLists.txt index 0a671faaf9cd81a87a9aba0b0035f387b5bf5505..869a7f51d221cb58f6a4e9fe1ea8186f166a488e 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -3,23 +3,20 @@ cmake_minimum_required(VERSION 3.22) project(eisgenerator VERSION "1.0") if(SKBUILD) - # Scikit-Build does not add your site-packages to the search path - # automatically, so we need to add it _or_ the pybind11 specific directory - # here. - execute_process( - COMMAND "${PYTHON_EXECUTABLE}" -c - "import pybind11; print(pybind11.get_cmake_dir())" - OUTPUT_VARIABLE _tmp_dir - OUTPUT_STRIP_TRAILING_WHITESPACE COMMAND_ECHO STDOUT) - list(APPEND CMAKE_PREFIX_PATH "${_tmp_dir}") + execute_process( + COMMAND "${PYTHON_EXECUTABLE}" -c + "import pybind11; print(pybind11.get_cmake_dir())" + OUTPUT_VARIABLE _tmp_dir + OUTPUT_STRIP_TRAILING_WHITESPACE COMMAND_ECHO STDOUT) + list(APPEND CMAKE_PREFIX_PATH "${_tmp_dir}") endif() -# Now we can find pybind11 find_package(pybind11 CONFIG REQUIRED) pybind11_add_module(_core MODULE src/main.cpp) - target_compile_definitions(_core PRIVATE VERSION_INFO=${PROJECT_VERSION}) target_link_libraries(_core PUBLIC -leisgenerator) - install(TARGETS _core DESTINATION .) + + + diff --git a/example/exampleTorch.py b/example/exampleTorch.py new file mode 100755 index 0000000000000000000000000000000000000000..9ba9528d60f247a5381a00ea8c6851d53c74e550 --- /dev/null +++ b/example/exampleTorch.py @@ -0,0 +1,24 @@ +#!/bin/python + +import eisgenerator as eis +import torch + +eis.Log.level = eis.Log.level.ERROR +print(f'set log level: {eis.Log.level}') + +model = eis.Model("cr") +print(f"model: {model}") + +script = model.getTorchScript() +print(f'TorchScript:\n{script}') + +eis.compileModel(model) + +modelFn = eis.getModelFunction(model) + +parameters = torch.empty((2)) +parameters[0] = 1e-6 +parameters[1] = 100 +omegas = torch.logspace(0, 5, 10) + +print(modelFn(parameters, omegas)) diff --git a/src/eisgenerator/__init__.py b/src/eisgenerator/__init__.py index 2186174b276d5bfae753aee4f1f8683fd241aabd..1140aca0a38badedfc74b56d7d30fb48fa421128 100644 --- a/src/eisgenerator/__init__.py +++ b/src/eisgenerator/__init__.py @@ -1 +1,2 @@ from ._core import __doc__, __version__, Model, DataPoint, Range, Log, ParseError, FileError, eisDistance, EisSpectra, eisNyquistDistance, ostream_redirect +from .execute import compileModel, getModelFunction diff --git a/src/eisgenerator/execute.py b/src/eisgenerator/execute.py new file mode 100644 index 0000000000000000000000000000000000000000..7c76e64903ab00893de7489ab99d8444a42ed7e2 --- /dev/null +++ b/src/eisgenerator/execute.py @@ -0,0 +1,37 @@ +from ._core import Model, Log + +models = dict() + +def compileModel(model: Model) -> bool: + try: + import torch + except ModuleNotFoundError: + Log.print("Could not import torch, torch must be availble to compile a torch model.", Log.level.ERROR) + return False + + script = model.getTorchScript() + if len(script) == 0: + Log.print("eisgenerator reports that this model can not be executed as a torch script.", Log.level.ERROR) + return False + + try: + exec(script) + fn_name = model.getCompiledFunctionName() + models[fn_name] = locals()[fn_name] + except SyntaxError as err: + Log.print(f'could not compile model: {err}', Log.level.ERROR) + return False + + return True + + +def getModelFunction(model: Model): + fn_name = model.getCompiledFunctionName() + + if fn_name not in models: + Log.print("You must first compile a model before getting its function", Log.level.ERROR) + return None + + return models[fn_name] + + diff --git a/src/main.cpp b/src/main.cpp index 1edfbc1c7df244f70ec20093927364a57bc45ded..1108418f1c0788aacc66948994f4d0f7c5f07ed1 100644 --- a/src/main.cpp +++ b/src/main.cpp @@ -29,6 +29,11 @@ std::string reprDataPoint(const DataPoint& dp) return ss.str(); } +void logPrint(const std::string& str, Log::Level level) +{ + Log(level)<<str; +} + PYBIND11_MODULE(_core, m) { py::class_<Model>(m, "Model") @@ -94,13 +99,14 @@ PYBIND11_MODULE(_core, m) .def("getFvalueLabels", &EisSpectra::getFvalueLabels) .def("saveToDisk", &EisSpectra::saveToDisk) .def("__repr__", &reprEisSpectra); - py::class_<Log>(m, "Log") - .def_readwrite_static("level", &Log::level); py::enum_<Log::Level>(m, "Level") .value("DEBUG", Log::DEBUG) .value("INFO", Log::INFO) .value("WARN", Log::WARN) .value("ERROR", Log::ERROR); + py::class_<Log>(m, "Log") + .def_readwrite_static("level", &Log::level) + .def_static("print", &logPrint, py::arg("string"), py::arg("level")); py::register_exception<parse_errror>(m, "ParseError"); py::register_exception<file_error>(m, "FileError"); py::add_ostream_redirect(m, "ostream_redirect");