mirror of
https://github.com/open-mmlab/mmsegmentation.git
synced 2025-06-03 22:03:48 +08:00
Add DistSamplerSeedHook for when runner is EpochBasedRunner (#1449)
* Add DistSamplerSeedHook for when runner is EpochBasedRunner * add comment
This commit is contained in:
parent
549616888e
commit
4bc2a30ea0
@ -7,7 +7,8 @@ import numpy as np
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
from mmcv.parallel import MMDataParallel, MMDistributedDataParallel
|
||||
from mmcv.runner import HOOKS, build_optimizer, build_runner, get_dist_info
|
||||
from mmcv.runner import (HOOKS, DistSamplerSeedHook, EpochBasedRunner,
|
||||
build_optimizer, build_runner, get_dist_info)
|
||||
from mmcv.utils import build_from_cfg
|
||||
|
||||
from mmseg import digit_version
|
||||
@ -128,6 +129,12 @@ def train_segmentor(model,
|
||||
runner.register_training_hooks(cfg.lr_config, cfg.optimizer_config,
|
||||
cfg.checkpoint_config, cfg.log_config,
|
||||
cfg.get('momentum_config', None))
|
||||
if distributed:
|
||||
# when distributed training by epoch, using`DistSamplerSeedHook` to set
|
||||
# the different seed to distributed sampler for each epoch, it will
|
||||
# shuffle dataset at each epoch and avoid overfitting.
|
||||
if isinstance(runner, EpochBasedRunner):
|
||||
runner.register_hook(DistSamplerSeedHook())
|
||||
|
||||
# an ugly walkaround to make the .log and .log.json filenames the same
|
||||
runner.timestamp = timestamp
|
||||
|
Loading…
x
Reference in New Issue
Block a user