mirror of
https://github.com/open-mmlab/mmsegmentation.git
synced 2025-06-03 22:03:48 +08:00
added if statement to account for IterableDatasets doing distributed training (#2151)
This commit is contained in:
parent
6c746fad9c
commit
9d2312b4ac
@ -9,7 +9,7 @@ import torch
|
||||
from mmcv.parallel import collate
|
||||
from mmcv.runner import get_dist_info
|
||||
from mmcv.utils import Registry, build_from_cfg, digit_version
|
||||
from torch.utils.data import DataLoader
|
||||
from torch.utils.data import DataLoader, IterableDataset
|
||||
|
||||
from .samplers import DistributedSampler
|
||||
|
||||
@ -129,12 +129,17 @@ def build_dataloader(dataset,
|
||||
DataLoader: A PyTorch dataloader.
|
||||
"""
|
||||
rank, world_size = get_dist_info()
|
||||
if dist:
|
||||
if dist and not isinstance(dataset, IterableDataset):
|
||||
sampler = DistributedSampler(
|
||||
dataset, world_size, rank, shuffle=shuffle, seed=seed)
|
||||
shuffle = False
|
||||
batch_size = samples_per_gpu
|
||||
num_workers = workers_per_gpu
|
||||
elif dist:
|
||||
sampler = None
|
||||
shuffle = False
|
||||
batch_size = samples_per_gpu
|
||||
num_workers = workers_per_gpu
|
||||
else:
|
||||
sampler = None
|
||||
batch_size = num_gpus * samples_per_gpu
|
||||
|
Loading…
x
Reference in New Issue
Block a user