mirror of
https://github.com/open-mmlab/mmengine.git
synced 2025-06-03 21:54:44 +08:00
[Feature]: Set different seed to different rank (#298)
* [Feature]: Set different seed for diff rank * [Feature]: Add log * [Fix]: Fix lint * [Fix]: Fix docstring * [Fix]: Fix sampler seed * [Fix]: Fix log bug * [Fix]: Change diff_seed to diff_rank_seed * [Fix]: Fix lint
This commit is contained in:
parent
12f7d3a0d3
commit
03d5c17ba6
@ -20,8 +20,11 @@ class BaseLoop(metaclass=ABCMeta):
|
||||
def __init__(self, runner, dataloader: Union[DataLoader, Dict]) -> None:
|
||||
self._runner = runner
|
||||
if isinstance(dataloader, dict):
|
||||
# Determine whether or not different ranks use different seed.
|
||||
diff_rank_seed = runner._randomness_cfg.get(
|
||||
'diff_rank_seed', False)
|
||||
self.dataloader = runner.build_dataloader(
|
||||
dataloader, seed=runner.seed)
|
||||
dataloader, seed=runner.seed, diff_rank_seed=diff_rank_seed)
|
||||
else:
|
||||
self.dataloader = dataloader
|
||||
|
||||
|
@ -649,11 +649,16 @@ class Runner:
|
||||
resource.setrlimit(resource.RLIMIT_NOFILE,
|
||||
(soft_limit, hard_limit))
|
||||
|
||||
def set_randomness(self, seed, deterministic: bool = False) -> None:
|
||||
def set_randomness(self,
|
||||
seed,
|
||||
diff_rank_seed: bool = False,
|
||||
deterministic: bool = False) -> None:
|
||||
"""Set random seed to guarantee reproducible results.
|
||||
|
||||
Args:
|
||||
seed (int): A number to set random modules.
|
||||
diff_rank_seed (bool): Whether or not set different seeds according
|
||||
to global rank. Defaults to False.
|
||||
deterministic (bool): Whether to set the deterministic option for
|
||||
CUDNN backend, i.e., set `torch.backends.cudnn.deterministic`
|
||||
to True and `torch.backends.cudnn.benchmark` to False.
|
||||
@ -666,6 +671,9 @@ class Runner:
|
||||
if self._seed is None:
|
||||
self._seed = sync_random_seed()
|
||||
|
||||
if diff_rank_seed:
|
||||
# set different seeds for different ranks
|
||||
self._seed = self._seed + get_rank()
|
||||
random.seed(self._seed)
|
||||
np.random.seed(self._seed)
|
||||
torch.manual_seed(self._seed)
|
||||
@ -1254,7 +1262,8 @@ class Runner:
|
||||
|
||||
@staticmethod
|
||||
def build_dataloader(dataloader: Union[DataLoader, Dict],
|
||||
seed: Optional[int] = None) -> DataLoader:
|
||||
seed: Optional[int] = None,
|
||||
diff_rank_seed: bool = False) -> DataLoader:
|
||||
"""Build dataloader.
|
||||
|
||||
The method builds three components:
|
||||
@ -1277,6 +1286,11 @@ class Runner:
|
||||
build Dataloader object. If ``dataloader`` is a Dataloader
|
||||
object, just returns itself.
|
||||
seed (int, optional): Random seed. Defaults to None.
|
||||
diff_rank_seed (bool): Whether or not set different seeds to
|
||||
different ranks. If True, the seed passed to sampler is set
|
||||
to None, in order to synchronize the seeds used in samplers
|
||||
across different ranks.
|
||||
|
||||
|
||||
Returns:
|
||||
Dataloader: DataLoader build from ``dataloader_cfg``.
|
||||
@ -1300,8 +1314,10 @@ class Runner:
|
||||
# build sampler
|
||||
sampler_cfg = dataloader_cfg.pop('sampler')
|
||||
if isinstance(sampler_cfg, dict):
|
||||
sampler_seed = None if diff_rank_seed else seed
|
||||
sampler = DATA_SAMPLERS.build(
|
||||
sampler_cfg, default_args=dict(dataset=dataset, seed=seed))
|
||||
sampler_cfg,
|
||||
default_args=dict(dataset=dataset, seed=sampler_seed))
|
||||
else:
|
||||
# fallback to raise error in dataloader
|
||||
# if `sampler_cfg` is not a valid type
|
||||
|
@ -998,6 +998,11 @@ class TestRunner(TestCase):
|
||||
self.assertIsInstance(dataloader.sampler, DefaultSampler)
|
||||
self.assertEqual(dataloader.sampler.seed, seed)
|
||||
|
||||
# diff_rank_seed is True
|
||||
dataloader = runner.build_dataloader(
|
||||
cfg, seed=seed, diff_rank_seed=True)
|
||||
self.assertNotEqual(dataloader.sampler.seed, seed)
|
||||
|
||||
def test_build_train_loop(self):
|
||||
cfg = copy.deepcopy(self.epoch_based_cfg)
|
||||
cfg.experiment_name = 'test_build_train_loop'
|
||||
|
Loading…
x
Reference in New Issue
Block a user