EasyCV/easycv/utils/mmlab_utils.py

323 lines
12 KiB
Python
Raw Normal View History

# Copyright (c) Alibaba, Inc. and its affiliates.
import inspect
import logging
import mmcv
import numpy as np
import torch
import torch.nn as nn
from mmcv.cnn import ConvModule
from easycv.models.registry import BACKBONES, HEADS, MODELS, NECKS
from .test_util import run_in_subprocess
EASYCV_REGISTRY_MAP = {
'model': MODELS,
'backbone': BACKBONES,
'neck': NECKS,
'head': HEADS
}
MMDET = 'mmdet'
SUPPORT_MMLAB_TYPES = [MMDET]
_MMLAB_COPIES = locals()
class MMAdapter:
def __init__(self, modules_config):
"""Adapt mmlab apis.
Args: modules_config is as follow format:
[
dict(type='mmdet', name='MaskRCNN', module='model'), # means using mmdet MaskRCNN
# dict(type='mmdet, name='ResNet', module='backbone'), # comment out, means use my ResNet
dict(name='FPN', module='neck'), # type is missing, use mmdet default
]
"""
self.default_mmtype = 'mmdet'
self.mmtype_list = set([])
for module_cfg in modules_config:
mmtype = module_cfg.get('type',
self.default_mmtype) # default mmdet
self.mmtype_list.add(mmtype)
self.check_env()
self.fix_conflicts()
self.MMTYPE_REGISTRY_MAP = self._get_mmtype_registry_map()
self.modules_config = modules_config
def check_env(self):
assert self.mmtype_list.issubset(
SUPPORT_MMLAB_TYPES), 'Only support %s now !' % SUPPORT_MMLAB_TYPES
install_success = False
try:
import mmdet
install_success = True
except ModuleNotFoundError as e:
logging.warning(e)
logging.warning('Try to install mmdet...')
if not install_success:
try:
run_in_subprocess('pip install mmdet')
except:
raise ValueError(
'Failed to install mmdet, '
'please refer to https://github.com/open-mmlab/mmdetection to install.'
)
def fix_conflicts(self):
# mmdet and easycv both register
if MMDET in self.mmtype_list:
mmcv_conflict_list = ['YOLOXLrUpdaterHook']
from mmcv.runner.hooks import HOOKS
for conflict in mmcv_conflict_list:
HOOKS._module_dict.pop(conflict, None)
def adapt_mmlab_modules(self):
for module_cfg in self.modules_config:
mmtype = module_cfg['type']
module_name, module_type = module_cfg['name'], module_cfg['module']
self._merge_mmlab_module_to_easycv(mmtype, module_type,
module_name)
self.wrap_module(mmtype, module_type, module_name)
for mmtype in self.mmtype_list:
self._merge_all_easycv_modules_to_mmlab(mmtype)
def wrap_module(self, mmtype, module_type, module_name):
module_obj = self._get_mm_module_obj(mmtype, module_type, module_name)
if mmtype == MMDET:
MMDetWrapper().wrap_module(module_obj, module_type)
def _merge_all_easycv_modules_to_mmlab(self, mmtype):
# Add all my module to mmlab module registry, if duplicated, replace with my module.
# To handle: if MaskRCNN use mmdet's api, but the backbone also uses the backbone registered in mmdet
# In order to support our backbone, register our modules into mmdet.
# If not specified mmdet type, use our modules by default.
for key, registry_type in self.MMTYPE_REGISTRY_MAP[mmtype].items():
registry_type._module_dict.update(
EASYCV_REGISTRY_MAP[key]._module_dict)
def _merge_mmlab_module_to_easycv(self,
mmtype,
module_type,
module_name,
force=True):
model_obj = self._get_mm_module_obj(mmtype, module_type, module_name)
# Add mmlab module to my module registry.
easycv_registry_type = EASYCV_REGISTRY_MAP[module_type]
# Copy a duplicate to avoid directly modifying the properties of the original object
_MMLAB_COPIES[module_name] = type(module_name, (model_obj, ), dict())
easycv_registry_type.register_module(
_MMLAB_COPIES[module_name], force=force)
def _get_mm_module_obj(self, mmtype, module_type, module_name):
if isinstance(module_name, str):
mm_registry_type = self.MMTYPE_REGISTRY_MAP[mmtype][module_type]
mm_module_dict = mm_registry_type._module_dict
if module_name in mm_module_dict:
module_obj = mm_module_dict[module_name]
else:
raise ValueError('Not find {} object in {}'.format(
module_name, mmtype))
elif inspect.isclass(module_name):
module_obj = module_name
else:
raise ValueError(
'Only support type `str` and `class` object, but get type {}'.
format(type(module_name)))
return module_obj
def _get_mmtype_registry_map(self):
from mmdet.models.builder import MODELS as MMMODELS
from mmdet.models.builder import BACKBONES as MMBACKBONES
from mmdet.models.builder import NECKS as MMNECKS
from mmdet.models.builder import HEADS as MMHEADS
registry_map = {
MMDET: {
'model': MMMODELS,
'backbone': MMBACKBONES,
'neck': MMNECKS,
'head': MMHEADS
}
}
return registry_map
class MMDetWrapper:
def __init__(self):
self.refactor_modules()
def wrap_module(self, cls, module_type):
if hasattr(cls, 'is_wrap') and cls.is_wrap:
return
if module_type == 'model':
self._wrap_model_init(cls)
self._wrap_model_forward(cls)
self._wrap_model_forward_test(cls)
cls.is_wrap = True
def refactor_modules(self):
update_rpn_head()
def _wrap_model_init(self, cls):
origin_init = cls.__init__
def _new_init(self, *args, **kwargs):
origin_init(self, *args, **kwargs)
self.init_weights()
setattr(cls, '__init__', _new_init)
def _wrap_model_forward(self, cls):
origin_forward = cls.forward
def _new_forward(self, img, mode='train', **kwargs):
img_metas = kwargs.pop('img_metas', None)
if mode == 'train':
return origin_forward(
self, img, img_metas, return_loss=True, **kwargs)
else:
return origin_forward(
self, img, img_metas, return_loss=False, **kwargs)
setattr(cls, 'forward', _new_forward)
def _wrap_model_forward_test(self, cls):
from mmdet.core import encode_mask_results
origin_forward_test = cls.forward_test
def _new_forward_test(self, img, img_metas=None, **kwargs):
kwargs.update({'rescale': True}) # move from single_gpu_test
logging.info('Set rescale to True for `model.forward_test`!')
result = origin_forward_test(self, img, img_metas, **kwargs)
# ============result process to adapt to easycv============
# encode mask results
if isinstance(result[0], tuple):
result = [(bbox_results, encode_mask_results(mask_results))
for bbox_results, mask_results in result]
# This logic is only used in panoptic segmentation test.
elif isinstance(result[0], dict) and 'ins_results' in result[0]:
for j in range(len(result)):
bbox_results, mask_results = result[j]['ins_results']
result[j]['ins_results'] = (
bbox_results, encode_mask_results(mask_results))
detection_boxes = []
detection_scores = []
detection_classes = []
detection_masks = []
for res_i in result:
if isinstance(res_i, tuple):
bbox_result, segm_result = res_i
if isinstance(segm_result, tuple):
segm_result = segm_result[0] # ms rcnn
else:
bbox_result, segm_result = res_i, None
bboxes = np.vstack(bbox_result)
labels = [
np.full(bbox.shape[0], i, dtype=np.int32)
for i, bbox in enumerate(bbox_result)
]
labels = np.concatenate(labels)
# draw segmentation masks
segms = []
if segm_result is not None and len(labels) > 0: # non empty
segms = mmcv.concat_list(segm_result)
if isinstance(segms[0], torch.Tensor):
segms = torch.stack(
segms, dim=0).detach().cpu().numpy()
else:
segms = np.stack(segms, axis=0)
scores = bboxes[:, 4] if bboxes.shape[1] == 5 else None
bboxes = bboxes[:, 0:4] if bboxes.shape[1] == 5 else bboxes
assert bboxes.shape[1] == 4
detection_boxes.append(bboxes)
detection_scores.append(scores)
detection_classes.append(labels)
detection_masks.append(segms)
assert len(img_metas) == 1
outputs = {
'detection_boxes': detection_boxes,
'detection_scores': detection_scores,
'detection_classes': detection_classes,
'detection_masks': detection_masks,
'img_metas': img_metas[0]
}
return outputs
setattr(cls, 'forward_test', _new_forward_test)
def update_rpn_head():
logging.warning('refactor mmdet.models.RPNHead, add `norm_cfg`')
from mmdet.models.builder import HEADS
HEADS._module_dict.pop('RPNHead', None)
from mmdet.models import RPNHead as _RPNHead
@HEADS.register_module()
class RPNHead(_RPNHead):
"""RPN head with norm.
Args:
in_channels (int): Number of channels in the input feature map.
init_cfg (dict or list[dict], optional): Initialization config dict.
num_convs (int): Number of convolution layers in the head. Default 1.
""" # noqa: W605
def __init__(self,
in_channels,
init_cfg=dict(type='Normal', layer='Conv2d', std=0.01),
num_convs=1,
norm_cfg=None,
**kwargs):
2022-06-22 18:07:44 +08:00
self.norm_cfg = norm_cfg
super(RPNHead, self).__init__(
2022-06-22 18:07:44 +08:00
in_channels, init_cfg=init_cfg, num_convs=num_convs, **kwargs)
def _init_layers(self):
"""Initialize layers of the head."""
if self.num_convs > 1:
rpn_convs = []
for i in range(self.num_convs):
if i == 0:
in_channels = self.in_channels
else:
in_channels = self.feat_channels
# use ``inplace=False`` to avoid error: one of the variables
# needed for gradient computation has been modified by an
# inplace operation.
rpn_convs.append(
ConvModule(
in_channels,
self.feat_channels,
3,
padding=1,
norm_cfg=self.norm_cfg,
inplace=False))
self.rpn_conv = nn.Sequential(*rpn_convs)
else:
self.rpn_conv = nn.Conv2d(
self.in_channels, self.feat_channels, 3, padding=1)
self.rpn_cls = nn.Conv2d(
self.feat_channels,
self.num_base_priors * self.cls_out_channels, 1)
self.rpn_reg = nn.Conv2d(self.feat_channels,
self.num_base_priors * 4, 1)
def dynamic_adapt_for_mmlab(cfg):
mmlab_modules_cfg = cfg.get('mmlab_modules', [])
if len(mmlab_modules_cfg) > 1:
adapter = MMAdapter(mmlab_modules_cfg)
adapter.adapt_mmlab_modules()