diff --git a/mmyolo/engine/__init__.py b/mmyolo/engine/__init__.py new file mode 100644 index 00000000..b2e0a126 --- /dev/null +++ b/mmyolo/engine/__init__.py @@ -0,0 +1,3 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from .hooks import * # noqa: F401,F403 +from .optimizers import * # noqa: F401,F403 diff --git a/mmyolo/engine/hooks/__init__.py b/mmyolo/engine/hooks/__init__.py new file mode 100644 index 00000000..d338af42 --- /dev/null +++ b/mmyolo/engine/hooks/__init__.py @@ -0,0 +1,5 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from .yolov5_param_scheduler_hook import YOLOv5ParamSchedulerHook +from .yolox_mode_switch_hook import YOLOXModeSwitchHook + +__all__ = ['YOLOv5ParamSchedulerHook', 'YOLOXModeSwitchHook'] diff --git a/mmyolo/engine/hooks/yolov5_param_scheduler_hook.py b/mmyolo/engine/hooks/yolov5_param_scheduler_hook.py new file mode 100644 index 00000000..60d91a6d --- /dev/null +++ b/mmyolo/engine/hooks/yolov5_param_scheduler_hook.py @@ -0,0 +1,111 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import math +from typing import Optional + +import numpy as np +from mmengine.hooks import ParamSchedulerHook +from mmengine.runner import Runner + +from mmyolo.registry import HOOKS + + +def linear_fn(lr_factor, max_epochs): + return lambda x: (1 - x / max_epochs) * (1.0 - lr_factor) + lr_factor + + +def cosine_fn(lr_factor, max_epochs): + return lambda x: ( + (1 - math.cos(x * math.pi / max_epochs)) / 2) * (lr_factor - 1) + 1 + + +@HOOKS.register_module() +class YOLOv5ParamSchedulerHook(ParamSchedulerHook): + """A hook to update learning rate and momentum in optimizer of YOLOv5.""" + priority = 9 + + scheduler_maps = {'linear': linear_fn, 'cosine': cosine_fn} + + def __init__(self, + scheduler_type: str = 'linear', + lr_factor: float = 0.01, + max_epochs: int = 300, + warmup_epochs: int = 3, + warmup_bias_lr: float = 0.1, + warmup_momentum: float = 0.8, + warmup_mim_iter: int = 1000, + **kwargs) -> None: + + assert scheduler_type in self.scheduler_maps + + self.warmup_epochs = warmup_epochs + self.warmup_bias_lr = warmup_bias_lr + self.warmup_momentum = warmup_momentum + self.warmup_mim_iter = warmup_mim_iter + + kwargs.update({'lr_factor': lr_factor, 'max_epochs': max_epochs}) + self.scheduler_fn = self.scheduler_maps[scheduler_type](**kwargs) + + self._warmup_end = False + self._base_lr = None + self._base_momentum = None + + def before_train(self, runner: Runner) -> None: + optimizer = runner.optim_wrapper.optimizer + for group in optimizer.param_groups: + # If the param is never be scheduled, record the current value + # as the initial value. + group.setdefault('initial_lr', group['lr']) + group.setdefault('initial_momentum', group.get('momentum', -1)) + + self._base_lr = [ + group['initial_lr'] for group in optimizer.param_groups + ] + self._base_momentum = [ + group['initial_momentum'] for group in optimizer.param_groups + ] + + def before_train_iter(self, + runner: Runner, + batch_idx: int, + data_batch: Optional[dict] = None) -> None: + cur_iters = runner.iter + cur_epoch = runner.epoch + optimizer = runner.optim_wrapper.optimizer + + # The minimum warmup is self.warmup_mim_iter + warmup_total_iters = max( + round(self.warmup_epochs * len(runner.train_dataloader)), + self.warmup_mim_iter) + + if cur_iters <= warmup_total_iters: + xp = [0, warmup_total_iters] + for group_idx, param in enumerate(optimizer.param_groups): + if group_idx == 2: + # bias learning rate will be handled specially + yp = [ + self.warmup_bias_lr, + self._base_lr[group_idx] * self.scheduler_fn(cur_epoch) + ] + else: + yp = [ + 0.0, + self._base_lr[group_idx] * self.scheduler_fn(cur_epoch) + ] + param['lr'] = np.interp(cur_iters, xp, yp) + + if 'momentum' in param: + param['momentum'] = np.interp( + cur_iters, xp, + [self.warmup_momentum, self._base_momentum[group_idx]]) + else: + self._warmup_end = True + + def after_train_epoch(self, runner: Runner) -> None: + if not self._warmup_end: + return + + cur_epoch = runner.epoch + optimizer = runner.optim_wrapper.optimizer + for group_idx, param in enumerate(optimizer.param_groups): + param['lr'] = self._base_lr[group_idx] * self.scheduler_fn( + cur_epoch) diff --git a/mmyolo/engine/hooks/yolox_mode_switch_hook.py b/mmyolo/engine/hooks/yolox_mode_switch_hook.py new file mode 100644 index 00000000..a5ea25dd --- /dev/null +++ b/mmyolo/engine/hooks/yolox_mode_switch_hook.py @@ -0,0 +1,53 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import copy + +from mmengine.hooks import Hook +from mmengine.model import is_model_wrapper +from mmengine.runner import Runner + +from mmyolo.registry import HOOKS + + +@HOOKS.register_module() +class YOLOXModeSwitchHook(Hook): + """Switch the mode of YOLOX during training. + + This hook turns off the mosaic and mixup data augmentation and switches + to use L1 loss in bbox_head. + + Args: + num_last_epochs (int): The number of latter epochs in the end of the + training to close the data augmentation and switch to L1 loss. + Defaults to 15. + """ + + def __init__(self, + num_last_epochs: int = 15, + new_train_pipeline=None) -> None: + self.num_last_epochs = num_last_epochs + self.new_train_pipeline_cfg = new_train_pipeline + + def before_train_epoch(self, runner) -> None: + """Close mosaic and mixup augmentation and switches to use L1 loss.""" + epoch = runner.epoch + model = runner.model + if is_model_wrapper(model): + model = model.module + + if (epoch + 1) == runner.max_epochs - self.num_last_epochs: + runner.logger.info(f'New Pipeline: {self.new_train_pipeline_cfg}') + + train_dataloader_cfg = copy.deepcopy(runner.cfg.train_dataloader) + train_dataloader_cfg.dataset.pipeline = self.new_train_pipeline_cfg + # Note: Why rebuild the dataset? + # When build_dataloader will make a deep copy of the dataset, + # it will lead to potential risks, such as the global instance + # object FileClient data is disordered. + # This problem needs to be solved in the future. + new_train_dataloader = Runner.build_dataloader( + train_dataloader_cfg) + runner.train_loop.dataloader = new_train_dataloader + + runner.logger.info('recreate the dataloader!') + runner.logger.info('Add additional bbox reg loss now!') + model.bbox_head.use_bbox_aux = True diff --git a/mmyolo/engine/optimizers/__init__.py b/mmyolo/engine/optimizers/__init__.py new file mode 100644 index 00000000..3ad91894 --- /dev/null +++ b/mmyolo/engine/optimizers/__init__.py @@ -0,0 +1,4 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from .yolov5_optim_constructor import YOLOv5OptimizerConstructor + +__all__ = ['YOLOv5OptimizerConstructor'] diff --git a/mmyolo/engine/optimizers/yolov5_optim_constructor.py b/mmyolo/engine/optimizers/yolov5_optim_constructor.py new file mode 100644 index 00000000..bbb77f13 --- /dev/null +++ b/mmyolo/engine/optimizers/yolov5_optim_constructor.py @@ -0,0 +1,128 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import Optional + +import torch.nn as nn +from mmengine.dist import get_world_size +from mmengine.logging import print_log +from mmengine.model import is_model_wrapper +from mmengine.optim import OptimWrapper + +from mmyolo.registry import (OPTIM_WRAPPER_CONSTRUCTORS, OPTIM_WRAPPERS, + OPTIMIZERS) + + +@OPTIM_WRAPPER_CONSTRUCTORS.register_module() +class YOLOv5OptimizerConstructor: + """YOLOv5 constructor for optimizers. + + It has the following functions: + + - divides the optimizer parameters into 3 groups: + Conv, Bias and BN + + - support `weight_decay` parameter adaption based on + `batch_size_pre_gpu` + + Args: + optim_wrapper_cfg (dict): The config dict of the optimizer wrapper. + Positional fields are + + - ``type``: class name of the OptimizerWrapper + - ``optimizer``: The configuration of optimizer. + + Optional fields are + + - any arguments of the corresponding optimizer wrapper type, + e.g., accumulative_counts, clip_grad, etc. + + The positional fields of ``optimizer`` are + + - `type`: class name of the optimizer. + + Optional fields are + + - any arguments of the corresponding optimizer type, e.g., + lr, weight_decay, momentum, etc. + + paramwise_cfg (dict, optional): Parameter-wise options. Must include + `base_total_batch_size` if not None. If the total input batch + is smaller than `base_total_batch_size`, the `weight_decay` + parameter will be kept unchanged, otherwise linear scaling. + + Example: + >>> model = torch.nn.modules.Conv1d(1, 1, 1) + >>> optim_wrapper_cfg = dict( + >>> dict(type='OptimWrapper', optimizer=dict(type='SGD', lr=0.01, + >>> momentum=0.9, weight_decay=0.0001, batch_size_pre_gpu=16)) + >>> paramwise_cfg = dict(base_total_batch_size=64) + >>> optim_wrapper_builder = YOLOv5OptimizerConstructor( + >>> optim_wrapper_cfg, paramwise_cfg) + >>> optim_wrapper = optim_wrapper_builder(model) + """ + + def __init__(self, + optim_wrapper_cfg: dict, + paramwise_cfg: Optional[dict] = None) -> None: + if paramwise_cfg is None: + paramwise_cfg = {'base_total_batch_size': 64} + assert 'base_total_batch_size' in paramwise_cfg + + if not isinstance(optim_wrapper_cfg, dict): + raise TypeError('optimizer_cfg should be a dict', + f'but got {type(optim_wrapper_cfg)}') + assert 'optimizer' in optim_wrapper_cfg, ( + '`optim_wrapper_cfg` must contain "optimizer" config') + + self.optim_wrapper_cfg = optim_wrapper_cfg + self.optimizer_cfg = self.optim_wrapper_cfg.pop('optimizer') + self.base_total_batch_size = paramwise_cfg['base_total_batch_size'] + + def __call__(self, model: nn.Module) -> OptimWrapper: + if is_model_wrapper(model): + model = model.module + optimizer_cfg = self.optimizer_cfg.copy() + weight_decay = optimizer_cfg.pop('weight_decay', 0) + + if 'batch_size_pre_gpu' in optimizer_cfg: + batch_size_pre_gpu = optimizer_cfg.pop('batch_size_pre_gpu') + # No scaling if total_batch_size is less than + # base_total_batch_size, otherwise linear scaling. + total_batch_size = get_world_size() * batch_size_pre_gpu + accumulate = max( + round(self.base_total_batch_size / total_batch_size), 1) + scale_factor = total_batch_size * \ + accumulate / self.base_total_batch_size + + if scale_factor != 1: + weight_decay *= scale_factor + print_log(f'Scaled weight_decay to {weight_decay}', 'current') + + params_groups = [], [], [] + + for v in model.modules(): + if hasattr(v, 'bias') and isinstance(v.bias, nn.Parameter): + params_groups[2].append(v.bias) + # Includes SyncBatchNorm + if isinstance(v, nn.modules.batchnorm._NormBase): + params_groups[1].append(v.weight) + elif hasattr(v, 'weight') and isinstance(v.weight, nn.Parameter): + params_groups[0].append(v.weight) + + # Note: Make sure bias is in the last parameter group + optimizer_cfg['params'] = [] + # conv + optimizer_cfg['params'].append({ + 'params': params_groups[0], + 'weight_decay': weight_decay + }) + # bn + optimizer_cfg['params'].append({'params': params_groups[1]}) + # bias + optimizer_cfg['params'].append({'params': params_groups[2]}) + + del params_groups + + optimizer = OPTIMIZERS.build(optimizer_cfg) + optim_wrapper = OPTIM_WRAPPERS.build( + self.optim_wrapper_cfg, default_args=dict(optimizer=optimizer)) + return optim_wrapper