mmpretrain/tests/test_runtime/test_optimizer.py

309 lines
11 KiB
Python
Raw Normal View History

import functools
from collections import OrderedDict
from copy import deepcopy
from typing import Iterable
import torch
import torch.nn as nn
from mmcv.runner import build_optimizer
from mmcv.runner.optimizer.builder import OPTIMIZERS
from mmcv.utils.registry import build_from_cfg
from torch.autograd import Variable
from torch.optim.optimizer import Optimizer
import mmcls.core # noqa: F401
base_lr = 0.01
base_wd = 0.0001
def assert_equal(x, y):
if isinstance(x, torch.Tensor) and isinstance(y, torch.Tensor):
torch.testing.assert_allclose(x, y.to(x.device))
elif isinstance(x, OrderedDict) and isinstance(y, OrderedDict):
for x_value, y_value in zip(x.values(), y.values()):
assert_equal(x_value, y_value)
elif isinstance(x, dict) and isinstance(y, dict):
assert x.keys() == y.keys()
for key in x.keys():
assert_equal(x[key], y[key])
elif isinstance(x, str) and isinstance(y, str):
assert x == y
elif isinstance(x, Iterable) and isinstance(y, Iterable):
assert len(x) == len(y)
for x_item, y_item in zip(x, y):
assert_equal(x_item, y_item)
else:
assert x == y
class SubModel(nn.Module):
def __init__(self):
super().__init__()
self.conv1 = nn.Conv2d(2, 2, kernel_size=1, groups=2)
self.gn = nn.GroupNorm(2, 2)
self.fc = nn.Linear(2, 2)
self.param1 = nn.Parameter(torch.ones(1))
def forward(self, x):
return x
class ExampleModel(nn.Module):
def __init__(self):
super().__init__()
self.param1 = nn.Parameter(torch.ones(1))
self.conv1 = nn.Conv2d(3, 4, kernel_size=1, bias=False)
self.conv2 = nn.Conv2d(4, 2, kernel_size=1)
self.bn = nn.BatchNorm2d(2)
self.sub = SubModel()
self.fc = nn.Linear(2, 1)
def forward(self, x):
return x
def check_lamb_optimizer(optimizer,
model,
bias_lr_mult=1,
bias_decay_mult=1,
norm_decay_mult=1,
dwconv_decay_mult=1):
param_groups = optimizer.param_groups
assert isinstance(optimizer, Optimizer)
assert optimizer.defaults['lr'] == base_lr
assert optimizer.defaults['weight_decay'] == base_wd
model_parameters = list(model.parameters())
assert len(param_groups) == len(model_parameters)
for i, param in enumerate(model_parameters):
param_group = param_groups[i]
assert torch.equal(param_group['params'][0], param)
# param1
param1 = param_groups[0]
assert param1['lr'] == base_lr
assert param1['weight_decay'] == base_wd
# conv1.weight
conv1_weight = param_groups[1]
assert conv1_weight['lr'] == base_lr
assert conv1_weight['weight_decay'] == base_wd
# conv2.weight
conv2_weight = param_groups[2]
assert conv2_weight['lr'] == base_lr
assert conv2_weight['weight_decay'] == base_wd
# conv2.bias
conv2_bias = param_groups[3]
assert conv2_bias['lr'] == base_lr * bias_lr_mult
assert conv2_bias['weight_decay'] == base_wd * bias_decay_mult
# bn.weight
bn_weight = param_groups[4]
assert bn_weight['lr'] == base_lr
assert bn_weight['weight_decay'] == base_wd * norm_decay_mult
# bn.bias
bn_bias = param_groups[5]
assert bn_bias['lr'] == base_lr
assert bn_bias['weight_decay'] == base_wd * norm_decay_mult
# sub.param1
sub_param1 = param_groups[6]
assert sub_param1['lr'] == base_lr
assert sub_param1['weight_decay'] == base_wd
# sub.conv1.weight
sub_conv1_weight = param_groups[7]
assert sub_conv1_weight['lr'] == base_lr
assert sub_conv1_weight['weight_decay'] == base_wd * dwconv_decay_mult
# sub.conv1.bias
sub_conv1_bias = param_groups[8]
assert sub_conv1_bias['lr'] == base_lr * bias_lr_mult
assert sub_conv1_bias['weight_decay'] == base_wd * dwconv_decay_mult
# sub.gn.weight
sub_gn_weight = param_groups[9]
assert sub_gn_weight['lr'] == base_lr
assert sub_gn_weight['weight_decay'] == base_wd * norm_decay_mult
# sub.gn.bias
sub_gn_bias = param_groups[10]
assert sub_gn_bias['lr'] == base_lr
assert sub_gn_bias['weight_decay'] == base_wd * norm_decay_mult
# sub.fc1.weight
sub_fc_weight = param_groups[11]
assert sub_fc_weight['lr'] == base_lr
assert sub_fc_weight['weight_decay'] == base_wd
# sub.fc1.bias
sub_fc_bias = param_groups[12]
assert sub_fc_bias['lr'] == base_lr * bias_lr_mult
assert sub_fc_bias['weight_decay'] == base_wd * bias_decay_mult
# fc1.weight
fc_weight = param_groups[13]
assert fc_weight['lr'] == base_lr
assert fc_weight['weight_decay'] == base_wd
# fc1.bias
fc_bias = param_groups[14]
assert fc_bias['lr'] == base_lr * bias_lr_mult
assert fc_bias['weight_decay'] == base_wd * bias_decay_mult
def _test_state_dict(weight, bias, input, constructor):
weight = Variable(weight, requires_grad=True)
bias = Variable(bias, requires_grad=True)
inputs = Variable(input)
def fn_base(optimizer, weight, bias):
optimizer.zero_grad()
i = input_cuda if weight.is_cuda else inputs
loss = (weight.mv(i) + bias).pow(2).sum()
loss.backward()
return loss
optimizer = constructor(weight, bias)
fn = functools.partial(fn_base, optimizer, weight, bias)
# Prime the optimizer
for _ in range(20):
optimizer.step(fn)
# Clone the weights and construct new optimizer for them
weight_c = Variable(weight.data.clone(), requires_grad=True)
bias_c = Variable(bias.data.clone(), requires_grad=True)
optimizer_c = constructor(weight_c, bias_c)
fn_c = functools.partial(fn_base, optimizer_c, weight_c, bias_c)
# Load state dict
state_dict = deepcopy(optimizer.state_dict())
state_dict_c = deepcopy(optimizer.state_dict())
optimizer_c.load_state_dict(state_dict_c)
# Run both optimizations in parallel
for _ in range(20):
optimizer.step(fn)
optimizer_c.step(fn_c)
assert_equal(weight, weight_c)
assert_equal(bias, bias_c)
# Make sure state dict wasn't modified
assert_equal(state_dict, state_dict_c)
# Make sure state dict is deterministic with equal
# but not identical parameters
# NOTE: The state_dict of optimizers in PyTorch 1.5 have random keys,
state_dict = deepcopy(optimizer.state_dict())
state_dict_c = deepcopy(optimizer_c.state_dict())
keys = state_dict['param_groups'][-1]['params']
keys_c = state_dict_c['param_groups'][-1]['params']
for key, key_c in zip(keys, keys_c):
assert_equal(optimizer.state_dict()['state'][key],
optimizer_c.state_dict()['state'][key_c])
# Make sure repeated parameters have identical representation in state dict
optimizer_c.param_groups.extend(optimizer_c.param_groups)
assert_equal(optimizer_c.state_dict()['param_groups'][0],
optimizer_c.state_dict()['param_groups'][1])
# Check that state dict can be loaded even when we cast parameters
# to a different type and move to a different device.
if not torch.cuda.is_available():
return
input_cuda = Variable(inputs.data.float().cuda())
weight_cuda = Variable(weight.data.float().cuda(), requires_grad=True)
bias_cuda = Variable(bias.data.float().cuda(), requires_grad=True)
optimizer_cuda = constructor(weight_cuda, bias_cuda)
fn_cuda = functools.partial(fn_base, optimizer_cuda, weight_cuda,
bias_cuda)
state_dict = deepcopy(optimizer.state_dict())
state_dict_c = deepcopy(optimizer.state_dict())
optimizer_cuda.load_state_dict(state_dict_c)
# Make sure state dict wasn't modified
assert_equal(state_dict, state_dict_c)
for _ in range(20):
optimizer.step(fn)
optimizer_cuda.step(fn_cuda)
assert_equal(weight, weight_cuda)
assert_equal(bias, bias_cuda)
# validate deepcopy() copies all public attributes
def getPublicAttr(obj):
return set(k for k in obj.__dict__ if not k.startswith('_'))
assert_equal(getPublicAttr(optimizer), getPublicAttr(deepcopy(optimizer)))
def _test_basic_cases_template(weight, bias, inputs, constructor,
scheduler_constructors):
"""Copied from PyTorch."""
weight = Variable(weight, requires_grad=True)
bias = Variable(bias, requires_grad=True)
inputs = Variable(inputs)
optimizer = constructor(weight, bias)
schedulers = []
for scheduler_constructor in scheduler_constructors:
schedulers.append(scheduler_constructor(optimizer))
# to check if the optimizer can be printed as a string
optimizer.__repr__()
def fn():
optimizer.zero_grad()
y = weight.mv(inputs)
if y.is_cuda and bias.is_cuda and y.get_device() != bias.get_device():
y = y.cuda(bias.get_device())
loss = (y + bias).pow(2).sum()
loss.backward()
return loss
initial_value = fn().item()
for _ in range(200):
for scheduler in schedulers:
scheduler.step()
optimizer.step(fn)
assert fn().item() < initial_value
def _test_basic_cases(constructor,
scheduler_constructors=None,
ignore_multidevice=False):
"""Copied from PyTorch."""
if scheduler_constructors is None:
scheduler_constructors = []
_test_state_dict(
torch.randn(10, 5), torch.randn(10), torch.randn(5), constructor)
_test_basic_cases_template(
torch.randn(10, 5), torch.randn(10), torch.randn(5), constructor,
scheduler_constructors)
# non-contiguous parameters
_test_basic_cases_template(
torch.randn(10, 5, 2)[..., 0],
torch.randn(10, 2)[..., 0], torch.randn(5), constructor,
scheduler_constructors)
# CUDA
if not torch.cuda.is_available():
return
_test_basic_cases_template(
torch.randn(10, 5).cuda(),
torch.randn(10).cuda(),
torch.randn(5).cuda(), constructor, scheduler_constructors)
# Multi-GPU
if not torch.cuda.device_count() > 1 or ignore_multidevice:
return
_test_basic_cases_template(
torch.randn(10, 5).cuda(0),
torch.randn(10).cuda(1),
torch.randn(5).cuda(0), constructor, scheduler_constructors)
def test_lamb_optimizer():
model = ExampleModel()
optimizer_cfg = dict(
type='Lamb',
lr=base_lr,
betas=(0.9, 0.999),
eps=1e-8,
weight_decay=base_wd,
paramwise_cfg=dict(
bias_lr_mult=2,
bias_decay_mult=0.5,
norm_decay_mult=0,
dwconv_decay_mult=0.1))
optimizer = build_optimizer(model, optimizer_cfg)
check_lamb_optimizer(optimizer, model, **optimizer_cfg['paramwise_cfg'])
_test_basic_cases(lambda weight, bias: build_from_cfg(
dict(type='Lamb', params=[weight, bias], lr=base_lr), OPTIMIZERS))