120 lines
4.8 KiB
Python
120 lines
4.8 KiB
Python
# Copyright (c) OpenMMLab. All rights reserved.
|
|
import os.path as osp
|
|
from typing import Dict, List, Optional, Sequence
|
|
|
|
import torch
|
|
from mmengine.device import get_device
|
|
from mmengine.dist import get_rank, get_world_size, is_distributed
|
|
from mmengine.hooks import Hook
|
|
from mmengine.logging import MMLogger
|
|
|
|
from mmpretrain.registry import HOOKS
|
|
from mmpretrain.utils import get_ori_model
|
|
|
|
|
|
@HOOKS.register_module()
|
|
class SwAVHook(Hook):
|
|
"""Hook for SwAV.
|
|
|
|
This hook builds the queue in SwAV according to ``epoch_queue_starts``.
|
|
The queue will be saved in ``runner.work_dir`` or loaded at start epoch
|
|
if the path folder has queues saved before.
|
|
|
|
Args:
|
|
batch_size (int): the batch size per GPU for computing.
|
|
epoch_queue_starts (int, optional): from this epoch, starts to use the
|
|
queue. Defaults to 15.
|
|
crops_for_assign (list[int], optional): list of crops id used for
|
|
computing assignments. Defaults to [0, 1].
|
|
feat_dim (int, optional): feature dimension of output vector.
|
|
Defaults to 128.
|
|
queue_length (int, optional): length of the queue (0 for no queue).
|
|
Defaults to 0.
|
|
interval (int, optional): the interval to save the queue.
|
|
Defaults to 1.
|
|
frozen_layers_cfg (dict, optional): Dict to config frozen layers.
|
|
The key-value pair is layer name and its frozen iters. If frozen,
|
|
the layers don't need gradient. Defaults to dict().
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
batch_size: int,
|
|
epoch_queue_starts: Optional[int] = 15,
|
|
crops_for_assign: Optional[List[int]] = [0, 1],
|
|
feat_dim: Optional[int] = 128,
|
|
queue_length: Optional[int] = 0,
|
|
interval: Optional[int] = 1,
|
|
frozen_layers_cfg: Optional[Dict] = dict()
|
|
) -> None:
|
|
self.batch_size = batch_size * get_world_size()
|
|
self.epoch_queue_starts = epoch_queue_starts
|
|
self.crops_for_assign = crops_for_assign
|
|
self.feat_dim = feat_dim
|
|
self.queue_length = queue_length
|
|
self.interval = interval
|
|
self.frozen_layers_cfg = frozen_layers_cfg
|
|
self.requires_grad = True
|
|
self.queue = None
|
|
|
|
def before_run(self, runner) -> None:
|
|
"""Check whether the queues exist locally or not."""
|
|
if is_distributed():
|
|
self.queue_path = osp.join(runner.work_dir,
|
|
'queue' + str(get_rank()) + '.pth')
|
|
else:
|
|
self.queue_path = osp.join(runner.work_dir, 'queue.pth')
|
|
|
|
# load the queues if queues exist locally
|
|
if osp.isfile(self.queue_path):
|
|
self.queue = torch.load(self.queue_path)['queue']
|
|
get_ori_model(runner.model).head.loss_module.queue = self.queue
|
|
MMLogger.get_current_instance().info(
|
|
f'Load queue from file: {self.queue_path}')
|
|
|
|
# the queue needs to be divisible by the batch size
|
|
self.queue_length -= self.queue_length % self.batch_size
|
|
|
|
def before_train_iter(self,
|
|
runner,
|
|
batch_idx: int,
|
|
data_batch: Optional[Sequence[dict]] = None) -> None:
|
|
"""Freeze layers before specific iters according to the config."""
|
|
for layer, frozen_iters in self.frozen_layers_cfg.items():
|
|
if runner.iter < frozen_iters and self.requires_grad:
|
|
self.requires_grad = False
|
|
for name, p in get_ori_model(runner.model).named_parameters():
|
|
if layer in name:
|
|
p.requires_grad = False
|
|
elif runner.iter >= frozen_iters and not self.requires_grad:
|
|
self.requires_grad = True
|
|
for name, p in get_ori_model(runner.model).named_parameters():
|
|
if layer in name:
|
|
p.requires_grad = True
|
|
|
|
def before_train_epoch(self, runner) -> None:
|
|
"""Check the queues' state."""
|
|
# optionally starts a queue
|
|
if self.queue_length > 0 \
|
|
and runner.epoch >= self.epoch_queue_starts \
|
|
and self.queue is None:
|
|
|
|
self.queue = torch.zeros(
|
|
len(self.crops_for_assign),
|
|
self.queue_length // runner.world_size,
|
|
self.feat_dim,
|
|
device=get_device(),
|
|
)
|
|
|
|
# set the boolean type of use_the_queue
|
|
get_ori_model(runner.model).head.loss_module.queue = self.queue
|
|
get_ori_model(runner.model).head.loss_module.use_queue = False
|
|
|
|
def after_train_epoch(self, runner) -> None:
|
|
"""Save the queues locally."""
|
|
self.queue = get_ori_model(runner.model).head.loss_module.queue
|
|
|
|
if self.queue is not None and self.every_n_epochs(
|
|
runner, self.interval):
|
|
torch.save({'queue': self.queue}, self.queue_path)
|