From 9d2312b4ac2370dad56456329be2691fcd60dafc Mon Sep 17 00:00:00 2001 From: Shirley Wang <43547424+ShirleyWangCVR@users.noreply.github.com> Date: Sat, 8 Oct 2022 00:14:01 -0400 Subject: [PATCH] added if statement to account for IterableDatasets doing distributed training (#2151) --- mmseg/datasets/builder.py | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/mmseg/datasets/builder.py b/mmseg/datasets/builder.py index 4d852d365..49ee63373 100644 --- a/mmseg/datasets/builder.py +++ b/mmseg/datasets/builder.py @@ -9,7 +9,7 @@ import torch from mmcv.parallel import collate from mmcv.runner import get_dist_info from mmcv.utils import Registry, build_from_cfg, digit_version -from torch.utils.data import DataLoader +from torch.utils.data import DataLoader, IterableDataset from .samplers import DistributedSampler @@ -129,12 +129,17 @@ def build_dataloader(dataset, DataLoader: A PyTorch dataloader. """ rank, world_size = get_dist_info() - if dist: + if dist and not isinstance(dataset, IterableDataset): sampler = DistributedSampler( dataset, world_size, rank, shuffle=shuffle, seed=seed) shuffle = False batch_size = samples_per_gpu num_workers = workers_per_gpu + elif dist: + sampler = None + shuffle = False + batch_size = samples_per_gpu + num_workers = workers_per_gpu else: sampler = None batch_size = num_gpus * samples_per_gpu