Skip to content
Snippets Groups Projects
Commit 77d0fd61 authored by Carl Philipp Klemm's avatar Carl Philipp Klemm
Browse files

Clean up tensoroptions usage,

add another conveniance calcDrt variant
parent a13bbd25
Branches
No related tags found
No related merge requests found
...@@ -35,7 +35,7 @@ target_include_directories(${PROJECT_NAME}_test PRIVATE . ${TORCH_INCLUDE_DIRS}) ...@@ -35,7 +35,7 @@ target_include_directories(${PROJECT_NAME}_test PRIVATE . ${TORCH_INCLUDE_DIRS})
set_target_properties(${PROJECT_NAME} PROPERTIES COMPILE_FLAGS "-Wall -O2 -march=native -g" LINK_FLAGS "-flto") set_target_properties(${PROJECT_NAME} PROPERTIES COMPILE_FLAGS "-Wall -O2 -march=native -g" LINK_FLAGS "-flto")
install(TARGETS ${PROJECT_NAME} RUNTIME DESTINATION bin) install(TARGETS ${PROJECT_NAME} RUNTIME DESTINATION bin)
set(API_HEADERS_DIR drt/) set(API_HEADERS_DIR eisdrt/)
set(API_HEADERS set(API_HEADERS
${API_HEADERS_DIR}/drt.h ${API_HEADERS_DIR}/drt.h
) )
......
...@@ -22,7 +22,7 @@ static torch::Tensor guesStartingPoint(torch::Tensor& omega, torch::Tensor& impe ...@@ -22,7 +22,7 @@ static torch::Tensor guesStartingPoint(torch::Tensor& omega, torch::Tensor& impe
static torch::Tensor aImag(torch::Tensor& omega) static torch::Tensor aImag(torch::Tensor& omega)
{ {
torch::Tensor tau = 1.0/(omega/(2*M_PI)); torch::Tensor tau = 1.0/(omega/(2*M_PI));
torch::Tensor out = torch::zeros({omega.numel(), omega.numel()}, torch::TensorOptions().dtype(torch::kFloat32)); torch::Tensor out = torch::zeros({omega.numel(), omega.numel()}, tensorOptCpu<fvalue>());
auto outAccessor = out.accessor<float, 2>(); auto outAccessor = out.accessor<float, 2>();
auto omegaAccessor = omega.accessor<float, 1>(); auto omegaAccessor = omega.accessor<float, 1>();
auto tauAccessor = tau.accessor<float, 1>(); auto tauAccessor = tau.accessor<float, 1>();
...@@ -168,3 +168,12 @@ torch::Tensor calcDrt(const std::vector<eis::DataPoint>& data, const std::vector ...@@ -168,3 +168,12 @@ torch::Tensor calcDrt(const std::vector<eis::DataPoint>& data, const std::vector
torch::Tensor omegaTensor = fvalueVectorToTensor(const_cast<std::vector<fvalue>&>(omegaVector)).clone(); torch::Tensor omegaTensor = fvalueVectorToTensor(const_cast<std::vector<fvalue>&>(omegaVector)).clone();
return calcDrt(impedanceSpectra, omegaTensor, fm, fp); return calcDrt(impedanceSpectra, omegaTensor, fm, fp);
} }
torch::Tensor calcDrt(const std::vector<eis::DataPoint>& data, FitMetics& fm, const FitParameters& fp)
{
torch::Tensor omegaTensor;
torch::Tensor impedanceSpectra = eisToComplexTensor(data, &omegaTensor);
return calcDrt(impedanceSpectra, omegaTensor, fm, fp);
}
...@@ -20,3 +20,5 @@ struct FitParameters ...@@ -20,3 +20,5 @@ struct FitParameters
torch::Tensor calcDrt(torch::Tensor& impedanceSpectra, torch::Tensor& omegaTensor, FitMetics& fm, const FitParameters& fp); torch::Tensor calcDrt(torch::Tensor& impedanceSpectra, torch::Tensor& omegaTensor, FitMetics& fm, const FitParameters& fp);
torch::Tensor calcDrt(const std::vector<eis::DataPoint>& data, const std::vector<fvalue>& omegaVector, FitMetics& fm, const FitParameters& fp); torch::Tensor calcDrt(const std::vector<eis::DataPoint>& data, const std::vector<fvalue>& omegaVector, FitMetics& fm, const FitParameters& fp);
torch::Tensor calcDrt(const std::vector<eis::DataPoint>& data, FitMetics& fm, const FitParameters& fp);
...@@ -7,13 +7,7 @@ ...@@ -7,13 +7,7 @@
torch::Tensor eisToComplexTensor(const std::vector<eis::DataPoint>& data, torch::Tensor* freqs) torch::Tensor eisToComplexTensor(const std::vector<eis::DataPoint>& data, torch::Tensor* freqs)
{ {
torch::TensorOptions options = tensorOptCpu<fvalue>(); torch::Tensor output = torch::empty({static_cast<long int>(data.size())}, tensorOptCplxCpu<fvalue>());
if constexpr(std::is_same<fvalue, float>::value)
options = options.dtype(torch::kComplexFloat);
else
options = options.dtype(torch::kComplexDouble);
torch::Tensor output = torch::empty({static_cast<long int>(data.size())}, options);
if(freqs) if(freqs)
*freqs = torch::empty({static_cast<long int>(data.size())}, tensorOptCpu<fvalue>()); *freqs = torch::empty({static_cast<long int>(data.size())}, tensorOptCpu<fvalue>());
......
#pragma once #pragma once
#include <c10/core/ScalarType.h>
#include <torch/torch.h> #include <torch/torch.h>
template <typename V> template <typename V>
...@@ -16,3 +17,19 @@ inline torch::TensorOptions tensorOptCpu(bool grad = true) ...@@ -16,3 +17,19 @@ inline torch::TensorOptions tensorOptCpu(bool grad = true)
options = options.requires_grad(grad); options = options.requires_grad(grad);
return options; return options;
} }
template <typename V>
inline torch::TensorOptions tensorOptCplxCpu(bool grad = true)
{
static_assert(std::is_same<V, float>::value || std::is_same<V, double>::value,
"This function can only be passed double or float types");
torch::TensorOptions options;
if constexpr(std::is_same<V, float>::value)
options = options.dtype(torch::kComplexFloat);
else
options = options.dtype(torch::kComplexDouble);
options = options.layout(torch::kStrided);
options = options.device(torch::kCPU);
options = options.requires_grad(grad);
return options;
}
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment