feat(hooks&optim): update stochastic weight averging hooks

Update swa method which will do after regular training if you
set this option enabled.
pull/49/head
liaoxingyu 2020-05-08 12:20:04 +08:00
parent afac8aad5d
commit 0b15ac4e03
4 changed files with 50 additions and 47 deletions

View File

@ -203,8 +203,9 @@ _C.SOLVER.SWA = CN()
_C.SOLVER.SWA.ENABLED = False
_C.SOLVER.SWA.ITER = 0
_C.SOLVER.SWA.PERIOD = 10
_C.SOLVER.SWA.LR = 3.5e-5
_C.SOLVER.SWA.CYCLIC_LR = False
_C.SOLVER.SWA.LR_FACTOR = 10.
_C.SOLVER.SWA.ETA_MIN_LR = 3.5e-6
_C.SOLVER.SWA.LR_SCHED = False
_C.SOLVER.CHECKPOINT_PERIOD = 5000

View File

@ -241,13 +241,16 @@ class DefaultTrainer(SimpleTrainer):
"""
# The checkpoint stores the training iteration that just finished, thus we start
# at the next iteration (or iter zero if there's no checkpoint).
self.start_iter = (
self.checkpointer.resume_or_load(self.cfg.MODEL.WEIGHTS, resume=resume).get(
"iteration", -1
)
+ 1
)
self.data_loader = data_prefetcher(self.cfg, self.train_loader)
checkpoint = self.checkpointer.resume_or_load(self.cfg.MODEL.WEIGHTS, resume=resume)
self.start_iter = checkpoint.get("iteration", -1) if resume else -1
# The checkpoint stores the training iteration that just finished, thus we start
# at the next iteration (or iter zero if there's no checkpoint).
self.start_iter += 1
if resume:
# data prefetcher will preload a batch data, thus we need to reload data loader
# because we have updated dataset pid dictionary.
self.data_loader = data_prefetcher(self.cfg, self.train_loader)
def build_hooks(self):
"""
@ -272,8 +275,9 @@ class DefaultTrainer(SimpleTrainer):
hooks.SWA(
cfg.SOLVER.MAX_ITER,
cfg.SOLVER.SWA.PERIOD,
cfg.SOLVER.SWA.LR,
cfg.SOLVER.SWA.CYCLIC_LR,
cfg.SOLVER.SWA.LR_FACTOR,
cfg.SOLVER.SWA.ETA_MIN_LR,
cfg.SOLVER.SWA.LR_SCHED,
)
)

View File

@ -473,34 +473,34 @@ class FreezeLayer(HookBase):
class SWA(HookBase):
def __init__(self, swa_start=None, swa_freq=None, swa_lr=None, cyclic_lr=False,):
def __init__(self, swa_start: int, swa_freq: int, swa_lr_factor: float, eta_min: float, lr_sched=False,):
self.swa_start = swa_start
self.swa_freq = swa_freq
self.swa_lr = swa_lr
self.cyclic_lr = cyclic_lr
self.swa_lr_factor = swa_lr_factor
self.eta_min = eta_min
self.lr_sched = lr_sched
def after_step(self):
# next_iter = self.trainer.iter + 1
next_iter = self.trainer.iter
is_swa = next_iter == self.swa_start
def before_step(self):
is_swa = self.trainer.iter == self.swa_start
if is_swa:
# Wrapper optimizer with SWA
self.trainer.optimizer = optim.SWA(self.trainer.optimizer, self.swa_freq,
None if self.cyclic_lr else self.swa_lr)
if self.cyclic_lr:
self.scheduler = torch.optim.lr_scheduler.CyclicLR(
self.trainer.optimizer,
base_lr=self.swa_lr,
max_lr=10*self.swa_lr,
step_size_up=1,
step_size_down=self.swa_freq-1,
cycle_momentum=False,
self.trainer.optimizer = optim.SWA(self.trainer.optimizer, self.swa_freq, self.swa_lr_factor)
self.trainer.optimizer.reset_lr_to_swa()
if self.lr_sched:
self.scheduler = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(
optimizer=self.trainer.optimizer,
T_0=self.swa_freq,
eta_min=self.eta_min,
)
def after_step(self):
next_iter = self.trainer.iter + 1
# Use Cyclic learning rate scheduler
if next_iter > self.swa_start and self.cyclic_lr:
if next_iter > self.swa_start and self.lr_sched:
self.scheduler.step()
is_final = next_iter == self.trainer.max_iter
if is_final:
self.trainer.optimizer.swap_swa_sgd()
self.trainer.optimizer.swap_swa_param()

View File

@ -6,14 +6,15 @@
# based on:
# https://github.com/pytorch/contrib/blob/master/torchcontrib/optim/swa.py
import warnings
from collections import defaultdict
import torch
from torch.optim.optimizer import Optimizer
import warnings
class SWA(Optimizer):
def __init__(self, optimizer, swa_freq=None, swa_lr=None):
def __init__(self, optimizer, swa_freq=None, swa_lr_factor=None):
r"""Implements Stochastic Weight Averaging (SWA).
Stochastic Weight Averaging was proposed in `Averaging Weights Leads to
Wider Optima and Better Generalization`_ by Pavel Izmailov, Dmitrii
@ -47,7 +48,7 @@ class SWA(Optimizer):
>>> opt.zero_grad()
>>> loss_fn(model(input), target).backward()
>>> opt.step()
>>> opt.swap_swa_sgd()
>>> opt.swap_swa_param()
>>> # manual mode
>>> opt = SWA(base_opt)
>>> for i in range(100):
@ -56,7 +57,7 @@ class SWA(Optimizer):
>>> opt.step()
>>> if i > 10 and i % 5 == 0:
>>> opt.update_swa()
>>> opt.swap_swa_sgd()
>>> opt.swap_swa_param()
.. note::
SWA does not support parameter-specific values of :attr:`swa_start`,
:attr:`swa_freq` or :attr:`swa_lr`. In automatic mode SWA uses the
@ -87,21 +88,21 @@ class SWA(Optimizer):
https://arxiv.org/abs/1806.05594
"""
self._auto_mode, (self.swa_freq,) = self._check_params(swa_freq)
self.swa_lr = swa_lr
self.swa_lr_factor = swa_lr_factor
if self._auto_mode:
if swa_freq < 1:
raise ValueError("Invalid swa_freq: {}".format(swa_freq))
else:
if self.swa_lr is not None:
if self.swa_lr_factor is not None:
warnings.warn(
"Swa_freq is None, ignoring swa_lr")
# If not in auto mode make all swa parameters None
self.swa_lr = None
self.swa_lr_factor = None
self.swa_freq = None
if self.swa_lr is not None and self.swa_lr < 0:
raise ValueError("Invalid SWA learning rate: {}".format(swa_lr))
if self.swa_lr_factor is not None and self.swa_lr_factor < 0:
raise ValueError("Invalid SWA learning rate factor: {}".format(swa_lr_factor))
self.optimizer = optimizer
@ -126,11 +127,9 @@ class SWA(Optimizer):
warnings.warn("Casting swa_start, swa_freq to int")
return not any(params_none), params
def _reset_lr_to_swa(self):
if self.swa_lr is None:
return
def reset_lr_to_swa(self):
for param_group in self.param_groups:
param_group['lr'] = self.swa_lr
param_group['initial_lr'] = self.swa_lr_factor * param_group['lr']
def update_swa_group(self, group):
r"""Updates the SWA running averages for the given parameter group.
@ -149,7 +148,7 @@ class SWA(Optimizer):
>>> if i > 10 and i % 5 == 0:
>>> # Update SWA for the second parameter group
>>> opt.update_swa_group(opt.param_groups[1])
>>> opt.swap_swa_sgd()
>>> opt.swap_swa_param()
"""
for p in group['params']:
param_state = self.state[p]
@ -167,7 +166,7 @@ class SWA(Optimizer):
for group in self.param_groups:
self.update_swa_group(group)
def swap_swa_sgd(self):
def swap_swa_param(self):
r"""Swaps the values of the optimized variables and swa buffers.
It's meant to be called in the end of training to use the collected
swa running averages. It can also be used to evaluate the running
@ -192,7 +191,6 @@ class SWA(Optimizer):
r"""Performs a single optimization step.
In automatic mode also updates SWA running averages.
"""
self._reset_lr_to_swa()
loss = self.optimizer.step(closure)
for group in self.param_groups:
group["step_counter"] += 1
@ -245,4 +243,4 @@ class SWA(Optimizer):
"""
param_group['n_avg'] = 0
param_group['step_counter'] = 0
self.optimizer.add_param_group(param_group)
self.optimizer.add_param_group(param_group)