Add DistSamplerSeedHook for when runner is EpochBasedRunner (#1449)

* Add DistSamplerSeedHook for when runner is EpochBasedRunner

* add comment
This commit is contained in:
Miao Zheng 2022-04-06 22:04:11 +08:00 committed by GitHub
parent 549616888e
commit 4bc2a30ea0

View File

@ -7,7 +7,8 @@ import numpy as np
import torch import torch
import torch.distributed as dist import torch.distributed as dist
from mmcv.parallel import MMDataParallel, MMDistributedDataParallel from mmcv.parallel import MMDataParallel, MMDistributedDataParallel
from mmcv.runner import HOOKS, build_optimizer, build_runner, get_dist_info from mmcv.runner import (HOOKS, DistSamplerSeedHook, EpochBasedRunner,
build_optimizer, build_runner, get_dist_info)
from mmcv.utils import build_from_cfg from mmcv.utils import build_from_cfg
from mmseg import digit_version from mmseg import digit_version
@ -128,6 +129,12 @@ def train_segmentor(model,
runner.register_training_hooks(cfg.lr_config, cfg.optimizer_config, runner.register_training_hooks(cfg.lr_config, cfg.optimizer_config,
cfg.checkpoint_config, cfg.log_config, cfg.checkpoint_config, cfg.log_config,
cfg.get('momentum_config', None)) cfg.get('momentum_config', None))
if distributed:
# when distributed training by epoch, using`DistSamplerSeedHook` to set
# the different seed to distributed sampler for each epoch, it will
# shuffle dataset at each epoch and avoid overfitting.
if isinstance(runner, EpochBasedRunner):
runner.register_hook(DistSamplerSeedHook())
# an ugly walkaround to make the .log and .log.json filenames the same # an ugly walkaround to make the .log and .log.json filenames the same
runner.timestamp = timestamp runner.timestamp = timestamp