From 8bb2ebd70373d237ad8622fc8b9f9f4b3e658aab Mon Sep 17 00:00:00 2001 From: Carl Philipp Klemm <philipp@uvos.xyz> Date: Mon, 15 May 2023 16:41:04 +0200 Subject: [PATCH] add eigen to torch conversions --- eigentorchconversions.h | 85 +++++++++++++++++++++++++++++++++++++++++ 1 file changed, 85 insertions(+) create mode 100644 eigentorchconversions.h diff --git a/eigentorchconversions.h b/eigentorchconversions.h new file mode 100644 index 0000000..79b78fe --- /dev/null +++ b/eigentorchconversions.h @@ -0,0 +1,85 @@ +#include <climits> +#include <sys/types.h> +#include <torch/torch.h> +#include <Eigen/Dense> +#include <torch/types.h> +#include <vector> + +#include "tensoroptions.h" + +template <typename V> +bool checkTorchType(const torch::Tensor& tensor) +{ + static_assert(std::is_same<V, float>::value || std::is_same<V, double>::value || + std::is_same<V, int64_t>::value || std::is_same<V, int32_t>::value || std::is_same<V, int8_t>::value, + "This function dose not work with this type"); + if constexpr(std::is_same<V, float>::value) + return tensor.dtype() == torch::kFloat32; + else if constexpr(std::is_same<V, double>::value) + return tensor.dtype() == torch::kFloat64; + else if constexpr(std::is_same<V, int64_t>::value) + return tensor.dtype() == torch::kInt64; + else if constexpr(std::is_same<V, int32_t>::value) + return tensor.dtype() == torch::kInt32; + else if constexpr(std::is_same<V, int8_t>::value) + return tensor.dtype() == torch::kInt8; +} + +template <typename V> +using MatrixXrm = typename Eigen::Matrix<V, Eigen::Dynamic, Eigen::Dynamic, Eigen::RowMajor>; + +template <typename V> +torch::Tensor eigen2libtorch(Eigen::MatrixX<V> &M) +{ + Eigen::Matrix<V, Eigen::Dynamic, Eigen::Dynamic, Eigen::RowMajor> E(M); + std::vector<int64_t> dims = {E.rows(), E.cols()}; + auto T = torch::from_blob(E.data(), dims, tensorOptCpu<V>(false)).clone(); + return T; +} + +template <typename V> +torch::Tensor eigen2libtorch(MatrixXrm<V> &E, bool copydata = true) +{ + std::vector<int64_t> dims = {E.rows(), E.cols()}; + auto T = torch::from_blob(E.data(), dims, tensorOptCpu<V>(false)); + if (copydata) + return T.clone(); + else + return T; +} + +template <typename V> +torch::Tensor eigenVector2libtorch(Eigen::Vector<V, Eigen::Dynamic> &E, bool copydata = true) +{ + std::vector<int64_t> dims = {E.rows()}; + auto T = torch::from_blob(E.data(), dims, tensorOptCpu<V>(false)); + if (copydata) + return T.clone(); + else + return T; +} + +template<typename V> +Eigen::Matrix<V, Eigen::Dynamic, Eigen::Dynamic> libtorch2eigenMaxtrix(torch::Tensor &Tin) +{ + /* + LibTorch is Row-major order and Eigen is Column-major order. + MatrixXrm uses Eigen::RowMajor for compatibility. + */ + assert(checkTorchType<V>(Tin)); + Tin = Tin.contiguous(); + auto T = Tin.to(torch::kCPU); + Eigen::Map<MatrixXrm<V>> E(T.data_ptr<V>(), T.size(0), T.size(1)); + return E; +} + +template<typename V> +Eigen::Vector<V, Eigen::Dynamic> libtorch2eigenVector(torch::Tensor &Tin) +{ + assert(Tin.sizes().size() == 1); + assert(checkTorchType<V>(Tin)); + Tin = Tin.contiguous(); + auto T = Tin.to(torch::kCPU); + Eigen::Map<Eigen::Vector<V, Eigen::Dynamic>> E(T.data_ptr<V>(), T.numel()); + return E; +} -- GitLab