[fix] EMAHook load state dict (#507)
* fix ema load_state_dict * fix ema load_state_dict * fix for test * fix by review * fix resume and keyspull/523/head
parent
cfb884c180
commit
a6f5297727
|
@ -7,6 +7,7 @@ from typing import Dict, Optional
|
||||||
from mmengine.logging import print_log
|
from mmengine.logging import print_log
|
||||||
from mmengine.model import is_model_wrapper
|
from mmengine.model import is_model_wrapper
|
||||||
from mmengine.registry import HOOKS, MODELS
|
from mmengine.registry import HOOKS, MODELS
|
||||||
|
from mmengine.runner.checkpoint import _load_checkpoint_to_model
|
||||||
from .hook import DATA_BATCH, Hook
|
from .hook import DATA_BATCH, Hook
|
||||||
|
|
||||||
|
|
||||||
|
@ -171,7 +172,7 @@ class EMAHook(Hook):
|
||||||
Args:
|
Args:
|
||||||
runner (Runner): The runner of the testing process.
|
runner (Runner): The runner of the testing process.
|
||||||
"""
|
"""
|
||||||
if 'ema_state_dict' in checkpoint:
|
if 'ema_state_dict' in checkpoint and runner._resume:
|
||||||
# The original model parameters are actually saved in ema
|
# The original model parameters are actually saved in ema
|
||||||
# field swap the weights back to resume ema state.
|
# field swap the weights back to resume ema state.
|
||||||
self._swap_ema_state_dict(checkpoint)
|
self._swap_ema_state_dict(checkpoint)
|
||||||
|
@ -180,11 +181,13 @@ class EMAHook(Hook):
|
||||||
|
|
||||||
# Support load checkpoint without ema state dict.
|
# Support load checkpoint without ema state dict.
|
||||||
else:
|
else:
|
||||||
print_log(
|
if runner._resume:
|
||||||
'There is no `ema_state_dict` in checkpoint. '
|
print_log(
|
||||||
'`EMAHook` will make a copy of `state_dict` as the '
|
'There is no `ema_state_dict` in checkpoint. '
|
||||||
'initial `ema_state_dict`', 'current', logging.WARNING)
|
'`EMAHook` will make a copy of `state_dict` as the '
|
||||||
self.ema_model.module.load_state_dict(
|
'initial `ema_state_dict`', 'current', logging.WARNING)
|
||||||
|
_load_checkpoint_to_model(
|
||||||
|
self.ema_model.module,
|
||||||
copy.deepcopy(checkpoint['state_dict']),
|
copy.deepcopy(checkpoint['state_dict']),
|
||||||
strict=self.strict_load)
|
strict=self.strict_load)
|
||||||
|
|
||||||
|
|
|
@ -56,6 +56,16 @@ class ToyModel2(BaseModel, ToyModel):
|
||||||
return super(BaseModel, self).forward(*args, **kwargs)
|
return super(BaseModel, self).forward(*args, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
class ToyModel3(BaseModel, ToyModel):
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
self.linear1 = nn.Linear(2, 2)
|
||||||
|
|
||||||
|
def forward(self, *args, **kwargs):
|
||||||
|
return super(BaseModel, self).forward(*args, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
@DATASETS.register_module()
|
@DATASETS.register_module()
|
||||||
class DummyDataset(Dataset):
|
class DummyDataset(Dataset):
|
||||||
METAINFO = dict() # type: ignore
|
METAINFO = dict() # type: ignore
|
||||||
|
@ -203,6 +213,25 @@ class TestEMAHook(TestCase):
|
||||||
experiment_name='test5')
|
experiment_name='test5')
|
||||||
runner.test()
|
runner.test()
|
||||||
|
|
||||||
|
# Test does not load ckpt strict_loadly.
|
||||||
|
# Test load checkpoint without ema_state_dict
|
||||||
|
# Test with different size head.
|
||||||
|
runner = Runner(
|
||||||
|
model=ToyModel3(),
|
||||||
|
test_dataloader=dict(
|
||||||
|
dataset=dict(type='DummyDataset'),
|
||||||
|
sampler=dict(type='DefaultSampler', shuffle=True),
|
||||||
|
batch_size=3,
|
||||||
|
num_workers=0),
|
||||||
|
test_evaluator=evaluator,
|
||||||
|
test_cfg=dict(),
|
||||||
|
work_dir=self.temp_dir.name,
|
||||||
|
load_from=osp.join(self.temp_dir.name, 'epoch_2.pth'),
|
||||||
|
default_hooks=dict(logger=None),
|
||||||
|
custom_hooks=[dict(type='EMAHook', strict_load=False)],
|
||||||
|
experiment_name='test5')
|
||||||
|
runner.test()
|
||||||
|
|
||||||
# Test enable ema at 5 epochs.
|
# Test enable ema at 5 epochs.
|
||||||
runner = Runner(
|
runner = Runner(
|
||||||
model=model,
|
model=model,
|
||||||
|
|
Loading…
Reference in New Issue