[Fix] Fix circle import in EMAHook (#523)

pull/515/head
Mashiro 2022-09-09 19:33:27 +08:00 committed by GitHub
parent 3ba50ce137
commit 0fb2b8ca8c
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 6 additions and 5 deletions

View File

@ -7,7 +7,6 @@ from typing import Dict, Optional
from mmengine.logging import print_log
from mmengine.model import is_model_wrapper
from mmengine.registry import HOOKS, MODELS
from mmengine.runner.checkpoint import _load_checkpoint_to_model
from .hook import DATA_BATCH, Hook
@ -172,6 +171,7 @@ class EMAHook(Hook):
Args:
runner (Runner): The runner of the testing process.
"""
from mmengine.runner.checkpoint import load_state_dict
if 'ema_state_dict' in checkpoint and runner._resume:
# The original model parameters are actually saved in ema
# field swap the weights back to resume ema state.
@ -186,7 +186,7 @@ class EMAHook(Hook):
'There is no `ema_state_dict` in checkpoint. '
'`EMAHook` will make a copy of `state_dict` as the '
'initial `ema_state_dict`', 'current', logging.WARNING)
_load_checkpoint_to_model(
load_state_dict(
self.ema_model.module,
copy.deepcopy(checkpoint['state_dict']),
strict=self.strict_load)

View File

@ -60,7 +60,7 @@ class ToyModel3(BaseModel, ToyModel):
def __init__(self):
super().__init__()
self.linear1 = nn.Linear(2, 2)
self.linear = nn.Linear(2, 2)
def forward(self, *args, **kwargs):
return super(BaseModel, self).forward(*args, **kwargs)
@ -226,10 +226,11 @@ class TestEMAHook(TestCase):
test_evaluator=evaluator,
test_cfg=dict(),
work_dir=self.temp_dir.name,
load_from=osp.join(self.temp_dir.name, 'epoch_2.pth'),
load_from=osp.join(self.temp_dir.name,
'without_ema_state_dict.pth'),
default_hooks=dict(logger=None),
custom_hooks=[dict(type='EMAHook', strict_load=False)],
experiment_name='test5')
experiment_name='test5.1')
runner.test()
# Test enable ema at 5 epochs.