[Feature] Support LAMB optimizer. (#591)

* impl lamb

* Add unit tests

* Fix unit test

* Fix unit tests

* Use list instead of tuple in `__all__` according to PEP8

Co-authored-by: mzr1996 <mzr1996@163.com>
pull/602/head
Zhicheng Chen 2021-12-13 17:24:44 +08:00 committed by GitHub
parent 188aa6ed5d
commit 851b438574
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 545 additions and 0 deletions

View File

@ -1,4 +1,5 @@
# Copyright (c) OpenMMLab. All rights reserved.
from .evaluation import * # noqa: F401, F403
from .fp16 import * # noqa: F401, F403
from .optimizers import * # noqa: F401, F403
from .utils import * # noqa: F401, F403

View File

@ -0,0 +1,5 @@
from .lamb import Lamb
__all__ = [
'Lamb',
]

View File

@ -0,0 +1,231 @@
"""PyTorch Lamb optimizer w/ behaviour similar to NVIDIA FusedLamb.
This optimizer code was adapted from the following (starting with latest)
* https://github.com/HabanaAI/Model-References/blob/
2b435114fe8e31f159b1d3063b8280ae37af7423/PyTorch/nlp/bert/pretraining/lamb.py
* https://github.com/NVIDIA/DeepLearningExamples/blob/master/PyTorch/
LanguageModeling/Transformer-XL/pytorch/lamb.py
* https://github.com/cybertronai/pytorch-lamb
Use FusedLamb if you can (GPU). The reason for including this variant of Lamb
is to have a version that is
similar in behaviour to APEX FusedLamb if you aren't using NVIDIA GPUs or
cannot install/use APEX.
In addition to some cleanup, this Lamb impl has been modified to support
PyTorch XLA and has been tested on TPU.
Original copyrights for above sources are below.
Modifications Copyright 2021 Ross Wightman
"""
# Copyright (c) 2021, Habana Labs Ltd. All rights reserved.
# Copyright (c) 2019-2020, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# MIT License
#
# Copyright (c) 2019 cybertronai
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
# in the Software without restriction, including without limitation the rights
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
# copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in
# all copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
# SOFTWARE.
import math
import torch
from mmcv.runner import OPTIMIZERS
from torch.optim import Optimizer
@OPTIMIZERS.register_module()
class Lamb(Optimizer):
"""Implements a pure pytorch variant of FuseLAMB (NvLamb variant) optimizer
from apex.optimizers.FusedLAMB
reference: https://github.com/NVIDIA/DeepLearningExamples/blob/master/
PyTorch/LanguageModeling/Transformer-XL/pytorch/lamb.py
LAMB was proposed in `Large Batch Optimization for Deep Learning: Training
BERT in 76 minutes`_.
Arguments:
params (iterable): iterable of parameters to optimize or dicts defining
parameter groups.
lr (float, optional): learning rate. (default: 1e-3)
betas (Tuple[float, float], optional): coefficients used for computing
running averages of gradient and its norm. (default: (0.9, 0.999))
eps (float, optional): term added to the denominator to improve
numerical stability. (default: 1e-8)
weight_decay (float, optional): weight decay (L2 penalty) (default: 0)
grad_averaging (bool, optional): whether apply (1-beta2) to grad when
calculating running averages of gradient. (default: True)
max_grad_norm (float, optional): value used to clip global grad norm
(default: 1.0)
trust_clip (bool): enable LAMBC trust ratio clipping (default: False)
always_adapt (boolean, optional): Apply adaptive learning rate to 0.0
weight decay parameter (default: False)
.. _Large Batch Optimization for Deep Learning - Training BERT in 76
minutes:
https://arxiv.org/abs/1904.00962
.. _On the Convergence of Adam and Beyond:
https://openreview.net/forum?id=ryQu7f-RZ
"""
def __init__(self,
params,
lr=1e-3,
bias_correction=True,
betas=(0.9, 0.999),
eps=1e-6,
weight_decay=0.01,
grad_averaging=True,
max_grad_norm=1.0,
trust_clip=False,
always_adapt=False):
defaults = dict(
lr=lr,
bias_correction=bias_correction,
betas=betas,
eps=eps,
weight_decay=weight_decay,
grad_averaging=grad_averaging,
max_grad_norm=max_grad_norm,
trust_clip=trust_clip,
always_adapt=always_adapt)
super().__init__(params, defaults)
@torch.no_grad()
def step(self, closure=None):
"""Performs a single optimization step.
Arguments:
closure (callable, optional): A closure that reevaluates the model
and returns the loss.
"""
loss = None
if closure is not None:
with torch.enable_grad():
loss = closure()
device = self.param_groups[0]['params'][0].device
one_tensor = torch.tensor(
1.0, device=device
) # because torch.where doesn't handle scalars correctly
global_grad_norm = torch.zeros(1, device=device)
for group in self.param_groups:
for p in group['params']:
if p.grad is None:
continue
grad = p.grad
if grad.is_sparse:
raise RuntimeError(
'Lamb does not support sparse gradients, consider '
'SparseAdam instead.')
global_grad_norm.add_(grad.pow(2).sum())
global_grad_norm = torch.sqrt(global_grad_norm)
# FIXME it'd be nice to remove explicit tensor conversion of scalars
# when torch.where promotes
# scalar types properly https://github.com/pytorch/pytorch/issues/9190
max_grad_norm = torch.tensor(
self.defaults['max_grad_norm'], device=device)
clip_global_grad_norm = torch.where(global_grad_norm > max_grad_norm,
global_grad_norm / max_grad_norm,
one_tensor)
for group in self.param_groups:
bias_correction = 1 if group['bias_correction'] else 0
beta1, beta2 = group['betas']
grad_averaging = 1 if group['grad_averaging'] else 0
beta3 = 1 - beta1 if grad_averaging else 1.0
# assume same step across group now to simplify things
# per parameter step can be easily support by making it tensor, or
# pass list into kernel
if 'step' in group:
group['step'] += 1
else:
group['step'] = 1
if bias_correction:
bias_correction1 = 1 - beta1**group['step']
bias_correction2 = 1 - beta2**group['step']
else:
bias_correction1, bias_correction2 = 1.0, 1.0
for p in group['params']:
if p.grad is None:
continue
grad = p.grad.div_(clip_global_grad_norm)
state = self.state[p]
# State initialization
if len(state) == 0:
# Exponential moving average of gradient valuesa
state['exp_avg'] = torch.zeros_like(p)
# Exponential moving average of squared gradient values
state['exp_avg_sq'] = torch.zeros_like(p)
exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq']
# Decay the first and second moment running average coefficient
exp_avg.mul_(beta1).add_(grad, alpha=beta3) # m_t
exp_avg_sq.mul_(beta2).addcmul_(
grad, grad, value=1 - beta2) # v_t
denom = (exp_avg_sq.sqrt() / math.sqrt(bias_correction2)).add_(
group['eps'])
update = (exp_avg / bias_correction1).div_(denom)
weight_decay = group['weight_decay']
if weight_decay != 0:
update.add_(p, alpha=weight_decay)
if weight_decay != 0 or group['always_adapt']:
# Layer-wise LR adaptation. By default, skip adaptation on
# parameters that are
# excluded from weight decay, unless always_adapt == True,
# then always enabled.
w_norm = p.norm(2.0)
g_norm = update.norm(2.0)
# FIXME nested where required since logical and/or not
# working in PT XLA
trust_ratio = torch.where(
w_norm > 0,
torch.where(g_norm > 0, w_norm / g_norm, one_tensor),
one_tensor,
)
if group['trust_clip']:
# LAMBC trust clipping, upper bound fixed at one
trust_ratio = torch.minimum(trust_ratio, one_tensor)
update.mul_(trust_ratio)
p.add_(update, alpha=-group['lr'])
return loss

View File

@ -0,0 +1,308 @@
import functools
from collections import OrderedDict
from copy import deepcopy
from typing import Iterable
import torch
import torch.nn as nn
from mmcv.runner import build_optimizer
from mmcv.runner.optimizer.builder import OPTIMIZERS
from mmcv.utils.registry import build_from_cfg
from torch.autograd import Variable
from torch.optim.optimizer import Optimizer
import mmcls.core # noqa: F401
base_lr = 0.01
base_wd = 0.0001
def assert_equal(x, y):
if isinstance(x, torch.Tensor) and isinstance(y, torch.Tensor):
torch.testing.assert_allclose(x, y.to(x.device))
elif isinstance(x, OrderedDict) and isinstance(y, OrderedDict):
for x_value, y_value in zip(x.values(), y.values()):
assert_equal(x_value, y_value)
elif isinstance(x, dict) and isinstance(y, dict):
assert x.keys() == y.keys()
for key in x.keys():
assert_equal(x[key], y[key])
elif isinstance(x, str) and isinstance(y, str):
assert x == y
elif isinstance(x, Iterable) and isinstance(y, Iterable):
assert len(x) == len(y)
for x_item, y_item in zip(x, y):
assert_equal(x_item, y_item)
else:
assert x == y
class SubModel(nn.Module):
def __init__(self):
super().__init__()
self.conv1 = nn.Conv2d(2, 2, kernel_size=1, groups=2)
self.gn = nn.GroupNorm(2, 2)
self.fc = nn.Linear(2, 2)
self.param1 = nn.Parameter(torch.ones(1))
def forward(self, x):
return x
class ExampleModel(nn.Module):
def __init__(self):
super().__init__()
self.param1 = nn.Parameter(torch.ones(1))
self.conv1 = nn.Conv2d(3, 4, kernel_size=1, bias=False)
self.conv2 = nn.Conv2d(4, 2, kernel_size=1)
self.bn = nn.BatchNorm2d(2)
self.sub = SubModel()
self.fc = nn.Linear(2, 1)
def forward(self, x):
return x
def check_lamb_optimizer(optimizer,
model,
bias_lr_mult=1,
bias_decay_mult=1,
norm_decay_mult=1,
dwconv_decay_mult=1):
param_groups = optimizer.param_groups
assert isinstance(optimizer, Optimizer)
assert optimizer.defaults['lr'] == base_lr
assert optimizer.defaults['weight_decay'] == base_wd
model_parameters = list(model.parameters())
assert len(param_groups) == len(model_parameters)
for i, param in enumerate(model_parameters):
param_group = param_groups[i]
assert torch.equal(param_group['params'][0], param)
# param1
param1 = param_groups[0]
assert param1['lr'] == base_lr
assert param1['weight_decay'] == base_wd
# conv1.weight
conv1_weight = param_groups[1]
assert conv1_weight['lr'] == base_lr
assert conv1_weight['weight_decay'] == base_wd
# conv2.weight
conv2_weight = param_groups[2]
assert conv2_weight['lr'] == base_lr
assert conv2_weight['weight_decay'] == base_wd
# conv2.bias
conv2_bias = param_groups[3]
assert conv2_bias['lr'] == base_lr * bias_lr_mult
assert conv2_bias['weight_decay'] == base_wd * bias_decay_mult
# bn.weight
bn_weight = param_groups[4]
assert bn_weight['lr'] == base_lr
assert bn_weight['weight_decay'] == base_wd * norm_decay_mult
# bn.bias
bn_bias = param_groups[5]
assert bn_bias['lr'] == base_lr
assert bn_bias['weight_decay'] == base_wd * norm_decay_mult
# sub.param1
sub_param1 = param_groups[6]
assert sub_param1['lr'] == base_lr
assert sub_param1['weight_decay'] == base_wd
# sub.conv1.weight
sub_conv1_weight = param_groups[7]
assert sub_conv1_weight['lr'] == base_lr
assert sub_conv1_weight['weight_decay'] == base_wd * dwconv_decay_mult
# sub.conv1.bias
sub_conv1_bias = param_groups[8]
assert sub_conv1_bias['lr'] == base_lr * bias_lr_mult
assert sub_conv1_bias['weight_decay'] == base_wd * dwconv_decay_mult
# sub.gn.weight
sub_gn_weight = param_groups[9]
assert sub_gn_weight['lr'] == base_lr
assert sub_gn_weight['weight_decay'] == base_wd * norm_decay_mult
# sub.gn.bias
sub_gn_bias = param_groups[10]
assert sub_gn_bias['lr'] == base_lr
assert sub_gn_bias['weight_decay'] == base_wd * norm_decay_mult
# sub.fc1.weight
sub_fc_weight = param_groups[11]
assert sub_fc_weight['lr'] == base_lr
assert sub_fc_weight['weight_decay'] == base_wd
# sub.fc1.bias
sub_fc_bias = param_groups[12]
assert sub_fc_bias['lr'] == base_lr * bias_lr_mult
assert sub_fc_bias['weight_decay'] == base_wd * bias_decay_mult
# fc1.weight
fc_weight = param_groups[13]
assert fc_weight['lr'] == base_lr
assert fc_weight['weight_decay'] == base_wd
# fc1.bias
fc_bias = param_groups[14]
assert fc_bias['lr'] == base_lr * bias_lr_mult
assert fc_bias['weight_decay'] == base_wd * bias_decay_mult
def _test_state_dict(weight, bias, input, constructor):
weight = Variable(weight, requires_grad=True)
bias = Variable(bias, requires_grad=True)
inputs = Variable(input)
def fn_base(optimizer, weight, bias):
optimizer.zero_grad()
i = input_cuda if weight.is_cuda else inputs
loss = (weight.mv(i) + bias).pow(2).sum()
loss.backward()
return loss
optimizer = constructor(weight, bias)
fn = functools.partial(fn_base, optimizer, weight, bias)
# Prime the optimizer
for _ in range(20):
optimizer.step(fn)
# Clone the weights and construct new optimizer for them
weight_c = Variable(weight.data.clone(), requires_grad=True)
bias_c = Variable(bias.data.clone(), requires_grad=True)
optimizer_c = constructor(weight_c, bias_c)
fn_c = functools.partial(fn_base, optimizer_c, weight_c, bias_c)
# Load state dict
state_dict = deepcopy(optimizer.state_dict())
state_dict_c = deepcopy(optimizer.state_dict())
optimizer_c.load_state_dict(state_dict_c)
# Run both optimizations in parallel
for _ in range(20):
optimizer.step(fn)
optimizer_c.step(fn_c)
assert_equal(weight, weight_c)
assert_equal(bias, bias_c)
# Make sure state dict wasn't modified
assert_equal(state_dict, state_dict_c)
# Make sure state dict is deterministic with equal
# but not identical parameters
# NOTE: The state_dict of optimizers in PyTorch 1.5 have random keys,
state_dict = deepcopy(optimizer.state_dict())
state_dict_c = deepcopy(optimizer_c.state_dict())
keys = state_dict['param_groups'][-1]['params']
keys_c = state_dict_c['param_groups'][-1]['params']
for key, key_c in zip(keys, keys_c):
assert_equal(optimizer.state_dict()['state'][key],
optimizer_c.state_dict()['state'][key_c])
# Make sure repeated parameters have identical representation in state dict
optimizer_c.param_groups.extend(optimizer_c.param_groups)
assert_equal(optimizer_c.state_dict()['param_groups'][0],
optimizer_c.state_dict()['param_groups'][1])
# Check that state dict can be loaded even when we cast parameters
# to a different type and move to a different device.
if not torch.cuda.is_available():
return
input_cuda = Variable(inputs.data.float().cuda())
weight_cuda = Variable(weight.data.float().cuda(), requires_grad=True)
bias_cuda = Variable(bias.data.float().cuda(), requires_grad=True)
optimizer_cuda = constructor(weight_cuda, bias_cuda)
fn_cuda = functools.partial(fn_base, optimizer_cuda, weight_cuda,
bias_cuda)
state_dict = deepcopy(optimizer.state_dict())
state_dict_c = deepcopy(optimizer.state_dict())
optimizer_cuda.load_state_dict(state_dict_c)
# Make sure state dict wasn't modified
assert_equal(state_dict, state_dict_c)
for _ in range(20):
optimizer.step(fn)
optimizer_cuda.step(fn_cuda)
assert_equal(weight, weight_cuda)
assert_equal(bias, bias_cuda)
# validate deepcopy() copies all public attributes
def getPublicAttr(obj):
return set(k for k in obj.__dict__ if not k.startswith('_'))
assert_equal(getPublicAttr(optimizer), getPublicAttr(deepcopy(optimizer)))
def _test_basic_cases_template(weight, bias, inputs, constructor,
scheduler_constructors):
"""Copied from PyTorch."""
weight = Variable(weight, requires_grad=True)
bias = Variable(bias, requires_grad=True)
inputs = Variable(inputs)
optimizer = constructor(weight, bias)
schedulers = []
for scheduler_constructor in scheduler_constructors:
schedulers.append(scheduler_constructor(optimizer))
# to check if the optimizer can be printed as a string
optimizer.__repr__()
def fn():
optimizer.zero_grad()
y = weight.mv(inputs)
if y.is_cuda and bias.is_cuda and y.get_device() != bias.get_device():
y = y.cuda(bias.get_device())
loss = (y + bias).pow(2).sum()
loss.backward()
return loss
initial_value = fn().item()
for _ in range(200):
for scheduler in schedulers:
scheduler.step()
optimizer.step(fn)
assert fn().item() < initial_value
def _test_basic_cases(constructor,
scheduler_constructors=None,
ignore_multidevice=False):
"""Copied from PyTorch."""
if scheduler_constructors is None:
scheduler_constructors = []
_test_state_dict(
torch.randn(10, 5), torch.randn(10), torch.randn(5), constructor)
_test_basic_cases_template(
torch.randn(10, 5), torch.randn(10), torch.randn(5), constructor,
scheduler_constructors)
# non-contiguous parameters
_test_basic_cases_template(
torch.randn(10, 5, 2)[..., 0],
torch.randn(10, 2)[..., 0], torch.randn(5), constructor,
scheduler_constructors)
# CUDA
if not torch.cuda.is_available():
return
_test_basic_cases_template(
torch.randn(10, 5).cuda(),
torch.randn(10).cuda(),
torch.randn(5).cuda(), constructor, scheduler_constructors)
# Multi-GPU
if not torch.cuda.device_count() > 1 or ignore_multidevice:
return
_test_basic_cases_template(
torch.randn(10, 5).cuda(0),
torch.randn(10).cuda(1),
torch.randn(5).cuda(0), constructor, scheduler_constructors)
def test_lamb_optimizer():
model = ExampleModel()
optimizer_cfg = dict(
type='Lamb',
lr=base_lr,
betas=(0.9, 0.999),
eps=1e-8,
weight_decay=base_wd,
paramwise_cfg=dict(
bias_lr_mult=2,
bias_decay_mult=0.5,
norm_decay_mult=0,
dwconv_decay_mult=0.1))
optimizer = build_optimizer(model, optimizer_cfg)
check_lamb_optimizer(optimizer, model, **optimizer_cfg['paramwise_cfg'])
_test_basic_cases(lambda weight, bias: build_from_cfg(
dict(type='Lamb', params=[weight, bias], lr=base_lr), OPTIMIZERS))