added if statement to account for IterableDatasets doing distributed training (#2151)

This commit is contained in:
Shirley Wang 2022-10-08 00:14:01 -04:00 committed by GitHub
parent 6c746fad9c
commit 9d2312b4ac
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -9,7 +9,7 @@ import torch
from mmcv.parallel import collate from mmcv.parallel import collate
from mmcv.runner import get_dist_info from mmcv.runner import get_dist_info
from mmcv.utils import Registry, build_from_cfg, digit_version from mmcv.utils import Registry, build_from_cfg, digit_version
from torch.utils.data import DataLoader from torch.utils.data import DataLoader, IterableDataset
from .samplers import DistributedSampler from .samplers import DistributedSampler
@ -129,12 +129,17 @@ def build_dataloader(dataset,
DataLoader: A PyTorch dataloader. DataLoader: A PyTorch dataloader.
""" """
rank, world_size = get_dist_info() rank, world_size = get_dist_info()
if dist: if dist and not isinstance(dataset, IterableDataset):
sampler = DistributedSampler( sampler = DistributedSampler(
dataset, world_size, rank, shuffle=shuffle, seed=seed) dataset, world_size, rank, shuffle=shuffle, seed=seed)
shuffle = False shuffle = False
batch_size = samples_per_gpu batch_size = samples_per_gpu
num_workers = workers_per_gpu num_workers = workers_per_gpu
elif dist:
sampler = None
shuffle = False
batch_size = samples_per_gpu
num_workers = workers_per_gpu
else: else:
sampler = None sampler = None
batch_size = num_gpus * samples_per_gpu batch_size = num_gpus * samples_per_gpu