310 lines
11 KiB
Python
310 lines
11 KiB
Python
# Copyright (c) OpenMMLab. All rights reserved.
|
|
import functools
|
|
from collections import OrderedDict
|
|
from copy import deepcopy
|
|
from typing import Iterable
|
|
|
|
import torch
|
|
import torch.nn as nn
|
|
from mmengine.runner import build_optimizer
|
|
from mmengine.runner.optimizer.builder import OPTIMIZERS
|
|
from mmcv.utils.registry import build_from_cfg
|
|
from torch.autograd import Variable
|
|
from torch.optim.optimizer import Optimizer
|
|
|
|
import mmcls.core # noqa: F401
|
|
|
|
base_lr = 0.01
|
|
base_wd = 0.0001
|
|
|
|
|
|
def assert_equal(x, y):
|
|
if isinstance(x, torch.Tensor) and isinstance(y, torch.Tensor):
|
|
torch.testing.assert_allclose(x, y.to(x.device))
|
|
elif isinstance(x, OrderedDict) and isinstance(y, OrderedDict):
|
|
for x_value, y_value in zip(x.values(), y.values()):
|
|
assert_equal(x_value, y_value)
|
|
elif isinstance(x, dict) and isinstance(y, dict):
|
|
assert x.keys() == y.keys()
|
|
for key in x.keys():
|
|
assert_equal(x[key], y[key])
|
|
elif isinstance(x, str) and isinstance(y, str):
|
|
assert x == y
|
|
elif isinstance(x, Iterable) and isinstance(y, Iterable):
|
|
assert len(x) == len(y)
|
|
for x_item, y_item in zip(x, y):
|
|
assert_equal(x_item, y_item)
|
|
else:
|
|
assert x == y
|
|
|
|
|
|
class SubModel(nn.Module):
|
|
|
|
def __init__(self):
|
|
super().__init__()
|
|
self.conv1 = nn.Conv2d(2, 2, kernel_size=1, groups=2)
|
|
self.gn = nn.GroupNorm(2, 2)
|
|
self.fc = nn.Linear(2, 2)
|
|
self.param1 = nn.Parameter(torch.ones(1))
|
|
|
|
def forward(self, x):
|
|
return x
|
|
|
|
|
|
class ExampleModel(nn.Module):
|
|
|
|
def __init__(self):
|
|
super().__init__()
|
|
self.param1 = nn.Parameter(torch.ones(1))
|
|
self.conv1 = nn.Conv2d(3, 4, kernel_size=1, bias=False)
|
|
self.conv2 = nn.Conv2d(4, 2, kernel_size=1)
|
|
self.bn = nn.BatchNorm2d(2)
|
|
self.sub = SubModel()
|
|
self.fc = nn.Linear(2, 1)
|
|
|
|
def forward(self, x):
|
|
return x
|
|
|
|
|
|
def check_lamb_optimizer(optimizer,
|
|
model,
|
|
bias_lr_mult=1,
|
|
bias_decay_mult=1,
|
|
norm_decay_mult=1,
|
|
dwconv_decay_mult=1):
|
|
param_groups = optimizer.param_groups
|
|
assert isinstance(optimizer, Optimizer)
|
|
assert optimizer.defaults['lr'] == base_lr
|
|
assert optimizer.defaults['weight_decay'] == base_wd
|
|
model_parameters = list(model.parameters())
|
|
assert len(param_groups) == len(model_parameters)
|
|
for i, param in enumerate(model_parameters):
|
|
param_group = param_groups[i]
|
|
assert torch.equal(param_group['params'][0], param)
|
|
# param1
|
|
param1 = param_groups[0]
|
|
assert param1['lr'] == base_lr
|
|
assert param1['weight_decay'] == base_wd
|
|
# conv1.weight
|
|
conv1_weight = param_groups[1]
|
|
assert conv1_weight['lr'] == base_lr
|
|
assert conv1_weight['weight_decay'] == base_wd
|
|
# conv2.weight
|
|
conv2_weight = param_groups[2]
|
|
assert conv2_weight['lr'] == base_lr
|
|
assert conv2_weight['weight_decay'] == base_wd
|
|
# conv2.bias
|
|
conv2_bias = param_groups[3]
|
|
assert conv2_bias['lr'] == base_lr * bias_lr_mult
|
|
assert conv2_bias['weight_decay'] == base_wd * bias_decay_mult
|
|
# bn.weight
|
|
bn_weight = param_groups[4]
|
|
assert bn_weight['lr'] == base_lr
|
|
assert bn_weight['weight_decay'] == base_wd * norm_decay_mult
|
|
# bn.bias
|
|
bn_bias = param_groups[5]
|
|
assert bn_bias['lr'] == base_lr
|
|
assert bn_bias['weight_decay'] == base_wd * norm_decay_mult
|
|
# sub.param1
|
|
sub_param1 = param_groups[6]
|
|
assert sub_param1['lr'] == base_lr
|
|
assert sub_param1['weight_decay'] == base_wd
|
|
# sub.conv1.weight
|
|
sub_conv1_weight = param_groups[7]
|
|
assert sub_conv1_weight['lr'] == base_lr
|
|
assert sub_conv1_weight['weight_decay'] == base_wd * dwconv_decay_mult
|
|
# sub.conv1.bias
|
|
sub_conv1_bias = param_groups[8]
|
|
assert sub_conv1_bias['lr'] == base_lr * bias_lr_mult
|
|
assert sub_conv1_bias['weight_decay'] == base_wd * dwconv_decay_mult
|
|
# sub.gn.weight
|
|
sub_gn_weight = param_groups[9]
|
|
assert sub_gn_weight['lr'] == base_lr
|
|
assert sub_gn_weight['weight_decay'] == base_wd * norm_decay_mult
|
|
# sub.gn.bias
|
|
sub_gn_bias = param_groups[10]
|
|
assert sub_gn_bias['lr'] == base_lr
|
|
assert sub_gn_bias['weight_decay'] == base_wd * norm_decay_mult
|
|
# sub.fc1.weight
|
|
sub_fc_weight = param_groups[11]
|
|
assert sub_fc_weight['lr'] == base_lr
|
|
assert sub_fc_weight['weight_decay'] == base_wd
|
|
# sub.fc1.bias
|
|
sub_fc_bias = param_groups[12]
|
|
assert sub_fc_bias['lr'] == base_lr * bias_lr_mult
|
|
assert sub_fc_bias['weight_decay'] == base_wd * bias_decay_mult
|
|
# fc1.weight
|
|
fc_weight = param_groups[13]
|
|
assert fc_weight['lr'] == base_lr
|
|
assert fc_weight['weight_decay'] == base_wd
|
|
# fc1.bias
|
|
fc_bias = param_groups[14]
|
|
assert fc_bias['lr'] == base_lr * bias_lr_mult
|
|
assert fc_bias['weight_decay'] == base_wd * bias_decay_mult
|
|
|
|
|
|
def _test_state_dict(weight, bias, input, constructor):
|
|
weight = Variable(weight, requires_grad=True)
|
|
bias = Variable(bias, requires_grad=True)
|
|
inputs = Variable(input)
|
|
|
|
def fn_base(optimizer, weight, bias):
|
|
optimizer.zero_grad()
|
|
i = input_cuda if weight.is_cuda else inputs
|
|
loss = (weight.mv(i) + bias).pow(2).sum()
|
|
loss.backward()
|
|
return loss
|
|
|
|
optimizer = constructor(weight, bias)
|
|
fn = functools.partial(fn_base, optimizer, weight, bias)
|
|
|
|
# Prime the optimizer
|
|
for _ in range(20):
|
|
optimizer.step(fn)
|
|
# Clone the weights and construct new optimizer for them
|
|
weight_c = Variable(weight.data.clone(), requires_grad=True)
|
|
bias_c = Variable(bias.data.clone(), requires_grad=True)
|
|
optimizer_c = constructor(weight_c, bias_c)
|
|
fn_c = functools.partial(fn_base, optimizer_c, weight_c, bias_c)
|
|
# Load state dict
|
|
state_dict = deepcopy(optimizer.state_dict())
|
|
state_dict_c = deepcopy(optimizer.state_dict())
|
|
optimizer_c.load_state_dict(state_dict_c)
|
|
# Run both optimizations in parallel
|
|
for _ in range(20):
|
|
optimizer.step(fn)
|
|
optimizer_c.step(fn_c)
|
|
assert_equal(weight, weight_c)
|
|
assert_equal(bias, bias_c)
|
|
# Make sure state dict wasn't modified
|
|
assert_equal(state_dict, state_dict_c)
|
|
# Make sure state dict is deterministic with equal
|
|
# but not identical parameters
|
|
# NOTE: The state_dict of optimizers in PyTorch 1.5 have random keys,
|
|
state_dict = deepcopy(optimizer.state_dict())
|
|
state_dict_c = deepcopy(optimizer_c.state_dict())
|
|
keys = state_dict['param_groups'][-1]['params']
|
|
keys_c = state_dict_c['param_groups'][-1]['params']
|
|
for key, key_c in zip(keys, keys_c):
|
|
assert_equal(optimizer.state_dict()['state'][key],
|
|
optimizer_c.state_dict()['state'][key_c])
|
|
# Make sure repeated parameters have identical representation in state dict
|
|
optimizer_c.param_groups.extend(optimizer_c.param_groups)
|
|
assert_equal(optimizer_c.state_dict()['param_groups'][0],
|
|
optimizer_c.state_dict()['param_groups'][1])
|
|
|
|
# Check that state dict can be loaded even when we cast parameters
|
|
# to a different type and move to a different device.
|
|
if not torch.cuda.is_available():
|
|
return
|
|
|
|
input_cuda = Variable(inputs.data.float().cuda())
|
|
weight_cuda = Variable(weight.data.float().cuda(), requires_grad=True)
|
|
bias_cuda = Variable(bias.data.float().cuda(), requires_grad=True)
|
|
optimizer_cuda = constructor(weight_cuda, bias_cuda)
|
|
fn_cuda = functools.partial(fn_base, optimizer_cuda, weight_cuda,
|
|
bias_cuda)
|
|
|
|
state_dict = deepcopy(optimizer.state_dict())
|
|
state_dict_c = deepcopy(optimizer.state_dict())
|
|
optimizer_cuda.load_state_dict(state_dict_c)
|
|
|
|
# Make sure state dict wasn't modified
|
|
assert_equal(state_dict, state_dict_c)
|
|
|
|
for _ in range(20):
|
|
optimizer.step(fn)
|
|
optimizer_cuda.step(fn_cuda)
|
|
assert_equal(weight, weight_cuda)
|
|
assert_equal(bias, bias_cuda)
|
|
|
|
# validate deepcopy() copies all public attributes
|
|
def getPublicAttr(obj):
|
|
return set(k for k in obj.__dict__ if not k.startswith('_'))
|
|
|
|
assert_equal(getPublicAttr(optimizer), getPublicAttr(deepcopy(optimizer)))
|
|
|
|
|
|
def _test_basic_cases_template(weight, bias, inputs, constructor,
|
|
scheduler_constructors):
|
|
"""Copied from PyTorch."""
|
|
weight = Variable(weight, requires_grad=True)
|
|
bias = Variable(bias, requires_grad=True)
|
|
inputs = Variable(inputs)
|
|
optimizer = constructor(weight, bias)
|
|
schedulers = []
|
|
for scheduler_constructor in scheduler_constructors:
|
|
schedulers.append(scheduler_constructor(optimizer))
|
|
|
|
# to check if the optimizer can be printed as a string
|
|
optimizer.__repr__()
|
|
|
|
def fn():
|
|
optimizer.zero_grad()
|
|
y = weight.mv(inputs)
|
|
if y.is_cuda and bias.is_cuda and y.get_device() != bias.get_device():
|
|
y = y.cuda(bias.get_device())
|
|
loss = (y + bias).pow(2).sum()
|
|
loss.backward()
|
|
return loss
|
|
|
|
initial_value = fn().item()
|
|
for _ in range(200):
|
|
for scheduler in schedulers:
|
|
scheduler.step()
|
|
optimizer.step(fn)
|
|
|
|
assert fn().item() < initial_value
|
|
|
|
|
|
def _test_basic_cases(constructor,
|
|
scheduler_constructors=None,
|
|
ignore_multidevice=False):
|
|
"""Copied from PyTorch."""
|
|
if scheduler_constructors is None:
|
|
scheduler_constructors = []
|
|
_test_state_dict(
|
|
torch.randn(10, 5), torch.randn(10), torch.randn(5), constructor)
|
|
_test_basic_cases_template(
|
|
torch.randn(10, 5), torch.randn(10), torch.randn(5), constructor,
|
|
scheduler_constructors)
|
|
# non-contiguous parameters
|
|
_test_basic_cases_template(
|
|
torch.randn(10, 5, 2)[..., 0],
|
|
torch.randn(10, 2)[..., 0], torch.randn(5), constructor,
|
|
scheduler_constructors)
|
|
# CUDA
|
|
if not torch.cuda.is_available():
|
|
return
|
|
_test_basic_cases_template(
|
|
torch.randn(10, 5).cuda(),
|
|
torch.randn(10).cuda(),
|
|
torch.randn(5).cuda(), constructor, scheduler_constructors)
|
|
# Multi-GPU
|
|
if not torch.cuda.device_count() > 1 or ignore_multidevice:
|
|
return
|
|
_test_basic_cases_template(
|
|
torch.randn(10, 5).cuda(0),
|
|
torch.randn(10).cuda(1),
|
|
torch.randn(5).cuda(0), constructor, scheduler_constructors)
|
|
|
|
|
|
def test_lamb_optimizer():
|
|
model = ExampleModel()
|
|
optimizer_cfg = dict(
|
|
type='Lamb',
|
|
lr=base_lr,
|
|
betas=(0.9, 0.999),
|
|
eps=1e-8,
|
|
weight_decay=base_wd,
|
|
paramwise_cfg=dict(
|
|
bias_lr_mult=2,
|
|
bias_decay_mult=0.5,
|
|
norm_decay_mult=0,
|
|
dwconv_decay_mult=0.1))
|
|
optimizer = build_optimizer(model, optimizer_cfg)
|
|
check_lamb_optimizer(optimizer, model, **optimizer_cfg['paramwise_cfg'])
|
|
|
|
_test_basic_cases(lambda weight, bias: build_from_cfg(
|
|
dict(type='Lamb', params=[weight, bias], lr=base_lr), OPTIMIZERS))
|