mirror of
https://github.com/open-mmlab/mmselfsup.git
synced 2025-06-03 14:59:38 +08:00
[Refactor] refactor hooks
This commit is contained in:
parent
9565da23a4
commit
063288240a
@ -1,14 +1,10 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from .deepcluster_hook import DeepClusterHook
|
||||
from .densecl_hook import DenseCLHook
|
||||
from .momentum_update_hook import MomentumUpdateHook
|
||||
from .odc_hook import ODCHook
|
||||
from .optimizer_hook import DistOptimizerHook, GradAccumFp16OptimizerHook
|
||||
from .simsiam_hook import SimSiamHook
|
||||
from .swav_hook import SwAVHook
|
||||
|
||||
__all__ = [
|
||||
'MomentumUpdateHook', 'DeepClusterHook', 'DenseCLHook', 'ODCHook',
|
||||
'DistOptimizerHook', 'GradAccumFp16OptimizerHook', 'SimSiamHook',
|
||||
'SwAVHook'
|
||||
'DeepClusterHook', 'DenseCLHook', 'ODCHook', 'SimSiamHook', 'SwAVHook'
|
||||
]
|
||||
|
@ -1,13 +1,16 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from typing import Dict, Optional
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
from mmcv.runner import HOOKS, Hook
|
||||
from mmengine.hooks import Hook
|
||||
from mmengine.logging import print_log
|
||||
from torch.utils.data import DataLoader
|
||||
|
||||
from mmselfsup.registry import HOOKS
|
||||
from mmselfsup.utils import Extractor
|
||||
from mmselfsup.utils import clustering as _clustering
|
||||
from mmselfsup.utils import get_root_logger
|
||||
|
||||
|
||||
@HOOKS.register_module()
|
||||
@ -32,33 +35,16 @@ class DeepClusterHook(Hook):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
extractor,
|
||||
clustering,
|
||||
unif_sampling,
|
||||
reweight,
|
||||
reweight_pow,
|
||||
init_memory=False, # for ODC
|
||||
initial=True,
|
||||
interval=1,
|
||||
dist_mode=True,
|
||||
data_loaders=None):
|
||||
|
||||
logger = get_root_logger()
|
||||
if 'imgs_per_gpu' in extractor:
|
||||
logger.warning('"imgs_per_gpu" is deprecated. '
|
||||
'Please use "samples_per_gpu" instead')
|
||||
if 'samples_per_gpu' in extractor:
|
||||
logger.warning(
|
||||
f'Got "imgs_per_gpu"={extractor["imgs_per_gpu"]} and '
|
||||
f'"samples_per_gpu"={extractor["samples_per_gpu"]}, '
|
||||
f'"imgs_per_gpu"={extractor["imgs_per_gpu"]} is used in '
|
||||
f'this experiments')
|
||||
else:
|
||||
logger.warning(
|
||||
'Automatically set "samples_per_gpu"="imgs_per_gpu"='
|
||||
f'{extractor["imgs_per_gpu"]} in this experiments')
|
||||
extractor['samples_per_gpu'] = extractor['imgs_per_gpu']
|
||||
|
||||
extractor: Dict,
|
||||
clustering: Dict,
|
||||
unif_sampling: bool,
|
||||
reweight: bool,
|
||||
reweight_pow: float,
|
||||
init_memory: Optional[bool] = False, # for ODC
|
||||
initial: Optional[bool] = True,
|
||||
interval: Optional[int] = 1,
|
||||
dist_mode: Optional[bool] = True,
|
||||
data_loaders: Optional[DataLoader] = None) -> None:
|
||||
self.extractor = Extractor(dist_mode=dist_mode, **extractor)
|
||||
self.clustering_type = clustering.pop('type')
|
||||
self.clustering_cfg = clustering
|
||||
@ -71,16 +57,16 @@ class DeepClusterHook(Hook):
|
||||
self.dist_mode = dist_mode
|
||||
self.data_loaders = data_loaders
|
||||
|
||||
def before_run(self, runner):
|
||||
def before_run(self, runner) -> None:
|
||||
if self.initial:
|
||||
self.deepcluster(runner)
|
||||
|
||||
def after_train_epoch(self, runner):
|
||||
def after_train_epoch(self, runner) -> None:
|
||||
if not self.every_n_epochs(runner, self.interval):
|
||||
return
|
||||
self.deepcluster(runner)
|
||||
|
||||
def deepcluster(self, runner):
|
||||
def deepcluster(self, runner) -> None:
|
||||
# step 1: get features
|
||||
runner.model.eval()
|
||||
features = self.extractor(runner)
|
||||
@ -130,7 +116,7 @@ class DeepClusterHook(Hook):
|
||||
if self.init_memory:
|
||||
runner.model.module.memory_bank.init_memory(features, new_labels)
|
||||
|
||||
def evaluate(self, runner, new_labels):
|
||||
def evaluate(self, runner, new_labels: np.ndarray) -> None:
|
||||
histogram = np.bincount(new_labels, minlength=self.clustering_cfg.k)
|
||||
empty_cls = (histogram == 0).sum()
|
||||
minimal_cls_size, maximal_cls_size = histogram.min(), histogram.max()
|
||||
|
@ -1,5 +1,9 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from mmcv.runner import HOOKS, Hook
|
||||
from typing import Optional
|
||||
|
||||
from mmengine.hooks import Hook
|
||||
|
||||
from mmselfsup.registry import HOOKS
|
||||
|
||||
|
||||
@HOOKS.register_module()
|
||||
@ -14,15 +18,15 @@ class DenseCLHook(Hook):
|
||||
``loss_lambda=0``. Defaults to 1000.
|
||||
"""
|
||||
|
||||
def __init__(self, start_iters=1000, **kwargs):
|
||||
def __init__(self, start_iters: Optional[int] = 1000) -> None:
|
||||
self.start_iters = start_iters
|
||||
|
||||
def before_run(self, runner):
|
||||
def before_run(self, runner) -> None:
|
||||
assert hasattr(runner.model.module, 'loss_lambda'), \
|
||||
"The runner must have attribute \"loss_lambda\" in DenseCL."
|
||||
self.loss_lambda = runner.model.module.loss_lambda
|
||||
|
||||
def before_train_iter(self, runner):
|
||||
def before_train_iter(self, runner) -> None:
|
||||
assert hasattr(runner.model.module, 'loss_lambda'), \
|
||||
"The runner must have attribute \"loss_lambda\" in DenseCL."
|
||||
cur_iter = runner.iter
|
||||
|
@ -1,48 +0,0 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from math import cos, pi
|
||||
|
||||
from mmcv.parallel import is_module_wrapper
|
||||
from mmcv.runner import HOOKS, Hook
|
||||
|
||||
|
||||
@HOOKS.register_module(name=['BYOLHook', 'MomentumUpdateHook'])
|
||||
class MomentumUpdateHook(Hook):
|
||||
"""Hook for updating momentum parameter, used by BYOL, MoCoV3, etc.
|
||||
|
||||
This hook includes momentum adjustment following:
|
||||
|
||||
.. math::
|
||||
m = 1 - (1 - m_0) * (cos(pi * k / K) + 1) / 2
|
||||
|
||||
where :math:`k` is the current step, :math:`K` is the total steps.
|
||||
|
||||
Args:
|
||||
end_momentum (float): The final momentum coefficient
|
||||
for the target network. Defaults to 1.
|
||||
update_interval (int, optional): The momentum update interval of the
|
||||
weights. Defaults to 1.
|
||||
"""
|
||||
|
||||
def __init__(self, end_momentum=1., update_interval=1, **kwargs):
|
||||
self.end_momentum = end_momentum
|
||||
self.update_interval = update_interval
|
||||
|
||||
def before_train_iter(self, runner):
|
||||
assert hasattr(runner.model.module, 'momentum'), \
|
||||
"The runner must have attribute \"momentum\" in algorithms."
|
||||
assert hasattr(runner.model.module, 'base_momentum'), \
|
||||
"The runner must have attribute \"base_momentum\" in algorithms."
|
||||
if self.every_n_iters(runner, self.update_interval):
|
||||
cur_iter = runner.iter
|
||||
max_iter = runner.max_iters
|
||||
base_m = runner.model.module.base_momentum
|
||||
m = self.end_momentum - (self.end_momentum - base_m) * (
|
||||
cos(pi * cur_iter / float(max_iter)) + 1) / 2
|
||||
runner.model.module.momentum = m
|
||||
|
||||
def after_train_iter(self, runner):
|
||||
if self.every_n_iters(runner, self.update_interval):
|
||||
if is_module_wrapper(runner.model):
|
||||
runner.model.module.momentum_update()
|
||||
else:
|
||||
runner.model.momentum_update()
|
@ -1,8 +1,12 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from typing import Optional
|
||||
|
||||
import numpy as np
|
||||
from mmcv.runner import HOOKS, Hook
|
||||
from mmengine.hooks import Hook
|
||||
from mmengine.logging import print_log
|
||||
|
||||
from mmselfsup.registry import HOOKS
|
||||
|
||||
|
||||
@HOOKS.register_module()
|
||||
class ODCHook(Hook):
|
||||
@ -22,12 +26,12 @@ class ODCHook(Hook):
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
centroids_update_interval,
|
||||
deal_with_small_clusters_interval,
|
||||
evaluate_interval,
|
||||
reweight,
|
||||
reweight_pow,
|
||||
dist_mode=True):
|
||||
centroids_update_interval: int,
|
||||
deal_with_small_clusters_interval: int,
|
||||
evaluate_interval: int,
|
||||
reweight: bool,
|
||||
reweight_pow: float,
|
||||
dist_mode: Optional[bool] = True) -> None:
|
||||
assert dist_mode, 'non-dist mode is not implemented'
|
||||
self.centroids_update_interval = centroids_update_interval
|
||||
self.deal_with_small_clusters_interval = \
|
||||
@ -36,7 +40,7 @@ class ODCHook(Hook):
|
||||
self.reweight = reweight
|
||||
self.reweight_pow = reweight_pow
|
||||
|
||||
def after_train_iter(self, runner):
|
||||
def after_train_iter(self, runner) -> None:
|
||||
# centroids update
|
||||
if self.every_n_iters(runner, self.centroids_update_interval):
|
||||
runner.model.module.memory_bank.update_centroids_memory()
|
||||
@ -55,7 +59,7 @@ class ODCHook(Hook):
|
||||
new_labels = new_labels.cpu()
|
||||
self.evaluate(runner, new_labels.numpy())
|
||||
|
||||
def after_train_epoch(self, runner):
|
||||
def after_train_epoch(self, runner) -> None:
|
||||
# save cluster
|
||||
if self.every_n_epochs(runner, 10) and runner.rank == 0:
|
||||
new_labels = runner.model.module.memory_bank.label_bank
|
||||
@ -64,7 +68,7 @@ class ODCHook(Hook):
|
||||
np.save(f'{runner.work_dir}/cluster_epoch_{runner.epoch + 1}.npy',
|
||||
new_labels.numpy())
|
||||
|
||||
def evaluate(self, runner, new_labels):
|
||||
def evaluate(self, runner, new_labels: np.ndarray) -> None:
|
||||
histogram = np.bincount(
|
||||
new_labels, minlength=runner.model.module.memory_bank.num_classes)
|
||||
empty_cls = (histogram == 0).sum()
|
||||
|
@ -1,261 +0,0 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from mmcv.runner import (HOOKS, Fp16OptimizerHook, OptimizerHook,
|
||||
allreduce_grads)
|
||||
from mmcv.utils import TORCH_VERSION, _BatchNorm, digit_version
|
||||
|
||||
|
||||
@HOOKS.register_module()
|
||||
class DistOptimizerHook(OptimizerHook):
|
||||
"""Optimizer hook for distributed training.
|
||||
|
||||
This hook can accumulate gradients every n intervals and freeze some
|
||||
layers for some iters at the beginning.
|
||||
|
||||
Args:
|
||||
update_interval (int, optional): The update interval of the weights,
|
||||
set > 1 to accumulate the grad. Defaults to 1.
|
||||
grad_clip (dict, optional): Dict to config the value of grad clip.
|
||||
E.g., grad_clip = dict(max_norm=10). Defaults to None.
|
||||
coalesce (bool, optional): Whether allreduce parameters as a whole.
|
||||
Defaults to True.
|
||||
bucket_size_mb (int, optional): Size of bucket, the unit is MB.
|
||||
Defaults to -1.
|
||||
frozen_layers_cfg (dict, optional): Dict to config frozen layers.
|
||||
The key-value pair is layer name and its frozen iters. If frozen,
|
||||
the layer gradient would be set to None. Defaults to dict().
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
update_interval=1,
|
||||
grad_clip=None,
|
||||
coalesce=True,
|
||||
bucket_size_mb=-1,
|
||||
frozen_layers_cfg=dict()):
|
||||
self.grad_clip = grad_clip
|
||||
self.coalesce = coalesce
|
||||
self.bucket_size_mb = bucket_size_mb
|
||||
self.update_interval = update_interval
|
||||
self.frozen_layers_cfg = frozen_layers_cfg
|
||||
self.initialized = False
|
||||
|
||||
def has_batch_norm(self, module):
|
||||
if isinstance(module, _BatchNorm):
|
||||
return True
|
||||
for m in module.children():
|
||||
if self.has_batch_norm(m):
|
||||
return True
|
||||
return False
|
||||
|
||||
def _init(self, runner):
|
||||
if runner.iter % self.update_interval != 0:
|
||||
runner.logger.warning(
|
||||
'Resume iter number is not divisible by update_interval in '
|
||||
'GradientCumulativeOptimizerHook, which means the gradient of '
|
||||
'some iters is lost and the result may be influenced slightly.'
|
||||
)
|
||||
|
||||
if self.has_batch_norm(runner.model) and self.update_interval > 1:
|
||||
runner.logger.warning(
|
||||
'GradientCumulativeOptimizerHook may slightly decrease '
|
||||
'performance if the model has BatchNorm layers.')
|
||||
|
||||
residual_iters = runner.max_iters
|
||||
|
||||
self.divisible_iters = (
|
||||
residual_iters // self.update_interval * self.update_interval)
|
||||
self.remainder_iters = residual_iters - self.divisible_iters
|
||||
|
||||
self.initialized = True
|
||||
|
||||
def before_run(self, runner):
|
||||
runner.optimizer.zero_grad()
|
||||
|
||||
def after_train_iter(self, runner):
|
||||
# In some cases, MMCV's GradientCumulativeOptimizerHook will
|
||||
# cause the loss_factor to be zero and we fix this bug in our
|
||||
# implementation.
|
||||
|
||||
if not self.initialized:
|
||||
self._init(runner)
|
||||
|
||||
if runner.iter < self.divisible_iters:
|
||||
loss_factor = self.update_interval
|
||||
else:
|
||||
loss_factor = self.remainder_iters
|
||||
|
||||
runner.outputs['loss'] /= loss_factor
|
||||
runner.outputs['loss'].backward()
|
||||
|
||||
if (self.every_n_iters(runner, self.update_interval)
|
||||
or self.is_last_iter(runner)):
|
||||
|
||||
# cancel gradient of certain layer for n iters
|
||||
# according to frozen_layers_cfg dict
|
||||
for layer, iters in self.frozen_layers_cfg.items():
|
||||
if runner.iter < iters:
|
||||
for name, p in runner.model.module.named_parameters():
|
||||
if layer in name:
|
||||
p.grad = None
|
||||
|
||||
if self.grad_clip is not None:
|
||||
grad_norm = self.clip_grads(runner.model.parameters())
|
||||
if grad_norm is not None:
|
||||
# Add grad norm to the logger
|
||||
runner.log_buffer.update({'grad_norm': float(grad_norm)},
|
||||
runner.outputs['num_samples'])
|
||||
|
||||
runner.optimizer.step()
|
||||
runner.optimizer.zero_grad()
|
||||
|
||||
|
||||
if (TORCH_VERSION != 'parrots'
|
||||
and digit_version(TORCH_VERSION) >= digit_version('1.6.0')):
|
||||
|
||||
@HOOKS.register_module()
|
||||
class GradAccumFp16OptimizerHook(Fp16OptimizerHook):
|
||||
"""Fp16 optimizer hook (using PyTorch's implementation).
|
||||
|
||||
This hook can accumulate gradients every n intervals and freeze some
|
||||
layers for some iters at the beginning.
|
||||
If you are using PyTorch >= 1.6, torch.cuda.amp is used as the backend,
|
||||
to take care of the optimization procedure.
|
||||
|
||||
Args:
|
||||
update_interval (int, optional): The update interval of the
|
||||
weights, set > 1 to accumulate the grad. Defaults to 1.
|
||||
frozen_layers_cfg (dict, optional): Dict to config frozen layers.
|
||||
The key-value pair is layer name and its frozen iters. If
|
||||
frozen, the layer gradient would be set to None.
|
||||
Defaults to dict().
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
update_interval=1,
|
||||
frozen_layers_cfg=dict(),
|
||||
**kwargs):
|
||||
super(GradAccumFp16OptimizerHook, self).__init__(**kwargs)
|
||||
self.update_interval = update_interval
|
||||
self.frozen_layers_cfg = frozen_layers_cfg
|
||||
|
||||
def after_train_iter(self, runner):
|
||||
runner.outputs['loss'] /= self.update_interval
|
||||
self.loss_scaler.scale(runner.outputs['loss']).backward()
|
||||
|
||||
if self.every_n_iters(runner, self.update_interval):
|
||||
|
||||
# cancel gradient of certain layer for n iters
|
||||
# according to frozen_layers_cfg dict
|
||||
for layer, iters in self.frozen_layers_cfg.items():
|
||||
if runner.iter < iters:
|
||||
for name, p in runner.model.module.named_parameters():
|
||||
if layer in name:
|
||||
p.grad = None
|
||||
|
||||
# copy fp16 grads in the model to fp32 params in the optimizer
|
||||
self.loss_scaler.unscale_(runner.optimizer)
|
||||
|
||||
if self.grad_clip is not None:
|
||||
grad_norm = self.clip_grads(runner.model.parameters())
|
||||
if grad_norm is not None:
|
||||
# Add grad norm to the logger
|
||||
runner.log_buffer.update(
|
||||
{'grad_norm': float(grad_norm)},
|
||||
runner.outputs['num_samples'])
|
||||
|
||||
# backward and update scaler
|
||||
self.loss_scaler.step(runner.optimizer)
|
||||
self.loss_scaler.update(self._scale_update_param)
|
||||
|
||||
# save state_dict of loss_scaler
|
||||
runner.meta.setdefault(
|
||||
'fp16', {})['loss_scaler'] = self.loss_scaler.state_dict()
|
||||
|
||||
# clear grads
|
||||
runner.model.zero_grad()
|
||||
runner.optimizer.zero_grad()
|
||||
|
||||
else:
|
||||
|
||||
@HOOKS.register_module()
|
||||
class GradAccumFp16OptimizerHook(Fp16OptimizerHook):
|
||||
"""Fp16 optimizer hook (using mmcv's implementation).
|
||||
|
||||
This hook can accumulate gradients every n intervals and freeze some
|
||||
layers for some iters at the beginning.
|
||||
|
||||
Args:
|
||||
update_interval (int, optional): The update interval of the
|
||||
weights, set > 1 to accumulate the grad. Defaults to 1.
|
||||
frozen_layers_cfg (dict, optional): Dict to config frozen layers.
|
||||
The key-value pair is layer name and its frozen iters. If
|
||||
frozen, the layer gradient would be set to None.
|
||||
Defaults to dict().
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
update_interval=1,
|
||||
frozen_layers_cfg=dict(),
|
||||
**kwargs):
|
||||
super(GradAccumFp16OptimizerHook, self).__init__(**kwargs)
|
||||
self.update_interval = update_interval
|
||||
self.frozen_layers_cfg = frozen_layers_cfg
|
||||
|
||||
def after_train_iter(self, runner):
|
||||
runner.outputs['loss'] /= self.update_interval
|
||||
|
||||
# scale the loss value
|
||||
scaled_loss = runner.outputs['loss'] * self.loss_scaler.loss_scale
|
||||
scaled_loss.backward()
|
||||
|
||||
if self.every_n_iters(runner, self.update_interval):
|
||||
|
||||
# cancel gradient of certain layer for n iters
|
||||
# according to frozen_layers_cfg dict
|
||||
for layer, iters in self.frozen_layers_cfg.items():
|
||||
if runner.iter < iters:
|
||||
for name, p in runner.model.module.named_parameters():
|
||||
if layer in name:
|
||||
p.grad = None
|
||||
|
||||
# copy fp16 grads in the model to fp32 params in the optimizer
|
||||
fp32_weights = []
|
||||
for param_group in runner.optimizer.param_groups:
|
||||
fp32_weights += param_group['params']
|
||||
self.copy_grads_to_fp32(runner.model, fp32_weights)
|
||||
# allreduce grads
|
||||
if self.distributed:
|
||||
allreduce_grads(fp32_weights, self.coalesce,
|
||||
self.bucket_size_mb)
|
||||
|
||||
has_overflow = self.loss_scaler.has_overflow(fp32_weights)
|
||||
# if has overflow, skip this iteration
|
||||
if not has_overflow:
|
||||
# scale the gradients back
|
||||
for param in fp32_weights:
|
||||
if param.grad is not None:
|
||||
param.grad.div_(self.loss_scaler.loss_scale)
|
||||
if self.grad_clip is not None:
|
||||
grad_norm = self.clip_grads(fp32_weights)
|
||||
if grad_norm is not None:
|
||||
# Add grad norm to the logger
|
||||
runner.log_buffer.update(
|
||||
{'grad_norm': float(grad_norm)},
|
||||
runner.outputs['num_samples'])
|
||||
# update fp32 params
|
||||
runner.optimizer.step()
|
||||
# copy fp32 params to the fp16 model
|
||||
self.copy_params_to_fp16(runner.model, fp32_weights)
|
||||
else:
|
||||
runner.logger.warning(
|
||||
'Check overflow, downscale loss scale '
|
||||
f'to {self.loss_scaler.cur_scale}')
|
||||
|
||||
self.loss_scaler.update_scale(has_overflow)
|
||||
|
||||
# save state_dict of loss_scaler
|
||||
runner.meta.setdefault(
|
||||
'fp16', {})['loss_scaler'] = self.loss_scaler.state_dict()
|
||||
|
||||
# clear grads
|
||||
runner.model.zero_grad()
|
||||
runner.optimizer.zero_grad()
|
@ -1,5 +1,9 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from mmcv.runner import HOOKS, Hook
|
||||
from typing import Optional
|
||||
|
||||
from mmengine.hooks import Hook
|
||||
|
||||
from mmselfsup.registry import HOOKS
|
||||
|
||||
|
||||
@HOOKS.register_module()
|
||||
@ -15,12 +19,16 @@ class SimSiamHook(Hook):
|
||||
Defaults to True.
|
||||
"""
|
||||
|
||||
def __init__(self, fix_pred_lr, lr, adjust_by_epoch=True, **kwargs):
|
||||
def __init__(self,
|
||||
fix_pred_lr: bool,
|
||||
lr: float,
|
||||
adjust_by_epoch: Optional[bool] = True) -> None:
|
||||
self.fix_pred_lr = fix_pred_lr
|
||||
self.lr = lr
|
||||
self.adjust_by_epoch = adjust_by_epoch
|
||||
|
||||
def before_train_iter(self, runner):
|
||||
def before_train_iter(self, runner) -> None:
|
||||
"""fix lr of predictor by iter."""
|
||||
if self.adjust_by_epoch:
|
||||
return
|
||||
else:
|
||||
@ -29,8 +37,8 @@ class SimSiamHook(Hook):
|
||||
if 'fix_lr' in param_group and param_group['fix_lr']:
|
||||
param_group['lr'] = self.lr
|
||||
|
||||
def before_train_epoch(self, runner):
|
||||
"""fix lr of predictor."""
|
||||
def before_train_epoch(self, runner) -> None:
|
||||
"""fix lr of predictor by epoch."""
|
||||
if self.fix_pred_lr:
|
||||
for param_group in runner.optimizer.param_groups:
|
||||
if 'fix_lr' in param_group and param_group['fix_lr']:
|
||||
|
@ -1,9 +1,12 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import os.path as osp
|
||||
from typing import List, Optional
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
from mmcv.runner import HOOKS, Hook
|
||||
from mmengine.hooks import Hook
|
||||
|
||||
from mmselfsup.registry import HOOKS
|
||||
|
||||
|
||||
@HOOKS.register_module()
|
||||
@ -29,13 +32,12 @@ class SwAVHook(Hook):
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
batch_size,
|
||||
epoch_queue_starts=15,
|
||||
crops_for_assign=[0, 1],
|
||||
feat_dim=128,
|
||||
queue_length=0,
|
||||
interval=1,
|
||||
**kwargs):
|
||||
batch_size: int,
|
||||
epoch_queue_starts: Optional[int] = 15,
|
||||
crops_for_assign: Optional[List[int]] = [0, 1],
|
||||
feat_dim: Optional[int] = 128,
|
||||
queue_length: Optional[int] = 0,
|
||||
interval: Optional[int] = 1):
|
||||
self.batch_size = batch_size * dist.get_world_size()\
|
||||
if dist.is_initialized() else batch_size
|
||||
self.epoch_queue_starts = epoch_queue_starts
|
||||
@ -45,7 +47,7 @@ class SwAVHook(Hook):
|
||||
self.interval = interval
|
||||
self.queue = None
|
||||
|
||||
def before_run(self, runner):
|
||||
def before_run(self, runner) -> None:
|
||||
if dist.is_initialized():
|
||||
self.queue_path = osp.join(runner.work_dir,
|
||||
'queue' + str(dist.get_rank()) + '.pth')
|
||||
@ -58,7 +60,7 @@ class SwAVHook(Hook):
|
||||
# the queue needs to be divisible by the batch size
|
||||
self.queue_length -= self.queue_length % self.batch_size
|
||||
|
||||
def before_train_epoch(self, runner):
|
||||
def before_train_epoch(self, runner) -> None:
|
||||
# optionally starts a queue
|
||||
if self.queue_length > 0 \
|
||||
and runner.epoch >= self.epoch_queue_starts \
|
||||
@ -73,7 +75,7 @@ class SwAVHook(Hook):
|
||||
runner.model.module.head.queue = self.queue
|
||||
runner.model.module.head.use_queue = False
|
||||
|
||||
def after_train_epoch(self, runner):
|
||||
def after_train_epoch(self, runner) -> None:
|
||||
self.queue = runner.model.module.head.queue
|
||||
|
||||
if self.queue is not None and self.every_n_epochs(
|
||||
|
Loading…
x
Reference in New Issue
Block a user