EasyCV/easycv/core/optimizer/adam.py

122 lines
4.6 KiB
Python

# Copyright (c) Alibaba, Inc. and its affiliates.
import math
from distutils.version import LooseVersion
from typing import List
import torch
from mmcv.runner.optimizer.builder import OPTIMIZERS
from torch import Tensor
from torch.optim import AdamW as _AdamW
def adamw(params: List[Tensor], grads: List[Tensor], exp_avgs: List[Tensor],
exp_avg_sqs: List[Tensor], max_exp_avg_sqs: List[Tensor],
state_steps: List[int], amsgrad: bool, beta1: float, beta2: float,
lr: float, weight_decay: float, eps: float):
r"""Functional API that performs AdamW algorithm computation.
See :class:`~torch.optim.AdamW` for details.
"""
for i, param in enumerate(params):
grad = grads[i]
exp_avg = exp_avgs[i]
exp_avg_sq = exp_avg_sqs[i]
step = state_steps[i]
# Perform stepweight decay
param.mul_(1 - lr * weight_decay)
bias_correction1 = 1 - beta1**step
bias_correction2 = 1 - beta2**step
# Decay the first and second moment running average coefficient
exp_avg.mul_(beta1).add_(grad, alpha=1 - beta1)
exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1 - beta2)
if amsgrad:
# Maintains the maximum of all 2nd moment running avg. till now
torch.maximum(
max_exp_avg_sqs[i], exp_avg_sq, out=max_exp_avg_sqs[i])
# Use the max. for normalizing running avg. of gradient
denom = (max_exp_avg_sqs[i].sqrt() /
math.sqrt(bias_correction2)).add_(eps)
else:
denom = (exp_avg_sq.sqrt() / math.sqrt(bias_correction2)).add_(eps)
step_size = lr / bias_correction1
param.addcdiv_(exp_avg, denom, value=-step_size)
if LooseVersion(torch.__version__) <= LooseVersion('1.9.0'):
@OPTIMIZERS.register_module(force=True)
class AdamW(_AdamW):
"""
torch1.8 bug UnboundLocalError: local variable 'beta1' referenced before assignment
bugfix reference: https://github.com/pytorch/pytorch/issues/55740
"""
@torch.no_grad()
def step(self, closure=None):
"""Performs a single optimization step.
Args:
closure (callable, optional): A closure that reevaluates the model
and returns the loss.
"""
loss = None
if closure is not None:
with torch.enable_grad():
loss = closure()
for group in self.param_groups:
params_with_grad = []
grads = []
exp_avgs = []
exp_avg_sqs = []
state_sums = []
max_exp_avg_sqs = []
state_steps = []
amsgrad = group['amsgrad']
beta1, beta2 = group['betas']
for p in group['params']:
if p.grad is None:
continue
params_with_grad.append(p)
if p.grad.is_sparse:
raise RuntimeError(
'AdamW does not support sparse gradients')
grads.append(p.grad)
state = self.state[p]
# State initialization
if len(state) == 0:
state['step'] = 0
# Exponential moving average of gradient values
state['exp_avg'] = torch.zeros_like(
p, memory_format=torch.preserve_format)
# Exponential moving average of squared gradient values
state['exp_avg_sq'] = torch.zeros_like(
p, memory_format=torch.preserve_format)
if amsgrad:
# Maintains max of all exp. moving avg. of sq. grad. values
state['max_exp_avg_sq'] = torch.zeros_like(
p, memory_format=torch.preserve_format)
exp_avgs.append(state['exp_avg'])
exp_avg_sqs.append(state['exp_avg_sq'])
if amsgrad:
max_exp_avg_sqs.append(state['max_exp_avg_sq'])
# update the steps for each param group update
state['step'] += 1
# record the step after step update
state_steps.append(state['step'])
adamw(params_with_grad, grads, exp_avgs, exp_avg_sqs,
max_exp_avg_sqs, state_steps, amsgrad, beta1, beta2,
group['lr'], group['weight_decay'], group['eps'])
return loss