[Enchance] Support random seed for distributed sampler (#1411)
* support random seed for distributed sampler * move mmseg/utils/dist_util.py to mmseg/core/utils/dist_util.py * move mmseg/utils/dist_util.py to mmseg/core/utils/dist_util.py * change dist sampler * change dist sampler * fix docstring in sync_random_seedpull/1801/head
parent
41d5c13df1
commit
f15a21a30d
|
@ -1,6 +1,10 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from .dist_util import check_dist_init, sync_random_seed
|
||||
from .layer_decay_optimizer_constructor import \
|
||||
LearningRateDecayOptimizerConstructor
|
||||
from .misc import add_prefix
|
||||
|
||||
__all__ = ['add_prefix', 'LearningRateDecayOptimizerConstructor']
|
||||
__all__ = [
|
||||
'add_prefix', 'LearningRateDecayOptimizerConstructor', 'check_dist_init',
|
||||
'sync_random_seed'
|
||||
]
|
||||
|
|
|
@ -0,0 +1,46 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
from mmcv.runner import get_dist_info
|
||||
|
||||
|
||||
def check_dist_init():
|
||||
return dist.is_available() and dist.is_initialized()
|
||||
|
||||
|
||||
def sync_random_seed(seed=None, device='cuda'):
|
||||
"""Make sure different ranks share the same seed. All workers must call
|
||||
this function, otherwise it will deadlock. This method is generally used in
|
||||
`DistributedSampler`, because the seed should be identical across all
|
||||
processes in the distributed group.
|
||||
|
||||
In distributed sampling, different ranks should sample non-overlapped
|
||||
data in the dataset. Therefore, this function is used to make sure that
|
||||
each rank shuffles the data indices in the same order based
|
||||
on the same seed. Then different ranks could use different indices
|
||||
to select non-overlapped data from the same data list.
|
||||
|
||||
Args:
|
||||
seed (int, Optional): The seed. Default to None.
|
||||
device (str): The device where the seed will be put on.
|
||||
Default to 'cuda'.
|
||||
Returns:
|
||||
int: Seed to be used.
|
||||
"""
|
||||
|
||||
if seed is None:
|
||||
seed = np.random.randint(2**31)
|
||||
assert isinstance(seed, int)
|
||||
|
||||
rank, world_size = get_dist_info()
|
||||
|
||||
if world_size == 1:
|
||||
return seed
|
||||
|
||||
if rank == 0:
|
||||
random_num = torch.tensor(seed, dtype=torch.int32, device=device)
|
||||
else:
|
||||
random_num = torch.tensor(0, dtype=torch.int32, device=device)
|
||||
dist.broadcast(random_num, src=0)
|
||||
return random_num.item()
|
|
@ -9,7 +9,9 @@ import torch
|
|||
from mmcv.parallel import collate
|
||||
from mmcv.runner import get_dist_info
|
||||
from mmcv.utils import Registry, build_from_cfg, digit_version
|
||||
from torch.utils.data import DataLoader, DistributedSampler
|
||||
from torch.utils.data import DataLoader
|
||||
|
||||
from .samplers import DistributedSampler
|
||||
|
||||
if platform.system() != 'Windows':
|
||||
# https://github.com/pytorch/pytorch/issues/973
|
||||
|
@ -129,7 +131,7 @@ def build_dataloader(dataset,
|
|||
rank, world_size = get_dist_info()
|
||||
if dist:
|
||||
sampler = DistributedSampler(
|
||||
dataset, world_size, rank, shuffle=shuffle)
|
||||
dataset, world_size, rank, shuffle=shuffle, seed=seed)
|
||||
shuffle = False
|
||||
batch_size = samples_per_gpu
|
||||
num_workers = workers_per_gpu
|
||||
|
|
|
@ -0,0 +1,4 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from .distributed_sampler import DistributedSampler
|
||||
|
||||
__all__ = ['DistributedSampler']
|
|
@ -0,0 +1,71 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from __future__ import division
|
||||
from typing import Iterator, Optional
|
||||
|
||||
import torch
|
||||
from torch.utils.data import Dataset
|
||||
from torch.utils.data import DistributedSampler as _DistributedSampler
|
||||
|
||||
from mmseg.core.utils import sync_random_seed
|
||||
|
||||
|
||||
class DistributedSampler(_DistributedSampler):
|
||||
"""DistributedSampler inheriting from
|
||||
`torch.utils.data.DistributedSampler`.
|
||||
|
||||
Args:
|
||||
datasets (Dataset): the dataset will be loaded.
|
||||
num_replicas (int, optional): Number of processes participating in
|
||||
distributed training. By default, world_size is retrieved from the
|
||||
current distributed group.
|
||||
rank (int, optional): Rank of the current process within num_replicas.
|
||||
By default, rank is retrieved from the current distributed group.
|
||||
shuffle (bool): If True (default), sampler will shuffle the indices.
|
||||
seed (int): random seed used to shuffle the sampler if
|
||||
:attr:`shuffle=True`. This number should be identical across all
|
||||
processes in the distributed group. Default: ``0``.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
dataset: Dataset,
|
||||
num_replicas: Optional[int] = None,
|
||||
rank: Optional[int] = None,
|
||||
shuffle: bool = True,
|
||||
seed=0) -> None:
|
||||
super().__init__(
|
||||
dataset, num_replicas=num_replicas, rank=rank, shuffle=shuffle)
|
||||
|
||||
# In distributed sampling, different ranks should sample
|
||||
# non-overlapped data in the dataset. Therefore, this function
|
||||
# is used to make sure that each rank shuffles the data indices
|
||||
# in the same order based on the same seed. Then different ranks
|
||||
# could use different indices to select non-overlapped data from the
|
||||
# same data list.
|
||||
self.seed = sync_random_seed(seed)
|
||||
|
||||
def __iter__(self) -> Iterator:
|
||||
"""
|
||||
Yields:
|
||||
Iterator: iterator of indices for rank.
|
||||
"""
|
||||
# deterministically shuffle based on epoch
|
||||
if self.shuffle:
|
||||
g = torch.Generator()
|
||||
# When :attr:`shuffle=True`, this ensures all replicas
|
||||
# use a different random ordering for each epoch.
|
||||
# Otherwise, the next iteration of this sampler will
|
||||
# yield the same ordering.
|
||||
g.manual_seed(self.epoch + self.seed)
|
||||
indices = torch.randperm(len(self.dataset), generator=g).tolist()
|
||||
else:
|
||||
indices = torch.arange(len(self.dataset)).tolist()
|
||||
|
||||
# add extra samples to make it evenly divisible
|
||||
indices += indices[:(self.total_size - len(indices))]
|
||||
assert len(indices) == self.total_size
|
||||
|
||||
# subsample
|
||||
indices = indices[self.rank:self.total_size:self.num_replicas]
|
||||
assert len(indices) == self.num_samples
|
||||
|
||||
return iter(indices)
|
Loading…
Reference in New Issue