[Fix]: Fix is_model_wrapper and add DistSamplerSeedHook to default hooks. (#172)

* [Fix]: Fix model_wrapper and add DistSamplerSeedHook as default hook.

* add comments
pull/167/head
RangiLyu 2022-04-08 22:18:23 +08:00 committed by GitHub
parent 93d22757cf
commit 3d830a28b6
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 44 additions and 20 deletions

View File

@ -20,11 +20,17 @@ class DistSamplerSeedHook(Hook):
Args:
runner (Runner): The runner of the training process.
"""
if hasattr(runner.train_loop.dataloader.sampler, 'set_epoch'):
# in case the data loader uses `SequentialSampler` in Pytorch
if hasattr(runner.train_loop.dataloader, 'sampler') and hasattr(
runner.train_loop.dataloader.sampler, 'set_epoch'):
# In case the` _SingleProcessDataLoaderIter` has no sampler,
# or data loader uses `SequentialSampler` in Pytorch.
runner.train_loop.dataloader.sampler.set_epoch(runner.epoch)
elif hasattr(runner.train_loop.dataloader.batch_sampler.sampler,
'set_epoch'):
elif hasattr(runner.train_loop.dataloader,
'batch_sampler') and hasattr(
runner.train_loop.dataloader.batch_sampler.sampler,
'set_epoch'):
# In case the` _SingleProcessDataLoaderIter` has no batch sampler.
# batch sampler in pytorch warps the sampler as its attributes.
runner.train_loop.dataloader.batch_sampler.sampler.set_epoch(
runner.epoch)

View File

@ -9,6 +9,9 @@ from torch.nn.parallel.distributed import (DistributedDataParallel,
from mmengine.registry import MODEL_WRAPPERS
from mmengine.utils import TORCH_VERSION, digit_version
MODEL_WRAPPERS.register_module(module=DataParallel)
MODEL_WRAPPERS.register_module(module=DistributedDataParallel)
@MODEL_WRAPPERS.register_module()
class MMDataParallel(DataParallel):

View File

@ -1397,8 +1397,13 @@ class Runner:
# Add comments to describe the usage of `after_load_ckpt`
self.call_hook('after_load_ckpt', checkpoint=checkpoint)
if is_model_wrapper(self.model):
model = self.model.module
else:
model = self.model
checkpoint = _load_checkpoint_to_model(
self.model, checkpoint, strict, revise_keys=revise_keys)
model, checkpoint, strict, revise_keys=revise_keys)
self._has_loaded = True

View File

@ -5,6 +5,8 @@ from unittest.mock import MagicMock, patch
import pytest
import torch
import torch.nn as nn
from torch.nn.parallel import DataParallel
from torch.nn.parallel.distributed import DistributedDataParallel
from mmengine.model.wrappers import (MMDataParallel, MMDistributedDataParallel,
is_model_wrapper)
@ -44,6 +46,12 @@ def test_is_model_wrapper():
mmddp = MMDistributedDataParallel(model, process_group=MagicMock())
assert is_model_wrapper(mmddp)
torch_dp = DataParallel(model)
assert is_model_wrapper(torch_dp)
torch_ddp = DistributedDataParallel(model, process_group=MagicMock())
assert is_model_wrapper(torch_ddp)
# test model wrapper registry
@MODEL_WRAPPERS.register_module()
class ModelWrapper(object):

View File

@ -14,8 +14,8 @@ from torch.utils.data import DataLoader, Dataset
from mmengine.config import Config
from mmengine.data import DefaultSampler
from mmengine.evaluator import BaseMetric, Evaluator
from mmengine.hooks import (Hook, IterTimerHook, LoggerHook, OptimizerHook,
ParamSchedulerHook)
from mmengine.hooks import (DistSamplerSeedHook, Hook, IterTimerHook,
LoggerHook, OptimizerHook, ParamSchedulerHook)
from mmengine.hooks.checkpoint_hook import CheckpointHook
from mmengine.logging import MessageHub, MMLogger
from mmengine.optim.scheduler import MultiStepLR, StepLR
@ -913,33 +913,35 @@ class TestRunner(TestCase):
# register five hooks by default
runner.register_default_hooks()
self.assertEqual(len(runner._hooks), 5)
# the forth registered hook should be `ParamSchedulerHook`
self.assertTrue(isinstance(runner._hooks[3], ParamSchedulerHook))
self.assertEqual(len(runner._hooks), 6)
# the third registered hook should be `DistSamplerSeedHook`
self.assertTrue(isinstance(runner._hooks[2], DistSamplerSeedHook))
# the fifth registered hook should be `ParamSchedulerHook`
self.assertTrue(isinstance(runner._hooks[4], ParamSchedulerHook))
runner._hooks = []
# remove `ParamSchedulerHook` from default hooks
runner.register_default_hooks(hooks=dict(timer=None))
self.assertEqual(len(runner._hooks), 4)
self.assertEqual(len(runner._hooks), 5)
# `ParamSchedulerHook` was popped so the forth is `CheckpointHook`
self.assertTrue(isinstance(runner._hooks[3], CheckpointHook))
self.assertTrue(isinstance(runner._hooks[4], CheckpointHook))
# add a new default hook
runner._hooks = []
runner.register_default_hooks(hooks=dict(ToyHook=dict(type='ToyHook')))
self.assertEqual(len(runner._hooks), 6)
self.assertTrue(isinstance(runner._hooks[5], ToyHook))
self.assertEqual(len(runner._hooks), 7)
self.assertTrue(isinstance(runner._hooks[6], ToyHook))
def test_custom_hooks(self):
cfg = copy.deepcopy(self.epoch_based_cfg)
cfg.experiment_name = 'test_custom_hooks'
runner = Runner.from_cfg(cfg)
self.assertEqual(len(runner._hooks), 5)
self.assertEqual(len(runner._hooks), 6)
custom_hooks = [dict(type='ToyHook')]
runner.register_custom_hooks(custom_hooks)
self.assertEqual(len(runner._hooks), 6)
self.assertTrue(isinstance(runner._hooks[5], ToyHook))
self.assertEqual(len(runner._hooks), 7)
self.assertTrue(isinstance(runner._hooks[6], ToyHook))
def test_register_hooks(self):
cfg = copy.deepcopy(self.epoch_based_cfg)
@ -949,9 +951,9 @@ class TestRunner(TestCase):
runner._hooks = []
custom_hooks = [dict(type='ToyHook')]
runner.register_hooks(custom_hooks=custom_hooks)
# five default hooks + custom hook (ToyHook)
self.assertEqual(len(runner._hooks), 6)
self.assertTrue(isinstance(runner._hooks[5], ToyHook))
# six default hooks + custom hook (ToyHook)
self.assertEqual(len(runner._hooks), 7)
self.assertTrue(isinstance(runner._hooks[6], ToyHook))
def test_custom_loop(self):
# test custom loop with additional hook