Skip to content
Snippets Groups Projects
Select Git revision
  • 77d0fd61f81f7584c2a76e8b6941d542f11be683
  • master default protected
2 results

drt.cpp

Blame
  • drt.cpp 6.45 KiB
    #include "eisdrt/drt.h"
    
    #include <ATen/ops/ones.h>
    #include <ATen/ops/zeros.h>
    #include <Eigen/Core>
    #include <eisgenerator/eistype.h>
    
    #include "tensoroptions.h"
    #include "eigentorchconversions.h"
    #include "eistotorch.h"
    #include "LBFG/LBFGSB.h"
    
    static torch::Tensor guesStartingPoint(torch::Tensor& omega, torch::Tensor& impedanceSpectra)
    {
    	std::vector<int64_t> size = omega.sizes().vec();
    	++size[0];
    	torch::Tensor startingPoint = torch::zeros(size, tensorOptCpu<fvalue>(false));
    	startingPoint[-1] = torch::abs(impedanceSpectra[-1]);
    	return startingPoint;
    }
    
    static torch::Tensor aImag(torch::Tensor& omega)
    {
    	torch::Tensor tau = 1.0/(omega/(2*M_PI));
    	torch::Tensor out = torch::zeros({omega.numel(), omega.numel()}, tensorOptCpu<fvalue>());
    	auto outAccessor = out.accessor<float, 2>();
    	auto omegaAccessor = omega.accessor<float, 1>();
    	auto tauAccessor = tau.accessor<float, 1>();
    	for(int32_t i = 0; i < out.size(0); ++i)
    	{
    		for(int32_t j = 0; j < out.size(1); ++j)
    		{
    			outAccessor[i][j] = 0.5*(omegaAccessor[i]*tauAccessor[j])/(1+std::pow(omegaAccessor[i]*tauAccessor[j], 2));
    			if(j == 0)
    				outAccessor[i][j] = outAccessor[i][j]*std::log(tauAccessor[j+1]/tauAccessor[j]);
    			else if(j == out.size(1)-1)
    				outAccessor[i][j] = outAccessor[i][j]*std::log(tauAccessor[j]/tauAccessor[j-1]);
    			else
    				outAccessor[i][j] = outAccessor[i][j]*std::log(tauAccessor[j+1]/tauAccessor[j-1]);
    		}
    	}
    	return out;
    }
    
    static torch::Tensor aReal(torch::Tensor& omega)
    {
    	torch::Tensor tau = 1.0/(omega/(2*M_PI));
    	torch::Tensor out = torch::zeros({omega.numel(), omega.numel()}, torch::TensorOptions().dtype(torch::kFloat32));
    	auto outAccessor = out.accessor<float, 2>();
    	auto omegaAccessor = omega.accessor<float, 1>();
    	auto tauAccessor = tau.accessor<float, 1>();
    	for(int32_t i = 0; i < out.size(0); ++i)
    	{
    		for(int32_t j = 0; j < out.size(1); ++j)
    		{
    			outAccessor[i][j] = -0.5/(1+std::pow(omegaAccessor[i]*tauAccessor[j], 2));
    			if(j == 0)
    				outAccessor[i][j] = outAccessor[i][j]*std::log(tauAccessor[j+1]/tauAccessor[j]);
    			else if(j == out.size(1)-1)
    				outAccessor[i][j] = outAccessor[i][j]*std::log(tauAccessor[j]/tauAccessor[j-1]);
    			else
    				outAccessor[i][j] = outAccessor[i][j]*std::log(tauAccessor[j+1]/tauAccessor[j-1]);
    		}
    	}
    	return out;
    }
    
    class RtFunct
    {
    private: