Add DistSamplerSeedHook for when runner is EpochBasedRunner (#1449)

* Add DistSamplerSeedHook for when runner is EpochBasedRunner

* add comment
This commit is contained in:
Miao Zheng 2022-04-06 22:04:11 +08:00 committed by GitHub
parent 549616888e
commit 4bc2a30ea0

View File

@ -7,7 +7,8 @@ import numpy as np
import torch
import torch.distributed as dist
from mmcv.parallel import MMDataParallel, MMDistributedDataParallel
from mmcv.runner import HOOKS, build_optimizer, build_runner, get_dist_info
from mmcv.runner import (HOOKS, DistSamplerSeedHook, EpochBasedRunner,
build_optimizer, build_runner, get_dist_info)
from mmcv.utils import build_from_cfg
from mmseg import digit_version
@ -128,6 +129,12 @@ def train_segmentor(model,
runner.register_training_hooks(cfg.lr_config, cfg.optimizer_config,
cfg.checkpoint_config, cfg.log_config,
cfg.get('momentum_config', None))
if distributed:
# when distributed training by epoch, using`DistSamplerSeedHook` to set
# the different seed to distributed sampler for each epoch, it will
# shuffle dataset at each epoch and avoid overfitting.
if isinstance(runner, EpochBasedRunner):
runner.register_hook(DistSamplerSeedHook())
# an ugly walkaround to make the .log and .log.json filenames the same
runner.timestamp = timestamp