# Copyright (c) OpenMMLab. All rights reserved.
from collections import OrderedDict
from copy import deepcopy

from mmcv.transforms import Compose
from mmengine.hooks import Hook
from mmengine.model import is_model_wrapper

from mmpretrain.models.utils import RandomBatchAugment
from mmpretrain.registry import HOOKS, MODEL_WRAPPERS, MODELS


@HOOKS.register_module()
class SwitchRecipeHook(Hook):
    """switch recipe during the training loop, including train pipeline, batch
    augments and loss currently.

    Args:
        schedule (list): Every item of the schedule list should be a dict, and
            the dict should have ``action_epoch`` and some of
            ``train_pipeline``, ``train_augments`` and ``loss`` keys:

            - ``action_epoch`` (int): switch training recipe at which epoch.
            - ``train_pipeline`` (list, optional): The new data pipeline of the
              train dataset. If not specified, keep the original settings.
            - ``batch_augments`` (dict | None, optional): The new batch
              augmentations of during training. See :mod:`Batch Augmentations
              <mmpretrain.models.utils.batch_augments>` for more details.
              If None, disable batch augmentations. If not specified, keep the
              original settings.
            - ``loss`` (dict, optional): The new loss module config. If not
              specified, keep the original settings.

    Example:
        To use this hook in config files.

        .. code:: python

            custom_hooks = [
                dict(
                    type='SwitchRecipeHook',
                    schedule=[
                        dict(
                            action_epoch=30,
                            train_pipeline=pipeline_after_30e,
                            batch_augments=batch_augments_after_30e,
                            loss=loss_after_30e,
                        ),
                        dict(
                            action_epoch=60,
                            # Disable batch augmentations after 60e
                            # and keep other settings.
                            batch_augments=None,
                        ),
                    ]
                )
            ]
    """
    priority = 'NORMAL'

    def __init__(self, schedule):
        recipes = {}
        for recipe in schedule:
            assert 'action_epoch' in recipe, \
                'Please set `action_epoch` in every item ' \
                'of the `schedule` in the SwitchRecipeHook.'
            recipe = deepcopy(recipe)
            if 'train_pipeline' in recipe:
                recipe['train_pipeline'] = Compose(recipe['train_pipeline'])
            if 'batch_augments' in recipe:
                batch_augments = recipe['batch_augments']
                if isinstance(batch_augments, dict):
                    batch_augments = RandomBatchAugment(**batch_augments)
                recipe['batch_augments'] = batch_augments
            if 'loss' in recipe:
                loss = recipe['loss']
                if isinstance(loss, dict):
                    loss = MODELS.build(loss)
                recipe['loss'] = loss

            action_epoch = recipe.pop('action_epoch')
            assert action_epoch not in recipes, \
                f'The `action_epoch` {action_epoch} is repeated ' \
                'in the SwitchRecipeHook.'
            recipes[action_epoch] = recipe
        self.schedule = OrderedDict(sorted(recipes.items()))

    def before_train(self, runner) -> None:
        """before run setting. If resume form a checkpoint, do all switch
        before the current epoch.

        Args:
            runner (Runner): The runner of the training, validation or testing
                process.
        """
        if runner._resume:
            for action_epoch, recipe in self.schedule.items():
                if action_epoch >= runner.epoch + 1:
                    break
                self._do_switch(runner, recipe,
                                f' (resume recipe of epoch {action_epoch})')

    def before_train_epoch(self, runner):
        """do before train epoch."""
        recipe = self.schedule.get(runner.epoch + 1, None)
        if recipe is not None:
            self._do_switch(runner, recipe, f' at epoch {runner.epoch + 1}')

    def _do_switch(self, runner, recipe, extra_info=''):
        """do the switch aug process."""
        if 'batch_augments' in recipe:
            self._switch_batch_augments(runner, recipe['batch_augments'])
            runner.logger.info(f'Switch batch augments{extra_info}.')

        if 'train_pipeline' in recipe:
            self._switch_train_pipeline(runner, recipe['train_pipeline'])
            runner.logger.info(f'Switch train pipeline{extra_info}.')

        if 'loss' in recipe:
            self._switch_loss(runner, recipe['loss'])
            runner.logger.info(f'Switch loss{extra_info}.')

    @staticmethod
    def _switch_batch_augments(runner, batch_augments):
        """switch the train augments."""
        model = runner.model
        if is_model_wrapper(model):
            model = model.module

        model.data_preprocessor.batch_augments = batch_augments

    @staticmethod
    def _switch_train_pipeline(runner, train_pipeline):
        """switch the train loader dataset pipeline."""

        def switch_pipeline(dataset, pipeline):
            if hasattr(dataset, 'pipeline'):
                # for usual dataset
                dataset.pipeline = pipeline
            elif hasattr(dataset, 'datasets'):
                # for concat dataset wrapper
                for ds in dataset.datasets:
                    switch_pipeline(ds, pipeline)
            elif hasattr(dataset, 'dataset'):
                # for other dataset wrappers
                switch_pipeline(dataset.dataset, pipeline)
            else:
                raise RuntimeError(
                    'Cannot access the `pipeline` of the dataset.')

        train_loader = runner.train_loop.dataloader
        switch_pipeline(train_loader.dataset, train_pipeline)

        # To restart the iterator of dataloader when `persistent_workers=True`
        train_loader._iterator = None

    @staticmethod
    def _switch_loss(runner, loss_module):
        """switch the loss module."""
        model = runner.model
        if is_model_wrapper(model, MODEL_WRAPPERS):
            model = model.module

        if hasattr(model, 'loss_module'):
            model.loss_module = loss_module
        elif hasattr(model, 'head') and hasattr(model.head, 'loss_module'):
            model.head.loss_module = loss_module
        else:
            raise RuntimeError('Cannot access the `loss_module` of the model.')