EasyCV/easycv/hooks/yolox_mode_switch_hook.py

43 lines
1.6 KiB
Python

# Copyright (c) Alibaba, Inc. and its affiliates.
from mmcv.parallel import is_module_wrapper
from mmcv.runner.hooks import Hook
from .registry import HOOKS
@HOOKS.register_module()
class YOLOXModeSwitchHook(Hook):
"""Switch the mode of YOLOX during training.
This hook turns off the mosaic and mixup data augmentation and switches
to use L1 loss in bbox_head.
Args:
no_aug_epochs (int): The number of latter epochs in the end of the
training to close the data augmentation and switch to L1 loss.
Default: 15.
skip_type_keys (list[str], optional): Sequence of type string to be
skip pipeline. Default: ('Mosaic', 'RandomAffine', 'MixUp')
"""
def __init__(self,
no_aug_epochs=15,
skip_type_keys=('MMMosaic', 'MMRandomAffine', 'MMMixUp'),
**kwargs):
super(YOLOXModeSwitchHook, self).__init__()
self.no_aug_epochs = no_aug_epochs
self.skip_type_keys = skip_type_keys
def before_train_epoch(self, runner):
"""Close mosaic and mixup augmentation and switches to use L1 loss."""
epoch = runner.epoch
train_loader = runner.data_loader
model = runner.model
if is_module_wrapper(model):
model = model.module
if (epoch + 1) == runner.max_epochs - self.no_aug_epochs:
runner.logger.info('No mosaic and mixup aug now!')
train_loader.dataset.update_skip_type_keys(self.skip_type_keys)
runner.logger.info('Add additional L1 loss now!')
model.head.use_l1 = True