mirror of https://github.com/JDAI-CV/fast-reid.git
feat(hooks&optim): update stochastic weight averging hooks
Update swa method which will do after regular training if you set this option enabled.pull/49/head
parent
afac8aad5d
commit
0b15ac4e03
|
@ -203,8 +203,9 @@ _C.SOLVER.SWA = CN()
|
|||
_C.SOLVER.SWA.ENABLED = False
|
||||
_C.SOLVER.SWA.ITER = 0
|
||||
_C.SOLVER.SWA.PERIOD = 10
|
||||
_C.SOLVER.SWA.LR = 3.5e-5
|
||||
_C.SOLVER.SWA.CYCLIC_LR = False
|
||||
_C.SOLVER.SWA.LR_FACTOR = 10.
|
||||
_C.SOLVER.SWA.ETA_MIN_LR = 3.5e-6
|
||||
_C.SOLVER.SWA.LR_SCHED = False
|
||||
|
||||
_C.SOLVER.CHECKPOINT_PERIOD = 5000
|
||||
|
||||
|
|
|
@ -241,13 +241,16 @@ class DefaultTrainer(SimpleTrainer):
|
|||
"""
|
||||
# The checkpoint stores the training iteration that just finished, thus we start
|
||||
# at the next iteration (or iter zero if there's no checkpoint).
|
||||
self.start_iter = (
|
||||
self.checkpointer.resume_or_load(self.cfg.MODEL.WEIGHTS, resume=resume).get(
|
||||
"iteration", -1
|
||||
)
|
||||
+ 1
|
||||
)
|
||||
self.data_loader = data_prefetcher(self.cfg, self.train_loader)
|
||||
checkpoint = self.checkpointer.resume_or_load(self.cfg.MODEL.WEIGHTS, resume=resume)
|
||||
self.start_iter = checkpoint.get("iteration", -1) if resume else -1
|
||||
# The checkpoint stores the training iteration that just finished, thus we start
|
||||
# at the next iteration (or iter zero if there's no checkpoint).
|
||||
self.start_iter += 1
|
||||
|
||||
if resume:
|
||||
# data prefetcher will preload a batch data, thus we need to reload data loader
|
||||
# because we have updated dataset pid dictionary.
|
||||
self.data_loader = data_prefetcher(self.cfg, self.train_loader)
|
||||
|
||||
def build_hooks(self):
|
||||
"""
|
||||
|
@ -272,8 +275,9 @@ class DefaultTrainer(SimpleTrainer):
|
|||
hooks.SWA(
|
||||
cfg.SOLVER.MAX_ITER,
|
||||
cfg.SOLVER.SWA.PERIOD,
|
||||
cfg.SOLVER.SWA.LR,
|
||||
cfg.SOLVER.SWA.CYCLIC_LR,
|
||||
cfg.SOLVER.SWA.LR_FACTOR,
|
||||
cfg.SOLVER.SWA.ETA_MIN_LR,
|
||||
cfg.SOLVER.SWA.LR_SCHED,
|
||||
)
|
||||
)
|
||||
|
||||
|
|
|
@ -473,34 +473,34 @@ class FreezeLayer(HookBase):
|
|||
|
||||
|
||||
class SWA(HookBase):
|
||||
def __init__(self, swa_start=None, swa_freq=None, swa_lr=None, cyclic_lr=False,):
|
||||
def __init__(self, swa_start: int, swa_freq: int, swa_lr_factor: float, eta_min: float, lr_sched=False,):
|
||||
self.swa_start = swa_start
|
||||
self.swa_freq = swa_freq
|
||||
self.swa_lr = swa_lr
|
||||
self.cyclic_lr = cyclic_lr
|
||||
self.swa_lr_factor = swa_lr_factor
|
||||
self.eta_min = eta_min
|
||||
self.lr_sched = lr_sched
|
||||
|
||||
def after_step(self):
|
||||
# next_iter = self.trainer.iter + 1
|
||||
next_iter = self.trainer.iter
|
||||
is_swa = next_iter == self.swa_start
|
||||
def before_step(self):
|
||||
is_swa = self.trainer.iter == self.swa_start
|
||||
if is_swa:
|
||||
# Wrapper optimizer with SWA
|
||||
self.trainer.optimizer = optim.SWA(self.trainer.optimizer, self.swa_freq,
|
||||
None if self.cyclic_lr else self.swa_lr)
|
||||
if self.cyclic_lr:
|
||||
self.scheduler = torch.optim.lr_scheduler.CyclicLR(
|
||||
self.trainer.optimizer,
|
||||
base_lr=self.swa_lr,
|
||||
max_lr=10*self.swa_lr,
|
||||
step_size_up=1,
|
||||
step_size_down=self.swa_freq-1,
|
||||
cycle_momentum=False,
|
||||
self.trainer.optimizer = optim.SWA(self.trainer.optimizer, self.swa_freq, self.swa_lr_factor)
|
||||
self.trainer.optimizer.reset_lr_to_swa()
|
||||
|
||||
if self.lr_sched:
|
||||
self.scheduler = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(
|
||||
optimizer=self.trainer.optimizer,
|
||||
T_0=self.swa_freq,
|
||||
eta_min=self.eta_min,
|
||||
)
|
||||
|
||||
def after_step(self):
|
||||
next_iter = self.trainer.iter + 1
|
||||
|
||||
# Use Cyclic learning rate scheduler
|
||||
if next_iter > self.swa_start and self.cyclic_lr:
|
||||
if next_iter > self.swa_start and self.lr_sched:
|
||||
self.scheduler.step()
|
||||
|
||||
is_final = next_iter == self.trainer.max_iter
|
||||
if is_final:
|
||||
self.trainer.optimizer.swap_swa_sgd()
|
||||
self.trainer.optimizer.swap_swa_param()
|
||||
|
|
|
@ -6,14 +6,15 @@
|
|||
# based on:
|
||||
# https://github.com/pytorch/contrib/blob/master/torchcontrib/optim/swa.py
|
||||
|
||||
import warnings
|
||||
from collections import defaultdict
|
||||
|
||||
import torch
|
||||
from torch.optim.optimizer import Optimizer
|
||||
import warnings
|
||||
|
||||
|
||||
class SWA(Optimizer):
|
||||
def __init__(self, optimizer, swa_freq=None, swa_lr=None):
|
||||
def __init__(self, optimizer, swa_freq=None, swa_lr_factor=None):
|
||||
r"""Implements Stochastic Weight Averaging (SWA).
|
||||
Stochastic Weight Averaging was proposed in `Averaging Weights Leads to
|
||||
Wider Optima and Better Generalization`_ by Pavel Izmailov, Dmitrii
|
||||
|
@ -47,7 +48,7 @@ class SWA(Optimizer):
|
|||
>>> opt.zero_grad()
|
||||
>>> loss_fn(model(input), target).backward()
|
||||
>>> opt.step()
|
||||
>>> opt.swap_swa_sgd()
|
||||
>>> opt.swap_swa_param()
|
||||
>>> # manual mode
|
||||
>>> opt = SWA(base_opt)
|
||||
>>> for i in range(100):
|
||||
|
@ -56,7 +57,7 @@ class SWA(Optimizer):
|
|||
>>> opt.step()
|
||||
>>> if i > 10 and i % 5 == 0:
|
||||
>>> opt.update_swa()
|
||||
>>> opt.swap_swa_sgd()
|
||||
>>> opt.swap_swa_param()
|
||||
.. note::
|
||||
SWA does not support parameter-specific values of :attr:`swa_start`,
|
||||
:attr:`swa_freq` or :attr:`swa_lr`. In automatic mode SWA uses the
|
||||
|
@ -87,21 +88,21 @@ class SWA(Optimizer):
|
|||
https://arxiv.org/abs/1806.05594
|
||||
"""
|
||||
self._auto_mode, (self.swa_freq,) = self._check_params(swa_freq)
|
||||
self.swa_lr = swa_lr
|
||||
self.swa_lr_factor = swa_lr_factor
|
||||
|
||||
if self._auto_mode:
|
||||
if swa_freq < 1:
|
||||
raise ValueError("Invalid swa_freq: {}".format(swa_freq))
|
||||
else:
|
||||
if self.swa_lr is not None:
|
||||
if self.swa_lr_factor is not None:
|
||||
warnings.warn(
|
||||
"Swa_freq is None, ignoring swa_lr")
|
||||
# If not in auto mode make all swa parameters None
|
||||
self.swa_lr = None
|
||||
self.swa_lr_factor = None
|
||||
self.swa_freq = None
|
||||
|
||||
if self.swa_lr is not None and self.swa_lr < 0:
|
||||
raise ValueError("Invalid SWA learning rate: {}".format(swa_lr))
|
||||
if self.swa_lr_factor is not None and self.swa_lr_factor < 0:
|
||||
raise ValueError("Invalid SWA learning rate factor: {}".format(swa_lr_factor))
|
||||
|
||||
self.optimizer = optimizer
|
||||
|
||||
|
@ -126,11 +127,9 @@ class SWA(Optimizer):
|
|||
warnings.warn("Casting swa_start, swa_freq to int")
|
||||
return not any(params_none), params
|
||||
|
||||
def _reset_lr_to_swa(self):
|
||||
if self.swa_lr is None:
|
||||
return
|
||||
def reset_lr_to_swa(self):
|
||||
for param_group in self.param_groups:
|
||||
param_group['lr'] = self.swa_lr
|
||||
param_group['initial_lr'] = self.swa_lr_factor * param_group['lr']
|
||||
|
||||
def update_swa_group(self, group):
|
||||
r"""Updates the SWA running averages for the given parameter group.
|
||||
|
@ -149,7 +148,7 @@ class SWA(Optimizer):
|
|||
>>> if i > 10 and i % 5 == 0:
|
||||
>>> # Update SWA for the second parameter group
|
||||
>>> opt.update_swa_group(opt.param_groups[1])
|
||||
>>> opt.swap_swa_sgd()
|
||||
>>> opt.swap_swa_param()
|
||||
"""
|
||||
for p in group['params']:
|
||||
param_state = self.state[p]
|
||||
|
@ -167,7 +166,7 @@ class SWA(Optimizer):
|
|||
for group in self.param_groups:
|
||||
self.update_swa_group(group)
|
||||
|
||||
def swap_swa_sgd(self):
|
||||
def swap_swa_param(self):
|
||||
r"""Swaps the values of the optimized variables and swa buffers.
|
||||
It's meant to be called in the end of training to use the collected
|
||||
swa running averages. It can also be used to evaluate the running
|
||||
|
@ -192,7 +191,6 @@ class SWA(Optimizer):
|
|||
r"""Performs a single optimization step.
|
||||
In automatic mode also updates SWA running averages.
|
||||
"""
|
||||
self._reset_lr_to_swa()
|
||||
loss = self.optimizer.step(closure)
|
||||
for group in self.param_groups:
|
||||
group["step_counter"] += 1
|
||||
|
@ -245,4 +243,4 @@ class SWA(Optimizer):
|
|||
"""
|
||||
param_group['n_avg'] = 0
|
||||
param_group['step_counter'] = 0
|
||||
self.optimizer.add_param_group(param_group)
|
||||
self.optimizer.add_param_group(param_group)
|
||||
|
|
Loading…
Reference in New Issue