[feature]: support mmdet models config (#25)

* support mmdet models

* add mmlab_models_usage_guide.md

* remove tools/test.py
pull/59/head
Cathy0908 2022-05-11 17:44:06 +08:00 committed by GitHub
parent 10266f54a7
commit b5fb2b70c7
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
14 changed files with 667 additions and 196 deletions

View File

@ -0,0 +1,138 @@
# model settings
model = dict(
type='MaskRCNN',
# EasyCV backbone
backbone=dict(
type='ResNet',
depth=50,
num_stages=4,
out_indices=(1, 2, 3, 4),
frozen_stages=1,
norm_cfg=dict(type='BN', requires_grad=True),
norm_eval=True),
# mmdet backbone
# backbone=dict(
# type='ResNet',
# depth=50,
# num_stages=4,
# out_indices=(0, 1, 2, 3),
# frozen_stages=1,
# norm_cfg=dict(type='BN', requires_grad=True),
# norm_eval=True,
# style='pytorch',
# init_cfg=dict(type='Pretrained', checkpoint='torchvision://resnet50')),
neck=dict(
type='FPN',
in_channels=[256, 512, 1024, 2048],
out_channels=256,
num_outs=5),
rpn_head=dict(
type='RPNHead',
in_channels=256,
feat_channels=256,
anchor_generator=dict(
type='AnchorGenerator',
scales=[8],
ratios=[0.5, 1.0, 2.0],
strides=[4, 8, 16, 32, 64]),
bbox_coder=dict(
type='DeltaXYWHBBoxCoder',
target_means=[.0, .0, .0, .0],
target_stds=[1.0, 1.0, 1.0, 1.0]),
loss_cls=dict(
type='CrossEntropyLoss', use_sigmoid=True, loss_weight=1.0),
loss_bbox=dict(type='L1Loss', loss_weight=1.0)),
roi_head=dict(
type='StandardRoIHead',
bbox_roi_extractor=dict(
type='SingleRoIExtractor',
roi_layer=dict(type='RoIAlign', output_size=7, sampling_ratio=0),
out_channels=256,
featmap_strides=[4, 8, 16, 32]),
bbox_head=dict(
type='Shared2FCBBoxHead',
in_channels=256,
fc_out_channels=1024,
roi_feat_size=7,
num_classes=80,
bbox_coder=dict(
type='DeltaXYWHBBoxCoder',
target_means=[0., 0., 0., 0.],
target_stds=[0.1, 0.1, 0.2, 0.2]),
reg_class_agnostic=False,
loss_cls=dict(
type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0),
loss_bbox=dict(type='L1Loss', loss_weight=1.0)),
mask_roi_extractor=dict(
type='SingleRoIExtractor',
roi_layer=dict(type='RoIAlign', output_size=14, sampling_ratio=0),
out_channels=256,
featmap_strides=[4, 8, 16, 32]),
mask_head=dict(
type='FCNMaskHead',
num_convs=4,
in_channels=256,
conv_out_channels=256,
num_classes=80,
loss_mask=dict(
type='CrossEntropyLoss', use_mask=True, loss_weight=1.0))),
# model training and testing settings
train_cfg=dict(
rpn=dict(
assigner=dict(
type='MaxIoUAssigner',
pos_iou_thr=0.7,
neg_iou_thr=0.3,
min_pos_iou=0.3,
match_low_quality=True,
ignore_iof_thr=-1),
sampler=dict(
type='RandomSampler',
num=256,
pos_fraction=0.5,
neg_pos_ub=-1,
add_gt_as_proposals=False),
allowed_border=-1,
pos_weight=-1,
debug=False),
rpn_proposal=dict(
nms_pre=2000,
max_per_img=1000,
nms=dict(type='nms', iou_threshold=0.7),
min_bbox_size=0),
rcnn=dict(
assigner=dict(
type='MaxIoUAssigner',
pos_iou_thr=0.5,
neg_iou_thr=0.5,
min_pos_iou=0.5,
match_low_quality=True,
ignore_iof_thr=-1),
sampler=dict(
type='RandomSampler',
num=512,
pos_fraction=0.25,
neg_pos_ub=-1,
add_gt_as_proposals=True),
mask_size=28,
pos_weight=-1,
debug=False)),
test_cfg=dict(
rpn=dict(
nms_pre=1000,
max_per_img=1000,
nms=dict(type='nms', iou_threshold=0.7),
min_bbox_size=0),
rcnn=dict(
score_thr=0.05,
nms=dict(type='nms', iou_threshold=0.5),
max_per_img=100,
mask_thr_binary=0.5)))
mmlab_modules = [
dict(type='mmdet', name='MaskRCNN', module='model'),
# dict(type=MMDET, name='ResNet', module='backbone'), # comment out, use EasyCV ResNet
dict(type='mmdet', name='FPN', module='neck'),
dict(type='mmdet', name='RPNHead', module='head'),
dict(type='mmdet', name='StandardRoIHead', module='head'),
]

View File

@ -0,0 +1,69 @@
# Use mmdetection's models in EasyCV
For details of mmdetection, please refer to :https://github.com/open-mmlab/mmdetection
**We only support mmdet's models and do not support other series in mmlab and other modules such as transforms, dataset api, etc. are not supported either.**
The models module of EasyCV is divided into four modules: `backbone`, `head`, `neck`, and `model`.
So we support the models combination of EasyCV and mmdet from these four levels.
**We will not adapt the other apis involved in these four levels modules, we package the entire api for use.**
> **Note: **
>
> **If you want to combine the models part of mmdet and easycv, please pay attention to the compatibility between the apis, we do not guarantee that the api of EasyCV and mmdet are compatible.**
Take the `MaskRCNN` model as an example, please refer to [mask_rcnn_r50_fpn.py](https://github.com/alibaba/EasyCV/tree/master/configs/detection/mask_rcnn/mask_rcnn_r50_fpn.py). Except for the backbone, other parts in this model are all mmdet apis.
The framework of `MaskRCNN` can be divided into the following parts from the `backbone`, `head`, `neck`, and `model` levels
- backbone: `ResNet`
- head`RPNHead`, `StandardRoIHead`
- neck: `FPN`
- model: `MaskRCNN`
The configuration adapt for mmdet is as follows:
```python
mmlab_modules = [
dict(type='mmdet', name='MaskRCNN', module='model'),
# dict(type='mmdet', name='ResNet', module='backbone'), # comment out, use EasyCV ResNet
dict(type='mmdet', name='FPN', module='neck'),
dict(type='mmdet', name='RPNHead', module='head'),
dict(type='mmdet', name='StandardRoIHead', module='head'),
]
```
> Parameters:
>
> - type: the name of the open source, only `mmdet` is supported
> - name: the name of api
> - Module: The name of the module to which the api belongs, only `backbone`,`head`,`neck`,`model` are supported.
In this configuration , the `head`, `neck`, and `model` parts specify the type as `mmdet`, except for `backbone`.
**No configured api will use the EasyCV api by default, , such as backbone (ResNet).**
**For other explicitly configured type as `mmdet`, we will use the mmdet api.**
Which is:
- `MaskRCNN`(model): Use mmdet's `MaskRCNN` api.
- `ResNet`(backbone): Use EasyCV's `ResNet` api.
> Note that the parameters of the `ResNet`of mmdet and EasyCV are different. Please pay attention to it!.
- `RPNHead`(head): Use mmdet's `RPNHead` api.
> Note that all the other apis configured in `RPNHead`, such as `AnchorGenerator`, `DeltaXYWHBBoxCoder`, etc., are all mmdet's apis, because we package the entire api for use.
- `StandardRoIHead`(head): Use mmdet's `StandardRoIHead` api.
> Note that all the other apis configured in `StandardRoIHead`, such as `SingleRoIExtractor`, `SingleRoIExtractor`, etc., are all mmdet's apis, because we package the entire api for use.
- `FPN`(neck): Use mmdet's `FPN` api.

View File

@ -119,9 +119,16 @@ def single_gpu_test(model, data_loader, mode='test', use_fp16=False, **kwargs):
results[k].append(v)
if 'img_metas' in data:
batch_size = len(data['img_metas'].data[0])
if isinstance(data['img_metas'], list):
batch_size = len(data['img_metas'][0].data[0])
else:
batch_size = len(data['img_metas'].data[0])
else:
batch_size = data['img'].size(0)
if isinstance(data['img'], list):
batch_size = data['img'][0].size(0)
else:
batch_size = data['img'].size(0)
for _ in range(batch_size):
prog_bar.update()

View File

@ -151,7 +151,7 @@ def train_model(model,
if validate:
interval = cfg.eval_config.pop('interval', 1)
for idx, eval_pipe in enumerate(cfg.eval_pipelines):
data = eval_pipe.data
data = eval_pipe.get('data', None) or cfg.data.val
dist_eval = eval_pipe.get('dist_eval', False)
evaluator_cfg = eval_pipe.evaluators[0]

View File

@ -473,7 +473,8 @@ class CocoMaskEvaluator(Evaluator):
groundtruth_masks_shape = self._image_id_to_mask_shape_map[image_id]
detection_masks = detections_dict[
standard_fields.DetectionResultFields.detection_masks]
if groundtruth_masks_shape[1:] != detection_masks.shape[1:]:
if len(detection_masks
) and groundtruth_masks_shape[1:] != detection_masks.shape[1:]:
raise ValueError(
'Spatial shape of groundtruth masks and detection masks '
'are incompatible: {} vs {}'.format(groundtruth_masks_shape,
@ -601,6 +602,9 @@ class CocoMaskEvaluator(Evaluator):
else:
groundtruth_is_crowd = groundtruth_is_crowd_list[idx]
gt_masks = np.array(
[self._ann_to_mask(mask, height, width) for mask in gt_masks],
dtype=np.uint8)
groundtruth_dict = {
'groundtruth_boxes': gt_boxes_absolute,
'groundtruth_instance_masks': gt_masks,
@ -609,6 +613,11 @@ class CocoMaskEvaluator(Evaluator):
}
self.add_single_ground_truth_image_info(image_id, groundtruth_dict)
detection_masks = np.array([
self._ann_to_mask(mask, height, width)
for mask in detection_masks
],
dtype=np.uint8)
# add detection info
detection_dict = {
'detection_masks': detection_masks,
@ -621,6 +630,27 @@ class CocoMaskEvaluator(Evaluator):
self.clear()
return eval_dict
def _ann_to_mask(self, segmentation, height, width):
from xtcocotools import mask as maskUtils
segm = segmentation
h = height
w = width
if type(segm) == list:
# polygon -- a single object might consist of multiple parts
# we merge all parts into one mask rle code
rles = maskUtils.frPyObjects(segm, h, w)
rle = maskUtils.merge(rles)
elif type(segm['counts']) == list:
# uncompressed RLE
rle = maskUtils.frPyObjects(segm, h, w)
else:
# rle
rle = segm
m = maskUtils.decode(rle)
return m
@EVALUATORS.register_module
class CoCoPoseTopDownEvaluator(Evaluator):

View File

@ -1644,20 +1644,21 @@ class LoadAnnotations:
Returns:
numpy.ndarray: The decode bitmap mask of shape (img_h, img_w).
"""
raise NotImplementedError
# if isinstance(mask_ann, list):
# # polygon -- a single object might consist of multiple parts
# # we merge all parts into one mask rle code
# rles = maskUtils.frPyObjects(mask_ann, img_h, img_w)
# rle = maskUtils.merge(rles)
# elif isinstance(mask_ann['counts'], list):
# # uncompressed RLE
# rle = maskUtils.frPyObjects(mask_ann, img_h, img_w)
# else:
# # rle
# rle = mask_ann
# mask = maskUtils.decode(rle)
# return mask
import xtcocotools.mask as maskUtils
if isinstance(mask_ann, list):
# polygon -- a single object might consist of multiple parts
# we merge all parts into one mask rle code
rles = maskUtils.frPyObjects(mask_ann, img_h, img_w)
rle = maskUtils.merge(rles)
elif isinstance(mask_ann['counts'], list):
# uncompressed RLE
rle = maskUtils.frPyObjects(mask_ann, img_h, img_w)
else:
# rle
rle = mask_ann
mask = maskUtils.decode(rle)
return mask
def process_polygons(self, polygons):
"""Convert polygons to list of ndarray and filter invalid polygons.
@ -1687,20 +1688,20 @@ class LoadAnnotations:
If ``self.poly2mask`` is set ``True``, `gt_mask` will contain
:obj:`PolygonMasks`. Otherwise, :obj:`BitmapMasks` is used.
"""
raise NotImplementedError
from mmdet.core import BitmapMasks, PolygonMasks
# h, w = results['img_info']['height'], results['img_info']['width']
# gt_masks = results['ann_info']['masks']
# if self.poly2mask:
# gt_masks = BitmapMasks(
# [self._poly2mask(mask, h, w) for mask in gt_masks], h, w)
# else:
# gt_masks = PolygonMasks(
# [self.process_polygons(polygons) for polygons in gt_masks], h,
# w)
# results['gt_masks'] = gt_masks
# results['mask_fields'].append('gt_masks')
# return results
h, w = results['img_info']['height'], results['img_info']['width']
gt_masks = results['ann_info']['masks']
if self.poly2mask:
gt_masks = BitmapMasks(
[self._poly2mask(mask, h, w) for mask in gt_masks], h, w)
else:
gt_masks = PolygonMasks(
[self.process_polygons(polygons) for polygons in gt_masks], h,
w)
results['gt_masks'] = gt_masks
results['mask_fields'].append('gt_masks')
return results
def _load_semantic_seg(self, results):
"""Private function to load semantic segmentation annotations.

View File

@ -70,6 +70,10 @@ class DetDataset(BaseDataset):
self.data_source.get_ann_info(idx)['groundtruth_is_crowd']
for idx in range(len(results['img_metas']))
]
groundtruth_dict['groundtruth_instance_masks'] = [
self.data_source.get_ann_info(idx).get('masks', None)
for idx in range(len(results['img_metas']))
]
for evaluator in evaluators:
eval_result.update(evaluator.evaluate(results, groundtruth_dict))

View File

@ -2,25 +2,9 @@
from .backbones import * # noqa: F401,F403
from .builder import build_backbone, build_head, build_loss, build_model
from .classification import *
from .detection import *
from .heads import *
from .loss import *
from .pose import TopDown
from .registry import BACKBONES, HEADS, LOSSES, MODELS, NECKS
from .selfsup import *
try:
from .detection.yolox.yolox import YOLOX
except:
import logging
logging.warning(
'Import YOLOX failed! Please check if mmcv and CUDA & Pytorch match.'
'You may try: `pip install mmcv-full=={mmcv_version} -f https://download.openmmlab.com/mmcv/dist/{cu_version}/{torch_version}/index.html`.'
'e.g.: `pip install mmcv-full==1.3.18 -f https://download.openmmlab.com/mmcv/dist/cu101/torch1.7.0/index.html`'
)
try:
from .detection.yolox_edge.yolox_edge import YOLOX_EDGE
except:
import logging
logging.warning(
'Import YOLOX EDGE model failed! Please check if mmcv and CUDA & Pytorch match.'
)

View File

@ -1,13 +1,11 @@
import copy
import os
import os.path as osp
import platform
import shutil
import sys
import tempfile
from importlib import import_module
from mmcv import Config, check_file_exist, import_modules_from_strings
from mmcv import Config, import_modules_from_strings
from .user_config_params_utils import check_value_type

View File

@ -0,0 +1,244 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
import inspect
import logging
import mmcv
import numpy as np
import torch
from easycv.models.registry import BACKBONES, HEADS, MODELS, NECKS
from .test_util import run_in_subprocess
EASYCV_REGISTRY_MAP = {
'model': MODELS,
'backbone': BACKBONES,
'neck': NECKS,
'head': HEADS
}
MMDET = 'mmdet'
SUPPORT_MMLAB_TYPES = [MMDET]
_MMLAB_COPIES = locals()
class MMAdapter:
def __init__(self, modules_config):
"""Adapt mmlab apis.
Args: modules_config is as follow format:
[
dict(type='mmdet', name='MaskRCNN', module='model'), # means using mmdet MaskRCNN
# dict(type='mmdet, name='ResNet', module='backbone'), # comment out, means use my ResNet
dict(name='FPN', module='neck'), # type is missing, use mmdet default
]
"""
self.default_mmtype = 'mmdet'
self.mmtype_list = set([])
for module_cfg in modules_config:
mmtype = module_cfg.get('type',
self.default_mmtype) # default mmdet
self.mmtype_list.add(mmtype)
self.check_env()
self.fix_conflicts()
self.MMTYPE_REGISTRY_MAP = self._get_mmtype_registry_map()
self.modules_config = modules_config
def check_env(self):
assert self.mmtype_list.issubset(
SUPPORT_MMLAB_TYPES), 'Only support %s now !' % SUPPORT_MMLAB_TYPES
install_success = False
try:
import mmdet
install_success = True
except ModuleNotFoundError as e:
logging.warning(e)
logging.warning('Try to install mmdet...')
if not install_success:
try:
run_in_subprocess('pip install mmdet')
except:
raise ValueError(
'Failed to install mmdet, '
'please refer to https://github.com/open-mmlab/mmdetection to install.'
)
def fix_conflicts(self):
# mmdet and easycv both register
if MMDET in self.mmtype_list:
mmcv_conflict_list = ['YOLOXLrUpdaterHook']
from mmcv.runner.hooks import HOOKS
for conflict in mmcv_conflict_list:
HOOKS._module_dict.pop(conflict, None)
def adapt_mmlab_modules(self):
for module_cfg in self.modules_config:
mmtype = module_cfg['type']
module_name, module_type = module_cfg['name'], module_cfg['module']
self._merge_mmlab_module_to_easycv(mmtype, module_type,
module_name)
self.wrap_module(mmtype, module_type, module_name)
for mmtype in self.mmtype_list:
self._merge_all_easycv_modules_to_mmlab(mmtype)
def wrap_module(self, mmtype, module_type, module_name):
module_obj = self._get_mm_module_obj(mmtype, module_type, module_name)
if mmtype == MMDET:
MMDetWrapper().wrap_module(module_obj, module_type)
def _merge_all_easycv_modules_to_mmlab(self, mmtype):
# Add all my module to mmlab module registry, if duplicated, replace with my module.
# To handle: if MaskRCNN use mmdet's api, but the backbone also uses the backbone registered in mmdet
# In order to support our backbone, register our modules into mmdet.
# If not specified mmdet type, use our modules by default.
for key, registry_type in self.MMTYPE_REGISTRY_MAP[mmtype].items():
registry_type._module_dict.update(
EASYCV_REGISTRY_MAP[key]._module_dict)
def _merge_mmlab_module_to_easycv(self,
mmtype,
module_type,
module_name,
force=True):
model_obj = self._get_mm_module_obj(mmtype, module_type, module_name)
# Add mmlab module to my module registry.
easycv_registry_type = EASYCV_REGISTRY_MAP[module_type]
# Copy a duplicate to avoid directly modifying the properties of the original object
_MMLAB_COPIES[module_name] = type(module_name, (model_obj, ), dict())
easycv_registry_type.register_module(
_MMLAB_COPIES[module_name], force=force)
def _get_mm_module_obj(self, mmtype, module_type, module_name):
if isinstance(module_name, str):
mm_registry_type = self.MMTYPE_REGISTRY_MAP[mmtype][module_type]
mm_module_dict = mm_registry_type._module_dict
if module_name in mm_module_dict:
module_obj = mm_module_dict[module_name]
else:
raise ValueError('Not find {} object in {}'.format(
module_name, mmtype))
elif inspect.isclass(module_name):
module_obj = module_name
else:
raise ValueError(
'Only support type `str` and `class` object, but get type {}'.
format(type(module_name)))
return module_obj
def _get_mmtype_registry_map(self):
from mmdet.models.builder import MODELS as MMMODELS
from mmdet.models.builder import BACKBONES as MMBACKBONES
from mmdet.models.builder import NECKS as MMNECKS
from mmdet.models.builder import HEADS as MMHEADS
registry_map = {
MMDET: {
'model': MMMODELS,
'backbone': MMBACKBONES,
'neck': MMNECKS,
'head': MMHEADS
}
}
return registry_map
class MMDetWrapper:
def wrap_module(self, cls, module_type):
if module_type == 'model':
self._wrap_model_forward(cls)
self._wrap_model_forward_test(cls)
def _wrap_model_forward(self, cls):
origin_forward = cls.forward
def _new_forward(self, img, mode='train', **kwargs):
img_metas = kwargs.pop('img_metas', None)
if mode == 'train':
return origin_forward(
self, img, img_metas, return_loss=True, **kwargs)
else:
return origin_forward(
self, img, img_metas, return_loss=False, **kwargs)
setattr(cls, 'forward', _new_forward)
def _wrap_model_forward_test(self, cls):
from mmdet.core import encode_mask_results
origin_forward_test = cls.forward_test
def _new_forward_test(self, img, img_metas=None, **kwargs):
kwargs.update({'rescale': True}) # move from single_gpu_test
logging.info('Set rescale to True for `model.forward_test`!')
result = origin_forward_test(self, img, img_metas, **kwargs)
# ============result process to adapt to easycv============
# encode mask results
if isinstance(result[0], tuple):
result = [(bbox_results, encode_mask_results(mask_results))
for bbox_results, mask_results in result]
# This logic is only used in panoptic segmentation test.
elif isinstance(result[0], dict) and 'ins_results' in result[0]:
for j in range(len(result)):
bbox_results, mask_results = result[j]['ins_results']
result[j]['ins_results'] = (
bbox_results, encode_mask_results(mask_results))
detection_boxes = []
detection_scores = []
detection_classes = []
detection_masks = []
for res_i in result:
if isinstance(res_i, tuple):
bbox_result, segm_result = res_i
if isinstance(segm_result, tuple):
segm_result = segm_result[0] # ms rcnn
else:
bbox_result, segm_result = res_i, None
bboxes = np.vstack(bbox_result)
labels = [
np.full(bbox.shape[0], i, dtype=np.int32)
for i, bbox in enumerate(bbox_result)
]
labels = np.concatenate(labels)
# draw segmentation masks
segms = []
if segm_result is not None and len(labels) > 0: # non empty
segms = mmcv.concat_list(segm_result)
if isinstance(segms[0], torch.Tensor):
segms = torch.stack(
segms, dim=0).detach().cpu().numpy()
else:
segms = np.stack(segms, axis=0)
scores = bboxes[:, 4] if bboxes.shape[1] == 5 else None
bboxes = bboxes[:, 0:4] if bboxes.shape[1] == 5 else bboxes
assert bboxes.shape[1] == 4
detection_boxes.append(bboxes)
detection_scores.append(scores)
detection_classes.append(labels)
detection_masks.append(segms)
assert len(img_metas) == 1
outputs = {
'detection_boxes': detection_boxes,
'detection_scores': detection_scores,
'detection_classes': detection_classes,
'detection_masks': detection_masks,
'img_metas': img_metas[0]
}
return outputs
setattr(cls, 'forward_test', _new_forward_test)
def dynamic_adapt_for_mmlab(cfg):
mmlab_modules_cfg = cfg.get('mmlab_modules', [])
if len(mmlab_modules_cfg) > 1:
adapter = MMAdapter(mmlab_modules_cfg)
adapter.adapt_mmlab_modules()

View File

@ -0,0 +1,133 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
import os
import unittest
import torch
from mmcv.parallel import scatter_kwargs
from tests.ut_config import (COCO_CLASSES, DET_DATA_SMALL_COCO_LOCAL,
IMG_NORM_CFG_255)
from easycv.apis.test import single_gpu_test
from easycv.datasets import build_dataloader, build_dataset
from easycv.models.builder import build_model
from easycv.utils.config_tools import mmcv_config_fromfile
from easycv.utils.mmlab_utils import dynamic_adapt_for_mmlab
class MMLabUtilTest(unittest.TestCase):
def setUp(self):
print(('Testing %s.%s' % (type(self).__name__, self._testMethodName)))
def _get_model(self):
config_path = 'configs/detection/mask_rcnn/mask_rcnn_r50_fpn.py'
cfg = mmcv_config_fromfile(config_path)
dynamic_adapt_for_mmlab(cfg)
model = build_model(cfg.model)
return model
def _get_dataset(self, mode='train'):
if mode == 'train':
pipeline = [
dict(
type='MMResize',
img_scale=[(1333, 640), (1333, 672), (1333, 704),
(1333, 736), (1333, 768), (1333, 800)],
multiscale_mode='value',
keep_ratio=True),
dict(type='MMRandomFlip', flip_ratio=0.5),
dict(type='MMNormalize', **IMG_NORM_CFG_255),
dict(type='MMPad', size_divisor=32),
dict(type='DefaultFormatBundle'),
dict(
type='Collect',
keys=['img', 'gt_bboxes', 'gt_labels', 'gt_masks'],
meta_keys=('filename', 'ori_filename', 'ori_shape',
'ori_img_shape', 'img_shape', 'pad_shape',
'scale_factor', 'flip', 'flip_direction',
'img_norm_cfg'))
]
else:
pipeline = [
dict(
type='MMMultiScaleFlipAug',
img_scale=(1333, 800),
flip=False,
transforms=[
dict(type='MMResize', keep_ratio=True),
dict(type='MMRandomFlip'),
dict(type='MMNormalize', **IMG_NORM_CFG_255),
dict(type='MMPad', size_divisor=32),
dict(type='ImageToTensor', keys=['img']),
dict(
type='Collect',
keys=['img'],
meta_keys=('filename', 'ori_filename', 'ori_shape',
'ori_img_shape', 'img_shape',
'pad_shape', 'scale_factor', 'flip',
'flip_direction', 'img_norm_cfg')),
])
]
data_root = DET_DATA_SMALL_COCO_LOCAL
dataset_cfg = dict(
type='DetDataset',
data_source=dict(
type='DetSourceCoco',
ann_file=os.path.join(data_root,
'instances_train2017_20.json'),
img_prefix=os.path.join(data_root, 'train2017'),
pipeline=[
dict(type='LoadImageFromFile', to_float32=True),
dict(
type='LoadAnnotations', with_bbox=True, with_mask=True)
],
classes=COCO_CLASSES,
filter_empty_gt=False,
iscrowd=False),
pipeline=pipeline)
return build_dataset(dataset_cfg)
def xxtest_model_train(self):
model = self._get_model()
model = model.cuda()
model.train()
dataset = self._get_dataset()
data_loader = build_dataloader(
dataset, imgs_per_gpu=3, workers_per_gpu=1, num_gpus=1, dist=False)
for i, data_batch in enumerate(data_loader):
input_args, kwargs = scatter_kwargs(None, data_batch, [-1])
for key in ['img', 'gt_bboxes', 'gt_labels']:
if isinstance(kwargs[0][key], (list, tuple)):
kwargs[0][key] = [
kwargs[0][key][i].cuda()
for i in range(len(kwargs[0][key]))
]
else:
kwargs[0][key] = kwargs[0][key].cuda()
output = model(**kwargs[0], mode='train')
self.assertEqual(len(output['loss_rpn_cls']), 5)
self.assertEqual(len(output['loss_rpn_bbox']), 5)
self.assertEqual(output['loss_cls'].shape, torch.Size([]))
self.assertEqual(output['acc'].shape, torch.Size([1]))
self.assertEqual(output['loss_bbox'].shape, torch.Size([]))
self.assertEqual(output['loss_mask'].shape, torch.Size([1]))
def test_model_test(self):
model = self._get_model()
model = model.cuda()
dataset = self._get_dataset(mode='test')
data_loader = build_dataloader(
dataset, imgs_per_gpu=1, workers_per_gpu=1, num_gpus=1, dist=False)
results = single_gpu_test(model, data_loader, mode='test')
self.assertEqual(len(results['detection_boxes']), 20)
self.assertEqual(len(results['detection_scores']), 20)
self.assertEqual(len(results['detection_classes']), 20)
self.assertEqual(len(results['detection_masks']), 20)
self.assertEqual(len(results['img_metas']), 20)
if __name__ == '__main__':
unittest.main()

View File

@ -28,6 +28,7 @@ from easycv.models import build_model
from easycv.utils.checkpoint import load_checkpoint
from easycv.utils.config_tools import (CONFIG_TEMPLATE_ZOO,
mmcv_config_fromfile, rebuild_config)
from easycv.utils.mmlab_utils import dynamic_adapt_for_mmlab
# from tools.fuse_conv_bn import fuse_module
@ -143,6 +144,9 @@ def main():
if cfg.get('oss_io_config', None) is not None:
io.access_oss(**cfg.oss_io_config)
# dynamic adapt mmdet models
dynamic_adapt_for_mmlab(cfg)
# set cudnn_benchmark
if cfg.get('cudnn_benchmark', False):
torch.backends.cudnn.benchmark = True

View File

@ -1,145 +0,0 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
import argparse
import importlib
import os
import os.path as osp
import sys
import time
import mmcv
import torch
from mmcv.parallel import MMDataParallel, MMDistributedDataParallel
from mmcv.runner import get_dist_info, init_dist, load_checkpoint
from easycv.core.evaluation.builder import build_evaluator
from easycv.datasets import build_dataloader, build_dataset
from easycv.models import build_model
from easycv.utils.collect import dist_forward_collect, nondist_forward_collect
# from mmcv import Config
from easycv.utils.config_tools import mmcv_config_fromfile, traverse_replace
from easycv.utils.logger import get_root_logger
sys.path.append(os.path.abspath(os.path.dirname(os.path.dirname(__file__))))
sys.path.append(
os.path.abspath(
osp.join(os.path.dirname(os.path.dirname(__file__)), '../')))
def single_gpu_test(model, data_loader):
model.eval()
func = lambda **x: model(mode='test', **x)
results = nondist_forward_collect(func, data_loader,
len(data_loader.dataset))
return results
def multi_gpu_test(model, data_loader):
model.eval()
func = lambda **x: model(mode='test', **x)
rank, world_size = get_dist_info()
results = dist_forward_collect(func, data_loader, rank,
len(data_loader.dataset))
return results
def parse_args():
parser = argparse.ArgumentParser(
description='MMDet test (and eval) a model')
parser.add_argument('config', help='test config file path')
parser.add_argument('checkpoint', help='checkpoint file')
parser.add_argument(
'--work_dir',
type=str,
default=None,
help='the dir to save logs and models')
parser.add_argument(
'--launcher',
choices=['none', 'pytorch', 'slurm', 'mpi'],
default='none',
help='job launcher')
parser.add_argument('--local_rank', type=int, default=0)
parser.add_argument(
'--port',
type=int,
default=29500,
help='port only works when launcher=="slurm"')
parser.add_argument(
'--model_type',
choices=['classification', 'pose'],
default='classification',
help='model type')
args = parser.parse_args()
if 'LOCAL_RANK' not in os.environ:
os.environ['LOCAL_RANK'] = str(args.local_rank)
return args
def main():
args = parse_args()
cfg = mmcv_config_fromfile(args.config)
# set cudnn_benchmark
if cfg.get('cudnn_benchmark', False):
torch.backends.cudnn.benchmark = True
# update configs according to CLI args
if args.work_dir is not None:
cfg.work_dir = args.work_dir
cfg.model.pretrained = None # ensure to use checkpoint rather than pretraining
# check memcached package exists
if importlib.util.find_spec('mc') is None:
traverse_replace(cfg, 'memcached', False)
# init distributed env first, since logger depends on the dist info.
if args.launcher == 'none':
distributed = False
else:
distributed = True
if args.launcher == 'slurm':
cfg.dist_params['port'] = args.port
init_dist(args.launcher, **cfg.dist_params)
# logger
timestamp = time.strftime('%Y%m%d_%H%M%S', time.localtime())
log_file = osp.join(cfg.work_dir, 'test_{}.log'.format(timestamp))
logger = get_root_logger(log_file=log_file, log_level=cfg.log_level)
# build the dataloader
dataset = build_dataset(cfg.data.val)
data_loader = build_dataloader(
dataset,
imgs_per_gpu=cfg.data.imgs_per_gpu,
workers_per_gpu=cfg.data.workers_per_gpu,
dist=distributed,
shuffle=False)
# build the model and load checkpoint
model = build_model(cfg.model)
load_checkpoint(model, args.checkpoint, map_location='cpu')
if not distributed:
model = MMDataParallel(model, device_ids=[0])
outputs = single_gpu_test(model, data_loader)
else:
model = MMDistributedDataParallel(
model.cuda(),
device_ids=[torch.cuda.current_device()],
broadcast_buffers=False)
outputs = multi_gpu_test(model, data_loader) # dict{key: np.ndarray}
rank, _ = get_dist_info()
if rank == 0:
if args.model_type == 'pose':
evaluators = build_evaluator(
cfg.eval_pipelines[0]['evaluators'][0])
dataset.evaluate(outputs, evaluators)
else:
for name, val in outputs.items():
dataset.evaluate(
torch.from_numpy(val), name, logger, topk=(1, 5))
if __name__ == '__main__':
main()

View File

@ -29,6 +29,7 @@ from easycv.models import build_model
from easycv.utils.collect_env import collect_env
from easycv.utils.flops_counter import get_model_info
from easycv.utils.logger import get_root_logger
from easycv.utils.mmlab_utils import dynamic_adapt_for_mmlab
from easycv.utils.config_tools import traverse_replace
from easycv.utils.config_tools import (CONFIG_TEMPLATE_ZOO,
mmcv_config_fromfile, rebuild_config)
@ -149,6 +150,9 @@ def main():
if args.load_from is not None:
cfg.load_from = args.load_from
# dynamic adapt mmdet models
dynamic_adapt_for_mmlab(cfg)
cfg.gpus = args.gpus
# check memcached package exists