# Copyright (c) Alibaba, Inc. and its affiliates.
import inspect
import logging

import mmcv
import numpy as np
import torch
import torch.nn as nn
from mmcv.cnn import ConvModule

from easycv.models.registry import BACKBONES, HEADS, MODELS, NECKS
from .test_util import run_in_subprocess

EASYCV_REGISTRY_MAP = {
    'model': MODELS,
    'backbone': BACKBONES,
    'neck': NECKS,
    'head': HEADS
}
MMDET = 'mmdet'
SUPPORT_MMLAB_TYPES = [MMDET]
_MMLAB_COPIES = locals()


class MMAdapter:

    def __init__(self, modules_config):
        """Adapt mmlab apis.
        Args: modules_config is as follow format:
            [
                dict(type='mmdet', name='MaskRCNN', module='model'), # means using mmdet MaskRCNN
                # dict(type='mmdet, name='ResNet', module='backbone'), # comment out, means use my ResNet
                dict(name='FPN', module='neck'),  # type is missing, use mmdet default
            ]
        """
        self.default_mmtype = 'mmdet'
        self.mmtype_list = set([])

        for module_cfg in modules_config:
            mmtype = module_cfg.get('type',
                                    self.default_mmtype)  # default mmdet
            self.mmtype_list.add(mmtype)

        self.check_env()
        self.fix_conflicts()

        self.MMTYPE_REGISTRY_MAP = self._get_mmtype_registry_map()
        self.modules_config = modules_config

    def check_env(self):
        assert self.mmtype_list.issubset(
            SUPPORT_MMLAB_TYPES), 'Only support %s now !' % SUPPORT_MMLAB_TYPES
        install_success = False
        try:
            import mmdet
            install_success = True
        except ModuleNotFoundError as e:
            logging.warning(e)
            logging.warning('Try to install mmdet...')

        if not install_success:
            try:
                run_in_subprocess('pip install mmdet')
            except:
                raise ValueError(
                    'Failed to install mmdet, '
                    'please refer to https://github.com/open-mmlab/mmdetection to install.'
                )

    def fix_conflicts(self):
        # mmdet and easycv both register
        if MMDET in self.mmtype_list:
            mmcv_conflict_list = ['YOLOXLrUpdaterHook']
            from mmcv.runner.hooks import HOOKS
            for conflict in mmcv_conflict_list:
                HOOKS._module_dict.pop(conflict, None)

    def adapt_mmlab_modules(self):
        for module_cfg in self.modules_config:
            mmtype = module_cfg['type']
            module_name, module_type = module_cfg['name'], module_cfg['module']
            self._merge_mmlab_module_to_easycv(mmtype, module_type,
                                               module_name)
            self.wrap_module(mmtype, module_type, module_name)

        for mmtype in self.mmtype_list:
            self._merge_all_easycv_modules_to_mmlab(mmtype)

    def wrap_module(self, mmtype, module_type, module_name):
        module_obj = self._get_mm_module_obj(mmtype, module_type, module_name)
        if mmtype == MMDET:
            MMDetWrapper().wrap_module(module_obj, module_type)

    def _merge_all_easycv_modules_to_mmlab(self, mmtype):
        # Add all my module to mmlab module registry, if duplicated, replace with my module.
        # To handle: if MaskRCNN use mmdet's api, but the backbone also uses the backbone registered in mmdet
        # In order to support our backbone, register our modules into mmdet.
        # If not specified mmdet type, use our modules by default.
        for key, registry_type in self.MMTYPE_REGISTRY_MAP[mmtype].items():
            registry_type._module_dict.update(
                EASYCV_REGISTRY_MAP[key]._module_dict)

    def _merge_mmlab_module_to_easycv(self,
                                      mmtype,
                                      module_type,
                                      module_name,
                                      force=True):
        model_obj = self._get_mm_module_obj(mmtype, module_type, module_name)
        # Add mmlab module to my module registry.
        easycv_registry_type = EASYCV_REGISTRY_MAP[module_type]
        # Copy a duplicate to avoid directly modifying the properties of the original object
        _MMLAB_COPIES[module_name] = type(module_name, (model_obj, ), dict())
        easycv_registry_type.register_module(
            _MMLAB_COPIES[module_name], force=force)

    def _get_mm_module_obj(self, mmtype, module_type, module_name):
        if isinstance(module_name, str):
            mm_registry_type = self.MMTYPE_REGISTRY_MAP[mmtype][module_type]
            mm_module_dict = mm_registry_type._module_dict
            if module_name in mm_module_dict:
                module_obj = mm_module_dict[module_name]
            else:
                raise ValueError('Not find {} object in {}'.format(
                    module_name, mmtype))
        elif inspect.isclass(module_name):
            module_obj = module_name
        else:
            raise ValueError(
                'Only support type `str` and `class` object, but get type {}'.
                format(type(module_name)))
        return module_obj

    def _get_mmtype_registry_map(self):
        from mmdet.models.builder import MODELS as MMMODELS
        from mmdet.models.builder import BACKBONES as MMBACKBONES
        from mmdet.models.builder import NECKS as MMNECKS
        from mmdet.models.builder import HEADS as MMHEADS
        registry_map = {
            MMDET: {
                'model': MMMODELS,
                'backbone': MMBACKBONES,
                'neck': MMNECKS,
                'head': MMHEADS
            }
        }
        return registry_map


class MMDetWrapper:

    def __init__(self):
        self.refactor_modules()

    def wrap_module(self, cls, module_type):
        if module_type == 'model':
            self._wrap_model_init(cls)
            self._wrap_model_forward(cls)
            self._wrap_model_forward_test(cls)

    def refactor_modules(self):
        update_rpn_head()

    def _wrap_model_init(self, cls):
        origin_init = cls.__init__

        def _new_init(self, *args, **kwargs):
            origin_init(self, *args, **kwargs)
            self.init_weights()

        setattr(cls, '__init__', _new_init)

    def _wrap_model_forward(self, cls):
        origin_forward = cls.forward

        def _new_forward(self, img, mode='train', **kwargs):
            img_metas = kwargs.pop('img_metas', None)

            if mode == 'train':
                return origin_forward(
                    self, img, img_metas, return_loss=True, **kwargs)
            else:
                return origin_forward(
                    self, img, img_metas, return_loss=False, **kwargs)

        setattr(cls, 'forward', _new_forward)

    def _wrap_model_forward_test(self, cls):
        from mmdet.core import encode_mask_results

        origin_forward_test = cls.forward_test

        def _new_forward_test(self, img, img_metas=None, **kwargs):
            kwargs.update({'rescale': True})  # move from single_gpu_test
            logging.info('Set rescale to True for `model.forward_test`!')

            result = origin_forward_test(self, img, img_metas, **kwargs)
            # ============result process to adapt to easycv============
            # encode mask results
            if isinstance(result[0], tuple):
                result = [(bbox_results, encode_mask_results(mask_results))
                          for bbox_results, mask_results in result]
            # This logic is only used in panoptic segmentation test.
            elif isinstance(result[0], dict) and 'ins_results' in result[0]:
                for j in range(len(result)):
                    bbox_results, mask_results = result[j]['ins_results']
                    result[j]['ins_results'] = (
                        bbox_results, encode_mask_results(mask_results))

            detection_boxes = []
            detection_scores = []
            detection_classes = []
            detection_masks = []
            for res_i in result:
                if isinstance(res_i, tuple):
                    bbox_result, segm_result = res_i
                    if isinstance(segm_result, tuple):
                        segm_result = segm_result[0]  # ms rcnn
                else:
                    bbox_result, segm_result = res_i, None
                bboxes = np.vstack(bbox_result)
                labels = [
                    np.full(bbox.shape[0], i, dtype=np.int32)
                    for i, bbox in enumerate(bbox_result)
                ]
                labels = np.concatenate(labels)
                # draw segmentation masks
                segms = []
                if segm_result is not None and len(labels) > 0:  # non empty
                    segms = mmcv.concat_list(segm_result)
                    if isinstance(segms[0], torch.Tensor):
                        segms = torch.stack(
                            segms, dim=0).detach().cpu().numpy()
                    else:
                        segms = np.stack(segms, axis=0)

                scores = bboxes[:, 4] if bboxes.shape[1] == 5 else None
                bboxes = bboxes[:, 0:4] if bboxes.shape[1] == 5 else bboxes
                assert bboxes.shape[1] == 4

                detection_boxes.append(bboxes)
                detection_scores.append(scores)
                detection_classes.append(labels)
                detection_masks.append(segms)

            assert len(img_metas) == 1
            outputs = {
                'detection_boxes': detection_boxes,
                'detection_scores': detection_scores,
                'detection_classes': detection_classes,
                'detection_masks': detection_masks,
                'img_metas': img_metas[0]
            }
            return outputs

        setattr(cls, 'forward_test', _new_forward_test)


def dynamic_adapt_for_mmlab(cfg):
    mmlab_modules_cfg = cfg.get('mmlab_modules', [])
    if len(mmlab_modules_cfg) > 1:
        adapter = MMAdapter(mmlab_modules_cfg)
        adapter.adapt_mmlab_modules()


def update_rpn_head():
    logging.warning('refactor mmdet.models.RPNHead, add `norm_cfg`')
    from mmdet.models.builder import HEADS
    HEADS._module_dict.pop('RPNHead', None)
    from mmdet.models import RPNHead as _RPNHead

    @HEADS.register_module()
    class RPNHead(_RPNHead):
        """RPN head with norm.
        Args:
            in_channels (int): Number of channels in the input feature map.
            init_cfg (dict or list[dict], optional): Initialization config dict.
            num_convs (int): Number of convolution layers in the head. Default 1.
        """  # noqa: W605

        def __init__(self,
                     in_channels,
                     init_cfg=dict(type='Normal', layer='Conv2d', std=0.01),
                     num_convs=1,
                     norm_cfg=None,
                     **kwargs):
            self.num_convs = num_convs
            self.norm_cfg = norm_cfg
            super(RPNHead, self).__init__(
                in_channels, init_cfg=init_cfg, **kwargs)

        def _init_layers(self):
            """Initialize layers of the head."""
            if self.num_convs > 1:
                rpn_convs = []
                for i in range(self.num_convs):
                    if i == 0:
                        in_channels = self.in_channels
                    else:
                        in_channels = self.feat_channels
                    # use ``inplace=False`` to avoid error: one of the variables
                    # needed for gradient computation has been modified by an
                    # inplace operation.
                    rpn_convs.append(
                        ConvModule(
                            in_channels,
                            self.feat_channels,
                            3,
                            padding=1,
                            norm_cfg=self.norm_cfg,
                            inplace=False))
                self.rpn_conv = nn.Sequential(*rpn_convs)
            else:
                self.rpn_conv = nn.Conv2d(
                    self.in_channels, self.feat_channels, 3, padding=1)
            self.rpn_cls = nn.Conv2d(
                self.feat_channels,
                self.num_base_priors * self.cls_out_channels, 1)
            self.rpn_reg = nn.Conv2d(self.feat_channels,
                                     self.num_base_priors * 4, 1)