# Copyright (c) OpenMMLab. All rights reserved.
import logging
import math
import os.path as osp
import tempfile
from unittest.mock import Mock

import torch
import torch.nn as nn
from torch.utils.data import Dataset

from mmengine.evaluator import BaseMetric
from mmengine.hooks import EarlyStoppingHook
from mmengine.logging import MMLogger
from mmengine.model import BaseModel
from mmengine.optim import OptimWrapper
from mmengine.runner import Runner
from mmengine.testing import RunnerTestCase


class ToyModel(BaseModel):

    def __init__(self):
        super().__init__()
        self.linear = nn.Linear(2, 1)

    def forward(self, inputs, data_sample, mode='tensor'):
        labels = torch.stack(data_sample)
        inputs = torch.stack(inputs)
        outputs = self.linear(inputs)
        if mode == 'tensor':
            return outputs
        elif mode == 'loss':
            loss = (labels - outputs).sum()
            outputs = dict(loss=loss)
            return outputs
        else:
            return outputs


class DummyDataset(Dataset):
    METAINFO = dict()  # type: ignore
    data = torch.randn(12, 2)
    label = torch.ones(12)

    @property
    def metainfo(self):
        return self.METAINFO

    def __len__(self):
        return self.data.size(0)

    def __getitem__(self, index):
        return dict(inputs=self.data[index], data_sample=self.label[index])


class DummyMetric(BaseMetric):

    default_prefix: str = 'test'

    def __init__(self, length):
        super().__init__()
        self.length = length
        self.best_idx = length // 2
        self.cur_idx = 0
        self.vals = [90, 91, 92, 88, 89, 90] * 2

    def process(self, *args, **kwargs):
        self.results.append(0)

    def compute_metrics(self, *args, **kwargs):
        acc = self.vals[self.cur_idx]
        self.cur_idx += 1
        return dict(acc=acc)


def get_mock_runner():
    runner = Mock()
    runner.train_loop = Mock()
    runner.train_loop.stop_training = False
    return runner


class TestEarlyStoppingHook(RunnerTestCase):

    def setUp(self):
        self.temp_dir = tempfile.TemporaryDirectory()

    def tearDown(self):
        # `FileHandler` should be closed in Windows, otherwise we cannot
        # delete the temporary directory
        logging.shutdown()
        MMLogger._instance_dict.clear()
        self.temp_dir.cleanup()

    def test_init(self):

        hook = EarlyStoppingHook(monitor='acc')
        self.assertEqual(hook.rule, 'greater')
        self.assertLess(hook.best_score, 0)

        hook = EarlyStoppingHook(monitor='ACC')
        self.assertEqual(hook.rule, 'greater')
        self.assertLess(hook.best_score, 0)

        hook = EarlyStoppingHook(monitor='mAP_50')
        self.assertEqual(hook.rule, 'greater')
        self.assertLess(hook.best_score, 0)

        hook = EarlyStoppingHook(monitor='loss')
        self.assertEqual(hook.rule, 'less')
        self.assertGreater(hook.best_score, 0)

        hook = EarlyStoppingHook(monitor='Loss')
        self.assertEqual(hook.rule, 'less')
        self.assertGreater(hook.best_score, 0)

        hook = EarlyStoppingHook(monitor='ce_loss')
        self.assertEqual(hook.rule, 'less')
        self.assertGreater(hook.best_score, 0)

        with self.assertRaises(ValueError):
            # `rule` should be passed.
            EarlyStoppingHook(monitor='recall')

        with self.assertRaises(ValueError):
            # Invalid `rule`
            EarlyStoppingHook(monitor='accuracy/top1', rule='the world')

    def test_before_run(self):
        runner = Mock()
        runner.train_loop = object()

        # `train_loop` must contain `stop_training` variable.
        with self.assertRaises(AssertionError):
            hook = EarlyStoppingHook(monitor='accuracy/top1', rule='greater')
            hook.before_run(runner)

    def test_after_val_epoch(self):
        runner = get_mock_runner()
        metrics = {'accuracy/top1': 0.5, 'loss': 0.23}
        hook = EarlyStoppingHook(monitor='acc', rule='greater')

        with self.assertWarns(UserWarning):
            # Skip early stopping process since the evaluation results does not
            # include the key 'acc'
            hook.after_val_epoch(runner, metrics)

        # if `monitor` does not match and strict=True, crash the training.
        with self.assertRaises(RuntimeError):
            metrics = {'accuracy/top1': 0.5, 'loss': 0.23}
            hook = EarlyStoppingHook(
                monitor='acc', rule='greater', strict=True)
            hook.after_val_epoch(runner, metrics)

        # Check largest value
        runner = get_mock_runner()
        metrics = [{'accuracy/top1': i / 9.} for i in range(8)]
        hook = EarlyStoppingHook(monitor='accuracy/top1', rule='greater')
        for metric in metrics:
            hook.after_val_epoch(runner, metric)
            if runner.train_loop.stop_training:
                break
        self.assertAlmostEqual(hook.best_score, 7 / 9)

        # Check smallest value
        runner = get_mock_runner()
        metrics = [{'loss': i / 9.} for i in range(8, 0, -1)]
        hook = EarlyStoppingHook(monitor='loss')
        for metric in metrics:
            hook.after_val_epoch(runner, metric)
            if runner.train_loop.stop_training:
                break
        self.assertAlmostEqual(hook.best_score, 1 / 9)

        # Check stop training
        runner = get_mock_runner()
        metrics = [{'accuracy/top1': i} for i in torch.linspace(98, 99, 8)]
        hook = EarlyStoppingHook(
            monitor='accuracy/top1', rule='greater', min_delta=1)
        for metric in metrics:
            hook.after_val_epoch(runner, metric)
            if runner.train_loop.stop_training:
                break
        self.assertTrue(runner.train_loop.stop_training)

        # Check finite
        runner = get_mock_runner()
        metrics = [{'accuracy/top1': math.inf} for i in range(5)]
        hook = EarlyStoppingHook(
            monitor='accuracy/top1', rule='greater', min_delta=1)
        for metric in metrics:
            hook.after_val_epoch(runner, metric)
            if runner.train_loop.stop_training:
                break
        self.assertTrue(runner.train_loop.stop_training)

        # Check patience
        runner = get_mock_runner()
        metrics = [{'accuracy/top1': i} for i in torch.linspace(98, 99, 8)]
        hook = EarlyStoppingHook(
            monitor='accuracy/top1', rule='greater', min_delta=1, patience=10)
        for metric in metrics:
            hook.after_val_epoch(runner, metric)
            if runner.train_loop.stop_training:
                break
        self.assertFalse(runner.train_loop.stop_training)

        # Check stopping_threshold
        runner = get_mock_runner()
        metrics = [{'accuracy/top1': i} for i in torch.linspace(98, 99, 8)]
        hook = EarlyStoppingHook(
            monitor='accuracy/top1',
            rule='greater',
            stopping_threshold=98.5,
            patience=0)
        for metric in metrics:
            hook.after_val_epoch(runner, metric)
            if runner.train_loop.stop_training:
                break
        self.assertAlmostEqual(hook.best_score.item(), 98 + 4 / 7, places=5)

    def test_with_runner(self):
        max_epoch = 10
        work_dir = osp.join(self.temp_dir.name, 'runner_test')
        early_stop_cfg = dict(
            type='EarlyStoppingHook',
            monitor='test/acc',
            rule='greater',
            min_delta=1,
            patience=3,
        )
        runner = Runner(
            model=ToyModel(),
            work_dir=work_dir,
            train_dataloader=dict(
                dataset=DummyDataset(),
                sampler=dict(type='DefaultSampler', shuffle=True),
                batch_size=3,
                num_workers=0),
            val_dataloader=dict(
                dataset=DummyDataset(),
                sampler=dict(type='DefaultSampler', shuffle=False),
                batch_size=3,
                num_workers=0),
            val_evaluator=dict(type=DummyMetric, length=max_epoch),
            optim_wrapper=OptimWrapper(
                torch.optim.Adam(ToyModel().parameters())),
            train_cfg=dict(
                by_epoch=True, max_epochs=max_epoch, val_interval=1),
            val_cfg=dict(),
            custom_hooks=[early_stop_cfg],
            experiment_name='earlystop_test')
        runner.train()
        self.assertEqual(runner.epoch, 6)