mirror of
https://github.com/JDAI-CV/fast-reid.git
synced 2025-06-03 14:50:47 +08:00
feat: add swa algorithm
Add swa and related config options, if it is enabled, model will do swa after regular training
This commit is contained in:
parent
9d9a1f4f2d
commit
948af64fd1
@ -182,8 +182,8 @@ _C.SOLVER.MOMENTUM = 0.9
|
||||
_C.SOLVER.WEIGHT_DECAY = 0.0005
|
||||
_C.SOLVER.WEIGHT_DECAY_BIAS = 0.
|
||||
|
||||
_C.SOLVER.SCHED = "WarmupMultiStepLR"
|
||||
# Multi-step learning rate options
|
||||
_C.SOLVER.SCHED = "WarmupMultiStepLR"
|
||||
_C.SOLVER.GAMMA = 0.1
|
||||
_C.SOLVER.STEPS = (30, 55)
|
||||
|
||||
@ -198,6 +198,14 @@ _C.SOLVER.WARMUP_METHOD = "linear"
|
||||
|
||||
_C.SOLVER.FREEZE_ITERS = 0
|
||||
|
||||
# SWA options
|
||||
_C.SOLVER.SWA = CN()
|
||||
_C.SOLVER.SWA.ENABLED = False
|
||||
_C.SOLVER.SWA.ITER = 0
|
||||
_C.SOLVER.SWA.PERIOD = 10
|
||||
_C.SOLVER.SWA.LR = 3.5e-5
|
||||
_C.SOLVER.SWA.CYCLIC_LR = False
|
||||
|
||||
_C.SOLVER.CHECKPOINT_PERIOD = 5000
|
||||
|
||||
_C.SOLVER.LOG_PERIOD = 30
|
||||
|
@ -4,6 +4,7 @@
|
||||
import datetime
|
||||
import itertools
|
||||
import logging
|
||||
import warnings
|
||||
import os
|
||||
import tempfile
|
||||
import time
|
||||
@ -12,14 +13,15 @@ from collections import Counter
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
from ..evaluation.testing import flatten_results_dict
|
||||
from ..utils import comm
|
||||
from ..utils.checkpoint import PeriodicCheckpointer as _PeriodicCheckpointer
|
||||
from ..utils.events import EventStorage, EventWriter
|
||||
from ..utils.file_io import PathManager
|
||||
from ..utils.precision_bn import update_bn_stats, get_bn_modules
|
||||
from ..utils.timer import Timer
|
||||
from .train_loop import HookBase
|
||||
from fastreid.solver import optim
|
||||
from fastreid.evaluation.testing import flatten_results_dict
|
||||
from fastreid.utils import comm
|
||||
from fastreid.utils.checkpoint import PeriodicCheckpointer as _PeriodicCheckpointer
|
||||
from fastreid.utils.events import EventStorage, EventWriter
|
||||
from fastreid.utils.file_io import PathManager
|
||||
from fastreid.utils.precision_bn import update_bn_stats, get_bn_modules
|
||||
from fastreid.utils.timer import Timer
|
||||
|
||||
__all__ = [
|
||||
"CallbackHook",
|
||||
@ -468,3 +470,37 @@ class FreezeLayer(HookBase):
|
||||
self.model.train()
|
||||
for name, param in self.model.named_parameters():
|
||||
param.requires_grad = self.param_grad[name]
|
||||
|
||||
|
||||
class SWA(HookBase):
|
||||
def __init__(self, swa_start=None, swa_freq=None, swa_lr=None, cyclic_lr=False,):
|
||||
self.swa_start = swa_start
|
||||
self.swa_freq = swa_freq
|
||||
self.swa_lr = swa_lr
|
||||
self.cyclic_lr = cyclic_lr
|
||||
|
||||
def after_step(self):
|
||||
# next_iter = self.trainer.iter + 1
|
||||
next_iter = self.trainer.iter
|
||||
is_swa = next_iter == self.swa_start
|
||||
if is_swa:
|
||||
# Wrapper optimizer with SWA
|
||||
self.trainer.optimizer = optim.SWA(self.trainer.optimizer, self.swa_freq,
|
||||
None if self.cyclic_lr else self.swa_lr)
|
||||
if self.cyclic_lr:
|
||||
self.scheduler = torch.optim.lr_scheduler.CyclicLR(
|
||||
self.trainer.optimizer,
|
||||
base_lr=self.swa_lr,
|
||||
max_lr=10*self.swa_lr,
|
||||
step_size_up=1,
|
||||
step_size_down=self.swa_freq-1,
|
||||
cycle_momentum=False,
|
||||
)
|
||||
|
||||
# Use Cyclic learning rate scheduler
|
||||
if next_iter > self.swa_start and self.cyclic_lr:
|
||||
self.scheduler.step()
|
||||
|
||||
is_final = next_iter == self.trainer.max_iter
|
||||
if is_final:
|
||||
self.trainer.optimizer.swap_swa_sgd()
|
||||
|
@ -5,5 +5,6 @@ from .over9000 import Over9000, RangerLars
|
||||
from .radam import RAdam, PlainRAdam, AdamW
|
||||
from .ralamb import Ralamb
|
||||
from .ranger import Ranger
|
||||
from .swa import SWA
|
||||
|
||||
from torch.optim import *
|
||||
|
248
fastreid/solver/optim/swa.py
Normal file
248
fastreid/solver/optim/swa.py
Normal file
@ -0,0 +1,248 @@
|
||||
# encoding: utf-8
|
||||
"""
|
||||
@author: xingyu liao
|
||||
@contact: liaoxingyu5@jd.com
|
||||
"""
|
||||
# based on:
|
||||
# https://github.com/pytorch/contrib/blob/master/torchcontrib/optim/swa.py
|
||||
|
||||
from collections import defaultdict
|
||||
import torch
|
||||
from torch.optim.optimizer import Optimizer
|
||||
import warnings
|
||||
|
||||
|
||||
class SWA(Optimizer):
|
||||
def __init__(self, optimizer, swa_freq=None, swa_lr=None):
|
||||
r"""Implements Stochastic Weight Averaging (SWA).
|
||||
Stochastic Weight Averaging was proposed in `Averaging Weights Leads to
|
||||
Wider Optima and Better Generalization`_ by Pavel Izmailov, Dmitrii
|
||||
Podoprikhin, Timur Garipov, Dmitry Vetrov and Andrew Gordon Wilson
|
||||
(UAI 2018).
|
||||
SWA is implemented as a wrapper class taking optimizer instance as input
|
||||
and applying SWA on top of that optimizer.
|
||||
SWA can be used in two modes: automatic and manual. In the automatic
|
||||
mode SWA running averages are automatically updated every
|
||||
:attr:`swa_freq` steps after :attr:`swa_start` steps of optimization. If
|
||||
:attr:`swa_lr` is provided, the learning rate of the optimizer is reset
|
||||
to :attr:`swa_lr` at every step starting from :attr:`swa_start`. To use
|
||||
SWA in automatic mode provide values for both :attr:`swa_start` and
|
||||
:attr:`swa_freq` arguments.
|
||||
Alternatively, in the manual mode, use :meth:`update_swa` or
|
||||
:meth:`update_swa_group` methods to update the SWA running averages.
|
||||
In the end of training use `swap_swa_sgd` method to set the optimized
|
||||
variables to the computed averages.
|
||||
Args:
|
||||
swa_freq (int): number of steps between subsequent updates of
|
||||
SWA running averages in automatic mode; if None, manual mode is
|
||||
selected (default: None)
|
||||
swa_lr (float): learning rate to use starting from step swa_start
|
||||
in automatic mode; if None, learning rate is not changed
|
||||
(default: None)
|
||||
Examples:
|
||||
>>> # automatic mode
|
||||
>>> base_opt = torch.optim.SGD(model.parameters(), lr=0.1)
|
||||
>>> opt = SWA(base_opt, swa_start=10, swa_freq=5, swa_lr=0.05)
|
||||
>>> for _ in range(100):
|
||||
>>> opt.zero_grad()
|
||||
>>> loss_fn(model(input), target).backward()
|
||||
>>> opt.step()
|
||||
>>> opt.swap_swa_sgd()
|
||||
>>> # manual mode
|
||||
>>> opt = SWA(base_opt)
|
||||
>>> for i in range(100):
|
||||
>>> opt.zero_grad()
|
||||
>>> loss_fn(model(input), target).backward()
|
||||
>>> opt.step()
|
||||
>>> if i > 10 and i % 5 == 0:
|
||||
>>> opt.update_swa()
|
||||
>>> opt.swap_swa_sgd()
|
||||
.. note::
|
||||
SWA does not support parameter-specific values of :attr:`swa_start`,
|
||||
:attr:`swa_freq` or :attr:`swa_lr`. In automatic mode SWA uses the
|
||||
same :attr:`swa_start`, :attr:`swa_freq` and :attr:`swa_lr` for all
|
||||
parameter groups. If needed, use manual mode with
|
||||
:meth:`update_swa_group` to use different update schedules for
|
||||
different parameter groups.
|
||||
.. note::
|
||||
Call :meth:`swap_swa_sgd` in the end of training to use the computed
|
||||
running averages.
|
||||
.. note::
|
||||
If you are using SWA to optimize the parameters of a Neural Network
|
||||
containing Batch Normalization layers, you need to update the
|
||||
:attr:`running_mean` and :attr:`running_var` statistics of the
|
||||
Batch Normalization module. You can do so by using
|
||||
`torchcontrib.optim.swa.bn_update` utility.
|
||||
.. note::
|
||||
See the blogpost
|
||||
https://pytorch.org/blog/stochastic-weight-averaging-in-pytorch/
|
||||
for an extended description of this SWA implementation.
|
||||
.. note::
|
||||
The repo https://github.com/izmailovpavel/contrib_swa_examples
|
||||
contains examples of using this SWA implementation.
|
||||
.. _Averaging Weights Leads to Wider Optima and Better Generalization:
|
||||
https://arxiv.org/abs/1803.05407
|
||||
.. _Improving Consistency-Based Semi-Supervised Learning with Weight
|
||||
Averaging:
|
||||
https://arxiv.org/abs/1806.05594
|
||||
"""
|
||||
self._auto_mode, (self.swa_freq,) = self._check_params(swa_freq)
|
||||
self.swa_lr = swa_lr
|
||||
|
||||
if self._auto_mode:
|
||||
if swa_freq < 1:
|
||||
raise ValueError("Invalid swa_freq: {}".format(swa_freq))
|
||||
else:
|
||||
if self.swa_lr is not None:
|
||||
warnings.warn(
|
||||
"Swa_freq is None, ignoring swa_lr")
|
||||
# If not in auto mode make all swa parameters None
|
||||
self.swa_lr = None
|
||||
self.swa_freq = None
|
||||
|
||||
if self.swa_lr is not None and self.swa_lr < 0:
|
||||
raise ValueError("Invalid SWA learning rate: {}".format(swa_lr))
|
||||
|
||||
self.optimizer = optimizer
|
||||
|
||||
self.defaults = self.optimizer.defaults
|
||||
self.param_groups = self.optimizer.param_groups
|
||||
self.state = defaultdict(dict)
|
||||
self.opt_state = self.optimizer.state
|
||||
for group in self.param_groups:
|
||||
group['n_avg'] = 0
|
||||
group['step_counter'] = 0
|
||||
|
||||
@staticmethod
|
||||
def _check_params(swa_freq):
|
||||
params = [swa_freq]
|
||||
params_none = [param is None for param in params]
|
||||
if not all(params_none) and any(params_none):
|
||||
warnings.warn(
|
||||
"Some of swa_start, swa_freq is None, ignoring other")
|
||||
for i, param in enumerate(params):
|
||||
if param is not None and not isinstance(param, int):
|
||||
params[i] = int(param)
|
||||
warnings.warn("Casting swa_start, swa_freq to int")
|
||||
return not any(params_none), params
|
||||
|
||||
def _reset_lr_to_swa(self):
|
||||
if self.swa_lr is None:
|
||||
return
|
||||
for param_group in self.param_groups:
|
||||
param_group['lr'] = self.swa_lr
|
||||
|
||||
def update_swa_group(self, group):
|
||||
r"""Updates the SWA running averages for the given parameter group.
|
||||
Arguments:
|
||||
group (dict): Specifies for what parameter group SWA running
|
||||
averages should be updated
|
||||
Examples:
|
||||
>>> # automatic mode
|
||||
>>> base_opt = torch.optim.SGD([{'params': [x]},
|
||||
>>> {'params': [y], 'lr': 1e-3}], lr=1e-2, momentum=0.9)
|
||||
>>> opt = torchcontrib.optim.SWA(base_opt)
|
||||
>>> for i in range(100):
|
||||
>>> opt.zero_grad()
|
||||
>>> loss_fn(model(input), target).backward()
|
||||
>>> opt.step()
|
||||
>>> if i > 10 and i % 5 == 0:
|
||||
>>> # Update SWA for the second parameter group
|
||||
>>> opt.update_swa_group(opt.param_groups[1])
|
||||
>>> opt.swap_swa_sgd()
|
||||
"""
|
||||
for p in group['params']:
|
||||
param_state = self.state[p]
|
||||
if 'swa_buffer' not in param_state:
|
||||
param_state['swa_buffer'] = torch.zeros_like(p.data)
|
||||
buf = param_state['swa_buffer']
|
||||
virtual_decay = 1 / float(group["n_avg"] + 1)
|
||||
diff = (p.data - buf) * virtual_decay
|
||||
buf.add_(diff)
|
||||
group["n_avg"] += 1
|
||||
|
||||
def update_swa(self):
|
||||
r"""Updates the SWA running averages of all optimized parameters.
|
||||
"""
|
||||
for group in self.param_groups:
|
||||
self.update_swa_group(group)
|
||||
|
||||
def swap_swa_sgd(self):
|
||||
r"""Swaps the values of the optimized variables and swa buffers.
|
||||
It's meant to be called in the end of training to use the collected
|
||||
swa running averages. It can also be used to evaluate the running
|
||||
averages during training; to continue training `swap_swa_sgd`
|
||||
should be called again.
|
||||
"""
|
||||
for group in self.param_groups:
|
||||
for p in group['params']:
|
||||
param_state = self.state[p]
|
||||
if 'swa_buffer' not in param_state:
|
||||
# If swa wasn't applied we don't swap params
|
||||
warnings.warn(
|
||||
"SWA wasn't applied to param {}; skipping it".format(p))
|
||||
continue
|
||||
buf = param_state['swa_buffer']
|
||||
tmp = torch.empty_like(p.data)
|
||||
tmp.copy_(p.data)
|
||||
p.data.copy_(buf)
|
||||
buf.copy_(tmp)
|
||||
|
||||
def step(self, closure=None):
|
||||
r"""Performs a single optimization step.
|
||||
In automatic mode also updates SWA running averages.
|
||||
"""
|
||||
self._reset_lr_to_swa()
|
||||
loss = self.optimizer.step(closure)
|
||||
for group in self.param_groups:
|
||||
group["step_counter"] += 1
|
||||
steps = group["step_counter"]
|
||||
if self._auto_mode:
|
||||
if steps % self.swa_freq == 0:
|
||||
self.update_swa_group(group)
|
||||
return loss
|
||||
|
||||
def state_dict(self):
|
||||
r"""Returns the state of SWA as a :class:`dict`.
|
||||
It contains three entries:
|
||||
* opt_state - a dict holding current optimization state of the base
|
||||
optimizer. Its content differs between optimizer classes.
|
||||
* swa_state - a dict containing current state of SWA. For each
|
||||
optimized variable it contains swa_buffer keeping the running
|
||||
average of the variable
|
||||
* param_groups - a dict containing all parameter groups
|
||||
"""
|
||||
opt_state_dict = self.optimizer.state_dict()
|
||||
swa_state = {(id(k) if isinstance(k, torch.Tensor) else k): v
|
||||
for k, v in self.state.items()}
|
||||
opt_state = opt_state_dict["state"]
|
||||
param_groups = opt_state_dict["param_groups"]
|
||||
return {"opt_state": opt_state, "swa_state": swa_state,
|
||||
"param_groups": param_groups}
|
||||
|
||||
def load_state_dict(self, state_dict):
|
||||
r"""Loads the optimizer state.
|
||||
Args:
|
||||
state_dict (dict): SWA optimizer state. Should be an object returned
|
||||
from a call to `state_dict`.
|
||||
"""
|
||||
swa_state_dict = {"state": state_dict["swa_state"],
|
||||
"param_groups": state_dict["param_groups"]}
|
||||
opt_state_dict = {"state": state_dict["opt_state"],
|
||||
"param_groups": state_dict["param_groups"]}
|
||||
super(SWA, self).load_state_dict(swa_state_dict)
|
||||
self.optimizer.load_state_dict(opt_state_dict)
|
||||
self.opt_state = self.optimizer.state
|
||||
|
||||
def add_param_group(self, param_group):
|
||||
r"""Add a param group to the :class:`Optimizer` s `param_groups`.
|
||||
This can be useful when fine tuning a pre-trained network as frozen
|
||||
layers can be made trainable and added to the :class:`Optimizer` as
|
||||
training progresses.
|
||||
Args:
|
||||
param_group (dict): Specifies what Tensors should be optimized along
|
||||
with group specific optimization options.
|
||||
"""
|
||||
param_group['n_avg'] = 0
|
||||
param_group['step_counter'] = 0
|
||||
self.optimizer.add_param_group(param_group)
|
Loading…
x
Reference in New Issue
Block a user