From 0b15ac4e03a37ac5b8982a531ebd13503ffd9de6 Mon Sep 17 00:00:00 2001 From: liaoxingyu Date: Fri, 8 May 2020 12:20:04 +0800 Subject: [PATCH] feat(hooks&optim): update stochastic weight averging hooks Update swa method which will do after regular training if you set this option enabled. --- fastreid/config/defaults.py | 5 +++-- fastreid/engine/defaults.py | 22 ++++++++++++--------- fastreid/engine/hooks.py | 38 ++++++++++++++++++------------------ fastreid/solver/optim/swa.py | 32 ++++++++++++++---------------- 4 files changed, 50 insertions(+), 47 deletions(-) diff --git a/fastreid/config/defaults.py b/fastreid/config/defaults.py index 6cb76bc..1a28cda 100644 --- a/fastreid/config/defaults.py +++ b/fastreid/config/defaults.py @@ -203,8 +203,9 @@ _C.SOLVER.SWA = CN() _C.SOLVER.SWA.ENABLED = False _C.SOLVER.SWA.ITER = 0 _C.SOLVER.SWA.PERIOD = 10 -_C.SOLVER.SWA.LR = 3.5e-5 -_C.SOLVER.SWA.CYCLIC_LR = False +_C.SOLVER.SWA.LR_FACTOR = 10. +_C.SOLVER.SWA.ETA_MIN_LR = 3.5e-6 +_C.SOLVER.SWA.LR_SCHED = False _C.SOLVER.CHECKPOINT_PERIOD = 5000 diff --git a/fastreid/engine/defaults.py b/fastreid/engine/defaults.py index 06cf5ed..d736f19 100644 --- a/fastreid/engine/defaults.py +++ b/fastreid/engine/defaults.py @@ -241,13 +241,16 @@ class DefaultTrainer(SimpleTrainer): """ # The checkpoint stores the training iteration that just finished, thus we start # at the next iteration (or iter zero if there's no checkpoint). - self.start_iter = ( - self.checkpointer.resume_or_load(self.cfg.MODEL.WEIGHTS, resume=resume).get( - "iteration", -1 - ) - + 1 - ) - self.data_loader = data_prefetcher(self.cfg, self.train_loader) + checkpoint = self.checkpointer.resume_or_load(self.cfg.MODEL.WEIGHTS, resume=resume) + self.start_iter = checkpoint.get("iteration", -1) if resume else -1 + # The checkpoint stores the training iteration that just finished, thus we start + # at the next iteration (or iter zero if there's no checkpoint). + self.start_iter += 1 + + if resume: + # data prefetcher will preload a batch data, thus we need to reload data loader + # because we have updated dataset pid dictionary. + self.data_loader = data_prefetcher(self.cfg, self.train_loader) def build_hooks(self): """ @@ -272,8 +275,9 @@ class DefaultTrainer(SimpleTrainer): hooks.SWA( cfg.SOLVER.MAX_ITER, cfg.SOLVER.SWA.PERIOD, - cfg.SOLVER.SWA.LR, - cfg.SOLVER.SWA.CYCLIC_LR, + cfg.SOLVER.SWA.LR_FACTOR, + cfg.SOLVER.SWA.ETA_MIN_LR, + cfg.SOLVER.SWA.LR_SCHED, ) ) diff --git a/fastreid/engine/hooks.py b/fastreid/engine/hooks.py index 209f78b..85edce2 100644 --- a/fastreid/engine/hooks.py +++ b/fastreid/engine/hooks.py @@ -473,34 +473,34 @@ class FreezeLayer(HookBase): class SWA(HookBase): - def __init__(self, swa_start=None, swa_freq=None, swa_lr=None, cyclic_lr=False,): + def __init__(self, swa_start: int, swa_freq: int, swa_lr_factor: float, eta_min: float, lr_sched=False,): self.swa_start = swa_start self.swa_freq = swa_freq - self.swa_lr = swa_lr - self.cyclic_lr = cyclic_lr + self.swa_lr_factor = swa_lr_factor + self.eta_min = eta_min + self.lr_sched = lr_sched - def after_step(self): - # next_iter = self.trainer.iter + 1 - next_iter = self.trainer.iter - is_swa = next_iter == self.swa_start + def before_step(self): + is_swa = self.trainer.iter == self.swa_start if is_swa: # Wrapper optimizer with SWA - self.trainer.optimizer = optim.SWA(self.trainer.optimizer, self.swa_freq, - None if self.cyclic_lr else self.swa_lr) - if self.cyclic_lr: - self.scheduler = torch.optim.lr_scheduler.CyclicLR( - self.trainer.optimizer, - base_lr=self.swa_lr, - max_lr=10*self.swa_lr, - step_size_up=1, - step_size_down=self.swa_freq-1, - cycle_momentum=False, + self.trainer.optimizer = optim.SWA(self.trainer.optimizer, self.swa_freq, self.swa_lr_factor) + self.trainer.optimizer.reset_lr_to_swa() + + if self.lr_sched: + self.scheduler = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts( + optimizer=self.trainer.optimizer, + T_0=self.swa_freq, + eta_min=self.eta_min, ) + def after_step(self): + next_iter = self.trainer.iter + 1 + # Use Cyclic learning rate scheduler - if next_iter > self.swa_start and self.cyclic_lr: + if next_iter > self.swa_start and self.lr_sched: self.scheduler.step() is_final = next_iter == self.trainer.max_iter if is_final: - self.trainer.optimizer.swap_swa_sgd() + self.trainer.optimizer.swap_swa_param() diff --git a/fastreid/solver/optim/swa.py b/fastreid/solver/optim/swa.py index 2200038..889239a 100644 --- a/fastreid/solver/optim/swa.py +++ b/fastreid/solver/optim/swa.py @@ -6,14 +6,15 @@ # based on: # https://github.com/pytorch/contrib/blob/master/torchcontrib/optim/swa.py +import warnings from collections import defaultdict + import torch from torch.optim.optimizer import Optimizer -import warnings class SWA(Optimizer): - def __init__(self, optimizer, swa_freq=None, swa_lr=None): + def __init__(self, optimizer, swa_freq=None, swa_lr_factor=None): r"""Implements Stochastic Weight Averaging (SWA). Stochastic Weight Averaging was proposed in `Averaging Weights Leads to Wider Optima and Better Generalization`_ by Pavel Izmailov, Dmitrii @@ -47,7 +48,7 @@ class SWA(Optimizer): >>> opt.zero_grad() >>> loss_fn(model(input), target).backward() >>> opt.step() - >>> opt.swap_swa_sgd() + >>> opt.swap_swa_param() >>> # manual mode >>> opt = SWA(base_opt) >>> for i in range(100): @@ -56,7 +57,7 @@ class SWA(Optimizer): >>> opt.step() >>> if i > 10 and i % 5 == 0: >>> opt.update_swa() - >>> opt.swap_swa_sgd() + >>> opt.swap_swa_param() .. note:: SWA does not support parameter-specific values of :attr:`swa_start`, :attr:`swa_freq` or :attr:`swa_lr`. In automatic mode SWA uses the @@ -87,21 +88,21 @@ class SWA(Optimizer): https://arxiv.org/abs/1806.05594 """ self._auto_mode, (self.swa_freq,) = self._check_params(swa_freq) - self.swa_lr = swa_lr + self.swa_lr_factor = swa_lr_factor if self._auto_mode: if swa_freq < 1: raise ValueError("Invalid swa_freq: {}".format(swa_freq)) else: - if self.swa_lr is not None: + if self.swa_lr_factor is not None: warnings.warn( "Swa_freq is None, ignoring swa_lr") # If not in auto mode make all swa parameters None - self.swa_lr = None + self.swa_lr_factor = None self.swa_freq = None - if self.swa_lr is not None and self.swa_lr < 0: - raise ValueError("Invalid SWA learning rate: {}".format(swa_lr)) + if self.swa_lr_factor is not None and self.swa_lr_factor < 0: + raise ValueError("Invalid SWA learning rate factor: {}".format(swa_lr_factor)) self.optimizer = optimizer @@ -126,11 +127,9 @@ class SWA(Optimizer): warnings.warn("Casting swa_start, swa_freq to int") return not any(params_none), params - def _reset_lr_to_swa(self): - if self.swa_lr is None: - return + def reset_lr_to_swa(self): for param_group in self.param_groups: - param_group['lr'] = self.swa_lr + param_group['initial_lr'] = self.swa_lr_factor * param_group['lr'] def update_swa_group(self, group): r"""Updates the SWA running averages for the given parameter group. @@ -149,7 +148,7 @@ class SWA(Optimizer): >>> if i > 10 and i % 5 == 0: >>> # Update SWA for the second parameter group >>> opt.update_swa_group(opt.param_groups[1]) - >>> opt.swap_swa_sgd() + >>> opt.swap_swa_param() """ for p in group['params']: param_state = self.state[p] @@ -167,7 +166,7 @@ class SWA(Optimizer): for group in self.param_groups: self.update_swa_group(group) - def swap_swa_sgd(self): + def swap_swa_param(self): r"""Swaps the values of the optimized variables and swa buffers. It's meant to be called in the end of training to use the collected swa running averages. It can also be used to evaluate the running @@ -192,7 +191,6 @@ class SWA(Optimizer): r"""Performs a single optimization step. In automatic mode also updates SWA running averages. """ - self._reset_lr_to_swa() loss = self.optimizer.step(closure) for group in self.param_groups: group["step_counter"] += 1 @@ -245,4 +243,4 @@ class SWA(Optimizer): """ param_group['n_avg'] = 0 param_group['step_counter'] = 0 - self.optimizer.add_param_group(param_group) \ No newline at end of file + self.optimizer.add_param_group(param_group)