[Enchance] Support random seed for distributed sampler (#1411)

* support random seed for distributed sampler

* move mmseg/utils/dist_util.py to mmseg/core/utils/dist_util.py

* move mmseg/utils/dist_util.py to mmseg/core/utils/dist_util.py

* change dist sampler

* change dist sampler

* fix docstring in sync_random_seed
pull/1801/head
MengzhangLI 2022-03-28 23:50:39 +08:00 committed by GitHub
parent 41d5c13df1
commit f15a21a30d
5 changed files with 130 additions and 3 deletions

View File

@ -1,6 +1,10 @@
# Copyright (c) OpenMMLab. All rights reserved.
from .dist_util import check_dist_init, sync_random_seed
from .layer_decay_optimizer_constructor import \
LearningRateDecayOptimizerConstructor
from .misc import add_prefix
__all__ = ['add_prefix', 'LearningRateDecayOptimizerConstructor']
__all__ = [
'add_prefix', 'LearningRateDecayOptimizerConstructor', 'check_dist_init',
'sync_random_seed'
]

View File

@ -0,0 +1,46 @@
# Copyright (c) OpenMMLab. All rights reserved.
import numpy as np
import torch
import torch.distributed as dist
from mmcv.runner import get_dist_info
def check_dist_init():
return dist.is_available() and dist.is_initialized()
def sync_random_seed(seed=None, device='cuda'):
"""Make sure different ranks share the same seed. All workers must call
this function, otherwise it will deadlock. This method is generally used in
`DistributedSampler`, because the seed should be identical across all
processes in the distributed group.
In distributed sampling, different ranks should sample non-overlapped
data in the dataset. Therefore, this function is used to make sure that
each rank shuffles the data indices in the same order based
on the same seed. Then different ranks could use different indices
to select non-overlapped data from the same data list.
Args:
seed (int, Optional): The seed. Default to None.
device (str): The device where the seed will be put on.
Default to 'cuda'.
Returns:
int: Seed to be used.
"""
if seed is None:
seed = np.random.randint(2**31)
assert isinstance(seed, int)
rank, world_size = get_dist_info()
if world_size == 1:
return seed
if rank == 0:
random_num = torch.tensor(seed, dtype=torch.int32, device=device)
else:
random_num = torch.tensor(0, dtype=torch.int32, device=device)
dist.broadcast(random_num, src=0)
return random_num.item()

View File

@ -9,7 +9,9 @@ import torch
from mmcv.parallel import collate
from mmcv.runner import get_dist_info
from mmcv.utils import Registry, build_from_cfg, digit_version
from torch.utils.data import DataLoader, DistributedSampler
from torch.utils.data import DataLoader
from .samplers import DistributedSampler
if platform.system() != 'Windows':
# https://github.com/pytorch/pytorch/issues/973
@ -129,7 +131,7 @@ def build_dataloader(dataset,
rank, world_size = get_dist_info()
if dist:
sampler = DistributedSampler(
dataset, world_size, rank, shuffle=shuffle)
dataset, world_size, rank, shuffle=shuffle, seed=seed)
shuffle = False
batch_size = samples_per_gpu
num_workers = workers_per_gpu

View File

@ -0,0 +1,4 @@
# Copyright (c) OpenMMLab. All rights reserved.
from .distributed_sampler import DistributedSampler
__all__ = ['DistributedSampler']

View File

@ -0,0 +1,71 @@
# Copyright (c) OpenMMLab. All rights reserved.
from __future__ import division
from typing import Iterator, Optional
import torch
from torch.utils.data import Dataset
from torch.utils.data import DistributedSampler as _DistributedSampler
from mmseg.core.utils import sync_random_seed
class DistributedSampler(_DistributedSampler):
"""DistributedSampler inheriting from
`torch.utils.data.DistributedSampler`.
Args:
datasets (Dataset): the dataset will be loaded.
num_replicas (int, optional): Number of processes participating in
distributed training. By default, world_size is retrieved from the
current distributed group.
rank (int, optional): Rank of the current process within num_replicas.
By default, rank is retrieved from the current distributed group.
shuffle (bool): If True (default), sampler will shuffle the indices.
seed (int): random seed used to shuffle the sampler if
:attr:`shuffle=True`. This number should be identical across all
processes in the distributed group. Default: ``0``.
"""
def __init__(self,
dataset: Dataset,
num_replicas: Optional[int] = None,
rank: Optional[int] = None,
shuffle: bool = True,
seed=0) -> None:
super().__init__(
dataset, num_replicas=num_replicas, rank=rank, shuffle=shuffle)
# In distributed sampling, different ranks should sample
# non-overlapped data in the dataset. Therefore, this function
# is used to make sure that each rank shuffles the data indices
# in the same order based on the same seed. Then different ranks
# could use different indices to select non-overlapped data from the
# same data list.
self.seed = sync_random_seed(seed)
def __iter__(self) -> Iterator:
"""
Yields:
Iterator: iterator of indices for rank.
"""
# deterministically shuffle based on epoch
if self.shuffle:
g = torch.Generator()
# When :attr:`shuffle=True`, this ensures all replicas
# use a different random ordering for each epoch.
# Otherwise, the next iteration of this sampler will
# yield the same ordering.
g.manual_seed(self.epoch + self.seed)
indices = torch.randperm(len(self.dataset), generator=g).tolist()
else:
indices = torch.arange(len(self.dataset)).tolist()
# add extra samples to make it evenly divisible
indices += indices[:(self.total_size - len(indices))]
assert len(indices) == self.total_size
# subsample
indices = indices[self.rank:self.total_size:self.num_replicas]
assert len(indices) == self.num_samples
return iter(indices)