mirror of https://github.com/alibaba/EasyCV.git
43 lines
1.6 KiB
Python
43 lines
1.6 KiB
Python
# Copyright (c) Alibaba, Inc. and its affiliates.
|
|
from mmcv.parallel import is_module_wrapper
|
|
from mmcv.runner.hooks import Hook
|
|
|
|
from .registry import HOOKS
|
|
|
|
|
|
@HOOKS.register_module()
|
|
class YOLOXModeSwitchHook(Hook):
|
|
"""Switch the mode of YOLOX during training.
|
|
|
|
This hook turns off the mosaic and mixup data augmentation and switches
|
|
to use L1 loss in bbox_head.
|
|
|
|
Args:
|
|
no_aug_epochs (int): The number of latter epochs in the end of the
|
|
training to close the data augmentation and switch to L1 loss.
|
|
Default: 15.
|
|
skip_type_keys (list[str], optional): Sequence of type string to be
|
|
skip pipeline. Default: ('Mosaic', 'RandomAffine', 'MixUp')
|
|
"""
|
|
|
|
def __init__(self,
|
|
no_aug_epochs=15,
|
|
skip_type_keys=('MMMosaic', 'MMRandomAffine', 'MMMixUp'),
|
|
**kwargs):
|
|
super(YOLOXModeSwitchHook, self).__init__()
|
|
self.no_aug_epochs = no_aug_epochs
|
|
self.skip_type_keys = skip_type_keys
|
|
|
|
def before_train_epoch(self, runner):
|
|
"""Close mosaic and mixup augmentation and switches to use L1 loss."""
|
|
epoch = runner.epoch
|
|
train_loader = runner.data_loader
|
|
model = runner.model
|
|
if is_module_wrapper(model):
|
|
model = model.module
|
|
if (epoch + 1) == runner.max_epochs - self.no_aug_epochs:
|
|
runner.logger.info('No mosaic and mixup aug now!')
|
|
train_loader.dataset.update_skip_type_keys(self.skip_type_keys)
|
|
runner.logger.info('Add additional L1 loss now!')
|
|
model.head.use_l1 = True
|