diff --git a/cap.cpp b/cap.cpp index eb8078185e6a4054bee66742cacf111cf7bfa055..8fadbf759991e506523ea53600bc62b863b83f79 100644 --- a/cap.cpp +++ b/cap.cpp @@ -55,3 +55,12 @@ std::string Cap::getCode(std::vector<std::string>& parameters) std::string out = "std::complex<fvalue>(" + real + ", " + imag + ")"; return out; } + +std::string Cap::getTorchScript(std::vector<std::string>& parameters) +{ + parameters.push_back(getUniqueName() + "_0"); + + std::string N = "1/(" + parameters.back() + "*omegas)"; + std::string out = "0-" + N + "*1j"; + return out; +} diff --git a/componant.cpp b/componant.cpp index 72ec08ab4f4ed4d50f34a9bfe5c33bae31bc7916..1d099f045a6fde7d16e8258df1093f9a5eefbec1 100644 --- a/componant.cpp +++ b/componant.cpp @@ -76,6 +76,13 @@ bool Componant::compileable() std::string Componant::getCode(std::vector<std::string>& parameters) { + (void)parameters; + return std::string(); +} + +std::string Componant::getTorchScript(std::vector<std::string>& parameters) +{ + (void)parameters; return std::string(); } diff --git a/constantphase.cpp b/constantphase.cpp index 4a9f6eca12b6412262f60e23c5dfc192e9a5c3a9..5814e9bc9817f6acb18ea8e15cb46335ca331ad7 100644 --- a/constantphase.cpp +++ b/constantphase.cpp @@ -82,3 +82,15 @@ std::string Cpe::getCode(std::vector<std::string>& parameters) std::string out = "std::complex<fvalue>(" + real +", " + imag + ")"; return out; } + +std::string Cpe::getTorchScript(std::vector<std::string>& parameters) +{ + std::string firstParameter = getUniqueName() + "_0"; + std::string secondParameter = getUniqueName() + "_1"; + parameters.push_back(firstParameter); + parameters.push_back(secondParameter); + std::string real = "(1/(" + firstParameter + "*torch.pow(omegas,"+ secondParameter +")))*torch.cos((torch.pi/2)*" + secondParameter + ")"; + std::string imag = "(1/(" + firstParameter + "*torch.pow(omegas,"+ secondParameter +")))*torch.sin((torch.pi/2)*" + secondParameter + ")"; + std::string out = real + '-' + imag + "*1j"; + return out; +} diff --git a/eisgenerator/cap.h b/eisgenerator/cap.h index 40a1080ae687cd8e05c2de7662b244eb52390115..78070597fb4a8dc02f31a8f162748b405b08e65e 100644 --- a/eisgenerator/cap.h +++ b/eisgenerator/cap.h @@ -18,6 +18,7 @@ public: static constexpr char staticGetComponantChar(){return 'c';} virtual std::string componantName() const override {return "Capacitor";} virtual std::string getCode(std::vector<std::string>& parameters) override; + virtual std::string getTorchScript(std::vector<std::string>& parameters) override; virtual ~Cap() = default; }; diff --git a/eisgenerator/componant.h b/eisgenerator/componant.h index 0034baf47f6a1d316b95590fb0c21d9d6513c484..5b8432da27750083d1ab8fb84597539487c0825e 100644 --- a/eisgenerator/componant.h +++ b/eisgenerator/componant.h @@ -32,6 +32,7 @@ class Componant virtual std::string getComponantString(bool currentValue = true) const; virtual std::string componantName() const = 0; virtual std::string getCode(std::vector<std::string>& parameters); + virtual std::string getTorchScript(std::vector<std::string>& parameters); virtual bool compileable(); std::string getUniqueName(); diff --git a/eisgenerator/constantphase.h b/eisgenerator/constantphase.h index 8bb0c58d0dae330285acd52aa7858a3362283dab..15e3acf2a21636fddebc60c505944e54b942873a 100644 --- a/eisgenerator/constantphase.h +++ b/eisgenerator/constantphase.h @@ -22,6 +22,7 @@ public: virtual std::string componantName() const override {return "ConstantPhase";} virtual ~Cpe() = default; virtual std::string getCode(std::vector<std::string>& parameters) override; + virtual std::string getTorchScript(std::vector<std::string>& parameters) override; }; } diff --git a/eisgenerator/inductor.h b/eisgenerator/inductor.h index a410eccb6c25687d6f926b62002fd3499c3d8eac..a6e7129b8b2771c60f1053b4a0de249ce5e290e8 100644 --- a/eisgenerator/inductor.h +++ b/eisgenerator/inductor.h @@ -18,6 +18,7 @@ public: static constexpr char staticGetComponantChar(){return 'l';} virtual std::string componantName() const override {return "Inductor";} virtual std::string getCode(std::vector<std::string>& parameters) override; + virtual std::string getTorchScript(std::vector<std::string>& parameters) override; virtual ~Inductor() = default; }; diff --git a/eisgenerator/model.h b/eisgenerator/model.h index 68a3e97e4b27ff0f0e16c48c15d838c61fdd9e32..4713f38596a4bd6658a080b7782e724cb46509bb 100644 --- a/eisgenerator/model.h +++ b/eisgenerator/model.h @@ -58,6 +58,8 @@ public: size_t getRequiredStepsForSweeps(); bool isParamSweep(); std::string getCode(); + std::string getTorchScript(); + std::string getCompiledFunctionName(); std::vector<size_t> getRecommendedParamIndices(eis::Range omegaRange, double distance, bool threaded = false); }; diff --git a/eisgenerator/paralellseriel.h b/eisgenerator/paralellseriel.h index 902bbe4874a214e7467271b4a5760c816242232a..a0c1be65b471ca1898867d0f09725857a6721c07 100644 --- a/eisgenerator/paralellseriel.h +++ b/eisgenerator/paralellseriel.h @@ -22,6 +22,7 @@ public: virtual std::string componantName() const override {return "Parallel";} virtual bool compileable() override; virtual std::string getCode(std::vector<std::string>& parameters) override; + virtual std::string getTorchScript(std::vector<std::string>& parameters) override; }; class Serial: public Componant @@ -40,6 +41,7 @@ public: virtual std::string componantName() const override {return "Serial";} virtual bool compileable() override; virtual std::string getCode(std::vector<std::string>& parameters) override; + virtual std::string getTorchScript(std::vector<std::string>& parameters) override; }; } diff --git a/eisgenerator/resistor.h b/eisgenerator/resistor.h index 53501b1c2c451b76d41d0f7894525c596481b218..154c01b2de6d822f1d9ca4212e2092869ea8b4df 100644 --- a/eisgenerator/resistor.h +++ b/eisgenerator/resistor.h @@ -16,6 +16,7 @@ public: static constexpr char staticGetComponantChar(){return 'r';} virtual std::string componantName() const override {return "Resistor";} virtual std::string getCode(std::vector<std::string>& parameters) override; + virtual std::string getTorchScript(std::vector<std::string>& parameters) override; virtual ~Resistor() = default; }; diff --git a/eisgenerator/warburg.h b/eisgenerator/warburg.h index 4ace89f365abae96a59b3e796b59d82c311d6c24..4dfa02da9ca6d8a58cc460e061d2c34f4860cc19 100644 --- a/eisgenerator/warburg.h +++ b/eisgenerator/warburg.h @@ -18,6 +18,7 @@ public: static constexpr char staticGetComponantChar(){return 'w';} virtual std::string componantName() const override {return "Warburg";} virtual std::string getCode(std::vector<std::string>& parameters) override; + virtual std::string getTorchScript(std::vector<std::string>& parameters) override; virtual ~Warburg() = default; }; diff --git a/inductor.cpp b/inductor.cpp index 32428a342e73de317e5a162191ae085bce0eeaec..39bf77bf742661c0cc1cd0f300fb7e88fa9a0ef1 100644 --- a/inductor.cpp +++ b/inductor.cpp @@ -53,3 +53,10 @@ std::string Inductor::getCode(std::vector<std::string>& parameters) std::string out = "std::complex<fvalue>(0, " + N + ")"; return out; } + +std::string Inductor::getTorchScript(std::vector<std::string>& parameters) +{ + parameters.push_back(getUniqueName() + "_0"); + std::string out = parameters.back() + "*omegas*1j"; + return out; +} diff --git a/log.cpp b/log.cpp deleted file mode 100644 index fc50c5889344ebd1f594778e957d98c9d4012ca5..0000000000000000000000000000000000000000 --- a/log.cpp +++ /dev/null @@ -1,65 +0,0 @@ -/** -* Lubricant Detecter -* Copyright (C) 2021 Carl Klemm -* -* This program is free software; you can redistribute it and/or -* modify it under the terms of the GNU General Public License -* version 3 as published by the Free Software Foundation. -* -* This program is distributed in the hope that it will be useful, -* but WITHOUT ANY WARRANTY; without even the implied warranty of -* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the -* GNU General Public License for more details. -* -* You should have received a copy of the GNU General Public License -* along with this program; if not, write to the -* Free Software Foundation, Inc., 51 Franklin Street, Fifth Floor, -* Boston, MA 02110-1301, USA. -*/ - -#include "log.h" - -using namespace eis; - -Log::Log(Level type, bool endlineI): endline(endlineI) -{ - msglevel = type; - if(headers) - { - operator << ("["+getLabel(type)+"] "); - } -} - -Log::~Log() -{ - if(opened && endline) - { - std::cout<<'\n'; - } - opened = false; -} - - -std::string Log::getLabel(Level level) -{ - std::string label; - switch(level) - { - case DEBUG: - label = "DEBUG"; - break; - case INFO: - label = "INFO "; - break; - case WARN: - label = "WARN "; - break; - case ERROR: - label = "ERROR"; - break; - } - return label; -} - -bool Log::headers = false; -Log::Level Log::level = WARN; diff --git a/main.cpp b/main.cpp index 0a95067c93158395e329ae5f9e1c1fe07749ab4a..7c25bcc520476754a9f43d398932d77c676aa1e9 100644 --- a/main.cpp +++ b/main.cpp @@ -380,6 +380,10 @@ int main(int argc, char** argv) { std::cout<<model.getCode(); } + else if(config.mode == MODE_TORCH_SCRIPT) + { + std::cout<<model.getTorchScript(); + } else { if(model.isParamSweep()) diff --git a/model.cpp b/model.cpp index 52dbb222cd3d1c27d2a8f46ff5841be700d7467c..de58e0955e90e75882b16e099d8bb056a984ee2e 100644 --- a/model.cpp +++ b/model.cpp @@ -2,6 +2,7 @@ #include <model.h> #include <iostream> #include <assert.h> +#include <sstream> #include <string> #include <vector> #include <array> @@ -558,7 +559,7 @@ bool Model::compile() if(!object.objectCode) throw std::runtime_error("Unable to dlopen compiled model " + std::string(dlerror())); - std::string symbolName = "model_" + std::to_string(getUuid()); + std::string symbolName = getCompiledFunctionName(); object.symbol = reinterpret_cast<std::vector<std::complex<fvalue>>(*)(const std::vector<fvalue>&, const std::vector<fvalue>&)> (dlsym(object.objectCode, symbolName.c_str())); @@ -588,8 +589,8 @@ std::string Model::getCode() "#include <complex>\n\n" "typedef float fvalue;\n\n" "extern \"C\"\n{\n\n" - "std::vector<std::complex<fvalue>> model_"; - out.append(std::to_string(getUuid())); + "std::vector<std::complex<fvalue>> "; + out.append(getCompiledFunctionName()); out.append("(const std::vector<fvalue>& parameters, const std::vector<fvalue> omegas)\n{\n\tassert(parameters.size() == "); out.append(std::to_string(parameters.size())); out.append(");\n\n"); @@ -605,3 +606,25 @@ std::string Model::getCode() out.append(";\n\t}\n\treturn out;\n}\n\n}\n"); return out; } + +std::string Model::getTorchScript() +{ + if(!_model || !_model->compileable()) + return ""; + + std::vector<std::string> parameters; + std::string formular = _model->getTorchScript(parameters); + + std::stringstream out; + out<<"def "<<getCompiledFunctionName()<<"(parameters: torch.Tensor, omegas: torch.Tensor) -> torch.Tensor:\n"; + out<<" assert parameters.size(0) is "<<parameters.size()<<"\n\n"; + for(size_t i = 0; i < parameters.size(); ++i) + out<<" "<<parameters[i]<<" = parameters["<<i<<"]\n"; + out<<"\n return "<<formular<<'\n'; + return out.str(); +} + +std::string Model::getCompiledFunctionName() +{ + return "model_"+std::to_string(getUuid()); +} diff --git a/options.h b/options.h index d8e796928ab673f8138cfdf02abb56b909b18e92..f8665472b57261778f0e16696e82d4dfc5fadb45 100644 --- a/options.h +++ b/options.h @@ -26,7 +26,7 @@ static struct argp_option options[] = {"invert", 'i', 0, 0, "inverts the imaginary axis"}, {"noise", 'x', "[AMPLITUDE]", 0, "add noise to output"}, {"input-type", 't', "[STRING]", 0, "set input string type, possible values: eis, boukamp, relaxis, madap"}, - {"mode", 'f', "[STRING]", 0, "mode, possible values: export, code, find-range, export-ranges"}, + {"mode", 'f', "[STRING]", 0, "mode, possible values: export, code, script, find-range, export-ranges"}, {"range-distance", 'd', "[DISTANCE]", 0, "distance from a previous point where a range is considered \"new\""}, {"parallel", 'p', 0, 0, "run on multiple threads"}, {"skip-linear", 'e', 0, 0, "dont output param sweeps that create linear nyquist plots"}, @@ -51,7 +51,8 @@ enum MODE_FIND_RANGE, MODE_OUTPUT_RANGE_DATAPOINTS, MODE_INVALID, - MODE_CODE + MODE_CODE, + MODE_TORCH_SCRIPT }; struct Config @@ -100,6 +101,8 @@ static int parseMode(const std::string& str) return MODE_OUTPUT_RANGE_DATAPOINTS; else if(str == "code") return MODE_CODE; + else if(str == "script") + return MODE_TORCH_SCRIPT; return MODE_INVALID; } diff --git a/paralellseriel.cpp b/paralellseriel.cpp index e661372f14a30abb42337e6bcaeb747e1cd1c48d..d40433b95269a5e8de1e6695625acd9958d84d57 100644 --- a/paralellseriel.cpp +++ b/paralellseriel.cpp @@ -74,6 +74,20 @@ std::string Parallel::getCode(std::vector<std::string>& parameters) return out; } +std::string Parallel::getTorchScript(std::vector<std::string>& parameters) +{ + std::string out = "1/("; + for(Componant* componant : componants) + { + out += "1/(" + componant->getTorchScript(parameters) + ") + "; + } + out.pop_back(); + out.pop_back(); + out.pop_back(); + out.push_back(')'); + return out; +} + Serial::Serial(std::vector<Componant*> componantsIn): componants(componantsIn) { } @@ -147,3 +161,17 @@ std::string Serial::getCode(std::vector<std::string>& parameters) out.push_back(')'); return out; } + +std::string Serial::getTorchScript(std::vector<std::string>& parameters) +{ + std::string out = "("; + for(Componant* componant : componants) + { + out += "(" + componant->getTorchScript(parameters) + ") + "; + } + out.pop_back(); + out.pop_back(); + out.pop_back(); + out.push_back(')'); + return out; +} diff --git a/resistor.cpp b/resistor.cpp index 406702b1f7b37fc2929a2ffcc564626dc088785f..2318555e32bd1ebf5dfe40fe27a2e1a62ab3bb49 100644 --- a/resistor.cpp +++ b/resistor.cpp @@ -53,3 +53,10 @@ std::string Resistor::getCode(std::vector<std::string>& parameters) std::string out = "std::complex<fvalue>(" + parameters.back() + ", 0)"; return out; } + +std::string Resistor::getTorchScript(std::vector<std::string>& parameters) +{ + parameters.push_back(getUniqueName() + "_0"); + + return parameters.back(); +} diff --git a/warburg.cpp b/warburg.cpp index a6c314d12fd4c6573761eb0962a64d989fdc3663..f0fc0f2d8c838aea3c7c73b6ff927e636c26d9ba 100644 --- a/warburg.cpp +++ b/warburg.cpp @@ -54,3 +54,12 @@ std::string Warburg::getCode(std::vector<std::string>& parameters) std::string out = "std::complex<fvalue>(" + N + ", 0-" + N + ")"; return out; } + +std::string Warburg::getTorchScript(std::vector<std::string>& parameters) +{ + parameters.push_back(getUniqueName() + "_0"); + + std::string N = "(" + parameters.back() + "/torch.sqrt(omegas))"; + std::string out = N + "-" + N + "*1j"; + return out; +}