# Copyright (c) Alibaba, Inc. and its affiliates.
import math
from distutils.version import LooseVersion
from typing import List

import torch
from mmcv.runner.optimizer.builder import OPTIMIZERS
from torch import Tensor
from torch.optim import AdamW as _AdamW


def adamw(params: List[Tensor], grads: List[Tensor], exp_avgs: List[Tensor],
          exp_avg_sqs: List[Tensor], max_exp_avg_sqs: List[Tensor],
          state_steps: List[int], amsgrad: bool, beta1: float, beta2: float,
          lr: float, weight_decay: float, eps: float):
    r"""Functional API that performs AdamW algorithm computation.
    See :class:`~torch.optim.AdamW` for details.
    """
    for i, param in enumerate(params):
        grad = grads[i]
        exp_avg = exp_avgs[i]
        exp_avg_sq = exp_avg_sqs[i]
        step = state_steps[i]

        # Perform stepweight decay
        param.mul_(1 - lr * weight_decay)

        bias_correction1 = 1 - beta1**step
        bias_correction2 = 1 - beta2**step

        # Decay the first and second moment running average coefficient
        exp_avg.mul_(beta1).add_(grad, alpha=1 - beta1)
        exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1 - beta2)
        if amsgrad:
            # Maintains the maximum of all 2nd moment running avg. till now
            torch.maximum(
                max_exp_avg_sqs[i], exp_avg_sq, out=max_exp_avg_sqs[i])
            # Use the max. for normalizing running avg. of gradient
            denom = (max_exp_avg_sqs[i].sqrt() /
                     math.sqrt(bias_correction2)).add_(eps)
        else:
            denom = (exp_avg_sq.sqrt() / math.sqrt(bias_correction2)).add_(eps)

        step_size = lr / bias_correction1

        param.addcdiv_(exp_avg, denom, value=-step_size)


if LooseVersion(torch.__version__) <= LooseVersion('1.9.0'):

    @OPTIMIZERS.register_module(force=True)
    class AdamW(_AdamW):
        """
        torch1.8 bug UnboundLocalError: local variable 'beta1' referenced before assignment
        bugfix reference: https://github.com/pytorch/pytorch/issues/55740
        """

        @torch.no_grad()
        def step(self, closure=None):
            """Performs a single optimization step.
            Args:
                closure (callable, optional): A closure that reevaluates the model
                    and returns the loss.
            """
            loss = None
            if closure is not None:
                with torch.enable_grad():
                    loss = closure()

            for group in self.param_groups:
                params_with_grad = []
                grads = []
                exp_avgs = []
                exp_avg_sqs = []
                state_sums = []
                max_exp_avg_sqs = []
                state_steps = []
                amsgrad = group['amsgrad']
                beta1, beta2 = group['betas']

                for p in group['params']:
                    if p.grad is None:
                        continue
                    params_with_grad.append(p)
                    if p.grad.is_sparse:
                        raise RuntimeError(
                            'AdamW does not support sparse gradients')
                    grads.append(p.grad)

                    state = self.state[p]

                    # State initialization
                    if len(state) == 0:
                        state['step'] = 0
                        # Exponential moving average of gradient values
                        state['exp_avg'] = torch.zeros_like(
                            p, memory_format=torch.preserve_format)
                        # Exponential moving average of squared gradient values
                        state['exp_avg_sq'] = torch.zeros_like(
                            p, memory_format=torch.preserve_format)
                        if amsgrad:
                            # Maintains max of all exp. moving avg. of sq. grad. values
                            state['max_exp_avg_sq'] = torch.zeros_like(
                                p, memory_format=torch.preserve_format)

                    exp_avgs.append(state['exp_avg'])
                    exp_avg_sqs.append(state['exp_avg_sq'])

                    if amsgrad:
                        max_exp_avg_sqs.append(state['max_exp_avg_sq'])

                    # update the steps for each param group update
                    state['step'] += 1
                    # record the step after step update
                    state_steps.append(state['step'])

                adamw(params_with_grad, grads, exp_avgs, exp_avg_sqs,
                      max_exp_avg_sqs, state_steps, amsgrad, beta1, beta2,
                      group['lr'], group['weight_decay'], group['eps'])

            return loss