[MMDet] DetWrapper

pull/1178/head
jiangqing.vendor 2022-07-13 11:25:37 +00:00 committed by gaotongxiao
parent dae4c9ca8c
commit ae4ba012a8
4 changed files with 416 additions and 1 deletions
.dev_scripts
mmocr/models/textdet/detectors
tests/test_models/test_textdet/test_wrappers.py

View File

@ -41,6 +41,8 @@ mmocr/datasets/utils/loader.py
# It will be removed after TTA refactor
mmocr/datasets/pipelines/test_time_aug.py
# Major part is coverd, however, it's hard to cover model's output.
mmocr/models/textdet/detectors/mmdet_wrapper.py
# Cover it by tests seems like an impossible mission
mmocr/models/textdet/postprocessors/drrg_postprocessor.py

View File

@ -2,6 +2,7 @@
from .dbnet import DBNet
from .drrg import DRRG
from .fcenet import FCENet
from .mmdet_wrapper import MMDetWrapper
from .panet import PANet
from .psenet import PSENet
from .single_stage_text_detector import SingleStageTextDetector
@ -9,5 +10,5 @@ from .textsnake import TextSnake
__all__ = [
'SingleStageTextDetector', 'DBNet', 'PANet', 'PSENet', 'TextSnake',
'FCENet', 'DRRG'
'FCENet', 'DRRG', 'MMDetWrapper'
]

View File

@ -0,0 +1,142 @@
# Copyright (c) OpenMMLab. All rights reserved.
from typing import Dict, List
import cv2
import torch
from mmdet.core import DetDataSample
from mmdet.core.mask.structures import bitmap_to_polygon
from mmdet.core.utils import ForwardResults, OptSampleList
from mmengine import InstanceData
from mmengine.model import BaseModel
from mmocr.data import TextDetDataSample
from mmocr.registry import MODELS
from mmocr.utils.bbox_utils import bbox2poly
@MODELS.register_module()
class MMDetWrapper(BaseModel):
"""A wrapper of MMDet's model.
Args:
cfg (dict): The config of the model.
text_repr_type (str): The boundary encoding type 'poly' or 'quad'.
Defaults to 'poly'.
"""
def __init__(self, cfg: Dict, text_repr_type: str = 'poly') -> None:
data_preprocessor = cfg.pop('data_preprocessor')
data_preprocessor.update(_scope_='mmdet')
super().__init__(data_preprocessor=data_preprocessor, init_cfg=None)
cfg['_scope_'] = 'mmdet'
self.wrapped_model = MODELS.build(cfg)
self.text_repr_type = text_repr_type
def forward(self,
batch_inputs: torch.Tensor,
batch_data_samples: OptSampleList = None,
mode: str = 'tensor',
**kwargs) -> ForwardResults:
"""The unified entry for a forward process in both training and test.
The method should accept three modes: "tensor", "predict" and "loss":
- "tensor": Forward the whole network and return tensor or tuple of
tensor without any post-processing, same as a common nn.Module.
- "predict": Forward and return the predictions, which are fully
processed to a list of :obj:`DetDataSample`.
- "loss": Forward and return a dict of losses according to the given
inputs and data samples.
Note that this method doesn't handle neither back propagation nor
optimizer updating, which are done in the :meth:`train_step`.
Args:
batch_inputs (torch.Tensor): The input tensor with shape
(N, C, ...) in general.
batch_data_samples (list[:obj:`DetDataSample`], optional): The
annotation data of every samples. Defaults to None.
mode (str): Return what kind of value. Defaults to 'tensor'.
Returns:
The return type depends on ``mode``.
- If ``mode="tensor"``, return a tensor or a tuple of tensor.
- If ``mode="predict"``, return a list of :obj:`TextDetDataSample`.
- If ``mode="loss"``, return a dict of tensor.
"""
results = self.wrapped_model.forward(batch_inputs, batch_data_samples,
mode, **kwargs)
if mode == 'predict':
results = self.adapt_predictions(results)
return results
def adapt_predictions(self, data: List[DetDataSample]
) -> List[TextDetDataSample]:
"""Convert Instance datas from MMDet into MMOCR's format.
Args:
data: (list[DetDataSample]): Detection results of the
input images. Each DetDataSample usually contain
'pred_instances'. And the ``pred_instances`` usually
contains following keys.
- scores (Tensor): Classification scores, has a shape
(num_instance, )
- labels (Tensor): Labels of bboxes, has a shape
(num_instances, ).
- bboxes (Tensor): Has a shape (num_instances, 4),
the last dimension 4 arrange as (x1, y1, x2, y2).
- masks (Tensor, Optional): Has a shape (num_instances, H, W).
Returns:
list[TextDetDataSample]: A list of N datasamples of prediction
results.
The polygon results are saved in
``TextDetDataSample.pred_instances.polygons``
The confidence scores are saved in
``TextDetDataSample.pred_instances.scores``.
"""
results = []
for data_sample in data:
result = TextDetDataSample()
result.pred_instances = InstanceData()
# convert mask to polygons if mask exists
if 'masks' in data_sample.pred_instances.keys():
masks = data_sample.pred_instances.masks.cpu().numpy()
polygons = []
scores = []
for mask_idx, mask in enumerate(masks):
contours, _ = bitmap_to_polygon(mask)
polygons += [contour.reshape(-1) for contour in contours]
scores += [
data_sample.pred_instances.scores[mask_idx].cpu()
] * len(contours)
# filter invalid polygons
filterd_polygons = []
keep_idx = []
for poly_idx, polygon in enumerate(polygons):
if len(polygon) < 6:
continue
filterd_polygons.append(polygon)
keep_idx.append(poly_idx)
# convert by text_repr_type
if self.text_repr_type == 'quad':
for i, poly in enumerate(filterd_polygons):
rect = cv2.minAreaRect(poly)
vertices = cv2.boxPoints(rect)
poly = vertices.flatten()
filterd_polygons[i] = poly
result.pred_instances.polygons = filterd_polygons
result.pred_instances.scores = torch.FloatTensor(
scores)[keep_idx]
else:
bboxes = data_sample.pred_instances.bboxes.cpu().numpy()
polygons = [bbox2poly(bbox) for bbox in bboxes]
result.pred_instances.polygons = polygons
result.pred_instances.scores = torch.FloatTensor(
data_sample.pred_instances.scores.cpu())
results.append(result)
return results

View File

@ -0,0 +1,270 @@
# Copyright (c) OpenMMLab. All rights reserved.
import unittest
import torch
from mmdet.core import DetDataSample
from mmdet.testing import demo_mm_inputs
from mmengine.config import Config
from mmengine.data import InstanceData
from mmocr.data import TextDetDataSample
from mmocr.registry import MODELS
class TestMMDetWrapper(unittest.TestCase):
def setUp(self):
model_cfg_fcos = dict(
type='MMDetWrapper',
cfg=dict(
type='FCOS',
data_preprocessor=dict(
type='DetDataPreprocessor',
mean=[102.9801, 115.9465, 122.7717],
std=[1.0, 1.0, 1.0],
bgr_to_rgb=False,
pad_size_divisor=32),
backbone=dict(
type='ResNet',
depth=50,
num_stages=4,
out_indices=(0, 1, 2, 3),
frozen_stages=1,
norm_cfg=dict(type='BN', requires_grad=False),
norm_eval=True,
style='caffe',
init_cfg=dict(
type='Pretrained',
checkpoint='open-mmlab://detectron/resnet50_caffe')),
neck=dict(
type='FPN',
in_channels=[256, 512, 1024, 2048],
out_channels=256,
start_level=1,
add_extra_convs='on_output', # use P5
num_outs=5,
relu_before_extra_convs=True),
bbox_head=dict(
type='FCOSHead',
num_classes=2,
in_channels=256,
stacked_convs=4,
feat_channels=256,
strides=[8, 16, 32, 64, 128],
loss_cls=dict(
type='FocalLoss',
use_sigmoid=True,
gamma=2.0,
alpha=0.25,
loss_weight=1.0),
loss_bbox=dict(type='IoULoss', loss_weight=1.0),
loss_centerness=dict(
type='CrossEntropyLoss',
use_sigmoid=True,
loss_weight=1.0)),
# testing settings
test_cfg=dict(
nms_pre=1000,
min_bbox_size=0,
score_thr=0.05,
nms=dict(type='nms', iou_threshold=0.5),
max_per_img=100)))
model_cfg_maskrcnn = dict(
type='MMDetWrapper',
text_repr_type='quad',
cfg=dict(
type='MaskRCNN',
data_preprocessor=dict(
type='DetDataPreprocessor',
mean=[123.675, 116.28, 103.53],
std=[58.395, 57.12, 57.375],
bgr_to_rgb=True,
pad_size_divisor=32),
backbone=dict(
type='ResNet',
depth=50,
num_stages=4,
out_indices=(0, 1, 2, 3),
frozen_stages=1,
norm_cfg=dict(type='BN', requires_grad=True),
norm_eval=True,
style='pytorch',
init_cfg=dict(
type='Pretrained',
checkpoint='torchvision://resnet50')),
neck=dict(
type='FPN',
in_channels=[256, 512, 1024, 2048],
out_channels=256,
num_outs=5),
rpn_head=dict(
type='RPNHead',
in_channels=256,
feat_channels=256,
anchor_generator=dict(
type='AnchorGenerator',
scales=[8],
ratios=[0.5, 1.0, 2.0],
strides=[4, 8, 16, 32, 64]),
bbox_coder=dict(
type='DeltaXYWHBBoxCoder',
target_means=[.0, .0, .0, .0],
target_stds=[1.0, 1.0, 1.0, 1.0]),
loss_cls=dict(
type='CrossEntropyLoss',
use_sigmoid=True,
loss_weight=1.0),
loss_bbox=dict(type='L1Loss', loss_weight=1.0)),
roi_head=dict(
type='StandardRoIHead',
bbox_roi_extractor=dict(
type='SingleRoIExtractor',
roi_layer=dict(
type='RoIAlign', output_size=7, sampling_ratio=0),
out_channels=256,
featmap_strides=[4, 8, 16, 32]),
bbox_head=dict(
type='Shared2FCBBoxHead',
in_channels=256,
fc_out_channels=1024,
roi_feat_size=7,
num_classes=80,
bbox_coder=dict(
type='DeltaXYWHBBoxCoder',
target_means=[0., 0., 0., 0.],
target_stds=[0.1, 0.1, 0.2, 0.2]),
reg_class_agnostic=False,
loss_cls=dict(
type='CrossEntropyLoss',
use_sigmoid=False,
loss_weight=1.0),
loss_bbox=dict(type='L1Loss', loss_weight=1.0)),
mask_roi_extractor=dict(
type='SingleRoIExtractor',
roi_layer=dict(
type='RoIAlign', output_size=14, sampling_ratio=0),
out_channels=256,
featmap_strides=[4, 8, 16, 32]),
mask_head=dict(
type='FCNMaskHead',
num_convs=4,
in_channels=256,
conv_out_channels=256,
num_classes=80,
loss_mask=dict(
type='CrossEntropyLoss',
use_mask=True,
loss_weight=1.0))),
# model training and testing settings
train_cfg=dict(
rpn=dict(
assigner=dict(
type='MaxIoUAssigner',
pos_iou_thr=0.7,
neg_iou_thr=0.3,
min_pos_iou=0.3,
match_low_quality=True,
ignore_iof_thr=-1),
sampler=dict(
type='RandomSampler',
num=256,
pos_fraction=0.5,
neg_pos_ub=-1,
add_gt_as_proposals=False),
allowed_border=-1,
pos_weight=-1,
debug=False),
rpn_proposal=dict(
nms_pre=2000,
max_per_img=1000,
nms=dict(type='nms', iou_threshold=0.7),
min_bbox_size=0),
rcnn=dict(
assigner=dict(
type='MaxIoUAssigner',
pos_iou_thr=0.5,
neg_iou_thr=0.5,
min_pos_iou=0.5,
match_low_quality=True,
ignore_iof_thr=-1),
sampler=dict(
type='RandomSampler',
num=512,
pos_fraction=0.25,
neg_pos_ub=-1,
add_gt_as_proposals=True),
mask_size=28,
pos_weight=-1,
debug=False)),
test_cfg=dict(
rpn=dict(
nms_pre=1000,
max_per_img=1000,
nms=dict(type='nms', iou_threshold=0.7),
min_bbox_size=0),
rcnn=dict(
score_thr=0.05,
nms=dict(type='nms', iou_threshold=0.5),
max_per_img=100,
mask_thr_binary=0.5))))
self.FCOS = MODELS.build(Config(model_cfg_fcos))
self.MRCNN = MODELS.build(Config(model_cfg_maskrcnn))
def test_one_stage_wrapper(self):
packed_inputs = demo_mm_inputs(
2, [[3, 128, 128], [3, 128, 128]], num_classes=2)
# Test forward train
bi, ds = self.FCOS.data_preprocessor(packed_inputs, True)
losses = self.FCOS.forward(bi, ds, mode='loss')
assert isinstance(losses, dict)
# Test forward test
self.FCOS.eval()
with torch.no_grad():
batch_results = self.FCOS.forward(bi, ds, mode='predict')
self.assertEqual(len(batch_results), 2)
self.assertIsInstance(batch_results[0], TextDetDataSample)
def test_mask_two_stage_wrapper(self):
packed_inputs = demo_mm_inputs(
2, [[3, 128, 128], [3, 128, 128]], num_classes=2, with_mask=True)
# Test forward train
bi, ds = self.MRCNN.data_preprocessor(packed_inputs, True)
losses = self.MRCNN.forward(bi, ds, mode='loss')
assert isinstance(losses, dict)
# Test forward test
self.MRCNN.eval()
with torch.no_grad():
batch_results = self.MRCNN.forward(bi, ds, mode='predict')
self.assertEqual(len(batch_results), 2)
self.assertIsInstance(batch_results[0], TextDetDataSample)
def test_adapt_predictions(self):
data_sample = DetDataSample()
pred_instances = InstanceData()
pred_instances.scores = torch.randn(1)
pred_instances.labels = torch.Tensor([1])
pred_instances.bboxes = torch.Tensor([[0, 0, 2, 2]])
pred_instances.masks = torch.rand(1, 10, 10)
data_sample.pred_instances = pred_instances
results = self.MRCNN.adapt_predictions([data_sample])
self.assertEqual(len(results), 1)
self.assertIsInstance(results[0], TextDetDataSample)
self.assertTrue('polygons' in results[0].pred_instances.keys())
data_sample = DetDataSample()
pred_instances = InstanceData()
pred_instances.scores = torch.randn(1)
pred_instances.labels = torch.Tensor([1])
pred_instances.bboxes = torch.Tensor([[0, 0, 2, 2]])
data_sample.pred_instances = pred_instances
results = self.FCOS.adapt_predictions([data_sample])
self.assertEqual(len(results), 1)
self.assertIsInstance(results[0], TextDetDataSample)
self.assertTrue('polygons' in results[0].pred_instances.keys())
if __name__ == '__main__':
test = TestMMDetWrapper()
test.setUp()
test.test_mask_two_stage_wrapper()