mirror of https://github.com/alibaba/EasyCV.git
376 lines
14 KiB
Python
376 lines
14 KiB
Python
# Copyright (c) Alibaba, Inc. and its affiliates.
|
|
# flake8: noqa
|
|
import copy
|
|
import inspect
|
|
import logging
|
|
|
|
import mmcv
|
|
import numpy as np
|
|
import torch
|
|
import torch.nn as nn
|
|
import torch.nn.functional as F
|
|
from mmcv.cnn import ConvModule
|
|
|
|
from easycv.framework.errors import TypeError, ValueError
|
|
from easycv.models.registry import BACKBONES, HEADS, MODELS, NECKS
|
|
from .test_util import run_in_subprocess
|
|
|
|
MMDET = 'mmdet'
|
|
|
|
try:
|
|
from mmcv.runner.hooks import HOOKS
|
|
import mmdet
|
|
HOOKS._module_dict.pop('YOLOXLrUpdaterHook', None)
|
|
from mmdet.core import BitmapMasks, PolygonMasks, encode_mask_results
|
|
from mmdet.core.mask import mask2bbox
|
|
except ImportError:
|
|
pass
|
|
|
|
MM_REGISTRY = None
|
|
MM_ORIGINAL_REGISTRY = None
|
|
|
|
EASYCV_REGISTRY_MAP = {
|
|
'model': MODELS,
|
|
'backbone': BACKBONES,
|
|
'neck': NECKS,
|
|
'head': HEADS
|
|
}
|
|
|
|
SUPPORT_MMLAB_TYPES = [MMDET]
|
|
_MMLAB_COPIES = locals()
|
|
|
|
|
|
class MMAdapter:
|
|
|
|
def __init__(self, modules_config):
|
|
"""Adapt mmlab apis.
|
|
Args: modules_config is as follow format:
|
|
[
|
|
dict(type='mmdet', name='MaskRCNN', module='model'), # means using mmdet MaskRCNN
|
|
# dict(type='mmdet, name='ResNet', module='backbone'), # comment out, means use my ResNet
|
|
dict(name='FPN', module='neck'), # type is missing, use mmdet default
|
|
]
|
|
"""
|
|
self.default_mmtype = 'mmdet'
|
|
self.mmtype_list = set([])
|
|
|
|
for module_cfg in modules_config:
|
|
mmtype = module_cfg.get('type',
|
|
self.default_mmtype) # default mmdet
|
|
self.mmtype_list.add(mmtype)
|
|
|
|
self.check_env()
|
|
|
|
# Remove the annotation in feature
|
|
# self.fix_conflicts()
|
|
|
|
self.MMTYPE_REGISTRY_MAP = MMAdapter.reset_mm_registry()
|
|
self.modules_config = modules_config
|
|
|
|
def check_env(self):
|
|
assert self.mmtype_list.issubset(
|
|
SUPPORT_MMLAB_TYPES), 'Only support %s now !' % SUPPORT_MMLAB_TYPES
|
|
install_success = False
|
|
try:
|
|
import mmdet
|
|
install_success = True
|
|
except ModuleNotFoundError as e:
|
|
logging.warning(e)
|
|
logging.warning('Try to install mmdet...')
|
|
|
|
if not install_success:
|
|
try:
|
|
run_in_subprocess('pip install mmdet')
|
|
except:
|
|
raise ValueError(
|
|
'Failed to install mmdet, '
|
|
'please refer to https://github.com/open-mmlab/mmdetection to install.'
|
|
)
|
|
|
|
def fix_conflicts(self):
|
|
# mmdet and easycv both register
|
|
if MMDET in self.mmtype_list:
|
|
mmcv_conflict_list = ['YOLOXLrUpdaterHook']
|
|
from mmcv.runner.hooks import HOOKS
|
|
for conflict in mmcv_conflict_list:
|
|
HOOKS._module_dict.pop(conflict, None)
|
|
|
|
def adapt_mmlab_modules(self):
|
|
for module_cfg in self.modules_config:
|
|
mmtype = module_cfg['type']
|
|
module_name, module_type = module_cfg['name'], module_cfg['module']
|
|
self._merge_mmlab_module_to_easycv(mmtype, module_type,
|
|
module_name)
|
|
self.wrap_module(mmtype, module_type, module_name)
|
|
|
|
for mmtype in self.mmtype_list:
|
|
self._merge_all_easycv_modules_to_mmlab(mmtype)
|
|
|
|
def wrap_module(self, mmtype, module_type, module_name):
|
|
module_obj = self._get_mm_module_obj_in_easycv(mmtype, module_type,
|
|
module_name)
|
|
if mmtype == MMDET:
|
|
MMDetWrapper().wrap_module(module_obj, module_type)
|
|
|
|
def _merge_all_easycv_modules_to_mmlab(self, mmtype):
|
|
# Add all my module to mmlab module registry, if duplicated, replace with my module.
|
|
# To handle: if MaskRCNN use mmdet's api, but the backbone also uses the backbone registered in mmdet
|
|
# In order to support our backbone, register our modules into mmdet.
|
|
# If not specified mmdet type, use our modules by default.
|
|
for key, registry_type in self.MMTYPE_REGISTRY_MAP[mmtype].items():
|
|
registry_type._module_dict.update(
|
|
EASYCV_REGISTRY_MAP[key]._module_dict)
|
|
|
|
def _merge_mmlab_module_to_easycv(self,
|
|
mmtype,
|
|
module_type,
|
|
module_name,
|
|
force=True):
|
|
model_obj = self._get_mm_module_obj(mmtype, module_type, module_name)
|
|
# Add mmlab module to my module registry.
|
|
easycv_registry_type = EASYCV_REGISTRY_MAP[module_type]
|
|
# Copy a duplicate to avoid directly modifying the properties of the original object
|
|
key = '.'.join([mmtype, module_type, module_name])
|
|
_MMLAB_COPIES[key] = type(module_name, (model_obj, ), dict())
|
|
easycv_registry_type.register_module(_MMLAB_COPIES[key], force=force)
|
|
|
|
def _get_mm_module_obj_in_easycv(self, mmtype, module_type, module_name):
|
|
key = '.'.join([mmtype, module_type, module_name])
|
|
return _MMLAB_COPIES[key]
|
|
|
|
def _get_mm_module_obj(self, mmtype, module_type, module_name):
|
|
if isinstance(module_name, str):
|
|
mm_registry_type = self.MMTYPE_REGISTRY_MAP[mmtype][module_type]
|
|
mm_module_dict = mm_registry_type._module_dict
|
|
if module_name in mm_module_dict:
|
|
module_obj = mm_module_dict[module_name]
|
|
else:
|
|
raise ValueError('Not find {} object in {}'.format(
|
|
module_name, mmtype))
|
|
elif inspect.isclass(module_name):
|
|
module_obj = module_name
|
|
else:
|
|
raise TypeError(
|
|
'Only support type `str` and `class` object, but get type {}'.
|
|
format(type(module_name)))
|
|
return module_obj
|
|
|
|
@staticmethod
|
|
def reset_mm_registry():
|
|
global MM_ORIGINAL_REGISTRY
|
|
global MM_REGISTRY
|
|
|
|
if MM_REGISTRY is None:
|
|
from mmdet.models.builder import MODELS as MMMODELS
|
|
from mmdet.models.builder import BACKBONES as MMBACKBONES
|
|
from mmdet.models.builder import NECKS as MMNECKS
|
|
from mmdet.models.builder import HEADS as MMHEADS
|
|
MM_REGISTRY = {
|
|
MMDET: {
|
|
'model': MMMODELS,
|
|
'backbone': MMBACKBONES,
|
|
'neck': MMNECKS,
|
|
'head': MMHEADS
|
|
}
|
|
}
|
|
MM_ORIGINAL_REGISTRY = copy.deepcopy(MM_REGISTRY)
|
|
|
|
for mmtype, registries in MM_ORIGINAL_REGISTRY.items():
|
|
for k, ori_v in registries.items():
|
|
MM_REGISTRY[mmtype][k]._module_dict = copy.deepcopy(
|
|
ori_v._module_dict)
|
|
|
|
return MM_REGISTRY
|
|
|
|
|
|
class MMDetWrapper:
|
|
|
|
def __init__(self, refactor_modules=True):
|
|
if refactor_modules:
|
|
self.refactor_modules()
|
|
|
|
def wrap_module(self, cls, module_type):
|
|
if hasattr(cls, 'is_wrap') and cls.is_wrap:
|
|
return
|
|
if module_type == 'model':
|
|
self._wrap_model_init(cls)
|
|
self._wrap_model_forward(cls)
|
|
self._wrap_model_forward_test(cls)
|
|
cls.is_wrap = True
|
|
|
|
def refactor_modules(self):
|
|
update_rpn_head()
|
|
|
|
def _wrap_model_init(self, cls):
|
|
origin_init = cls.__init__
|
|
|
|
def _new_init(self, *args, **kwargs):
|
|
origin_init(self, *args, **kwargs)
|
|
self.init_weights()
|
|
|
|
setattr(cls, '__init__', _new_init)
|
|
|
|
def _wrap_model_forward(self, cls):
|
|
origin_forward = cls.forward
|
|
|
|
def _new_forward(self, img, mode='train', **kwargs):
|
|
img_metas = kwargs.pop('img_metas', None)
|
|
|
|
if mode == 'train':
|
|
return origin_forward(
|
|
self, img, img_metas, return_loss=True, **kwargs)
|
|
else:
|
|
return origin_forward(
|
|
self, img, img_metas, return_loss=False, **kwargs)
|
|
|
|
setattr(cls, 'forward', _new_forward)
|
|
|
|
def _wrap_model_forward_test(self, cls):
|
|
|
|
origin_forward_test = cls.forward_test
|
|
|
|
def _new_forward_test(self, img, img_metas=None, **kwargs):
|
|
kwargs.update({'rescale': True}) # move from single_gpu_test
|
|
logging.info('Set rescale to True for `model.forward_test`!')
|
|
|
|
result = origin_forward_test(self, img, img_metas, **kwargs)
|
|
# ============result process to adapt to easycv============
|
|
# encode mask results
|
|
if isinstance(result[0], tuple):
|
|
result = [(bbox_results, encode_mask_results(mask_results))
|
|
for bbox_results, mask_results in result]
|
|
# This logic is only used in panoptic segmentation test.
|
|
elif isinstance(result[0], dict) and 'ins_results' in result[0]:
|
|
for j in range(len(result)):
|
|
bbox_results, mask_results = result[j]['ins_results']
|
|
result[j]['ins_results'] = (
|
|
bbox_results, encode_mask_results(mask_results))
|
|
|
|
detection_boxes = []
|
|
detection_scores = []
|
|
detection_classes = []
|
|
detection_masks = []
|
|
for res_i in result:
|
|
if isinstance(res_i, tuple):
|
|
bbox_result, segm_result = res_i
|
|
if isinstance(segm_result, tuple):
|
|
segm_result = segm_result[0] # ms rcnn
|
|
else:
|
|
bbox_result, segm_result = res_i, None
|
|
bboxes = np.vstack(bbox_result)
|
|
labels = [
|
|
np.full(bbox.shape[0], i, dtype=np.int32)
|
|
for i, bbox in enumerate(bbox_result)
|
|
]
|
|
labels = np.concatenate(labels)
|
|
# draw segmentation masks
|
|
segms = []
|
|
if segm_result is not None and len(labels) > 0: # non empty
|
|
|
|
segms = mmcv.concat_list(segm_result)
|
|
if isinstance(segms[0], torch.Tensor):
|
|
segms = torch.stack(
|
|
segms, dim=0).detach().cpu().numpy()
|
|
else:
|
|
segms = np.stack(segms, axis=0)
|
|
|
|
scores = bboxes[:, 4] if bboxes.shape[1] == 5 else None
|
|
bboxes = bboxes[:, 0:4] if bboxes.shape[1] == 5 else bboxes
|
|
assert bboxes.shape[1] == 4
|
|
|
|
detection_boxes.append(bboxes)
|
|
detection_scores.append(scores)
|
|
detection_classes.append(labels)
|
|
detection_masks.append(segms)
|
|
|
|
assert len(img_metas) == 1
|
|
outputs = {
|
|
'detection_boxes': detection_boxes,
|
|
'detection_scores': detection_scores,
|
|
'detection_classes': detection_classes,
|
|
'detection_masks': detection_masks,
|
|
'img_metas': img_metas[0]
|
|
}
|
|
return outputs
|
|
|
|
setattr(cls, 'forward_test', _new_forward_test)
|
|
|
|
|
|
def update_rpn_head():
|
|
logging.warning('refactor mmdet.models.RPNHead, add `norm_cfg`')
|
|
from mmdet.models.builder import HEADS
|
|
HEADS._module_dict.pop('RPNHead', None)
|
|
from mmdet.models import RPNHead as _RPNHead
|
|
|
|
@HEADS.register_module()
|
|
class RPNHead(_RPNHead):
|
|
"""RPN head with norm.
|
|
Args:
|
|
in_channels (int): Number of channels in the input feature map.
|
|
init_cfg (dict or list[dict], optional): Initialization config dict.
|
|
num_convs (int): Number of convolution layers in the head. Default 1.
|
|
""" # noqa: W605
|
|
|
|
def __init__(self,
|
|
in_channels,
|
|
init_cfg=dict(type='Normal', layer='Conv2d', std=0.01),
|
|
num_convs=1,
|
|
norm_cfg=None,
|
|
**kwargs):
|
|
|
|
self.norm_cfg = norm_cfg
|
|
super(RPNHead, self).__init__(
|
|
in_channels, init_cfg=init_cfg, num_convs=num_convs, **kwargs)
|
|
|
|
def _init_layers(self):
|
|
"""Initialize layers of the head."""
|
|
if self.num_convs > 1:
|
|
rpn_convs = []
|
|
for i in range(self.num_convs):
|
|
if i == 0:
|
|
in_channels = self.in_channels
|
|
else:
|
|
in_channels = self.feat_channels
|
|
# use ``inplace=False`` to avoid error: one of the variables
|
|
# needed for gradient computation has been modified by an
|
|
# inplace operation.
|
|
rpn_convs.append(
|
|
ConvModule(
|
|
in_channels,
|
|
self.feat_channels,
|
|
3,
|
|
padding=1,
|
|
norm_cfg=self.norm_cfg,
|
|
inplace=False))
|
|
self.rpn_conv = nn.Sequential(*rpn_convs)
|
|
else:
|
|
self.rpn_conv = nn.Conv2d(
|
|
self.in_channels, self.feat_channels, 3, padding=1)
|
|
self.rpn_cls = nn.Conv2d(
|
|
self.feat_channels,
|
|
self.num_base_priors * self.cls_out_channels, 1)
|
|
self.rpn_reg = nn.Conv2d(self.feat_channels,
|
|
self.num_base_priors * 4, 1)
|
|
|
|
def forward_single(self, x):
|
|
"""Forward feature map of a single scale level."""
|
|
x = self.rpn_conv(x)
|
|
# inplace=False to fix gradient computation has been modified by F.relu() when run with PyTorch 1.10.
|
|
x = F.relu(x, inplace=False)
|
|
rpn_cls_score = self.rpn_cls(x)
|
|
rpn_bbox_pred = self.rpn_reg(x)
|
|
return rpn_cls_score, rpn_bbox_pred
|
|
|
|
|
|
def dynamic_adapt_for_mmlab(cfg):
|
|
mmlab_modules_cfg = cfg.get('mmlab_modules', [])
|
|
if len(mmlab_modules_cfg) > 1:
|
|
adapter = MMAdapter(mmlab_modules_cfg)
|
|
adapter.adapt_mmlab_modules()
|
|
|
|
|
|
def remove_adapt_for_mmlab(cfg):
|
|
mmlab_modules_cfg = cfg.get('mmlab_modules', [])
|
|
adapter = MMAdapter(mmlab_modules_cfg)
|
|
adapter.reset_mm_registry()
|