# Copyright (c) Alibaba, Inc. and its affiliates. import inspect import logging import mmcv import numpy as np import torch from easycv.models.registry import BACKBONES, HEADS, MODELS, NECKS from .test_util import run_in_subprocess EASYCV_REGISTRY_MAP = { 'model': MODELS, 'backbone': BACKBONES, 'neck': NECKS, 'head': HEADS } MMDET = 'mmdet' SUPPORT_MMLAB_TYPES = [MMDET] _MMLAB_COPIES = locals() class MMAdapter: def __init__(self, modules_config): """Adapt mmlab apis. Args: modules_config is as follow format: [ dict(type='mmdet', name='MaskRCNN', module='model'), # means using mmdet MaskRCNN # dict(type='mmdet, name='ResNet', module='backbone'), # comment out, means use my ResNet dict(name='FPN', module='neck'), # type is missing, use mmdet default ] """ self.default_mmtype = 'mmdet' self.mmtype_list = set([]) for module_cfg in modules_config: mmtype = module_cfg.get('type', self.default_mmtype) # default mmdet self.mmtype_list.add(mmtype) self.check_env() self.fix_conflicts() self.MMTYPE_REGISTRY_MAP = self._get_mmtype_registry_map() self.modules_config = modules_config def check_env(self): assert self.mmtype_list.issubset( SUPPORT_MMLAB_TYPES), 'Only support %s now !' % SUPPORT_MMLAB_TYPES install_success = False try: import mmdet install_success = True except ModuleNotFoundError as e: logging.warning(e) logging.warning('Try to install mmdet...') if not install_success: try: run_in_subprocess('pip install mmdet') except: raise ValueError( 'Failed to install mmdet, ' 'please refer to https://github.com/open-mmlab/mmdetection to install.' ) def fix_conflicts(self): # mmdet and easycv both register if MMDET in self.mmtype_list: mmcv_conflict_list = ['YOLOXLrUpdaterHook'] from mmcv.runner.hooks import HOOKS for conflict in mmcv_conflict_list: HOOKS._module_dict.pop(conflict, None) def adapt_mmlab_modules(self): for module_cfg in self.modules_config: mmtype = module_cfg['type'] module_name, module_type = module_cfg['name'], module_cfg['module'] self._merge_mmlab_module_to_easycv(mmtype, module_type, module_name) self.wrap_module(mmtype, module_type, module_name) for mmtype in self.mmtype_list: self._merge_all_easycv_modules_to_mmlab(mmtype) def wrap_module(self, mmtype, module_type, module_name): module_obj = self._get_mm_module_obj(mmtype, module_type, module_name) if mmtype == MMDET: MMDetWrapper().wrap_module(module_obj, module_type) def _merge_all_easycv_modules_to_mmlab(self, mmtype): # Add all my module to mmlab module registry, if duplicated, replace with my module. # To handle: if MaskRCNN use mmdet's api, but the backbone also uses the backbone registered in mmdet # In order to support our backbone, register our modules into mmdet. # If not specified mmdet type, use our modules by default. for key, registry_type in self.MMTYPE_REGISTRY_MAP[mmtype].items(): registry_type._module_dict.update( EASYCV_REGISTRY_MAP[key]._module_dict) def _merge_mmlab_module_to_easycv(self, mmtype, module_type, module_name, force=True): model_obj = self._get_mm_module_obj(mmtype, module_type, module_name) # Add mmlab module to my module registry. easycv_registry_type = EASYCV_REGISTRY_MAP[module_type] # Copy a duplicate to avoid directly modifying the properties of the original object _MMLAB_COPIES[module_name] = type(module_name, (model_obj, ), dict()) easycv_registry_type.register_module( _MMLAB_COPIES[module_name], force=force) def _get_mm_module_obj(self, mmtype, module_type, module_name): if isinstance(module_name, str): mm_registry_type = self.MMTYPE_REGISTRY_MAP[mmtype][module_type] mm_module_dict = mm_registry_type._module_dict if module_name in mm_module_dict: module_obj = mm_module_dict[module_name] else: raise ValueError('Not find {} object in {}'.format( module_name, mmtype)) elif inspect.isclass(module_name): module_obj = module_name else: raise ValueError( 'Only support type `str` and `class` object, but get type {}'. format(type(module_name))) return module_obj def _get_mmtype_registry_map(self): from mmdet.models.builder import MODELS as MMMODELS from mmdet.models.builder import BACKBONES as MMBACKBONES from mmdet.models.builder import NECKS as MMNECKS from mmdet.models.builder import HEADS as MMHEADS registry_map = { MMDET: { 'model': MMMODELS, 'backbone': MMBACKBONES, 'neck': MMNECKS, 'head': MMHEADS } } return registry_map class MMDetWrapper: def wrap_module(self, cls, module_type): if module_type == 'model': self._wrap_model_forward(cls) self._wrap_model_forward_test(cls) def _wrap_model_forward(self, cls): origin_forward = cls.forward def _new_forward(self, img, mode='train', **kwargs): img_metas = kwargs.pop('img_metas', None) if mode == 'train': return origin_forward( self, img, img_metas, return_loss=True, **kwargs) else: return origin_forward( self, img, img_metas, return_loss=False, **kwargs) setattr(cls, 'forward', _new_forward) def _wrap_model_forward_test(self, cls): from mmdet.core import encode_mask_results origin_forward_test = cls.forward_test def _new_forward_test(self, img, img_metas=None, **kwargs): kwargs.update({'rescale': True}) # move from single_gpu_test logging.info('Set rescale to True for `model.forward_test`!') result = origin_forward_test(self, img, img_metas, **kwargs) # ============result process to adapt to easycv============ # encode mask results if isinstance(result[0], tuple): result = [(bbox_results, encode_mask_results(mask_results)) for bbox_results, mask_results in result] # This logic is only used in panoptic segmentation test. elif isinstance(result[0], dict) and 'ins_results' in result[0]: for j in range(len(result)): bbox_results, mask_results = result[j]['ins_results'] result[j]['ins_results'] = ( bbox_results, encode_mask_results(mask_results)) detection_boxes = [] detection_scores = [] detection_classes = [] detection_masks = [] for res_i in result: if isinstance(res_i, tuple): bbox_result, segm_result = res_i if isinstance(segm_result, tuple): segm_result = segm_result[0] # ms rcnn else: bbox_result, segm_result = res_i, None bboxes = np.vstack(bbox_result) labels = [ np.full(bbox.shape[0], i, dtype=np.int32) for i, bbox in enumerate(bbox_result) ] labels = np.concatenate(labels) # draw segmentation masks segms = [] if segm_result is not None and len(labels) > 0: # non empty segms = mmcv.concat_list(segm_result) if isinstance(segms[0], torch.Tensor): segms = torch.stack( segms, dim=0).detach().cpu().numpy() else: segms = np.stack(segms, axis=0) scores = bboxes[:, 4] if bboxes.shape[1] == 5 else None bboxes = bboxes[:, 0:4] if bboxes.shape[1] == 5 else bboxes assert bboxes.shape[1] == 4 detection_boxes.append(bboxes) detection_scores.append(scores) detection_classes.append(labels) detection_masks.append(segms) assert len(img_metas) == 1 outputs = { 'detection_boxes': detection_boxes, 'detection_scores': detection_scores, 'detection_classes': detection_classes, 'detection_masks': detection_masks, 'img_metas': img_metas[0] } return outputs setattr(cls, 'forward_test', _new_forward_test) def dynamic_adapt_for_mmlab(cfg): mmlab_modules_cfg = cfg.get('mmlab_modules', []) if len(mmlab_modules_cfg) > 1: adapter = MMAdapter(mmlab_modules_cfg) adapter.adapt_mmlab_modules()