mirror of
https://github.com/open-mmlab/mmsegmentation.git
synced 2025-06-03 22:03:48 +08:00
added if statement to account for IterableDatasets doing distributed training (#2151)
This commit is contained in:
parent
6c746fad9c
commit
9d2312b4ac
@ -9,7 +9,7 @@ import torch
|
|||||||
from mmcv.parallel import collate
|
from mmcv.parallel import collate
|
||||||
from mmcv.runner import get_dist_info
|
from mmcv.runner import get_dist_info
|
||||||
from mmcv.utils import Registry, build_from_cfg, digit_version
|
from mmcv.utils import Registry, build_from_cfg, digit_version
|
||||||
from torch.utils.data import DataLoader
|
from torch.utils.data import DataLoader, IterableDataset
|
||||||
|
|
||||||
from .samplers import DistributedSampler
|
from .samplers import DistributedSampler
|
||||||
|
|
||||||
@ -129,12 +129,17 @@ def build_dataloader(dataset,
|
|||||||
DataLoader: A PyTorch dataloader.
|
DataLoader: A PyTorch dataloader.
|
||||||
"""
|
"""
|
||||||
rank, world_size = get_dist_info()
|
rank, world_size = get_dist_info()
|
||||||
if dist:
|
if dist and not isinstance(dataset, IterableDataset):
|
||||||
sampler = DistributedSampler(
|
sampler = DistributedSampler(
|
||||||
dataset, world_size, rank, shuffle=shuffle, seed=seed)
|
dataset, world_size, rank, shuffle=shuffle, seed=seed)
|
||||||
shuffle = False
|
shuffle = False
|
||||||
batch_size = samples_per_gpu
|
batch_size = samples_per_gpu
|
||||||
num_workers = workers_per_gpu
|
num_workers = workers_per_gpu
|
||||||
|
elif dist:
|
||||||
|
sampler = None
|
||||||
|
shuffle = False
|
||||||
|
batch_size = samples_per_gpu
|
||||||
|
num_workers = workers_per_gpu
|
||||||
else:
|
else:
|
||||||
sampler = None
|
sampler = None
|
||||||
batch_size = num_gpus * samples_per_gpu
|
batch_size = num_gpus * samples_per_gpu
|
||||||
|
Loading…
x
Reference in New Issue
Block a user