mirror of https://github.com/alibaba/EasyCV.git
[feature]: support mmdet models config (#25)
* support mmdet models * add mmlab_models_usage_guide.md * remove tools/test.pypull/59/head
parent
10266f54a7
commit
b5fb2b70c7
|
@ -0,0 +1,138 @@
|
|||
# model settings
|
||||
model = dict(
|
||||
type='MaskRCNN',
|
||||
# EasyCV backbone
|
||||
backbone=dict(
|
||||
type='ResNet',
|
||||
depth=50,
|
||||
num_stages=4,
|
||||
out_indices=(1, 2, 3, 4),
|
||||
frozen_stages=1,
|
||||
norm_cfg=dict(type='BN', requires_grad=True),
|
||||
norm_eval=True),
|
||||
# mmdet backbone
|
||||
# backbone=dict(
|
||||
# type='ResNet',
|
||||
# depth=50,
|
||||
# num_stages=4,
|
||||
# out_indices=(0, 1, 2, 3),
|
||||
# frozen_stages=1,
|
||||
# norm_cfg=dict(type='BN', requires_grad=True),
|
||||
# norm_eval=True,
|
||||
# style='pytorch',
|
||||
# init_cfg=dict(type='Pretrained', checkpoint='torchvision://resnet50')),
|
||||
neck=dict(
|
||||
type='FPN',
|
||||
in_channels=[256, 512, 1024, 2048],
|
||||
out_channels=256,
|
||||
num_outs=5),
|
||||
rpn_head=dict(
|
||||
type='RPNHead',
|
||||
in_channels=256,
|
||||
feat_channels=256,
|
||||
anchor_generator=dict(
|
||||
type='AnchorGenerator',
|
||||
scales=[8],
|
||||
ratios=[0.5, 1.0, 2.0],
|
||||
strides=[4, 8, 16, 32, 64]),
|
||||
bbox_coder=dict(
|
||||
type='DeltaXYWHBBoxCoder',
|
||||
target_means=[.0, .0, .0, .0],
|
||||
target_stds=[1.0, 1.0, 1.0, 1.0]),
|
||||
loss_cls=dict(
|
||||
type='CrossEntropyLoss', use_sigmoid=True, loss_weight=1.0),
|
||||
loss_bbox=dict(type='L1Loss', loss_weight=1.0)),
|
||||
roi_head=dict(
|
||||
type='StandardRoIHead',
|
||||
bbox_roi_extractor=dict(
|
||||
type='SingleRoIExtractor',
|
||||
roi_layer=dict(type='RoIAlign', output_size=7, sampling_ratio=0),
|
||||
out_channels=256,
|
||||
featmap_strides=[4, 8, 16, 32]),
|
||||
bbox_head=dict(
|
||||
type='Shared2FCBBoxHead',
|
||||
in_channels=256,
|
||||
fc_out_channels=1024,
|
||||
roi_feat_size=7,
|
||||
num_classes=80,
|
||||
bbox_coder=dict(
|
||||
type='DeltaXYWHBBoxCoder',
|
||||
target_means=[0., 0., 0., 0.],
|
||||
target_stds=[0.1, 0.1, 0.2, 0.2]),
|
||||
reg_class_agnostic=False,
|
||||
loss_cls=dict(
|
||||
type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0),
|
||||
loss_bbox=dict(type='L1Loss', loss_weight=1.0)),
|
||||
mask_roi_extractor=dict(
|
||||
type='SingleRoIExtractor',
|
||||
roi_layer=dict(type='RoIAlign', output_size=14, sampling_ratio=0),
|
||||
out_channels=256,
|
||||
featmap_strides=[4, 8, 16, 32]),
|
||||
mask_head=dict(
|
||||
type='FCNMaskHead',
|
||||
num_convs=4,
|
||||
in_channels=256,
|
||||
conv_out_channels=256,
|
||||
num_classes=80,
|
||||
loss_mask=dict(
|
||||
type='CrossEntropyLoss', use_mask=True, loss_weight=1.0))),
|
||||
# model training and testing settings
|
||||
train_cfg=dict(
|
||||
rpn=dict(
|
||||
assigner=dict(
|
||||
type='MaxIoUAssigner',
|
||||
pos_iou_thr=0.7,
|
||||
neg_iou_thr=0.3,
|
||||
min_pos_iou=0.3,
|
||||
match_low_quality=True,
|
||||
ignore_iof_thr=-1),
|
||||
sampler=dict(
|
||||
type='RandomSampler',
|
||||
num=256,
|
||||
pos_fraction=0.5,
|
||||
neg_pos_ub=-1,
|
||||
add_gt_as_proposals=False),
|
||||
allowed_border=-1,
|
||||
pos_weight=-1,
|
||||
debug=False),
|
||||
rpn_proposal=dict(
|
||||
nms_pre=2000,
|
||||
max_per_img=1000,
|
||||
nms=dict(type='nms', iou_threshold=0.7),
|
||||
min_bbox_size=0),
|
||||
rcnn=dict(
|
||||
assigner=dict(
|
||||
type='MaxIoUAssigner',
|
||||
pos_iou_thr=0.5,
|
||||
neg_iou_thr=0.5,
|
||||
min_pos_iou=0.5,
|
||||
match_low_quality=True,
|
||||
ignore_iof_thr=-1),
|
||||
sampler=dict(
|
||||
type='RandomSampler',
|
||||
num=512,
|
||||
pos_fraction=0.25,
|
||||
neg_pos_ub=-1,
|
||||
add_gt_as_proposals=True),
|
||||
mask_size=28,
|
||||
pos_weight=-1,
|
||||
debug=False)),
|
||||
test_cfg=dict(
|
||||
rpn=dict(
|
||||
nms_pre=1000,
|
||||
max_per_img=1000,
|
||||
nms=dict(type='nms', iou_threshold=0.7),
|
||||
min_bbox_size=0),
|
||||
rcnn=dict(
|
||||
score_thr=0.05,
|
||||
nms=dict(type='nms', iou_threshold=0.5),
|
||||
max_per_img=100,
|
||||
mask_thr_binary=0.5)))
|
||||
|
||||
mmlab_modules = [
|
||||
dict(type='mmdet', name='MaskRCNN', module='model'),
|
||||
# dict(type=MMDET, name='ResNet', module='backbone'), # comment out, use EasyCV ResNet
|
||||
dict(type='mmdet', name='FPN', module='neck'),
|
||||
dict(type='mmdet', name='RPNHead', module='head'),
|
||||
dict(type='mmdet', name='StandardRoIHead', module='head'),
|
||||
]
|
|
@ -0,0 +1,69 @@
|
|||
# Use mmdetection's models in EasyCV
|
||||
|
||||
For details of mmdetection, please refer to :https://github.com/open-mmlab/mmdetection
|
||||
|
||||
**We only support mmdet's models and do not support other series in mmlab and other modules such as transforms, dataset api, etc. are not supported either.**
|
||||
|
||||
The models module of EasyCV is divided into four modules: `backbone`, `head`, `neck`, and `model`.
|
||||
|
||||
So we support the models combination of EasyCV and mmdet from these four levels.
|
||||
|
||||
**We will not adapt the other apis involved in these four levels modules, we package the entire api for use.**
|
||||
|
||||
> **Note: **
|
||||
>
|
||||
> **If you want to combine the models part of mmdet and easycv, please pay attention to the compatibility between the apis, we do not guarantee that the api of EasyCV and mmdet are compatible.**
|
||||
|
||||
Take the `MaskRCNN` model as an example, please refer to [mask_rcnn_r50_fpn.py](https://github.com/alibaba/EasyCV/tree/master/configs/detection/mask_rcnn/mask_rcnn_r50_fpn.py). Except for the backbone, other parts in this model are all mmdet apis.
|
||||
|
||||
The framework of `MaskRCNN` can be divided into the following parts from the `backbone`, `head`, `neck`, and `model` levels:
|
||||
|
||||
- backbone: `ResNet`
|
||||
|
||||
- head:`RPNHead`, `StandardRoIHead`
|
||||
|
||||
- neck: `FPN`
|
||||
|
||||
- model: `MaskRCNN`
|
||||
|
||||
The configuration adapt for mmdet is as follows:
|
||||
|
||||
```python
|
||||
mmlab_modules = [
|
||||
dict(type='mmdet', name='MaskRCNN', module='model'),
|
||||
# dict(type='mmdet', name='ResNet', module='backbone'), # comment out, use EasyCV ResNet
|
||||
dict(type='mmdet', name='FPN', module='neck'),
|
||||
dict(type='mmdet', name='RPNHead', module='head'),
|
||||
dict(type='mmdet', name='StandardRoIHead', module='head'),
|
||||
]
|
||||
```
|
||||
|
||||
> Parameters:
|
||||
>
|
||||
> - type: the name of the open source, only `mmdet` is supported
|
||||
> - name: the name of api
|
||||
> - Module: The name of the module to which the api belongs, only `backbone`,`head`,`neck`,`model` are supported.
|
||||
|
||||
In this configuration , the `head`, `neck`, and `model` parts specify the type as `mmdet`, except for `backbone`.
|
||||
|
||||
**No configured api will use the EasyCV api by default, , such as backbone (ResNet).**
|
||||
|
||||
**For other explicitly configured type as `mmdet`, we will use the mmdet api.**
|
||||
|
||||
Which is:
|
||||
|
||||
- `MaskRCNN`(model): Use mmdet's `MaskRCNN` api.
|
||||
|
||||
- `ResNet`(backbone): Use EasyCV's `ResNet` api.
|
||||
|
||||
> Note that the parameters of the `ResNet`of mmdet and EasyCV are different. Please pay attention to it!.
|
||||
|
||||
- `RPNHead`(head): Use mmdet's `RPNHead` api.
|
||||
|
||||
> Note that all the other apis configured in `RPNHead`, such as `AnchorGenerator`, `DeltaXYWHBBoxCoder`, etc., are all mmdet's apis, because we package the entire api for use.
|
||||
|
||||
- `StandardRoIHead`(head): Use mmdet's `StandardRoIHead` api.
|
||||
|
||||
> Note that all the other apis configured in `StandardRoIHead`, such as `SingleRoIExtractor`, `SingleRoIExtractor`, etc., are all mmdet's apis, because we package the entire api for use.
|
||||
|
||||
- `FPN`(neck): Use mmdet's `FPN` api.
|
|
@ -119,9 +119,16 @@ def single_gpu_test(model, data_loader, mode='test', use_fp16=False, **kwargs):
|
|||
results[k].append(v)
|
||||
|
||||
if 'img_metas' in data:
|
||||
batch_size = len(data['img_metas'].data[0])
|
||||
if isinstance(data['img_metas'], list):
|
||||
batch_size = len(data['img_metas'][0].data[0])
|
||||
else:
|
||||
batch_size = len(data['img_metas'].data[0])
|
||||
|
||||
else:
|
||||
batch_size = data['img'].size(0)
|
||||
if isinstance(data['img'], list):
|
||||
batch_size = data['img'][0].size(0)
|
||||
else:
|
||||
batch_size = data['img'].size(0)
|
||||
|
||||
for _ in range(batch_size):
|
||||
prog_bar.update()
|
||||
|
|
|
@ -151,7 +151,7 @@ def train_model(model,
|
|||
if validate:
|
||||
interval = cfg.eval_config.pop('interval', 1)
|
||||
for idx, eval_pipe in enumerate(cfg.eval_pipelines):
|
||||
data = eval_pipe.data
|
||||
data = eval_pipe.get('data', None) or cfg.data.val
|
||||
dist_eval = eval_pipe.get('dist_eval', False)
|
||||
|
||||
evaluator_cfg = eval_pipe.evaluators[0]
|
||||
|
|
|
@ -473,7 +473,8 @@ class CocoMaskEvaluator(Evaluator):
|
|||
groundtruth_masks_shape = self._image_id_to_mask_shape_map[image_id]
|
||||
detection_masks = detections_dict[
|
||||
standard_fields.DetectionResultFields.detection_masks]
|
||||
if groundtruth_masks_shape[1:] != detection_masks.shape[1:]:
|
||||
if len(detection_masks
|
||||
) and groundtruth_masks_shape[1:] != detection_masks.shape[1:]:
|
||||
raise ValueError(
|
||||
'Spatial shape of groundtruth masks and detection masks '
|
||||
'are incompatible: {} vs {}'.format(groundtruth_masks_shape,
|
||||
|
@ -601,6 +602,9 @@ class CocoMaskEvaluator(Evaluator):
|
|||
else:
|
||||
groundtruth_is_crowd = groundtruth_is_crowd_list[idx]
|
||||
|
||||
gt_masks = np.array(
|
||||
[self._ann_to_mask(mask, height, width) for mask in gt_masks],
|
||||
dtype=np.uint8)
|
||||
groundtruth_dict = {
|
||||
'groundtruth_boxes': gt_boxes_absolute,
|
||||
'groundtruth_instance_masks': gt_masks,
|
||||
|
@ -609,6 +613,11 @@ class CocoMaskEvaluator(Evaluator):
|
|||
}
|
||||
self.add_single_ground_truth_image_info(image_id, groundtruth_dict)
|
||||
|
||||
detection_masks = np.array([
|
||||
self._ann_to_mask(mask, height, width)
|
||||
for mask in detection_masks
|
||||
],
|
||||
dtype=np.uint8)
|
||||
# add detection info
|
||||
detection_dict = {
|
||||
'detection_masks': detection_masks,
|
||||
|
@ -621,6 +630,27 @@ class CocoMaskEvaluator(Evaluator):
|
|||
self.clear()
|
||||
return eval_dict
|
||||
|
||||
def _ann_to_mask(self, segmentation, height, width):
|
||||
from xtcocotools import mask as maskUtils
|
||||
segm = segmentation
|
||||
h = height
|
||||
w = width
|
||||
|
||||
if type(segm) == list:
|
||||
# polygon -- a single object might consist of multiple parts
|
||||
# we merge all parts into one mask rle code
|
||||
rles = maskUtils.frPyObjects(segm, h, w)
|
||||
rle = maskUtils.merge(rles)
|
||||
elif type(segm['counts']) == list:
|
||||
# uncompressed RLE
|
||||
rle = maskUtils.frPyObjects(segm, h, w)
|
||||
else:
|
||||
# rle
|
||||
rle = segm
|
||||
|
||||
m = maskUtils.decode(rle)
|
||||
return m
|
||||
|
||||
|
||||
@EVALUATORS.register_module
|
||||
class CoCoPoseTopDownEvaluator(Evaluator):
|
||||
|
|
|
@ -1644,20 +1644,21 @@ class LoadAnnotations:
|
|||
Returns:
|
||||
numpy.ndarray: The decode bitmap mask of shape (img_h, img_w).
|
||||
"""
|
||||
raise NotImplementedError
|
||||
# if isinstance(mask_ann, list):
|
||||
# # polygon -- a single object might consist of multiple parts
|
||||
# # we merge all parts into one mask rle code
|
||||
# rles = maskUtils.frPyObjects(mask_ann, img_h, img_w)
|
||||
# rle = maskUtils.merge(rles)
|
||||
# elif isinstance(mask_ann['counts'], list):
|
||||
# # uncompressed RLE
|
||||
# rle = maskUtils.frPyObjects(mask_ann, img_h, img_w)
|
||||
# else:
|
||||
# # rle
|
||||
# rle = mask_ann
|
||||
# mask = maskUtils.decode(rle)
|
||||
# return mask
|
||||
import xtcocotools.mask as maskUtils
|
||||
|
||||
if isinstance(mask_ann, list):
|
||||
# polygon -- a single object might consist of multiple parts
|
||||
# we merge all parts into one mask rle code
|
||||
rles = maskUtils.frPyObjects(mask_ann, img_h, img_w)
|
||||
rle = maskUtils.merge(rles)
|
||||
elif isinstance(mask_ann['counts'], list):
|
||||
# uncompressed RLE
|
||||
rle = maskUtils.frPyObjects(mask_ann, img_h, img_w)
|
||||
else:
|
||||
# rle
|
||||
rle = mask_ann
|
||||
mask = maskUtils.decode(rle)
|
||||
return mask
|
||||
|
||||
def process_polygons(self, polygons):
|
||||
"""Convert polygons to list of ndarray and filter invalid polygons.
|
||||
|
@ -1687,20 +1688,20 @@ class LoadAnnotations:
|
|||
If ``self.poly2mask`` is set ``True``, `gt_mask` will contain
|
||||
:obj:`PolygonMasks`. Otherwise, :obj:`BitmapMasks` is used.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
from mmdet.core import BitmapMasks, PolygonMasks
|
||||
|
||||
# h, w = results['img_info']['height'], results['img_info']['width']
|
||||
# gt_masks = results['ann_info']['masks']
|
||||
# if self.poly2mask:
|
||||
# gt_masks = BitmapMasks(
|
||||
# [self._poly2mask(mask, h, w) for mask in gt_masks], h, w)
|
||||
# else:
|
||||
# gt_masks = PolygonMasks(
|
||||
# [self.process_polygons(polygons) for polygons in gt_masks], h,
|
||||
# w)
|
||||
# results['gt_masks'] = gt_masks
|
||||
# results['mask_fields'].append('gt_masks')
|
||||
# return results
|
||||
h, w = results['img_info']['height'], results['img_info']['width']
|
||||
gt_masks = results['ann_info']['masks']
|
||||
if self.poly2mask:
|
||||
gt_masks = BitmapMasks(
|
||||
[self._poly2mask(mask, h, w) for mask in gt_masks], h, w)
|
||||
else:
|
||||
gt_masks = PolygonMasks(
|
||||
[self.process_polygons(polygons) for polygons in gt_masks], h,
|
||||
w)
|
||||
results['gt_masks'] = gt_masks
|
||||
results['mask_fields'].append('gt_masks')
|
||||
return results
|
||||
|
||||
def _load_semantic_seg(self, results):
|
||||
"""Private function to load semantic segmentation annotations.
|
||||
|
|
|
@ -70,6 +70,10 @@ class DetDataset(BaseDataset):
|
|||
self.data_source.get_ann_info(idx)['groundtruth_is_crowd']
|
||||
for idx in range(len(results['img_metas']))
|
||||
]
|
||||
groundtruth_dict['groundtruth_instance_masks'] = [
|
||||
self.data_source.get_ann_info(idx).get('masks', None)
|
||||
for idx in range(len(results['img_metas']))
|
||||
]
|
||||
|
||||
for evaluator in evaluators:
|
||||
eval_result.update(evaluator.evaluate(results, groundtruth_dict))
|
||||
|
|
|
@ -2,25 +2,9 @@
|
|||
from .backbones import * # noqa: F401,F403
|
||||
from .builder import build_backbone, build_head, build_loss, build_model
|
||||
from .classification import *
|
||||
from .detection import *
|
||||
from .heads import *
|
||||
from .loss import *
|
||||
from .pose import TopDown
|
||||
from .registry import BACKBONES, HEADS, LOSSES, MODELS, NECKS
|
||||
from .selfsup import *
|
||||
|
||||
try:
|
||||
from .detection.yolox.yolox import YOLOX
|
||||
except:
|
||||
import logging
|
||||
logging.warning(
|
||||
'Import YOLOX failed! Please check if mmcv and CUDA & Pytorch match.'
|
||||
'You may try: `pip install mmcv-full=={mmcv_version} -f https://download.openmmlab.com/mmcv/dist/{cu_version}/{torch_version}/index.html`.'
|
||||
'e.g.: `pip install mmcv-full==1.3.18 -f https://download.openmmlab.com/mmcv/dist/cu101/torch1.7.0/index.html`'
|
||||
)
|
||||
try:
|
||||
from .detection.yolox_edge.yolox_edge import YOLOX_EDGE
|
||||
except:
|
||||
import logging
|
||||
logging.warning(
|
||||
'Import YOLOX EDGE model failed! Please check if mmcv and CUDA & Pytorch match.'
|
||||
)
|
||||
|
|
|
@ -1,13 +1,11 @@
|
|||
import copy
|
||||
import os
|
||||
import os.path as osp
|
||||
import platform
|
||||
import shutil
|
||||
import sys
|
||||
import tempfile
|
||||
from importlib import import_module
|
||||
|
||||
from mmcv import Config, check_file_exist, import_modules_from_strings
|
||||
from mmcv import Config, import_modules_from_strings
|
||||
|
||||
from .user_config_params_utils import check_value_type
|
||||
|
||||
|
|
|
@ -0,0 +1,244 @@
|
|||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||
import inspect
|
||||
import logging
|
||||
|
||||
import mmcv
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
from easycv.models.registry import BACKBONES, HEADS, MODELS, NECKS
|
||||
from .test_util import run_in_subprocess
|
||||
|
||||
EASYCV_REGISTRY_MAP = {
|
||||
'model': MODELS,
|
||||
'backbone': BACKBONES,
|
||||
'neck': NECKS,
|
||||
'head': HEADS
|
||||
}
|
||||
MMDET = 'mmdet'
|
||||
SUPPORT_MMLAB_TYPES = [MMDET]
|
||||
_MMLAB_COPIES = locals()
|
||||
|
||||
|
||||
class MMAdapter:
|
||||
|
||||
def __init__(self, modules_config):
|
||||
"""Adapt mmlab apis.
|
||||
Args: modules_config is as follow format:
|
||||
[
|
||||
dict(type='mmdet', name='MaskRCNN', module='model'), # means using mmdet MaskRCNN
|
||||
# dict(type='mmdet, name='ResNet', module='backbone'), # comment out, means use my ResNet
|
||||
dict(name='FPN', module='neck'), # type is missing, use mmdet default
|
||||
]
|
||||
"""
|
||||
self.default_mmtype = 'mmdet'
|
||||
self.mmtype_list = set([])
|
||||
|
||||
for module_cfg in modules_config:
|
||||
mmtype = module_cfg.get('type',
|
||||
self.default_mmtype) # default mmdet
|
||||
self.mmtype_list.add(mmtype)
|
||||
|
||||
self.check_env()
|
||||
self.fix_conflicts()
|
||||
|
||||
self.MMTYPE_REGISTRY_MAP = self._get_mmtype_registry_map()
|
||||
self.modules_config = modules_config
|
||||
|
||||
def check_env(self):
|
||||
assert self.mmtype_list.issubset(
|
||||
SUPPORT_MMLAB_TYPES), 'Only support %s now !' % SUPPORT_MMLAB_TYPES
|
||||
install_success = False
|
||||
try:
|
||||
import mmdet
|
||||
install_success = True
|
||||
except ModuleNotFoundError as e:
|
||||
logging.warning(e)
|
||||
logging.warning('Try to install mmdet...')
|
||||
|
||||
if not install_success:
|
||||
try:
|
||||
run_in_subprocess('pip install mmdet')
|
||||
except:
|
||||
raise ValueError(
|
||||
'Failed to install mmdet, '
|
||||
'please refer to https://github.com/open-mmlab/mmdetection to install.'
|
||||
)
|
||||
|
||||
def fix_conflicts(self):
|
||||
# mmdet and easycv both register
|
||||
if MMDET in self.mmtype_list:
|
||||
mmcv_conflict_list = ['YOLOXLrUpdaterHook']
|
||||
from mmcv.runner.hooks import HOOKS
|
||||
for conflict in mmcv_conflict_list:
|
||||
HOOKS._module_dict.pop(conflict, None)
|
||||
|
||||
def adapt_mmlab_modules(self):
|
||||
for module_cfg in self.modules_config:
|
||||
mmtype = module_cfg['type']
|
||||
module_name, module_type = module_cfg['name'], module_cfg['module']
|
||||
self._merge_mmlab_module_to_easycv(mmtype, module_type,
|
||||
module_name)
|
||||
self.wrap_module(mmtype, module_type, module_name)
|
||||
|
||||
for mmtype in self.mmtype_list:
|
||||
self._merge_all_easycv_modules_to_mmlab(mmtype)
|
||||
|
||||
def wrap_module(self, mmtype, module_type, module_name):
|
||||
module_obj = self._get_mm_module_obj(mmtype, module_type, module_name)
|
||||
if mmtype == MMDET:
|
||||
MMDetWrapper().wrap_module(module_obj, module_type)
|
||||
|
||||
def _merge_all_easycv_modules_to_mmlab(self, mmtype):
|
||||
# Add all my module to mmlab module registry, if duplicated, replace with my module.
|
||||
# To handle: if MaskRCNN use mmdet's api, but the backbone also uses the backbone registered in mmdet
|
||||
# In order to support our backbone, register our modules into mmdet.
|
||||
# If not specified mmdet type, use our modules by default.
|
||||
for key, registry_type in self.MMTYPE_REGISTRY_MAP[mmtype].items():
|
||||
registry_type._module_dict.update(
|
||||
EASYCV_REGISTRY_MAP[key]._module_dict)
|
||||
|
||||
def _merge_mmlab_module_to_easycv(self,
|
||||
mmtype,
|
||||
module_type,
|
||||
module_name,
|
||||
force=True):
|
||||
model_obj = self._get_mm_module_obj(mmtype, module_type, module_name)
|
||||
# Add mmlab module to my module registry.
|
||||
easycv_registry_type = EASYCV_REGISTRY_MAP[module_type]
|
||||
# Copy a duplicate to avoid directly modifying the properties of the original object
|
||||
_MMLAB_COPIES[module_name] = type(module_name, (model_obj, ), dict())
|
||||
easycv_registry_type.register_module(
|
||||
_MMLAB_COPIES[module_name], force=force)
|
||||
|
||||
def _get_mm_module_obj(self, mmtype, module_type, module_name):
|
||||
if isinstance(module_name, str):
|
||||
mm_registry_type = self.MMTYPE_REGISTRY_MAP[mmtype][module_type]
|
||||
mm_module_dict = mm_registry_type._module_dict
|
||||
if module_name in mm_module_dict:
|
||||
module_obj = mm_module_dict[module_name]
|
||||
else:
|
||||
raise ValueError('Not find {} object in {}'.format(
|
||||
module_name, mmtype))
|
||||
elif inspect.isclass(module_name):
|
||||
module_obj = module_name
|
||||
else:
|
||||
raise ValueError(
|
||||
'Only support type `str` and `class` object, but get type {}'.
|
||||
format(type(module_name)))
|
||||
return module_obj
|
||||
|
||||
def _get_mmtype_registry_map(self):
|
||||
from mmdet.models.builder import MODELS as MMMODELS
|
||||
from mmdet.models.builder import BACKBONES as MMBACKBONES
|
||||
from mmdet.models.builder import NECKS as MMNECKS
|
||||
from mmdet.models.builder import HEADS as MMHEADS
|
||||
registry_map = {
|
||||
MMDET: {
|
||||
'model': MMMODELS,
|
||||
'backbone': MMBACKBONES,
|
||||
'neck': MMNECKS,
|
||||
'head': MMHEADS
|
||||
}
|
||||
}
|
||||
return registry_map
|
||||
|
||||
|
||||
class MMDetWrapper:
|
||||
|
||||
def wrap_module(self, cls, module_type):
|
||||
if module_type == 'model':
|
||||
self._wrap_model_forward(cls)
|
||||
self._wrap_model_forward_test(cls)
|
||||
|
||||
def _wrap_model_forward(self, cls):
|
||||
origin_forward = cls.forward
|
||||
|
||||
def _new_forward(self, img, mode='train', **kwargs):
|
||||
img_metas = kwargs.pop('img_metas', None)
|
||||
|
||||
if mode == 'train':
|
||||
return origin_forward(
|
||||
self, img, img_metas, return_loss=True, **kwargs)
|
||||
else:
|
||||
return origin_forward(
|
||||
self, img, img_metas, return_loss=False, **kwargs)
|
||||
|
||||
setattr(cls, 'forward', _new_forward)
|
||||
|
||||
def _wrap_model_forward_test(self, cls):
|
||||
from mmdet.core import encode_mask_results
|
||||
|
||||
origin_forward_test = cls.forward_test
|
||||
|
||||
def _new_forward_test(self, img, img_metas=None, **kwargs):
|
||||
kwargs.update({'rescale': True}) # move from single_gpu_test
|
||||
logging.info('Set rescale to True for `model.forward_test`!')
|
||||
|
||||
result = origin_forward_test(self, img, img_metas, **kwargs)
|
||||
# ============result process to adapt to easycv============
|
||||
# encode mask results
|
||||
if isinstance(result[0], tuple):
|
||||
result = [(bbox_results, encode_mask_results(mask_results))
|
||||
for bbox_results, mask_results in result]
|
||||
# This logic is only used in panoptic segmentation test.
|
||||
elif isinstance(result[0], dict) and 'ins_results' in result[0]:
|
||||
for j in range(len(result)):
|
||||
bbox_results, mask_results = result[j]['ins_results']
|
||||
result[j]['ins_results'] = (
|
||||
bbox_results, encode_mask_results(mask_results))
|
||||
|
||||
detection_boxes = []
|
||||
detection_scores = []
|
||||
detection_classes = []
|
||||
detection_masks = []
|
||||
for res_i in result:
|
||||
if isinstance(res_i, tuple):
|
||||
bbox_result, segm_result = res_i
|
||||
if isinstance(segm_result, tuple):
|
||||
segm_result = segm_result[0] # ms rcnn
|
||||
else:
|
||||
bbox_result, segm_result = res_i, None
|
||||
bboxes = np.vstack(bbox_result)
|
||||
labels = [
|
||||
np.full(bbox.shape[0], i, dtype=np.int32)
|
||||
for i, bbox in enumerate(bbox_result)
|
||||
]
|
||||
labels = np.concatenate(labels)
|
||||
# draw segmentation masks
|
||||
segms = []
|
||||
if segm_result is not None and len(labels) > 0: # non empty
|
||||
segms = mmcv.concat_list(segm_result)
|
||||
if isinstance(segms[0], torch.Tensor):
|
||||
segms = torch.stack(
|
||||
segms, dim=0).detach().cpu().numpy()
|
||||
else:
|
||||
segms = np.stack(segms, axis=0)
|
||||
|
||||
scores = bboxes[:, 4] if bboxes.shape[1] == 5 else None
|
||||
bboxes = bboxes[:, 0:4] if bboxes.shape[1] == 5 else bboxes
|
||||
assert bboxes.shape[1] == 4
|
||||
|
||||
detection_boxes.append(bboxes)
|
||||
detection_scores.append(scores)
|
||||
detection_classes.append(labels)
|
||||
detection_masks.append(segms)
|
||||
|
||||
assert len(img_metas) == 1
|
||||
outputs = {
|
||||
'detection_boxes': detection_boxes,
|
||||
'detection_scores': detection_scores,
|
||||
'detection_classes': detection_classes,
|
||||
'detection_masks': detection_masks,
|
||||
'img_metas': img_metas[0]
|
||||
}
|
||||
return outputs
|
||||
|
||||
setattr(cls, 'forward_test', _new_forward_test)
|
||||
|
||||
|
||||
def dynamic_adapt_for_mmlab(cfg):
|
||||
mmlab_modules_cfg = cfg.get('mmlab_modules', [])
|
||||
if len(mmlab_modules_cfg) > 1:
|
||||
adapter = MMAdapter(mmlab_modules_cfg)
|
||||
adapter.adapt_mmlab_modules()
|
|
@ -0,0 +1,133 @@
|
|||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||
import os
|
||||
import unittest
|
||||
|
||||
import torch
|
||||
from mmcv.parallel import scatter_kwargs
|
||||
from tests.ut_config import (COCO_CLASSES, DET_DATA_SMALL_COCO_LOCAL,
|
||||
IMG_NORM_CFG_255)
|
||||
|
||||
from easycv.apis.test import single_gpu_test
|
||||
from easycv.datasets import build_dataloader, build_dataset
|
||||
from easycv.models.builder import build_model
|
||||
from easycv.utils.config_tools import mmcv_config_fromfile
|
||||
from easycv.utils.mmlab_utils import dynamic_adapt_for_mmlab
|
||||
|
||||
|
||||
class MMLabUtilTest(unittest.TestCase):
|
||||
|
||||
def setUp(self):
|
||||
print(('Testing %s.%s' % (type(self).__name__, self._testMethodName)))
|
||||
|
||||
def _get_model(self):
|
||||
config_path = 'configs/detection/mask_rcnn/mask_rcnn_r50_fpn.py'
|
||||
cfg = mmcv_config_fromfile(config_path)
|
||||
dynamic_adapt_for_mmlab(cfg)
|
||||
model = build_model(cfg.model)
|
||||
|
||||
return model
|
||||
|
||||
def _get_dataset(self, mode='train'):
|
||||
if mode == 'train':
|
||||
pipeline = [
|
||||
dict(
|
||||
type='MMResize',
|
||||
img_scale=[(1333, 640), (1333, 672), (1333, 704),
|
||||
(1333, 736), (1333, 768), (1333, 800)],
|
||||
multiscale_mode='value',
|
||||
keep_ratio=True),
|
||||
dict(type='MMRandomFlip', flip_ratio=0.5),
|
||||
dict(type='MMNormalize', **IMG_NORM_CFG_255),
|
||||
dict(type='MMPad', size_divisor=32),
|
||||
dict(type='DefaultFormatBundle'),
|
||||
dict(
|
||||
type='Collect',
|
||||
keys=['img', 'gt_bboxes', 'gt_labels', 'gt_masks'],
|
||||
meta_keys=('filename', 'ori_filename', 'ori_shape',
|
||||
'ori_img_shape', 'img_shape', 'pad_shape',
|
||||
'scale_factor', 'flip', 'flip_direction',
|
||||
'img_norm_cfg'))
|
||||
]
|
||||
else:
|
||||
pipeline = [
|
||||
dict(
|
||||
type='MMMultiScaleFlipAug',
|
||||
img_scale=(1333, 800),
|
||||
flip=False,
|
||||
transforms=[
|
||||
dict(type='MMResize', keep_ratio=True),
|
||||
dict(type='MMRandomFlip'),
|
||||
dict(type='MMNormalize', **IMG_NORM_CFG_255),
|
||||
dict(type='MMPad', size_divisor=32),
|
||||
dict(type='ImageToTensor', keys=['img']),
|
||||
dict(
|
||||
type='Collect',
|
||||
keys=['img'],
|
||||
meta_keys=('filename', 'ori_filename', 'ori_shape',
|
||||
'ori_img_shape', 'img_shape',
|
||||
'pad_shape', 'scale_factor', 'flip',
|
||||
'flip_direction', 'img_norm_cfg')),
|
||||
])
|
||||
]
|
||||
|
||||
data_root = DET_DATA_SMALL_COCO_LOCAL
|
||||
dataset_cfg = dict(
|
||||
type='DetDataset',
|
||||
data_source=dict(
|
||||
type='DetSourceCoco',
|
||||
ann_file=os.path.join(data_root,
|
||||
'instances_train2017_20.json'),
|
||||
img_prefix=os.path.join(data_root, 'train2017'),
|
||||
pipeline=[
|
||||
dict(type='LoadImageFromFile', to_float32=True),
|
||||
dict(
|
||||
type='LoadAnnotations', with_bbox=True, with_mask=True)
|
||||
],
|
||||
classes=COCO_CLASSES,
|
||||
filter_empty_gt=False,
|
||||
iscrowd=False),
|
||||
pipeline=pipeline)
|
||||
return build_dataset(dataset_cfg)
|
||||
|
||||
def xxtest_model_train(self):
|
||||
model = self._get_model()
|
||||
model = model.cuda()
|
||||
model.train()
|
||||
|
||||
dataset = self._get_dataset()
|
||||
data_loader = build_dataloader(
|
||||
dataset, imgs_per_gpu=3, workers_per_gpu=1, num_gpus=1, dist=False)
|
||||
for i, data_batch in enumerate(data_loader):
|
||||
input_args, kwargs = scatter_kwargs(None, data_batch, [-1])
|
||||
for key in ['img', 'gt_bboxes', 'gt_labels']:
|
||||
if isinstance(kwargs[0][key], (list, tuple)):
|
||||
kwargs[0][key] = [
|
||||
kwargs[0][key][i].cuda()
|
||||
for i in range(len(kwargs[0][key]))
|
||||
]
|
||||
else:
|
||||
kwargs[0][key] = kwargs[0][key].cuda()
|
||||
output = model(**kwargs[0], mode='train')
|
||||
self.assertEqual(len(output['loss_rpn_cls']), 5)
|
||||
self.assertEqual(len(output['loss_rpn_bbox']), 5)
|
||||
self.assertEqual(output['loss_cls'].shape, torch.Size([]))
|
||||
self.assertEqual(output['acc'].shape, torch.Size([1]))
|
||||
self.assertEqual(output['loss_bbox'].shape, torch.Size([]))
|
||||
self.assertEqual(output['loss_mask'].shape, torch.Size([1]))
|
||||
|
||||
def test_model_test(self):
|
||||
model = self._get_model()
|
||||
model = model.cuda()
|
||||
dataset = self._get_dataset(mode='test')
|
||||
data_loader = build_dataloader(
|
||||
dataset, imgs_per_gpu=1, workers_per_gpu=1, num_gpus=1, dist=False)
|
||||
results = single_gpu_test(model, data_loader, mode='test')
|
||||
self.assertEqual(len(results['detection_boxes']), 20)
|
||||
self.assertEqual(len(results['detection_scores']), 20)
|
||||
self.assertEqual(len(results['detection_classes']), 20)
|
||||
self.assertEqual(len(results['detection_masks']), 20)
|
||||
self.assertEqual(len(results['img_metas']), 20)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
|
@ -28,6 +28,7 @@ from easycv.models import build_model
|
|||
from easycv.utils.checkpoint import load_checkpoint
|
||||
from easycv.utils.config_tools import (CONFIG_TEMPLATE_ZOO,
|
||||
mmcv_config_fromfile, rebuild_config)
|
||||
from easycv.utils.mmlab_utils import dynamic_adapt_for_mmlab
|
||||
|
||||
# from tools.fuse_conv_bn import fuse_module
|
||||
|
||||
|
@ -143,6 +144,9 @@ def main():
|
|||
if cfg.get('oss_io_config', None) is not None:
|
||||
io.access_oss(**cfg.oss_io_config)
|
||||
|
||||
# dynamic adapt mmdet models
|
||||
dynamic_adapt_for_mmlab(cfg)
|
||||
|
||||
# set cudnn_benchmark
|
||||
if cfg.get('cudnn_benchmark', False):
|
||||
torch.backends.cudnn.benchmark = True
|
||||
|
|
145
tools/test.py
145
tools/test.py
|
@ -1,145 +0,0 @@
|
|||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||
import argparse
|
||||
import importlib
|
||||
import os
|
||||
import os.path as osp
|
||||
import sys
|
||||
import time
|
||||
|
||||
import mmcv
|
||||
import torch
|
||||
from mmcv.parallel import MMDataParallel, MMDistributedDataParallel
|
||||
from mmcv.runner import get_dist_info, init_dist, load_checkpoint
|
||||
|
||||
from easycv.core.evaluation.builder import build_evaluator
|
||||
from easycv.datasets import build_dataloader, build_dataset
|
||||
from easycv.models import build_model
|
||||
from easycv.utils.collect import dist_forward_collect, nondist_forward_collect
|
||||
# from mmcv import Config
|
||||
from easycv.utils.config_tools import mmcv_config_fromfile, traverse_replace
|
||||
from easycv.utils.logger import get_root_logger
|
||||
|
||||
sys.path.append(os.path.abspath(os.path.dirname(os.path.dirname(__file__))))
|
||||
sys.path.append(
|
||||
os.path.abspath(
|
||||
osp.join(os.path.dirname(os.path.dirname(__file__)), '../')))
|
||||
|
||||
|
||||
def single_gpu_test(model, data_loader):
|
||||
model.eval()
|
||||
func = lambda **x: model(mode='test', **x)
|
||||
results = nondist_forward_collect(func, data_loader,
|
||||
len(data_loader.dataset))
|
||||
return results
|
||||
|
||||
|
||||
def multi_gpu_test(model, data_loader):
|
||||
model.eval()
|
||||
func = lambda **x: model(mode='test', **x)
|
||||
rank, world_size = get_dist_info()
|
||||
results = dist_forward_collect(func, data_loader, rank,
|
||||
len(data_loader.dataset))
|
||||
return results
|
||||
|
||||
|
||||
def parse_args():
|
||||
parser = argparse.ArgumentParser(
|
||||
description='MMDet test (and eval) a model')
|
||||
parser.add_argument('config', help='test config file path')
|
||||
parser.add_argument('checkpoint', help='checkpoint file')
|
||||
parser.add_argument(
|
||||
'--work_dir',
|
||||
type=str,
|
||||
default=None,
|
||||
help='the dir to save logs and models')
|
||||
parser.add_argument(
|
||||
'--launcher',
|
||||
choices=['none', 'pytorch', 'slurm', 'mpi'],
|
||||
default='none',
|
||||
help='job launcher')
|
||||
parser.add_argument('--local_rank', type=int, default=0)
|
||||
parser.add_argument(
|
||||
'--port',
|
||||
type=int,
|
||||
default=29500,
|
||||
help='port only works when launcher=="slurm"')
|
||||
parser.add_argument(
|
||||
'--model_type',
|
||||
choices=['classification', 'pose'],
|
||||
default='classification',
|
||||
help='model type')
|
||||
args = parser.parse_args()
|
||||
if 'LOCAL_RANK' not in os.environ:
|
||||
os.environ['LOCAL_RANK'] = str(args.local_rank)
|
||||
|
||||
return args
|
||||
|
||||
|
||||
def main():
|
||||
args = parse_args()
|
||||
|
||||
cfg = mmcv_config_fromfile(args.config)
|
||||
# set cudnn_benchmark
|
||||
if cfg.get('cudnn_benchmark', False):
|
||||
torch.backends.cudnn.benchmark = True
|
||||
# update configs according to CLI args
|
||||
if args.work_dir is not None:
|
||||
cfg.work_dir = args.work_dir
|
||||
|
||||
cfg.model.pretrained = None # ensure to use checkpoint rather than pretraining
|
||||
|
||||
# check memcached package exists
|
||||
if importlib.util.find_spec('mc') is None:
|
||||
traverse_replace(cfg, 'memcached', False)
|
||||
|
||||
# init distributed env first, since logger depends on the dist info.
|
||||
if args.launcher == 'none':
|
||||
distributed = False
|
||||
else:
|
||||
distributed = True
|
||||
if args.launcher == 'slurm':
|
||||
cfg.dist_params['port'] = args.port
|
||||
init_dist(args.launcher, **cfg.dist_params)
|
||||
|
||||
# logger
|
||||
timestamp = time.strftime('%Y%m%d_%H%M%S', time.localtime())
|
||||
log_file = osp.join(cfg.work_dir, 'test_{}.log'.format(timestamp))
|
||||
logger = get_root_logger(log_file=log_file, log_level=cfg.log_level)
|
||||
|
||||
# build the dataloader
|
||||
dataset = build_dataset(cfg.data.val)
|
||||
data_loader = build_dataloader(
|
||||
dataset,
|
||||
imgs_per_gpu=cfg.data.imgs_per_gpu,
|
||||
workers_per_gpu=cfg.data.workers_per_gpu,
|
||||
dist=distributed,
|
||||
shuffle=False)
|
||||
|
||||
# build the model and load checkpoint
|
||||
model = build_model(cfg.model)
|
||||
load_checkpoint(model, args.checkpoint, map_location='cpu')
|
||||
|
||||
if not distributed:
|
||||
model = MMDataParallel(model, device_ids=[0])
|
||||
outputs = single_gpu_test(model, data_loader)
|
||||
else:
|
||||
model = MMDistributedDataParallel(
|
||||
model.cuda(),
|
||||
device_ids=[torch.cuda.current_device()],
|
||||
broadcast_buffers=False)
|
||||
outputs = multi_gpu_test(model, data_loader) # dict{key: np.ndarray}
|
||||
|
||||
rank, _ = get_dist_info()
|
||||
if rank == 0:
|
||||
if args.model_type == 'pose':
|
||||
evaluators = build_evaluator(
|
||||
cfg.eval_pipelines[0]['evaluators'][0])
|
||||
dataset.evaluate(outputs, evaluators)
|
||||
else:
|
||||
for name, val in outputs.items():
|
||||
dataset.evaluate(
|
||||
torch.from_numpy(val), name, logger, topk=(1, 5))
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
|
@ -29,6 +29,7 @@ from easycv.models import build_model
|
|||
from easycv.utils.collect_env import collect_env
|
||||
from easycv.utils.flops_counter import get_model_info
|
||||
from easycv.utils.logger import get_root_logger
|
||||
from easycv.utils.mmlab_utils import dynamic_adapt_for_mmlab
|
||||
from easycv.utils.config_tools import traverse_replace
|
||||
from easycv.utils.config_tools import (CONFIG_TEMPLATE_ZOO,
|
||||
mmcv_config_fromfile, rebuild_config)
|
||||
|
@ -149,6 +150,9 @@ def main():
|
|||
if args.load_from is not None:
|
||||
cfg.load_from = args.load_from
|
||||
|
||||
# dynamic adapt mmdet models
|
||||
dynamic_adapt_for_mmlab(cfg)
|
||||
|
||||
cfg.gpus = args.gpus
|
||||
|
||||
# check memcached package exists
|
||||
|
|
Loading…
Reference in New Issue