diff --git a/pytorch/cifar10_distributed/train_model.py b/pytorch/cifar10_distributed/train_model.py index 873ca0785aa6a4f421253daa7c529e13e017f34e..c50dcd424787491f734ea5ddb407140bccda6d37 100644 --- a/pytorch/cifar10_distributed/train_model.py +++ b/pytorch/cifar10_distributed/train_model.py @@ -10,6 +10,7 @@ from torchvision.models import resnet50 from torchvision.datasets import CIFAR10 import torchvision.transforms as transforms from torch.utils.data import DistributedSampler, DataLoader +from filelock import FileLock def parse_command_line(): parser = argparse.ArgumentParser() @@ -44,7 +45,8 @@ def load_dataset(args): # load CIFAR10 dataset splits for train and test and apply the transformations # note: you need to download the dataset once at the beginning - ds_train = CIFAR10("datasets", train=True, download=True, transform=trans) + with FileLock(os.path.expanduser("~/.dataset_lock")): + ds_train = CIFAR10("datasets", train=True, download=True, transform=trans) ds_test = CIFAR10("datasets", train=False, download=False, transform=trans) # define distributed samplers (only for distributed version) diff --git a/pytorch/mnist_distributed/train_model.py b/pytorch/mnist_distributed/train_model.py index ceeb81f8ec6e189a143598cd4453feee1b7b7517..5e1e34f83bf81f9e41bc44b96d9bd1da99eba124 100644 --- a/pytorch/mnist_distributed/train_model.py +++ b/pytorch/mnist_distributed/train_model.py @@ -12,6 +12,7 @@ from torchvision.models import resnet50 from torchvision.datasets import MNIST import torchvision.transforms as transforms from torch.utils.data import DistributedSampler, DataLoader +from filelock import FileLock class Net(nn.Module): def __init__(self): @@ -65,7 +66,8 @@ def load_dataset(args): # load MNIST dataset splits for train and test and apply the transformations # note: you need to download the dataset once at the beginning - ds_train = MNIST("datasets", train=True, download=True, transform=trans) + with FileLock(os.path.expanduser("~/.dataset_lock")): + ds_train = MNIST("datasets", train=True, download=True, transform=trans) ds_test = MNIST("datasets", train=False, download=False, transform=trans) # define distributed samplers (only for distributed version)