Skip to content
Snippets Groups Projects
Commit 77d0fd61 authored by Carl Philipp Klemm's avatar Carl Philipp Klemm
Browse files

Clean up tensoroptions usage,

add another conveniance calcDrt variant
parent a13bbd25
Branches
No related tags found
No related merge requests found
......@@ -35,7 +35,7 @@ target_include_directories(${PROJECT_NAME}_test PRIVATE . ${TORCH_INCLUDE_DIRS})
set_target_properties(${PROJECT_NAME} PROPERTIES COMPILE_FLAGS "-Wall -O2 -march=native -g" LINK_FLAGS "-flto")
install(TARGETS ${PROJECT_NAME} RUNTIME DESTINATION bin)
set(API_HEADERS_DIR drt/)
set(API_HEADERS_DIR eisdrt/)
set(API_HEADERS
${API_HEADERS_DIR}/drt.h
)
......
......@@ -22,7 +22,7 @@ static torch::Tensor guesStartingPoint(torch::Tensor& omega, torch::Tensor& impe
static torch::Tensor aImag(torch::Tensor& omega)
{
torch::Tensor tau = 1.0/(omega/(2*M_PI));
torch::Tensor out = torch::zeros({omega.numel(), omega.numel()}, torch::TensorOptions().dtype(torch::kFloat32));
torch::Tensor out = torch::zeros({omega.numel(), omega.numel()}, tensorOptCpu<fvalue>());
auto outAccessor = out.accessor<float, 2>();
auto omegaAccessor = omega.accessor<float, 1>();
auto tauAccessor = tau.accessor<float, 1>();
......@@ -168,3 +168,12 @@ torch::Tensor calcDrt(const std::vector<eis::DataPoint>& data, const std::vector
torch::Tensor omegaTensor = fvalueVectorToTensor(const_cast<std::vector<fvalue>&>(omegaVector)).clone();
return calcDrt(impedanceSpectra, omegaTensor, fm, fp);
}
torch::Tensor calcDrt(const std::vector<eis::DataPoint>& data, FitMetics& fm, const FitParameters& fp)
{
torch::Tensor omegaTensor;
torch::Tensor impedanceSpectra = eisToComplexTensor(data, &omegaTensor);
return calcDrt(impedanceSpectra, omegaTensor, fm, fp);
}
......@@ -20,3 +20,5 @@ struct FitParameters
torch::Tensor calcDrt(torch::Tensor& impedanceSpectra, torch::Tensor& omegaTensor, FitMetics& fm, const FitParameters& fp);
torch::Tensor calcDrt(const std::vector<eis::DataPoint>& data, const std::vector<fvalue>& omegaVector, FitMetics& fm, const FitParameters& fp);
torch::Tensor calcDrt(const std::vector<eis::DataPoint>& data, FitMetics& fm, const FitParameters& fp);
......@@ -7,13 +7,7 @@
torch::Tensor eisToComplexTensor(const std::vector<eis::DataPoint>& data, torch::Tensor* freqs)
{
torch::TensorOptions options = tensorOptCpu<fvalue>();
if constexpr(std::is_same<fvalue, float>::value)
options = options.dtype(torch::kComplexFloat);
else
options = options.dtype(torch::kComplexDouble);
torch::Tensor output = torch::empty({static_cast<long int>(data.size())}, options);
torch::Tensor output = torch::empty({static_cast<long int>(data.size())}, tensorOptCplxCpu<fvalue>());
if(freqs)
*freqs = torch::empty({static_cast<long int>(data.size())}, tensorOptCpu<fvalue>());
......
#pragma once
#include <c10/core/ScalarType.h>
#include <torch/torch.h>
template <typename V>
......@@ -16,3 +17,19 @@ inline torch::TensorOptions tensorOptCpu(bool grad = true)
options = options.requires_grad(grad);
return options;
}
template <typename V>
inline torch::TensorOptions tensorOptCplxCpu(bool grad = true)
{
static_assert(std::is_same<V, float>::value || std::is_same<V, double>::value,
"This function can only be passed double or float types");
torch::TensorOptions options;
if constexpr(std::is_same<V, float>::value)
options = options.dtype(torch::kComplexFloat);
else
options = options.dtype(torch::kComplexDouble);
options = options.layout(torch::kStrided);
options = options.device(torch::kCPU);
options = options.requires_grad(grad);
return options;
}
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment