mirror of
https://github.com/open-mmlab/mmclassification.git
synced 2025-06-03 21:53:55 +08:00
[Enhance] Support evaluate on both EMA and non-EMA models. (#1204)
* [Enhance] Support evaluate on both EMA and original models. * Fix lint
This commit is contained in:
parent
d80ec5a4b8
commit
7b9a1010f5
@ -122,7 +122,7 @@ and WandB. More details can be found in the [Visualizer section](#visualizer).
|
||||
Many above functionalities are implemented by hooks, and you can also plug-in other custom hooks by modifying
|
||||
`custom_hooks` field. Here are some hooks in MMEngine and MMClassification that you can use directly, such as:
|
||||
|
||||
- [EMAHook](mmengine.hooks.EMAHook)
|
||||
- [EMAHook](mmcls.engine.hooks.EMAHook)
|
||||
- [SyncBuffersHook](mmengine.hooks.SyncBuffersHook)
|
||||
- [EmptyCacheHook](mmengine.hooks.EmptyCacheHook)
|
||||
- [ClassNumCheckHook](mmcls.engine.hooks.ClassNumCheckHook)
|
||||
|
@ -33,6 +33,7 @@ Hooks
|
||||
VisualizationHook
|
||||
PrepareProtoBeforeValLoopHook
|
||||
SetAdaptiveMarginsHook
|
||||
EMAHook
|
||||
|
||||
.. module:: mmcls.engine.optimizers
|
||||
|
||||
|
@ -1,5 +1,6 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from .class_num_check_hook import ClassNumCheckHook
|
||||
from .ema_hook import EMAHook
|
||||
from .margin_head_hooks import SetAdaptiveMarginsHook
|
||||
from .precise_bn_hook import PreciseBNHook
|
||||
from .retriever_hooks import PrepareProtoBeforeValLoopHook
|
||||
@ -9,5 +10,5 @@ from .visualization_hook import VisualizationHook
|
||||
__all__ = [
|
||||
'ClassNumCheckHook', 'PreciseBNHook', 'VisualizationHook',
|
||||
'SwitchRecipeHook', 'PrepareProtoBeforeValLoopHook',
|
||||
'SetAdaptiveMarginsHook'
|
||||
'SetAdaptiveMarginsHook', 'EMAHook'
|
||||
]
|
||||
|
216
mmcls/engine/hooks/ema_hook.py
Normal file
216
mmcls/engine/hooks/ema_hook.py
Normal file
@ -0,0 +1,216 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import copy
|
||||
import itertools
|
||||
import warnings
|
||||
from typing import Dict, Optional
|
||||
|
||||
from mmengine.hooks import EMAHook as BaseEMAHook
|
||||
from mmengine.logging import MMLogger
|
||||
from mmengine.runner import Runner
|
||||
|
||||
from mmcls.registry import HOOKS
|
||||
|
||||
|
||||
@HOOKS.register_module()
|
||||
class EMAHook(BaseEMAHook):
|
||||
"""A Hook to apply Exponential Moving Average (EMA) on the model during
|
||||
training.
|
||||
|
||||
Comparing with :class:`mmengine.hooks.EMAHook`, this hook accepts
|
||||
``evaluate_on_ema`` and ``evaluate_on_origin`` arguments. By default, the
|
||||
``evaluate_on_ema`` is enabled, and if you want to do validation and
|
||||
testing on both original and EMA models, please set both arguments
|
||||
``True``.
|
||||
|
||||
Note:
|
||||
- EMAHook takes priority over CheckpointHook.
|
||||
- The original model parameters are actually saved in ema field after
|
||||
train.
|
||||
- ``begin_iter`` and ``begin_epoch`` cannot be set at the same time.
|
||||
|
||||
Args:
|
||||
ema_type (str): The type of EMA strategy to use. You can find the
|
||||
supported strategies in :mod:`mmengine.model.averaged_model`.
|
||||
Defaults to 'ExponentialMovingAverage'.
|
||||
strict_load (bool): Whether to strictly enforce that the keys of
|
||||
``state_dict`` in checkpoint match the keys returned by
|
||||
``self.module.state_dict``. Defaults to False.
|
||||
Changed in v0.3.0.
|
||||
begin_iter (int): The number of iteration to enable ``EMAHook``.
|
||||
Defaults to 0.
|
||||
begin_epoch (int): The number of epoch to enable ``EMAHook``.
|
||||
Defaults to 0.
|
||||
evaluate_on_ema (bool): Whether to evaluate (validate and test)
|
||||
on EMA model during val-loop and test-loop. Defaults to True.
|
||||
evaluate_on_origin (bool): Whether to evaluate (validate and test)
|
||||
on the original model during val-loop and test-loop.
|
||||
Defaults to False.
|
||||
**kwargs: Keyword arguments passed to subclasses of
|
||||
:obj:`BaseAveragedModel`
|
||||
"""
|
||||
|
||||
priority = 'NORMAL'
|
||||
|
||||
def __init__(self,
|
||||
ema_type: str = 'ExponentialMovingAverage',
|
||||
strict_load: bool = False,
|
||||
begin_iter: int = 0,
|
||||
begin_epoch: int = 0,
|
||||
evaluate_on_ema: bool = True,
|
||||
evaluate_on_origin: bool = False,
|
||||
**kwargs):
|
||||
super().__init__(
|
||||
ema_type=ema_type,
|
||||
strict_load=strict_load,
|
||||
begin_iter=begin_iter,
|
||||
begin_epoch=begin_epoch,
|
||||
**kwargs)
|
||||
|
||||
if not evaluate_on_ema and not evaluate_on_origin:
|
||||
warnings.warn(
|
||||
'Automatically set `evaluate_on_origin=True` since the '
|
||||
'`evaluate_on_ema` is disabled. If you want to disable '
|
||||
'all validation, please modify the `val_interval` of '
|
||||
'the `train_cfg`.', UserWarning)
|
||||
evaluate_on_origin = True
|
||||
|
||||
self.evaluate_on_ema = evaluate_on_ema
|
||||
self.evaluate_on_origin = evaluate_on_origin
|
||||
self.load_ema_from_ckpt = False
|
||||
|
||||
def before_train(self, runner) -> None:
|
||||
super().before_train(runner)
|
||||
if not runner._resume and self.load_ema_from_ckpt:
|
||||
# If loaded EMA state dict but not want to resume training
|
||||
# overwrite the EMA state dict with the source model.
|
||||
MMLogger.get_current_instance().info(
|
||||
'Load from a checkpoint with EMA parameters but not '
|
||||
'resume training. Initialize the model parameters with '
|
||||
'EMA parameters')
|
||||
for p_ema, p_src in zip(self._ema_params, self._src_params):
|
||||
p_src.data.copy_(p_ema.data)
|
||||
|
||||
def before_val_epoch(self, runner) -> None:
|
||||
"""We load parameter values from ema model to source model before
|
||||
validation.
|
||||
|
||||
Args:
|
||||
runner (Runner): The runner of the training process.
|
||||
"""
|
||||
if self.evaluate_on_ema:
|
||||
# Swap when evaluate on ema
|
||||
self._swap_ema_parameters()
|
||||
|
||||
def after_val_epoch(self,
|
||||
runner,
|
||||
metrics: Optional[Dict[str, float]] = None) -> None:
|
||||
"""We recover source model's parameter from ema model after validation.
|
||||
|
||||
Args:
|
||||
runner (Runner): The runner of the validation process.
|
||||
metrics (Dict[str, float], optional): Evaluation results of all
|
||||
metrics on validation dataset. The keys are the names of the
|
||||
metrics, and the values are corresponding results.
|
||||
"""
|
||||
if self.evaluate_on_ema:
|
||||
# Swap when evaluate on ema
|
||||
self._swap_ema_parameters()
|
||||
|
||||
if self.evaluate_on_ema and self.evaluate_on_origin:
|
||||
# Re-evaluate if evaluate on both ema and origin.
|
||||
val_loop = runner.val_loop
|
||||
|
||||
runner.model.eval()
|
||||
for idx, data_batch in enumerate(val_loop.dataloader):
|
||||
val_loop.run_iter(idx, data_batch)
|
||||
|
||||
# compute metrics
|
||||
origin_metrics = val_loop.evaluator.evaluate(
|
||||
len(val_loop.dataloader.dataset))
|
||||
|
||||
for k, v in origin_metrics.items():
|
||||
runner.message_hub.update_scalar(f'val/{k}_origin', v)
|
||||
|
||||
def before_test_epoch(self, runner) -> None:
|
||||
"""We load parameter values from ema model to source model before test.
|
||||
|
||||
Args:
|
||||
runner (Runner): The runner of the training process.
|
||||
"""
|
||||
if self.evaluate_on_ema:
|
||||
# Swap when evaluate on ema
|
||||
self._swap_ema_parameters()
|
||||
MMLogger.get_current_instance().info('Start testing on EMA model.')
|
||||
else:
|
||||
MMLogger.get_current_instance().info(
|
||||
'Start testing on the original model.')
|
||||
|
||||
def after_test_epoch(self,
|
||||
runner: Runner,
|
||||
metrics: Optional[Dict[str, float]] = None) -> None:
|
||||
"""We recover source model's parameter from ema model after test.
|
||||
|
||||
Args:
|
||||
runner (Runner): The runner of the testing process.
|
||||
metrics (Dict[str, float], optional): Evaluation results of all
|
||||
metrics on test dataset. The keys are the names of the
|
||||
metrics, and the values are corresponding results.
|
||||
"""
|
||||
if self.evaluate_on_ema:
|
||||
# Swap when evaluate on ema
|
||||
self._swap_ema_parameters()
|
||||
|
||||
if self.evaluate_on_ema and self.evaluate_on_origin:
|
||||
# Re-evaluate if evaluate on both ema and origin.
|
||||
MMLogger.get_current_instance().info(
|
||||
'Start testing on the original model.')
|
||||
test_loop = runner.test_loop
|
||||
|
||||
runner.model.eval()
|
||||
for idx, data_batch in enumerate(test_loop.dataloader):
|
||||
test_loop.run_iter(idx, data_batch)
|
||||
|
||||
# compute metrics
|
||||
origin_metrics = test_loop.evaluator.evaluate(
|
||||
len(test_loop.dataloader.dataset))
|
||||
|
||||
for k, v in origin_metrics.items():
|
||||
runner.message_hub.update_scalar(f'test/{k}_origin', v)
|
||||
|
||||
def after_load_checkpoint(self, runner, checkpoint: dict) -> None:
|
||||
"""Resume ema parameters from checkpoint.
|
||||
|
||||
Args:
|
||||
runner (Runner): The runner of the testing process.
|
||||
"""
|
||||
from mmengine.runner.checkpoint import load_state_dict
|
||||
if 'ema_state_dict' in checkpoint:
|
||||
# The original model parameters are actually saved in ema
|
||||
# field swap the weights back to resume ema state.
|
||||
self._swap_ema_state_dict(checkpoint)
|
||||
self.ema_model.load_state_dict(
|
||||
checkpoint['ema_state_dict'], strict=self.strict_load)
|
||||
self.load_ema_from_ckpt = True
|
||||
|
||||
# Support load checkpoint without ema state dict.
|
||||
else:
|
||||
load_state_dict(
|
||||
self.ema_model.module,
|
||||
copy.deepcopy(checkpoint['state_dict']),
|
||||
strict=self.strict_load)
|
||||
|
||||
@property
|
||||
def _src_params(self):
|
||||
if self.ema_model.update_buffers:
|
||||
return itertools.chain(self.src_model.parameters(),
|
||||
self.src_model.buffers())
|
||||
else:
|
||||
return self.src_model.parameters()
|
||||
|
||||
@property
|
||||
def _ema_params(self):
|
||||
if self.ema_model.update_buffers:
|
||||
return itertools.chain(self.ema_model.module.parameters(),
|
||||
self.ema_model.module.buffers())
|
||||
else:
|
||||
return self.ema_model.module.parameters()
|
223
tests/test_engine/test_hooks/test_ema_hook.py
Normal file
223
tests/test_engine/test_hooks/test_ema_hook.py
Normal file
@ -0,0 +1,223 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import logging
|
||||
import os.path as osp
|
||||
import tempfile
|
||||
from collections import OrderedDict
|
||||
from unittest import TestCase
|
||||
from unittest.mock import ANY, MagicMock, call
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from mmengine.evaluator import Evaluator
|
||||
from mmengine.logging import MMLogger
|
||||
from mmengine.model import BaseModel
|
||||
from mmengine.optim import OptimWrapper
|
||||
from mmengine.runner import Runner
|
||||
from mmengine.testing import assert_allclose
|
||||
from torch.utils.data import Dataset
|
||||
|
||||
from mmcls.engine import EMAHook
|
||||
|
||||
|
||||
class SimpleModel(BaseModel):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.para = nn.Parameter(torch.zeros(1))
|
||||
|
||||
def forward(self, *args, mode='tensor', **kwargs):
|
||||
if mode == 'predict':
|
||||
return self.para.clone()
|
||||
elif mode == 'loss':
|
||||
return {'loss': self.para.mean()}
|
||||
|
||||
|
||||
class DummyDataset(Dataset):
|
||||
METAINFO = dict() # type: ignore
|
||||
data = torch.randn(6, 2)
|
||||
label = torch.ones(6)
|
||||
|
||||
@property
|
||||
def metainfo(self):
|
||||
return self.METAINFO
|
||||
|
||||
def __len__(self):
|
||||
return self.data.size(0)
|
||||
|
||||
def __getitem__(self, index):
|
||||
return dict(inputs=self.data[index], data_sample=self.label[index])
|
||||
|
||||
|
||||
class TestEMAHook(TestCase):
|
||||
|
||||
def setUp(self):
|
||||
self.temp_dir = tempfile.TemporaryDirectory()
|
||||
state_dict = OrderedDict(
|
||||
meta=dict(epoch=1, iter=2),
|
||||
# The actual ema para
|
||||
state_dict={'para': torch.tensor([1.])},
|
||||
# The actual original para
|
||||
ema_state_dict={'module.para': torch.tensor([2.])},
|
||||
)
|
||||
self.ckpt = osp.join(self.temp_dir.name, 'ema.pth')
|
||||
torch.save(state_dict, self.ckpt)
|
||||
|
||||
def tearDown(self):
|
||||
# `FileHandler` should be closed in Windows, otherwise we cannot
|
||||
# delete the temporary directory
|
||||
logging.shutdown()
|
||||
MMLogger._instance_dict.clear()
|
||||
self.temp_dir.cleanup()
|
||||
|
||||
def test_load_state_dict(self):
|
||||
device = 'cuda:0' if torch.cuda.is_available() else 'cpu'
|
||||
model = SimpleModel().to(device)
|
||||
ema_hook = EMAHook()
|
||||
runner = Runner(
|
||||
model=model,
|
||||
train_dataloader=dict(
|
||||
dataset=DummyDataset(),
|
||||
sampler=dict(type='DefaultSampler', shuffle=False),
|
||||
batch_size=3,
|
||||
num_workers=0),
|
||||
optim_wrapper=OptimWrapper(
|
||||
optimizer=torch.optim.Adam(model.parameters(), lr=0.)),
|
||||
train_cfg=dict(by_epoch=True, max_epochs=2),
|
||||
work_dir=self.temp_dir.name,
|
||||
resume=False,
|
||||
load_from=self.ckpt,
|
||||
default_hooks=dict(logger=None),
|
||||
custom_hooks=[ema_hook],
|
||||
default_scope='mmcls',
|
||||
experiment_name='load_state_dict')
|
||||
runner.train()
|
||||
assert_allclose(runner.model.para, torch.tensor([1.], device=device))
|
||||
|
||||
def test_evaluate_on_ema(self):
|
||||
|
||||
device = 'cuda:0' if torch.cuda.is_available() else 'cpu'
|
||||
model = SimpleModel().to(device)
|
||||
|
||||
# Test validate on ema model
|
||||
evaluator = Evaluator([MagicMock()])
|
||||
runner = Runner(
|
||||
model=model,
|
||||
val_dataloader=dict(
|
||||
dataset=DummyDataset(),
|
||||
sampler=dict(type='DefaultSampler', shuffle=False),
|
||||
batch_size=3,
|
||||
num_workers=0),
|
||||
val_evaluator=evaluator,
|
||||
val_cfg=dict(),
|
||||
work_dir=self.temp_dir.name,
|
||||
load_from=self.ckpt,
|
||||
default_hooks=dict(logger=None),
|
||||
custom_hooks=[dict(type='EMAHook')],
|
||||
default_scope='mmcls',
|
||||
experiment_name='validate_on_ema')
|
||||
runner.val()
|
||||
evaluator.metrics[0].process.assert_has_calls([
|
||||
call(ANY, [torch.tensor([1.]).to(device)]),
|
||||
])
|
||||
self.assertNotIn(
|
||||
call(ANY, [torch.tensor([2.]).to(device)]),
|
||||
evaluator.metrics[0].process.mock_calls)
|
||||
|
||||
# Test test on ema model
|
||||
evaluator = Evaluator([MagicMock()])
|
||||
runner = Runner(
|
||||
model=model,
|
||||
test_dataloader=dict(
|
||||
dataset=DummyDataset(),
|
||||
sampler=dict(type='DefaultSampler', shuffle=False),
|
||||
batch_size=3,
|
||||
num_workers=0),
|
||||
test_evaluator=evaluator,
|
||||
test_cfg=dict(),
|
||||
work_dir=self.temp_dir.name,
|
||||
load_from=self.ckpt,
|
||||
default_hooks=dict(logger=None),
|
||||
custom_hooks=[dict(type='EMAHook')],
|
||||
default_scope='mmcls',
|
||||
experiment_name='test_on_ema')
|
||||
runner.test()
|
||||
evaluator.metrics[0].process.assert_has_calls([
|
||||
call(ANY, [torch.tensor([1.]).to(device)]),
|
||||
])
|
||||
self.assertNotIn(
|
||||
call(ANY, [torch.tensor([2.]).to(device)]),
|
||||
evaluator.metrics[0].process.mock_calls)
|
||||
|
||||
# Test validate on both models
|
||||
evaluator = Evaluator([MagicMock()])
|
||||
runner = Runner(
|
||||
model=model,
|
||||
val_dataloader=dict(
|
||||
dataset=DummyDataset(),
|
||||
sampler=dict(type='DefaultSampler', shuffle=True),
|
||||
batch_size=3,
|
||||
num_workers=0),
|
||||
val_evaluator=evaluator,
|
||||
val_cfg=dict(),
|
||||
work_dir=self.temp_dir.name,
|
||||
load_from=self.ckpt,
|
||||
default_hooks=dict(logger=None),
|
||||
custom_hooks=[dict(type='EMAHook', evaluate_on_origin=True)],
|
||||
default_scope='mmcls',
|
||||
experiment_name='validate_on_ema_false',
|
||||
)
|
||||
runner.val()
|
||||
evaluator.metrics[0].process.assert_has_calls([
|
||||
call(ANY, [torch.tensor([1.]).to(device)]),
|
||||
call(ANY, [torch.tensor([2.]).to(device)]),
|
||||
])
|
||||
|
||||
# Test test on both models
|
||||
evaluator = Evaluator([MagicMock()])
|
||||
runner = Runner(
|
||||
model=model,
|
||||
test_dataloader=dict(
|
||||
dataset=DummyDataset(),
|
||||
sampler=dict(type='DefaultSampler', shuffle=True),
|
||||
batch_size=3,
|
||||
num_workers=0),
|
||||
test_evaluator=evaluator,
|
||||
test_cfg=dict(),
|
||||
work_dir=self.temp_dir.name,
|
||||
load_from=self.ckpt,
|
||||
default_hooks=dict(logger=None),
|
||||
custom_hooks=[dict(type='EMAHook', evaluate_on_origin=True)],
|
||||
default_scope='mmcls',
|
||||
experiment_name='test_on_ema_false',
|
||||
)
|
||||
runner.test()
|
||||
evaluator.metrics[0].process.assert_has_calls([
|
||||
call(ANY, [torch.tensor([1.]).to(device)]),
|
||||
call(ANY, [torch.tensor([2.]).to(device)]),
|
||||
])
|
||||
|
||||
# Test evaluate_on_ema=False
|
||||
evaluator = Evaluator([MagicMock()])
|
||||
with self.assertWarnsRegex(UserWarning, 'evaluate_on_origin'):
|
||||
runner = Runner(
|
||||
model=model,
|
||||
test_dataloader=dict(
|
||||
dataset=DummyDataset(),
|
||||
sampler=dict(type='DefaultSampler', shuffle=False),
|
||||
batch_size=3,
|
||||
num_workers=0),
|
||||
test_evaluator=evaluator,
|
||||
test_cfg=dict(),
|
||||
work_dir=self.temp_dir.name,
|
||||
load_from=self.ckpt,
|
||||
default_hooks=dict(logger=None),
|
||||
custom_hooks=[dict(type='EMAHook', evaluate_on_ema=False)],
|
||||
default_scope='mmcls',
|
||||
experiment_name='not_test_on_ema')
|
||||
runner.test()
|
||||
evaluator.metrics[0].process.assert_has_calls([
|
||||
call(ANY, [torch.tensor([2.]).to(device)]),
|
||||
])
|
||||
self.assertNotIn(
|
||||
call(ANY, [torch.tensor([1.]).to(device)]),
|
||||
evaluator.metrics[0].process.mock_calls)
|
Loading…
x
Reference in New Issue
Block a user