From 325d9abb764a33982412e19cffa76075ee40d70c Mon Sep 17 00:00:00 2001 From: liaoxingyu Date: Mon, 27 Apr 2020 15:12:01 +0800 Subject: [PATCH] feat($solver): change scheduler call methods using name of lr scheduler in config to call --- fastreid/solver/build.py | 37 +++---- fastreid/solver/lr_scheduler.py | 23 ++-- fastreid/solver/optim/ranger.py | 181 ++++++++++++++++++++++++++++++-- 3 files changed, 203 insertions(+), 38 deletions(-) diff --git a/fastreid/solver/build.py b/fastreid/solver/build.py index b58beaa..92ea616 100644 --- a/fastreid/solver/build.py +++ b/fastreid/solver/build.py @@ -34,21 +34,22 @@ def build_optimizer(cfg, model): def build_lr_scheduler(cfg, optimizer): - if cfg.SOLVER.SCHED == "warmup": - return lr_scheduler.WarmupMultiStepLR( - optimizer, - cfg.SOLVER.STEPS, - cfg.SOLVER.GAMMA, - warmup_factor=cfg.SOLVER.WARMUP_FACTOR, - warmup_iters=cfg.SOLVER.WARMUP_ITERS, - warmup_method=cfg.SOLVER.WARMUP_METHOD - ) - elif cfg.SOLVER.SCHED == "delay": - return lr_scheduler.DelayedCosineAnnealingLR( - optimizer, - cfg.SOLVER.DELAY_ITERS, - cfg.SOLVER.COS_ANNEAL_ITERS, - warmup_factor=cfg.SOLVER.WARMUP_FACTOR, - warmup_iters=cfg.SOLVER.WARMUP_ITERS, - warmup_method=cfg.SOLVER.WARMUP_METHOD - ) + scheduler_args = { + "optimizer": optimizer, + + # warmup options + "warmup_factor": cfg.SOLVER.WARMUP_FACTOR, + "warmup_iters": cfg.SOLVER.WARMUP_ITERS, + "warmup_method": cfg.SOLVER.WARMUP_METHOD, + + # multi-step lr scheduler options + "milestones": cfg.SOLVER.STEPS, + "gamma": cfg.SOLVER.GAMMA, + + # cosine annealing lr scheduler options + "max_iters": cfg.SOLVER.MAX_ITER, + "delay_iters": cfg.SOLVER.DELAY_ITERS, + "eta_min_lr": cfg.SOLVER.ETA_MIN_LR, + + } + return getattr(lr_scheduler, cfg.SOLVER.SCHED)(**scheduler_args) diff --git a/fastreid/solver/lr_scheduler.py b/fastreid/solver/lr_scheduler.py index 7e66275..d109a71 100644 --- a/fastreid/solver/lr_scheduler.py +++ b/fastreid/solver/lr_scheduler.py @@ -10,7 +10,7 @@ from typing import List import torch from torch.optim.lr_scheduler import _LRScheduler, CosineAnnealingLR -__all__ = ["WarmupMultiStepLR", "DelayerScheduler"] +__all__ = ["WarmupMultiStepLR", "DelayedScheduler"] class WarmupMultiStepLR(_LRScheduler): @@ -23,6 +23,7 @@ class WarmupMultiStepLR(_LRScheduler): warmup_iters: int = 1000, warmup_method: str = "linear", last_epoch: int = -1, + **kwargs, ): if not list(milestones) == sorted(milestones): raise ValueError( @@ -76,16 +77,16 @@ def _get_warmup_factor_at_iter( raise ValueError("Unknown warmup method: {}".format(method)) -class DelayerScheduler(_LRScheduler): +class DelayedScheduler(_LRScheduler): """ Starts with a flat lr schedule until it reaches N epochs the applies a scheduler Args: optimizer (Optimizer): Wrapped optimizer. - delay_epochs: number of epochs to keep the initial lr until starting aplying the scheduler + delay_iters: number of epochs to keep the initial lr until starting applying the scheduler after_scheduler: after target_epoch, use this scheduler(eg. ReduceLROnPlateau) """ - def __init__(self, optimizer, delay_epochs, after_scheduler, warmup_factor, warmup_iters, warmup_method): - self.delay_epochs = delay_epochs + def __init__(self, optimizer, delay_iters, after_scheduler, warmup_factor, warmup_iters, warmup_method): + self.delay_epochs = delay_iters self.after_scheduler = after_scheduler self.finished = False self.warmup_factor = warmup_factor @@ -94,7 +95,6 @@ class DelayerScheduler(_LRScheduler): super().__init__(optimizer) def get_lr(self): - if self.last_epoch >= self.delay_epochs: if not self.finished: self.after_scheduler.base_lrs = self.base_lrs @@ -113,10 +113,11 @@ class DelayerScheduler(_LRScheduler): else: self.after_scheduler.step(epoch - self.delay_epochs) else: - return super(DelayerScheduler, self).step(epoch) + return super(DelayedScheduler, self).step(epoch) -def DelayedCosineAnnealingLR(optimizer, delay_epochs, cosine_annealing_epochs, warmup_factor, - warmup_iters, warmup_method): - base_scheduler = CosineAnnealingLR(optimizer, cosine_annealing_epochs, eta_min=0) - return DelayerScheduler(optimizer, delay_epochs, base_scheduler, warmup_factor, warmup_iters, warmup_method) +def DelayedCosineAnnealingLR(optimizer, delay_iters, max_iters, eta_min_lr, warmup_factor, + warmup_iters, warmup_method, **kwargs, ): + cosine_annealing_iters = max_iters - delay_iters + base_scheduler = CosineAnnealingLR(optimizer, cosine_annealing_iters, eta_min_lr) + return DelayedScheduler(optimizer, delay_iters, base_scheduler, warmup_factor, warmup_iters, warmup_method) diff --git a/fastreid/solver/optim/ranger.py b/fastreid/solver/optim/ranger.py index e6fd9f1..62ddf6e 100644 --- a/fastreid/solver/optim/ranger.py +++ b/fastreid/solver/optim/ranger.py @@ -1,14 +1,177 @@ -#### -# CODE TAKEN FROM https://github.com/lessw2020/Ranger-Deep-Learning-Optimizer -# Blog post: https://medium.com/@lessw/new-deep-learning-optimizer-ranger-synergistic-combination-of-radam-lookahead-for-the-best-of-2dc83f79a48d -#### +# Ranger deep learning optimizer - RAdam + Lookahead + Gradient Centralization, combined into one optimizer. + +# https://github.com/lessw2020/Ranger-Deep-Learning-Optimizer +# and/or +# https://github.com/lessw2020/Best-Deep-Learning-Optimizers + +# Ranger has now been used to capture 12 records on the FastAI leaderboard. + +# This version = 20.4.11 + +# Credits: +# Gradient Centralization --> https://arxiv.org/abs/2004.01461v2 (a new optimization technique for DNNs), github: https://github.com/Yonghongwei/Gradient-Centralization +# RAdam --> https://github.com/LiyuanLucasLiu/RAdam +# Lookahead --> rewritten by lessw2020, but big thanks to Github @LonePatient and @RWightman for ideas from their code. +# Lookahead paper --> MZhang,G Hinton https://arxiv.org/abs/1907.08610 + +# summary of changes: +# 4/11/20 - add gradient centralization option. Set new testing benchmark for accuracy with it, toggle with use_gc flag at init. +# full code integration with all updates at param level instead of group, moves slow weights into state dict (from generic weights), +# supports group learning rates (thanks @SHolderbach), fixes sporadic load from saved model issues. +# changes 8/31/19 - fix references to *self*.N_sma_threshold; +# changed eps to 1e-5 as better default than 1e-8. import math + import torch -from .lookahead import Lookahead -from .radam import RAdam +from torch.optim.optimizer import Optimizer -def Ranger(params, alpha=0.5, k=6, betas=(.95, 0.999), *args, **kwargs): - radam = RAdam(params, betas=betas, *args, **kwargs) - return Lookahead(radam, alpha, k) +class Ranger(Optimizer): + + def __init__(self, params, lr=1e-3, # lr + alpha=0.5, k=6, N_sma_threshhold=5, # Ranger options + betas=(.95, 0.999), eps=1e-5, weight_decay=0, # Adam options + use_gc=True, gc_conv_only=False + # Gradient centralization on or off, applied to conv layers only or conv + fc layers + ): + + # parameter checks + if not 0.0 <= alpha <= 1.0: + raise ValueError(f'Invalid slow update rate: {alpha}') + if not 1 <= k: + raise ValueError(f'Invalid lookahead steps: {k}') + if not lr > 0: + raise ValueError(f'Invalid Learning Rate: {lr}') + if not eps > 0: + raise ValueError(f'Invalid eps: {eps}') + + # parameter comments: + # beta1 (momentum) of .95 seems to work better than .90... + # N_sma_threshold of 5 seems better in testing than 4. + # In both cases, worth testing on your dataset (.90 vs .95, 4 vs 5) to make sure which works best for you. + + # prep defaults and init torch.optim base + defaults = dict(lr=lr, alpha=alpha, k=k, step_counter=0, betas=betas, N_sma_threshhold=N_sma_threshhold, + eps=eps, weight_decay=weight_decay) + super().__init__(params, defaults) + + # adjustable threshold + self.N_sma_threshhold = N_sma_threshhold + + # look ahead params + + self.alpha = alpha + self.k = k + + # radam buffer for state + self.radam_buffer = [[None, None, None] for ind in range(10)] + + # gc on or off + self.use_gc = use_gc + + # level of gradient centralization + self.gc_gradient_threshold = 3 if gc_conv_only else 1 + + print(f"Ranger optimizer loaded. \nGradient Centralization usage = {self.use_gc}") + if (self.use_gc and self.gc_gradient_threshold == 1): + print(f"GC applied to both conv and fc layers") + elif (self.use_gc and self.gc_gradient_threshold == 3): + print(f"GC applied to conv layers only") + + def __setstate__(self, state): + print("set state called") + super(Ranger, self).__setstate__(state) + + def step(self, closure=None): + loss = None + # note - below is commented out b/c I have other work that passes back the loss as a float, and thus not a callable closure. + # Uncomment if you need to use the actual closure... + + # if closure is not None: + # loss = closure() + + # Evaluate averages and grad, update param tensors + for group in self.param_groups: + + for p in group['params']: + if p.grad is None: + continue + grad = p.grad.data.float() + + if grad.is_sparse: + raise RuntimeError('Ranger optimizer does not support sparse gradients') + + p_data_fp32 = p.data.float() + + state = self.state[p] # get state dict for this param + + if len(state) == 0: # if first time to run...init dictionary with our desired entries + # if self.first_run_check==0: + # self.first_run_check=1 + # print("Initializing slow buffer...should not see this at load from saved model!") + state['step'] = 0 + state['exp_avg'] = torch.zeros_like(p_data_fp32) + state['exp_avg_sq'] = torch.zeros_like(p_data_fp32) + + # look ahead weight storage now in state dict + state['slow_buffer'] = torch.empty_like(p.data) + state['slow_buffer'].copy_(p.data) + + else: + state['exp_avg'] = state['exp_avg'].type_as(p_data_fp32) + state['exp_avg_sq'] = state['exp_avg_sq'].type_as(p_data_fp32) + + # begin computations + exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq'] + beta1, beta2 = group['betas'] + + # GC operation for Conv layers and FC layers + if grad.dim() > self.gc_gradient_threshold: + grad.add_(-grad.mean(dim=tuple(range(1, grad.dim())), keepdim=True)) + + state['step'] += 1 + + # compute variance mov avg + exp_avg_sq.mul_(beta2).addcmul_(1 - beta2, grad, grad) + # compute mean moving avg + exp_avg.mul_(beta1).add_(1 - beta1, grad) + + buffered = self.radam_buffer[int(state['step'] % 10)] + + if state['step'] == buffered[0]: + N_sma, step_size = buffered[1], buffered[2] + else: + buffered[0] = state['step'] + beta2_t = beta2 ** state['step'] + N_sma_max = 2 / (1 - beta2) - 1 + N_sma = N_sma_max - 2 * state['step'] * beta2_t / (1 - beta2_t) + buffered[1] = N_sma + if N_sma > self.N_sma_threshhold: + step_size = math.sqrt( + (1 - beta2_t) * (N_sma - 4) / (N_sma_max - 4) * (N_sma - 2) / N_sma * N_sma_max / ( + N_sma_max - 2)) / (1 - beta1 ** state['step']) + else: + step_size = 1.0 / (1 - beta1 ** state['step']) + buffered[2] = step_size + + if group['weight_decay'] != 0: + p_data_fp32.add_(-group['weight_decay'] * group['lr'], p_data_fp32) + + # apply lr + if N_sma > self.N_sma_threshhold: + denom = exp_avg_sq.sqrt().add_(group['eps']) + p_data_fp32.addcdiv_(-step_size * group['lr'], exp_avg, denom) + else: + p_data_fp32.add_(-step_size * group['lr'], exp_avg) + + p.data.copy_(p_data_fp32) + + # integrated look ahead... + # we do it at the param level instead of group level + if state['step'] % group['k'] == 0: + slow_p = state['slow_buffer'] # get access to slow param tensor + slow_p.add_(self.alpha, p.data - slow_p) # (fast weights - slow weights) * alpha + p.data.copy_(slow_p) # copy interpolated weights to RAdam param tensor + + return loss