EasyCV/easycv/utils/mmlab_utils.py

245 lines
9.4 KiB
Python
Raw Normal View History

# Copyright (c) Alibaba, Inc. and its affiliates.
import inspect
import logging
import mmcv
import numpy as np
import torch
from easycv.models.registry import BACKBONES, HEADS, MODELS, NECKS
from .test_util import run_in_subprocess
EASYCV_REGISTRY_MAP = {
'model': MODELS,
'backbone': BACKBONES,
'neck': NECKS,
'head': HEADS
}
MMDET = 'mmdet'
SUPPORT_MMLAB_TYPES = [MMDET]
_MMLAB_COPIES = locals()
class MMAdapter:
def __init__(self, modules_config):
"""Adapt mmlab apis.
Args: modules_config is as follow format:
[
dict(type='mmdet', name='MaskRCNN', module='model'), # means using mmdet MaskRCNN
# dict(type='mmdet, name='ResNet', module='backbone'), # comment out, means use my ResNet
dict(name='FPN', module='neck'), # type is missing, use mmdet default
]
"""
self.default_mmtype = 'mmdet'
self.mmtype_list = set([])
for module_cfg in modules_config:
mmtype = module_cfg.get('type',
self.default_mmtype) # default mmdet
self.mmtype_list.add(mmtype)
self.check_env()
self.fix_conflicts()
self.MMTYPE_REGISTRY_MAP = self._get_mmtype_registry_map()
self.modules_config = modules_config
def check_env(self):
assert self.mmtype_list.issubset(
SUPPORT_MMLAB_TYPES), 'Only support %s now !' % SUPPORT_MMLAB_TYPES
install_success = False
try:
import mmdet
install_success = True
except ModuleNotFoundError as e:
logging.warning(e)
logging.warning('Try to install mmdet...')
if not install_success:
try:
run_in_subprocess('pip install mmdet')
except:
raise ValueError(
'Failed to install mmdet, '
'please refer to https://github.com/open-mmlab/mmdetection to install.'
)
def fix_conflicts(self):
# mmdet and easycv both register
if MMDET in self.mmtype_list:
mmcv_conflict_list = ['YOLOXLrUpdaterHook']
from mmcv.runner.hooks import HOOKS
for conflict in mmcv_conflict_list:
HOOKS._module_dict.pop(conflict, None)
def adapt_mmlab_modules(self):
for module_cfg in self.modules_config:
mmtype = module_cfg['type']
module_name, module_type = module_cfg['name'], module_cfg['module']
self._merge_mmlab_module_to_easycv(mmtype, module_type,
module_name)
self.wrap_module(mmtype, module_type, module_name)
for mmtype in self.mmtype_list:
self._merge_all_easycv_modules_to_mmlab(mmtype)
def wrap_module(self, mmtype, module_type, module_name):
module_obj = self._get_mm_module_obj(mmtype, module_type, module_name)
if mmtype == MMDET:
MMDetWrapper().wrap_module(module_obj, module_type)
def _merge_all_easycv_modules_to_mmlab(self, mmtype):
# Add all my module to mmlab module registry, if duplicated, replace with my module.
# To handle: if MaskRCNN use mmdet's api, but the backbone also uses the backbone registered in mmdet
# In order to support our backbone, register our modules into mmdet.
# If not specified mmdet type, use our modules by default.
for key, registry_type in self.MMTYPE_REGISTRY_MAP[mmtype].items():
registry_type._module_dict.update(
EASYCV_REGISTRY_MAP[key]._module_dict)
def _merge_mmlab_module_to_easycv(self,
mmtype,
module_type,
module_name,
force=True):
model_obj = self._get_mm_module_obj(mmtype, module_type, module_name)
# Add mmlab module to my module registry.
easycv_registry_type = EASYCV_REGISTRY_MAP[module_type]
# Copy a duplicate to avoid directly modifying the properties of the original object
_MMLAB_COPIES[module_name] = type(module_name, (model_obj, ), dict())
easycv_registry_type.register_module(
_MMLAB_COPIES[module_name], force=force)
def _get_mm_module_obj(self, mmtype, module_type, module_name):
if isinstance(module_name, str):
mm_registry_type = self.MMTYPE_REGISTRY_MAP[mmtype][module_type]
mm_module_dict = mm_registry_type._module_dict
if module_name in mm_module_dict:
module_obj = mm_module_dict[module_name]
else:
raise ValueError('Not find {} object in {}'.format(
module_name, mmtype))
elif inspect.isclass(module_name):
module_obj = module_name
else:
raise ValueError(
'Only support type `str` and `class` object, but get type {}'.
format(type(module_name)))
return module_obj
def _get_mmtype_registry_map(self):
from mmdet.models.builder import MODELS as MMMODELS
from mmdet.models.builder import BACKBONES as MMBACKBONES
from mmdet.models.builder import NECKS as MMNECKS
from mmdet.models.builder import HEADS as MMHEADS
registry_map = {
MMDET: {
'model': MMMODELS,
'backbone': MMBACKBONES,
'neck': MMNECKS,
'head': MMHEADS
}
}
return registry_map
class MMDetWrapper:
def wrap_module(self, cls, module_type):
if module_type == 'model':
self._wrap_model_forward(cls)
self._wrap_model_forward_test(cls)
def _wrap_model_forward(self, cls):
origin_forward = cls.forward
def _new_forward(self, img, mode='train', **kwargs):
img_metas = kwargs.pop('img_metas', None)
if mode == 'train':
return origin_forward(
self, img, img_metas, return_loss=True, **kwargs)
else:
return origin_forward(
self, img, img_metas, return_loss=False, **kwargs)
setattr(cls, 'forward', _new_forward)
def _wrap_model_forward_test(self, cls):
from mmdet.core import encode_mask_results
origin_forward_test = cls.forward_test
def _new_forward_test(self, img, img_metas=None, **kwargs):
kwargs.update({'rescale': True}) # move from single_gpu_test
logging.info('Set rescale to True for `model.forward_test`!')
result = origin_forward_test(self, img, img_metas, **kwargs)
# ============result process to adapt to easycv============
# encode mask results
if isinstance(result[0], tuple):
result = [(bbox_results, encode_mask_results(mask_results))
for bbox_results, mask_results in result]
# This logic is only used in panoptic segmentation test.
elif isinstance(result[0], dict) and 'ins_results' in result[0]:
for j in range(len(result)):
bbox_results, mask_results = result[j]['ins_results']
result[j]['ins_results'] = (
bbox_results, encode_mask_results(mask_results))
detection_boxes = []
detection_scores = []
detection_classes = []
detection_masks = []
for res_i in result:
if isinstance(res_i, tuple):
bbox_result, segm_result = res_i
if isinstance(segm_result, tuple):
segm_result = segm_result[0] # ms rcnn
else:
bbox_result, segm_result = res_i, None
bboxes = np.vstack(bbox_result)
labels = [
np.full(bbox.shape[0], i, dtype=np.int32)
for i, bbox in enumerate(bbox_result)
]
labels = np.concatenate(labels)
# draw segmentation masks
segms = []
if segm_result is not None and len(labels) > 0: # non empty
segms = mmcv.concat_list(segm_result)
if isinstance(segms[0], torch.Tensor):
segms = torch.stack(
segms, dim=0).detach().cpu().numpy()
else:
segms = np.stack(segms, axis=0)
scores = bboxes[:, 4] if bboxes.shape[1] == 5 else None
bboxes = bboxes[:, 0:4] if bboxes.shape[1] == 5 else bboxes
assert bboxes.shape[1] == 4
detection_boxes.append(bboxes)
detection_scores.append(scores)
detection_classes.append(labels)
detection_masks.append(segms)
assert len(img_metas) == 1
outputs = {
'detection_boxes': detection_boxes,
'detection_scores': detection_scores,
'detection_classes': detection_classes,
'detection_masks': detection_masks,
'img_metas': img_metas[0]
}
return outputs
setattr(cls, 'forward_test', _new_forward_test)
def dynamic_adapt_for_mmlab(cfg):
mmlab_modules_cfg = cfg.get('mmlab_modules', [])
if len(mmlab_modules_cfg) > 1:
adapter = MMAdapter(mmlab_modules_cfg)
adapter.adapt_mmlab_modules()