Skip to content
Snippets Groups Projects
Select Git revision
  • aa59701fffeba831428d115a9049b19a7cce9630
  • master default protected
2 results

eistotorch.cpp

Blame
  • eistotorch.cpp 3.37 KiB
    //
    // libeisdrt - A library to calculate EIS Drts
    // Copyright (C) 2023 Carl Klemm <carl@uvos.xyz>
    //
    // This file is part of libeisdrt.
    //
    // libeisdrt is free software: you can redistribute it and/or modify
    // it under the terms of the GNU Lesser General Public License as published by
    // the Free Software Foundation, either version 3 of the License, or
    // (at your option) any later version.
    //
    // libeisdrt is distributed in the hope that it will be useful,
    // but WITHOUT ANY WARRANTY; without even the implied warranty of
    // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
    // GNU Lesser General Public License for more details.
    //
    // You should have received a copy of the GNU Lesser General Public License
    // along with libeisdrt.  If not, see <http://www.gnu.org/licenses/>.
    //
    
    #include "eistotorch.h"
    #include <cassert>
    #include <cmath>
    #include <cstdint>
    
    #include "tensoroptions.h"
    
    torch::Tensor eisToComplexTensor(const std::vector<eis::DataPoint>& data, torch::Tensor* freqs)
    {
    	torch::TensorOptions options = tensorOptCpu<fvalue>();
    
    	if constexpr(std::is_same<fvalue, float>::value)
    		options = options.dtype(torch::kComplexFloat);
    	else
    		options = options.dtype(torch::kComplexDouble);
    	torch::Tensor output = torch::empty({static_cast<long int>(data.size())}, options);
    	if(freqs)
    		*freqs = torch::empty({static_cast<long int>(data.size())}, tensorOptCpu<fvalue>());
    
    	torch::Tensor real = torch::real(output);
    	torch::Tensor imag = torch::imag(output);
    
    	auto realAccessor = real.accessor<fvalue, 1>();
    	auto imagAccessor = imag.accessor<fvalue, 1>();
    	float* tensorFreqDataPtr = freqs ? freqs->contiguous().data_ptr<float>() : nullptr;
    
    	for(size_t i = 0; i < data.size(); ++i)
    	{
    		fvalue real = data[i].im.real();
    		fvalue imag = data[i].im.imag();
    		if(std::isnan(real) || std::isinf(real))
    			real = 0;
    		if(std::isnan(imag) || std::isinf(imag))
    			real = 0;
    
    		realAccessor[i] = real;
    		imagAccessor[i] = imag;
    		if(tensorFreqDataPtr)
    			tensorFreqDataPtr[i] = data[i % data.size()].omega;
    	}
    
    	return output;
    }
    
    torch::Tensor eisToTorch(const std::vector<eis::DataPoint>& data, torch::Tensor* freqs)
    {
    	torch::Tensor input = torch::empty({static_cast<long int>(data.size()*2)}, tensorOptCpu<fvalue>());
    	if(freqs)
    		*freqs = torch::empty({static_cast<long int>(data.size()*2)}, tensorOptCpu<fvalue>());