mirror of
https://github.com/alibaba/EasyCV.git
synced 2025-06-03 14:49:00 +08:00
253 lines
10 KiB
Python
253 lines
10 KiB
Python
# Copyright (c) Alibaba, Inc. and its affiliates.
|
|
import functools
|
|
import unittest
|
|
from copy import deepcopy
|
|
from distutils.version import LooseVersion
|
|
|
|
import torch
|
|
from torch.autograd import Variable
|
|
from torch.optim.lr_scheduler import ExponentialLR, ReduceLROnPlateau, StepLR
|
|
from torch.testing._internal.common_utils import TestCase
|
|
|
|
from easycv.framework.errors import ValueError
|
|
|
|
|
|
@unittest.skipIf(
|
|
LooseVersion(torch.__version__) < LooseVersion('1.6.0'),
|
|
'skip some test bugs below 1.6.0 ')
|
|
class TestOptim(TestCase):
|
|
exact_dtype = True
|
|
|
|
def _test_basic_cases_template(self, weight, bias, input, constructor,
|
|
scheduler_constructors):
|
|
weight = Variable(weight, requires_grad=True)
|
|
bias = Variable(bias, requires_grad=True)
|
|
input = Variable(input)
|
|
optimizer = constructor(weight, bias)
|
|
schedulers = []
|
|
for scheduler_constructor in scheduler_constructors:
|
|
schedulers.append(scheduler_constructor(optimizer))
|
|
|
|
# to check if the optimizer can be printed as a string
|
|
optimizer.__repr__()
|
|
|
|
def fn():
|
|
optimizer.zero_grad()
|
|
y = weight.mv(input)
|
|
if y.is_cuda and bias.is_cuda and y.get_device(
|
|
) != bias.get_device():
|
|
y = y.cuda(bias.get_device())
|
|
loss = (y + bias).pow(2).sum()
|
|
loss.backward()
|
|
return loss
|
|
|
|
initial_value = fn().item()
|
|
for _i in range(200):
|
|
for scheduler in schedulers:
|
|
if isinstance(scheduler, ReduceLROnPlateau):
|
|
val_loss = fn()
|
|
scheduler.step(val_loss)
|
|
else:
|
|
scheduler.step()
|
|
optimizer.step(fn)
|
|
self.assertLess(fn().item(), initial_value)
|
|
|
|
def _test_state_dict(self, weight, bias, input, constructor):
|
|
weight = Variable(weight, requires_grad=True)
|
|
bias = Variable(bias, requires_grad=True)
|
|
input = Variable(input)
|
|
|
|
def fn_base(optimizer, weight, bias):
|
|
optimizer.zero_grad()
|
|
i = input_cuda if weight.is_cuda else input
|
|
loss = (weight.mv(i) + bias).pow(2).sum()
|
|
loss.backward()
|
|
return loss
|
|
|
|
optimizer = constructor(weight, bias)
|
|
fn = functools.partial(fn_base, optimizer, weight, bias)
|
|
|
|
# Prime the optimizer
|
|
for _i in range(20):
|
|
optimizer.step(fn)
|
|
# Clone the weights and construct new optimizer for them
|
|
weight_c = Variable(weight.data.clone(), requires_grad=True)
|
|
bias_c = Variable(bias.data.clone(), requires_grad=True)
|
|
optimizer_c = constructor(weight_c, bias_c)
|
|
fn_c = functools.partial(fn_base, optimizer_c, weight_c, bias_c)
|
|
# Load state dict
|
|
state_dict = deepcopy(optimizer.state_dict())
|
|
state_dict_c = deepcopy(optimizer.state_dict())
|
|
optimizer_c.load_state_dict(state_dict_c)
|
|
# Run both optimizations in parallel
|
|
for _i in range(20):
|
|
optimizer.step(fn)
|
|
optimizer_c.step(fn_c)
|
|
self.assertEqual(weight, weight_c)
|
|
self.assertEqual(bias, bias_c)
|
|
# Make sure state dict wasn't modified
|
|
self.assertEqual(state_dict, state_dict_c)
|
|
# Make sure state dict is deterministic with equal but not identical parameters
|
|
self.assertEqual(optimizer.state_dict(), optimizer_c.state_dict())
|
|
# Make sure repeated parameters have identical representation in state dict
|
|
optimizer_c.param_groups.extend(optimizer_c.param_groups)
|
|
self.assertEqual(optimizer.state_dict()['param_groups'][-1],
|
|
optimizer_c.state_dict()['param_groups'][-1])
|
|
|
|
# Check that state dict can be loaded even when we cast parameters
|
|
# to a different type and move to a different device.
|
|
if not torch.cuda.is_available():
|
|
return
|
|
|
|
input_cuda = Variable(input.data.float().cuda())
|
|
weight_cuda = Variable(weight.data.float().cuda(), requires_grad=True)
|
|
bias_cuda = Variable(bias.data.float().cuda(), requires_grad=True)
|
|
optimizer_cuda = constructor(weight_cuda, bias_cuda)
|
|
fn_cuda = functools.partial(fn_base, optimizer_cuda, weight_cuda,
|
|
bias_cuda)
|
|
|
|
state_dict = deepcopy(optimizer.state_dict())
|
|
state_dict_c = deepcopy(optimizer.state_dict())
|
|
optimizer_cuda.load_state_dict(state_dict_c)
|
|
|
|
# Make sure state dict wasn't modified
|
|
self.assertEqual(state_dict, state_dict_c)
|
|
|
|
for _i in range(20):
|
|
optimizer.step(fn)
|
|
optimizer_cuda.step(fn_cuda)
|
|
self.assertEqual(weight, weight_cuda)
|
|
self.assertEqual(bias, bias_cuda)
|
|
|
|
# validate deepcopy() copies all public attributes
|
|
def getPublicAttr(obj):
|
|
return set(k for k in obj.__dict__ if not k.startswith('_'))
|
|
|
|
try:
|
|
self.assertEqual(
|
|
getPublicAttr(optimizer), getPublicAttr(deepcopy(optimizer)))
|
|
except:
|
|
self.assertEqual(
|
|
getPublicAttr(optimizer), getPublicAttr(deepcopy(optimizer)))
|
|
|
|
def _test_basic_cases(self,
|
|
constructor,
|
|
scheduler_constructors=None,
|
|
ignore_multidevice=False):
|
|
if scheduler_constructors is None:
|
|
scheduler_constructors = []
|
|
self._test_state_dict(
|
|
torch.randn(10, 5), torch.randn(10), torch.randn(5), constructor)
|
|
# self._test_basic_cases_template(
|
|
# torch.randn(10, 5),
|
|
# torch.randn(10),
|
|
# torch.randn(5),
|
|
# constructor,
|
|
# scheduler_constructors
|
|
# )
|
|
# non-contiguous parameters
|
|
self._test_basic_cases_template(
|
|
torch.randn(10, 5, 2)[..., 0],
|
|
torch.randn(10, 2)[..., 0], torch.randn(5), constructor,
|
|
scheduler_constructors)
|
|
# CUDA
|
|
if not torch.cuda.is_available():
|
|
return
|
|
self._test_basic_cases_template(
|
|
torch.randn(10, 5).cuda(),
|
|
torch.randn(10).cuda(),
|
|
torch.randn(5).cuda(), constructor, scheduler_constructors)
|
|
# Multi-GPU
|
|
if not torch.cuda.device_count() > 1 or ignore_multidevice:
|
|
return
|
|
self._test_basic_cases_template(
|
|
torch.randn(10, 5).cuda(0),
|
|
torch.randn(10).cuda(1),
|
|
torch.randn(5).cuda(0), constructor, scheduler_constructors)
|
|
|
|
def _build_params_dict(self, weight, bias, **kwargs):
|
|
return [{'params': [weight]}, dict(params=[bias], **kwargs)]
|
|
|
|
def _build_params_dict_single(self, weight, bias, **kwargs):
|
|
return [dict(params=bias, **kwargs)]
|
|
|
|
def test_lars(self):
|
|
from easycv.core.optimizer import LARS
|
|
optimizer = LARS
|
|
|
|
self._test_basic_cases(
|
|
lambda weight, bias: optimizer([weight, bias], lr=1e-3))
|
|
self._test_basic_cases(lambda weight, bias: optimizer(
|
|
self._build_params_dict(weight, bias, lr=1e-2), lr=1e-3))
|
|
self._test_basic_cases(lambda weight, bias: optimizer(
|
|
self._build_params_dict_single(weight, bias, lr=1e-2), lr=1e-3))
|
|
self._test_basic_cases(lambda weight, bias: optimizer(
|
|
self._build_params_dict_single(weight, bias, lr=1e-2)))
|
|
self._test_basic_cases(
|
|
lambda weight, bias: optimizer([weight, bias], lr=1e-3),
|
|
[lambda opt: StepLR(opt, gamma=0.9, step_size=10)])
|
|
self._test_basic_cases(
|
|
lambda weight, bias: optimizer([weight, bias], lr=1e-3), [
|
|
lambda opt: StepLR(opt, gamma=0.9, step_size=10),
|
|
lambda opt: ReduceLROnPlateau(opt)
|
|
])
|
|
self._test_basic_cases(
|
|
lambda weight, bias: optimizer([weight, bias], lr=1e-3), [
|
|
lambda opt: StepLR(opt, gamma=0.99, step_size=10),
|
|
lambda opt: ExponentialLR(opt, gamma=0.99),
|
|
lambda opt: ReduceLROnPlateau(opt)
|
|
])
|
|
self._test_basic_cases(lambda weight, bias: optimizer(
|
|
[weight, bias], lr=1e-3, momentum=1))
|
|
self._test_basic_cases(lambda weight, bias: optimizer(
|
|
[weight, bias], lr=1e-3, momentum=1, weight_decay=1))
|
|
self._test_basic_cases(lambda weight, bias: optimizer(
|
|
[weight, bias], nesterov=True, lr=1e-3, momentum=1, weight_decay=1)
|
|
)
|
|
with self.assertRaisesRegex(ValueError,
|
|
'Invalid momentum value: -0.5'):
|
|
optimizer(None, lr=1e-2, momentum=-0.5)
|
|
|
|
def test_ranger(self):
|
|
from easycv.core.optimizer import Ranger
|
|
optimizer = Ranger
|
|
|
|
self._test_basic_cases(
|
|
lambda weight, bias: optimizer([weight, bias], lr=1e-3))
|
|
self._test_basic_cases(lambda weight, bias: optimizer(
|
|
self._build_params_dict(weight, bias, lr=1e-2), lr=1e-3))
|
|
self._test_basic_cases(lambda weight, bias: optimizer(
|
|
self._build_params_dict_single(weight, bias, lr=1e-2), lr=1e-3))
|
|
self._test_basic_cases(lambda weight, bias: optimizer(
|
|
self._build_params_dict_single(weight, bias, lr=1e-2)))
|
|
self._test_basic_cases(
|
|
lambda weight, bias: optimizer([weight, bias], lr=1e-3),
|
|
[lambda opt: StepLR(opt, gamma=0.9, step_size=10)])
|
|
self._test_basic_cases(
|
|
lambda weight, bias: optimizer([weight, bias], lr=1e-3), [
|
|
lambda opt: StepLR(opt, gamma=0.9, step_size=10),
|
|
lambda opt: ReduceLROnPlateau(opt)
|
|
])
|
|
self._test_basic_cases(
|
|
lambda weight, bias: optimizer([weight, bias], lr=1e-3), [
|
|
lambda opt: StepLR(opt, gamma=0.99, step_size=10),
|
|
lambda opt: ExponentialLR(opt, gamma=0.99),
|
|
lambda opt: ReduceLROnPlateau(opt)
|
|
])
|
|
self._test_basic_cases(lambda weight, bias: optimizer(
|
|
[weight, bias],
|
|
lr=1e-3,
|
|
alpha=0.5,
|
|
))
|
|
self._test_basic_cases(lambda weight, bias: optimizer(
|
|
[weight, bias], lr=1e-3, alpha=0.5, weight_decay=1))
|
|
self._test_basic_cases(lambda weight, bias: optimizer(
|
|
[weight, bias], lr=1e-3, alpha=0.5, weight_decay=1))
|
|
with self.assertRaisesRegex(ValueError,
|
|
'Invalid slow update rate: -0.5'):
|
|
optimizer(None, lr=1e-2, alpha=-0.5)
|
|
|
|
|
|
if __name__ == '__main__':
|
|
unittest.main()
|