Select Git revision
drt.cpp 6.45 KiB
#include "eisdrt/drt.h"
#include <ATen/ops/ones.h>
#include <ATen/ops/zeros.h>
#include <Eigen/Core>
#include <eisgenerator/eistype.h>
#include "tensoroptions.h"
#include "eigentorchconversions.h"
#include "eistotorch.h"
#include "LBFG/LBFGSB.h"
static torch::Tensor guesStartingPoint(torch::Tensor& omega, torch::Tensor& impedanceSpectra)
{
std::vector<int64_t> size = omega.sizes().vec();
++size[0];
torch::Tensor startingPoint = torch::zeros(size, tensorOptCpu<fvalue>(false));
startingPoint[-1] = torch::abs(impedanceSpectra[-1]);
return startingPoint;
}
static torch::Tensor aImag(torch::Tensor& omega)
{
torch::Tensor tau = 1.0/(omega/(2*M_PI));
torch::Tensor out = torch::zeros({omega.numel(), omega.numel()}, tensorOptCpu<fvalue>());
auto outAccessor = out.accessor<float, 2>();
auto omegaAccessor = omega.accessor<float, 1>();
auto tauAccessor = tau.accessor<float, 1>();
for(int32_t i = 0; i < out.size(0); ++i)
{
for(int32_t j = 0; j < out.size(1); ++j)
{
outAccessor[i][j] = 0.5*(omegaAccessor[i]*tauAccessor[j])/(1+std::pow(omegaAccessor[i]*tauAccessor[j], 2));
if(j == 0)
outAccessor[i][j] = outAccessor[i][j]*std::log(tauAccessor[j+1]/tauAccessor[j]);
else if(j == out.size(1)-1)
outAccessor[i][j] = outAccessor[i][j]*std::log(tauAccessor[j]/tauAccessor[j-1]);
else
outAccessor[i][j] = outAccessor[i][j]*std::log(tauAccessor[j+1]/tauAccessor[j-1]);
}
}
return out;
}
static torch::Tensor aReal(torch::Tensor& omega)
{
torch::Tensor tau = 1.0/(omega/(2*M_PI));
torch::Tensor out = torch::zeros({omega.numel(), omega.numel()}, torch::TensorOptions().dtype(torch::kFloat32));
auto outAccessor = out.accessor<float, 2>();
auto omegaAccessor = omega.accessor<float, 1>();
auto tauAccessor = tau.accessor<float, 1>();
for(int32_t i = 0; i < out.size(0); ++i)
{
for(int32_t j = 0; j < out.size(1); ++j)
{
outAccessor[i][j] = -0.5/(1+std::pow(omegaAccessor[i]*tauAccessor[j], 2));
if(j == 0)
outAccessor[i][j] = outAccessor[i][j]*std::log(tauAccessor[j+1]/tauAccessor[j]);
else if(j == out.size(1)-1)
outAccessor[i][j] = outAccessor[i][j]*std::log(tauAccessor[j]/tauAccessor[j-1]);
else
outAccessor[i][j] = outAccessor[i][j]*std::log(tauAccessor[j+1]/tauAccessor[j-1]);
}
}
return out;
}
class RtFunct
{
private: