From 77d0fd61f81f7584c2a76e8b6941d542f11be683 Mon Sep 17 00:00:00 2001 From: Carl Philipp Klemm <philipp@uvos.xyz> Date: Fri, 12 May 2023 15:54:03 +0200 Subject: [PATCH] Clean up tensoroptions usage, add another conveniance calcDrt variant --- CMakeLists.txt | 2 +- drt.cpp | 11 ++++++++++- eisdrt/drt.h | 2 ++ eistotorch.cpp | 8 +------- tensoroptions.h | 17 +++++++++++++++++ 5 files changed, 31 insertions(+), 9 deletions(-) diff --git a/CMakeLists.txt b/CMakeLists.txt index 0a133f0..12f3e86 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -35,7 +35,7 @@ target_include_directories(${PROJECT_NAME}_test PRIVATE . ${TORCH_INCLUDE_DIRS}) set_target_properties(${PROJECT_NAME} PROPERTIES COMPILE_FLAGS "-Wall -O2 -march=native -g" LINK_FLAGS "-flto") install(TARGETS ${PROJECT_NAME} RUNTIME DESTINATION bin) -set(API_HEADERS_DIR drt/) +set(API_HEADERS_DIR eisdrt/) set(API_HEADERS ${API_HEADERS_DIR}/drt.h ) diff --git a/drt.cpp b/drt.cpp index f7429ee..c997d96 100644 --- a/drt.cpp +++ b/drt.cpp @@ -22,7 +22,7 @@ static torch::Tensor guesStartingPoint(torch::Tensor& omega, torch::Tensor& impe static torch::Tensor aImag(torch::Tensor& omega) { torch::Tensor tau = 1.0/(omega/(2*M_PI)); - torch::Tensor out = torch::zeros({omega.numel(), omega.numel()}, torch::TensorOptions().dtype(torch::kFloat32)); + torch::Tensor out = torch::zeros({omega.numel(), omega.numel()}, tensorOptCpu<fvalue>()); auto outAccessor = out.accessor<float, 2>(); auto omegaAccessor = omega.accessor<float, 1>(); auto tauAccessor = tau.accessor<float, 1>(); @@ -168,3 +168,12 @@ torch::Tensor calcDrt(const std::vector<eis::DataPoint>& data, const std::vector torch::Tensor omegaTensor = fvalueVectorToTensor(const_cast<std::vector<fvalue>&>(omegaVector)).clone(); return calcDrt(impedanceSpectra, omegaTensor, fm, fp); } + +torch::Tensor calcDrt(const std::vector<eis::DataPoint>& data, FitMetics& fm, const FitParameters& fp) +{ + torch::Tensor omegaTensor; + torch::Tensor impedanceSpectra = eisToComplexTensor(data, &omegaTensor); + return calcDrt(impedanceSpectra, omegaTensor, fm, fp); +} + + diff --git a/eisdrt/drt.h b/eisdrt/drt.h index e3b9c6a..79f154b 100644 --- a/eisdrt/drt.h +++ b/eisdrt/drt.h @@ -20,3 +20,5 @@ struct FitParameters torch::Tensor calcDrt(torch::Tensor& impedanceSpectra, torch::Tensor& omegaTensor, FitMetics& fm, const FitParameters& fp); torch::Tensor calcDrt(const std::vector<eis::DataPoint>& data, const std::vector<fvalue>& omegaVector, FitMetics& fm, const FitParameters& fp); + +torch::Tensor calcDrt(const std::vector<eis::DataPoint>& data, FitMetics& fm, const FitParameters& fp); diff --git a/eistotorch.cpp b/eistotorch.cpp index fc2d882..1d5b6da 100644 --- a/eistotorch.cpp +++ b/eistotorch.cpp @@ -7,13 +7,7 @@ torch::Tensor eisToComplexTensor(const std::vector<eis::DataPoint>& data, torch::Tensor* freqs) { - torch::TensorOptions options = tensorOptCpu<fvalue>(); - - if constexpr(std::is_same<fvalue, float>::value) - options = options.dtype(torch::kComplexFloat); - else - options = options.dtype(torch::kComplexDouble); - torch::Tensor output = torch::empty({static_cast<long int>(data.size())}, options); + torch::Tensor output = torch::empty({static_cast<long int>(data.size())}, tensorOptCplxCpu<fvalue>()); if(freqs) *freqs = torch::empty({static_cast<long int>(data.size())}, tensorOptCpu<fvalue>()); diff --git a/tensoroptions.h b/tensoroptions.h index 7fc1701..8587994 100644 --- a/tensoroptions.h +++ b/tensoroptions.h @@ -1,4 +1,5 @@ #pragma once +#include <c10/core/ScalarType.h> #include <torch/torch.h> template <typename V> @@ -16,3 +17,19 @@ inline torch::TensorOptions tensorOptCpu(bool grad = true) options = options.requires_grad(grad); return options; } + +template <typename V> +inline torch::TensorOptions tensorOptCplxCpu(bool grad = true) +{ + static_assert(std::is_same<V, float>::value || std::is_same<V, double>::value, + "This function can only be passed double or float types"); + torch::TensorOptions options; + if constexpr(std::is_same<V, float>::value) + options = options.dtype(torch::kComplexFloat); + else + options = options.dtype(torch::kComplexDouble); + options = options.layout(torch::kStrided); + options = options.device(torch::kCPU); + options = options.requires_grad(grad); + return options; +} -- GitLab