From f0f4b2d76ca17bdf58bc9018e6d989433e1ee63c Mon Sep 17 00:00:00 2001
From: Carl Philipp Klemm <philipp@uvos.xyz>
Date: Thu, 14 Sep 2023 18:14:13 +0200
Subject: [PATCH] Futher work towards compleate torchscript support

---
 cap.cpp                       |  9 +++++
 componant.cpp                 |  7 ++++
 constantphase.cpp             | 12 +++++++
 eisgenerator/cap.h            |  1 +
 eisgenerator/componant.h      |  1 +
 eisgenerator/constantphase.h  |  1 +
 eisgenerator/inductor.h       |  1 +
 eisgenerator/model.h          |  2 ++
 eisgenerator/paralellseriel.h |  2 ++
 eisgenerator/resistor.h       |  1 +
 eisgenerator/warburg.h        |  1 +
 inductor.cpp                  |  7 ++++
 log.cpp                       | 65 -----------------------------------
 main.cpp                      |  4 +++
 model.cpp                     | 29 ++++++++++++++--
 options.h                     |  7 ++--
 paralellseriel.cpp            | 28 +++++++++++++++
 resistor.cpp                  |  7 ++++
 warburg.cpp                   |  9 +++++
 19 files changed, 124 insertions(+), 70 deletions(-)
 delete mode 100644 log.cpp

diff --git a/cap.cpp b/cap.cpp
index eb80781..8fadbf7 100644
--- a/cap.cpp
+++ b/cap.cpp
@@ -55,3 +55,12 @@ std::string Cap::getCode(std::vector<std::string>& parameters)
 	std::string out = "std::complex<fvalue>(" + real + ", " + imag + ")";
 	return out;
 }
+
+std::string Cap::getTorchScript(std::vector<std::string>& parameters)
+{
+	parameters.push_back(getUniqueName() + "_0");
+
+	std::string N = "1/(" + parameters.back() + "*omegas)";
+	std::string out =  "0-" + N + "*1j";
+	return out;
+}
diff --git a/componant.cpp b/componant.cpp
index 72ec08a..1d099f0 100644
--- a/componant.cpp
+++ b/componant.cpp
@@ -76,6 +76,13 @@ bool Componant::compileable()
 
 std::string Componant::getCode(std::vector<std::string>& parameters)
 {
+	(void)parameters;
+	return std::string();
+}
+
+std::string Componant::getTorchScript(std::vector<std::string>& parameters)
+{
+	(void)parameters;
 	return std::string();
 }
 
diff --git a/constantphase.cpp b/constantphase.cpp
index 4a9f6ec..5814e9b 100644
--- a/constantphase.cpp
+++ b/constantphase.cpp
@@ -82,3 +82,15 @@ std::string Cpe::getCode(std::vector<std::string>& parameters)
 	std::string out = "std::complex<fvalue>(" + real +", " + imag + ")";
 	return out;
 }
+
+std::string Cpe::getTorchScript(std::vector<std::string>& parameters)
+{
+	std::string firstParameter = getUniqueName() + "_0";
+	std::string secondParameter = getUniqueName() + "_1";
+	parameters.push_back(firstParameter);
+	parameters.push_back(secondParameter);
+	std::string real = "(1/(" + firstParameter + "*torch.pow(omegas,"+ secondParameter +")))*torch.cos((torch.pi/2)*" + secondParameter + ")";
+	std::string imag = "(1/(" + firstParameter + "*torch.pow(omegas,"+ secondParameter +")))*torch.sin((torch.pi/2)*" + secondParameter + ")";
+	std::string out = real + '-' + imag + "*1j";
+	return out;
+}
diff --git a/eisgenerator/cap.h b/eisgenerator/cap.h
index 40a1080..7807059 100644
--- a/eisgenerator/cap.h
+++ b/eisgenerator/cap.h
@@ -18,6 +18,7 @@ public:
 	static constexpr char staticGetComponantChar(){return 'c';}
 	virtual std::string componantName() const override {return "Capacitor";}
 	virtual std::string getCode(std::vector<std::string>& parameters) override;
+	virtual std::string getTorchScript(std::vector<std::string>& parameters) override;
 	virtual ~Cap() = default;
 };
 
diff --git a/eisgenerator/componant.h b/eisgenerator/componant.h
index 0034baf..5b8432d 100644
--- a/eisgenerator/componant.h
+++ b/eisgenerator/componant.h
@@ -32,6 +32,7 @@ class Componant
 		virtual std::string getComponantString(bool currentValue = true) const;
 		virtual std::string componantName() const = 0;
 		virtual std::string getCode(std::vector<std::string>& parameters);
+		virtual std::string getTorchScript(std::vector<std::string>& parameters);
 		virtual bool compileable();
 
 		std::string getUniqueName();
diff --git a/eisgenerator/constantphase.h b/eisgenerator/constantphase.h
index 8bb0c58..15e3acf 100644
--- a/eisgenerator/constantphase.h
+++ b/eisgenerator/constantphase.h
@@ -22,6 +22,7 @@ public:
 	virtual std::string componantName() const override {return "ConstantPhase";}
 	virtual ~Cpe() = default;
 	virtual std::string getCode(std::vector<std::string>& parameters) override;
+	virtual std::string getTorchScript(std::vector<std::string>& parameters) override;
 };
 
 }
diff --git a/eisgenerator/inductor.h b/eisgenerator/inductor.h
index a410ecc..a6e7129 100644
--- a/eisgenerator/inductor.h
+++ b/eisgenerator/inductor.h
@@ -18,6 +18,7 @@ public:
 	static constexpr char staticGetComponantChar(){return 'l';}
 	virtual std::string componantName() const override {return "Inductor";}
 	virtual std::string getCode(std::vector<std::string>& parameters) override;
+	virtual std::string getTorchScript(std::vector<std::string>& parameters) override;
 	virtual ~Inductor() = default;
 };
 
diff --git a/eisgenerator/model.h b/eisgenerator/model.h
index 68a3e97..4713f38 100644
--- a/eisgenerator/model.h
+++ b/eisgenerator/model.h
@@ -58,6 +58,8 @@ public:
 	size_t getRequiredStepsForSweeps();
 	bool isParamSweep();
 	std::string getCode();
+	std::string getTorchScript();
+	std::string getCompiledFunctionName();
 	std::vector<size_t> getRecommendedParamIndices(eis::Range omegaRange, double distance, bool threaded = false);
 };
 
diff --git a/eisgenerator/paralellseriel.h b/eisgenerator/paralellseriel.h
index 902bbe4..a0c1be6 100644
--- a/eisgenerator/paralellseriel.h
+++ b/eisgenerator/paralellseriel.h
@@ -22,6 +22,7 @@ public:
 	virtual std::string componantName() const override {return "Parallel";}
 	virtual bool compileable() override;
 	virtual std::string getCode(std::vector<std::string>& parameters) override;
+	virtual std::string getTorchScript(std::vector<std::string>& parameters) override;
 };
 
 class Serial: public Componant
@@ -40,6 +41,7 @@ public:
 	virtual std::string componantName() const override {return "Serial";}
 	virtual bool compileable() override;
 	virtual std::string getCode(std::vector<std::string>& parameters) override;
+	virtual std::string getTorchScript(std::vector<std::string>& parameters) override;
 };
 
 }
diff --git a/eisgenerator/resistor.h b/eisgenerator/resistor.h
index 53501b1..154c01b 100644
--- a/eisgenerator/resistor.h
+++ b/eisgenerator/resistor.h
@@ -16,6 +16,7 @@ public:
 	static constexpr char staticGetComponantChar(){return 'r';}
 	virtual std::string componantName() const override {return "Resistor";}
 	virtual std::string getCode(std::vector<std::string>& parameters) override;
+	virtual std::string getTorchScript(std::vector<std::string>& parameters) override;
 	virtual ~Resistor() = default;
 };
 
diff --git a/eisgenerator/warburg.h b/eisgenerator/warburg.h
index 4ace89f..4dfa02d 100644
--- a/eisgenerator/warburg.h
+++ b/eisgenerator/warburg.h
@@ -18,6 +18,7 @@ public:
 	static constexpr char staticGetComponantChar(){return 'w';}
 	virtual std::string componantName() const override {return "Warburg";}
 	virtual std::string getCode(std::vector<std::string>& parameters) override;
+	virtual std::string getTorchScript(std::vector<std::string>& parameters) override;
 	virtual ~Warburg() = default;
 };
 
diff --git a/inductor.cpp b/inductor.cpp
index 32428a3..39bf77b 100644
--- a/inductor.cpp
+++ b/inductor.cpp
@@ -53,3 +53,10 @@ std::string Inductor::getCode(std::vector<std::string>& parameters)
 	std::string out = "std::complex<fvalue>(0, " + N + ")";
 	return out;
 }
+
+std::string Inductor::getTorchScript(std::vector<std::string>& parameters)
+{
+	parameters.push_back(getUniqueName() + "_0");
+	std::string out = parameters.back() + "*omegas*1j";
+	return out;
+}
diff --git a/log.cpp b/log.cpp
deleted file mode 100644
index fc50c58..0000000
--- a/log.cpp
+++ /dev/null
@@ -1,65 +0,0 @@
-/**
-* Lubricant Detecter
-* Copyright (C) 2021 Carl Klemm
-*
-* This program is free software; you can redistribute it and/or
-* modify it under the terms of the GNU General Public License
-* version 3 as published by the Free Software Foundation.
-*
-* This program is distributed in the hope that it will be useful,
-* but WITHOUT ANY WARRANTY; without even the implied warranty of
-* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
-* GNU General Public License for more details.
-*
-* You should have received a copy of the GNU General Public License
-* along with this program; if not, write to the
-* Free Software Foundation, Inc., 51 Franklin Street, Fifth Floor,
-* Boston, MA  02110-1301, USA.
-*/
-
-#include "log.h"
-
-using namespace eis;
-
-Log::Log(Level type, bool endlineI): endline(endlineI)
-{
-	msglevel = type;
-	if(headers) 
-	{
-		operator << ("["+getLabel(type)+"] ");
-	}
-}
-
-Log::~Log() 
-{
-	if(opened && endline)
-	{
-		std::cout<<'\n';
-	}
-	opened = false;
-}
-	
-
-std::string Log::getLabel(Level level) 
-{
-	std::string label;
-	switch(level) 
-	{
-		case DEBUG: 
-			label = "DEBUG"; 
-			break;
-		case INFO:  
-			label = "INFO "; 
-			break;
-		case WARN:  
-			label = "WARN "; 
-			break;
-		case ERROR: 
-			label = "ERROR"; 
-			break;
-	}
-	return label;
-}
-	
-bool Log::headers = false;
-Log::Level Log::level = WARN;
diff --git a/main.cpp b/main.cpp
index 0a95067..7c25bcc 100644
--- a/main.cpp
+++ b/main.cpp
@@ -380,6 +380,10 @@ int main(int argc, char** argv)
 		{
 			std::cout<<model.getCode();
 		}
+		else if(config.mode == MODE_TORCH_SCRIPT)
+		{
+			std::cout<<model.getTorchScript();
+		}
 		else
 		{
 			if(model.isParamSweep())
diff --git a/model.cpp b/model.cpp
index 52dbb22..de58e09 100644
--- a/model.cpp
+++ b/model.cpp
@@ -2,6 +2,7 @@
 #include <model.h>
 #include <iostream>
 #include <assert.h>
+#include <sstream>
 #include <string>
 #include <vector>
 #include <array>
@@ -558,7 +559,7 @@ bool Model::compile()
 		if(!object.objectCode)
 			throw std::runtime_error("Unable to dlopen compiled model " + std::string(dlerror()));
 
-		std::string symbolName = "model_" + std::to_string(getUuid());
+		std::string symbolName = getCompiledFunctionName();
 		object.symbol =
 			reinterpret_cast<std::vector<std::complex<fvalue>>(*)(const std::vector<fvalue>&, const std::vector<fvalue>&)>
 				(dlsym(object.objectCode, symbolName.c_str()));
@@ -588,8 +589,8 @@ std::string Model::getCode()
 	"#include <complex>\n\n"
 	"typedef float fvalue;\n\n"
 	"extern \"C\"\n{\n\n"
-	"std::vector<std::complex<fvalue>> model_";
-	out.append(std::to_string(getUuid()));
+	"std::vector<std::complex<fvalue>> ";
+	out.append(getCompiledFunctionName());
 	out.append("(const std::vector<fvalue>& parameters, const std::vector<fvalue> omegas)\n{\n\tassert(parameters.size() == ");
 	out.append(std::to_string(parameters.size()));
 	out.append(");\n\n");
@@ -605,3 +606,25 @@ std::string Model::getCode()
 	out.append(";\n\t}\n\treturn out;\n}\n\n}\n");
 	return out;
 }
+
+std::string Model::getTorchScript()
+{
+	if(!_model || !_model->compileable())
+		return "";
+
+	std::vector<std::string> parameters;
+	std::string formular = _model->getTorchScript(parameters);
+
+	std::stringstream out;
+	out<<"def "<<getCompiledFunctionName()<<"(parameters: torch.Tensor, omegas: torch.Tensor) -> torch.Tensor:\n";
+	out<<"    assert parameters.size(0) is "<<parameters.size()<<"\n\n";
+	for(size_t i = 0; i < parameters.size(); ++i)
+		out<<"    "<<parameters[i]<<" = parameters["<<i<<"]\n";
+	out<<"\n    return "<<formular<<'\n';
+	return out.str();
+}
+
+std::string Model::getCompiledFunctionName()
+{
+	return "model_"+std::to_string(getUuid());
+}
diff --git a/options.h b/options.h
index d8e7969..f866547 100644
--- a/options.h
+++ b/options.h
@@ -26,7 +26,7 @@ static struct argp_option options[] =
   {"invert",        'i', 0,      0,  "inverts the imaginary axis"},
   {"noise",        'x', "[AMPLITUDE]",      0,  "add noise to output"},
   {"input-type",   't', "[STRING]",      0,  "set input string type, possible values: eis, boukamp, relaxis, madap"},
-  {"mode",         'f', "[STRING]",      0,  "mode, possible values: export, code, find-range, export-ranges"},
+  {"mode",         'f', "[STRING]",      0,  "mode, possible values: export, code, script, find-range, export-ranges"},
   {"range-distance",   'd', "[DISTANCE]",      0,  "distance from a previous point where a range is considered \"new\""},
   {"parallel",   'p', 0,      0,  "run on multiple threads"},
   {"skip-linear",   'e', 0,      0,  "dont output param sweeps that create linear nyquist plots"},
@@ -51,7 +51,8 @@ enum
 	MODE_FIND_RANGE,
 	MODE_OUTPUT_RANGE_DATAPOINTS,
 	MODE_INVALID,
-	MODE_CODE
+	MODE_CODE,
+	MODE_TORCH_SCRIPT
 };
 
 struct Config
@@ -100,6 +101,8 @@ static int parseMode(const std::string& str)
 		return MODE_OUTPUT_RANGE_DATAPOINTS;
 	else if(str == "code")
 		return MODE_CODE;
+	else if(str == "script")
+		return MODE_TORCH_SCRIPT;
 	return MODE_INVALID;
 }
 
diff --git a/paralellseriel.cpp b/paralellseriel.cpp
index e661372..d40433b 100644
--- a/paralellseriel.cpp
+++ b/paralellseriel.cpp
@@ -74,6 +74,20 @@ std::string Parallel::getCode(std::vector<std::string>& parameters)
 	return out;
 }
 
+std::string Parallel::getTorchScript(std::vector<std::string>& parameters)
+{
+	std::string out = "1/(";
+	for(Componant* componant : componants)
+	{
+		out += "1/(" + componant->getTorchScript(parameters) + ") + ";
+	}
+	out.pop_back();
+	out.pop_back();
+	out.pop_back();
+	out.push_back(')');
+	return out;
+}
+
 Serial::Serial(std::vector<Componant*> componantsIn): componants(componantsIn)
 {
 }
@@ -147,3 +161,17 @@ std::string Serial::getCode(std::vector<std::string>& parameters)
 	out.push_back(')');
 	return out;
 }
+
+std::string Serial::getTorchScript(std::vector<std::string>& parameters)
+{
+	std::string out = "(";
+	for(Componant* componant : componants)
+	{
+		out += "(" + componant->getTorchScript(parameters) + ") + ";
+	}
+	out.pop_back();
+	out.pop_back();
+	out.pop_back();
+	out.push_back(')');
+	return out;
+}
diff --git a/resistor.cpp b/resistor.cpp
index 406702b..2318555 100644
--- a/resistor.cpp
+++ b/resistor.cpp
@@ -53,3 +53,10 @@ std::string Resistor::getCode(std::vector<std::string>& parameters)
 	std::string out = "std::complex<fvalue>(" + parameters.back() + ", 0)";
 	return out;
 }
+
+std::string Resistor::getTorchScript(std::vector<std::string>& parameters)
+{
+	parameters.push_back(getUniqueName() + "_0");
+
+	return parameters.back();
+}
diff --git a/warburg.cpp b/warburg.cpp
index a6c314d..f0fc0f2 100644
--- a/warburg.cpp
+++ b/warburg.cpp
@@ -54,3 +54,12 @@ std::string Warburg::getCode(std::vector<std::string>& parameters)
 	std::string out = "std::complex<fvalue>(" + N + ", 0-" + N + ")";
 	return out;
 }
+
+std::string Warburg::getTorchScript(std::vector<std::string>& parameters)
+{
+	parameters.push_back(getUniqueName() + "_0");
+
+	std::string N = "(" + parameters.back() + "/torch.sqrt(omegas))";
+	std::string out =  N + "-" + N + "*1j";
+	return out;
+}
-- 
GitLab