# Copyright (c) Alibaba, Inc. and its affiliates.
import inspect
import logging

import mmcv
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from mmcv.cnn import ConvModule

from easycv.models.registry import BACKBONES, HEADS, MODELS, NECKS
from .test_util import run_in_subprocess

try:
    from mmcv.runner.hooks import HOOKS
    HOOKS._module_dict.pop('YOLOXLrUpdaterHook', None)
    from mmdet.models.builder import MODELS as MMMODELS
    from mmdet.models.builder import BACKBONES as MMBACKBONES
    from mmdet.models.builder import NECKS as MMNECKS
    from mmdet.models.builder import HEADS as MMHEADS
    from mmdet.core import BitmapMasks, PolygonMasks, encode_mask_results
    from mmdet.core.mask import mask2bbox
except ImportError:
    pass

EASYCV_REGISTRY_MAP = {
    'model': MODELS,
    'backbone': BACKBONES,
    'neck': NECKS,
    'head': HEADS
}
MMDET = 'mmdet'
SUPPORT_MMLAB_TYPES = [MMDET]
_MMLAB_COPIES = locals()


class MMAdapter:

    def __init__(self, modules_config):
        """Adapt mmlab apis.
        Args: modules_config is as follow format:
            [
                dict(type='mmdet', name='MaskRCNN', module='model'), # means using mmdet MaskRCNN
                # dict(type='mmdet, name='ResNet', module='backbone'), # comment out, means use my ResNet
                dict(name='FPN', module='neck'),  # type is missing, use mmdet default
            ]
        """
        self.default_mmtype = 'mmdet'
        self.mmtype_list = set([])

        for module_cfg in modules_config:
            mmtype = module_cfg.get('type',
                                    self.default_mmtype)  # default mmdet
            self.mmtype_list.add(mmtype)

        self.check_env()
        self.fix_conflicts()

        self.MMTYPE_REGISTRY_MAP = self._get_mmtype_registry_map()
        self.modules_config = modules_config

    def check_env(self):
        assert self.mmtype_list.issubset(
            SUPPORT_MMLAB_TYPES), 'Only support %s now !' % SUPPORT_MMLAB_TYPES
        install_success = False
        try:
            import mmdet
            install_success = True
        except ModuleNotFoundError as e:
            logging.warning(e)
            logging.warning('Try to install mmdet...')

        if not install_success:
            try:
                run_in_subprocess('pip install mmdet')
            except:
                raise ValueError(
                    'Failed to install mmdet, '
                    'please refer to https://github.com/open-mmlab/mmdetection to install.'
                )

    def fix_conflicts(self):
        # mmdet and easycv both register
        if MMDET in self.mmtype_list:
            mmcv_conflict_list = ['YOLOXLrUpdaterHook']
            from mmcv.runner.hooks import HOOKS
            for conflict in mmcv_conflict_list:
                HOOKS._module_dict.pop(conflict, None)

    def adapt_mmlab_modules(self):
        for module_cfg in self.modules_config:
            mmtype = module_cfg['type']
            module_name, module_type = module_cfg['name'], module_cfg['module']
            self._merge_mmlab_module_to_easycv(mmtype, module_type,
                                               module_name)
            self.wrap_module(mmtype, module_type, module_name)

        for mmtype in self.mmtype_list:
            self._merge_all_easycv_modules_to_mmlab(mmtype)

    def wrap_module(self, mmtype, module_type, module_name):
        module_obj = self._get_mm_module_obj(mmtype, module_type, module_name)
        if mmtype == MMDET:
            MMDetWrapper().wrap_module(module_obj, module_type)

    def _merge_all_easycv_modules_to_mmlab(self, mmtype):
        # Add all my module to mmlab module registry, if duplicated, replace with my module.
        # To handle: if MaskRCNN use mmdet's api, but the backbone also uses the backbone registered in mmdet
        # In order to support our backbone, register our modules into mmdet.
        # If not specified mmdet type, use our modules by default.
        for key, registry_type in self.MMTYPE_REGISTRY_MAP[mmtype].items():
            registry_type._module_dict.update(
                EASYCV_REGISTRY_MAP[key]._module_dict)

    def _merge_mmlab_module_to_easycv(self,
                                      mmtype,
                                      module_type,
                                      module_name,
                                      force=True):
        model_obj = self._get_mm_module_obj(mmtype, module_type, module_name)
        # Add mmlab module to my module registry.
        easycv_registry_type = EASYCV_REGISTRY_MAP[module_type]
        # Copy a duplicate to avoid directly modifying the properties of the original object
        _MMLAB_COPIES[module_name] = type(module_name, (model_obj, ), dict())
        easycv_registry_type.register_module(
            _MMLAB_COPIES[module_name], force=force)

    def _get_mm_module_obj(self, mmtype, module_type, module_name):
        if isinstance(module_name, str):
            mm_registry_type = self.MMTYPE_REGISTRY_MAP[mmtype][module_type]
            mm_module_dict = mm_registry_type._module_dict
            if module_name in mm_module_dict:
                module_obj = mm_module_dict[module_name]
            else:
                raise ValueError('Not find {} object in {}'.format(
                    module_name, mmtype))
        elif inspect.isclass(module_name):
            module_obj = module_name
        else:
            raise ValueError(
                'Only support type `str` and `class` object, but get type {}'.
                format(type(module_name)))
        return module_obj

    def _get_mmtype_registry_map(self):
        registry_map = {
            MMDET: {
                'model': MMMODELS,
                'backbone': MMBACKBONES,
                'neck': MMNECKS,
                'head': MMHEADS
            }
        }
        return registry_map


class MMDetWrapper:

    def __init__(self):
        self.refactor_modules()

    def wrap_module(self, cls, module_type):
        if hasattr(cls, 'is_wrap') and cls.is_wrap:
            return
        if module_type == 'model':
            self._wrap_model_init(cls)
            self._wrap_model_forward(cls)
            self._wrap_model_forward_test(cls)
            cls.is_wrap = True

    def refactor_modules(self):
        update_rpn_head()

    def _wrap_model_init(self, cls):
        origin_init = cls.__init__

        def _new_init(self, *args, **kwargs):
            origin_init(self, *args, **kwargs)
            self.init_weights()

        setattr(cls, '__init__', _new_init)

    def _wrap_model_forward(self, cls):
        origin_forward = cls.forward

        def _new_forward(self, img, mode='train', **kwargs):
            img_metas = kwargs.pop('img_metas', None)

            if mode == 'train':
                return origin_forward(
                    self, img, img_metas, return_loss=True, **kwargs)
            else:
                return origin_forward(
                    self, img, img_metas, return_loss=False, **kwargs)

        setattr(cls, 'forward', _new_forward)

    def _wrap_model_forward_test(self, cls):

        origin_forward_test = cls.forward_test

        def _new_forward_test(self, img, img_metas=None, **kwargs):
            kwargs.update({'rescale': True})  # move from single_gpu_test
            logging.info('Set rescale to True for `model.forward_test`!')

            result = origin_forward_test(self, img, img_metas, **kwargs)
            # ============result process to adapt to easycv============
            # encode mask results
            if isinstance(result[0], tuple):
                result = [(bbox_results, encode_mask_results(mask_results))
                          for bbox_results, mask_results in result]
            # This logic is only used in panoptic segmentation test.
            elif isinstance(result[0], dict) and 'ins_results' in result[0]:
                for j in range(len(result)):
                    bbox_results, mask_results = result[j]['ins_results']
                    result[j]['ins_results'] = (
                        bbox_results, encode_mask_results(mask_results))

            detection_boxes = []
            detection_scores = []
            detection_classes = []
            detection_masks = []
            for res_i in result:
                if isinstance(res_i, tuple):
                    bbox_result, segm_result = res_i
                    if isinstance(segm_result, tuple):
                        segm_result = segm_result[0]  # ms rcnn
                else:
                    bbox_result, segm_result = res_i, None
                bboxes = np.vstack(bbox_result)
                labels = [
                    np.full(bbox.shape[0], i, dtype=np.int32)
                    for i, bbox in enumerate(bbox_result)
                ]
                labels = np.concatenate(labels)
                # draw segmentation masks
                segms = []
                if segm_result is not None and len(labels) > 0:  # non empty

                    segms = mmcv.concat_list(segm_result)
                    if isinstance(segms[0], torch.Tensor):
                        segms = torch.stack(
                            segms, dim=0).detach().cpu().numpy()
                    else:
                        segms = np.stack(segms, axis=0)

                scores = bboxes[:, 4] if bboxes.shape[1] == 5 else None
                bboxes = bboxes[:, 0:4] if bboxes.shape[1] == 5 else bboxes
                assert bboxes.shape[1] == 4

                detection_boxes.append(bboxes)
                detection_scores.append(scores)
                detection_classes.append(labels)
                detection_masks.append(segms)

            assert len(img_metas) == 1
            outputs = {
                'detection_boxes': detection_boxes,
                'detection_scores': detection_scores,
                'detection_classes': detection_classes,
                'detection_masks': detection_masks,
                'img_metas': img_metas[0]
            }
            return outputs

        setattr(cls, 'forward_test', _new_forward_test)


def update_rpn_head():
    logging.warning('refactor mmdet.models.RPNHead, add `norm_cfg`')
    from mmdet.models.builder import HEADS
    HEADS._module_dict.pop('RPNHead', None)
    from mmdet.models import RPNHead as _RPNHead

    @HEADS.register_module()
    class RPNHead(_RPNHead):
        """RPN head with norm.
        Args:
            in_channels (int): Number of channels in the input feature map.
            init_cfg (dict or list[dict], optional): Initialization config dict.
            num_convs (int): Number of convolution layers in the head. Default 1.
        """  # noqa: W605

        def __init__(self,
                     in_channels,
                     init_cfg=dict(type='Normal', layer='Conv2d', std=0.01),
                     num_convs=1,
                     norm_cfg=None,
                     **kwargs):

            self.norm_cfg = norm_cfg
            super(RPNHead, self).__init__(
                in_channels, init_cfg=init_cfg, num_convs=num_convs, **kwargs)

        def _init_layers(self):
            """Initialize layers of the head."""
            if self.num_convs > 1:
                rpn_convs = []
                for i in range(self.num_convs):
                    if i == 0:
                        in_channels = self.in_channels
                    else:
                        in_channels = self.feat_channels
                    # use ``inplace=False`` to avoid error: one of the variables
                    # needed for gradient computation has been modified by an
                    # inplace operation.
                    rpn_convs.append(
                        ConvModule(
                            in_channels,
                            self.feat_channels,
                            3,
                            padding=1,
                            norm_cfg=self.norm_cfg,
                            inplace=False))
                self.rpn_conv = nn.Sequential(*rpn_convs)
            else:
                self.rpn_conv = nn.Conv2d(
                    self.in_channels, self.feat_channels, 3, padding=1)
            self.rpn_cls = nn.Conv2d(
                self.feat_channels,
                self.num_base_priors * self.cls_out_channels, 1)
            self.rpn_reg = nn.Conv2d(self.feat_channels,
                                     self.num_base_priors * 4, 1)

        def forward_single(self, x):
            """Forward feature map of a single scale level."""
            x = self.rpn_conv(x)
            # inplace=False to fix gradient computation has been modified by F.relu() when run with PyTorch 1.10.
            x = F.relu(x, inplace=False)
            rpn_cls_score = self.rpn_cls(x)
            rpn_bbox_pred = self.rpn_reg(x)
            return rpn_cls_score, rpn_bbox_pred


def dynamic_adapt_for_mmlab(cfg):
    mmlab_modules_cfg = cfg.get('mmlab_modules', [])
    if len(mmlab_modules_cfg) > 1:
        adapter = MMAdapter(mmlab_modules_cfg)
        adapter.adapt_mmlab_modules()