From f17d287fa8d914c9f18ad8034dc886cd0fc41421 Mon Sep 17 00:00:00 2001 From: Jannis Klinkenberg <j.klinkenberg@itc.rwth-aachen.de> Date: Mon, 2 Dec 2024 09:39:19 +0100 Subject: [PATCH] added file locking to ensure proper download in distributed setting --- pytorch/cifar10_distributed/train_model.py | 4 +++- pytorch/mnist_distributed/train_model.py | 4 +++- 2 files changed, 6 insertions(+), 2 deletions(-) diff --git a/pytorch/cifar10_distributed/train_model.py b/pytorch/cifar10_distributed/train_model.py index 873ca07..c50dcd4 100644 --- a/pytorch/cifar10_distributed/train_model.py +++ b/pytorch/cifar10_distributed/train_model.py @@ -10,6 +10,7 @@ from torchvision.models import resnet50 from torchvision.datasets import CIFAR10 import torchvision.transforms as transforms from torch.utils.data import DistributedSampler, DataLoader +from filelock import FileLock def parse_command_line(): parser = argparse.ArgumentParser() @@ -44,7 +45,8 @@ def load_dataset(args): # load CIFAR10 dataset splits for train and test and apply the transformations # note: you need to download the dataset once at the beginning - ds_train = CIFAR10("datasets", train=True, download=True, transform=trans) + with FileLock(os.path.expanduser("~/.dataset_lock")): + ds_train = CIFAR10("datasets", train=True, download=True, transform=trans) ds_test = CIFAR10("datasets", train=False, download=False, transform=trans) # define distributed samplers (only for distributed version) diff --git a/pytorch/mnist_distributed/train_model.py b/pytorch/mnist_distributed/train_model.py index ceeb81f..5e1e34f 100644 --- a/pytorch/mnist_distributed/train_model.py +++ b/pytorch/mnist_distributed/train_model.py @@ -12,6 +12,7 @@ from torchvision.models import resnet50 from torchvision.datasets import MNIST import torchvision.transforms as transforms from torch.utils.data import DistributedSampler, DataLoader +from filelock import FileLock class Net(nn.Module): def __init__(self): @@ -65,7 +66,8 @@ def load_dataset(args): # load MNIST dataset splits for train and test and apply the transformations # note: you need to download the dataset once at the beginning - ds_train = MNIST("datasets", train=True, download=True, transform=trans) + with FileLock(os.path.expanduser("~/.dataset_lock")): + ds_train = MNIST("datasets", train=True, download=True, transform=trans) ds_test = MNIST("datasets", train=False, download=False, transform=trans) # define distributed samplers (only for distributed version) -- GitLab