From 063288240a181de3b2cd8cde9420dc5b515060a1 Mon Sep 17 00:00:00 2001 From: "fangyixiao.vendor" Date: Fri, 20 May 2022 06:26:32 +0000 Subject: [PATCH] [Refactor] refactor hooks --- mmselfsup/core/hooks/__init__.py | 6 +- mmselfsup/core/hooks/deepcluster_hook.py | 52 ++-- mmselfsup/core/hooks/densecl_hook.py | 12 +- mmselfsup/core/hooks/momentum_update_hook.py | 48 ---- mmselfsup/core/hooks/odc_hook.py | 24 +- mmselfsup/core/hooks/optimizer_hook.py | 261 ------------------- mmselfsup/core/hooks/simsiam_hook.py | 18 +- mmselfsup/core/hooks/swav_hook.py | 24 +- 8 files changed, 68 insertions(+), 377 deletions(-) delete mode 100644 mmselfsup/core/hooks/momentum_update_hook.py delete mode 100644 mmselfsup/core/hooks/optimizer_hook.py diff --git a/mmselfsup/core/hooks/__init__.py b/mmselfsup/core/hooks/__init__.py index 288b3e48..147d254f 100644 --- a/mmselfsup/core/hooks/__init__.py +++ b/mmselfsup/core/hooks/__init__.py @@ -1,14 +1,10 @@ # Copyright (c) OpenMMLab. All rights reserved. from .deepcluster_hook import DeepClusterHook from .densecl_hook import DenseCLHook -from .momentum_update_hook import MomentumUpdateHook from .odc_hook import ODCHook -from .optimizer_hook import DistOptimizerHook, GradAccumFp16OptimizerHook from .simsiam_hook import SimSiamHook from .swav_hook import SwAVHook __all__ = [ - 'MomentumUpdateHook', 'DeepClusterHook', 'DenseCLHook', 'ODCHook', - 'DistOptimizerHook', 'GradAccumFp16OptimizerHook', 'SimSiamHook', - 'SwAVHook' + 'DeepClusterHook', 'DenseCLHook', 'ODCHook', 'SimSiamHook', 'SwAVHook' ] diff --git a/mmselfsup/core/hooks/deepcluster_hook.py b/mmselfsup/core/hooks/deepcluster_hook.py index a1321d0f..17c2d923 100644 --- a/mmselfsup/core/hooks/deepcluster_hook.py +++ b/mmselfsup/core/hooks/deepcluster_hook.py @@ -1,13 +1,16 @@ # Copyright (c) OpenMMLab. All rights reserved. +from typing import Dict, Optional + import numpy as np import torch import torch.distributed as dist -from mmcv.runner import HOOKS, Hook +from mmengine.hooks import Hook from mmengine.logging import print_log +from torch.utils.data import DataLoader +from mmselfsup.registry import HOOKS from mmselfsup.utils import Extractor from mmselfsup.utils import clustering as _clustering -from mmselfsup.utils import get_root_logger @HOOKS.register_module() @@ -32,33 +35,16 @@ class DeepClusterHook(Hook): def __init__( self, - extractor, - clustering, - unif_sampling, - reweight, - reweight_pow, - init_memory=False, # for ODC - initial=True, - interval=1, - dist_mode=True, - data_loaders=None): - - logger = get_root_logger() - if 'imgs_per_gpu' in extractor: - logger.warning('"imgs_per_gpu" is deprecated. ' - 'Please use "samples_per_gpu" instead') - if 'samples_per_gpu' in extractor: - logger.warning( - f'Got "imgs_per_gpu"={extractor["imgs_per_gpu"]} and ' - f'"samples_per_gpu"={extractor["samples_per_gpu"]}, ' - f'"imgs_per_gpu"={extractor["imgs_per_gpu"]} is used in ' - f'this experiments') - else: - logger.warning( - 'Automatically set "samples_per_gpu"="imgs_per_gpu"=' - f'{extractor["imgs_per_gpu"]} in this experiments') - extractor['samples_per_gpu'] = extractor['imgs_per_gpu'] - + extractor: Dict, + clustering: Dict, + unif_sampling: bool, + reweight: bool, + reweight_pow: float, + init_memory: Optional[bool] = False, # for ODC + initial: Optional[bool] = True, + interval: Optional[int] = 1, + dist_mode: Optional[bool] = True, + data_loaders: Optional[DataLoader] = None) -> None: self.extractor = Extractor(dist_mode=dist_mode, **extractor) self.clustering_type = clustering.pop('type') self.clustering_cfg = clustering @@ -71,16 +57,16 @@ class DeepClusterHook(Hook): self.dist_mode = dist_mode self.data_loaders = data_loaders - def before_run(self, runner): + def before_run(self, runner) -> None: if self.initial: self.deepcluster(runner) - def after_train_epoch(self, runner): + def after_train_epoch(self, runner) -> None: if not self.every_n_epochs(runner, self.interval): return self.deepcluster(runner) - def deepcluster(self, runner): + def deepcluster(self, runner) -> None: # step 1: get features runner.model.eval() features = self.extractor(runner) @@ -130,7 +116,7 @@ class DeepClusterHook(Hook): if self.init_memory: runner.model.module.memory_bank.init_memory(features, new_labels) - def evaluate(self, runner, new_labels): + def evaluate(self, runner, new_labels: np.ndarray) -> None: histogram = np.bincount(new_labels, minlength=self.clustering_cfg.k) empty_cls = (histogram == 0).sum() minimal_cls_size, maximal_cls_size = histogram.min(), histogram.max() diff --git a/mmselfsup/core/hooks/densecl_hook.py b/mmselfsup/core/hooks/densecl_hook.py index c4ac75fe..05712fe7 100644 --- a/mmselfsup/core/hooks/densecl_hook.py +++ b/mmselfsup/core/hooks/densecl_hook.py @@ -1,5 +1,9 @@ # Copyright (c) OpenMMLab. All rights reserved. -from mmcv.runner import HOOKS, Hook +from typing import Optional + +from mmengine.hooks import Hook + +from mmselfsup.registry import HOOKS @HOOKS.register_module() @@ -14,15 +18,15 @@ class DenseCLHook(Hook): ``loss_lambda=0``. Defaults to 1000. """ - def __init__(self, start_iters=1000, **kwargs): + def __init__(self, start_iters: Optional[int] = 1000) -> None: self.start_iters = start_iters - def before_run(self, runner): + def before_run(self, runner) -> None: assert hasattr(runner.model.module, 'loss_lambda'), \ "The runner must have attribute \"loss_lambda\" in DenseCL." self.loss_lambda = runner.model.module.loss_lambda - def before_train_iter(self, runner): + def before_train_iter(self, runner) -> None: assert hasattr(runner.model.module, 'loss_lambda'), \ "The runner must have attribute \"loss_lambda\" in DenseCL." cur_iter = runner.iter diff --git a/mmselfsup/core/hooks/momentum_update_hook.py b/mmselfsup/core/hooks/momentum_update_hook.py deleted file mode 100644 index a211ec5f..00000000 --- a/mmselfsup/core/hooks/momentum_update_hook.py +++ /dev/null @@ -1,48 +0,0 @@ -# Copyright (c) OpenMMLab. All rights reserved. -from math import cos, pi - -from mmcv.parallel import is_module_wrapper -from mmcv.runner import HOOKS, Hook - - -@HOOKS.register_module(name=['BYOLHook', 'MomentumUpdateHook']) -class MomentumUpdateHook(Hook): - """Hook for updating momentum parameter, used by BYOL, MoCoV3, etc. - - This hook includes momentum adjustment following: - - .. math:: - m = 1 - (1 - m_0) * (cos(pi * k / K) + 1) / 2 - - where :math:`k` is the current step, :math:`K` is the total steps. - - Args: - end_momentum (float): The final momentum coefficient - for the target network. Defaults to 1. - update_interval (int, optional): The momentum update interval of the - weights. Defaults to 1. - """ - - def __init__(self, end_momentum=1., update_interval=1, **kwargs): - self.end_momentum = end_momentum - self.update_interval = update_interval - - def before_train_iter(self, runner): - assert hasattr(runner.model.module, 'momentum'), \ - "The runner must have attribute \"momentum\" in algorithms." - assert hasattr(runner.model.module, 'base_momentum'), \ - "The runner must have attribute \"base_momentum\" in algorithms." - if self.every_n_iters(runner, self.update_interval): - cur_iter = runner.iter - max_iter = runner.max_iters - base_m = runner.model.module.base_momentum - m = self.end_momentum - (self.end_momentum - base_m) * ( - cos(pi * cur_iter / float(max_iter)) + 1) / 2 - runner.model.module.momentum = m - - def after_train_iter(self, runner): - if self.every_n_iters(runner, self.update_interval): - if is_module_wrapper(runner.model): - runner.model.module.momentum_update() - else: - runner.model.momentum_update() diff --git a/mmselfsup/core/hooks/odc_hook.py b/mmselfsup/core/hooks/odc_hook.py index 27b08429..a756bc95 100644 --- a/mmselfsup/core/hooks/odc_hook.py +++ b/mmselfsup/core/hooks/odc_hook.py @@ -1,8 +1,12 @@ # Copyright (c) OpenMMLab. All rights reserved. +from typing import Optional + import numpy as np -from mmcv.runner import HOOKS, Hook +from mmengine.hooks import Hook from mmengine.logging import print_log +from mmselfsup.registry import HOOKS + @HOOKS.register_module() class ODCHook(Hook): @@ -22,12 +26,12 @@ class ODCHook(Hook): """ def __init__(self, - centroids_update_interval, - deal_with_small_clusters_interval, - evaluate_interval, - reweight, - reweight_pow, - dist_mode=True): + centroids_update_interval: int, + deal_with_small_clusters_interval: int, + evaluate_interval: int, + reweight: bool, + reweight_pow: float, + dist_mode: Optional[bool] = True) -> None: assert dist_mode, 'non-dist mode is not implemented' self.centroids_update_interval = centroids_update_interval self.deal_with_small_clusters_interval = \ @@ -36,7 +40,7 @@ class ODCHook(Hook): self.reweight = reweight self.reweight_pow = reweight_pow - def after_train_iter(self, runner): + def after_train_iter(self, runner) -> None: # centroids update if self.every_n_iters(runner, self.centroids_update_interval): runner.model.module.memory_bank.update_centroids_memory() @@ -55,7 +59,7 @@ class ODCHook(Hook): new_labels = new_labels.cpu() self.evaluate(runner, new_labels.numpy()) - def after_train_epoch(self, runner): + def after_train_epoch(self, runner) -> None: # save cluster if self.every_n_epochs(runner, 10) and runner.rank == 0: new_labels = runner.model.module.memory_bank.label_bank @@ -64,7 +68,7 @@ class ODCHook(Hook): np.save(f'{runner.work_dir}/cluster_epoch_{runner.epoch + 1}.npy', new_labels.numpy()) - def evaluate(self, runner, new_labels): + def evaluate(self, runner, new_labels: np.ndarray) -> None: histogram = np.bincount( new_labels, minlength=runner.model.module.memory_bank.num_classes) empty_cls = (histogram == 0).sum() diff --git a/mmselfsup/core/hooks/optimizer_hook.py b/mmselfsup/core/hooks/optimizer_hook.py deleted file mode 100644 index b0f4f757..00000000 --- a/mmselfsup/core/hooks/optimizer_hook.py +++ /dev/null @@ -1,261 +0,0 @@ -# Copyright (c) OpenMMLab. All rights reserved. -from mmcv.runner import (HOOKS, Fp16OptimizerHook, OptimizerHook, - allreduce_grads) -from mmcv.utils import TORCH_VERSION, _BatchNorm, digit_version - - -@HOOKS.register_module() -class DistOptimizerHook(OptimizerHook): - """Optimizer hook for distributed training. - - This hook can accumulate gradients every n intervals and freeze some - layers for some iters at the beginning. - - Args: - update_interval (int, optional): The update interval of the weights, - set > 1 to accumulate the grad. Defaults to 1. - grad_clip (dict, optional): Dict to config the value of grad clip. - E.g., grad_clip = dict(max_norm=10). Defaults to None. - coalesce (bool, optional): Whether allreduce parameters as a whole. - Defaults to True. - bucket_size_mb (int, optional): Size of bucket, the unit is MB. - Defaults to -1. - frozen_layers_cfg (dict, optional): Dict to config frozen layers. - The key-value pair is layer name and its frozen iters. If frozen, - the layer gradient would be set to None. Defaults to dict(). - """ - - def __init__(self, - update_interval=1, - grad_clip=None, - coalesce=True, - bucket_size_mb=-1, - frozen_layers_cfg=dict()): - self.grad_clip = grad_clip - self.coalesce = coalesce - self.bucket_size_mb = bucket_size_mb - self.update_interval = update_interval - self.frozen_layers_cfg = frozen_layers_cfg - self.initialized = False - - def has_batch_norm(self, module): - if isinstance(module, _BatchNorm): - return True - for m in module.children(): - if self.has_batch_norm(m): - return True - return False - - def _init(self, runner): - if runner.iter % self.update_interval != 0: - runner.logger.warning( - 'Resume iter number is not divisible by update_interval in ' - 'GradientCumulativeOptimizerHook, which means the gradient of ' - 'some iters is lost and the result may be influenced slightly.' - ) - - if self.has_batch_norm(runner.model) and self.update_interval > 1: - runner.logger.warning( - 'GradientCumulativeOptimizerHook may slightly decrease ' - 'performance if the model has BatchNorm layers.') - - residual_iters = runner.max_iters - - self.divisible_iters = ( - residual_iters // self.update_interval * self.update_interval) - self.remainder_iters = residual_iters - self.divisible_iters - - self.initialized = True - - def before_run(self, runner): - runner.optimizer.zero_grad() - - def after_train_iter(self, runner): - # In some cases, MMCV's GradientCumulativeOptimizerHook will - # cause the loss_factor to be zero and we fix this bug in our - # implementation. - - if not self.initialized: - self._init(runner) - - if runner.iter < self.divisible_iters: - loss_factor = self.update_interval - else: - loss_factor = self.remainder_iters - - runner.outputs['loss'] /= loss_factor - runner.outputs['loss'].backward() - - if (self.every_n_iters(runner, self.update_interval) - or self.is_last_iter(runner)): - - # cancel gradient of certain layer for n iters - # according to frozen_layers_cfg dict - for layer, iters in self.frozen_layers_cfg.items(): - if runner.iter < iters: - for name, p in runner.model.module.named_parameters(): - if layer in name: - p.grad = None - - if self.grad_clip is not None: - grad_norm = self.clip_grads(runner.model.parameters()) - if grad_norm is not None: - # Add grad norm to the logger - runner.log_buffer.update({'grad_norm': float(grad_norm)}, - runner.outputs['num_samples']) - - runner.optimizer.step() - runner.optimizer.zero_grad() - - -if (TORCH_VERSION != 'parrots' - and digit_version(TORCH_VERSION) >= digit_version('1.6.0')): - - @HOOKS.register_module() - class GradAccumFp16OptimizerHook(Fp16OptimizerHook): - """Fp16 optimizer hook (using PyTorch's implementation). - - This hook can accumulate gradients every n intervals and freeze some - layers for some iters at the beginning. - If you are using PyTorch >= 1.6, torch.cuda.amp is used as the backend, - to take care of the optimization procedure. - - Args: - update_interval (int, optional): The update interval of the - weights, set > 1 to accumulate the grad. Defaults to 1. - frozen_layers_cfg (dict, optional): Dict to config frozen layers. - The key-value pair is layer name and its frozen iters. If - frozen, the layer gradient would be set to None. - Defaults to dict(). - """ - - def __init__(self, - update_interval=1, - frozen_layers_cfg=dict(), - **kwargs): - super(GradAccumFp16OptimizerHook, self).__init__(**kwargs) - self.update_interval = update_interval - self.frozen_layers_cfg = frozen_layers_cfg - - def after_train_iter(self, runner): - runner.outputs['loss'] /= self.update_interval - self.loss_scaler.scale(runner.outputs['loss']).backward() - - if self.every_n_iters(runner, self.update_interval): - - # cancel gradient of certain layer for n iters - # according to frozen_layers_cfg dict - for layer, iters in self.frozen_layers_cfg.items(): - if runner.iter < iters: - for name, p in runner.model.module.named_parameters(): - if layer in name: - p.grad = None - - # copy fp16 grads in the model to fp32 params in the optimizer - self.loss_scaler.unscale_(runner.optimizer) - - if self.grad_clip is not None: - grad_norm = self.clip_grads(runner.model.parameters()) - if grad_norm is not None: - # Add grad norm to the logger - runner.log_buffer.update( - {'grad_norm': float(grad_norm)}, - runner.outputs['num_samples']) - - # backward and update scaler - self.loss_scaler.step(runner.optimizer) - self.loss_scaler.update(self._scale_update_param) - - # save state_dict of loss_scaler - runner.meta.setdefault( - 'fp16', {})['loss_scaler'] = self.loss_scaler.state_dict() - - # clear grads - runner.model.zero_grad() - runner.optimizer.zero_grad() - -else: - - @HOOKS.register_module() - class GradAccumFp16OptimizerHook(Fp16OptimizerHook): - """Fp16 optimizer hook (using mmcv's implementation). - - This hook can accumulate gradients every n intervals and freeze some - layers for some iters at the beginning. - - Args: - update_interval (int, optional): The update interval of the - weights, set > 1 to accumulate the grad. Defaults to 1. - frozen_layers_cfg (dict, optional): Dict to config frozen layers. - The key-value pair is layer name and its frozen iters. If - frozen, the layer gradient would be set to None. - Defaults to dict(). - """ - - def __init__(self, - update_interval=1, - frozen_layers_cfg=dict(), - **kwargs): - super(GradAccumFp16OptimizerHook, self).__init__(**kwargs) - self.update_interval = update_interval - self.frozen_layers_cfg = frozen_layers_cfg - - def after_train_iter(self, runner): - runner.outputs['loss'] /= self.update_interval - - # scale the loss value - scaled_loss = runner.outputs['loss'] * self.loss_scaler.loss_scale - scaled_loss.backward() - - if self.every_n_iters(runner, self.update_interval): - - # cancel gradient of certain layer for n iters - # according to frozen_layers_cfg dict - for layer, iters in self.frozen_layers_cfg.items(): - if runner.iter < iters: - for name, p in runner.model.module.named_parameters(): - if layer in name: - p.grad = None - - # copy fp16 grads in the model to fp32 params in the optimizer - fp32_weights = [] - for param_group in runner.optimizer.param_groups: - fp32_weights += param_group['params'] - self.copy_grads_to_fp32(runner.model, fp32_weights) - # allreduce grads - if self.distributed: - allreduce_grads(fp32_weights, self.coalesce, - self.bucket_size_mb) - - has_overflow = self.loss_scaler.has_overflow(fp32_weights) - # if has overflow, skip this iteration - if not has_overflow: - # scale the gradients back - for param in fp32_weights: - if param.grad is not None: - param.grad.div_(self.loss_scaler.loss_scale) - if self.grad_clip is not None: - grad_norm = self.clip_grads(fp32_weights) - if grad_norm is not None: - # Add grad norm to the logger - runner.log_buffer.update( - {'grad_norm': float(grad_norm)}, - runner.outputs['num_samples']) - # update fp32 params - runner.optimizer.step() - # copy fp32 params to the fp16 model - self.copy_params_to_fp16(runner.model, fp32_weights) - else: - runner.logger.warning( - 'Check overflow, downscale loss scale ' - f'to {self.loss_scaler.cur_scale}') - - self.loss_scaler.update_scale(has_overflow) - - # save state_dict of loss_scaler - runner.meta.setdefault( - 'fp16', {})['loss_scaler'] = self.loss_scaler.state_dict() - - # clear grads - runner.model.zero_grad() - runner.optimizer.zero_grad() diff --git a/mmselfsup/core/hooks/simsiam_hook.py b/mmselfsup/core/hooks/simsiam_hook.py index 0aa07292..0f407b56 100644 --- a/mmselfsup/core/hooks/simsiam_hook.py +++ b/mmselfsup/core/hooks/simsiam_hook.py @@ -1,5 +1,9 @@ # Copyright (c) OpenMMLab. All rights reserved. -from mmcv.runner import HOOKS, Hook +from typing import Optional + +from mmengine.hooks import Hook + +from mmselfsup.registry import HOOKS @HOOKS.register_module() @@ -15,12 +19,16 @@ class SimSiamHook(Hook): Defaults to True. """ - def __init__(self, fix_pred_lr, lr, adjust_by_epoch=True, **kwargs): + def __init__(self, + fix_pred_lr: bool, + lr: float, + adjust_by_epoch: Optional[bool] = True) -> None: self.fix_pred_lr = fix_pred_lr self.lr = lr self.adjust_by_epoch = adjust_by_epoch - def before_train_iter(self, runner): + def before_train_iter(self, runner) -> None: + """fix lr of predictor by iter.""" if self.adjust_by_epoch: return else: @@ -29,8 +37,8 @@ class SimSiamHook(Hook): if 'fix_lr' in param_group and param_group['fix_lr']: param_group['lr'] = self.lr - def before_train_epoch(self, runner): - """fix lr of predictor.""" + def before_train_epoch(self, runner) -> None: + """fix lr of predictor by epoch.""" if self.fix_pred_lr: for param_group in runner.optimizer.param_groups: if 'fix_lr' in param_group and param_group['fix_lr']: diff --git a/mmselfsup/core/hooks/swav_hook.py b/mmselfsup/core/hooks/swav_hook.py index 821e6348..91c89378 100644 --- a/mmselfsup/core/hooks/swav_hook.py +++ b/mmselfsup/core/hooks/swav_hook.py @@ -1,9 +1,12 @@ # Copyright (c) OpenMMLab. All rights reserved. import os.path as osp +from typing import List, Optional import torch import torch.distributed as dist -from mmcv.runner import HOOKS, Hook +from mmengine.hooks import Hook + +from mmselfsup.registry import HOOKS @HOOKS.register_module() @@ -29,13 +32,12 @@ class SwAVHook(Hook): """ def __init__(self, - batch_size, - epoch_queue_starts=15, - crops_for_assign=[0, 1], - feat_dim=128, - queue_length=0, - interval=1, - **kwargs): + batch_size: int, + epoch_queue_starts: Optional[int] = 15, + crops_for_assign: Optional[List[int]] = [0, 1], + feat_dim: Optional[int] = 128, + queue_length: Optional[int] = 0, + interval: Optional[int] = 1): self.batch_size = batch_size * dist.get_world_size()\ if dist.is_initialized() else batch_size self.epoch_queue_starts = epoch_queue_starts @@ -45,7 +47,7 @@ class SwAVHook(Hook): self.interval = interval self.queue = None - def before_run(self, runner): + def before_run(self, runner) -> None: if dist.is_initialized(): self.queue_path = osp.join(runner.work_dir, 'queue' + str(dist.get_rank()) + '.pth') @@ -58,7 +60,7 @@ class SwAVHook(Hook): # the queue needs to be divisible by the batch size self.queue_length -= self.queue_length % self.batch_size - def before_train_epoch(self, runner): + def before_train_epoch(self, runner) -> None: # optionally starts a queue if self.queue_length > 0 \ and runner.epoch >= self.epoch_queue_starts \ @@ -73,7 +75,7 @@ class SwAVHook(Hook): runner.model.module.head.queue = self.queue runner.model.module.head.use_queue = False - def after_train_epoch(self, runner): + def after_train_epoch(self, runner) -> None: self.queue = runner.model.module.head.queue if self.queue is not None and self.every_n_epochs(