Add Adafactor and Adahessian optimizers, cleanup optimizer arg passing, add gradient clipping support.
parent
fcb6258877
commit
80078c47bb
|
@ -7,7 +7,8 @@ pip install -r requirements-sotabench.txt
|
|||
apt-get update
|
||||
apt-get install -y libjpeg-dev zlib1g-dev libpng-dev libwebp-dev
|
||||
pip uninstall -y pillow
|
||||
CC="cc -mavx2" pip install -U --force-reinstall pillow-simd
|
||||
CFLAGS="${CFLAGS} -mavx2" pip install -U --no-cache-dir --force-reinstall --no-binary :all:--compile https://github.com/mrT23/pillow-simd/zipball/simd/7.0.x
|
||||
#CC="cc -mavx2" pip install -U --force-reinstall pillow-simd
|
||||
|
||||
# FIXME this shouldn't be needed but sb dataset upload functionality doesn't seem to work
|
||||
apt-get install wget
|
||||
|
|
|
@ -1,10 +1,13 @@
|
|||
from .nadam import Nadam
|
||||
from .rmsprop_tf import RMSpropTF
|
||||
from .adamp import AdamP
|
||||
from .adamw import AdamW
|
||||
from .radam import RAdam
|
||||
from .adafactor import Adafactor
|
||||
from .adahessian import Adahessian
|
||||
from .lookahead import Lookahead
|
||||
from .nadam import Nadam
|
||||
from .novograd import NovoGrad
|
||||
from .nvnovograd import NvNovoGrad
|
||||
from .lookahead import Lookahead
|
||||
from .adamp import AdamP
|
||||
from .radam import RAdam
|
||||
from .rmsprop_tf import RMSpropTF
|
||||
from .sgdp import SGDP
|
||||
from .optim_factory import create_optimizer
|
||||
|
||||
from .optim_factory import create_optimizer
|
|
@ -0,0 +1,174 @@
|
|||
""" Adafactor Optimizer
|
||||
|
||||
Lifted from https://github.com/pytorch/fairseq/blob/master/fairseq/optim/adafactor.py
|
||||
|
||||
Original header/copyright below.
|
||||
|
||||
"""
|
||||
# Copyright (c) Facebook, Inc. and its affiliates.
|
||||
#
|
||||
# This source code is licensed under the MIT license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
import torch
|
||||
import math
|
||||
|
||||
|
||||
class Adafactor(torch.optim.Optimizer):
|
||||
"""Implements Adafactor algorithm.
|
||||
This implementation is based on: `Adafactor: Adaptive Learning Rates with Sublinear Memory Cost`
|
||||
(see https://arxiv.org/abs/1804.04235)
|
||||
|
||||
Note that this optimizer internally adjusts the learning rate depending on the
|
||||
*scale_parameter*, *relative_step* and *warmup_init* options.
|
||||
|
||||
To use a manual (external) learning rate schedule you should set `scale_parameter=False` and
|
||||
`relative_step=False`.
|
||||
|
||||
Arguments:
|
||||
params (iterable): iterable of parameters to optimize or dicts defining parameter groups
|
||||
lr (float, optional): external learning rate (default: None)
|
||||
eps (tuple[float, float]): regularization constants for square gradient
|
||||
and parameter scale respectively (default: (1e-30, 1e-3))
|
||||
clip_threshold (float): threshold of root mean square of final gradient update (default: 1.0)
|
||||
decay_rate (float): coefficient used to compute running averages of square gradient (default: -0.8)
|
||||
beta1 (float): coefficient used for computing running averages of gradient (default: None)
|
||||
weight_decay (float, optional): weight decay (L2 penalty) (default: 0)
|
||||
scale_parameter (bool): if True, learning rate is scaled by root mean square of parameter (default: True)
|
||||
relative_step (bool): if True, time-dependent learning rate is computed
|
||||
instead of external learning rate (default: True)
|
||||
warmup_init (bool): time-dependent learning rate computation depends on
|
||||
whether warm-up initialization is being used (default: False)
|
||||
"""
|
||||
|
||||
def __init__(self, params, lr=None, eps=1e-30, eps_scale=1e-3, clip_threshold=1.0,
|
||||
decay_rate=-0.8, betas=None, weight_decay=0.0, scale_parameter=True, warmup_init=False):
|
||||
relative_step = lr is None
|
||||
if warmup_init and not relative_step:
|
||||
raise ValueError('warmup_init requires relative_step=True')
|
||||
|
||||
beta1 = None if betas is None else betas[0] # make it compat with standard betas arg
|
||||
defaults = dict(lr=lr, eps=eps, eps_scale=eps_scale, clip_threshold=clip_threshold, decay_rate=decay_rate,
|
||||
beta1=beta1, weight_decay=weight_decay, scale_parameter=scale_parameter,
|
||||
relative_step=relative_step, warmup_init=warmup_init)
|
||||
super(Adafactor, self).__init__(params, defaults)
|
||||
|
||||
@staticmethod
|
||||
def _get_lr(param_group, param_state):
|
||||
if param_group['relative_step']:
|
||||
min_step = 1e-6 * param_state['step'] if param_group['warmup_init'] else 1e-2
|
||||
lr_t = min(min_step, 1.0 / math.sqrt(param_state['step']))
|
||||
param_scale = 1.0
|
||||
if param_group['scale_parameter']:
|
||||
param_scale = max(param_group['eps_scale'], param_state['RMS'])
|
||||
param_group['lr'] = lr_t * param_scale
|
||||
return param_group['lr']
|
||||
|
||||
@staticmethod
|
||||
def _get_options(param_group, param_shape):
|
||||
factored = len(param_shape) >= 2
|
||||
use_first_moment = param_group['beta1'] is not None
|
||||
return factored, use_first_moment
|
||||
|
||||
@staticmethod
|
||||
def _rms(tensor):
|
||||
return tensor.norm(2) / (tensor.numel() ** 0.5)
|
||||
|
||||
def _approx_sq_grad(self, exp_avg_sq_row, exp_avg_sq_col):
|
||||
r_factor = (exp_avg_sq_row / exp_avg_sq_row.mean(dim=-1, keepdim=True)).rsqrt_().unsqueeze(-1)
|
||||
c_factor = exp_avg_sq_col.unsqueeze(-2).rsqrt()
|
||||
return torch.mul(r_factor, c_factor)
|
||||
|
||||
def step(self, closure=None):
|
||||
"""Performs a single optimization step.
|
||||
Arguments:
|
||||
closure (callable, optional): A closure that reevaluates the model and returns the loss.
|
||||
"""
|
||||
loss = None
|
||||
if closure is not None:
|
||||
loss = closure()
|
||||
|
||||
for group in self.param_groups:
|
||||
for p in group['params']:
|
||||
if p.grad is None:
|
||||
continue
|
||||
grad = p.grad.data
|
||||
if grad.dtype in {torch.float16, torch.bfloat16}:
|
||||
grad = grad.float()
|
||||
if grad.is_sparse:
|
||||
raise RuntimeError('Adafactor does not support sparse gradients.')
|
||||
|
||||
state = self.state[p]
|
||||
grad_shape = grad.shape
|
||||
|
||||
factored, use_first_moment = self._get_options(group, grad_shape)
|
||||
# State Initialization
|
||||
if len(state) == 0:
|
||||
state['step'] = 0
|
||||
|
||||
if use_first_moment:
|
||||
# Exponential moving average of gradient values
|
||||
state['exp_avg'] = torch.zeros_like(grad)
|
||||
if factored:
|
||||
state['exp_avg_sq_row'] = torch.zeros(grad_shape[:-1]).to(grad)
|
||||
state['exp_avg_sq_col'] = torch.zeros(grad_shape[:-2] + grad_shape[-1:]).to(grad)
|
||||
else:
|
||||
state['exp_avg_sq'] = torch.zeros_like(grad)
|
||||
|
||||
state['RMS'] = 0
|
||||
else:
|
||||
if use_first_moment:
|
||||
state['exp_avg'] = state['exp_avg'].to(grad)
|
||||
if factored:
|
||||
state['exp_avg_sq_row'] = state['exp_avg_sq_row'].to(grad)
|
||||
state['exp_avg_sq_col'] = state['exp_avg_sq_col'].to(grad)
|
||||
else:
|
||||
state['exp_avg_sq'] = state['exp_avg_sq'].to(grad)
|
||||
|
||||
p_data_fp32 = p.data
|
||||
if p.data.dtype in {torch.float16, torch.bfloat16}:
|
||||
p_data_fp32 = p_data_fp32.float()
|
||||
|
||||
state['step'] += 1
|
||||
state['RMS'] = self._rms(p_data_fp32)
|
||||
lr_t = self._get_lr(group, state)
|
||||
|
||||
beta2t = 1.0 - math.pow(state['step'], group['decay_rate'])
|
||||
update = grad ** 2 + group['eps']
|
||||
if factored:
|
||||
exp_avg_sq_row = state['exp_avg_sq_row']
|
||||
exp_avg_sq_col = state['exp_avg_sq_col']
|
||||
|
||||
exp_avg_sq_row.mul_(beta2t).add_(1.0 - beta2t, update.mean(dim=-1))
|
||||
exp_avg_sq_col.mul_(beta2t).add_(1.0 - beta2t, update.mean(dim=-2))
|
||||
#exp_avg_sq_row.mul_(beta2t).add_(update.mean(dim=-1), alpha=1.0 - beta2t) # pytorch 1.6+
|
||||
#exp_avg_sq_col.mul_(beta2t).add_(update.mean(dim=-2), alpha=1.0 - beta2t)
|
||||
|
||||
# Approximation of exponential moving average of square of gradient
|
||||
update = self._approx_sq_grad(exp_avg_sq_row, exp_avg_sq_col)
|
||||
update.mul_(grad)
|
||||
else:
|
||||
exp_avg_sq = state['exp_avg_sq']
|
||||
|
||||
exp_avg_sq.mul_(beta2t).add_(1.0 - beta2t, update)
|
||||
#exp_avg_sq.mul_(beta2t).add_(update, alpha=1.0 - beta2t) # pytorch 1.6+
|
||||
update = exp_avg_sq.rsqrt().mul_(grad)
|
||||
|
||||
update.div_((self._rms(update) / group['clip_threshold']).clamp_(min=1.0))
|
||||
update.mul_(lr_t)
|
||||
|
||||
if use_first_moment:
|
||||
exp_avg = state['exp_avg']
|
||||
exp_avg.mul_(group["beta1"]).add_(1 - group["beta1"], update)
|
||||
#exp_avg.mul_(group['beta1']).add_(update, alpha=1 - group['beta1']) # pytorch 1.6+
|
||||
update = exp_avg
|
||||
|
||||
if group['weight_decay'] != 0:
|
||||
p_data_fp32.add_(-group["weight_decay"] * lr_t, p_data_fp32)
|
||||
#p_data_fp32.add_(p_data_fp32, alpha=-group['weight_decay'] * lr_t) # pytorch 1.6+
|
||||
|
||||
p_data_fp32.add_(-update)
|
||||
|
||||
if p.data.dtype in {torch.float16, torch.bfloat16}:
|
||||
p.data.copy_(p_data_fp32)
|
||||
|
||||
return loss
|
|
@ -0,0 +1,156 @@
|
|||
""" AdaHessian Optimizer
|
||||
|
||||
Lifted from https://github.com/davda54/ada-hessian/blob/master/ada_hessian.py
|
||||
Originally licensed MIT, Copyright 2020, David Samuel
|
||||
"""
|
||||
import torch
|
||||
|
||||
|
||||
class Adahessian(torch.optim.Optimizer):
|
||||
"""
|
||||
Implements the AdaHessian algorithm from "ADAHESSIAN: An Adaptive Second OrderOptimizer for Machine Learning"
|
||||
|
||||
Arguments:
|
||||
params (iterable): iterable of parameters to optimize or dicts defining parameter groups
|
||||
lr (float, optional): learning rate (default: 0.1)
|
||||
betas ((float, float), optional): coefficients used for computing running averages of gradient and the
|
||||
squared hessian trace (default: (0.9, 0.999))
|
||||
eps (float, optional): term added to the denominator to improve numerical stability (default: 1e-8)
|
||||
weight_decay (float, optional): weight decay (L2 penalty) (default: 0.0)
|
||||
hessian_power (float, optional): exponent of the hessian trace (default: 1.0)
|
||||
update_each (int, optional): compute the hessian trace approximation only after *this* number of steps
|
||||
(to save time) (default: 1)
|
||||
n_samples (int, optional): how many times to sample `z` for the approximation of the hessian trace (default: 1)
|
||||
"""
|
||||
|
||||
def __init__(self, params, lr=0.1, betas=(0.9, 0.999), eps=1e-8, weight_decay=0.0,
|
||||
hessian_power=1.0, update_each=1, n_samples=1, avg_conv_kernel=False):
|
||||
if not 0.0 <= lr:
|
||||
raise ValueError(f"Invalid learning rate: {lr}")
|
||||
if not 0.0 <= eps:
|
||||
raise ValueError(f"Invalid epsilon value: {eps}")
|
||||
if not 0.0 <= betas[0] < 1.0:
|
||||
raise ValueError(f"Invalid beta parameter at index 0: {betas[0]}")
|
||||
if not 0.0 <= betas[1] < 1.0:
|
||||
raise ValueError(f"Invalid beta parameter at index 1: {betas[1]}")
|
||||
if not 0.0 <= hessian_power <= 1.0:
|
||||
raise ValueError(f"Invalid Hessian power value: {hessian_power}")
|
||||
|
||||
self.n_samples = n_samples
|
||||
self.update_each = update_each
|
||||
self.avg_conv_kernel = avg_conv_kernel
|
||||
|
||||
# use a separate generator that deterministically generates the same `z`s across all GPUs in case of distributed training
|
||||
self.seed = 2147483647
|
||||
self.generator = torch.Generator().manual_seed(self.seed)
|
||||
|
||||
defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay, hessian_power=hessian_power)
|
||||
super(Adahessian, self).__init__(params, defaults)
|
||||
|
||||
for p in self.get_params():
|
||||
p.hess = 0.0
|
||||
self.state[p]["hessian step"] = 0
|
||||
|
||||
@property
|
||||
def is_second_order(self):
|
||||
return True
|
||||
|
||||
def get_params(self):
|
||||
"""
|
||||
Gets all parameters in all param_groups with gradients
|
||||
"""
|
||||
|
||||
return (p for group in self.param_groups for p in group['params'] if p.requires_grad)
|
||||
|
||||
def zero_hessian(self):
|
||||
"""
|
||||
Zeros out the accumalated hessian traces.
|
||||
"""
|
||||
|
||||
for p in self.get_params():
|
||||
if not isinstance(p.hess, float) and self.state[p]["hessian step"] % self.update_each == 0:
|
||||
p.hess.zero_()
|
||||
|
||||
@torch.no_grad()
|
||||
def set_hessian(self):
|
||||
"""
|
||||
Computes the Hutchinson approximation of the hessian trace and accumulates it for each trainable parameter.
|
||||
"""
|
||||
|
||||
params = []
|
||||
for p in filter(lambda p: p.grad is not None, self.get_params()):
|
||||
if self.state[p]["hessian step"] % self.update_each == 0: # compute the trace only each `update_each` step
|
||||
params.append(p)
|
||||
self.state[p]["hessian step"] += 1
|
||||
|
||||
if len(params) == 0:
|
||||
return
|
||||
|
||||
if self.generator.device != params[0].device: # hackish way of casting the generator to the right device
|
||||
self.generator = torch.Generator(params[0].device).manual_seed(self.seed)
|
||||
|
||||
grads = [p.grad for p in params]
|
||||
|
||||
for i in range(self.n_samples):
|
||||
# Rademacher distribution {-1.0, 1.0}
|
||||
zs = [torch.randint(0, 2, p.size(), generator=self.generator, device=p.device) * 2.0 - 1.0 for p in params]
|
||||
h_zs = torch.autograd.grad(
|
||||
grads, params, grad_outputs=zs, only_inputs=True, retain_graph=i < self.n_samples - 1)
|
||||
for h_z, z, p in zip(h_zs, zs, params):
|
||||
p.hess += h_z * z / self.n_samples # approximate the expected values of z*(H@z)
|
||||
|
||||
@torch.no_grad()
|
||||
def step(self, closure=None):
|
||||
"""
|
||||
Performs a single optimization step.
|
||||
Arguments:
|
||||
closure (callable, optional) -- a closure that reevaluates the model and returns the loss (default: None)
|
||||
"""
|
||||
|
||||
loss = None
|
||||
if closure is not None:
|
||||
loss = closure()
|
||||
|
||||
self.zero_hessian()
|
||||
self.set_hessian()
|
||||
|
||||
for group in self.param_groups:
|
||||
for p in group['params']:
|
||||
if p.grad is None or p.hess is None:
|
||||
continue
|
||||
|
||||
if self.avg_conv_kernel and p.dim() == 4:
|
||||
p.hess = torch.abs(p.hess).mean(dim=[2, 3], keepdim=True).expand_as(p.hess).clone()
|
||||
|
||||
# Perform correct stepweight decay as in AdamW
|
||||
p.mul_(1 - group['lr'] * group['weight_decay'])
|
||||
|
||||
state = self.state[p]
|
||||
|
||||
# State initialization
|
||||
if len(state) == 1:
|
||||
state['step'] = 0
|
||||
# Exponential moving average of gradient values
|
||||
state['exp_avg'] = torch.zeros_like(p)
|
||||
# Exponential moving average of Hessian diagonal square values
|
||||
state['exp_hessian_diag_sq'] = torch.zeros_like(p)
|
||||
|
||||
exp_avg, exp_hessian_diag_sq = state['exp_avg'], state['exp_hessian_diag_sq']
|
||||
beta1, beta2 = group['betas']
|
||||
state['step'] += 1
|
||||
|
||||
# Decay the first and second moment running average coefficient
|
||||
exp_avg.mul_(beta1).add_(p.grad, alpha=1 - beta1)
|
||||
exp_hessian_diag_sq.mul_(beta2).addcmul_(p.hess, p.hess, value=1 - beta2)
|
||||
|
||||
bias_correction1 = 1 - beta1 ** state['step']
|
||||
bias_correction2 = 1 - beta2 ** state['step']
|
||||
|
||||
k = group['hessian_power']
|
||||
denom = (exp_hessian_diag_sq / bias_correction2).pow_(k / 2).add_(group['eps'])
|
||||
|
||||
# make update
|
||||
step_size = group['lr'] / bias_correction1
|
||||
p.addcdiv_(exp_avg, denom, value=-step_size)
|
||||
|
||||
return loss
|
|
@ -3,7 +3,18 @@ Hacked together by / Copyright 2020 Ross Wightman
|
|||
"""
|
||||
import torch
|
||||
from torch import optim as optim
|
||||
from timm.optim import Nadam, RMSpropTF, AdamW, RAdam, NovoGrad, NvNovoGrad, Lookahead, AdamP, SGDP
|
||||
|
||||
from .adafactor import Adafactor
|
||||
from .adahessian import Adahessian
|
||||
from .adamp import AdamP
|
||||
from .lookahead import Lookahead
|
||||
from .nadam import Nadam
|
||||
from .novograd import NovoGrad
|
||||
from .nvnovograd import NvNovoGrad
|
||||
from .radam import RAdam
|
||||
from .rmsprop_tf import RMSpropTF
|
||||
from .sgdp import SGDP
|
||||
|
||||
try:
|
||||
from apex.optimizers import FusedNovoGrad, FusedAdam, FusedLAMB, FusedSGD
|
||||
has_apex = True
|
||||
|
@ -29,11 +40,6 @@ def add_weight_decay(model, weight_decay=1e-5, skip_list=()):
|
|||
def create_optimizer(args, model, filter_bias_and_bn=True):
|
||||
opt_lower = args.opt.lower()
|
||||
weight_decay = args.weight_decay
|
||||
if 'adamw' in opt_lower or 'radam' in opt_lower:
|
||||
# Compensate for the way current AdamW and RAdam optimizers apply LR to the weight-decay
|
||||
# I don't believe they follow the paper or original Torch7 impl which schedules weight
|
||||
# decay based on the ratio of current_lr/initial_lr
|
||||
weight_decay /= args.lr
|
||||
if weight_decay and filter_bias_and_bn:
|
||||
parameters = add_weight_decay(model, weight_decay)
|
||||
weight_decay = 0.
|
||||
|
@ -43,66 +49,59 @@ def create_optimizer(args, model, filter_bias_and_bn=True):
|
|||
if 'fused' in opt_lower:
|
||||
assert has_apex and torch.cuda.is_available(), 'APEX and CUDA required for fused optimizers'
|
||||
|
||||
opt_args = dict(lr=args.lr, weight_decay=weight_decay)
|
||||
if args.opt_eps is not None:
|
||||
opt_args['eps'] = args.opt_eps
|
||||
if args.opt_betas is not None:
|
||||
opt_args['betas'] = args.opt_betas
|
||||
|
||||
opt_split = opt_lower.split('_')
|
||||
opt_lower = opt_split[-1]
|
||||
if opt_lower == 'sgd' or opt_lower == 'nesterov':
|
||||
optimizer = optim.SGD(
|
||||
parameters, lr=args.lr, momentum=args.momentum, weight_decay=weight_decay, nesterov=True)
|
||||
optimizer = optim.SGD(parameters, momentum=args.momentum, nesterov=True, **opt_args)
|
||||
elif opt_lower == 'momentum':
|
||||
optimizer = optim.SGD(
|
||||
parameters, lr=args.lr, momentum=args.momentum, weight_decay=weight_decay, nesterov=False)
|
||||
optimizer = optim.SGD(parameters, momentum=args.momentum, nesterov=False, **opt_args)
|
||||
elif opt_lower == 'adam':
|
||||
optimizer = optim.Adam(
|
||||
parameters, lr=args.lr, weight_decay=weight_decay, eps=args.opt_eps)
|
||||
optimizer = optim.Adam(parameters, **opt_args)
|
||||
elif opt_lower == 'adamw':
|
||||
optimizer = AdamW(
|
||||
parameters, lr=args.lr, weight_decay=weight_decay, eps=args.opt_eps)
|
||||
optimizer = optim.AdamW(parameters, **opt_args)
|
||||
elif opt_lower == 'nadam':
|
||||
optimizer = Nadam(
|
||||
parameters, lr=args.lr, weight_decay=weight_decay, eps=args.opt_eps)
|
||||
optimizer = Nadam(parameters, **opt_args)
|
||||
elif opt_lower == 'radam':
|
||||
optimizer = RAdam(
|
||||
parameters, lr=args.lr, weight_decay=weight_decay, eps=args.opt_eps)
|
||||
optimizer = RAdam(parameters, **opt_args)
|
||||
elif opt_lower == 'adamp':
|
||||
optimizer = AdamP(
|
||||
parameters, lr=args.lr, weight_decay=weight_decay, eps=args.opt_eps,
|
||||
delta=0.1, wd_ratio=0.01, nesterov=True)
|
||||
optimizer = AdamP(parameters, wd_ratio=0.01, nesterov=True, **opt_args)
|
||||
elif opt_lower == 'sgdp':
|
||||
optimizer = SGDP(
|
||||
parameters, lr=args.lr, momentum=args.momentum, weight_decay=weight_decay,
|
||||
eps=args.opt_eps, nesterov=True)
|
||||
optimizer = SGDP(parameters, momentum=args.momentum, nesterov=True, **opt_args)
|
||||
elif opt_lower == 'adadelta':
|
||||
optimizer = optim.Adadelta(
|
||||
parameters, lr=args.lr, weight_decay=weight_decay, eps=args.opt_eps)
|
||||
optimizer = optim.Adadelta(parameters, **opt_args)
|
||||
elif opt_lower == 'adafactor':
|
||||
if not args.lr:
|
||||
opt_args['lr'] = None
|
||||
optimizer = Adafactor(parameters, **opt_args)
|
||||
elif opt_lower == 'adahessian':
|
||||
optimizer = Adahessian(parameters, **opt_args)
|
||||
elif opt_lower == 'rmsprop':
|
||||
optimizer = optim.RMSprop(
|
||||
parameters, lr=args.lr, alpha=0.9, eps=args.opt_eps,
|
||||
momentum=args.momentum, weight_decay=weight_decay)
|
||||
optimizer = optim.RMSprop(parameters, alpha=0.9, momentum=args.momentum, **opt_args)
|
||||
elif opt_lower == 'rmsproptf':
|
||||
optimizer = RMSpropTF(
|
||||
parameters, lr=args.lr, alpha=0.9, eps=args.opt_eps,
|
||||
momentum=args.momentum, weight_decay=weight_decay)
|
||||
optimizer = RMSpropTF(parameters, alpha=0.9, momentum=args.momentum, **opt_args)
|
||||
elif opt_lower == 'novograd':
|
||||
optimizer = NovoGrad(parameters, lr=args.lr, weight_decay=weight_decay, eps=args.opt_eps)
|
||||
optimizer = NovoGrad(parameters, **opt_args)
|
||||
elif opt_lower == 'nvnovograd':
|
||||
optimizer = NvNovoGrad(parameters, lr=args.lr, weight_decay=weight_decay, eps=args.opt_eps)
|
||||
optimizer = NvNovoGrad(parameters, **opt_args)
|
||||
elif opt_lower == 'fusedsgd':
|
||||
optimizer = FusedSGD(
|
||||
parameters, lr=args.lr, momentum=args.momentum, weight_decay=weight_decay, nesterov=True)
|
||||
optimizer = FusedSGD(parameters, momentum=args.momentum, nesterov=True, **opt_args)
|
||||
elif opt_lower == 'fusedmomentum':
|
||||
optimizer = FusedSGD(
|
||||
parameters, lr=args.lr, momentum=args.momentum, weight_decay=weight_decay, nesterov=False)
|
||||
optimizer = FusedSGD(parameters, momentum=args.momentum, nesterov=False, **opt_args)
|
||||
elif opt_lower == 'fusedadam':
|
||||
optimizer = FusedAdam(
|
||||
parameters, lr=args.lr, adam_w_mode=False, weight_decay=weight_decay, eps=args.opt_eps)
|
||||
optimizer = FusedAdam(parameters, adam_w_mode=False, **opt_args)
|
||||
elif opt_lower == 'fusedadamw':
|
||||
optimizer = FusedAdam(
|
||||
parameters, lr=args.lr, adam_w_mode=True, weight_decay=weight_decay, eps=args.opt_eps)
|
||||
optimizer = FusedAdam(parameters, adam_w_mode=True, **opt_args)
|
||||
elif opt_lower == 'fusedlamb':
|
||||
optimizer = FusedLAMB(parameters, lr=args.lr, weight_decay=weight_decay, eps=args.opt_eps)
|
||||
optimizer = FusedLAMB(parameters, **opt_args)
|
||||
elif opt_lower == 'fusednovograd':
|
||||
optimizer = FusedNovoGrad(
|
||||
parameters, lr=args.lr, betas=(0.95, 0.98), weight_decay=weight_decay, eps=args.opt_eps)
|
||||
opt_args.setdefault('betas', (0.95, 0.98))
|
||||
optimizer = FusedNovoGrad(parameters, **opt_args)
|
||||
else:
|
||||
assert False and "Invalid optimizer"
|
||||
raise ValueError
|
||||
|
|
|
@ -15,10 +15,10 @@ except ImportError:
|
|||
class ApexScaler:
|
||||
state_dict_key = "amp"
|
||||
|
||||
def __call__(self, loss, optimizer, clip_grad=None, parameters=None):
|
||||
def __call__(self, loss, optimizer, clip_grad=None, parameters=None, create_graph=False):
|
||||
with amp.scale_loss(loss, optimizer) as scaled_loss:
|
||||
scaled_loss.backward()
|
||||
if clip_grad:
|
||||
scaled_loss.backward(create_graph=create_graph)
|
||||
if clip_grad is not None:
|
||||
torch.nn.utils.clip_grad_norm_(amp.master_params(optimizer), clip_grad)
|
||||
optimizer.step()
|
||||
|
||||
|
@ -37,9 +37,9 @@ class NativeScaler:
|
|||
def __init__(self):
|
||||
self._scaler = torch.cuda.amp.GradScaler()
|
||||
|
||||
def __call__(self, loss, optimizer, clip_grad=None, parameters=None):
|
||||
self._scaler.scale(loss).backward()
|
||||
if clip_grad:
|
||||
def __call__(self, loss, optimizer, clip_grad=None, parameters=None, create_graph=False):
|
||||
self._scaler.scale(loss).backward(create_graph=create_graph)
|
||||
if clip_grad is not None:
|
||||
assert parameters is not None
|
||||
self._scaler.unscale_(optimizer) # unscale the gradients of optimizer's assigned params in-place
|
||||
torch.nn.utils.clip_grad_norm_(parameters, clip_grad)
|
||||
|
|
20
train.py
20
train.py
|
@ -98,12 +98,18 @@ parser.add_argument('-vb', '--validation-batch-size-multiplier', type=int, defau
|
|||
# Optimizer parameters
|
||||
parser.add_argument('--opt', default='sgd', type=str, metavar='OPTIMIZER',
|
||||
help='Optimizer (default: "sgd"')
|
||||
parser.add_argument('--opt-eps', default=1e-8, type=float, metavar='EPSILON',
|
||||
help='Optimizer Epsilon (default: 1e-8)')
|
||||
parser.add_argument('--opt-eps', default=None, type=float, metavar='EPSILON',
|
||||
help='Optimizer Epsilon (default: None, use opt default)')
|
||||
parser.add_argument('--opt-betas', default=None, type=float, nargs='+', metavar='BETA',
|
||||
help='Optimizer Betas (default: None, use opt default)')
|
||||
parser.add_argument('--momentum', type=float, default=0.9, metavar='M',
|
||||
help='SGD momentum (default: 0.9)')
|
||||
help='Optimizer momentum (default: 0.9)')
|
||||
parser.add_argument('--weight-decay', type=float, default=0.0001,
|
||||
help='weight decay (default: 0.0001)')
|
||||
parser.add_argument('--clip-grad', type=float, default=None, metavar='NORM',
|
||||
help='Clip gradient norm (default: None, no clipping)')
|
||||
|
||||
|
||||
|
||||
# Learning rate schedule parameters
|
||||
parser.add_argument('--sched', default='step', type=str, metavar='SCHEDULER',
|
||||
|
@ -595,6 +601,7 @@ def train_epoch(
|
|||
elif mixup_fn is not None:
|
||||
mixup_fn.mixup_enabled = False
|
||||
|
||||
second_order = hasattr(optimizer, 'is_second_order') and optimizer.is_second_order
|
||||
batch_time_m = AverageMeter()
|
||||
data_time_m = AverageMeter()
|
||||
losses_m = AverageMeter()
|
||||
|
@ -623,9 +630,12 @@ def train_epoch(
|
|||
|
||||
optimizer.zero_grad()
|
||||
if loss_scaler is not None:
|
||||
loss_scaler(loss, optimizer)
|
||||
loss_scaler(
|
||||
loss, optimizer, clip_grad=args.clip_grad, parameters=model.parameters(), create_graph=second_order)
|
||||
else:
|
||||
loss.backward()
|
||||
loss.backward(create_graph=second_order)
|
||||
if args.clip_grad is not None:
|
||||
torch.nn.utils.clip_grad_norm_(model.parameters(), args.clip_grad)
|
||||
optimizer.step()
|
||||
|
||||
torch.cuda.synchronize()
|
||||
|
|
Loading…
Reference in New Issue