[Feature] Support LAMB optimizer. (#591)
* impl lamb * Add unit tests * Fix unit test * Fix unit tests * Use list instead of tuple in `__all__` according to PEP8 Co-authored-by: mzr1996 <mzr1996@163.com>pull/602/head
parent
188aa6ed5d
commit
851b438574
|
@ -1,4 +1,5 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from .evaluation import * # noqa: F401, F403
|
||||
from .fp16 import * # noqa: F401, F403
|
||||
from .optimizers import * # noqa: F401, F403
|
||||
from .utils import * # noqa: F401, F403
|
||||
|
|
|
@ -0,0 +1,5 @@
|
|||
from .lamb import Lamb
|
||||
|
||||
__all__ = [
|
||||
'Lamb',
|
||||
]
|
|
@ -0,0 +1,231 @@
|
|||
"""PyTorch Lamb optimizer w/ behaviour similar to NVIDIA FusedLamb.
|
||||
|
||||
This optimizer code was adapted from the following (starting with latest)
|
||||
* https://github.com/HabanaAI/Model-References/blob/
|
||||
2b435114fe8e31f159b1d3063b8280ae37af7423/PyTorch/nlp/bert/pretraining/lamb.py
|
||||
* https://github.com/NVIDIA/DeepLearningExamples/blob/master/PyTorch/
|
||||
LanguageModeling/Transformer-XL/pytorch/lamb.py
|
||||
* https://github.com/cybertronai/pytorch-lamb
|
||||
|
||||
Use FusedLamb if you can (GPU). The reason for including this variant of Lamb
|
||||
is to have a version that is
|
||||
similar in behaviour to APEX FusedLamb if you aren't using NVIDIA GPUs or
|
||||
cannot install/use APEX.
|
||||
|
||||
In addition to some cleanup, this Lamb impl has been modified to support
|
||||
PyTorch XLA and has been tested on TPU.
|
||||
|
||||
Original copyrights for above sources are below.
|
||||
|
||||
Modifications Copyright 2021 Ross Wightman
|
||||
"""
|
||||
# Copyright (c) 2021, Habana Labs Ltd. All rights reserved.
|
||||
|
||||
# Copyright (c) 2019-2020, NVIDIA CORPORATION. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
# MIT License
|
||||
#
|
||||
# Copyright (c) 2019 cybertronai
|
||||
#
|
||||
# Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
# of this software and associated documentation files (the "Software"), to deal
|
||||
# in the Software without restriction, including without limitation the rights
|
||||
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||
# copies of the Software, and to permit persons to whom the Software is
|
||||
# furnished to do so, subject to the following conditions:
|
||||
#
|
||||
# The above copyright notice and this permission notice shall be included in
|
||||
# all copies or substantial portions of the Software.
|
||||
#
|
||||
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
||||
# SOFTWARE.
|
||||
import math
|
||||
|
||||
import torch
|
||||
from mmcv.runner import OPTIMIZERS
|
||||
from torch.optim import Optimizer
|
||||
|
||||
|
||||
@OPTIMIZERS.register_module()
|
||||
class Lamb(Optimizer):
|
||||
"""Implements a pure pytorch variant of FuseLAMB (NvLamb variant) optimizer
|
||||
from apex.optimizers.FusedLAMB
|
||||
reference: https://github.com/NVIDIA/DeepLearningExamples/blob/master/
|
||||
PyTorch/LanguageModeling/Transformer-XL/pytorch/lamb.py
|
||||
|
||||
LAMB was proposed in `Large Batch Optimization for Deep Learning: Training
|
||||
BERT in 76 minutes`_.
|
||||
|
||||
Arguments:
|
||||
params (iterable): iterable of parameters to optimize or dicts defining
|
||||
parameter groups.
|
||||
lr (float, optional): learning rate. (default: 1e-3)
|
||||
betas (Tuple[float, float], optional): coefficients used for computing
|
||||
running averages of gradient and its norm. (default: (0.9, 0.999))
|
||||
eps (float, optional): term added to the denominator to improve
|
||||
numerical stability. (default: 1e-8)
|
||||
weight_decay (float, optional): weight decay (L2 penalty) (default: 0)
|
||||
grad_averaging (bool, optional): whether apply (1-beta2) to grad when
|
||||
calculating running averages of gradient. (default: True)
|
||||
max_grad_norm (float, optional): value used to clip global grad norm
|
||||
(default: 1.0)
|
||||
trust_clip (bool): enable LAMBC trust ratio clipping (default: False)
|
||||
always_adapt (boolean, optional): Apply adaptive learning rate to 0.0
|
||||
weight decay parameter (default: False)
|
||||
|
||||
.. _Large Batch Optimization for Deep Learning - Training BERT in 76
|
||||
minutes:
|
||||
https://arxiv.org/abs/1904.00962
|
||||
.. _On the Convergence of Adam and Beyond:
|
||||
https://openreview.net/forum?id=ryQu7f-RZ
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
params,
|
||||
lr=1e-3,
|
||||
bias_correction=True,
|
||||
betas=(0.9, 0.999),
|
||||
eps=1e-6,
|
||||
weight_decay=0.01,
|
||||
grad_averaging=True,
|
||||
max_grad_norm=1.0,
|
||||
trust_clip=False,
|
||||
always_adapt=False):
|
||||
defaults = dict(
|
||||
lr=lr,
|
||||
bias_correction=bias_correction,
|
||||
betas=betas,
|
||||
eps=eps,
|
||||
weight_decay=weight_decay,
|
||||
grad_averaging=grad_averaging,
|
||||
max_grad_norm=max_grad_norm,
|
||||
trust_clip=trust_clip,
|
||||
always_adapt=always_adapt)
|
||||
super().__init__(params, defaults)
|
||||
|
||||
@torch.no_grad()
|
||||
def step(self, closure=None):
|
||||
"""Performs a single optimization step.
|
||||
|
||||
Arguments:
|
||||
closure (callable, optional): A closure that reevaluates the model
|
||||
and returns the loss.
|
||||
"""
|
||||
loss = None
|
||||
if closure is not None:
|
||||
with torch.enable_grad():
|
||||
loss = closure()
|
||||
|
||||
device = self.param_groups[0]['params'][0].device
|
||||
one_tensor = torch.tensor(
|
||||
1.0, device=device
|
||||
) # because torch.where doesn't handle scalars correctly
|
||||
global_grad_norm = torch.zeros(1, device=device)
|
||||
for group in self.param_groups:
|
||||
for p in group['params']:
|
||||
if p.grad is None:
|
||||
continue
|
||||
grad = p.grad
|
||||
if grad.is_sparse:
|
||||
raise RuntimeError(
|
||||
'Lamb does not support sparse gradients, consider '
|
||||
'SparseAdam instead.')
|
||||
global_grad_norm.add_(grad.pow(2).sum())
|
||||
|
||||
global_grad_norm = torch.sqrt(global_grad_norm)
|
||||
# FIXME it'd be nice to remove explicit tensor conversion of scalars
|
||||
# when torch.where promotes
|
||||
# scalar types properly https://github.com/pytorch/pytorch/issues/9190
|
||||
max_grad_norm = torch.tensor(
|
||||
self.defaults['max_grad_norm'], device=device)
|
||||
clip_global_grad_norm = torch.where(global_grad_norm > max_grad_norm,
|
||||
global_grad_norm / max_grad_norm,
|
||||
one_tensor)
|
||||
|
||||
for group in self.param_groups:
|
||||
bias_correction = 1 if group['bias_correction'] else 0
|
||||
beta1, beta2 = group['betas']
|
||||
grad_averaging = 1 if group['grad_averaging'] else 0
|
||||
beta3 = 1 - beta1 if grad_averaging else 1.0
|
||||
|
||||
# assume same step across group now to simplify things
|
||||
# per parameter step can be easily support by making it tensor, or
|
||||
# pass list into kernel
|
||||
if 'step' in group:
|
||||
group['step'] += 1
|
||||
else:
|
||||
group['step'] = 1
|
||||
|
||||
if bias_correction:
|
||||
bias_correction1 = 1 - beta1**group['step']
|
||||
bias_correction2 = 1 - beta2**group['step']
|
||||
else:
|
||||
bias_correction1, bias_correction2 = 1.0, 1.0
|
||||
|
||||
for p in group['params']:
|
||||
if p.grad is None:
|
||||
continue
|
||||
grad = p.grad.div_(clip_global_grad_norm)
|
||||
state = self.state[p]
|
||||
|
||||
# State initialization
|
||||
if len(state) == 0:
|
||||
# Exponential moving average of gradient valuesa
|
||||
state['exp_avg'] = torch.zeros_like(p)
|
||||
# Exponential moving average of squared gradient values
|
||||
state['exp_avg_sq'] = torch.zeros_like(p)
|
||||
|
||||
exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq']
|
||||
|
||||
# Decay the first and second moment running average coefficient
|
||||
exp_avg.mul_(beta1).add_(grad, alpha=beta3) # m_t
|
||||
exp_avg_sq.mul_(beta2).addcmul_(
|
||||
grad, grad, value=1 - beta2) # v_t
|
||||
|
||||
denom = (exp_avg_sq.sqrt() / math.sqrt(bias_correction2)).add_(
|
||||
group['eps'])
|
||||
update = (exp_avg / bias_correction1).div_(denom)
|
||||
|
||||
weight_decay = group['weight_decay']
|
||||
if weight_decay != 0:
|
||||
update.add_(p, alpha=weight_decay)
|
||||
|
||||
if weight_decay != 0 or group['always_adapt']:
|
||||
# Layer-wise LR adaptation. By default, skip adaptation on
|
||||
# parameters that are
|
||||
# excluded from weight decay, unless always_adapt == True,
|
||||
# then always enabled.
|
||||
w_norm = p.norm(2.0)
|
||||
g_norm = update.norm(2.0)
|
||||
# FIXME nested where required since logical and/or not
|
||||
# working in PT XLA
|
||||
trust_ratio = torch.where(
|
||||
w_norm > 0,
|
||||
torch.where(g_norm > 0, w_norm / g_norm, one_tensor),
|
||||
one_tensor,
|
||||
)
|
||||
if group['trust_clip']:
|
||||
# LAMBC trust clipping, upper bound fixed at one
|
||||
trust_ratio = torch.minimum(trust_ratio, one_tensor)
|
||||
update.mul_(trust_ratio)
|
||||
|
||||
p.add_(update, alpha=-group['lr'])
|
||||
|
||||
return loss
|
|
@ -0,0 +1,308 @@
|
|||
import functools
|
||||
from collections import OrderedDict
|
||||
from copy import deepcopy
|
||||
from typing import Iterable
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from mmcv.runner import build_optimizer
|
||||
from mmcv.runner.optimizer.builder import OPTIMIZERS
|
||||
from mmcv.utils.registry import build_from_cfg
|
||||
from torch.autograd import Variable
|
||||
from torch.optim.optimizer import Optimizer
|
||||
|
||||
import mmcls.core # noqa: F401
|
||||
|
||||
base_lr = 0.01
|
||||
base_wd = 0.0001
|
||||
|
||||
|
||||
def assert_equal(x, y):
|
||||
if isinstance(x, torch.Tensor) and isinstance(y, torch.Tensor):
|
||||
torch.testing.assert_allclose(x, y.to(x.device))
|
||||
elif isinstance(x, OrderedDict) and isinstance(y, OrderedDict):
|
||||
for x_value, y_value in zip(x.values(), y.values()):
|
||||
assert_equal(x_value, y_value)
|
||||
elif isinstance(x, dict) and isinstance(y, dict):
|
||||
assert x.keys() == y.keys()
|
||||
for key in x.keys():
|
||||
assert_equal(x[key], y[key])
|
||||
elif isinstance(x, str) and isinstance(y, str):
|
||||
assert x == y
|
||||
elif isinstance(x, Iterable) and isinstance(y, Iterable):
|
||||
assert len(x) == len(y)
|
||||
for x_item, y_item in zip(x, y):
|
||||
assert_equal(x_item, y_item)
|
||||
else:
|
||||
assert x == y
|
||||
|
||||
|
||||
class SubModel(nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.conv1 = nn.Conv2d(2, 2, kernel_size=1, groups=2)
|
||||
self.gn = nn.GroupNorm(2, 2)
|
||||
self.fc = nn.Linear(2, 2)
|
||||
self.param1 = nn.Parameter(torch.ones(1))
|
||||
|
||||
def forward(self, x):
|
||||
return x
|
||||
|
||||
|
||||
class ExampleModel(nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.param1 = nn.Parameter(torch.ones(1))
|
||||
self.conv1 = nn.Conv2d(3, 4, kernel_size=1, bias=False)
|
||||
self.conv2 = nn.Conv2d(4, 2, kernel_size=1)
|
||||
self.bn = nn.BatchNorm2d(2)
|
||||
self.sub = SubModel()
|
||||
self.fc = nn.Linear(2, 1)
|
||||
|
||||
def forward(self, x):
|
||||
return x
|
||||
|
||||
|
||||
def check_lamb_optimizer(optimizer,
|
||||
model,
|
||||
bias_lr_mult=1,
|
||||
bias_decay_mult=1,
|
||||
norm_decay_mult=1,
|
||||
dwconv_decay_mult=1):
|
||||
param_groups = optimizer.param_groups
|
||||
assert isinstance(optimizer, Optimizer)
|
||||
assert optimizer.defaults['lr'] == base_lr
|
||||
assert optimizer.defaults['weight_decay'] == base_wd
|
||||
model_parameters = list(model.parameters())
|
||||
assert len(param_groups) == len(model_parameters)
|
||||
for i, param in enumerate(model_parameters):
|
||||
param_group = param_groups[i]
|
||||
assert torch.equal(param_group['params'][0], param)
|
||||
# param1
|
||||
param1 = param_groups[0]
|
||||
assert param1['lr'] == base_lr
|
||||
assert param1['weight_decay'] == base_wd
|
||||
# conv1.weight
|
||||
conv1_weight = param_groups[1]
|
||||
assert conv1_weight['lr'] == base_lr
|
||||
assert conv1_weight['weight_decay'] == base_wd
|
||||
# conv2.weight
|
||||
conv2_weight = param_groups[2]
|
||||
assert conv2_weight['lr'] == base_lr
|
||||
assert conv2_weight['weight_decay'] == base_wd
|
||||
# conv2.bias
|
||||
conv2_bias = param_groups[3]
|
||||
assert conv2_bias['lr'] == base_lr * bias_lr_mult
|
||||
assert conv2_bias['weight_decay'] == base_wd * bias_decay_mult
|
||||
# bn.weight
|
||||
bn_weight = param_groups[4]
|
||||
assert bn_weight['lr'] == base_lr
|
||||
assert bn_weight['weight_decay'] == base_wd * norm_decay_mult
|
||||
# bn.bias
|
||||
bn_bias = param_groups[5]
|
||||
assert bn_bias['lr'] == base_lr
|
||||
assert bn_bias['weight_decay'] == base_wd * norm_decay_mult
|
||||
# sub.param1
|
||||
sub_param1 = param_groups[6]
|
||||
assert sub_param1['lr'] == base_lr
|
||||
assert sub_param1['weight_decay'] == base_wd
|
||||
# sub.conv1.weight
|
||||
sub_conv1_weight = param_groups[7]
|
||||
assert sub_conv1_weight['lr'] == base_lr
|
||||
assert sub_conv1_weight['weight_decay'] == base_wd * dwconv_decay_mult
|
||||
# sub.conv1.bias
|
||||
sub_conv1_bias = param_groups[8]
|
||||
assert sub_conv1_bias['lr'] == base_lr * bias_lr_mult
|
||||
assert sub_conv1_bias['weight_decay'] == base_wd * dwconv_decay_mult
|
||||
# sub.gn.weight
|
||||
sub_gn_weight = param_groups[9]
|
||||
assert sub_gn_weight['lr'] == base_lr
|
||||
assert sub_gn_weight['weight_decay'] == base_wd * norm_decay_mult
|
||||
# sub.gn.bias
|
||||
sub_gn_bias = param_groups[10]
|
||||
assert sub_gn_bias['lr'] == base_lr
|
||||
assert sub_gn_bias['weight_decay'] == base_wd * norm_decay_mult
|
||||
# sub.fc1.weight
|
||||
sub_fc_weight = param_groups[11]
|
||||
assert sub_fc_weight['lr'] == base_lr
|
||||
assert sub_fc_weight['weight_decay'] == base_wd
|
||||
# sub.fc1.bias
|
||||
sub_fc_bias = param_groups[12]
|
||||
assert sub_fc_bias['lr'] == base_lr * bias_lr_mult
|
||||
assert sub_fc_bias['weight_decay'] == base_wd * bias_decay_mult
|
||||
# fc1.weight
|
||||
fc_weight = param_groups[13]
|
||||
assert fc_weight['lr'] == base_lr
|
||||
assert fc_weight['weight_decay'] == base_wd
|
||||
# fc1.bias
|
||||
fc_bias = param_groups[14]
|
||||
assert fc_bias['lr'] == base_lr * bias_lr_mult
|
||||
assert fc_bias['weight_decay'] == base_wd * bias_decay_mult
|
||||
|
||||
|
||||
def _test_state_dict(weight, bias, input, constructor):
|
||||
weight = Variable(weight, requires_grad=True)
|
||||
bias = Variable(bias, requires_grad=True)
|
||||
inputs = Variable(input)
|
||||
|
||||
def fn_base(optimizer, weight, bias):
|
||||
optimizer.zero_grad()
|
||||
i = input_cuda if weight.is_cuda else inputs
|
||||
loss = (weight.mv(i) + bias).pow(2).sum()
|
||||
loss.backward()
|
||||
return loss
|
||||
|
||||
optimizer = constructor(weight, bias)
|
||||
fn = functools.partial(fn_base, optimizer, weight, bias)
|
||||
|
||||
# Prime the optimizer
|
||||
for _ in range(20):
|
||||
optimizer.step(fn)
|
||||
# Clone the weights and construct new optimizer for them
|
||||
weight_c = Variable(weight.data.clone(), requires_grad=True)
|
||||
bias_c = Variable(bias.data.clone(), requires_grad=True)
|
||||
optimizer_c = constructor(weight_c, bias_c)
|
||||
fn_c = functools.partial(fn_base, optimizer_c, weight_c, bias_c)
|
||||
# Load state dict
|
||||
state_dict = deepcopy(optimizer.state_dict())
|
||||
state_dict_c = deepcopy(optimizer.state_dict())
|
||||
optimizer_c.load_state_dict(state_dict_c)
|
||||
# Run both optimizations in parallel
|
||||
for _ in range(20):
|
||||
optimizer.step(fn)
|
||||
optimizer_c.step(fn_c)
|
||||
assert_equal(weight, weight_c)
|
||||
assert_equal(bias, bias_c)
|
||||
# Make sure state dict wasn't modified
|
||||
assert_equal(state_dict, state_dict_c)
|
||||
# Make sure state dict is deterministic with equal
|
||||
# but not identical parameters
|
||||
# NOTE: The state_dict of optimizers in PyTorch 1.5 have random keys,
|
||||
state_dict = deepcopy(optimizer.state_dict())
|
||||
state_dict_c = deepcopy(optimizer_c.state_dict())
|
||||
keys = state_dict['param_groups'][-1]['params']
|
||||
keys_c = state_dict_c['param_groups'][-1]['params']
|
||||
for key, key_c in zip(keys, keys_c):
|
||||
assert_equal(optimizer.state_dict()['state'][key],
|
||||
optimizer_c.state_dict()['state'][key_c])
|
||||
# Make sure repeated parameters have identical representation in state dict
|
||||
optimizer_c.param_groups.extend(optimizer_c.param_groups)
|
||||
assert_equal(optimizer_c.state_dict()['param_groups'][0],
|
||||
optimizer_c.state_dict()['param_groups'][1])
|
||||
|
||||
# Check that state dict can be loaded even when we cast parameters
|
||||
# to a different type and move to a different device.
|
||||
if not torch.cuda.is_available():
|
||||
return
|
||||
|
||||
input_cuda = Variable(inputs.data.float().cuda())
|
||||
weight_cuda = Variable(weight.data.float().cuda(), requires_grad=True)
|
||||
bias_cuda = Variable(bias.data.float().cuda(), requires_grad=True)
|
||||
optimizer_cuda = constructor(weight_cuda, bias_cuda)
|
||||
fn_cuda = functools.partial(fn_base, optimizer_cuda, weight_cuda,
|
||||
bias_cuda)
|
||||
|
||||
state_dict = deepcopy(optimizer.state_dict())
|
||||
state_dict_c = deepcopy(optimizer.state_dict())
|
||||
optimizer_cuda.load_state_dict(state_dict_c)
|
||||
|
||||
# Make sure state dict wasn't modified
|
||||
assert_equal(state_dict, state_dict_c)
|
||||
|
||||
for _ in range(20):
|
||||
optimizer.step(fn)
|
||||
optimizer_cuda.step(fn_cuda)
|
||||
assert_equal(weight, weight_cuda)
|
||||
assert_equal(bias, bias_cuda)
|
||||
|
||||
# validate deepcopy() copies all public attributes
|
||||
def getPublicAttr(obj):
|
||||
return set(k for k in obj.__dict__ if not k.startswith('_'))
|
||||
|
||||
assert_equal(getPublicAttr(optimizer), getPublicAttr(deepcopy(optimizer)))
|
||||
|
||||
|
||||
def _test_basic_cases_template(weight, bias, inputs, constructor,
|
||||
scheduler_constructors):
|
||||
"""Copied from PyTorch."""
|
||||
weight = Variable(weight, requires_grad=True)
|
||||
bias = Variable(bias, requires_grad=True)
|
||||
inputs = Variable(inputs)
|
||||
optimizer = constructor(weight, bias)
|
||||
schedulers = []
|
||||
for scheduler_constructor in scheduler_constructors:
|
||||
schedulers.append(scheduler_constructor(optimizer))
|
||||
|
||||
# to check if the optimizer can be printed as a string
|
||||
optimizer.__repr__()
|
||||
|
||||
def fn():
|
||||
optimizer.zero_grad()
|
||||
y = weight.mv(inputs)
|
||||
if y.is_cuda and bias.is_cuda and y.get_device() != bias.get_device():
|
||||
y = y.cuda(bias.get_device())
|
||||
loss = (y + bias).pow(2).sum()
|
||||
loss.backward()
|
||||
return loss
|
||||
|
||||
initial_value = fn().item()
|
||||
for _ in range(200):
|
||||
for scheduler in schedulers:
|
||||
scheduler.step()
|
||||
optimizer.step(fn)
|
||||
|
||||
assert fn().item() < initial_value
|
||||
|
||||
|
||||
def _test_basic_cases(constructor,
|
||||
scheduler_constructors=None,
|
||||
ignore_multidevice=False):
|
||||
"""Copied from PyTorch."""
|
||||
if scheduler_constructors is None:
|
||||
scheduler_constructors = []
|
||||
_test_state_dict(
|
||||
torch.randn(10, 5), torch.randn(10), torch.randn(5), constructor)
|
||||
_test_basic_cases_template(
|
||||
torch.randn(10, 5), torch.randn(10), torch.randn(5), constructor,
|
||||
scheduler_constructors)
|
||||
# non-contiguous parameters
|
||||
_test_basic_cases_template(
|
||||
torch.randn(10, 5, 2)[..., 0],
|
||||
torch.randn(10, 2)[..., 0], torch.randn(5), constructor,
|
||||
scheduler_constructors)
|
||||
# CUDA
|
||||
if not torch.cuda.is_available():
|
||||
return
|
||||
_test_basic_cases_template(
|
||||
torch.randn(10, 5).cuda(),
|
||||
torch.randn(10).cuda(),
|
||||
torch.randn(5).cuda(), constructor, scheduler_constructors)
|
||||
# Multi-GPU
|
||||
if not torch.cuda.device_count() > 1 or ignore_multidevice:
|
||||
return
|
||||
_test_basic_cases_template(
|
||||
torch.randn(10, 5).cuda(0),
|
||||
torch.randn(10).cuda(1),
|
||||
torch.randn(5).cuda(0), constructor, scheduler_constructors)
|
||||
|
||||
|
||||
def test_lamb_optimizer():
|
||||
model = ExampleModel()
|
||||
optimizer_cfg = dict(
|
||||
type='Lamb',
|
||||
lr=base_lr,
|
||||
betas=(0.9, 0.999),
|
||||
eps=1e-8,
|
||||
weight_decay=base_wd,
|
||||
paramwise_cfg=dict(
|
||||
bias_lr_mult=2,
|
||||
bias_decay_mult=0.5,
|
||||
norm_decay_mult=0,
|
||||
dwconv_decay_mult=0.1))
|
||||
optimizer = build_optimizer(model, optimizer_cfg)
|
||||
check_lamb_optimizer(optimizer, model, **optimizer_cfg['paramwise_cfg'])
|
||||
|
||||
_test_basic_cases(lambda weight, bias: build_from_cfg(
|
||||
dict(type='Lamb', params=[weight, bias], lr=base_lr), OPTIMIZERS))
|
Loading…
Reference in New Issue