diff --git a/mmseg/apis/train.py b/mmseg/apis/train.py index 760701be6..be84a89e6 100644 --- a/mmseg/apis/train.py +++ b/mmseg/apis/train.py @@ -7,7 +7,8 @@ import numpy as np import torch import torch.distributed as dist from mmcv.parallel import MMDataParallel, MMDistributedDataParallel -from mmcv.runner import HOOKS, build_optimizer, build_runner, get_dist_info +from mmcv.runner import (HOOKS, DistSamplerSeedHook, EpochBasedRunner, + build_optimizer, build_runner, get_dist_info) from mmcv.utils import build_from_cfg from mmseg import digit_version @@ -128,6 +129,12 @@ def train_segmentor(model, runner.register_training_hooks(cfg.lr_config, cfg.optimizer_config, cfg.checkpoint_config, cfg.log_config, cfg.get('momentum_config', None)) + if distributed: + # when distributed training by epoch, using`DistSamplerSeedHook` to set + # the different seed to distributed sampler for each epoch, it will + # shuffle dataset at each epoch and avoid overfitting. + if isinstance(runner, EpochBasedRunner): + runner.register_hook(DistSamplerSeedHook()) # an ugly walkaround to make the .log and .log.json filenames the same runner.timestamp = timestamp