diff --git a/mmcls/datasets/cifar.py b/mmcls/datasets/cifar.py index a2c22d1ea..f3159ea8e 100644 --- a/mmcls/datasets/cifar.py +++ b/mmcls/datasets/cifar.py @@ -3,6 +3,8 @@ import os.path import pickle import numpy as np +import torch.distributed as dist +from mmcv.runner import get_dist_info from .base_dataset import BaseDataset from .builder import DATASETS @@ -40,13 +42,21 @@ class CIFAR10(BaseDataset): def load_annotations(self): - if not self._check_integrity(): + rank, world_size = get_dist_info() + + if rank == 0 and not self._check_integrity(): download_and_extract_archive( self.url, self.data_prefix, filename=self.filename, md5=self.tgz_md5) + if world_size > 1: + dist.barrier() + assert self._check_integrity(), \ + 'Shared storage seems unavailable. ' \ + f'Please download the dataset manually through {self.url}.' + if not self.test_mode: downloaded_list = self.train_list else: diff --git a/mmcls/datasets/mnist.py b/mmcls/datasets/mnist.py index 0bd417255..f00ef20e2 100644 --- a/mmcls/datasets/mnist.py +++ b/mmcls/datasets/mnist.py @@ -4,6 +4,8 @@ import os.path as osp import numpy as np import torch +import torch.distributed as dist +from mmcv.runner import get_dist_info, master_only from .base_dataset import BaseDataset from .builder import DATASETS @@ -50,6 +52,15 @@ class MNIST(BaseDataset): test_image_file) or not osp.exists(test_label_file): self.download() + _, world_size = get_dist_info() + if world_size > 1: + dist.barrier() + assert osp.exists(train_image_file) and osp.exists( + train_label_file) and osp.exists( + test_image_file) and osp.exists(test_label_file), \ + 'Shared storage seems unavailable. Please download dataset ' \ + f'manually through {self.resource_prefix}.' + train_set = (read_image_file(train_image_file), read_label_file(train_label_file)) test_set = (read_image_file(test_image_file), @@ -67,6 +78,7 @@ class MNIST(BaseDataset): data_infos.append(info) return data_infos + @master_only def download(self): os.makedirs(self.data_prefix, exist_ok=True)