From 25b9c15dbe4010a5dd050ade6753bb50d790131e Mon Sep 17 00:00:00 2001
From: Nassim Bouteldja <nbouteldja@ukaachen.de>
Date: Wed, 13 Apr 2022 21:39:45 +0200
Subject: [PATCH] Upload New File

---
 nnUnet/to_torch.py | 31 +++++++++++++++++++++++++++++++
 1 file changed, 31 insertions(+)
 create mode 100644 nnUnet/to_torch.py

diff --git a/nnUnet/to_torch.py b/nnUnet/to_torch.py
new file mode 100644
index 0000000..ab68035
--- /dev/null
+++ b/nnUnet/to_torch.py
@@ -0,0 +1,31 @@
+#    Copyright 2020 Division of Medical Image Computing, German Cancer Research Center (DKFZ), Heidelberg, Germany
+#
+#    Licensed under the Apache License, Version 2.0 (the "License");
+#    you may not use this file except in compliance with the License.
+#    You may obtain a copy of the License at
+#
+#        http://www.apache.org/licenses/LICENSE-2.0
+#
+#    Unless required by applicable law or agreed to in writing, software
+#    distributed under the License is distributed on an "AS IS" BASIS,
+#    WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+#    See the License for the specific language governing permissions and
+#    limitations under the License.
+
+import torch
+
+
+def maybe_to_torch(d):
+    if isinstance(d, list):
+        d = [maybe_to_torch(i) if not isinstance(i, torch.Tensor) else i for i in d]
+    elif not isinstance(d, torch.Tensor):
+        d = torch.from_numpy(d).float()
+    return d
+
+
+def to_cuda(data, non_blocking=True, gpu_id=0):
+    if isinstance(data, list):
+        data = [i.cuda(gpu_id, non_blocking=non_blocking) for i in data]
+    else:
+        data = data.cuda(gpu_id, non_blocking=non_blocking)
+    return data
-- 
GitLab