From f17d287fa8d914c9f18ad8034dc886cd0fc41421 Mon Sep 17 00:00:00 2001
From: Jannis Klinkenberg <j.klinkenberg@itc.rwth-aachen.de>
Date: Mon, 2 Dec 2024 09:39:19 +0100
Subject: [PATCH] added file locking to ensure proper download in distributed
 setting

---
 pytorch/cifar10_distributed/train_model.py | 4 +++-
 pytorch/mnist_distributed/train_model.py   | 4 +++-
 2 files changed, 6 insertions(+), 2 deletions(-)

diff --git a/pytorch/cifar10_distributed/train_model.py b/pytorch/cifar10_distributed/train_model.py
index 873ca07..c50dcd4 100644
--- a/pytorch/cifar10_distributed/train_model.py
+++ b/pytorch/cifar10_distributed/train_model.py
@@ -10,6 +10,7 @@ from torchvision.models import resnet50
 from torchvision.datasets import CIFAR10
 import torchvision.transforms as transforms
 from torch.utils.data import DistributedSampler, DataLoader
+from filelock import FileLock
 
 def parse_command_line():
     parser = argparse.ArgumentParser()    
@@ -44,7 +45,8 @@ def load_dataset(args):
     
     # load CIFAR10 dataset splits for train and test and apply the transformations
     # note: you need to download the dataset once at the beginning
-    ds_train = CIFAR10("datasets", train=True,  download=True,  transform=trans)
+    with FileLock(os.path.expanduser("~/.dataset_lock")):
+        ds_train = CIFAR10("datasets", train=True,  download=True,  transform=trans)
     ds_test  = CIFAR10("datasets", train=False, download=False, transform=trans)
     
     # define distributed samplers (only for distributed version)
diff --git a/pytorch/mnist_distributed/train_model.py b/pytorch/mnist_distributed/train_model.py
index ceeb81f..5e1e34f 100644
--- a/pytorch/mnist_distributed/train_model.py
+++ b/pytorch/mnist_distributed/train_model.py
@@ -12,6 +12,7 @@ from torchvision.models import resnet50
 from torchvision.datasets import MNIST
 import torchvision.transforms as transforms
 from torch.utils.data import DistributedSampler, DataLoader
+from filelock import FileLock
 
 class Net(nn.Module):
     def __init__(self):
@@ -65,7 +66,8 @@ def load_dataset(args):
     
     # load MNIST dataset splits for train and test and apply the transformations
     # note: you need to download the dataset once at the beginning
-    ds_train = MNIST("datasets", train=True,  download=True,  transform=trans)
+    with FileLock(os.path.expanduser("~/.dataset_lock")):
+        ds_train = MNIST("datasets", train=True,  download=True,  transform=trans)
     ds_test  = MNIST("datasets", train=False, download=False, transform=trans)
     
     # define distributed samplers (only for distributed version)
-- 
GitLab