mirror of
https://github.com/open-mmlab/mmclassification.git
synced 2025-06-03 21:53:55 +08:00
[Bug] Download dataset only on rank 0 (#273)
* only download dataset on rank 0 * download only on rank 0 * fix bug * fix error message
This commit is contained in:
parent
a06827bc08
commit
bd9411d743
@ -3,6 +3,8 @@ import os.path
|
||||
import pickle
|
||||
|
||||
import numpy as np
|
||||
import torch.distributed as dist
|
||||
from mmcv.runner import get_dist_info
|
||||
|
||||
from .base_dataset import BaseDataset
|
||||
from .builder import DATASETS
|
||||
@ -40,13 +42,21 @@ class CIFAR10(BaseDataset):
|
||||
|
||||
def load_annotations(self):
|
||||
|
||||
if not self._check_integrity():
|
||||
rank, world_size = get_dist_info()
|
||||
|
||||
if rank == 0 and not self._check_integrity():
|
||||
download_and_extract_archive(
|
||||
self.url,
|
||||
self.data_prefix,
|
||||
filename=self.filename,
|
||||
md5=self.tgz_md5)
|
||||
|
||||
if world_size > 1:
|
||||
dist.barrier()
|
||||
assert self._check_integrity(), \
|
||||
'Shared storage seems unavailable. ' \
|
||||
f'Please download the dataset manually through {self.url}.'
|
||||
|
||||
if not self.test_mode:
|
||||
downloaded_list = self.train_list
|
||||
else:
|
||||
|
@ -4,6 +4,8 @@ import os.path as osp
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
from mmcv.runner import get_dist_info, master_only
|
||||
|
||||
from .base_dataset import BaseDataset
|
||||
from .builder import DATASETS
|
||||
@ -50,6 +52,15 @@ class MNIST(BaseDataset):
|
||||
test_image_file) or not osp.exists(test_label_file):
|
||||
self.download()
|
||||
|
||||
_, world_size = get_dist_info()
|
||||
if world_size > 1:
|
||||
dist.barrier()
|
||||
assert osp.exists(train_image_file) and osp.exists(
|
||||
train_label_file) and osp.exists(
|
||||
test_image_file) and osp.exists(test_label_file), \
|
||||
'Shared storage seems unavailable. Please download dataset ' \
|
||||
f'manually through {self.resource_prefix}.'
|
||||
|
||||
train_set = (read_image_file(train_image_file),
|
||||
read_label_file(train_label_file))
|
||||
test_set = (read_image_file(test_image_file),
|
||||
@ -67,6 +78,7 @@ class MNIST(BaseDataset):
|
||||
data_infos.append(info)
|
||||
return data_infos
|
||||
|
||||
@master_only
|
||||
def download(self):
|
||||
os.makedirs(self.data_prefix, exist_ok=True)
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user