mirror of
https://github.com/open-mmlab/mmsegmentation.git
synced 2025-06-03 22:03:48 +08:00
[Enchance] Support random seed for distributed sampler (#1411)
* support random seed for distributed sampler * move mmseg/utils/dist_util.py to mmseg/core/utils/dist_util.py * move mmseg/utils/dist_util.py to mmseg/core/utils/dist_util.py * change dist sampler * change dist sampler * fix docstring in sync_random_seed
This commit is contained in:
parent
41d5c13df1
commit
f15a21a30d
@ -1,6 +1,10 @@
|
|||||||
# Copyright (c) OpenMMLab. All rights reserved.
|
# Copyright (c) OpenMMLab. All rights reserved.
|
||||||
|
from .dist_util import check_dist_init, sync_random_seed
|
||||||
from .layer_decay_optimizer_constructor import \
|
from .layer_decay_optimizer_constructor import \
|
||||||
LearningRateDecayOptimizerConstructor
|
LearningRateDecayOptimizerConstructor
|
||||||
from .misc import add_prefix
|
from .misc import add_prefix
|
||||||
|
|
||||||
__all__ = ['add_prefix', 'LearningRateDecayOptimizerConstructor']
|
__all__ = [
|
||||||
|
'add_prefix', 'LearningRateDecayOptimizerConstructor', 'check_dist_init',
|
||||||
|
'sync_random_seed'
|
||||||
|
]
|
||||||
|
46
mmseg/core/utils/dist_util.py
Normal file
46
mmseg/core/utils/dist_util.py
Normal file
@ -0,0 +1,46 @@
|
|||||||
|
# Copyright (c) OpenMMLab. All rights reserved.
|
||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
import torch.distributed as dist
|
||||||
|
from mmcv.runner import get_dist_info
|
||||||
|
|
||||||
|
|
||||||
|
def check_dist_init():
|
||||||
|
return dist.is_available() and dist.is_initialized()
|
||||||
|
|
||||||
|
|
||||||
|
def sync_random_seed(seed=None, device='cuda'):
|
||||||
|
"""Make sure different ranks share the same seed. All workers must call
|
||||||
|
this function, otherwise it will deadlock. This method is generally used in
|
||||||
|
`DistributedSampler`, because the seed should be identical across all
|
||||||
|
processes in the distributed group.
|
||||||
|
|
||||||
|
In distributed sampling, different ranks should sample non-overlapped
|
||||||
|
data in the dataset. Therefore, this function is used to make sure that
|
||||||
|
each rank shuffles the data indices in the same order based
|
||||||
|
on the same seed. Then different ranks could use different indices
|
||||||
|
to select non-overlapped data from the same data list.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
seed (int, Optional): The seed. Default to None.
|
||||||
|
device (str): The device where the seed will be put on.
|
||||||
|
Default to 'cuda'.
|
||||||
|
Returns:
|
||||||
|
int: Seed to be used.
|
||||||
|
"""
|
||||||
|
|
||||||
|
if seed is None:
|
||||||
|
seed = np.random.randint(2**31)
|
||||||
|
assert isinstance(seed, int)
|
||||||
|
|
||||||
|
rank, world_size = get_dist_info()
|
||||||
|
|
||||||
|
if world_size == 1:
|
||||||
|
return seed
|
||||||
|
|
||||||
|
if rank == 0:
|
||||||
|
random_num = torch.tensor(seed, dtype=torch.int32, device=device)
|
||||||
|
else:
|
||||||
|
random_num = torch.tensor(0, dtype=torch.int32, device=device)
|
||||||
|
dist.broadcast(random_num, src=0)
|
||||||
|
return random_num.item()
|
@ -9,7 +9,9 @@ import torch
|
|||||||
from mmcv.parallel import collate
|
from mmcv.parallel import collate
|
||||||
from mmcv.runner import get_dist_info
|
from mmcv.runner import get_dist_info
|
||||||
from mmcv.utils import Registry, build_from_cfg, digit_version
|
from mmcv.utils import Registry, build_from_cfg, digit_version
|
||||||
from torch.utils.data import DataLoader, DistributedSampler
|
from torch.utils.data import DataLoader
|
||||||
|
|
||||||
|
from .samplers import DistributedSampler
|
||||||
|
|
||||||
if platform.system() != 'Windows':
|
if platform.system() != 'Windows':
|
||||||
# https://github.com/pytorch/pytorch/issues/973
|
# https://github.com/pytorch/pytorch/issues/973
|
||||||
@ -129,7 +131,7 @@ def build_dataloader(dataset,
|
|||||||
rank, world_size = get_dist_info()
|
rank, world_size = get_dist_info()
|
||||||
if dist:
|
if dist:
|
||||||
sampler = DistributedSampler(
|
sampler = DistributedSampler(
|
||||||
dataset, world_size, rank, shuffle=shuffle)
|
dataset, world_size, rank, shuffle=shuffle, seed=seed)
|
||||||
shuffle = False
|
shuffle = False
|
||||||
batch_size = samples_per_gpu
|
batch_size = samples_per_gpu
|
||||||
num_workers = workers_per_gpu
|
num_workers = workers_per_gpu
|
||||||
|
4
mmseg/datasets/samplers/__init__.py
Normal file
4
mmseg/datasets/samplers/__init__.py
Normal file
@ -0,0 +1,4 @@
|
|||||||
|
# Copyright (c) OpenMMLab. All rights reserved.
|
||||||
|
from .distributed_sampler import DistributedSampler
|
||||||
|
|
||||||
|
__all__ = ['DistributedSampler']
|
71
mmseg/datasets/samplers/distributed_sampler.py
Normal file
71
mmseg/datasets/samplers/distributed_sampler.py
Normal file
@ -0,0 +1,71 @@
|
|||||||
|
# Copyright (c) OpenMMLab. All rights reserved.
|
||||||
|
from __future__ import division
|
||||||
|
from typing import Iterator, Optional
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from torch.utils.data import Dataset
|
||||||
|
from torch.utils.data import DistributedSampler as _DistributedSampler
|
||||||
|
|
||||||
|
from mmseg.core.utils import sync_random_seed
|
||||||
|
|
||||||
|
|
||||||
|
class DistributedSampler(_DistributedSampler):
|
||||||
|
"""DistributedSampler inheriting from
|
||||||
|
`torch.utils.data.DistributedSampler`.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
datasets (Dataset): the dataset will be loaded.
|
||||||
|
num_replicas (int, optional): Number of processes participating in
|
||||||
|
distributed training. By default, world_size is retrieved from the
|
||||||
|
current distributed group.
|
||||||
|
rank (int, optional): Rank of the current process within num_replicas.
|
||||||
|
By default, rank is retrieved from the current distributed group.
|
||||||
|
shuffle (bool): If True (default), sampler will shuffle the indices.
|
||||||
|
seed (int): random seed used to shuffle the sampler if
|
||||||
|
:attr:`shuffle=True`. This number should be identical across all
|
||||||
|
processes in the distributed group. Default: ``0``.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self,
|
||||||
|
dataset: Dataset,
|
||||||
|
num_replicas: Optional[int] = None,
|
||||||
|
rank: Optional[int] = None,
|
||||||
|
shuffle: bool = True,
|
||||||
|
seed=0) -> None:
|
||||||
|
super().__init__(
|
||||||
|
dataset, num_replicas=num_replicas, rank=rank, shuffle=shuffle)
|
||||||
|
|
||||||
|
# In distributed sampling, different ranks should sample
|
||||||
|
# non-overlapped data in the dataset. Therefore, this function
|
||||||
|
# is used to make sure that each rank shuffles the data indices
|
||||||
|
# in the same order based on the same seed. Then different ranks
|
||||||
|
# could use different indices to select non-overlapped data from the
|
||||||
|
# same data list.
|
||||||
|
self.seed = sync_random_seed(seed)
|
||||||
|
|
||||||
|
def __iter__(self) -> Iterator:
|
||||||
|
"""
|
||||||
|
Yields:
|
||||||
|
Iterator: iterator of indices for rank.
|
||||||
|
"""
|
||||||
|
# deterministically shuffle based on epoch
|
||||||
|
if self.shuffle:
|
||||||
|
g = torch.Generator()
|
||||||
|
# When :attr:`shuffle=True`, this ensures all replicas
|
||||||
|
# use a different random ordering for each epoch.
|
||||||
|
# Otherwise, the next iteration of this sampler will
|
||||||
|
# yield the same ordering.
|
||||||
|
g.manual_seed(self.epoch + self.seed)
|
||||||
|
indices = torch.randperm(len(self.dataset), generator=g).tolist()
|
||||||
|
else:
|
||||||
|
indices = torch.arange(len(self.dataset)).tolist()
|
||||||
|
|
||||||
|
# add extra samples to make it evenly divisible
|
||||||
|
indices += indices[:(self.total_size - len(indices))]
|
||||||
|
assert len(indices) == self.total_size
|
||||||
|
|
||||||
|
# subsample
|
||||||
|
indices = indices[self.rank:self.total_size:self.num_replicas]
|
||||||
|
assert len(indices) == self.num_samples
|
||||||
|
|
||||||
|
return iter(indices)
|
Loading…
x
Reference in New Issue
Block a user