# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#    http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import math
from paddle.optimizer.lr import LRScheduler


class CyclicalCosineDecay(LRScheduler):
    def __init__(self,
                 learning_rate,
                 T_max,
                 cycle=1,
                 last_epoch=-1,
                 eta_min=0.0,
                 verbose=False):
        """
        Cyclical cosine learning rate decay
        A learning rate which can be referred in https://arxiv.org/pdf/2012.12645.pdf
        Args:
            learning rate(float): learning rate
            T_max(int): maximum epoch num
            cycle(int): period of the cosine decay
            last_epoch (int, optional):  The index of last epoch. Can be set to restart training. Default: -1, means initial learning rate.
            eta_min(float): minimum learning rate during training
            verbose(bool): whether to print learning rate for each epoch
        """
        super(CyclicalCosineDecay, self).__init__(learning_rate, last_epoch,
                                                  verbose)
        self.cycle = cycle
        self.eta_min = eta_min

    def get_lr(self):
        if self.last_epoch == 0:
            return self.base_lr
        reletive_epoch = self.last_epoch % self.cycle
        lr = self.eta_min + 0.5 * (self.base_lr - self.eta_min) * \
                (1 + math.cos(math.pi * reletive_epoch / self.cycle))
        return lr


class OneCycleDecay(LRScheduler):
    """
    One Cycle learning rate decay
    A learning rate which can be referred in https://arxiv.org/abs/1708.07120
    Code refered in https://pytorch.org/docs/stable/_modules/torch/optim/lr_scheduler.html#OneCycleLR
    """

    def __init__(self,
                 max_lr,
                 epochs=None,
                 steps_per_epoch=None,
                 pct_start=0.3,
                 anneal_strategy='cos',
                 div_factor=25.,
                 final_div_factor=1e4,
                 three_phase=False,
                 last_epoch=-1,
                 verbose=False):

        # Validate total_steps
        if epochs <= 0 or not isinstance(epochs, int):
            raise ValueError(
                "Expected positive integer epochs, but got {}".format(epochs))
        if steps_per_epoch <= 0 or not isinstance(steps_per_epoch, int):
            raise ValueError(
                "Expected positive integer steps_per_epoch, but got {}".format(
                    steps_per_epoch))
        self.total_steps = epochs * steps_per_epoch

        self.max_lr = max_lr
        self.initial_lr = self.max_lr / div_factor
        self.min_lr = self.initial_lr / final_div_factor

        if three_phase:
            self._schedule_phases = [
                {
                    'end_step': float(pct_start * self.total_steps) - 1,
                    'start_lr': self.initial_lr,
                    'end_lr': self.max_lr,
                },
                {
                    'end_step': float(2 * pct_start * self.total_steps) - 2,
                    'start_lr': self.max_lr,
                    'end_lr': self.initial_lr,
                },
                {
                    'end_step': self.total_steps - 1,
                    'start_lr': self.initial_lr,
                    'end_lr': self.min_lr,
                },
            ]
        else:
            self._schedule_phases = [
                {
                    'end_step': float(pct_start * self.total_steps) - 1,
                    'start_lr': self.initial_lr,
                    'end_lr': self.max_lr,
                },
                {
                    'end_step': self.total_steps - 1,
                    'start_lr': self.max_lr,
                    'end_lr': self.min_lr,
                },
            ]

        # Validate pct_start
        if pct_start < 0 or pct_start > 1 or not isinstance(pct_start, float):
            raise ValueError(
                "Expected float between 0 and 1 pct_start, but got {}".format(
                    pct_start))

        # Validate anneal_strategy
        if anneal_strategy not in ['cos', 'linear']:
            raise ValueError(
                "anneal_strategy must by one of 'cos' or 'linear', instead got {}".
                format(anneal_strategy))
        elif anneal_strategy == 'cos':
            self.anneal_func = self._annealing_cos
        elif anneal_strategy == 'linear':
            self.anneal_func = self._annealing_linear

        super(OneCycleDecay, self).__init__(max_lr, last_epoch, verbose)

    def _annealing_cos(self, start, end, pct):
        "Cosine anneal from `start` to `end` as pct goes from 0.0 to 1.0."
        cos_out = math.cos(math.pi * pct) + 1
        return end + (start - end) / 2.0 * cos_out

    def _annealing_linear(self, start, end, pct):
        "Linearly anneal from `start` to `end` as pct goes from 0.0 to 1.0."
        return (end - start) * pct + start

    def get_lr(self):
        computed_lr = 0.0
        step_num = self.last_epoch

        if step_num > self.total_steps:
            raise ValueError(
                "Tried to step {} times. The specified number of total steps is {}"
                .format(step_num + 1, self.total_steps))
        start_step = 0
        for i, phase in enumerate(self._schedule_phases):
            end_step = phase['end_step']
            if step_num <= end_step or i == len(self._schedule_phases) - 1:
                pct = (step_num - start_step) / (end_step - start_step)
                computed_lr = self.anneal_func(phase['start_lr'],
                                               phase['end_lr'], pct)
                break
            start_step = phase['end_step']

        return computed_lr