170 lines
6.6 KiB
Python
170 lines
6.6 KiB
Python
# Copyright (c) OpenMMLab. All rights reserved.
|
|
from collections import OrderedDict
|
|
from copy import deepcopy
|
|
|
|
from mmcv.transforms import Compose
|
|
from mmengine.hooks import Hook
|
|
from mmengine.model import is_model_wrapper
|
|
|
|
from mmcls.models.utils import RandomBatchAugment
|
|
from mmcls.registry import HOOKS, MODEL_WRAPPERS, MODELS
|
|
|
|
|
|
@HOOKS.register_module()
|
|
class SwitchRecipeHook(Hook):
|
|
"""switch recipe during the training loop, including train pipeline, batch
|
|
augments and loss currently.
|
|
|
|
Args:
|
|
schedule (list): Every item of the schedule list should be a dict, and
|
|
the dict should have ``action_epoch`` and some of
|
|
``train_pipeline``, ``train_augments`` and ``loss`` keys:
|
|
|
|
- ``action_epoch`` (int): switch training recipe at which epoch.
|
|
- ``train_pipeline`` (list, optional): The new data pipeline of the
|
|
train dataset. If not specified, keep the original settings.
|
|
- ``batch_augments`` (dict | None, optional): The new batch
|
|
augmentations of during training. See :mod:`Batch Augmentations
|
|
<mmcls.models.utils.batch_augments>` for more details. If None,
|
|
disable batch augmentations. If not specified, keep the original
|
|
settings.
|
|
- ``loss`` (dict, optional): The new loss module config. If not
|
|
specified, keep the original settings.
|
|
|
|
Example:
|
|
To use this hook in config files.
|
|
|
|
.. code:: python
|
|
|
|
custom_hooks = [
|
|
dict(
|
|
type='SwitchRecipeHook',
|
|
schedule=[
|
|
dict(
|
|
action_epoch=30,
|
|
train_pipeline=pipeline_after_30e,
|
|
batch_augments=batch_augments_after_30e,
|
|
loss=loss_after_30e,
|
|
),
|
|
dict(
|
|
action_epoch=60,
|
|
# Disable batch augmentations after 60e
|
|
# and keep other settings.
|
|
batch_augments=None,
|
|
),
|
|
]
|
|
)
|
|
]
|
|
"""
|
|
priority = 'NORMAL'
|
|
|
|
def __init__(self, schedule):
|
|
recipes = {}
|
|
for recipe in schedule:
|
|
assert 'action_epoch' in recipe, \
|
|
'Please set `action_epoch` in every item ' \
|
|
'of the `schedule` in the SwitchRecipeHook.'
|
|
recipe = deepcopy(recipe)
|
|
if 'train_pipeline' in recipe:
|
|
recipe['train_pipeline'] = Compose(recipe['train_pipeline'])
|
|
if 'batch_augments' in recipe:
|
|
batch_augments = recipe['batch_augments']
|
|
if isinstance(batch_augments, dict):
|
|
batch_augments = RandomBatchAugment(**batch_augments)
|
|
recipe['batch_augments'] = batch_augments
|
|
if 'loss' in recipe:
|
|
loss = recipe['loss']
|
|
if isinstance(loss, dict):
|
|
loss = MODELS.build(loss)
|
|
recipe['loss'] = loss
|
|
|
|
action_epoch = recipe.pop('action_epoch')
|
|
assert action_epoch not in recipes, \
|
|
f'The `action_epoch` {action_epoch} is repeated ' \
|
|
'in the SwitchRecipeHook.'
|
|
recipes[action_epoch] = recipe
|
|
self.schedule = OrderedDict(sorted(recipes.items()))
|
|
|
|
def before_train(self, runner) -> None:
|
|
"""before run setting. If resume form a checkpoint, do all switch
|
|
before the current epoch.
|
|
|
|
Args:
|
|
runner (Runner): The runner of the training, validation or testing
|
|
process.
|
|
"""
|
|
if runner._resume:
|
|
for action_epoch, recipe in self.schedule.items():
|
|
if action_epoch >= runner.epoch + 1:
|
|
break
|
|
self._do_switch(runner, recipe,
|
|
f' (resume recipe of epoch {action_epoch})')
|
|
|
|
def before_train_epoch(self, runner):
|
|
"""do before train epoch."""
|
|
recipe = self.schedule.get(runner.epoch + 1, None)
|
|
if recipe is not None:
|
|
self._do_switch(runner, recipe, f' at epoch {runner.epoch + 1}')
|
|
|
|
def _do_switch(self, runner, recipe, extra_info=''):
|
|
"""do the switch aug process."""
|
|
if 'batch_augments' in recipe:
|
|
self._switch_batch_augments(runner, recipe['batch_augments'])
|
|
runner.logger.info(f'Switch batch augments{extra_info}.')
|
|
|
|
if 'train_pipeline' in recipe:
|
|
self._switch_train_pipeline(runner, recipe['train_pipeline'])
|
|
runner.logger.info(f'Switch train pipeline{extra_info}.')
|
|
|
|
if 'loss' in recipe:
|
|
self._switch_loss(runner, recipe['loss'])
|
|
runner.logger.info(f'Switch loss{extra_info}.')
|
|
|
|
@staticmethod
|
|
def _switch_batch_augments(runner, batch_augments):
|
|
"""switch the train augments."""
|
|
model = runner.model
|
|
if is_model_wrapper(model):
|
|
model = model.module
|
|
|
|
model.data_preprocessor.batch_augments = batch_augments
|
|
|
|
@staticmethod
|
|
def _switch_train_pipeline(runner, train_pipeline):
|
|
"""switch the train loader dataset pipeline."""
|
|
|
|
def switch_pipeline(dataset, pipeline):
|
|
if hasattr(dataset, 'pipeline'):
|
|
# for usual dataset
|
|
dataset.pipeline = pipeline
|
|
elif hasattr(dataset, 'datasets'):
|
|
# for concat dataset wrapper
|
|
for ds in dataset.datasets:
|
|
switch_pipeline(ds, pipeline)
|
|
elif hasattr(dataset, 'dataset'):
|
|
# for other dataset wrappers
|
|
switch_pipeline(dataset.dataset, pipeline)
|
|
else:
|
|
raise RuntimeError(
|
|
'Cannot access the `pipeline` of the dataset.')
|
|
|
|
train_loader = runner.train_loop.dataloader
|
|
switch_pipeline(train_loader.dataset, train_pipeline)
|
|
|
|
# To restart the iterator of dataloader when `persistent_workers=True`
|
|
train_loader._iterator = None
|
|
|
|
@staticmethod
|
|
def _switch_loss(runner, loss_module):
|
|
"""switch the loss module."""
|
|
model = runner.model
|
|
if is_model_wrapper(model, MODEL_WRAPPERS):
|
|
model = model.module
|
|
|
|
if hasattr(model, 'loss_module'):
|
|
model.loss_module = loss_module
|
|
elif hasattr(model, 'head') and hasattr(model.head, 'loss_module'):
|
|
model.head.loss_module = loss_module
|
|
else:
|
|
raise RuntimeError('Cannot access the `loss_module` of the model.')
|