mirror of https://github.com/alibaba/EasyCV.git
113 lines
3.8 KiB
Python
113 lines
3.8 KiB
Python
# Copyright (c) Alibaba, Inc. and its affiliates.
|
|
import time
|
|
|
|
import numpy as np
|
|
import torch
|
|
from mmcv.runner import Hook, get_dist_info
|
|
|
|
from .registry import HOOKS
|
|
|
|
|
|
def cosine_scheduler(base_value,
|
|
final_value,
|
|
epochs,
|
|
niter_per_ep,
|
|
warmup_epochs=0,
|
|
start_warmup_value=0):
|
|
warmup_schedule = np.array([])
|
|
warmup_iters = warmup_epochs * niter_per_ep
|
|
if warmup_epochs > 0:
|
|
warmup_schedule = np.linspace(start_warmup_value, base_value,
|
|
warmup_iters)
|
|
|
|
iters = np.arange(epochs * niter_per_ep - warmup_iters)
|
|
schedule = final_value + 0.5 * (base_value - final_value) * (
|
|
1 + np.cos(np.pi * iters / len(iters)))
|
|
|
|
schedule = np.concatenate((warmup_schedule, schedule))
|
|
assert len(schedule) == epochs * niter_per_ep
|
|
return schedule
|
|
|
|
|
|
@HOOKS.register_module
|
|
class DINOHook(Hook):
|
|
'''Hook in DINO
|
|
'''
|
|
|
|
def __init__(self,
|
|
momentum_teacher=0.996,
|
|
weight_decay=0.04,
|
|
weight_decay_end=0.4,
|
|
**kwargs):
|
|
self.momentum_teacher = momentum_teacher
|
|
self.weight_decay = weight_decay
|
|
self.weight_decay_end = weight_decay_end
|
|
|
|
def before_run(self, runner):
|
|
# call model init
|
|
runner.model.module.init_before_train()
|
|
|
|
try:
|
|
self.rank, self.world_size = get_dist_info()
|
|
except:
|
|
self.rank = 0
|
|
self.world_size = 1
|
|
|
|
max_progress = runner.max_epochs
|
|
self.epoch_length = runner.data_loader[0].__len__()
|
|
self.momentum_schedule = cosine_scheduler(self.momentum_teacher, 1,
|
|
max_progress,
|
|
self.epoch_length)
|
|
self.wd_schedule = cosine_scheduler(self.weight_decay,
|
|
self.weight_decay_end,
|
|
max_progress, self.epoch_length)
|
|
self.optimizer = runner.optimizer
|
|
runner.model.module.this_loss = 0
|
|
runner.model.module.count = 0
|
|
|
|
self.epoch_total_loss = 0
|
|
self.count = 0
|
|
|
|
def before_train_iter(self, runner):
|
|
cur_iter = runner.iter
|
|
|
|
# setup weight decay
|
|
for i, param_group in enumerate(self.optimizer.param_groups):
|
|
if i == 0: # only the first group is regularized
|
|
param_group['weight_decay'] = self.wd_schedule[cur_iter]
|
|
|
|
# call model ema
|
|
if cur_iter > 0:
|
|
runner.model.module.momentum_update_key_encoder(
|
|
self.momentum_schedule[cur_iter])
|
|
|
|
def after_train_iter(self, runner):
|
|
if self.world_size > 1:
|
|
t = torch.tensor(
|
|
[runner.model.module.count, runner.model.module.this_loss],
|
|
dtype=torch.float64,
|
|
device='cuda')
|
|
torch.distributed.barrier()
|
|
torch.distributed.all_reduce(t)
|
|
t = t.tolist()
|
|
self.count += int(t[0])
|
|
self.epoch_total_loss += t[1]
|
|
else:
|
|
self.count += int(runner.model.module.count)
|
|
self.epoch_total_loss += runner.model.module.this_loss
|
|
|
|
if runner.iter % 10 == 0 and self.rank == 0:
|
|
print(' wd : %.4f momentum : %.4f total_avg_loss : %.4f' %
|
|
(self.wd_schedule[runner.iter],
|
|
self.momentum_schedule[runner.iter],
|
|
self.epoch_total_loss / self.count))
|
|
|
|
def before_train_epoch(self, runner):
|
|
# reset epoch loss
|
|
self.epoch_total_loss = 0
|
|
self.count = 0
|
|
torch.cuda.empty_cache()
|
|
# Make sure `torch.cuda.empty_cache` is done and all cache is cleaned
|
|
time.sleep(3)
|
|
runner.model.module.cur_epoch = runner.epoch
|