[Fix]: Fix is_model_wrapper and add DistSamplerSeedHook to default hooks. (#172)
* [Fix]: Fix model_wrapper and add DistSamplerSeedHook as default hook. * add commentspull/167/head
parent
93d22757cf
commit
3d830a28b6
|
@ -20,11 +20,17 @@ class DistSamplerSeedHook(Hook):
|
|||
Args:
|
||||
runner (Runner): The runner of the training process.
|
||||
"""
|
||||
if hasattr(runner.train_loop.dataloader.sampler, 'set_epoch'):
|
||||
# in case the data loader uses `SequentialSampler` in Pytorch
|
||||
if hasattr(runner.train_loop.dataloader, 'sampler') and hasattr(
|
||||
runner.train_loop.dataloader.sampler, 'set_epoch'):
|
||||
# In case the` _SingleProcessDataLoaderIter` has no sampler,
|
||||
# or data loader uses `SequentialSampler` in Pytorch.
|
||||
runner.train_loop.dataloader.sampler.set_epoch(runner.epoch)
|
||||
elif hasattr(runner.train_loop.dataloader.batch_sampler.sampler,
|
||||
'set_epoch'):
|
||||
|
||||
elif hasattr(runner.train_loop.dataloader,
|
||||
'batch_sampler') and hasattr(
|
||||
runner.train_loop.dataloader.batch_sampler.sampler,
|
||||
'set_epoch'):
|
||||
# In case the` _SingleProcessDataLoaderIter` has no batch sampler.
|
||||
# batch sampler in pytorch warps the sampler as its attributes.
|
||||
runner.train_loop.dataloader.batch_sampler.sampler.set_epoch(
|
||||
runner.epoch)
|
||||
|
|
|
@ -9,6 +9,9 @@ from torch.nn.parallel.distributed import (DistributedDataParallel,
|
|||
from mmengine.registry import MODEL_WRAPPERS
|
||||
from mmengine.utils import TORCH_VERSION, digit_version
|
||||
|
||||
MODEL_WRAPPERS.register_module(module=DataParallel)
|
||||
MODEL_WRAPPERS.register_module(module=DistributedDataParallel)
|
||||
|
||||
|
||||
@MODEL_WRAPPERS.register_module()
|
||||
class MMDataParallel(DataParallel):
|
||||
|
|
|
@ -1397,8 +1397,13 @@ class Runner:
|
|||
# Add comments to describe the usage of `after_load_ckpt`
|
||||
self.call_hook('after_load_ckpt', checkpoint=checkpoint)
|
||||
|
||||
if is_model_wrapper(self.model):
|
||||
model = self.model.module
|
||||
else:
|
||||
model = self.model
|
||||
|
||||
checkpoint = _load_checkpoint_to_model(
|
||||
self.model, checkpoint, strict, revise_keys=revise_keys)
|
||||
model, checkpoint, strict, revise_keys=revise_keys)
|
||||
|
||||
self._has_loaded = True
|
||||
|
||||
|
|
|
@ -5,6 +5,8 @@ from unittest.mock import MagicMock, patch
|
|||
import pytest
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from torch.nn.parallel import DataParallel
|
||||
from torch.nn.parallel.distributed import DistributedDataParallel
|
||||
|
||||
from mmengine.model.wrappers import (MMDataParallel, MMDistributedDataParallel,
|
||||
is_model_wrapper)
|
||||
|
@ -44,6 +46,12 @@ def test_is_model_wrapper():
|
|||
mmddp = MMDistributedDataParallel(model, process_group=MagicMock())
|
||||
assert is_model_wrapper(mmddp)
|
||||
|
||||
torch_dp = DataParallel(model)
|
||||
assert is_model_wrapper(torch_dp)
|
||||
|
||||
torch_ddp = DistributedDataParallel(model, process_group=MagicMock())
|
||||
assert is_model_wrapper(torch_ddp)
|
||||
|
||||
# test model wrapper registry
|
||||
@MODEL_WRAPPERS.register_module()
|
||||
class ModelWrapper(object):
|
||||
|
|
|
@ -14,8 +14,8 @@ from torch.utils.data import DataLoader, Dataset
|
|||
from mmengine.config import Config
|
||||
from mmengine.data import DefaultSampler
|
||||
from mmengine.evaluator import BaseMetric, Evaluator
|
||||
from mmengine.hooks import (Hook, IterTimerHook, LoggerHook, OptimizerHook,
|
||||
ParamSchedulerHook)
|
||||
from mmengine.hooks import (DistSamplerSeedHook, Hook, IterTimerHook,
|
||||
LoggerHook, OptimizerHook, ParamSchedulerHook)
|
||||
from mmengine.hooks.checkpoint_hook import CheckpointHook
|
||||
from mmengine.logging import MessageHub, MMLogger
|
||||
from mmengine.optim.scheduler import MultiStepLR, StepLR
|
||||
|
@ -913,33 +913,35 @@ class TestRunner(TestCase):
|
|||
|
||||
# register five hooks by default
|
||||
runner.register_default_hooks()
|
||||
self.assertEqual(len(runner._hooks), 5)
|
||||
# the forth registered hook should be `ParamSchedulerHook`
|
||||
self.assertTrue(isinstance(runner._hooks[3], ParamSchedulerHook))
|
||||
self.assertEqual(len(runner._hooks), 6)
|
||||
# the third registered hook should be `DistSamplerSeedHook`
|
||||
self.assertTrue(isinstance(runner._hooks[2], DistSamplerSeedHook))
|
||||
# the fifth registered hook should be `ParamSchedulerHook`
|
||||
self.assertTrue(isinstance(runner._hooks[4], ParamSchedulerHook))
|
||||
|
||||
runner._hooks = []
|
||||
# remove `ParamSchedulerHook` from default hooks
|
||||
runner.register_default_hooks(hooks=dict(timer=None))
|
||||
self.assertEqual(len(runner._hooks), 4)
|
||||
self.assertEqual(len(runner._hooks), 5)
|
||||
# `ParamSchedulerHook` was popped so the forth is `CheckpointHook`
|
||||
self.assertTrue(isinstance(runner._hooks[3], CheckpointHook))
|
||||
self.assertTrue(isinstance(runner._hooks[4], CheckpointHook))
|
||||
|
||||
# add a new default hook
|
||||
runner._hooks = []
|
||||
runner.register_default_hooks(hooks=dict(ToyHook=dict(type='ToyHook')))
|
||||
self.assertEqual(len(runner._hooks), 6)
|
||||
self.assertTrue(isinstance(runner._hooks[5], ToyHook))
|
||||
self.assertEqual(len(runner._hooks), 7)
|
||||
self.assertTrue(isinstance(runner._hooks[6], ToyHook))
|
||||
|
||||
def test_custom_hooks(self):
|
||||
cfg = copy.deepcopy(self.epoch_based_cfg)
|
||||
cfg.experiment_name = 'test_custom_hooks'
|
||||
runner = Runner.from_cfg(cfg)
|
||||
|
||||
self.assertEqual(len(runner._hooks), 5)
|
||||
self.assertEqual(len(runner._hooks), 6)
|
||||
custom_hooks = [dict(type='ToyHook')]
|
||||
runner.register_custom_hooks(custom_hooks)
|
||||
self.assertEqual(len(runner._hooks), 6)
|
||||
self.assertTrue(isinstance(runner._hooks[5], ToyHook))
|
||||
self.assertEqual(len(runner._hooks), 7)
|
||||
self.assertTrue(isinstance(runner._hooks[6], ToyHook))
|
||||
|
||||
def test_register_hooks(self):
|
||||
cfg = copy.deepcopy(self.epoch_based_cfg)
|
||||
|
@ -949,9 +951,9 @@ class TestRunner(TestCase):
|
|||
runner._hooks = []
|
||||
custom_hooks = [dict(type='ToyHook')]
|
||||
runner.register_hooks(custom_hooks=custom_hooks)
|
||||
# five default hooks + custom hook (ToyHook)
|
||||
self.assertEqual(len(runner._hooks), 6)
|
||||
self.assertTrue(isinstance(runner._hooks[5], ToyHook))
|
||||
# six default hooks + custom hook (ToyHook)
|
||||
self.assertEqual(len(runner._hooks), 7)
|
||||
self.assertTrue(isinstance(runner._hooks[6], ToyHook))
|
||||
|
||||
def test_custom_loop(self):
|
||||
# test custom loop with additional hook
|
||||
|
|
Loading…
Reference in New Issue