diff --git a/CMakeLists.txt b/CMakeLists.txt index 0a133f0f15a3ac14550240b053036fa16d6dbfc9..12f3e86410133cef5fd5400bf34dcb89f8ed7a6e 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -35,7 +35,7 @@ target_include_directories(${PROJECT_NAME}_test PRIVATE . ${TORCH_INCLUDE_DIRS}) set_target_properties(${PROJECT_NAME} PROPERTIES COMPILE_FLAGS "-Wall -O2 -march=native -g" LINK_FLAGS "-flto") install(TARGETS ${PROJECT_NAME} RUNTIME DESTINATION bin) -set(API_HEADERS_DIR drt/) +set(API_HEADERS_DIR eisdrt/) set(API_HEADERS ${API_HEADERS_DIR}/drt.h ) diff --git a/drt.cpp b/drt.cpp index f7429ee1a4173f6e8229ee896559e61249d1368f..c997d96c9b0bc07a84d39fe1e7246c1ab6aad91d 100644 --- a/drt.cpp +++ b/drt.cpp @@ -22,7 +22,7 @@ static torch::Tensor guesStartingPoint(torch::Tensor& omega, torch::Tensor& impe static torch::Tensor aImag(torch::Tensor& omega) { torch::Tensor tau = 1.0/(omega/(2*M_PI)); - torch::Tensor out = torch::zeros({omega.numel(), omega.numel()}, torch::TensorOptions().dtype(torch::kFloat32)); + torch::Tensor out = torch::zeros({omega.numel(), omega.numel()}, tensorOptCpu<fvalue>()); auto outAccessor = out.accessor<float, 2>(); auto omegaAccessor = omega.accessor<float, 1>(); auto tauAccessor = tau.accessor<float, 1>(); @@ -168,3 +168,12 @@ torch::Tensor calcDrt(const std::vector<eis::DataPoint>& data, const std::vector torch::Tensor omegaTensor = fvalueVectorToTensor(const_cast<std::vector<fvalue>&>(omegaVector)).clone(); return calcDrt(impedanceSpectra, omegaTensor, fm, fp); } + +torch::Tensor calcDrt(const std::vector<eis::DataPoint>& data, FitMetics& fm, const FitParameters& fp) +{ + torch::Tensor omegaTensor; + torch::Tensor impedanceSpectra = eisToComplexTensor(data, &omegaTensor); + return calcDrt(impedanceSpectra, omegaTensor, fm, fp); +} + + diff --git a/eisdrt/drt.h b/eisdrt/drt.h index e3b9c6a22648eb8c0a7b961ad287d46e7b36467d..79f154b173a4b941bc6cb58782a279f4431b9a90 100644 --- a/eisdrt/drt.h +++ b/eisdrt/drt.h @@ -20,3 +20,5 @@ struct FitParameters torch::Tensor calcDrt(torch::Tensor& impedanceSpectra, torch::Tensor& omegaTensor, FitMetics& fm, const FitParameters& fp); torch::Tensor calcDrt(const std::vector<eis::DataPoint>& data, const std::vector<fvalue>& omegaVector, FitMetics& fm, const FitParameters& fp); + +torch::Tensor calcDrt(const std::vector<eis::DataPoint>& data, FitMetics& fm, const FitParameters& fp); diff --git a/eistotorch.cpp b/eistotorch.cpp index fc2d88238f9d1ed24d0ba6c386215afa9bc79023..1d5b6dad4b411ae1ba2ce08d40df967a2c25412e 100644 --- a/eistotorch.cpp +++ b/eistotorch.cpp @@ -7,13 +7,7 @@ torch::Tensor eisToComplexTensor(const std::vector<eis::DataPoint>& data, torch::Tensor* freqs) { - torch::TensorOptions options = tensorOptCpu<fvalue>(); - - if constexpr(std::is_same<fvalue, float>::value) - options = options.dtype(torch::kComplexFloat); - else - options = options.dtype(torch::kComplexDouble); - torch::Tensor output = torch::empty({static_cast<long int>(data.size())}, options); + torch::Tensor output = torch::empty({static_cast<long int>(data.size())}, tensorOptCplxCpu<fvalue>()); if(freqs) *freqs = torch::empty({static_cast<long int>(data.size())}, tensorOptCpu<fvalue>()); diff --git a/tensoroptions.h b/tensoroptions.h index 7fc1701552df81b2c8ca61adefd9c33db9a2a531..858799488ed56f3052a75ff9c608f887ac468275 100644 --- a/tensoroptions.h +++ b/tensoroptions.h @@ -1,4 +1,5 @@ #pragma once +#include <c10/core/ScalarType.h> #include <torch/torch.h> template <typename V> @@ -16,3 +17,19 @@ inline torch::TensorOptions tensorOptCpu(bool grad = true) options = options.requires_grad(grad); return options; } + +template <typename V> +inline torch::TensorOptions tensorOptCplxCpu(bool grad = true) +{ + static_assert(std::is_same<V, float>::value || std::is_same<V, double>::value, + "This function can only be passed double or float types"); + torch::TensorOptions options; + if constexpr(std::is_same<V, float>::value) + options = options.dtype(torch::kComplexFloat); + else + options = options.dtype(torch::kComplexDouble); + options = options.layout(torch::kStrided); + options = options.device(torch::kCPU); + options = options.requires_grad(grad); + return options; +}