From 03d5c17ba6a8ce7650ceebdb60e670a01b061e9c Mon Sep 17 00:00:00 2001 From: Yuan Liu <30762564+YuanLiuuuuuu@users.noreply.github.com> Date: Fri, 24 Jun 2022 14:28:16 +0800 Subject: [PATCH] [Feature]: Set different seed to different rank (#298) * [Feature]: Set different seed for diff rank * [Feature]: Add log * [Fix]: Fix lint * [Fix]: Fix docstring * [Fix]: Fix sampler seed * [Fix]: Fix log bug * [Fix]: Change diff_seed to diff_rank_seed * [Fix]: Fix lint --- mmengine/runner/base_loop.py | 5 ++++- mmengine/runner/runner.py | 22 +++++++++++++++++++--- tests/test_runner/test_runner.py | 5 +++++ 3 files changed, 28 insertions(+), 4 deletions(-) diff --git a/mmengine/runner/base_loop.py b/mmengine/runner/base_loop.py index e3916a24..48a6f69c 100644 --- a/mmengine/runner/base_loop.py +++ b/mmengine/runner/base_loop.py @@ -20,8 +20,11 @@ class BaseLoop(metaclass=ABCMeta): def __init__(self, runner, dataloader: Union[DataLoader, Dict]) -> None: self._runner = runner if isinstance(dataloader, dict): + # Determine whether or not different ranks use different seed. + diff_rank_seed = runner._randomness_cfg.get( + 'diff_rank_seed', False) self.dataloader = runner.build_dataloader( - dataloader, seed=runner.seed) + dataloader, seed=runner.seed, diff_rank_seed=diff_rank_seed) else: self.dataloader = dataloader diff --git a/mmengine/runner/runner.py b/mmengine/runner/runner.py index f754e88c..f20de612 100644 --- a/mmengine/runner/runner.py +++ b/mmengine/runner/runner.py @@ -649,11 +649,16 @@ class Runner: resource.setrlimit(resource.RLIMIT_NOFILE, (soft_limit, hard_limit)) - def set_randomness(self, seed, deterministic: bool = False) -> None: + def set_randomness(self, + seed, + diff_rank_seed: bool = False, + deterministic: bool = False) -> None: """Set random seed to guarantee reproducible results. Args: seed (int): A number to set random modules. + diff_rank_seed (bool): Whether or not set different seeds according + to global rank. Defaults to False. deterministic (bool): Whether to set the deterministic option for CUDNN backend, i.e., set `torch.backends.cudnn.deterministic` to True and `torch.backends.cudnn.benchmark` to False. @@ -666,6 +671,9 @@ class Runner: if self._seed is None: self._seed = sync_random_seed() + if diff_rank_seed: + # set different seeds for different ranks + self._seed = self._seed + get_rank() random.seed(self._seed) np.random.seed(self._seed) torch.manual_seed(self._seed) @@ -1254,7 +1262,8 @@ class Runner: @staticmethod def build_dataloader(dataloader: Union[DataLoader, Dict], - seed: Optional[int] = None) -> DataLoader: + seed: Optional[int] = None, + diff_rank_seed: bool = False) -> DataLoader: """Build dataloader. The method builds three components: @@ -1277,6 +1286,11 @@ class Runner: build Dataloader object. If ``dataloader`` is a Dataloader object, just returns itself. seed (int, optional): Random seed. Defaults to None. + diff_rank_seed (bool): Whether or not set different seeds to + different ranks. If True, the seed passed to sampler is set + to None, in order to synchronize the seeds used in samplers + across different ranks. + Returns: Dataloader: DataLoader build from ``dataloader_cfg``. @@ -1300,8 +1314,10 @@ class Runner: # build sampler sampler_cfg = dataloader_cfg.pop('sampler') if isinstance(sampler_cfg, dict): + sampler_seed = None if diff_rank_seed else seed sampler = DATA_SAMPLERS.build( - sampler_cfg, default_args=dict(dataset=dataset, seed=seed)) + sampler_cfg, + default_args=dict(dataset=dataset, seed=sampler_seed)) else: # fallback to raise error in dataloader # if `sampler_cfg` is not a valid type diff --git a/tests/test_runner/test_runner.py b/tests/test_runner/test_runner.py index 39454511..5c75094e 100644 --- a/tests/test_runner/test_runner.py +++ b/tests/test_runner/test_runner.py @@ -998,6 +998,11 @@ class TestRunner(TestCase): self.assertIsInstance(dataloader.sampler, DefaultSampler) self.assertEqual(dataloader.sampler.seed, seed) + # diff_rank_seed is True + dataloader = runner.build_dataloader( + cfg, seed=seed, diff_rank_seed=True) + self.assertNotEqual(dataloader.sampler.seed, seed) + def test_build_train_loop(self): cfg = copy.deepcopy(self.epoch_based_cfg) cfg.experiment_name = 'test_build_train_loop'