Skip to content
Snippets Groups Projects
Commit f0f4b2d7 authored by Carl Philipp Klemm's avatar Carl Philipp Klemm
Browse files

Futher work towards compleate torchscript support

parent 4ab4d58b
No related branches found
No related tags found
No related merge requests found
......@@ -55,3 +55,12 @@ std::string Cap::getCode(std::vector<std::string>& parameters)
std::string out = "std::complex<fvalue>(" + real + ", " + imag + ")";
return out;
}
std::string Cap::getTorchScript(std::vector<std::string>& parameters)
{
parameters.push_back(getUniqueName() + "_0");
std::string N = "1/(" + parameters.back() + "*omegas)";
std::string out = "0-" + N + "*1j";
return out;
}
......@@ -76,6 +76,13 @@ bool Componant::compileable()
std::string Componant::getCode(std::vector<std::string>& parameters)
{
(void)parameters;
return std::string();
}
std::string Componant::getTorchScript(std::vector<std::string>& parameters)
{
(void)parameters;
return std::string();
}
......
......@@ -82,3 +82,15 @@ std::string Cpe::getCode(std::vector<std::string>& parameters)
std::string out = "std::complex<fvalue>(" + real +", " + imag + ")";
return out;
}
std::string Cpe::getTorchScript(std::vector<std::string>& parameters)
{
std::string firstParameter = getUniqueName() + "_0";
std::string secondParameter = getUniqueName() + "_1";
parameters.push_back(firstParameter);
parameters.push_back(secondParameter);
std::string real = "(1/(" + firstParameter + "*torch.pow(omegas,"+ secondParameter +")))*torch.cos((torch.pi/2)*" + secondParameter + ")";
std::string imag = "(1/(" + firstParameter + "*torch.pow(omegas,"+ secondParameter +")))*torch.sin((torch.pi/2)*" + secondParameter + ")";
std::string out = real + '-' + imag + "*1j";
return out;
}
......@@ -18,6 +18,7 @@ public:
static constexpr char staticGetComponantChar(){return 'c';}
virtual std::string componantName() const override {return "Capacitor";}
virtual std::string getCode(std::vector<std::string>& parameters) override;
virtual std::string getTorchScript(std::vector<std::string>& parameters) override;
virtual ~Cap() = default;
};
......
......@@ -32,6 +32,7 @@ class Componant
virtual std::string getComponantString(bool currentValue = true) const;
virtual std::string componantName() const = 0;
virtual std::string getCode(std::vector<std::string>& parameters);
virtual std::string getTorchScript(std::vector<std::string>& parameters);
virtual bool compileable();
std::string getUniqueName();
......
......@@ -22,6 +22,7 @@ public:
virtual std::string componantName() const override {return "ConstantPhase";}
virtual ~Cpe() = default;
virtual std::string getCode(std::vector<std::string>& parameters) override;
virtual std::string getTorchScript(std::vector<std::string>& parameters) override;
};
}
......@@ -18,6 +18,7 @@ public:
static constexpr char staticGetComponantChar(){return 'l';}
virtual std::string componantName() const override {return "Inductor";}
virtual std::string getCode(std::vector<std::string>& parameters) override;
virtual std::string getTorchScript(std::vector<std::string>& parameters) override;
virtual ~Inductor() = default;
};
......
......@@ -58,6 +58,8 @@ public:
size_t getRequiredStepsForSweeps();
bool isParamSweep();
std::string getCode();
std::string getTorchScript();
std::string getCompiledFunctionName();
std::vector<size_t> getRecommendedParamIndices(eis::Range omegaRange, double distance, bool threaded = false);
};
......
......@@ -22,6 +22,7 @@ public:
virtual std::string componantName() const override {return "Parallel";}
virtual bool compileable() override;
virtual std::string getCode(std::vector<std::string>& parameters) override;
virtual std::string getTorchScript(std::vector<std::string>& parameters) override;
};
class Serial: public Componant
......@@ -40,6 +41,7 @@ public:
virtual std::string componantName() const override {return "Serial";}
virtual bool compileable() override;
virtual std::string getCode(std::vector<std::string>& parameters) override;
virtual std::string getTorchScript(std::vector<std::string>& parameters) override;
};
}
......@@ -16,6 +16,7 @@ public:
static constexpr char staticGetComponantChar(){return 'r';}
virtual std::string componantName() const override {return "Resistor";}
virtual std::string getCode(std::vector<std::string>& parameters) override;
virtual std::string getTorchScript(std::vector<std::string>& parameters) override;
virtual ~Resistor() = default;
};
......
......@@ -18,6 +18,7 @@ public:
static constexpr char staticGetComponantChar(){return 'w';}
virtual std::string componantName() const override {return "Warburg";}
virtual std::string getCode(std::vector<std::string>& parameters) override;
virtual std::string getTorchScript(std::vector<std::string>& parameters) override;
virtual ~Warburg() = default;
};
......
......@@ -53,3 +53,10 @@ std::string Inductor::getCode(std::vector<std::string>& parameters)
std::string out = "std::complex<fvalue>(0, " + N + ")";
return out;
}
std::string Inductor::getTorchScript(std::vector<std::string>& parameters)
{
parameters.push_back(getUniqueName() + "_0");
std::string out = parameters.back() + "*omegas*1j";
return out;
}
/**
* Lubricant Detecter
* Copyright (C) 2021 Carl Klemm
*
* This program is free software; you can redistribute it and/or
* modify it under the terms of the GNU General Public License
* version 3 as published by the Free Software Foundation.
*
* This program is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU General Public License for more details.
*
* You should have received a copy of the GNU General Public License
* along with this program; if not, write to the
* Free Software Foundation, Inc., 51 Franklin Street, Fifth Floor,
* Boston, MA 02110-1301, USA.
*/
#include "log.h"
using namespace eis;
Log::Log(Level type, bool endlineI): endline(endlineI)
{
msglevel = type;
if(headers)
{
operator << ("["+getLabel(type)+"] ");
}
}
Log::~Log()
{
if(opened && endline)
{
std::cout<<'\n';
}
opened = false;
}
std::string Log::getLabel(Level level)
{
std::string label;
switch(level)
{
case DEBUG:
label = "DEBUG";
break;
case INFO:
label = "INFO ";
break;
case WARN:
label = "WARN ";
break;
case ERROR:
label = "ERROR";
break;
}
return label;
}
bool Log::headers = false;
Log::Level Log::level = WARN;
......@@ -380,6 +380,10 @@ int main(int argc, char** argv)
{
std::cout<<model.getCode();
}
else if(config.mode == MODE_TORCH_SCRIPT)
{
std::cout<<model.getTorchScript();
}
else
{
if(model.isParamSweep())
......
......@@ -2,6 +2,7 @@
#include <model.h>
#include <iostream>
#include <assert.h>
#include <sstream>
#include <string>
#include <vector>
#include <array>
......@@ -558,7 +559,7 @@ bool Model::compile()
if(!object.objectCode)
throw std::runtime_error("Unable to dlopen compiled model " + std::string(dlerror()));
std::string symbolName = "model_" + std::to_string(getUuid());
std::string symbolName = getCompiledFunctionName();
object.symbol =
reinterpret_cast<std::vector<std::complex<fvalue>>(*)(const std::vector<fvalue>&, const std::vector<fvalue>&)>
(dlsym(object.objectCode, symbolName.c_str()));
......@@ -588,8 +589,8 @@ std::string Model::getCode()
"#include <complex>\n\n"
"typedef float fvalue;\n\n"
"extern \"C\"\n{\n\n"
"std::vector<std::complex<fvalue>> model_";
out.append(std::to_string(getUuid()));
"std::vector<std::complex<fvalue>> ";
out.append(getCompiledFunctionName());
out.append("(const std::vector<fvalue>& parameters, const std::vector<fvalue> omegas)\n{\n\tassert(parameters.size() == ");
out.append(std::to_string(parameters.size()));
out.append(");\n\n");
......@@ -605,3 +606,25 @@ std::string Model::getCode()
out.append(";\n\t}\n\treturn out;\n}\n\n}\n");
return out;
}
std::string Model::getTorchScript()
{
if(!_model || !_model->compileable())
return "";
std::vector<std::string> parameters;
std::string formular = _model->getTorchScript(parameters);
std::stringstream out;
out<<"def "<<getCompiledFunctionName()<<"(parameters: torch.Tensor, omegas: torch.Tensor) -> torch.Tensor:\n";
out<<" assert parameters.size(0) is "<<parameters.size()<<"\n\n";
for(size_t i = 0; i < parameters.size(); ++i)
out<<" "<<parameters[i]<<" = parameters["<<i<<"]\n";
out<<"\n return "<<formular<<'\n';
return out.str();
}
std::string Model::getCompiledFunctionName()
{
return "model_"+std::to_string(getUuid());
}
......@@ -26,7 +26,7 @@ static struct argp_option options[] =
{"invert", 'i', 0, 0, "inverts the imaginary axis"},
{"noise", 'x', "[AMPLITUDE]", 0, "add noise to output"},
{"input-type", 't', "[STRING]", 0, "set input string type, possible values: eis, boukamp, relaxis, madap"},
{"mode", 'f', "[STRING]", 0, "mode, possible values: export, code, find-range, export-ranges"},
{"mode", 'f', "[STRING]", 0, "mode, possible values: export, code, script, find-range, export-ranges"},
{"range-distance", 'd', "[DISTANCE]", 0, "distance from a previous point where a range is considered \"new\""},
{"parallel", 'p', 0, 0, "run on multiple threads"},
{"skip-linear", 'e', 0, 0, "dont output param sweeps that create linear nyquist plots"},
......@@ -51,7 +51,8 @@ enum
MODE_FIND_RANGE,
MODE_OUTPUT_RANGE_DATAPOINTS,
MODE_INVALID,
MODE_CODE
MODE_CODE,
MODE_TORCH_SCRIPT
};
struct Config
......@@ -100,6 +101,8 @@ static int parseMode(const std::string& str)
return MODE_OUTPUT_RANGE_DATAPOINTS;
else if(str == "code")
return MODE_CODE;
else if(str == "script")
return MODE_TORCH_SCRIPT;
return MODE_INVALID;
}
......
......@@ -74,6 +74,20 @@ std::string Parallel::getCode(std::vector<std::string>& parameters)
return out;
}
std::string Parallel::getTorchScript(std::vector<std::string>& parameters)
{
std::string out = "1/(";
for(Componant* componant : componants)
{
out += "1/(" + componant->getTorchScript(parameters) + ") + ";
}
out.pop_back();
out.pop_back();
out.pop_back();
out.push_back(')');
return out;
}
Serial::Serial(std::vector<Componant*> componantsIn): componants(componantsIn)
{
}
......@@ -147,3 +161,17 @@ std::string Serial::getCode(std::vector<std::string>& parameters)
out.push_back(')');
return out;
}
std::string Serial::getTorchScript(std::vector<std::string>& parameters)
{
std::string out = "(";
for(Componant* componant : componants)
{
out += "(" + componant->getTorchScript(parameters) + ") + ";
}
out.pop_back();
out.pop_back();
out.pop_back();
out.push_back(')');
return out;
}
......@@ -53,3 +53,10 @@ std::string Resistor::getCode(std::vector<std::string>& parameters)
std::string out = "std::complex<fvalue>(" + parameters.back() + ", 0)";
return out;
}
std::string Resistor::getTorchScript(std::vector<std::string>& parameters)
{
parameters.push_back(getUniqueName() + "_0");
return parameters.back();
}
......@@ -54,3 +54,12 @@ std::string Warburg::getCode(std::vector<std::string>& parameters)
std::string out = "std::complex<fvalue>(" + N + ", 0-" + N + ")";
return out;
}
std::string Warburg::getTorchScript(std::vector<std::string>& parameters)
{
parameters.push_back(getUniqueName() + "_0");
std::string N = "(" + parameters.back() + "/torch.sqrt(omegas))";
std::string out = N + "-" + N + "*1j";
return out;
}
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment