commit
4d284017b8
|
@ -17,8 +17,8 @@ jobs:
|
|||
matrix:
|
||||
os: [ubuntu-latest, macOS-latest]
|
||||
python: ['3.8']
|
||||
torch: ['1.8.1']
|
||||
torchvision: ['0.9.1']
|
||||
torch: ['1.9.0']
|
||||
torchvision: ['0.10.0']
|
||||
runs-on: ${{ matrix.os }}
|
||||
|
||||
steps:
|
||||
|
@ -43,7 +43,7 @@ jobs:
|
|||
- name: Install requirements
|
||||
run: |
|
||||
if [ -f requirements.txt ]; then pip install -r requirements.txt; fi
|
||||
pip install --no-cache-dir git+https://github.com/mapillary/inplace_abn.git@v1.0.12
|
||||
pip install --no-cache-dir git+https://github.com/mapillary/inplace_abn.git@v1.1.0
|
||||
- name: Run tests
|
||||
env:
|
||||
LD_PRELOAD: /usr/lib/x86_64-linux-gnu/libtcmalloc.so.4
|
||||
|
|
|
@ -255,8 +255,8 @@ class TrainBenchmarkRunner(BenchmarkRunner):
|
|||
|
||||
self.optimizer = create_optimizer_v2(
|
||||
self.model,
|
||||
optimizer_name=kwargs.pop('opt', 'sgd'),
|
||||
learning_rate=kwargs.pop('lr', 1e-4))
|
||||
opt=kwargs.pop('opt', 'sgd'),
|
||||
lr=kwargs.pop('lr', 1e-4))
|
||||
|
||||
def _gen_target(self, batch_size):
|
||||
return torch.empty(
|
||||
|
|
|
@ -0,0 +1,704 @@
|
|||
""" Optimzier Tests
|
||||
|
||||
These tests were adapted from PyTorch' optimizer tests.
|
||||
|
||||
"""
|
||||
import math
|
||||
import pytest
|
||||
import functools
|
||||
from copy import deepcopy
|
||||
|
||||
import torch
|
||||
from torch.testing._internal.common_utils import TestCase
|
||||
from torch.autograd import Variable
|
||||
from timm.scheduler import PlateauLRScheduler
|
||||
|
||||
from timm.optim import create_optimizer_v2
|
||||
|
||||
|
||||
# HACK relying on internal PyTorch test functionality for comparisons that I don't want to write
|
||||
torch_tc = TestCase()
|
||||
|
||||
|
||||
def _test_basic_cases_template(weight, bias, input, constructor, scheduler_constructors):
|
||||
weight = Variable(weight, requires_grad=True)
|
||||
bias = Variable(bias, requires_grad=True)
|
||||
input = Variable(input)
|
||||
optimizer = constructor(weight, bias)
|
||||
schedulers = []
|
||||
for scheduler_constructor in scheduler_constructors:
|
||||
schedulers.append(scheduler_constructor(optimizer))
|
||||
|
||||
# to check if the optimizer can be printed as a string
|
||||
optimizer.__repr__()
|
||||
|
||||
def fn():
|
||||
optimizer.zero_grad()
|
||||
y = weight.mv(input)
|
||||
if y.is_cuda and bias.is_cuda and y.get_device() != bias.get_device():
|
||||
y = y.cuda(bias.get_device())
|
||||
loss = (y + bias).pow(2).sum()
|
||||
loss.backward()
|
||||
return loss
|
||||
|
||||
initial_value = fn().item()
|
||||
for _i in range(200):
|
||||
for scheduler in schedulers:
|
||||
if isinstance(scheduler, PlateauLRScheduler):
|
||||
val_loss = fn()
|
||||
scheduler.step(val_loss)
|
||||
else:
|
||||
scheduler.step()
|
||||
optimizer.step(fn)
|
||||
|
||||
assert fn().item() < initial_value
|
||||
|
||||
|
||||
def _test_state_dict(weight, bias, input, constructor):
|
||||
weight = Variable(weight, requires_grad=True)
|
||||
bias = Variable(bias, requires_grad=True)
|
||||
input = Variable(input)
|
||||
|
||||
def fn_base(optimizer, weight, bias):
|
||||
optimizer.zero_grad()
|
||||
i = input_cuda if weight.is_cuda else input
|
||||
loss = (weight.mv(i) + bias).pow(2).sum()
|
||||
loss.backward()
|
||||
return loss
|
||||
|
||||
optimizer = constructor(weight, bias)
|
||||
fn = functools.partial(fn_base, optimizer, weight, bias)
|
||||
|
||||
# Prime the optimizer
|
||||
for _i in range(20):
|
||||
optimizer.step(fn)
|
||||
# Clone the weights and construct new optimizer for them
|
||||
weight_c = Variable(weight.data.clone(), requires_grad=True)
|
||||
bias_c = Variable(bias.data.clone(), requires_grad=True)
|
||||
optimizer_c = constructor(weight_c, bias_c)
|
||||
fn_c = functools.partial(fn_base, optimizer_c, weight_c, bias_c)
|
||||
# Load state dict
|
||||
state_dict = deepcopy(optimizer.state_dict())
|
||||
state_dict_c = deepcopy(optimizer.state_dict())
|
||||
optimizer_c.load_state_dict(state_dict_c)
|
||||
|
||||
# Run both optimizations in parallel
|
||||
for _i in range(20):
|
||||
optimizer.step(fn)
|
||||
optimizer_c.step(fn_c)
|
||||
#assert torch.equal(weight, weight_c)
|
||||
#assert torch.equal(bias, bias_c)
|
||||
torch_tc.assertEqual(weight, weight_c)
|
||||
torch_tc.assertEqual(bias, bias_c)
|
||||
# Make sure state dict wasn't modified
|
||||
torch_tc.assertEqual(state_dict, state_dict_c)
|
||||
# Make sure state dict is deterministic with equal but not identical parameters
|
||||
torch_tc.assertEqual(optimizer.state_dict(), optimizer_c.state_dict())
|
||||
# Make sure repeated parameters have identical representation in state dict
|
||||
optimizer_c.param_groups.extend(optimizer_c.param_groups)
|
||||
torch_tc.assertEqual(optimizer.state_dict()['param_groups'][-1], optimizer_c.state_dict()['param_groups'][-1])
|
||||
|
||||
# Check that state dict can be loaded even when we cast parameters
|
||||
# to a different type and move to a different device.
|
||||
if not torch.cuda.is_available():
|
||||
return
|
||||
|
||||
input_cuda = Variable(input.data.float().cuda())
|
||||
weight_cuda = Variable(weight.data.float().cuda(), requires_grad=True)
|
||||
bias_cuda = Variable(bias.data.float().cuda(), requires_grad=True)
|
||||
optimizer_cuda = constructor(weight_cuda, bias_cuda)
|
||||
fn_cuda = functools.partial(fn_base, optimizer_cuda, weight_cuda, bias_cuda)
|
||||
|
||||
state_dict = deepcopy(optimizer.state_dict())
|
||||
state_dict_c = deepcopy(optimizer.state_dict())
|
||||
optimizer_cuda.load_state_dict(state_dict_c)
|
||||
|
||||
# Make sure state dict wasn't modified
|
||||
torch_tc.assertEqual(state_dict, state_dict_c)
|
||||
|
||||
for _i in range(20):
|
||||
optimizer.step(fn)
|
||||
optimizer_cuda.step(fn_cuda)
|
||||
torch_tc.assertEqual(weight, weight_cuda)
|
||||
torch_tc.assertEqual(bias, bias_cuda)
|
||||
|
||||
# validate deepcopy() copies all public attributes
|
||||
def getPublicAttr(obj):
|
||||
return set(k for k in obj.__dict__ if not k.startswith('_'))
|
||||
|
||||
assert getPublicAttr(optimizer) == getPublicAttr(deepcopy(optimizer))
|
||||
|
||||
|
||||
def _test_basic_cases(constructor, scheduler_constructors=None):
|
||||
if scheduler_constructors is None:
|
||||
scheduler_constructors = []
|
||||
_test_state_dict(
|
||||
torch.randn(10, 5),
|
||||
torch.randn(10),
|
||||
torch.randn(5),
|
||||
constructor
|
||||
)
|
||||
_test_basic_cases_template(
|
||||
torch.randn(10, 5),
|
||||
torch.randn(10),
|
||||
torch.randn(5),
|
||||
constructor,
|
||||
scheduler_constructors
|
||||
)
|
||||
# non-contiguous parameters
|
||||
_test_basic_cases_template(
|
||||
torch.randn(10, 5, 2)[..., 0],
|
||||
torch.randn(10, 2)[..., 0],
|
||||
torch.randn(5),
|
||||
constructor,
|
||||
scheduler_constructors
|
||||
)
|
||||
# CUDA
|
||||
if not torch.cuda.is_available():
|
||||
return
|
||||
_test_basic_cases_template(
|
||||
torch.randn(10, 5).cuda(),
|
||||
torch.randn(10).cuda(),
|
||||
torch.randn(5).cuda(),
|
||||
constructor,
|
||||
scheduler_constructors
|
||||
)
|
||||
|
||||
|
||||
def _test_model(optimizer, params, device=torch.device('cpu')):
|
||||
weight = torch.tensor(
|
||||
[[-0.2109, -0.4976], [-0.1413, -0.3420], [-0.2524, 0.6976]],
|
||||
device=device, requires_grad=True)
|
||||
bias = torch.tensor([-0.1085, -0.2979, 0.6892], device=device, requires_grad=True)
|
||||
weight2 = torch.tensor([[-0.0508, -0.3941, -0.2843]], device=device, requires_grad=True)
|
||||
bias2 = torch.tensor([-0.0711], device=device, requires_grad=True)
|
||||
input = torch.tensor([0.1, 0.2, 0.3, 0.4, 0.5, 0.6], device=device).reshape(3, 2)
|
||||
|
||||
model = torch.nn.Sequential(torch.nn.Linear(2, 3),
|
||||
torch.nn.Sigmoid(),
|
||||
torch.nn.Linear(3, 1),
|
||||
torch.nn.Sigmoid())
|
||||
model.to(device)
|
||||
|
||||
pretrained_dict = model.state_dict()
|
||||
pretrained_dict['0.weight'] = weight
|
||||
pretrained_dict['0.bias'] = bias
|
||||
pretrained_dict['2.weight'] = weight2
|
||||
pretrained_dict['2.bias'] = bias2
|
||||
model.load_state_dict(pretrained_dict)
|
||||
|
||||
optimizer = create_optimizer_v2(model, opt=optimizer, **params)
|
||||
|
||||
prev_loss = float('inf')
|
||||
for i in range(20):
|
||||
optimizer.zero_grad()
|
||||
output = model(input)
|
||||
loss = output.sum()
|
||||
loss.backward()
|
||||
loss = loss.item()
|
||||
assert loss < prev_loss
|
||||
prev_loss = loss
|
||||
optimizer.step()
|
||||
|
||||
|
||||
def rosenbrock(tensor):
|
||||
x, y = tensor
|
||||
return (1 - x) ** 2 + 100 * (y - x ** 2) ** 2
|
||||
|
||||
|
||||
def drosenbrock(tensor):
|
||||
x, y = tensor
|
||||
return torch.tensor((-400 * x * (y - x ** 2) - 2 * (1 - x), 200 * (y - x ** 2)))
|
||||
|
||||
|
||||
def _test_rosenbrock(constructor, scheduler_constructors=None):
|
||||
if scheduler_constructors is None:
|
||||
scheduler_constructors = []
|
||||
params_t = torch.tensor([1.5, 1.5])
|
||||
|
||||
params = Variable(params_t, requires_grad=True)
|
||||
optimizer = constructor([params])
|
||||
schedulers = []
|
||||
for scheduler_constructor in scheduler_constructors:
|
||||
schedulers.append(scheduler_constructor(optimizer))
|
||||
|
||||
solution = torch.tensor([1, 1])
|
||||
initial_dist = params.data.dist(solution)
|
||||
|
||||
def eval(params, w):
|
||||
# Depending on w, provide only the x or y gradient
|
||||
optimizer.zero_grad()
|
||||
loss = rosenbrock(params)
|
||||
loss.backward()
|
||||
grad = drosenbrock(params.data)
|
||||
# NB: We torture test the optimizer by returning an
|
||||
# uncoalesced sparse tensor
|
||||
if w:
|
||||
i = torch.LongTensor([[0, 0]])
|
||||
x = grad[0]
|
||||
v = torch.tensor([x / 4., x - x / 4.])
|
||||
else:
|
||||
i = torch.LongTensor([[1, 1]])
|
||||
y = grad[1]
|
||||
v = torch.tensor([y - y / 4., y / 4.])
|
||||
x = torch.sparse.DoubleTensor(i, v, torch.Size([2])).to(dtype=v.dtype)
|
||||
with torch.no_grad():
|
||||
params.grad = x.to_dense()
|
||||
return loss
|
||||
|
||||
for i in range(2000):
|
||||
# Do cyclic coordinate descent
|
||||
w = i % 2
|
||||
optimizer.step(functools.partial(eval, params, w))
|
||||
for scheduler in schedulers:
|
||||
if isinstance(scheduler, PlateauLRScheduler):
|
||||
scheduler.step(rosenbrock(params))
|
||||
else:
|
||||
scheduler.step()
|
||||
|
||||
torch_tc.assertLessEqual(params.data.dist(solution), initial_dist)
|
||||
|
||||
|
||||
def _build_params_dict(weight, bias, **kwargs):
|
||||
return [{'params': [weight]}, dict(params=[bias], **kwargs)]
|
||||
|
||||
|
||||
def _build_params_dict_single(weight, bias, **kwargs):
|
||||
return [dict(params=bias, **kwargs)]
|
||||
|
||||
|
||||
@pytest.mark.parametrize('optimizer', ['sgd', 'momentum'])
|
||||
def test_sgd(optimizer):
|
||||
_test_basic_cases(
|
||||
lambda weight, bias: create_optimizer_v2([weight, bias], optimizer, lr=1e-3)
|
||||
)
|
||||
_test_basic_cases(
|
||||
lambda weight, bias: create_optimizer_v2(
|
||||
_build_params_dict(weight, bias, lr=1e-2),
|
||||
optimizer,
|
||||
lr=1e-3)
|
||||
)
|
||||
_test_basic_cases(
|
||||
lambda weight, bias: create_optimizer_v2(
|
||||
_build_params_dict_single(weight, bias, lr=1e-2),
|
||||
optimizer,
|
||||
lr=1e-3)
|
||||
)
|
||||
_test_basic_cases(
|
||||
lambda weight, bias: create_optimizer_v2(
|
||||
_build_params_dict_single(weight, bias, lr=1e-2), optimizer)
|
||||
)
|
||||
# _test_basic_cases(
|
||||
# lambda weight, bias: create_optimizer_v2([weight, bias], optimizer, lr=1e-3),
|
||||
# [lambda opt: StepLR(opt, gamma=0.9, step_size=10)]
|
||||
# )
|
||||
# _test_basic_cases(
|
||||
# lambda weight, bias: create_optimizer_v2([weight, bias], optimizer, lr=1e-3),
|
||||
# [lambda opt: WarmUpLR(opt, warmup_factor=0.4, warmup_iters=4, warmup_method="linear")]
|
||||
# )
|
||||
# _test_basic_cases(
|
||||
# lambda weight, bias: optimizer([weight, bias], lr=1e-3),
|
||||
# [lambda opt: WarmUpLR(opt, warmup_factor=0.4, warmup_iters=4, warmup_method="constant")]
|
||||
# )
|
||||
# _test_basic_cases(
|
||||
# lambda weight, bias: optimizer([weight, bias], lr=1e-3),
|
||||
# [lambda opt: StepLR(opt, gamma=0.9, step_size=10),
|
||||
# lambda opt: WarmUpLR(opt, warmup_factor=0.4, warmup_iters=4)]
|
||||
# )
|
||||
# _test_basic_cases(
|
||||
# lambda weight, bias: optimizer([weight, bias], lr=1e-3),
|
||||
# [lambda opt: StepLR(opt, gamma=0.9, step_size=10),
|
||||
# lambda opt: ReduceLROnPlateau(opt)]
|
||||
# )
|
||||
# _test_basic_cases(
|
||||
# lambda weight, bias: optimizer([weight, bias], lr=1e-3),
|
||||
# [lambda opt: StepLR(opt, gamma=0.99, step_size=10),
|
||||
# lambda opt: ExponentialLR(opt, gamma=0.99),
|
||||
# lambda opt: ReduceLROnPlateau(opt)]
|
||||
# )
|
||||
_test_basic_cases(
|
||||
lambda weight, bias: create_optimizer_v2([weight, bias], optimizer, lr=1e-3, momentum=1)
|
||||
)
|
||||
_test_basic_cases(
|
||||
lambda weight, bias: create_optimizer_v2([weight, bias], optimizer, lr=1e-3, momentum=1, weight_decay=1)
|
||||
)
|
||||
_test_rosenbrock(
|
||||
lambda params: create_optimizer_v2(params, optimizer, lr=1e-3)
|
||||
)
|
||||
_test_model(optimizer, dict(lr=1e-3))
|
||||
|
||||
|
||||
@pytest.mark.parametrize('optimizer', ['adamw', 'adam', 'nadam', 'adamax'])
|
||||
def test_adam(optimizer):
|
||||
_test_basic_cases(
|
||||
lambda weight, bias: create_optimizer_v2([weight, bias], optimizer, lr=1e-3)
|
||||
)
|
||||
_test_basic_cases(
|
||||
lambda weight, bias: create_optimizer_v2(
|
||||
_build_params_dict(weight, bias, lr=3e-3),
|
||||
optimizer,
|
||||
lr=1e-3)
|
||||
)
|
||||
_test_basic_cases(
|
||||
lambda weight, bias: create_optimizer_v2(
|
||||
_build_params_dict_single(weight, bias, lr=3e-3),
|
||||
optimizer,
|
||||
lr=1e-3)
|
||||
)
|
||||
_test_rosenbrock(
|
||||
lambda params: create_optimizer_v2(params, optimizer, lr=5e-2)
|
||||
)
|
||||
_test_model(optimizer, dict(lr=5e-2))
|
||||
|
||||
|
||||
@pytest.mark.parametrize('optimizer', ['adabelief'])
|
||||
def test_adabelief(optimizer):
|
||||
_test_basic_cases(
|
||||
lambda weight, bias: create_optimizer_v2([weight, bias], optimizer, lr=1e-3)
|
||||
)
|
||||
_test_basic_cases(
|
||||
lambda weight, bias: create_optimizer_v2(
|
||||
_build_params_dict(weight, bias, lr=3e-3),
|
||||
optimizer,
|
||||
lr=1e-3)
|
||||
)
|
||||
_test_basic_cases(
|
||||
lambda weight, bias: create_optimizer_v2(
|
||||
_build_params_dict_single(weight, bias, lr=3e-3),
|
||||
optimizer,
|
||||
lr=1e-3)
|
||||
)
|
||||
_test_basic_cases(
|
||||
lambda weight, bias: create_optimizer_v2(
|
||||
_build_params_dict_single(weight, bias, lr=3e-3), optimizer)
|
||||
)
|
||||
_test_basic_cases(
|
||||
lambda weight, bias: create_optimizer_v2([weight, bias], optimizer, lr=1e-3, weight_decay=1)
|
||||
)
|
||||
_test_rosenbrock(
|
||||
lambda params: create_optimizer_v2(params, optimizer, lr=5e-2)
|
||||
)
|
||||
_test_model(optimizer, dict(lr=5e-2))
|
||||
|
||||
|
||||
@pytest.mark.parametrize('optimizer', ['radam', 'radabelief'])
|
||||
def test_rectified(optimizer):
|
||||
_test_basic_cases(
|
||||
lambda weight, bias: create_optimizer_v2([weight, bias], optimizer, lr=1e-3)
|
||||
)
|
||||
_test_basic_cases(
|
||||
lambda weight, bias: create_optimizer_v2(
|
||||
_build_params_dict(weight, bias, lr=3e-3),
|
||||
optimizer,
|
||||
lr=1e-3)
|
||||
)
|
||||
_test_basic_cases(
|
||||
lambda weight, bias: create_optimizer_v2(
|
||||
_build_params_dict_single(weight, bias, lr=3e-3),
|
||||
optimizer,
|
||||
lr=1e-3)
|
||||
)
|
||||
_test_rosenbrock(
|
||||
lambda params: create_optimizer_v2(params, optimizer, lr=1e-3)
|
||||
)
|
||||
_test_model(optimizer, dict(lr=1e-3))
|
||||
|
||||
|
||||
@pytest.mark.parametrize('optimizer', ['adadelta', 'adagrad'])
|
||||
def test_adaother(optimizer):
|
||||
_test_basic_cases(
|
||||
lambda weight, bias: create_optimizer_v2([weight, bias], optimizer, lr=1e-3)
|
||||
)
|
||||
_test_basic_cases(
|
||||
lambda weight, bias: create_optimizer_v2(
|
||||
_build_params_dict(weight, bias, lr=3e-3),
|
||||
optimizer,
|
||||
lr=1e-3)
|
||||
)
|
||||
_test_basic_cases(
|
||||
lambda weight, bias: create_optimizer_v2(
|
||||
_build_params_dict_single(weight, bias, lr=3e-3),
|
||||
optimizer,
|
||||
lr=1e-3)
|
||||
)
|
||||
_test_basic_cases(
|
||||
lambda weight, bias: create_optimizer_v2(
|
||||
_build_params_dict_single(weight, bias, lr=3e-3), optimizer)
|
||||
)
|
||||
_test_basic_cases(
|
||||
lambda weight, bias: create_optimizer_v2([weight, bias], optimizer, lr=1e-3, weight_decay=1)
|
||||
)
|
||||
_test_rosenbrock(
|
||||
lambda params: create_optimizer_v2(params, optimizer, lr=1e-1)
|
||||
)
|
||||
_test_model(optimizer, dict(lr=5e-2))
|
||||
|
||||
|
||||
@pytest.mark.parametrize('optimizer', ['adafactor'])
|
||||
def test_adafactor(optimizer):
|
||||
_test_basic_cases(
|
||||
lambda weight, bias: create_optimizer_v2([weight, bias], optimizer, lr=1e-3)
|
||||
)
|
||||
_test_basic_cases(
|
||||
lambda weight, bias: create_optimizer_v2(
|
||||
_build_params_dict(weight, bias, lr=3e-3),
|
||||
optimizer,
|
||||
lr=1e-3)
|
||||
)
|
||||
_test_basic_cases(
|
||||
lambda weight, bias: create_optimizer_v2(
|
||||
_build_params_dict_single(weight, bias, lr=3e-3),
|
||||
optimizer,
|
||||
lr=1e-3)
|
||||
)
|
||||
_test_basic_cases(
|
||||
lambda weight, bias: create_optimizer_v2(_build_params_dict_single(weight, bias), optimizer)
|
||||
)
|
||||
_test_basic_cases(
|
||||
lambda weight, bias: create_optimizer_v2([weight, bias], optimizer, lr=1e-3, weight_decay=1)
|
||||
)
|
||||
_test_rosenbrock(
|
||||
lambda params: create_optimizer_v2(params, optimizer, lr=5e-2)
|
||||
)
|
||||
_test_model(optimizer, dict(lr=5e-2))
|
||||
|
||||
|
||||
@pytest.mark.parametrize('optimizer', ['lamb'])
|
||||
def test_lamb(optimizer):
|
||||
_test_basic_cases(
|
||||
lambda weight, bias: create_optimizer_v2([weight, bias], optimizer, lr=1e-3)
|
||||
)
|
||||
_test_basic_cases(
|
||||
lambda weight, bias: create_optimizer_v2(
|
||||
_build_params_dict(weight, bias, lr=3e-3),
|
||||
optimizer,
|
||||
lr=1e-3)
|
||||
)
|
||||
_test_basic_cases(
|
||||
lambda weight, bias: create_optimizer_v2(
|
||||
_build_params_dict_single(weight, bias, lr=3e-3),
|
||||
optimizer,
|
||||
lr=1e-3)
|
||||
)
|
||||
_test_basic_cases(
|
||||
lambda weight, bias: create_optimizer_v2(
|
||||
_build_params_dict_single(weight, bias, lr=3e-3), optimizer)
|
||||
)
|
||||
_test_rosenbrock(
|
||||
lambda params: create_optimizer_v2(params, optimizer, lr=1e-3)
|
||||
)
|
||||
_test_model(optimizer, dict(lr=1e-3))
|
||||
|
||||
|
||||
@pytest.mark.parametrize('optimizer', ['madgrad', 'madgradw'])
|
||||
def test_madgrad(optimizer):
|
||||
_test_basic_cases(
|
||||
lambda weight, bias: create_optimizer_v2([weight, bias], optimizer, lr=1e-3)
|
||||
)
|
||||
_test_basic_cases(
|
||||
lambda weight, bias: create_optimizer_v2(
|
||||
_build_params_dict(weight, bias, lr=3e-3),
|
||||
optimizer,
|
||||
lr=1e-3)
|
||||
)
|
||||
_test_basic_cases(
|
||||
lambda weight, bias: create_optimizer_v2(
|
||||
_build_params_dict_single(weight, bias, lr=3e-3),
|
||||
optimizer,
|
||||
lr=1e-3)
|
||||
)
|
||||
_test_basic_cases(
|
||||
lambda weight, bias: create_optimizer_v2(
|
||||
_build_params_dict_single(weight, bias, lr=3e-3), optimizer)
|
||||
)
|
||||
_test_rosenbrock(
|
||||
lambda params: create_optimizer_v2(params, optimizer, lr=1e-2)
|
||||
)
|
||||
_test_model(optimizer, dict(lr=1e-2))
|
||||
|
||||
|
||||
@pytest.mark.parametrize('optimizer', ['novograd'])
|
||||
def test_novograd(optimizer):
|
||||
_test_basic_cases(
|
||||
lambda weight, bias: create_optimizer_v2([weight, bias], optimizer, lr=1e-3)
|
||||
)
|
||||
_test_basic_cases(
|
||||
lambda weight, bias: create_optimizer_v2(
|
||||
_build_params_dict(weight, bias, lr=3e-3),
|
||||
optimizer,
|
||||
lr=1e-3)
|
||||
)
|
||||
_test_basic_cases(
|
||||
lambda weight, bias: create_optimizer_v2(
|
||||
_build_params_dict_single(weight, bias, lr=3e-3),
|
||||
optimizer,
|
||||
lr=1e-3)
|
||||
)
|
||||
_test_basic_cases(
|
||||
lambda weight, bias: create_optimizer_v2(
|
||||
_build_params_dict_single(weight, bias, lr=3e-3), optimizer)
|
||||
)
|
||||
_test_rosenbrock(
|
||||
lambda params: create_optimizer_v2(params, optimizer, lr=1e-3)
|
||||
)
|
||||
_test_model(optimizer, dict(lr=1e-3))
|
||||
|
||||
|
||||
@pytest.mark.parametrize('optimizer', ['rmsprop', 'rmsproptf'])
|
||||
def test_rmsprop(optimizer):
|
||||
_test_basic_cases(
|
||||
lambda weight, bias: create_optimizer_v2([weight, bias], optimizer, lr=1e-3)
|
||||
)
|
||||
_test_basic_cases(
|
||||
lambda weight, bias: create_optimizer_v2(
|
||||
_build_params_dict(weight, bias, lr=3e-3),
|
||||
optimizer,
|
||||
lr=1e-3)
|
||||
)
|
||||
_test_basic_cases(
|
||||
lambda weight, bias: create_optimizer_v2(
|
||||
_build_params_dict_single(weight, bias, lr=3e-3),
|
||||
optimizer,
|
||||
lr=1e-3)
|
||||
)
|
||||
_test_basic_cases(
|
||||
lambda weight, bias: create_optimizer_v2(
|
||||
_build_params_dict_single(weight, bias, lr=3e-3), optimizer)
|
||||
)
|
||||
_test_rosenbrock(
|
||||
lambda params: create_optimizer_v2(params, optimizer, lr=1e-2)
|
||||
)
|
||||
_test_model(optimizer, dict(lr=1e-2))
|
||||
|
||||
|
||||
@pytest.mark.parametrize('optimizer', ['adamp'])
|
||||
def test_adamp(optimizer):
|
||||
_test_basic_cases(
|
||||
lambda weight, bias: create_optimizer_v2([weight, bias], optimizer, lr=1e-3)
|
||||
)
|
||||
_test_basic_cases(
|
||||
lambda weight, bias: create_optimizer_v2(
|
||||
_build_params_dict(weight, bias, lr=3e-3),
|
||||
optimizer,
|
||||
lr=1e-3)
|
||||
)
|
||||
_test_basic_cases(
|
||||
lambda weight, bias: create_optimizer_v2(
|
||||
_build_params_dict_single(weight, bias, lr=3e-3),
|
||||
optimizer,
|
||||
lr=1e-3)
|
||||
)
|
||||
_test_basic_cases(
|
||||
lambda weight, bias: create_optimizer_v2(
|
||||
_build_params_dict_single(weight, bias, lr=3e-3), optimizer)
|
||||
)
|
||||
_test_rosenbrock(
|
||||
lambda params: create_optimizer_v2(params, optimizer, lr=5e-2)
|
||||
)
|
||||
_test_model(optimizer, dict(lr=5e-2))
|
||||
|
||||
|
||||
@pytest.mark.parametrize('optimizer', ['sgdp'])
|
||||
def test_sgdp(optimizer):
|
||||
_test_basic_cases(
|
||||
lambda weight, bias: create_optimizer_v2([weight, bias], optimizer, lr=1e-3)
|
||||
)
|
||||
_test_basic_cases(
|
||||
lambda weight, bias: create_optimizer_v2(
|
||||
_build_params_dict(weight, bias, lr=3e-3),
|
||||
optimizer,
|
||||
lr=1e-3)
|
||||
)
|
||||
_test_basic_cases(
|
||||
lambda weight, bias: create_optimizer_v2(
|
||||
_build_params_dict_single(weight, bias, lr=3e-3),
|
||||
optimizer,
|
||||
lr=1e-3)
|
||||
)
|
||||
_test_basic_cases(
|
||||
lambda weight, bias: create_optimizer_v2(
|
||||
_build_params_dict_single(weight, bias, lr=3e-3), optimizer)
|
||||
)
|
||||
_test_rosenbrock(
|
||||
lambda params: create_optimizer_v2(params, optimizer, lr=1e-3)
|
||||
)
|
||||
_test_model(optimizer, dict(lr=1e-3))
|
||||
|
||||
|
||||
@pytest.mark.parametrize('optimizer', ['lookahead_sgd', 'lookahead_momentum'])
|
||||
def test_lookahead_sgd(optimizer):
|
||||
_test_basic_cases(
|
||||
lambda weight, bias: create_optimizer_v2([weight, bias], optimizer, lr=1e-3)
|
||||
)
|
||||
_test_basic_cases(
|
||||
lambda weight, bias: create_optimizer_v2(
|
||||
_build_params_dict(weight, bias, lr=3e-3),
|
||||
optimizer,
|
||||
lr=1e-3)
|
||||
)
|
||||
_test_basic_cases(
|
||||
lambda weight, bias: create_optimizer_v2(
|
||||
_build_params_dict_single(weight, bias, lr=3e-3),
|
||||
optimizer,
|
||||
lr=1e-3)
|
||||
)
|
||||
_test_basic_cases(
|
||||
lambda weight, bias: create_optimizer_v2(
|
||||
_build_params_dict_single(weight, bias, lr=3e-3), optimizer)
|
||||
)
|
||||
_test_rosenbrock(
|
||||
lambda params: create_optimizer_v2(params, optimizer, lr=1e-3)
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize('optimizer', ['lookahead_adamw', 'lookahead_adam'])
|
||||
def test_lookahead_adam(optimizer):
|
||||
_test_basic_cases(
|
||||
lambda weight, bias: create_optimizer_v2([weight, bias], optimizer, lr=1e-3)
|
||||
)
|
||||
_test_basic_cases(
|
||||
lambda weight, bias: create_optimizer_v2(
|
||||
_build_params_dict(weight, bias, lr=3e-3),
|
||||
optimizer,
|
||||
lr=1e-3)
|
||||
)
|
||||
_test_basic_cases(
|
||||
lambda weight, bias: create_optimizer_v2(
|
||||
_build_params_dict_single(weight, bias, lr=3e-3),
|
||||
optimizer,
|
||||
lr=1e-3)
|
||||
)
|
||||
_test_basic_cases(
|
||||
lambda weight, bias: create_optimizer_v2(
|
||||
_build_params_dict_single(weight, bias, lr=3e-3), optimizer)
|
||||
)
|
||||
_test_rosenbrock(
|
||||
lambda params: create_optimizer_v2(params, optimizer, lr=5e-2)
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize('optimizer', ['lookahead_radam'])
|
||||
def test_lookahead_radam(optimizer):
|
||||
_test_basic_cases(
|
||||
lambda weight, bias: create_optimizer_v2([weight, bias], optimizer, lr=1e-3)
|
||||
)
|
||||
_test_basic_cases(
|
||||
lambda weight, bias: create_optimizer_v2(
|
||||
_build_params_dict(weight, bias, lr=3e-3),
|
||||
optimizer,
|
||||
lr=1e-3)
|
||||
)
|
||||
_test_basic_cases(
|
||||
lambda weight, bias: create_optimizer_v2(
|
||||
_build_params_dict_single(weight, bias, lr=3e-3),
|
||||
optimizer,
|
||||
lr=1e-3)
|
||||
)
|
||||
_test_basic_cases(
|
||||
lambda weight, bias: create_optimizer_v2(
|
||||
_build_params_dict_single(weight, bias, lr=3e-3), optimizer)
|
||||
)
|
||||
_test_rosenbrock(
|
||||
lambda params: create_optimizer_v2(params, optimizer, lr=1e-4)
|
||||
)
|
||||
|
|
@ -3,8 +3,8 @@ from .adamw import AdamW
|
|||
from .adafactor import Adafactor
|
||||
from .adahessian import Adahessian
|
||||
from .lookahead import Lookahead
|
||||
from .madgrad import MADGRAD
|
||||
from .nadam import Nadam
|
||||
from .novograd import NovoGrad
|
||||
from .nvnovograd import NvNovoGrad
|
||||
from .radam import RAdam
|
||||
from .rmsprop_tf import RMSpropTF
|
||||
|
|
|
@ -18,7 +18,7 @@ class AdaBelief(Optimizer):
|
|||
amsgrad (boolean, optional): whether to use the AMSGrad variant of this
|
||||
algorithm from the paper `On the Convergence of Adam and Beyond`_
|
||||
(default: False)
|
||||
weight_decouple (boolean, optional): ( default: True) If set as True, then
|
||||
decoupled_decay (boolean, optional): (default: True) If set as True, then
|
||||
the optimizer uses decoupled weight decay as in AdamW
|
||||
fixed_decay (boolean, optional): (default: False) This is used when weight_decouple
|
||||
is set as True.
|
||||
|
@ -39,9 +39,9 @@ class AdaBelief(Optimizer):
|
|||
- link to args.yaml: https://gist.github.com/juntang-zhuang/517ce3c27022b908bb93f78e4f786dc3
|
||||
"""
|
||||
|
||||
def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-16,
|
||||
weight_decay=0, amsgrad=False, weight_decouple=True, fixed_decay=False, rectify=True,
|
||||
degenerated_to_sgd=True):
|
||||
def __init__(
|
||||
self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-16, weight_decay=0, amsgrad=False,
|
||||
decoupled_decay=True, fixed_decay=False, rectify=True, degenerated_to_sgd=True):
|
||||
|
||||
if not 0.0 <= lr:
|
||||
raise ValueError("Invalid learning rate: {}".format(lr))
|
||||
|
@ -52,21 +52,17 @@ class AdaBelief(Optimizer):
|
|||
if not 0.0 <= betas[1] < 1.0:
|
||||
raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1]))
|
||||
|
||||
self.degenerated_to_sgd = degenerated_to_sgd
|
||||
if isinstance(params, (list, tuple)) and len(params) > 0 and isinstance(params[0], dict):
|
||||
for param in params:
|
||||
if 'betas' in param and (param['betas'][0] != betas[0] or param['betas'][1] != betas[1]):
|
||||
param['buffer'] = [[None, None, None] for _ in range(10)]
|
||||
|
||||
defaults = dict(lr=lr, betas=betas, eps=eps,
|
||||
weight_decay=weight_decay, amsgrad=amsgrad, buffer=[[None, None, None] for _ in range(10)])
|
||||
defaults = dict(
|
||||
lr=lr, betas=betas, eps=eps, weight_decay=weight_decay, amsgrad=amsgrad,
|
||||
degenerated_to_sgd=degenerated_to_sgd, decoupled_decay=decoupled_decay, rectify=rectify,
|
||||
fixed_decay=fixed_decay, buffer=[[None, None, None] for _ in range(10)])
|
||||
super(AdaBelief, self).__init__(params, defaults)
|
||||
|
||||
self.degenerated_to_sgd = degenerated_to_sgd
|
||||
self.weight_decouple = weight_decouple
|
||||
self.rectify = rectify
|
||||
self.fixed_decay = fixed_decay
|
||||
|
||||
def __setstate__(self, state):
|
||||
super(AdaBelief, self).__setstate__(state)
|
||||
for group in self.param_groups:
|
||||
|
@ -133,8 +129,8 @@ class AdaBelief(Optimizer):
|
|||
state['max_exp_avg_var'] = torch.zeros_like(p.data)
|
||||
|
||||
# perform weight decay, check if decoupled weight decay
|
||||
if self.weight_decouple:
|
||||
if not self.fixed_decay:
|
||||
if group['decoupled_decay']:
|
||||
if not group['fixed_decay']:
|
||||
p.data.mul_(1.0 - group['lr'] * group['weight_decay'])
|
||||
else:
|
||||
p.data.mul_(1.0 - group['weight_decay'])
|
||||
|
@ -152,7 +148,7 @@ class AdaBelief(Optimizer):
|
|||
# Update first and second moment running average
|
||||
exp_avg.mul_(beta1).add_(grad, alpha=1 - beta1)
|
||||
grad_residual = grad - exp_avg
|
||||
exp_avg_var.mul_(beta2).addcmul_( grad_residual, grad_residual, value=1 - beta2)
|
||||
exp_avg_var.mul_(beta2).addcmul_(grad_residual, grad_residual, value=1 - beta2)
|
||||
|
||||
if amsgrad:
|
||||
max_exp_avg_var = state['max_exp_avg_var']
|
||||
|
@ -165,38 +161,40 @@ class AdaBelief(Optimizer):
|
|||
denom = (exp_avg_var.add_(group['eps']).sqrt() / math.sqrt(bias_correction2)).add_(group['eps'])
|
||||
|
||||
# update
|
||||
if not self.rectify:
|
||||
if not group['rectify']:
|
||||
# Default update
|
||||
step_size = group['lr'] / bias_correction1
|
||||
p.data.addcdiv_( exp_avg, denom, value=-step_size)
|
||||
|
||||
else: # Rectified update, forked from RAdam
|
||||
p.data.addcdiv_(exp_avg, denom, value=-step_size)
|
||||
else:
|
||||
# Rectified update, forked from RAdam
|
||||
buffered = group['buffer'][int(state['step'] % 10)]
|
||||
if state['step'] == buffered[0]:
|
||||
N_sma, step_size = buffered[1], buffered[2]
|
||||
num_sma, step_size = buffered[1], buffered[2]
|
||||
else:
|
||||
buffered[0] = state['step']
|
||||
beta2_t = beta2 ** state['step']
|
||||
N_sma_max = 2 / (1 - beta2) - 1
|
||||
N_sma = N_sma_max - 2 * state['step'] * beta2_t / (1 - beta2_t)
|
||||
buffered[1] = N_sma
|
||||
num_sma_max = 2 / (1 - beta2) - 1
|
||||
num_sma = num_sma_max - 2 * state['step'] * beta2_t / (1 - beta2_t)
|
||||
buffered[1] = num_sma
|
||||
|
||||
# more conservative since it's an approximated value
|
||||
if N_sma >= 5:
|
||||
if num_sma >= 5:
|
||||
step_size = math.sqrt(
|
||||
(1 - beta2_t) * (N_sma - 4) / (N_sma_max - 4) * (N_sma - 2) / N_sma * N_sma_max / (
|
||||
N_sma_max - 2)) / (1 - beta1 ** state['step'])
|
||||
elif self.degenerated_to_sgd:
|
||||
(1 - beta2_t) *
|
||||
(num_sma - 4) / (num_sma_max - 4) *
|
||||
(num_sma - 2) / num_sma *
|
||||
num_sma_max / (num_sma_max - 2)) / (1 - beta1 ** state['step'])
|
||||
elif group['degenerated_to_sgd']:
|
||||
step_size = 1.0 / (1 - beta1 ** state['step'])
|
||||
else:
|
||||
step_size = -1
|
||||
buffered[2] = step_size
|
||||
|
||||
if N_sma >= 5:
|
||||
if num_sma >= 5:
|
||||
denom = exp_avg_var.sqrt().add_(group['eps'])
|
||||
p.data.addcdiv_(exp_avg, denom, value=-step_size * group['lr'])
|
||||
elif step_size > 0:
|
||||
p.data.add_( exp_avg, alpha=-step_size * group['lr'])
|
||||
p.data.add_(exp_avg, alpha=-step_size * group['lr'])
|
||||
|
||||
if half_precision:
|
||||
p.data = p.data.half()
|
||||
|
|
|
@ -34,15 +34,13 @@ class Adafactor(torch.optim.Optimizer):
|
|||
beta1 (float): coefficient used for computing running averages of gradient (default: None)
|
||||
weight_decay (float, optional): weight decay (L2 penalty) (default: 0)
|
||||
scale_parameter (bool): if True, learning rate is scaled by root mean square of parameter (default: True)
|
||||
relative_step (bool): if True, time-dependent learning rate is computed
|
||||
instead of external learning rate (default: True)
|
||||
warmup_init (bool): time-dependent learning rate computation depends on
|
||||
whether warm-up initialization is being used (default: False)
|
||||
"""
|
||||
|
||||
def __init__(self, params, lr=None, eps=1e-30, eps_scale=1e-3, clip_threshold=1.0,
|
||||
decay_rate=-0.8, betas=None, weight_decay=0.0, scale_parameter=True, warmup_init=False):
|
||||
relative_step = lr is None
|
||||
relative_step = not lr
|
||||
if warmup_init and not relative_step:
|
||||
raise ValueError('warmup_init requires relative_step=True')
|
||||
|
||||
|
@ -138,10 +136,8 @@ class Adafactor(torch.optim.Optimizer):
|
|||
exp_avg_sq_row = state['exp_avg_sq_row']
|
||||
exp_avg_sq_col = state['exp_avg_sq_col']
|
||||
|
||||
exp_avg_sq_row.mul_(beta2t).add_(1.0 - beta2t, update.mean(dim=-1))
|
||||
exp_avg_sq_col.mul_(beta2t).add_(1.0 - beta2t, update.mean(dim=-2))
|
||||
#exp_avg_sq_row.mul_(beta2t).add_(update.mean(dim=-1), alpha=1.0 - beta2t) # pytorch 1.6+
|
||||
#exp_avg_sq_col.mul_(beta2t).add_(update.mean(dim=-2), alpha=1.0 - beta2t)
|
||||
exp_avg_sq_row.mul_(beta2t).add_(update.mean(dim=-1), alpha=1.0 - beta2t)
|
||||
exp_avg_sq_col.mul_(beta2t).add_(update.mean(dim=-2), alpha=1.0 - beta2t)
|
||||
|
||||
# Approximation of exponential moving average of square of gradient
|
||||
update = self._approx_sq_grad(exp_avg_sq_row, exp_avg_sq_col)
|
||||
|
@ -149,8 +145,7 @@ class Adafactor(torch.optim.Optimizer):
|
|||
else:
|
||||
exp_avg_sq = state['exp_avg_sq']
|
||||
|
||||
exp_avg_sq.mul_(beta2t).add_(1.0 - beta2t, update)
|
||||
#exp_avg_sq.mul_(beta2t).add_(update, alpha=1.0 - beta2t) # pytorch 1.6+
|
||||
exp_avg_sq.mul_(beta2t).add_(update, alpha=1.0 - beta2t)
|
||||
update = exp_avg_sq.rsqrt().mul_(grad)
|
||||
|
||||
update.div_((self._rms(update) / group['clip_threshold']).clamp_(min=1.0))
|
||||
|
@ -158,17 +153,15 @@ class Adafactor(torch.optim.Optimizer):
|
|||
|
||||
if use_first_moment:
|
||||
exp_avg = state['exp_avg']
|
||||
exp_avg.mul_(group["beta1"]).add_(1 - group["beta1"], update)
|
||||
#exp_avg.mul_(group['beta1']).add_(update, alpha=1 - group['beta1']) # pytorch 1.6+
|
||||
exp_avg.mul_(group['beta1']).add_(update, alpha=1 - group['beta1'])
|
||||
update = exp_avg
|
||||
|
||||
if group['weight_decay'] != 0:
|
||||
p_data_fp32.add_(-group["weight_decay"] * lr_t, p_data_fp32)
|
||||
#p_data_fp32.add_(p_data_fp32, alpha=-group['weight_decay'] * lr_t) # pytorch 1.6+
|
||||
p_data_fp32.add_(p_data_fp32, alpha=-group['weight_decay'] * lr_t)
|
||||
|
||||
p_data_fp32.add_(-update)
|
||||
|
||||
if p.data.dtype in {torch.float16, torch.bfloat16}:
|
||||
p.data.copy_(p_data_fp32)
|
||||
|
||||
return loss
|
||||
return loss
|
||||
|
|
|
@ -9,49 +9,44 @@ MIT license
|
|||
"""
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from torch.optim.optimizer import Optimizer, required
|
||||
import torch.nn.functional as F
|
||||
from torch.optim.optimizer import Optimizer
|
||||
import math
|
||||
|
||||
|
||||
def _channel_view(x) -> torch.Tensor:
|
||||
return x.reshape(x.size(0), -1)
|
||||
|
||||
|
||||
def _layer_view(x) -> torch.Tensor:
|
||||
return x.reshape(1, -1)
|
||||
|
||||
|
||||
def projection(p, grad, perturb, delta: float, wd_ratio: float, eps: float):
|
||||
wd = 1.
|
||||
expand_size = (-1,) + (1,) * (len(p.shape) - 1)
|
||||
for view_func in [_channel_view, _layer_view]:
|
||||
param_view = view_func(p.data)
|
||||
grad_view = view_func(grad)
|
||||
cosine_sim = F.cosine_similarity(grad_view, param_view, dim=1, eps=eps).abs_()
|
||||
|
||||
if cosine_sim.max() < delta / math.sqrt(param_view.size(1)):
|
||||
p_n = p.data / param_view.norm(p=2, dim=1).add_(eps).reshape(expand_size)
|
||||
perturb -= p_n * view_func(p_n * perturb).sum(dim=1).reshape(expand_size)
|
||||
wd = wd_ratio
|
||||
return perturb, wd
|
||||
|
||||
return perturb, wd
|
||||
|
||||
|
||||
class AdamP(Optimizer):
|
||||
def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8,
|
||||
weight_decay=0, delta=0.1, wd_ratio=0.1, nesterov=False):
|
||||
defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay,
|
||||
delta=delta, wd_ratio=wd_ratio, nesterov=nesterov)
|
||||
defaults = dict(
|
||||
lr=lr, betas=betas, eps=eps, weight_decay=weight_decay,
|
||||
delta=delta, wd_ratio=wd_ratio, nesterov=nesterov)
|
||||
super(AdamP, self).__init__(params, defaults)
|
||||
|
||||
def _channel_view(self, x):
|
||||
return x.view(x.size(0), -1)
|
||||
|
||||
def _layer_view(self, x):
|
||||
return x.view(1, -1)
|
||||
|
||||
def _cosine_similarity(self, x, y, eps, view_func):
|
||||
x = view_func(x)
|
||||
y = view_func(y)
|
||||
|
||||
x_norm = x.norm(dim=1).add_(eps)
|
||||
y_norm = y.norm(dim=1).add_(eps)
|
||||
dot = (x * y).sum(dim=1)
|
||||
|
||||
return dot.abs() / x_norm / y_norm
|
||||
|
||||
def _projection(self, p, grad, perturb, delta, wd_ratio, eps):
|
||||
wd = 1
|
||||
expand_size = [-1] + [1] * (len(p.shape) - 1)
|
||||
for view_func in [self._channel_view, self._layer_view]:
|
||||
|
||||
cosine_sim = self._cosine_similarity(grad, p.data, eps, view_func)
|
||||
|
||||
if cosine_sim.max() < delta / math.sqrt(view_func(p.data).size(1)):
|
||||
p_n = p.data / view_func(p.data).norm(dim=1).view(expand_size).add_(eps)
|
||||
perturb -= p_n * view_func(p_n * perturb).sum(dim=1).view(expand_size)
|
||||
wd = wd_ratio
|
||||
|
||||
return perturb, wd
|
||||
|
||||
return perturb, wd
|
||||
|
||||
def step(self, closure=None):
|
||||
loss = None
|
||||
if closure is not None:
|
||||
|
@ -81,8 +76,8 @@ class AdamP(Optimizer):
|
|||
bias_correction1 = 1 - beta1 ** state['step']
|
||||
bias_correction2 = 1 - beta2 ** state['step']
|
||||
|
||||
exp_avg.mul_(beta1).add_(1 - beta1, grad)
|
||||
exp_avg_sq.mul_(beta2).addcmul_(1 - beta2, grad, grad)
|
||||
exp_avg.mul_(beta1).add_(grad, alpha=1 - beta1)
|
||||
exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1 - beta2)
|
||||
|
||||
denom = (exp_avg_sq.sqrt() / math.sqrt(bias_correction2)).add_(group['eps'])
|
||||
step_size = group['lr'] / bias_correction1
|
||||
|
@ -93,15 +88,15 @@ class AdamP(Optimizer):
|
|||
perturb = exp_avg / denom
|
||||
|
||||
# Projection
|
||||
wd_ratio = 1
|
||||
wd_ratio = 1.
|
||||
if len(p.shape) > 1:
|
||||
perturb, wd_ratio = self._projection(p, grad, perturb, group['delta'], group['wd_ratio'], group['eps'])
|
||||
perturb, wd_ratio = projection(p, grad, perturb, group['delta'], group['wd_ratio'], group['eps'])
|
||||
|
||||
# Weight decay
|
||||
if group['weight_decay'] > 0:
|
||||
p.data.mul_(1 - group['lr'] * group['weight_decay'] * wd_ratio)
|
||||
p.data.mul_(1. - group['lr'] * group['weight_decay'] * wd_ratio)
|
||||
|
||||
# Step
|
||||
p.data.add_(-step_size, perturb)
|
||||
p.data.add_(perturb, alpha=-step_size)
|
||||
|
||||
return loss
|
||||
|
|
|
@ -1,5 +1,8 @@
|
|||
""" AdamW Optimizer
|
||||
Impl copied from PyTorch master
|
||||
|
||||
NOTE: Builtin optim.AdamW is used by the factory, this impl only serves as a Python based reference, will be removed
|
||||
someday
|
||||
"""
|
||||
import math
|
||||
import torch
|
||||
|
@ -100,8 +103,8 @@ class AdamW(Optimizer):
|
|||
bias_correction2 = 1 - beta2 ** state['step']
|
||||
|
||||
# Decay the first and second moment running average coefficient
|
||||
exp_avg.mul_(beta1).add_(1 - beta1, grad)
|
||||
exp_avg_sq.mul_(beta2).addcmul_(1 - beta2, grad, grad)
|
||||
exp_avg.mul_(beta1).add_(grad, alpha=1 - beta1)
|
||||
exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1 - beta2)
|
||||
if amsgrad:
|
||||
# Maintains the maximum of all 2nd moment running avg. till now
|
||||
torch.max(max_exp_avg_sq, exp_avg_sq, out=max_exp_avg_sq)
|
||||
|
@ -112,6 +115,6 @@ class AdamW(Optimizer):
|
|||
|
||||
step_size = group['lr'] / bias_correction1
|
||||
|
||||
p.data.addcdiv_(-step_size, exp_avg, denom)
|
||||
p.data.addcdiv_(exp_avg, denom, value=-step_size)
|
||||
|
||||
return loss
|
||||
|
|
|
@ -47,12 +47,13 @@ Original copyrights for above sources are below.
|
|||
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
||||
# SOFTWARE.
|
||||
import math
|
||||
|
||||
import torch
|
||||
from torch.optim import Optimizer
|
||||
|
||||
|
||||
class NvLamb(Optimizer):
|
||||
class Lamb(Optimizer):
|
||||
"""Implements a pure pytorch variant of FuseLAMB (NvLamb variant) optimizer from apex.optimizers.FusedLAMB
|
||||
reference: https://github.com/NVIDIA/DeepLearningExamples/blob/master/PyTorch/LanguageModeling/Transformer-XL/pytorch/lamb.py
|
||||
|
||||
|
@ -82,25 +83,13 @@ class NvLamb(Optimizer):
|
|||
https://openreview.net/forum?id=ryQu7f-RZ
|
||||
"""
|
||||
|
||||
def __init__(self, params, lr=1e-3, bias_correction=True,
|
||||
betas=(0.9, 0.999), eps=1e-6, weight_decay=0.01,
|
||||
grad_averaging=True, set_grad_none=True,
|
||||
max_grad_norm=1.0, use_nvlamb=False):
|
||||
defaults = dict(lr=lr, bias_correction=bias_correction,
|
||||
betas=betas, eps=eps, weight_decay=weight_decay,
|
||||
grad_averaging=grad_averaging,
|
||||
max_grad_norm=max_grad_norm)
|
||||
def __init__(
|
||||
self, params, lr=1e-3, bias_correction=True, betas=(0.9, 0.999), eps=1e-6,
|
||||
weight_decay=0.01, grad_averaging=True, max_grad_norm=1.0, use_nvlamb=False):
|
||||
defaults = dict(
|
||||
lr=lr, bias_correction=bias_correction, betas=betas, eps=eps, weight_decay=weight_decay,
|
||||
grad_averaging=grad_averaging, max_grad_norm=max_grad_norm, use_nvlamb=use_nvlamb)
|
||||
super().__init__(params, defaults)
|
||||
self.set_grad_none = set_grad_none
|
||||
self.use_nvlamb = use_nvlamb
|
||||
|
||||
def zero_grad(self):
|
||||
if self.set_grad_none:
|
||||
for group in self.param_groups:
|
||||
for p in group['params']:
|
||||
p.grad = None
|
||||
else:
|
||||
super(NvLamb, self).zero_grad()
|
||||
|
||||
def step(self, closure=None):
|
||||
"""Performs a single optimization step.
|
||||
|
@ -109,6 +98,7 @@ class NvLamb(Optimizer):
|
|||
and returns the loss.
|
||||
"""
|
||||
device = self.param_groups[0]["params"][0].device
|
||||
one_tensor = torch.tensor(1.0, device=device)
|
||||
|
||||
loss = None
|
||||
if closure is not None:
|
||||
|
@ -124,22 +114,18 @@ class NvLamb(Optimizer):
|
|||
raise RuntimeError('Lamb does not support sparse gradients, consider SparseAdam instad.')
|
||||
global_grad_norm.add_(grad.pow(2).sum())
|
||||
|
||||
global_grad_norm_ = torch.sqrt(global_grad_norm)
|
||||
global_grad_norm = torch.sqrt(global_grad_norm)
|
||||
max_grad_norm = self.defaults['max_grad_norm']
|
||||
|
||||
if global_grad_norm_ > max_grad_norm:
|
||||
clip_global_grad_norm = global_grad_norm_ / max_grad_norm
|
||||
else:
|
||||
clip_global_grad_norm = 1.0
|
||||
clip_global_grad_norm = torch.where(
|
||||
global_grad_norm > max_grad_norm,
|
||||
global_grad_norm / max_grad_norm,
|
||||
one_tensor)
|
||||
|
||||
for group in self.param_groups:
|
||||
bias_correction = 1 if group['bias_correction'] else 0
|
||||
beta1, beta2 = group['betas']
|
||||
grad_averaging = 1 if group['grad_averaging'] else 0
|
||||
if grad_averaging:
|
||||
beta3 = 1 - beta1
|
||||
else:
|
||||
beta3 = 1.0
|
||||
beta3 = 1 - beta1 if grad_averaging else 1.0
|
||||
|
||||
# assume same step across group now to simplify things
|
||||
# per parameter step can be easily support by making it tensor, or pass list into kernel
|
||||
|
@ -148,8 +134,6 @@ class NvLamb(Optimizer):
|
|||
else:
|
||||
group['step'] = 1
|
||||
|
||||
step_size = group['lr']
|
||||
|
||||
if bias_correction:
|
||||
bias_correction1 = 1 - beta1 ** group['step']
|
||||
bias_correction2 = 1 - beta2 ** group['step']
|
||||
|
@ -169,36 +153,31 @@ class NvLamb(Optimizer):
|
|||
# Exponential moving average of squared gradient values
|
||||
state['exp_avg_sq'] = torch.zeros_like(p.data)
|
||||
|
||||
exp_avg_, exp_avg_sq_ = state['exp_avg'], state['exp_avg_sq']
|
||||
exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq']
|
||||
|
||||
# Decay the first and second moment running average coefficient
|
||||
# m_t
|
||||
exp_avg_.mul_(beta1).add_(grad, alpha=beta3)
|
||||
# v_t
|
||||
exp_avg_sq_.mul_(beta2).addcmul_(grad, grad, value=1 - beta2)
|
||||
# create clones to avoid modifying runner stats
|
||||
exp_avg = exp_avg_.div(bias_correction1)
|
||||
exp_avg_sq = exp_avg_sq_.div(bias_correction2)
|
||||
exp_avg.mul_(beta1).add_(grad, alpha=beta3) # m_t
|
||||
exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1 - beta2) # v_t
|
||||
|
||||
# || w_t ||
|
||||
weight_norm = p.data.norm(2.0)
|
||||
# u_t
|
||||
exp_avg_sq_sqrt = torch.sqrt(exp_avg_sq)
|
||||
adam_step = exp_avg.div_(exp_avg_sq_sqrt.add_(group['eps']))
|
||||
if group['weight_decay'] != 0:
|
||||
adam_step.add_(p.data, alpha=group['weight_decay'])
|
||||
# || u_t ||
|
||||
adam_norm = adam_step.norm(2.0)
|
||||
if (group['weight_decay'] != 0 or self.use_nvlamb) and adam_norm > 0 and weight_norm > 0:
|
||||
trust_ratio = weight_norm / adam_norm
|
||||
trust_ratio = trust_ratio.item()
|
||||
else:
|
||||
trust_ratio = 1
|
||||
denom = (exp_avg_sq.sqrt() / math.sqrt(bias_correction2)).add_(group['eps'])
|
||||
update = (exp_avg / bias_correction1).div_(denom)
|
||||
|
||||
state['weight_norm'] = weight_norm
|
||||
state['adam_norm'] = adam_norm
|
||||
state['trust_ratio'] = trust_ratio
|
||||
weight_decay = group['weight_decay']
|
||||
if weight_decay != 0:
|
||||
update.add_(p.data, alpha=weight_decay)
|
||||
|
||||
p.data.add_(adam_step, alpha=-step_size * trust_ratio)
|
||||
trust_ratio = one_tensor
|
||||
if weight_decay != 0 or group['use_nvlamb']:
|
||||
# Layer adaptation. By default, skip layer adaptation on parameters that are
|
||||
# excluded from weight norm, unless use_nvlamb == True, then always enabled.
|
||||
w_norm = p.data.norm(2.0)
|
||||
g_norm = update.norm(2.0)
|
||||
trust_ratio = torch.where(
|
||||
w_norm > 0,
|
||||
torch.where(g_norm > 0, w_norm / g_norm, one_tensor),
|
||||
one_tensor,
|
||||
)
|
||||
update.mul_(trust_ratio)
|
||||
p.data.add_(update, alpha=-group['lr'])
|
||||
|
||||
return loss
|
||||
|
|
|
@ -11,82 +11,49 @@ from collections import defaultdict
|
|||
|
||||
class Lookahead(Optimizer):
|
||||
def __init__(self, base_optimizer, alpha=0.5, k=6):
|
||||
# NOTE super().__init__() not called on purpose
|
||||
if not 0.0 <= alpha <= 1.0:
|
||||
raise ValueError(f'Invalid slow update rate: {alpha}')
|
||||
if not 1 <= k:
|
||||
raise ValueError(f'Invalid lookahead steps: {k}')
|
||||
defaults = dict(lookahead_alpha=alpha, lookahead_k=k, lookahead_step=0)
|
||||
self.base_optimizer = base_optimizer
|
||||
self.param_groups = self.base_optimizer.param_groups
|
||||
self._base_optimizer = base_optimizer
|
||||
self.param_groups = base_optimizer.param_groups
|
||||
self.defaults = base_optimizer.defaults
|
||||
self.defaults.update(defaults)
|
||||
self.state = defaultdict(dict)
|
||||
# manually add our defaults to the param groups
|
||||
for name, default in defaults.items():
|
||||
for group in self.param_groups:
|
||||
for group in self._base_optimizer.param_groups:
|
||||
group.setdefault(name, default)
|
||||
|
||||
def update_slow(self, group):
|
||||
for fast_p in group["params"]:
|
||||
if fast_p.grad is None:
|
||||
continue
|
||||
param_state = self.state[fast_p]
|
||||
if 'slow_buffer' not in param_state:
|
||||
param_state['slow_buffer'] = torch.empty_like(fast_p.data)
|
||||
param_state['slow_buffer'].copy_(fast_p.data)
|
||||
slow = param_state['slow_buffer']
|
||||
slow.add_(group['lookahead_alpha'], fast_p.data - slow)
|
||||
param_state = self._base_optimizer.state[fast_p]
|
||||
if 'lookahead_slow_buff' not in param_state:
|
||||
param_state['lookahead_slow_buff'] = torch.empty_like(fast_p.data)
|
||||
param_state['lookahead_slow_buff'].copy_(fast_p.data)
|
||||
slow = param_state['lookahead_slow_buff']
|
||||
slow.add_(fast_p.data - slow, alpha=group['lookahead_alpha'])
|
||||
fast_p.data.copy_(slow)
|
||||
|
||||
def sync_lookahead(self):
|
||||
for group in self.param_groups:
|
||||
for group in self._base_optimizer.param_groups:
|
||||
self.update_slow(group)
|
||||
|
||||
def step(self, closure=None):
|
||||
#assert id(self.param_groups) == id(self.base_optimizer.param_groups)
|
||||
loss = self.base_optimizer.step(closure)
|
||||
for group in self.param_groups:
|
||||
loss = self._base_optimizer.step(closure)
|
||||
for group in self._base_optimizer.param_groups:
|
||||
group['lookahead_step'] += 1
|
||||
if group['lookahead_step'] % group['lookahead_k'] == 0:
|
||||
self.update_slow(group)
|
||||
return loss
|
||||
|
||||
def state_dict(self):
|
||||
fast_state_dict = self.base_optimizer.state_dict()
|
||||
slow_state = {
|
||||
(id(k) if isinstance(k, torch.Tensor) else k): v
|
||||
for k, v in self.state.items()
|
||||
}
|
||||
fast_state = fast_state_dict['state']
|
||||
param_groups = fast_state_dict['param_groups']
|
||||
return {
|
||||
'state': fast_state,
|
||||
'slow_state': slow_state,
|
||||
'param_groups': param_groups,
|
||||
}
|
||||
return self._base_optimizer.state_dict()
|
||||
|
||||
def load_state_dict(self, state_dict):
|
||||
fast_state_dict = {
|
||||
'state': state_dict['state'],
|
||||
'param_groups': state_dict['param_groups'],
|
||||
}
|
||||
self.base_optimizer.load_state_dict(fast_state_dict)
|
||||
|
||||
# We want to restore the slow state, but share param_groups reference
|
||||
# with base_optimizer. This is a bit redundant but least code
|
||||
slow_state_new = False
|
||||
if 'slow_state' not in state_dict:
|
||||
print('Loading state_dict from optimizer without Lookahead applied.')
|
||||
state_dict['slow_state'] = defaultdict(dict)
|
||||
slow_state_new = True
|
||||
slow_state_dict = {
|
||||
'state': state_dict['slow_state'],
|
||||
'param_groups': state_dict['param_groups'], # this is pointless but saves code
|
||||
}
|
||||
super(Lookahead, self).load_state_dict(slow_state_dict)
|
||||
self.param_groups = self.base_optimizer.param_groups # make both ref same container
|
||||
if slow_state_new:
|
||||
# reapply defaults to catch missing lookahead specific ones
|
||||
for name, default in self.defaults.items():
|
||||
for group in self.param_groups:
|
||||
group.setdefault(name, default)
|
||||
self._base_optimizer.load_state_dict(state_dict)
|
||||
self.param_groups = self._base_optimizer.param_groups
|
||||
|
|
|
@ -0,0 +1,190 @@
|
|||
""" PyTorch MADGRAD optimizer
|
||||
|
||||
MADGRAD: https://arxiv.org/abs/2101.11075
|
||||
|
||||
Code from: https://github.com/facebookresearch/madgrad
|
||||
"""
|
||||
# Copyright (c) Facebook, Inc. and its affiliates.
|
||||
#
|
||||
# This source code is licensed under the MIT license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
import math
|
||||
from typing import TYPE_CHECKING, Any, Callable, Optional
|
||||
|
||||
import torch
|
||||
import torch.optim
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from torch.optim.optimizer import _params_t
|
||||
else:
|
||||
_params_t = Any
|
||||
|
||||
|
||||
class MADGRAD(torch.optim.Optimizer):
|
||||
"""
|
||||
MADGRAD_: A Momentumized, Adaptive, Dual Averaged Gradient Method for Stochastic
|
||||
Optimization.
|
||||
|
||||
.. _MADGRAD: https://arxiv.org/abs/2101.11075
|
||||
|
||||
MADGRAD is a general purpose optimizer that can be used in place of SGD or
|
||||
Adam may converge faster and generalize better. Currently GPU-only.
|
||||
Typically, the same learning rate schedule that is used for SGD or Adam may
|
||||
be used. The overall learning rate is not comparable to either method and
|
||||
should be determined by a hyper-parameter sweep.
|
||||
|
||||
MADGRAD requires less weight decay than other methods, often as little as
|
||||
zero. Momentum values used for SGD or Adam's beta1 should work here also.
|
||||
|
||||
On sparse problems both weight_decay and momentum should be set to 0.
|
||||
|
||||
Arguments:
|
||||
params (iterable):
|
||||
Iterable of parameters to optimize or dicts defining parameter groups.
|
||||
lr (float):
|
||||
Learning rate (default: 1e-2).
|
||||
momentum (float):
|
||||
Momentum value in the range [0,1) (default: 0.9).
|
||||
weight_decay (float):
|
||||
Weight decay, i.e. a L2 penalty (default: 0).
|
||||
eps (float):
|
||||
Term added to the denominator outside of the root operation to improve numerical stability. (default: 1e-6).
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
params: _params_t,
|
||||
lr: float = 1e-2,
|
||||
momentum: float = 0.9,
|
||||
weight_decay: float = 0,
|
||||
eps: float = 1e-6,
|
||||
decoupled_decay: bool = False,
|
||||
):
|
||||
if momentum < 0 or momentum >= 1:
|
||||
raise ValueError(f"Momentum {momentum} must be in the range [0,1]")
|
||||
if lr <= 0:
|
||||
raise ValueError(f"Learning rate {lr} must be positive")
|
||||
if weight_decay < 0:
|
||||
raise ValueError(f"Weight decay {weight_decay} must be non-negative")
|
||||
if eps < 0:
|
||||
raise ValueError(f"Eps must be non-negative")
|
||||
|
||||
defaults = dict(
|
||||
lr=lr, eps=eps, momentum=momentum, weight_decay=weight_decay, decoupled_decay=decoupled_decay)
|
||||
super().__init__(params, defaults)
|
||||
|
||||
@property
|
||||
def supports_memory_efficient_fp16(self) -> bool:
|
||||
return False
|
||||
|
||||
@property
|
||||
def supports_flat_params(self) -> bool:
|
||||
return True
|
||||
|
||||
def step(self, closure: Optional[Callable[[], float]] = None) -> Optional[float]:
|
||||
"""Performs a single optimization step.
|
||||
|
||||
Arguments:
|
||||
closure (callable, optional): A closure that reevaluates the model
|
||||
and returns the loss.
|
||||
"""
|
||||
loss = None
|
||||
if closure is not None:
|
||||
loss = closure()
|
||||
|
||||
# step counter must be stored in state to ensure correct behavior under
|
||||
# optimizer sharding
|
||||
if 'k' not in self.state:
|
||||
self.state['k'] = torch.tensor([0], dtype=torch.long)
|
||||
k = self.state['k'].item()
|
||||
|
||||
for group in self.param_groups:
|
||||
eps = group["eps"]
|
||||
lr = group["lr"] + eps
|
||||
weight_decay = group["weight_decay"]
|
||||
momentum = group["momentum"]
|
||||
|
||||
ck = 1 - momentum
|
||||
lamb = lr * math.pow(k + 1, 0.5)
|
||||
|
||||
for p in group["params"]:
|
||||
if p.grad is None:
|
||||
continue
|
||||
grad = p.grad.data
|
||||
state = self.state[p]
|
||||
|
||||
if "grad_sum_sq" not in state:
|
||||
state["grad_sum_sq"] = torch.zeros_like(p.data).detach()
|
||||
state["s"] = torch.zeros_like(p.data).detach()
|
||||
if momentum != 0:
|
||||
state["x0"] = torch.clone(p.data).detach()
|
||||
|
||||
if momentum != 0.0 and grad.is_sparse:
|
||||
raise RuntimeError("momentum != 0 is not compatible with sparse gradients")
|
||||
|
||||
grad_sum_sq = state["grad_sum_sq"]
|
||||
s = state["s"]
|
||||
|
||||
# Apply weight decay
|
||||
if weight_decay != 0:
|
||||
if group['decoupled_decay']:
|
||||
p.data.mul_(1.0 - group['lr'] * weight_decay)
|
||||
else:
|
||||
if grad.is_sparse:
|
||||
raise RuntimeError("weight_decay option is not compatible with sparse gradients")
|
||||
grad.add_(p.data, alpha=weight_decay)
|
||||
|
||||
if grad.is_sparse:
|
||||
grad = grad.coalesce()
|
||||
grad_val = grad._values()
|
||||
|
||||
p_masked = p.sparse_mask(grad)
|
||||
grad_sum_sq_masked = grad_sum_sq.sparse_mask(grad)
|
||||
s_masked = s.sparse_mask(grad)
|
||||
|
||||
# Compute x_0 from other known quantities
|
||||
rms_masked_vals = grad_sum_sq_masked._values().pow(1 / 3).add_(eps)
|
||||
x0_masked_vals = p_masked._values().addcdiv(s_masked._values(), rms_masked_vals, value=1)
|
||||
|
||||
# Dense + sparse op
|
||||
grad_sq = grad * grad
|
||||
grad_sum_sq.add_(grad_sq, alpha=lamb)
|
||||
grad_sum_sq_masked.add_(grad_sq, alpha=lamb)
|
||||
|
||||
rms_masked_vals = grad_sum_sq_masked._values().pow_(1 / 3).add_(eps)
|
||||
|
||||
s.add_(grad, alpha=lamb)
|
||||
s_masked._values().add_(grad_val, alpha=lamb)
|
||||
|
||||
# update masked copy of p
|
||||
p_kp1_masked_vals = x0_masked_vals.addcdiv(s_masked._values(), rms_masked_vals, value=-1)
|
||||
# Copy updated masked p to dense p using an add operation
|
||||
p_masked._values().add_(p_kp1_masked_vals, alpha=-1)
|
||||
p.data.add_(p_masked, alpha=-1)
|
||||
else:
|
||||
if momentum == 0:
|
||||
# Compute x_0 from other known quantities
|
||||
rms = grad_sum_sq.pow(1 / 3).add_(eps)
|
||||
x0 = p.data.addcdiv(s, rms, value=1)
|
||||
else:
|
||||
x0 = state["x0"]
|
||||
|
||||
# Accumulate second moments
|
||||
grad_sum_sq.addcmul_(grad, grad, value=lamb)
|
||||
rms = grad_sum_sq.pow(1 / 3).add_(eps)
|
||||
|
||||
# Update s
|
||||
s.data.add_(grad, alpha=lamb)
|
||||
|
||||
# Step
|
||||
if momentum == 0:
|
||||
p.data.copy_(x0.addcdiv(s, rms, value=-1))
|
||||
else:
|
||||
z = x0.addcdiv(s, rms, value=-1)
|
||||
|
||||
# p is a moving average of z
|
||||
p.data.mul_(1 - ck).add_(z, alpha=ck)
|
||||
|
||||
self.state['k'] += 1
|
||||
return loss
|
|
@ -1,5 +1,5 @@
|
|||
import torch
|
||||
from torch.optim import Optimizer
|
||||
from torch.optim.optimizer import Optimizer
|
||||
|
||||
|
||||
class Nadam(Optimizer):
|
||||
|
@ -27,8 +27,10 @@ class Nadam(Optimizer):
|
|||
|
||||
def __init__(self, params, lr=2e-3, betas=(0.9, 0.999), eps=1e-8,
|
||||
weight_decay=0, schedule_decay=4e-3):
|
||||
defaults = dict(lr=lr, betas=betas, eps=eps,
|
||||
weight_decay=weight_decay, schedule_decay=schedule_decay)
|
||||
if not 0.0 <= lr:
|
||||
raise ValueError("Invalid learning rate: {}".format(lr))
|
||||
defaults = dict(
|
||||
lr=lr, betas=betas, eps=eps, weight_decay=weight_decay, schedule_decay=schedule_decay)
|
||||
super(Nadam, self).__init__(params, defaults)
|
||||
|
||||
def step(self, closure=None):
|
||||
|
@ -53,8 +55,8 @@ class Nadam(Optimizer):
|
|||
if len(state) == 0:
|
||||
state['step'] = 0
|
||||
state['m_schedule'] = 1.
|
||||
state['exp_avg'] = grad.new().resize_as_(grad).zero_()
|
||||
state['exp_avg_sq'] = grad.new().resize_as_(grad).zero_()
|
||||
state['exp_avg'] = torch.zeros_like(p.data)
|
||||
state['exp_avg_sq'] = torch.zeros_like(p.data)
|
||||
|
||||
# Warming momentum schedule
|
||||
m_schedule = state['m_schedule']
|
||||
|
@ -66,23 +68,21 @@ class Nadam(Optimizer):
|
|||
t = state['step']
|
||||
|
||||
if group['weight_decay'] != 0:
|
||||
grad = grad.add(group['weight_decay'], p.data)
|
||||
grad = grad.add(p.data, alpha=group['weight_decay'])
|
||||
|
||||
momentum_cache_t = beta1 * \
|
||||
(1. - 0.5 * (0.96 ** (t * schedule_decay)))
|
||||
momentum_cache_t_1 = beta1 * \
|
||||
(1. - 0.5 * (0.96 ** ((t + 1) * schedule_decay)))
|
||||
momentum_cache_t = beta1 * (1. - 0.5 * (0.96 ** (t * schedule_decay)))
|
||||
momentum_cache_t_1 = beta1 * (1. - 0.5 * (0.96 ** ((t + 1) * schedule_decay)))
|
||||
m_schedule_new = m_schedule * momentum_cache_t
|
||||
m_schedule_next = m_schedule * momentum_cache_t * momentum_cache_t_1
|
||||
state['m_schedule'] = m_schedule_new
|
||||
|
||||
# Decay the first and second moment running average coefficient
|
||||
exp_avg.mul_(beta1).add_(1. - beta1, grad)
|
||||
exp_avg_sq.mul_(beta2).addcmul_(1. - beta2, grad, grad)
|
||||
exp_avg.mul_(beta1).add_(grad, alpha=1. - beta1)
|
||||
exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1. - beta2)
|
||||
exp_avg_sq_prime = exp_avg_sq / (1. - beta2 ** t)
|
||||
denom = exp_avg_sq_prime.sqrt_().add_(eps)
|
||||
|
||||
p.data.addcdiv_(-group['lr'] * (1. - momentum_cache_t) / (1. - m_schedule_new), grad, denom)
|
||||
p.data.addcdiv_(-group['lr'] * momentum_cache_t_1 / (1. - m_schedule_next), exp_avg, denom)
|
||||
p.data.addcdiv_(grad, denom, value=-group['lr'] * (1. - momentum_cache_t) / (1. - m_schedule_new))
|
||||
p.data.addcdiv_(exp_avg, denom, value=-group['lr'] * momentum_cache_t_1 / (1. - m_schedule_next))
|
||||
|
||||
return loss
|
||||
|
|
|
@ -1,77 +0,0 @@
|
|||
"""NovoGrad Optimizer.
|
||||
Original impl by Masashi Kimura (Convergence Lab): https://github.com/convergence-lab/novograd
|
||||
Paper: `Stochastic Gradient Methods with Layer-wise Adaptive Moments for Training of Deep Networks`
|
||||
- https://arxiv.org/abs/1905.11286
|
||||
"""
|
||||
|
||||
import torch
|
||||
from torch.optim.optimizer import Optimizer
|
||||
import math
|
||||
|
||||
|
||||
class NovoGrad(Optimizer):
|
||||
def __init__(self, params, grad_averaging=False, lr=0.1, betas=(0.95, 0.98), eps=1e-8, weight_decay=0):
|
||||
defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay)
|
||||
super(NovoGrad, self).__init__(params, defaults)
|
||||
self._lr = lr
|
||||
self._beta1 = betas[0]
|
||||
self._beta2 = betas[1]
|
||||
self._eps = eps
|
||||
self._wd = weight_decay
|
||||
self._grad_averaging = grad_averaging
|
||||
|
||||
self._momentum_initialized = False
|
||||
|
||||
def step(self, closure=None):
|
||||
loss = None
|
||||
if closure is not None:
|
||||
loss = closure()
|
||||
|
||||
if not self._momentum_initialized:
|
||||
for group in self.param_groups:
|
||||
for p in group['params']:
|
||||
if p.grad is None:
|
||||
continue
|
||||
state = self.state[p]
|
||||
grad = p.grad.data
|
||||
if grad.is_sparse:
|
||||
raise RuntimeError('NovoGrad does not support sparse gradients')
|
||||
|
||||
v = torch.norm(grad)**2
|
||||
m = grad/(torch.sqrt(v) + self._eps) + self._wd * p.data
|
||||
state['step'] = 0
|
||||
state['v'] = v
|
||||
state['m'] = m
|
||||
state['grad_ema'] = None
|
||||
self._momentum_initialized = True
|
||||
|
||||
for group in self.param_groups:
|
||||
for p in group['params']:
|
||||
if p.grad is None:
|
||||
continue
|
||||
state = self.state[p]
|
||||
state['step'] += 1
|
||||
|
||||
step, v, m = state['step'], state['v'], state['m']
|
||||
grad_ema = state['grad_ema']
|
||||
|
||||
grad = p.grad.data
|
||||
g2 = torch.norm(grad)**2
|
||||
grad_ema = g2 if grad_ema is None else grad_ema * \
|
||||
self._beta2 + g2 * (1. - self._beta2)
|
||||
grad *= 1.0 / (torch.sqrt(grad_ema) + self._eps)
|
||||
|
||||
if self._grad_averaging:
|
||||
grad *= (1. - self._beta1)
|
||||
|
||||
g2 = torch.norm(grad)**2
|
||||
v = self._beta2*v + (1. - self._beta2)*g2
|
||||
m = self._beta1*m + (grad / (torch.sqrt(v) + self._eps) + self._wd * p.data)
|
||||
bias_correction1 = 1 - self._beta1 ** step
|
||||
bias_correction2 = 1 - self._beta2 ** step
|
||||
step_size = group['lr'] * math.sqrt(bias_correction2) / bias_correction1
|
||||
|
||||
state['v'], state['m'] = v, m
|
||||
state['grad_ema'] = grad_ema
|
||||
p.data.add_(-step_size, m)
|
||||
return loss
|
|
@ -96,7 +96,7 @@ class NvNovoGrad(Optimizer):
|
|||
if exp_avg_sq == 0:
|
||||
exp_avg_sq.copy_(norm)
|
||||
else:
|
||||
exp_avg_sq.mul_(beta2).add_(1 - beta2, norm)
|
||||
exp_avg_sq.mul_(beta2).add_(norm, alpha=1 - beta2)
|
||||
|
||||
if amsgrad:
|
||||
# Maintains the maximum of all 2nd moment running avg. till now
|
||||
|
@ -108,11 +108,11 @@ class NvNovoGrad(Optimizer):
|
|||
|
||||
grad.div_(denom)
|
||||
if group['weight_decay'] != 0:
|
||||
grad.add_(group['weight_decay'], p.data)
|
||||
grad.add_(p.data, alpha=group['weight_decay'])
|
||||
if group['grad_averaging']:
|
||||
grad.mul_(1 - beta1)
|
||||
exp_avg.mul_(beta1).add_(grad)
|
||||
|
||||
p.data.add_(-group['lr'], exp_avg)
|
||||
p.data.add_(exp_avg, alpha=-group['lr'])
|
||||
|
||||
return loss
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
""" Optimizer Factory w/ Custom Weight Decay
|
||||
Hacked together by / Copyright 2020 Ross Wightman
|
||||
Hacked together by / Copyright 2021 Ross Wightman
|
||||
"""
|
||||
from typing import Optional
|
||||
|
||||
|
@ -11,10 +11,10 @@ from .adabelief import AdaBelief
|
|||
from .adafactor import Adafactor
|
||||
from .adahessian import Adahessian
|
||||
from .adamp import AdamP
|
||||
from .lamb import NvLamb
|
||||
from .lamb import Lamb
|
||||
from .lookahead import Lookahead
|
||||
from .madgrad import MADGRAD
|
||||
from .nadam import Nadam
|
||||
from .novograd import NovoGrad
|
||||
from .nvnovograd import NvNovoGrad
|
||||
from .radam import RAdam
|
||||
from .rmsprop_tf import RMSpropTF
|
||||
|
@ -47,8 +47,8 @@ def optimizer_kwargs(cfg):
|
|||
Convert optimizer args in argparse args or cfg like object to keyword args for updated create fn.
|
||||
"""
|
||||
kwargs = dict(
|
||||
optimizer_name=cfg.opt,
|
||||
learning_rate=cfg.lr,
|
||||
opt=cfg.opt,
|
||||
lr=cfg.lr,
|
||||
weight_decay=cfg.weight_decay,
|
||||
momentum=cfg.momentum)
|
||||
if getattr(cfg, 'opt_eps', None) is not None:
|
||||
|
@ -72,9 +72,9 @@ def create_optimizer(args, model, filter_bias_and_bn=True):
|
|||
|
||||
|
||||
def create_optimizer_v2(
|
||||
model: nn.Module,
|
||||
optimizer_name: str = 'sgd',
|
||||
learning_rate: Optional[float] = None,
|
||||
model_or_params,
|
||||
opt: str = 'sgd',
|
||||
lr: Optional[float] = None,
|
||||
weight_decay: float = 0.,
|
||||
momentum: float = 0.9,
|
||||
filter_bias_and_bn: bool = True,
|
||||
|
@ -87,9 +87,9 @@ def create_optimizer_v2(
|
|||
* expose the parameters interface and leave it up to caller
|
||||
|
||||
Args:
|
||||
model (nn.Module): model containing parameters to optimize
|
||||
optimizer_name: name of optimizer to create
|
||||
learning_rate: initial learning rate
|
||||
model_or_params (nn.Module): model containing parameters to optimize
|
||||
opt: name of optimizer to create
|
||||
lr: initial learning rate
|
||||
weight_decay: weight decay to apply in optimizer
|
||||
momentum: momentum for momentum based optimizers (others may use betas via kwargs)
|
||||
filter_bias_and_bn: filter out bias, bn and other 1d params from weight decay
|
||||
|
@ -98,59 +98,85 @@ def create_optimizer_v2(
|
|||
Returns:
|
||||
Optimizer
|
||||
"""
|
||||
opt_lower = optimizer_name.lower()
|
||||
if weight_decay and filter_bias_and_bn:
|
||||
skip = {}
|
||||
if hasattr(model, 'no_weight_decay'):
|
||||
skip = model.no_weight_decay()
|
||||
parameters = add_weight_decay(model, weight_decay, skip)
|
||||
weight_decay = 0.
|
||||
if isinstance(model_or_params, nn.Module):
|
||||
# a model was passed in, extract parameters and add weight decays to appropriate layers
|
||||
if weight_decay and filter_bias_and_bn:
|
||||
skip = {}
|
||||
if hasattr(model_or_params, 'no_weight_decay'):
|
||||
skip = model_or_params.no_weight_decay()
|
||||
parameters = add_weight_decay(model_or_params, weight_decay, skip)
|
||||
weight_decay = 0.
|
||||
else:
|
||||
parameters = model_or_params.parameters()
|
||||
else:
|
||||
parameters = model.parameters()
|
||||
# iterable of parameters or param groups passed in
|
||||
parameters = model_or_params
|
||||
|
||||
opt_lower = opt.lower()
|
||||
opt_split = opt_lower.split('_')
|
||||
opt_lower = opt_split[-1]
|
||||
if 'fused' in opt_lower:
|
||||
assert has_apex and torch.cuda.is_available(), 'APEX and CUDA required for fused optimizers'
|
||||
|
||||
opt_args = dict(lr=learning_rate, weight_decay=weight_decay, **kwargs)
|
||||
opt_split = opt_lower.split('_')
|
||||
opt_lower = opt_split[-1]
|
||||
opt_args = dict(weight_decay=weight_decay, **kwargs)
|
||||
if lr is not None:
|
||||
opt_args.setdefault('lr', lr)
|
||||
|
||||
# basic SGD & related
|
||||
if opt_lower == 'sgd' or opt_lower == 'nesterov':
|
||||
# NOTE 'sgd' refers to SGD + nesterov momentum for legacy / backwards compat reasons
|
||||
opt_args.pop('eps', None)
|
||||
optimizer = optim.SGD(parameters, momentum=momentum, nesterov=True, **opt_args)
|
||||
elif opt_lower == 'momentum':
|
||||
opt_args.pop('eps', None)
|
||||
optimizer = optim.SGD(parameters, momentum=momentum, nesterov=False, **opt_args)
|
||||
elif opt_lower == 'adam':
|
||||
optimizer = optim.Adam(parameters, **opt_args)
|
||||
elif opt_lower == 'adabelief':
|
||||
optimizer = AdaBelief(parameters, rectify=False, **opt_args)
|
||||
elif opt_lower == 'adamw':
|
||||
optimizer = optim.AdamW(parameters, **opt_args)
|
||||
elif opt_lower == 'nadam':
|
||||
optimizer = Nadam(parameters, **opt_args)
|
||||
elif opt_lower == 'radam':
|
||||
optimizer = RAdam(parameters, **opt_args)
|
||||
elif opt_lower == 'adamp':
|
||||
optimizer = AdamP(parameters, wd_ratio=0.01, nesterov=True, **opt_args)
|
||||
elif opt_lower == 'sgdp':
|
||||
optimizer = SGDP(parameters, momentum=momentum, nesterov=True, **opt_args)
|
||||
|
||||
# adaptive
|
||||
elif opt_lower == 'adam':
|
||||
optimizer = optim.Adam(parameters, **opt_args)
|
||||
elif opt_lower == 'adamw':
|
||||
optimizer = optim.AdamW(parameters, **opt_args)
|
||||
elif opt_lower == 'adamp':
|
||||
optimizer = AdamP(parameters, wd_ratio=0.01, nesterov=True, **opt_args)
|
||||
elif opt_lower == 'nadam':
|
||||
try:
|
||||
# NOTE PyTorch >= 1.10 should have native NAdam
|
||||
optimizer = optim.Nadam(parameters, **opt_args)
|
||||
except AttributeError:
|
||||
optimizer = Nadam(parameters, **opt_args)
|
||||
elif opt_lower == 'radam':
|
||||
optimizer = RAdam(parameters, **opt_args)
|
||||
elif opt_lower == 'adamax':
|
||||
optimizer = optim.Adamax(parameters, **opt_args)
|
||||
elif opt_lower == 'adabelief':
|
||||
optimizer = AdaBelief(parameters, rectify=False, **opt_args)
|
||||
elif opt_lower == 'radabelief':
|
||||
optimizer = AdaBelief(parameters, rectify=True, **opt_args)
|
||||
elif opt_lower == 'adadelta':
|
||||
optimizer = optim.Adadelta(parameters, **opt_args)
|
||||
elif opt_lower == 'adagrad':
|
||||
opt_args.setdefault('eps', 1e-8)
|
||||
optimizer = optim.Adagrad(parameters, **opt_args)
|
||||
elif opt_lower == 'adafactor':
|
||||
if not learning_rate:
|
||||
opt_args['lr'] = None
|
||||
optimizer = Adafactor(parameters, **opt_args)
|
||||
elif opt_lower == 'adahessian':
|
||||
optimizer = Adahessian(parameters, **opt_args)
|
||||
elif opt_lower == 'lamb':
|
||||
optimizer = Lamb(parameters, **opt_args)
|
||||
elif opt_lower == 'madgrad':
|
||||
optimizer = MADGRAD(parameters, momentum=momentum, **opt_args)
|
||||
elif opt_lower == 'madgradw':
|
||||
optimizer = MADGRAD(parameters, momentum=momentum, decoupled_decay=True, **opt_args)
|
||||
elif opt_lower == 'novograd' or opt_lower == 'nvnovograd':
|
||||
optimizer = NvNovoGrad(parameters, **opt_args)
|
||||
elif opt_lower == 'rmsprop':
|
||||
optimizer = optim.RMSprop(parameters, alpha=0.9, momentum=momentum, **opt_args)
|
||||
elif opt_lower == 'rmsproptf':
|
||||
optimizer = RMSpropTF(parameters, alpha=0.9, momentum=momentum, **opt_args)
|
||||
elif opt_lower == 'novograd':
|
||||
optimizer = NovoGrad(parameters, **opt_args)
|
||||
elif opt_lower == 'nvnovograd':
|
||||
optimizer = NvNovoGrad(parameters, **opt_args)
|
||||
elif opt_lower == 'lamb':
|
||||
optimizer = NvLamb(parameters, **opt_args)
|
||||
|
||||
# second order
|
||||
elif opt_lower == 'adahessian':
|
||||
optimizer = Adahessian(parameters, **opt_args)
|
||||
|
||||
# NVIDIA fused optimizers, require APEX to be installed
|
||||
elif opt_lower == 'fusedsgd':
|
||||
|
|
|
@ -4,21 +4,21 @@ Paper: `On the Variance of the Adaptive Learning Rate and Beyond` - https://arxi
|
|||
"""
|
||||
import math
|
||||
import torch
|
||||
from torch.optim.optimizer import Optimizer, required
|
||||
from torch.optim.optimizer import Optimizer
|
||||
|
||||
|
||||
class RAdam(Optimizer):
|
||||
|
||||
def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, weight_decay=0):
|
||||
defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay)
|
||||
self.buffer = [[None, None, None] for ind in range(10)]
|
||||
defaults = dict(
|
||||
lr=lr, betas=betas, eps=eps, weight_decay=weight_decay,
|
||||
buffer=[[None, None, None] for _ in range(10)])
|
||||
super(RAdam, self).__init__(params, defaults)
|
||||
|
||||
def __setstate__(self, state):
|
||||
super(RAdam, self).__setstate__(state)
|
||||
|
||||
def step(self, closure=None):
|
||||
|
||||
loss = None
|
||||
if closure is not None:
|
||||
loss = closure()
|
||||
|
@ -47,105 +47,40 @@ class RAdam(Optimizer):
|
|||
exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq']
|
||||
beta1, beta2 = group['betas']
|
||||
|
||||
exp_avg_sq.mul_(beta2).addcmul_(1 - beta2, grad, grad)
|
||||
exp_avg.mul_(beta1).add_(1 - beta1, grad)
|
||||
exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1 - beta2)
|
||||
exp_avg.mul_(beta1).add_(grad, alpha=1 - beta1)
|
||||
|
||||
state['step'] += 1
|
||||
buffered = self.buffer[int(state['step'] % 10)]
|
||||
buffered = group['buffer'][int(state['step'] % 10)]
|
||||
if state['step'] == buffered[0]:
|
||||
N_sma, step_size = buffered[1], buffered[2]
|
||||
num_sma, step_size = buffered[1], buffered[2]
|
||||
else:
|
||||
buffered[0] = state['step']
|
||||
beta2_t = beta2 ** state['step']
|
||||
N_sma_max = 2 / (1 - beta2) - 1
|
||||
N_sma = N_sma_max - 2 * state['step'] * beta2_t / (1 - beta2_t)
|
||||
buffered[1] = N_sma
|
||||
num_sma_max = 2 / (1 - beta2) - 1
|
||||
num_sma = num_sma_max - 2 * state['step'] * beta2_t / (1 - beta2_t)
|
||||
buffered[1] = num_sma
|
||||
|
||||
# more conservative since it's an approximated value
|
||||
if N_sma >= 5:
|
||||
if num_sma >= 5:
|
||||
step_size = group['lr'] * math.sqrt(
|
||||
(1 - beta2_t) * (N_sma - 4) / (N_sma_max - 4) * (N_sma - 2) / N_sma * N_sma_max / (
|
||||
N_sma_max - 2)) / (1 - beta1 ** state['step'])
|
||||
(1 - beta2_t) *
|
||||
(num_sma - 4) / (num_sma_max - 4) *
|
||||
(num_sma - 2) / num_sma *
|
||||
num_sma_max / (num_sma_max - 2)) / (1 - beta1 ** state['step'])
|
||||
else:
|
||||
step_size = group['lr'] / (1 - beta1 ** state['step'])
|
||||
buffered[2] = step_size
|
||||
|
||||
if group['weight_decay'] != 0:
|
||||
p_data_fp32.add_(-group['weight_decay'] * group['lr'], p_data_fp32)
|
||||
p_data_fp32.add_(p_data_fp32, alpha=-group['weight_decay'] * group['lr'])
|
||||
|
||||
# more conservative since it's an approximated value
|
||||
if N_sma >= 5:
|
||||
if num_sma >= 5:
|
||||
denom = exp_avg_sq.sqrt().add_(group['eps'])
|
||||
p_data_fp32.addcdiv_(-step_size, exp_avg, denom)
|
||||
p_data_fp32.addcdiv_(exp_avg, denom, value=-step_size)
|
||||
else:
|
||||
p_data_fp32.add_(-step_size, exp_avg)
|
||||
|
||||
p.data.copy_(p_data_fp32)
|
||||
|
||||
return loss
|
||||
|
||||
|
||||
class PlainRAdam(Optimizer):
|
||||
|
||||
def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, weight_decay=0):
|
||||
defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay)
|
||||
|
||||
super(PlainRAdam, self).__init__(params, defaults)
|
||||
|
||||
def __setstate__(self, state):
|
||||
super(PlainRAdam, self).__setstate__(state)
|
||||
|
||||
def step(self, closure=None):
|
||||
|
||||
loss = None
|
||||
if closure is not None:
|
||||
loss = closure()
|
||||
|
||||
for group in self.param_groups:
|
||||
|
||||
for p in group['params']:
|
||||
if p.grad is None:
|
||||
continue
|
||||
grad = p.grad.data.float()
|
||||
if grad.is_sparse:
|
||||
raise RuntimeError('RAdam does not support sparse gradients')
|
||||
|
||||
p_data_fp32 = p.data.float()
|
||||
|
||||
state = self.state[p]
|
||||
|
||||
if len(state) == 0:
|
||||
state['step'] = 0
|
||||
state['exp_avg'] = torch.zeros_like(p_data_fp32)
|
||||
state['exp_avg_sq'] = torch.zeros_like(p_data_fp32)
|
||||
else:
|
||||
state['exp_avg'] = state['exp_avg'].type_as(p_data_fp32)
|
||||
state['exp_avg_sq'] = state['exp_avg_sq'].type_as(p_data_fp32)
|
||||
|
||||
exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq']
|
||||
beta1, beta2 = group['betas']
|
||||
|
||||
exp_avg_sq.mul_(beta2).addcmul_(1 - beta2, grad, grad)
|
||||
exp_avg.mul_(beta1).add_(1 - beta1, grad)
|
||||
|
||||
state['step'] += 1
|
||||
beta2_t = beta2 ** state['step']
|
||||
N_sma_max = 2 / (1 - beta2) - 1
|
||||
N_sma = N_sma_max - 2 * state['step'] * beta2_t / (1 - beta2_t)
|
||||
|
||||
if group['weight_decay'] != 0:
|
||||
p_data_fp32.add_(-group['weight_decay'] * group['lr'], p_data_fp32)
|
||||
|
||||
# more conservative since it's an approximated value
|
||||
if N_sma >= 5:
|
||||
step_size = group['lr'] * math.sqrt(
|
||||
(1 - beta2_t) * (N_sma - 4) / (N_sma_max - 4) * (N_sma - 2) / N_sma * N_sma_max / (
|
||||
N_sma_max - 2)) / (1 - beta1 ** state['step'])
|
||||
denom = exp_avg_sq.sqrt().add_(group['eps'])
|
||||
p_data_fp32.addcdiv_(-step_size, exp_avg, denom)
|
||||
else:
|
||||
step_size = group['lr'] / (1 - beta1 ** state['step'])
|
||||
p_data_fp32.add_(-step_size, exp_avg)
|
||||
p_data_fp32.add_(exp_avg, alpha=-step_size)
|
||||
|
||||
p.data.copy_(p_data_fp32)
|
||||
|
||||
|
|
|
@ -58,8 +58,9 @@ class RMSpropTF(Optimizer):
|
|||
if not 0.0 <= alpha:
|
||||
raise ValueError("Invalid alpha value: {}".format(alpha))
|
||||
|
||||
defaults = dict(lr=lr, momentum=momentum, alpha=alpha, eps=eps, centered=centered, weight_decay=weight_decay,
|
||||
decoupled_decay=decoupled_decay, lr_in_momentum=lr_in_momentum)
|
||||
defaults = dict(
|
||||
lr=lr, momentum=momentum, alpha=alpha, eps=eps, centered=centered, weight_decay=weight_decay,
|
||||
decoupled_decay=decoupled_decay, lr_in_momentum=lr_in_momentum)
|
||||
super(RMSpropTF, self).__init__(params, defaults)
|
||||
|
||||
def __setstate__(self, state):
|
||||
|
@ -103,34 +104,34 @@ class RMSpropTF(Optimizer):
|
|||
state['step'] += 1
|
||||
|
||||
if group['weight_decay'] != 0:
|
||||
if 'decoupled_decay' in group and group['decoupled_decay']:
|
||||
p.data.add_(-group['weight_decay'], p.data)
|
||||
if group['decoupled_decay']:
|
||||
p.data.mul_(1. - group['lr'] * group['weight_decay'])
|
||||
else:
|
||||
grad = grad.add(group['weight_decay'], p.data)
|
||||
grad = grad.add(p.data, alpha=group['weight_decay'])
|
||||
|
||||
# Tensorflow order of ops for updating squared avg
|
||||
square_avg.add_(one_minus_alpha, grad.pow(2) - square_avg)
|
||||
# square_avg.mul_(alpha).addcmul_(1 - alpha, grad, grad) # PyTorch original
|
||||
square_avg.add_(grad.pow(2) - square_avg, alpha=one_minus_alpha)
|
||||
# square_avg.mul_(alpha).addcmul_(grad, grad, value=1 - alpha) # PyTorch original
|
||||
|
||||
if group['centered']:
|
||||
grad_avg = state['grad_avg']
|
||||
grad_avg.add_(one_minus_alpha, grad - grad_avg)
|
||||
# grad_avg.mul_(alpha).add_(1 - alpha, grad) # PyTorch original
|
||||
avg = square_avg.addcmul(-1, grad_avg, grad_avg).add(group['eps']).sqrt_() # eps moved in sqrt
|
||||
grad_avg.add_(grad - grad_avg, alpha=one_minus_alpha)
|
||||
avg = square_avg.addcmul(grad_avg, grad_avg, value=-1).add(group['eps']).sqrt_() # eps in sqrt
|
||||
# grad_avg.mul_(alpha).add_(grad, alpha=1 - alpha) # PyTorch original
|
||||
else:
|
||||
avg = square_avg.add(group['eps']).sqrt_() # eps moved in sqrt
|
||||
|
||||
if group['momentum'] > 0:
|
||||
buf = state['momentum_buffer']
|
||||
# Tensorflow accumulates the LR scaling in the momentum buffer
|
||||
if 'lr_in_momentum' in group and group['lr_in_momentum']:
|
||||
buf.mul_(group['momentum']).addcdiv_(group['lr'], grad, avg)
|
||||
if group['lr_in_momentum']:
|
||||
buf.mul_(group['momentum']).addcdiv_(grad, avg, value=group['lr'])
|
||||
p.data.add_(-buf)
|
||||
else:
|
||||
# PyTorch scales the param update by LR
|
||||
buf.mul_(group['momentum']).addcdiv_(grad, avg)
|
||||
p.data.add_(-group['lr'], buf)
|
||||
p.data.add_(buf, alpha=-group['lr'])
|
||||
else:
|
||||
p.data.addcdiv_(-group['lr'], grad, avg)
|
||||
p.data.addcdiv_(grad, avg, value=-group['lr'])
|
||||
|
||||
return loss
|
||||
|
|
|
@ -9,49 +9,21 @@ MIT license
|
|||
"""
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from torch.optim.optimizer import Optimizer, required
|
||||
import math
|
||||
|
||||
from .adamp import projection
|
||||
|
||||
|
||||
class SGDP(Optimizer):
|
||||
def __init__(self, params, lr=required, momentum=0, dampening=0,
|
||||
weight_decay=0, nesterov=False, eps=1e-8, delta=0.1, wd_ratio=0.1):
|
||||
defaults = dict(lr=lr, momentum=momentum, dampening=dampening, weight_decay=weight_decay,
|
||||
nesterov=nesterov, eps=eps, delta=delta, wd_ratio=wd_ratio)
|
||||
defaults = dict(
|
||||
lr=lr, momentum=momentum, dampening=dampening, weight_decay=weight_decay,
|
||||
nesterov=nesterov, eps=eps, delta=delta, wd_ratio=wd_ratio)
|
||||
super(SGDP, self).__init__(params, defaults)
|
||||
|
||||
def _channel_view(self, x):
|
||||
return x.view(x.size(0), -1)
|
||||
|
||||
def _layer_view(self, x):
|
||||
return x.view(1, -1)
|
||||
|
||||
def _cosine_similarity(self, x, y, eps, view_func):
|
||||
x = view_func(x)
|
||||
y = view_func(y)
|
||||
|
||||
x_norm = x.norm(dim=1).add_(eps)
|
||||
y_norm = y.norm(dim=1).add_(eps)
|
||||
dot = (x * y).sum(dim=1)
|
||||
|
||||
return dot.abs() / x_norm / y_norm
|
||||
|
||||
def _projection(self, p, grad, perturb, delta, wd_ratio, eps):
|
||||
wd = 1
|
||||
expand_size = [-1] + [1] * (len(p.shape) - 1)
|
||||
for view_func in [self._channel_view, self._layer_view]:
|
||||
|
||||
cosine_sim = self._cosine_similarity(grad, p.data, eps, view_func)
|
||||
|
||||
if cosine_sim.max() < delta / math.sqrt(view_func(p.data).size(1)):
|
||||
p_n = p.data / view_func(p.data).norm(dim=1).view(expand_size).add_(eps)
|
||||
perturb -= p_n * view_func(p_n * perturb).sum(dim=1).view(expand_size)
|
||||
wd = wd_ratio
|
||||
|
||||
return perturb, wd
|
||||
|
||||
return perturb, wd
|
||||
|
||||
def step(self, closure=None):
|
||||
loss = None
|
||||
if closure is not None:
|
||||
|
@ -75,22 +47,22 @@ class SGDP(Optimizer):
|
|||
|
||||
# SGD
|
||||
buf = state['momentum']
|
||||
buf.mul_(momentum).add_(1 - dampening, grad)
|
||||
buf.mul_(momentum).add_(grad, alpha=1. - dampening)
|
||||
if nesterov:
|
||||
d_p = grad + momentum * buf
|
||||
else:
|
||||
d_p = buf
|
||||
|
||||
# Projection
|
||||
wd_ratio = 1
|
||||
wd_ratio = 1.
|
||||
if len(p.shape) > 1:
|
||||
d_p, wd_ratio = self._projection(p, grad, d_p, group['delta'], group['wd_ratio'], group['eps'])
|
||||
d_p, wd_ratio = projection(p, grad, d_p, group['delta'], group['wd_ratio'], group['eps'])
|
||||
|
||||
# Weight decay
|
||||
if weight_decay != 0:
|
||||
p.data.mul_(1 - group['lr'] * group['weight_decay'] * wd_ratio / (1-momentum))
|
||||
p.data.mul_(1. - group['lr'] * group['weight_decay'] * wd_ratio / (1-momentum))
|
||||
|
||||
# Step
|
||||
p.data.add_(-group['lr'], d_p)
|
||||
p.data.add_(d_p, alpha=-group['lr'])
|
||||
|
||||
return loss
|
||||
|
|
Loading…
Reference in New Issue