mirror of https://github.com/JDAI-CV/fast-reid.git
feat($solver): change scheduler call methods
using name of lr scheduler in config to callpull/44/head
parent
9e3f2c1e7a
commit
325d9abb76
|
@ -34,21 +34,22 @@ def build_optimizer(cfg, model):
|
||||||
|
|
||||||
|
|
||||||
def build_lr_scheduler(cfg, optimizer):
|
def build_lr_scheduler(cfg, optimizer):
|
||||||
if cfg.SOLVER.SCHED == "warmup":
|
scheduler_args = {
|
||||||
return lr_scheduler.WarmupMultiStepLR(
|
"optimizer": optimizer,
|
||||||
optimizer,
|
|
||||||
cfg.SOLVER.STEPS,
|
# warmup options
|
||||||
cfg.SOLVER.GAMMA,
|
"warmup_factor": cfg.SOLVER.WARMUP_FACTOR,
|
||||||
warmup_factor=cfg.SOLVER.WARMUP_FACTOR,
|
"warmup_iters": cfg.SOLVER.WARMUP_ITERS,
|
||||||
warmup_iters=cfg.SOLVER.WARMUP_ITERS,
|
"warmup_method": cfg.SOLVER.WARMUP_METHOD,
|
||||||
warmup_method=cfg.SOLVER.WARMUP_METHOD
|
|
||||||
)
|
# multi-step lr scheduler options
|
||||||
elif cfg.SOLVER.SCHED == "delay":
|
"milestones": cfg.SOLVER.STEPS,
|
||||||
return lr_scheduler.DelayedCosineAnnealingLR(
|
"gamma": cfg.SOLVER.GAMMA,
|
||||||
optimizer,
|
|
||||||
cfg.SOLVER.DELAY_ITERS,
|
# cosine annealing lr scheduler options
|
||||||
cfg.SOLVER.COS_ANNEAL_ITERS,
|
"max_iters": cfg.SOLVER.MAX_ITER,
|
||||||
warmup_factor=cfg.SOLVER.WARMUP_FACTOR,
|
"delay_iters": cfg.SOLVER.DELAY_ITERS,
|
||||||
warmup_iters=cfg.SOLVER.WARMUP_ITERS,
|
"eta_min_lr": cfg.SOLVER.ETA_MIN_LR,
|
||||||
warmup_method=cfg.SOLVER.WARMUP_METHOD
|
|
||||||
)
|
}
|
||||||
|
return getattr(lr_scheduler, cfg.SOLVER.SCHED)(**scheduler_args)
|
||||||
|
|
|
@ -10,7 +10,7 @@ from typing import List
|
||||||
import torch
|
import torch
|
||||||
from torch.optim.lr_scheduler import _LRScheduler, CosineAnnealingLR
|
from torch.optim.lr_scheduler import _LRScheduler, CosineAnnealingLR
|
||||||
|
|
||||||
__all__ = ["WarmupMultiStepLR", "DelayerScheduler"]
|
__all__ = ["WarmupMultiStepLR", "DelayedScheduler"]
|
||||||
|
|
||||||
|
|
||||||
class WarmupMultiStepLR(_LRScheduler):
|
class WarmupMultiStepLR(_LRScheduler):
|
||||||
|
@ -23,6 +23,7 @@ class WarmupMultiStepLR(_LRScheduler):
|
||||||
warmup_iters: int = 1000,
|
warmup_iters: int = 1000,
|
||||||
warmup_method: str = "linear",
|
warmup_method: str = "linear",
|
||||||
last_epoch: int = -1,
|
last_epoch: int = -1,
|
||||||
|
**kwargs,
|
||||||
):
|
):
|
||||||
if not list(milestones) == sorted(milestones):
|
if not list(milestones) == sorted(milestones):
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
|
@ -76,16 +77,16 @@ def _get_warmup_factor_at_iter(
|
||||||
raise ValueError("Unknown warmup method: {}".format(method))
|
raise ValueError("Unknown warmup method: {}".format(method))
|
||||||
|
|
||||||
|
|
||||||
class DelayerScheduler(_LRScheduler):
|
class DelayedScheduler(_LRScheduler):
|
||||||
""" Starts with a flat lr schedule until it reaches N epochs the applies a scheduler
|
""" Starts with a flat lr schedule until it reaches N epochs the applies a scheduler
|
||||||
Args:
|
Args:
|
||||||
optimizer (Optimizer): Wrapped optimizer.
|
optimizer (Optimizer): Wrapped optimizer.
|
||||||
delay_epochs: number of epochs to keep the initial lr until starting aplying the scheduler
|
delay_iters: number of epochs to keep the initial lr until starting applying the scheduler
|
||||||
after_scheduler: after target_epoch, use this scheduler(eg. ReduceLROnPlateau)
|
after_scheduler: after target_epoch, use this scheduler(eg. ReduceLROnPlateau)
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, optimizer, delay_epochs, after_scheduler, warmup_factor, warmup_iters, warmup_method):
|
def __init__(self, optimizer, delay_iters, after_scheduler, warmup_factor, warmup_iters, warmup_method):
|
||||||
self.delay_epochs = delay_epochs
|
self.delay_epochs = delay_iters
|
||||||
self.after_scheduler = after_scheduler
|
self.after_scheduler = after_scheduler
|
||||||
self.finished = False
|
self.finished = False
|
||||||
self.warmup_factor = warmup_factor
|
self.warmup_factor = warmup_factor
|
||||||
|
@ -94,7 +95,6 @@ class DelayerScheduler(_LRScheduler):
|
||||||
super().__init__(optimizer)
|
super().__init__(optimizer)
|
||||||
|
|
||||||
def get_lr(self):
|
def get_lr(self):
|
||||||
|
|
||||||
if self.last_epoch >= self.delay_epochs:
|
if self.last_epoch >= self.delay_epochs:
|
||||||
if not self.finished:
|
if not self.finished:
|
||||||
self.after_scheduler.base_lrs = self.base_lrs
|
self.after_scheduler.base_lrs = self.base_lrs
|
||||||
|
@ -113,10 +113,11 @@ class DelayerScheduler(_LRScheduler):
|
||||||
else:
|
else:
|
||||||
self.after_scheduler.step(epoch - self.delay_epochs)
|
self.after_scheduler.step(epoch - self.delay_epochs)
|
||||||
else:
|
else:
|
||||||
return super(DelayerScheduler, self).step(epoch)
|
return super(DelayedScheduler, self).step(epoch)
|
||||||
|
|
||||||
|
|
||||||
def DelayedCosineAnnealingLR(optimizer, delay_epochs, cosine_annealing_epochs, warmup_factor,
|
def DelayedCosineAnnealingLR(optimizer, delay_iters, max_iters, eta_min_lr, warmup_factor,
|
||||||
warmup_iters, warmup_method):
|
warmup_iters, warmup_method, **kwargs, ):
|
||||||
base_scheduler = CosineAnnealingLR(optimizer, cosine_annealing_epochs, eta_min=0)
|
cosine_annealing_iters = max_iters - delay_iters
|
||||||
return DelayerScheduler(optimizer, delay_epochs, base_scheduler, warmup_factor, warmup_iters, warmup_method)
|
base_scheduler = CosineAnnealingLR(optimizer, cosine_annealing_iters, eta_min_lr)
|
||||||
|
return DelayedScheduler(optimizer, delay_iters, base_scheduler, warmup_factor, warmup_iters, warmup_method)
|
||||||
|
|
|
@ -1,14 +1,177 @@
|
||||||
####
|
# Ranger deep learning optimizer - RAdam + Lookahead + Gradient Centralization, combined into one optimizer.
|
||||||
# CODE TAKEN FROM https://github.com/lessw2020/Ranger-Deep-Learning-Optimizer
|
|
||||||
# Blog post: https://medium.com/@lessw/new-deep-learning-optimizer-ranger-synergistic-combination-of-radam-lookahead-for-the-best-of-2dc83f79a48d
|
# https://github.com/lessw2020/Ranger-Deep-Learning-Optimizer
|
||||||
####
|
# and/or
|
||||||
|
# https://github.com/lessw2020/Best-Deep-Learning-Optimizers
|
||||||
|
|
||||||
|
# Ranger has now been used to capture 12 records on the FastAI leaderboard.
|
||||||
|
|
||||||
|
# This version = 20.4.11
|
||||||
|
|
||||||
|
# Credits:
|
||||||
|
# Gradient Centralization --> https://arxiv.org/abs/2004.01461v2 (a new optimization technique for DNNs), github: https://github.com/Yonghongwei/Gradient-Centralization
|
||||||
|
# RAdam --> https://github.com/LiyuanLucasLiu/RAdam
|
||||||
|
# Lookahead --> rewritten by lessw2020, but big thanks to Github @LonePatient and @RWightman for ideas from their code.
|
||||||
|
# Lookahead paper --> MZhang,G Hinton https://arxiv.org/abs/1907.08610
|
||||||
|
|
||||||
|
# summary of changes:
|
||||||
|
# 4/11/20 - add gradient centralization option. Set new testing benchmark for accuracy with it, toggle with use_gc flag at init.
|
||||||
|
# full code integration with all updates at param level instead of group, moves slow weights into state dict (from generic weights),
|
||||||
|
# supports group learning rates (thanks @SHolderbach), fixes sporadic load from saved model issues.
|
||||||
|
# changes 8/31/19 - fix references to *self*.N_sma_threshold;
|
||||||
|
# changed eps to 1e-5 as better default than 1e-8.
|
||||||
|
|
||||||
import math
|
import math
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from .lookahead import Lookahead
|
from torch.optim.optimizer import Optimizer
|
||||||
from .radam import RAdam
|
|
||||||
|
|
||||||
|
|
||||||
def Ranger(params, alpha=0.5, k=6, betas=(.95, 0.999), *args, **kwargs):
|
class Ranger(Optimizer):
|
||||||
radam = RAdam(params, betas=betas, *args, **kwargs)
|
|
||||||
return Lookahead(radam, alpha, k)
|
def __init__(self, params, lr=1e-3, # lr
|
||||||
|
alpha=0.5, k=6, N_sma_threshhold=5, # Ranger options
|
||||||
|
betas=(.95, 0.999), eps=1e-5, weight_decay=0, # Adam options
|
||||||
|
use_gc=True, gc_conv_only=False
|
||||||
|
# Gradient centralization on or off, applied to conv layers only or conv + fc layers
|
||||||
|
):
|
||||||
|
|
||||||
|
# parameter checks
|
||||||
|
if not 0.0 <= alpha <= 1.0:
|
||||||
|
raise ValueError(f'Invalid slow update rate: {alpha}')
|
||||||
|
if not 1 <= k:
|
||||||
|
raise ValueError(f'Invalid lookahead steps: {k}')
|
||||||
|
if not lr > 0:
|
||||||
|
raise ValueError(f'Invalid Learning Rate: {lr}')
|
||||||
|
if not eps > 0:
|
||||||
|
raise ValueError(f'Invalid eps: {eps}')
|
||||||
|
|
||||||
|
# parameter comments:
|
||||||
|
# beta1 (momentum) of .95 seems to work better than .90...
|
||||||
|
# N_sma_threshold of 5 seems better in testing than 4.
|
||||||
|
# In both cases, worth testing on your dataset (.90 vs .95, 4 vs 5) to make sure which works best for you.
|
||||||
|
|
||||||
|
# prep defaults and init torch.optim base
|
||||||
|
defaults = dict(lr=lr, alpha=alpha, k=k, step_counter=0, betas=betas, N_sma_threshhold=N_sma_threshhold,
|
||||||
|
eps=eps, weight_decay=weight_decay)
|
||||||
|
super().__init__(params, defaults)
|
||||||
|
|
||||||
|
# adjustable threshold
|
||||||
|
self.N_sma_threshhold = N_sma_threshhold
|
||||||
|
|
||||||
|
# look ahead params
|
||||||
|
|
||||||
|
self.alpha = alpha
|
||||||
|
self.k = k
|
||||||
|
|
||||||
|
# radam buffer for state
|
||||||
|
self.radam_buffer = [[None, None, None] for ind in range(10)]
|
||||||
|
|
||||||
|
# gc on or off
|
||||||
|
self.use_gc = use_gc
|
||||||
|
|
||||||
|
# level of gradient centralization
|
||||||
|
self.gc_gradient_threshold = 3 if gc_conv_only else 1
|
||||||
|
|
||||||
|
print(f"Ranger optimizer loaded. \nGradient Centralization usage = {self.use_gc}")
|
||||||
|
if (self.use_gc and self.gc_gradient_threshold == 1):
|
||||||
|
print(f"GC applied to both conv and fc layers")
|
||||||
|
elif (self.use_gc and self.gc_gradient_threshold == 3):
|
||||||
|
print(f"GC applied to conv layers only")
|
||||||
|
|
||||||
|
def __setstate__(self, state):
|
||||||
|
print("set state called")
|
||||||
|
super(Ranger, self).__setstate__(state)
|
||||||
|
|
||||||
|
def step(self, closure=None):
|
||||||
|
loss = None
|
||||||
|
# note - below is commented out b/c I have other work that passes back the loss as a float, and thus not a callable closure.
|
||||||
|
# Uncomment if you need to use the actual closure...
|
||||||
|
|
||||||
|
# if closure is not None:
|
||||||
|
# loss = closure()
|
||||||
|
|
||||||
|
# Evaluate averages and grad, update param tensors
|
||||||
|
for group in self.param_groups:
|
||||||
|
|
||||||
|
for p in group['params']:
|
||||||
|
if p.grad is None:
|
||||||
|
continue
|
||||||
|
grad = p.grad.data.float()
|
||||||
|
|
||||||
|
if grad.is_sparse:
|
||||||
|
raise RuntimeError('Ranger optimizer does not support sparse gradients')
|
||||||
|
|
||||||
|
p_data_fp32 = p.data.float()
|
||||||
|
|
||||||
|
state = self.state[p] # get state dict for this param
|
||||||
|
|
||||||
|
if len(state) == 0: # if first time to run...init dictionary with our desired entries
|
||||||
|
# if self.first_run_check==0:
|
||||||
|
# self.first_run_check=1
|
||||||
|
# print("Initializing slow buffer...should not see this at load from saved model!")
|
||||||
|
state['step'] = 0
|
||||||
|
state['exp_avg'] = torch.zeros_like(p_data_fp32)
|
||||||
|
state['exp_avg_sq'] = torch.zeros_like(p_data_fp32)
|
||||||
|
|
||||||
|
# look ahead weight storage now in state dict
|
||||||
|
state['slow_buffer'] = torch.empty_like(p.data)
|
||||||
|
state['slow_buffer'].copy_(p.data)
|
||||||
|
|
||||||
|
else:
|
||||||
|
state['exp_avg'] = state['exp_avg'].type_as(p_data_fp32)
|
||||||
|
state['exp_avg_sq'] = state['exp_avg_sq'].type_as(p_data_fp32)
|
||||||
|
|
||||||
|
# begin computations
|
||||||
|
exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq']
|
||||||
|
beta1, beta2 = group['betas']
|
||||||
|
|
||||||
|
# GC operation for Conv layers and FC layers
|
||||||
|
if grad.dim() > self.gc_gradient_threshold:
|
||||||
|
grad.add_(-grad.mean(dim=tuple(range(1, grad.dim())), keepdim=True))
|
||||||
|
|
||||||
|
state['step'] += 1
|
||||||
|
|
||||||
|
# compute variance mov avg
|
||||||
|
exp_avg_sq.mul_(beta2).addcmul_(1 - beta2, grad, grad)
|
||||||
|
# compute mean moving avg
|
||||||
|
exp_avg.mul_(beta1).add_(1 - beta1, grad)
|
||||||
|
|
||||||
|
buffered = self.radam_buffer[int(state['step'] % 10)]
|
||||||
|
|
||||||
|
if state['step'] == buffered[0]:
|
||||||
|
N_sma, step_size = buffered[1], buffered[2]
|
||||||
|
else:
|
||||||
|
buffered[0] = state['step']
|
||||||
|
beta2_t = beta2 ** state['step']
|
||||||
|
N_sma_max = 2 / (1 - beta2) - 1
|
||||||
|
N_sma = N_sma_max - 2 * state['step'] * beta2_t / (1 - beta2_t)
|
||||||
|
buffered[1] = N_sma
|
||||||
|
if N_sma > self.N_sma_threshhold:
|
||||||
|
step_size = math.sqrt(
|
||||||
|
(1 - beta2_t) * (N_sma - 4) / (N_sma_max - 4) * (N_sma - 2) / N_sma * N_sma_max / (
|
||||||
|
N_sma_max - 2)) / (1 - beta1 ** state['step'])
|
||||||
|
else:
|
||||||
|
step_size = 1.0 / (1 - beta1 ** state['step'])
|
||||||
|
buffered[2] = step_size
|
||||||
|
|
||||||
|
if group['weight_decay'] != 0:
|
||||||
|
p_data_fp32.add_(-group['weight_decay'] * group['lr'], p_data_fp32)
|
||||||
|
|
||||||
|
# apply lr
|
||||||
|
if N_sma > self.N_sma_threshhold:
|
||||||
|
denom = exp_avg_sq.sqrt().add_(group['eps'])
|
||||||
|
p_data_fp32.addcdiv_(-step_size * group['lr'], exp_avg, denom)
|
||||||
|
else:
|
||||||
|
p_data_fp32.add_(-step_size * group['lr'], exp_avg)
|
||||||
|
|
||||||
|
p.data.copy_(p_data_fp32)
|
||||||
|
|
||||||
|
# integrated look ahead...
|
||||||
|
# we do it at the param level instead of group level
|
||||||
|
if state['step'] % group['k'] == 0:
|
||||||
|
slow_p = state['slow_buffer'] # get access to slow param tensor
|
||||||
|
slow_p.add_(self.alpha, p.data - slow_p) # (fast weights - slow weights) * alpha
|
||||||
|
p.data.copy_(slow_p) # copy interpolated weights to RAdam param tensor
|
||||||
|
|
||||||
|
return loss
|
||||||
|
|
Loading…
Reference in New Issue