added if statement to account for IterableDatasets doing distributed training (#2151)

This commit is contained in:
Shirley Wang 2022-10-08 00:14:01 -04:00 committed by GitHub
parent 6c746fad9c
commit 9d2312b4ac
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -9,7 +9,7 @@ import torch
from mmcv.parallel import collate
from mmcv.runner import get_dist_info
from mmcv.utils import Registry, build_from_cfg, digit_version
from torch.utils.data import DataLoader
from torch.utils.data import DataLoader, IterableDataset
from .samplers import DistributedSampler
@ -129,12 +129,17 @@ def build_dataloader(dataset,
DataLoader: A PyTorch dataloader.
"""
rank, world_size = get_dist_info()
if dist:
if dist and not isinstance(dataset, IterableDataset):
sampler = DistributedSampler(
dataset, world_size, rank, shuffle=shuffle, seed=seed)
shuffle = False
batch_size = samples_per_gpu
num_workers = workers_per_gpu
elif dist:
sampler = None
shuffle = False
batch_size = samples_per_gpu
num_workers = workers_per_gpu
else:
sampler = None
batch_size = num_gpus * samples_per_gpu