Skip to content
Snippets Groups Projects
Verified Commit f17d287f authored by Jannis Klinkenberg's avatar Jannis Klinkenberg
Browse files

added file locking to ensure proper download in distributed setting

parent 6fe19f38
No related branches found
No related tags found
No related merge requests found
......@@ -10,6 +10,7 @@ from torchvision.models import resnet50
from torchvision.datasets import CIFAR10
import torchvision.transforms as transforms
from torch.utils.data import DistributedSampler, DataLoader
from filelock import FileLock
def parse_command_line():
parser = argparse.ArgumentParser()
......@@ -44,6 +45,7 @@ def load_dataset(args):
# load CIFAR10 dataset splits for train and test and apply the transformations
# note: you need to download the dataset once at the beginning
with FileLock(os.path.expanduser("~/.dataset_lock")):
ds_train = CIFAR10("datasets", train=True, download=True, transform=trans)
ds_test = CIFAR10("datasets", train=False, download=False, transform=trans)
......
......@@ -12,6 +12,7 @@ from torchvision.models import resnet50
from torchvision.datasets import MNIST
import torchvision.transforms as transforms
from torch.utils.data import DistributedSampler, DataLoader
from filelock import FileLock
class Net(nn.Module):
def __init__(self):
......@@ -65,6 +66,7 @@ def load_dataset(args):
# load MNIST dataset splits for train and test and apply the transformations
# note: you need to download the dataset once at the beginning
with FileLock(os.path.expanduser("~/.dataset_lock")):
ds_train = MNIST("datasets", train=True, download=True, transform=trans)
ds_test = MNIST("datasets", train=False, download=False, transform=trans)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment