mirror of
https://github.com/open-mmlab/mmselfsup.git
synced 2025-06-03 14:59:38 +08:00
[Refactor] refactor hooks
This commit is contained in:
parent
9565da23a4
commit
063288240a
@ -1,14 +1,10 @@
|
|||||||
# Copyright (c) OpenMMLab. All rights reserved.
|
# Copyright (c) OpenMMLab. All rights reserved.
|
||||||
from .deepcluster_hook import DeepClusterHook
|
from .deepcluster_hook import DeepClusterHook
|
||||||
from .densecl_hook import DenseCLHook
|
from .densecl_hook import DenseCLHook
|
||||||
from .momentum_update_hook import MomentumUpdateHook
|
|
||||||
from .odc_hook import ODCHook
|
from .odc_hook import ODCHook
|
||||||
from .optimizer_hook import DistOptimizerHook, GradAccumFp16OptimizerHook
|
|
||||||
from .simsiam_hook import SimSiamHook
|
from .simsiam_hook import SimSiamHook
|
||||||
from .swav_hook import SwAVHook
|
from .swav_hook import SwAVHook
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
'MomentumUpdateHook', 'DeepClusterHook', 'DenseCLHook', 'ODCHook',
|
'DeepClusterHook', 'DenseCLHook', 'ODCHook', 'SimSiamHook', 'SwAVHook'
|
||||||
'DistOptimizerHook', 'GradAccumFp16OptimizerHook', 'SimSiamHook',
|
|
||||||
'SwAVHook'
|
|
||||||
]
|
]
|
||||||
|
@ -1,13 +1,16 @@
|
|||||||
# Copyright (c) OpenMMLab. All rights reserved.
|
# Copyright (c) OpenMMLab. All rights reserved.
|
||||||
|
from typing import Dict, Optional
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
import torch.distributed as dist
|
import torch.distributed as dist
|
||||||
from mmcv.runner import HOOKS, Hook
|
from mmengine.hooks import Hook
|
||||||
from mmengine.logging import print_log
|
from mmengine.logging import print_log
|
||||||
|
from torch.utils.data import DataLoader
|
||||||
|
|
||||||
|
from mmselfsup.registry import HOOKS
|
||||||
from mmselfsup.utils import Extractor
|
from mmselfsup.utils import Extractor
|
||||||
from mmselfsup.utils import clustering as _clustering
|
from mmselfsup.utils import clustering as _clustering
|
||||||
from mmselfsup.utils import get_root_logger
|
|
||||||
|
|
||||||
|
|
||||||
@HOOKS.register_module()
|
@HOOKS.register_module()
|
||||||
@ -32,33 +35,16 @@ class DeepClusterHook(Hook):
|
|||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
extractor,
|
extractor: Dict,
|
||||||
clustering,
|
clustering: Dict,
|
||||||
unif_sampling,
|
unif_sampling: bool,
|
||||||
reweight,
|
reweight: bool,
|
||||||
reweight_pow,
|
reweight_pow: float,
|
||||||
init_memory=False, # for ODC
|
init_memory: Optional[bool] = False, # for ODC
|
||||||
initial=True,
|
initial: Optional[bool] = True,
|
||||||
interval=1,
|
interval: Optional[int] = 1,
|
||||||
dist_mode=True,
|
dist_mode: Optional[bool] = True,
|
||||||
data_loaders=None):
|
data_loaders: Optional[DataLoader] = None) -> None:
|
||||||
|
|
||||||
logger = get_root_logger()
|
|
||||||
if 'imgs_per_gpu' in extractor:
|
|
||||||
logger.warning('"imgs_per_gpu" is deprecated. '
|
|
||||||
'Please use "samples_per_gpu" instead')
|
|
||||||
if 'samples_per_gpu' in extractor:
|
|
||||||
logger.warning(
|
|
||||||
f'Got "imgs_per_gpu"={extractor["imgs_per_gpu"]} and '
|
|
||||||
f'"samples_per_gpu"={extractor["samples_per_gpu"]}, '
|
|
||||||
f'"imgs_per_gpu"={extractor["imgs_per_gpu"]} is used in '
|
|
||||||
f'this experiments')
|
|
||||||
else:
|
|
||||||
logger.warning(
|
|
||||||
'Automatically set "samples_per_gpu"="imgs_per_gpu"='
|
|
||||||
f'{extractor["imgs_per_gpu"]} in this experiments')
|
|
||||||
extractor['samples_per_gpu'] = extractor['imgs_per_gpu']
|
|
||||||
|
|
||||||
self.extractor = Extractor(dist_mode=dist_mode, **extractor)
|
self.extractor = Extractor(dist_mode=dist_mode, **extractor)
|
||||||
self.clustering_type = clustering.pop('type')
|
self.clustering_type = clustering.pop('type')
|
||||||
self.clustering_cfg = clustering
|
self.clustering_cfg = clustering
|
||||||
@ -71,16 +57,16 @@ class DeepClusterHook(Hook):
|
|||||||
self.dist_mode = dist_mode
|
self.dist_mode = dist_mode
|
||||||
self.data_loaders = data_loaders
|
self.data_loaders = data_loaders
|
||||||
|
|
||||||
def before_run(self, runner):
|
def before_run(self, runner) -> None:
|
||||||
if self.initial:
|
if self.initial:
|
||||||
self.deepcluster(runner)
|
self.deepcluster(runner)
|
||||||
|
|
||||||
def after_train_epoch(self, runner):
|
def after_train_epoch(self, runner) -> None:
|
||||||
if not self.every_n_epochs(runner, self.interval):
|
if not self.every_n_epochs(runner, self.interval):
|
||||||
return
|
return
|
||||||
self.deepcluster(runner)
|
self.deepcluster(runner)
|
||||||
|
|
||||||
def deepcluster(self, runner):
|
def deepcluster(self, runner) -> None:
|
||||||
# step 1: get features
|
# step 1: get features
|
||||||
runner.model.eval()
|
runner.model.eval()
|
||||||
features = self.extractor(runner)
|
features = self.extractor(runner)
|
||||||
@ -130,7 +116,7 @@ class DeepClusterHook(Hook):
|
|||||||
if self.init_memory:
|
if self.init_memory:
|
||||||
runner.model.module.memory_bank.init_memory(features, new_labels)
|
runner.model.module.memory_bank.init_memory(features, new_labels)
|
||||||
|
|
||||||
def evaluate(self, runner, new_labels):
|
def evaluate(self, runner, new_labels: np.ndarray) -> None:
|
||||||
histogram = np.bincount(new_labels, minlength=self.clustering_cfg.k)
|
histogram = np.bincount(new_labels, minlength=self.clustering_cfg.k)
|
||||||
empty_cls = (histogram == 0).sum()
|
empty_cls = (histogram == 0).sum()
|
||||||
minimal_cls_size, maximal_cls_size = histogram.min(), histogram.max()
|
minimal_cls_size, maximal_cls_size = histogram.min(), histogram.max()
|
||||||
|
@ -1,5 +1,9 @@
|
|||||||
# Copyright (c) OpenMMLab. All rights reserved.
|
# Copyright (c) OpenMMLab. All rights reserved.
|
||||||
from mmcv.runner import HOOKS, Hook
|
from typing import Optional
|
||||||
|
|
||||||
|
from mmengine.hooks import Hook
|
||||||
|
|
||||||
|
from mmselfsup.registry import HOOKS
|
||||||
|
|
||||||
|
|
||||||
@HOOKS.register_module()
|
@HOOKS.register_module()
|
||||||
@ -14,15 +18,15 @@ class DenseCLHook(Hook):
|
|||||||
``loss_lambda=0``. Defaults to 1000.
|
``loss_lambda=0``. Defaults to 1000.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, start_iters=1000, **kwargs):
|
def __init__(self, start_iters: Optional[int] = 1000) -> None:
|
||||||
self.start_iters = start_iters
|
self.start_iters = start_iters
|
||||||
|
|
||||||
def before_run(self, runner):
|
def before_run(self, runner) -> None:
|
||||||
assert hasattr(runner.model.module, 'loss_lambda'), \
|
assert hasattr(runner.model.module, 'loss_lambda'), \
|
||||||
"The runner must have attribute \"loss_lambda\" in DenseCL."
|
"The runner must have attribute \"loss_lambda\" in DenseCL."
|
||||||
self.loss_lambda = runner.model.module.loss_lambda
|
self.loss_lambda = runner.model.module.loss_lambda
|
||||||
|
|
||||||
def before_train_iter(self, runner):
|
def before_train_iter(self, runner) -> None:
|
||||||
assert hasattr(runner.model.module, 'loss_lambda'), \
|
assert hasattr(runner.model.module, 'loss_lambda'), \
|
||||||
"The runner must have attribute \"loss_lambda\" in DenseCL."
|
"The runner must have attribute \"loss_lambda\" in DenseCL."
|
||||||
cur_iter = runner.iter
|
cur_iter = runner.iter
|
||||||
|
@ -1,48 +0,0 @@
|
|||||||
# Copyright (c) OpenMMLab. All rights reserved.
|
|
||||||
from math import cos, pi
|
|
||||||
|
|
||||||
from mmcv.parallel import is_module_wrapper
|
|
||||||
from mmcv.runner import HOOKS, Hook
|
|
||||||
|
|
||||||
|
|
||||||
@HOOKS.register_module(name=['BYOLHook', 'MomentumUpdateHook'])
|
|
||||||
class MomentumUpdateHook(Hook):
|
|
||||||
"""Hook for updating momentum parameter, used by BYOL, MoCoV3, etc.
|
|
||||||
|
|
||||||
This hook includes momentum adjustment following:
|
|
||||||
|
|
||||||
.. math::
|
|
||||||
m = 1 - (1 - m_0) * (cos(pi * k / K) + 1) / 2
|
|
||||||
|
|
||||||
where :math:`k` is the current step, :math:`K` is the total steps.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
end_momentum (float): The final momentum coefficient
|
|
||||||
for the target network. Defaults to 1.
|
|
||||||
update_interval (int, optional): The momentum update interval of the
|
|
||||||
weights. Defaults to 1.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self, end_momentum=1., update_interval=1, **kwargs):
|
|
||||||
self.end_momentum = end_momentum
|
|
||||||
self.update_interval = update_interval
|
|
||||||
|
|
||||||
def before_train_iter(self, runner):
|
|
||||||
assert hasattr(runner.model.module, 'momentum'), \
|
|
||||||
"The runner must have attribute \"momentum\" in algorithms."
|
|
||||||
assert hasattr(runner.model.module, 'base_momentum'), \
|
|
||||||
"The runner must have attribute \"base_momentum\" in algorithms."
|
|
||||||
if self.every_n_iters(runner, self.update_interval):
|
|
||||||
cur_iter = runner.iter
|
|
||||||
max_iter = runner.max_iters
|
|
||||||
base_m = runner.model.module.base_momentum
|
|
||||||
m = self.end_momentum - (self.end_momentum - base_m) * (
|
|
||||||
cos(pi * cur_iter / float(max_iter)) + 1) / 2
|
|
||||||
runner.model.module.momentum = m
|
|
||||||
|
|
||||||
def after_train_iter(self, runner):
|
|
||||||
if self.every_n_iters(runner, self.update_interval):
|
|
||||||
if is_module_wrapper(runner.model):
|
|
||||||
runner.model.module.momentum_update()
|
|
||||||
else:
|
|
||||||
runner.model.momentum_update()
|
|
@ -1,8 +1,12 @@
|
|||||||
# Copyright (c) OpenMMLab. All rights reserved.
|
# Copyright (c) OpenMMLab. All rights reserved.
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from mmcv.runner import HOOKS, Hook
|
from mmengine.hooks import Hook
|
||||||
from mmengine.logging import print_log
|
from mmengine.logging import print_log
|
||||||
|
|
||||||
|
from mmselfsup.registry import HOOKS
|
||||||
|
|
||||||
|
|
||||||
@HOOKS.register_module()
|
@HOOKS.register_module()
|
||||||
class ODCHook(Hook):
|
class ODCHook(Hook):
|
||||||
@ -22,12 +26,12 @@ class ODCHook(Hook):
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self,
|
def __init__(self,
|
||||||
centroids_update_interval,
|
centroids_update_interval: int,
|
||||||
deal_with_small_clusters_interval,
|
deal_with_small_clusters_interval: int,
|
||||||
evaluate_interval,
|
evaluate_interval: int,
|
||||||
reweight,
|
reweight: bool,
|
||||||
reweight_pow,
|
reweight_pow: float,
|
||||||
dist_mode=True):
|
dist_mode: Optional[bool] = True) -> None:
|
||||||
assert dist_mode, 'non-dist mode is not implemented'
|
assert dist_mode, 'non-dist mode is not implemented'
|
||||||
self.centroids_update_interval = centroids_update_interval
|
self.centroids_update_interval = centroids_update_interval
|
||||||
self.deal_with_small_clusters_interval = \
|
self.deal_with_small_clusters_interval = \
|
||||||
@ -36,7 +40,7 @@ class ODCHook(Hook):
|
|||||||
self.reweight = reweight
|
self.reweight = reweight
|
||||||
self.reweight_pow = reweight_pow
|
self.reweight_pow = reweight_pow
|
||||||
|
|
||||||
def after_train_iter(self, runner):
|
def after_train_iter(self, runner) -> None:
|
||||||
# centroids update
|
# centroids update
|
||||||
if self.every_n_iters(runner, self.centroids_update_interval):
|
if self.every_n_iters(runner, self.centroids_update_interval):
|
||||||
runner.model.module.memory_bank.update_centroids_memory()
|
runner.model.module.memory_bank.update_centroids_memory()
|
||||||
@ -55,7 +59,7 @@ class ODCHook(Hook):
|
|||||||
new_labels = new_labels.cpu()
|
new_labels = new_labels.cpu()
|
||||||
self.evaluate(runner, new_labels.numpy())
|
self.evaluate(runner, new_labels.numpy())
|
||||||
|
|
||||||
def after_train_epoch(self, runner):
|
def after_train_epoch(self, runner) -> None:
|
||||||
# save cluster
|
# save cluster
|
||||||
if self.every_n_epochs(runner, 10) and runner.rank == 0:
|
if self.every_n_epochs(runner, 10) and runner.rank == 0:
|
||||||
new_labels = runner.model.module.memory_bank.label_bank
|
new_labels = runner.model.module.memory_bank.label_bank
|
||||||
@ -64,7 +68,7 @@ class ODCHook(Hook):
|
|||||||
np.save(f'{runner.work_dir}/cluster_epoch_{runner.epoch + 1}.npy',
|
np.save(f'{runner.work_dir}/cluster_epoch_{runner.epoch + 1}.npy',
|
||||||
new_labels.numpy())
|
new_labels.numpy())
|
||||||
|
|
||||||
def evaluate(self, runner, new_labels):
|
def evaluate(self, runner, new_labels: np.ndarray) -> None:
|
||||||
histogram = np.bincount(
|
histogram = np.bincount(
|
||||||
new_labels, minlength=runner.model.module.memory_bank.num_classes)
|
new_labels, minlength=runner.model.module.memory_bank.num_classes)
|
||||||
empty_cls = (histogram == 0).sum()
|
empty_cls = (histogram == 0).sum()
|
||||||
|
@ -1,261 +0,0 @@
|
|||||||
# Copyright (c) OpenMMLab. All rights reserved.
|
|
||||||
from mmcv.runner import (HOOKS, Fp16OptimizerHook, OptimizerHook,
|
|
||||||
allreduce_grads)
|
|
||||||
from mmcv.utils import TORCH_VERSION, _BatchNorm, digit_version
|
|
||||||
|
|
||||||
|
|
||||||
@HOOKS.register_module()
|
|
||||||
class DistOptimizerHook(OptimizerHook):
|
|
||||||
"""Optimizer hook for distributed training.
|
|
||||||
|
|
||||||
This hook can accumulate gradients every n intervals and freeze some
|
|
||||||
layers for some iters at the beginning.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
update_interval (int, optional): The update interval of the weights,
|
|
||||||
set > 1 to accumulate the grad. Defaults to 1.
|
|
||||||
grad_clip (dict, optional): Dict to config the value of grad clip.
|
|
||||||
E.g., grad_clip = dict(max_norm=10). Defaults to None.
|
|
||||||
coalesce (bool, optional): Whether allreduce parameters as a whole.
|
|
||||||
Defaults to True.
|
|
||||||
bucket_size_mb (int, optional): Size of bucket, the unit is MB.
|
|
||||||
Defaults to -1.
|
|
||||||
frozen_layers_cfg (dict, optional): Dict to config frozen layers.
|
|
||||||
The key-value pair is layer name and its frozen iters. If frozen,
|
|
||||||
the layer gradient would be set to None. Defaults to dict().
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self,
|
|
||||||
update_interval=1,
|
|
||||||
grad_clip=None,
|
|
||||||
coalesce=True,
|
|
||||||
bucket_size_mb=-1,
|
|
||||||
frozen_layers_cfg=dict()):
|
|
||||||
self.grad_clip = grad_clip
|
|
||||||
self.coalesce = coalesce
|
|
||||||
self.bucket_size_mb = bucket_size_mb
|
|
||||||
self.update_interval = update_interval
|
|
||||||
self.frozen_layers_cfg = frozen_layers_cfg
|
|
||||||
self.initialized = False
|
|
||||||
|
|
||||||
def has_batch_norm(self, module):
|
|
||||||
if isinstance(module, _BatchNorm):
|
|
||||||
return True
|
|
||||||
for m in module.children():
|
|
||||||
if self.has_batch_norm(m):
|
|
||||||
return True
|
|
||||||
return False
|
|
||||||
|
|
||||||
def _init(self, runner):
|
|
||||||
if runner.iter % self.update_interval != 0:
|
|
||||||
runner.logger.warning(
|
|
||||||
'Resume iter number is not divisible by update_interval in '
|
|
||||||
'GradientCumulativeOptimizerHook, which means the gradient of '
|
|
||||||
'some iters is lost and the result may be influenced slightly.'
|
|
||||||
)
|
|
||||||
|
|
||||||
if self.has_batch_norm(runner.model) and self.update_interval > 1:
|
|
||||||
runner.logger.warning(
|
|
||||||
'GradientCumulativeOptimizerHook may slightly decrease '
|
|
||||||
'performance if the model has BatchNorm layers.')
|
|
||||||
|
|
||||||
residual_iters = runner.max_iters
|
|
||||||
|
|
||||||
self.divisible_iters = (
|
|
||||||
residual_iters // self.update_interval * self.update_interval)
|
|
||||||
self.remainder_iters = residual_iters - self.divisible_iters
|
|
||||||
|
|
||||||
self.initialized = True
|
|
||||||
|
|
||||||
def before_run(self, runner):
|
|
||||||
runner.optimizer.zero_grad()
|
|
||||||
|
|
||||||
def after_train_iter(self, runner):
|
|
||||||
# In some cases, MMCV's GradientCumulativeOptimizerHook will
|
|
||||||
# cause the loss_factor to be zero and we fix this bug in our
|
|
||||||
# implementation.
|
|
||||||
|
|
||||||
if not self.initialized:
|
|
||||||
self._init(runner)
|
|
||||||
|
|
||||||
if runner.iter < self.divisible_iters:
|
|
||||||
loss_factor = self.update_interval
|
|
||||||
else:
|
|
||||||
loss_factor = self.remainder_iters
|
|
||||||
|
|
||||||
runner.outputs['loss'] /= loss_factor
|
|
||||||
runner.outputs['loss'].backward()
|
|
||||||
|
|
||||||
if (self.every_n_iters(runner, self.update_interval)
|
|
||||||
or self.is_last_iter(runner)):
|
|
||||||
|
|
||||||
# cancel gradient of certain layer for n iters
|
|
||||||
# according to frozen_layers_cfg dict
|
|
||||||
for layer, iters in self.frozen_layers_cfg.items():
|
|
||||||
if runner.iter < iters:
|
|
||||||
for name, p in runner.model.module.named_parameters():
|
|
||||||
if layer in name:
|
|
||||||
p.grad = None
|
|
||||||
|
|
||||||
if self.grad_clip is not None:
|
|
||||||
grad_norm = self.clip_grads(runner.model.parameters())
|
|
||||||
if grad_norm is not None:
|
|
||||||
# Add grad norm to the logger
|
|
||||||
runner.log_buffer.update({'grad_norm': float(grad_norm)},
|
|
||||||
runner.outputs['num_samples'])
|
|
||||||
|
|
||||||
runner.optimizer.step()
|
|
||||||
runner.optimizer.zero_grad()
|
|
||||||
|
|
||||||
|
|
||||||
if (TORCH_VERSION != 'parrots'
|
|
||||||
and digit_version(TORCH_VERSION) >= digit_version('1.6.0')):
|
|
||||||
|
|
||||||
@HOOKS.register_module()
|
|
||||||
class GradAccumFp16OptimizerHook(Fp16OptimizerHook):
|
|
||||||
"""Fp16 optimizer hook (using PyTorch's implementation).
|
|
||||||
|
|
||||||
This hook can accumulate gradients every n intervals and freeze some
|
|
||||||
layers for some iters at the beginning.
|
|
||||||
If you are using PyTorch >= 1.6, torch.cuda.amp is used as the backend,
|
|
||||||
to take care of the optimization procedure.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
update_interval (int, optional): The update interval of the
|
|
||||||
weights, set > 1 to accumulate the grad. Defaults to 1.
|
|
||||||
frozen_layers_cfg (dict, optional): Dict to config frozen layers.
|
|
||||||
The key-value pair is layer name and its frozen iters. If
|
|
||||||
frozen, the layer gradient would be set to None.
|
|
||||||
Defaults to dict().
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self,
|
|
||||||
update_interval=1,
|
|
||||||
frozen_layers_cfg=dict(),
|
|
||||||
**kwargs):
|
|
||||||
super(GradAccumFp16OptimizerHook, self).__init__(**kwargs)
|
|
||||||
self.update_interval = update_interval
|
|
||||||
self.frozen_layers_cfg = frozen_layers_cfg
|
|
||||||
|
|
||||||
def after_train_iter(self, runner):
|
|
||||||
runner.outputs['loss'] /= self.update_interval
|
|
||||||
self.loss_scaler.scale(runner.outputs['loss']).backward()
|
|
||||||
|
|
||||||
if self.every_n_iters(runner, self.update_interval):
|
|
||||||
|
|
||||||
# cancel gradient of certain layer for n iters
|
|
||||||
# according to frozen_layers_cfg dict
|
|
||||||
for layer, iters in self.frozen_layers_cfg.items():
|
|
||||||
if runner.iter < iters:
|
|
||||||
for name, p in runner.model.module.named_parameters():
|
|
||||||
if layer in name:
|
|
||||||
p.grad = None
|
|
||||||
|
|
||||||
# copy fp16 grads in the model to fp32 params in the optimizer
|
|
||||||
self.loss_scaler.unscale_(runner.optimizer)
|
|
||||||
|
|
||||||
if self.grad_clip is not None:
|
|
||||||
grad_norm = self.clip_grads(runner.model.parameters())
|
|
||||||
if grad_norm is not None:
|
|
||||||
# Add grad norm to the logger
|
|
||||||
runner.log_buffer.update(
|
|
||||||
{'grad_norm': float(grad_norm)},
|
|
||||||
runner.outputs['num_samples'])
|
|
||||||
|
|
||||||
# backward and update scaler
|
|
||||||
self.loss_scaler.step(runner.optimizer)
|
|
||||||
self.loss_scaler.update(self._scale_update_param)
|
|
||||||
|
|
||||||
# save state_dict of loss_scaler
|
|
||||||
runner.meta.setdefault(
|
|
||||||
'fp16', {})['loss_scaler'] = self.loss_scaler.state_dict()
|
|
||||||
|
|
||||||
# clear grads
|
|
||||||
runner.model.zero_grad()
|
|
||||||
runner.optimizer.zero_grad()
|
|
||||||
|
|
||||||
else:
|
|
||||||
|
|
||||||
@HOOKS.register_module()
|
|
||||||
class GradAccumFp16OptimizerHook(Fp16OptimizerHook):
|
|
||||||
"""Fp16 optimizer hook (using mmcv's implementation).
|
|
||||||
|
|
||||||
This hook can accumulate gradients every n intervals and freeze some
|
|
||||||
layers for some iters at the beginning.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
update_interval (int, optional): The update interval of the
|
|
||||||
weights, set > 1 to accumulate the grad. Defaults to 1.
|
|
||||||
frozen_layers_cfg (dict, optional): Dict to config frozen layers.
|
|
||||||
The key-value pair is layer name and its frozen iters. If
|
|
||||||
frozen, the layer gradient would be set to None.
|
|
||||||
Defaults to dict().
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self,
|
|
||||||
update_interval=1,
|
|
||||||
frozen_layers_cfg=dict(),
|
|
||||||
**kwargs):
|
|
||||||
super(GradAccumFp16OptimizerHook, self).__init__(**kwargs)
|
|
||||||
self.update_interval = update_interval
|
|
||||||
self.frozen_layers_cfg = frozen_layers_cfg
|
|
||||||
|
|
||||||
def after_train_iter(self, runner):
|
|
||||||
runner.outputs['loss'] /= self.update_interval
|
|
||||||
|
|
||||||
# scale the loss value
|
|
||||||
scaled_loss = runner.outputs['loss'] * self.loss_scaler.loss_scale
|
|
||||||
scaled_loss.backward()
|
|
||||||
|
|
||||||
if self.every_n_iters(runner, self.update_interval):
|
|
||||||
|
|
||||||
# cancel gradient of certain layer for n iters
|
|
||||||
# according to frozen_layers_cfg dict
|
|
||||||
for layer, iters in self.frozen_layers_cfg.items():
|
|
||||||
if runner.iter < iters:
|
|
||||||
for name, p in runner.model.module.named_parameters():
|
|
||||||
if layer in name:
|
|
||||||
p.grad = None
|
|
||||||
|
|
||||||
# copy fp16 grads in the model to fp32 params in the optimizer
|
|
||||||
fp32_weights = []
|
|
||||||
for param_group in runner.optimizer.param_groups:
|
|
||||||
fp32_weights += param_group['params']
|
|
||||||
self.copy_grads_to_fp32(runner.model, fp32_weights)
|
|
||||||
# allreduce grads
|
|
||||||
if self.distributed:
|
|
||||||
allreduce_grads(fp32_weights, self.coalesce,
|
|
||||||
self.bucket_size_mb)
|
|
||||||
|
|
||||||
has_overflow = self.loss_scaler.has_overflow(fp32_weights)
|
|
||||||
# if has overflow, skip this iteration
|
|
||||||
if not has_overflow:
|
|
||||||
# scale the gradients back
|
|
||||||
for param in fp32_weights:
|
|
||||||
if param.grad is not None:
|
|
||||||
param.grad.div_(self.loss_scaler.loss_scale)
|
|
||||||
if self.grad_clip is not None:
|
|
||||||
grad_norm = self.clip_grads(fp32_weights)
|
|
||||||
if grad_norm is not None:
|
|
||||||
# Add grad norm to the logger
|
|
||||||
runner.log_buffer.update(
|
|
||||||
{'grad_norm': float(grad_norm)},
|
|
||||||
runner.outputs['num_samples'])
|
|
||||||
# update fp32 params
|
|
||||||
runner.optimizer.step()
|
|
||||||
# copy fp32 params to the fp16 model
|
|
||||||
self.copy_params_to_fp16(runner.model, fp32_weights)
|
|
||||||
else:
|
|
||||||
runner.logger.warning(
|
|
||||||
'Check overflow, downscale loss scale '
|
|
||||||
f'to {self.loss_scaler.cur_scale}')
|
|
||||||
|
|
||||||
self.loss_scaler.update_scale(has_overflow)
|
|
||||||
|
|
||||||
# save state_dict of loss_scaler
|
|
||||||
runner.meta.setdefault(
|
|
||||||
'fp16', {})['loss_scaler'] = self.loss_scaler.state_dict()
|
|
||||||
|
|
||||||
# clear grads
|
|
||||||
runner.model.zero_grad()
|
|
||||||
runner.optimizer.zero_grad()
|
|
@ -1,5 +1,9 @@
|
|||||||
# Copyright (c) OpenMMLab. All rights reserved.
|
# Copyright (c) OpenMMLab. All rights reserved.
|
||||||
from mmcv.runner import HOOKS, Hook
|
from typing import Optional
|
||||||
|
|
||||||
|
from mmengine.hooks import Hook
|
||||||
|
|
||||||
|
from mmselfsup.registry import HOOKS
|
||||||
|
|
||||||
|
|
||||||
@HOOKS.register_module()
|
@HOOKS.register_module()
|
||||||
@ -15,12 +19,16 @@ class SimSiamHook(Hook):
|
|||||||
Defaults to True.
|
Defaults to True.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, fix_pred_lr, lr, adjust_by_epoch=True, **kwargs):
|
def __init__(self,
|
||||||
|
fix_pred_lr: bool,
|
||||||
|
lr: float,
|
||||||
|
adjust_by_epoch: Optional[bool] = True) -> None:
|
||||||
self.fix_pred_lr = fix_pred_lr
|
self.fix_pred_lr = fix_pred_lr
|
||||||
self.lr = lr
|
self.lr = lr
|
||||||
self.adjust_by_epoch = adjust_by_epoch
|
self.adjust_by_epoch = adjust_by_epoch
|
||||||
|
|
||||||
def before_train_iter(self, runner):
|
def before_train_iter(self, runner) -> None:
|
||||||
|
"""fix lr of predictor by iter."""
|
||||||
if self.adjust_by_epoch:
|
if self.adjust_by_epoch:
|
||||||
return
|
return
|
||||||
else:
|
else:
|
||||||
@ -29,8 +37,8 @@ class SimSiamHook(Hook):
|
|||||||
if 'fix_lr' in param_group and param_group['fix_lr']:
|
if 'fix_lr' in param_group and param_group['fix_lr']:
|
||||||
param_group['lr'] = self.lr
|
param_group['lr'] = self.lr
|
||||||
|
|
||||||
def before_train_epoch(self, runner):
|
def before_train_epoch(self, runner) -> None:
|
||||||
"""fix lr of predictor."""
|
"""fix lr of predictor by epoch."""
|
||||||
if self.fix_pred_lr:
|
if self.fix_pred_lr:
|
||||||
for param_group in runner.optimizer.param_groups:
|
for param_group in runner.optimizer.param_groups:
|
||||||
if 'fix_lr' in param_group and param_group['fix_lr']:
|
if 'fix_lr' in param_group and param_group['fix_lr']:
|
||||||
|
@ -1,9 +1,12 @@
|
|||||||
# Copyright (c) OpenMMLab. All rights reserved.
|
# Copyright (c) OpenMMLab. All rights reserved.
|
||||||
import os.path as osp
|
import os.path as osp
|
||||||
|
from typing import List, Optional
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.distributed as dist
|
import torch.distributed as dist
|
||||||
from mmcv.runner import HOOKS, Hook
|
from mmengine.hooks import Hook
|
||||||
|
|
||||||
|
from mmselfsup.registry import HOOKS
|
||||||
|
|
||||||
|
|
||||||
@HOOKS.register_module()
|
@HOOKS.register_module()
|
||||||
@ -29,13 +32,12 @@ class SwAVHook(Hook):
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self,
|
def __init__(self,
|
||||||
batch_size,
|
batch_size: int,
|
||||||
epoch_queue_starts=15,
|
epoch_queue_starts: Optional[int] = 15,
|
||||||
crops_for_assign=[0, 1],
|
crops_for_assign: Optional[List[int]] = [0, 1],
|
||||||
feat_dim=128,
|
feat_dim: Optional[int] = 128,
|
||||||
queue_length=0,
|
queue_length: Optional[int] = 0,
|
||||||
interval=1,
|
interval: Optional[int] = 1):
|
||||||
**kwargs):
|
|
||||||
self.batch_size = batch_size * dist.get_world_size()\
|
self.batch_size = batch_size * dist.get_world_size()\
|
||||||
if dist.is_initialized() else batch_size
|
if dist.is_initialized() else batch_size
|
||||||
self.epoch_queue_starts = epoch_queue_starts
|
self.epoch_queue_starts = epoch_queue_starts
|
||||||
@ -45,7 +47,7 @@ class SwAVHook(Hook):
|
|||||||
self.interval = interval
|
self.interval = interval
|
||||||
self.queue = None
|
self.queue = None
|
||||||
|
|
||||||
def before_run(self, runner):
|
def before_run(self, runner) -> None:
|
||||||
if dist.is_initialized():
|
if dist.is_initialized():
|
||||||
self.queue_path = osp.join(runner.work_dir,
|
self.queue_path = osp.join(runner.work_dir,
|
||||||
'queue' + str(dist.get_rank()) + '.pth')
|
'queue' + str(dist.get_rank()) + '.pth')
|
||||||
@ -58,7 +60,7 @@ class SwAVHook(Hook):
|
|||||||
# the queue needs to be divisible by the batch size
|
# the queue needs to be divisible by the batch size
|
||||||
self.queue_length -= self.queue_length % self.batch_size
|
self.queue_length -= self.queue_length % self.batch_size
|
||||||
|
|
||||||
def before_train_epoch(self, runner):
|
def before_train_epoch(self, runner) -> None:
|
||||||
# optionally starts a queue
|
# optionally starts a queue
|
||||||
if self.queue_length > 0 \
|
if self.queue_length > 0 \
|
||||||
and runner.epoch >= self.epoch_queue_starts \
|
and runner.epoch >= self.epoch_queue_starts \
|
||||||
@ -73,7 +75,7 @@ class SwAVHook(Hook):
|
|||||||
runner.model.module.head.queue = self.queue
|
runner.model.module.head.queue = self.queue
|
||||||
runner.model.module.head.use_queue = False
|
runner.model.module.head.use_queue = False
|
||||||
|
|
||||||
def after_train_epoch(self, runner):
|
def after_train_epoch(self, runner) -> None:
|
||||||
self.queue = runner.model.module.head.queue
|
self.queue = runner.model.module.head.queue
|
||||||
|
|
||||||
if self.queue is not None and self.every_n_epochs(
|
if self.queue is not None and self.every_n_epochs(
|
||||||
|
Loading…
x
Reference in New Issue
Block a user