From 77d0fd61f81f7584c2a76e8b6941d542f11be683 Mon Sep 17 00:00:00 2001
From: Carl Philipp Klemm <philipp@uvos.xyz>
Date: Fri, 12 May 2023 15:54:03 +0200
Subject: [PATCH] Clean up tensoroptions usage, add another conveniance calcDrt
 variant

---
 CMakeLists.txt  |  2 +-
 drt.cpp         | 11 ++++++++++-
 eisdrt/drt.h    |  2 ++
 eistotorch.cpp  |  8 +-------
 tensoroptions.h | 17 +++++++++++++++++
 5 files changed, 31 insertions(+), 9 deletions(-)

diff --git a/CMakeLists.txt b/CMakeLists.txt
index 0a133f0..12f3e86 100644
--- a/CMakeLists.txt
+++ b/CMakeLists.txt
@@ -35,7 +35,7 @@ target_include_directories(${PROJECT_NAME}_test PRIVATE . ${TORCH_INCLUDE_DIRS})
 set_target_properties(${PROJECT_NAME} PROPERTIES COMPILE_FLAGS "-Wall -O2 -march=native -g" LINK_FLAGS "-flto")
 install(TARGETS ${PROJECT_NAME} RUNTIME DESTINATION bin)
 
-set(API_HEADERS_DIR drt/)
+set(API_HEADERS_DIR eisdrt/)
 set(API_HEADERS
 	${API_HEADERS_DIR}/drt.h
 )
diff --git a/drt.cpp b/drt.cpp
index f7429ee..c997d96 100644
--- a/drt.cpp
+++ b/drt.cpp
@@ -22,7 +22,7 @@ static torch::Tensor guesStartingPoint(torch::Tensor& omega, torch::Tensor& impe
 static torch::Tensor aImag(torch::Tensor& omega)
 {
 	torch::Tensor tau = 1.0/(omega/(2*M_PI));
-	torch::Tensor out = torch::zeros({omega.numel(), omega.numel()}, torch::TensorOptions().dtype(torch::kFloat32));
+	torch::Tensor out = torch::zeros({omega.numel(), omega.numel()}, tensorOptCpu<fvalue>());
 	auto outAccessor = out.accessor<float, 2>();
 	auto omegaAccessor = omega.accessor<float, 1>();
 	auto tauAccessor = tau.accessor<float, 1>();
@@ -168,3 +168,12 @@ torch::Tensor calcDrt(const std::vector<eis::DataPoint>& data, const std::vector
 	torch::Tensor omegaTensor = fvalueVectorToTensor(const_cast<std::vector<fvalue>&>(omegaVector)).clone();
 	return calcDrt(impedanceSpectra, omegaTensor, fm, fp);
 }
+
+torch::Tensor calcDrt(const std::vector<eis::DataPoint>& data, FitMetics& fm,  const FitParameters& fp)
+{
+	torch::Tensor omegaTensor;
+	torch::Tensor impedanceSpectra = eisToComplexTensor(data, &omegaTensor);
+	return calcDrt(impedanceSpectra, omegaTensor, fm, fp);
+}
+
+
diff --git a/eisdrt/drt.h b/eisdrt/drt.h
index e3b9c6a..79f154b 100644
--- a/eisdrt/drt.h
+++ b/eisdrt/drt.h
@@ -20,3 +20,5 @@ struct FitParameters
 torch::Tensor calcDrt(torch::Tensor& impedanceSpectra, torch::Tensor& omegaTensor, FitMetics& fm, const FitParameters& fp);
 
 torch::Tensor calcDrt(const std::vector<eis::DataPoint>& data, const std::vector<fvalue>& omegaVector, FitMetics& fm,  const FitParameters& fp);
+
+torch::Tensor calcDrt(const std::vector<eis::DataPoint>& data, FitMetics& fm,  const FitParameters& fp);
diff --git a/eistotorch.cpp b/eistotorch.cpp
index fc2d882..1d5b6da 100644
--- a/eistotorch.cpp
+++ b/eistotorch.cpp
@@ -7,13 +7,7 @@
 
 torch::Tensor eisToComplexTensor(const std::vector<eis::DataPoint>& data, torch::Tensor* freqs)
 {
-	torch::TensorOptions options = tensorOptCpu<fvalue>();
-
-	if constexpr(std::is_same<fvalue, float>::value)
-		options = options.dtype(torch::kComplexFloat);
-	else
-		options = options.dtype(torch::kComplexDouble);
-	torch::Tensor output = torch::empty({static_cast<long int>(data.size())}, options);
+	torch::Tensor output = torch::empty({static_cast<long int>(data.size())}, tensorOptCplxCpu<fvalue>());
 	if(freqs)
 		*freqs = torch::empty({static_cast<long int>(data.size())}, tensorOptCpu<fvalue>());
 
diff --git a/tensoroptions.h b/tensoroptions.h
index 7fc1701..8587994 100644
--- a/tensoroptions.h
+++ b/tensoroptions.h
@@ -1,4 +1,5 @@
 #pragma once
+#include <c10/core/ScalarType.h>
 #include <torch/torch.h>
 
 template <typename V>
@@ -16,3 +17,19 @@ inline torch::TensorOptions tensorOptCpu(bool grad = true)
 	options = options.requires_grad(grad);
 	return options;
 }
+
+template <typename V>
+inline torch::TensorOptions tensorOptCplxCpu(bool grad = true)
+{
+	static_assert(std::is_same<V, float>::value || std::is_same<V, double>::value,
+				  "This function can only be passed double or float types");
+	torch::TensorOptions options;
+	if constexpr(std::is_same<V, float>::value)
+		options = options.dtype(torch::kComplexFloat);
+	else
+		options = options.dtype(torch::kComplexDouble);
+	options = options.layout(torch::kStrided);
+	options = options.device(torch::kCPU);
+	options = options.requires_grad(grad);
+	return options;
+}
-- 
GitLab