mirror of https://github.com/open-mmlab/mmocr.git
[MMDet] DetWrapper
parent
dae4c9ca8c
commit
ae4ba012a8
.dev_scripts
mmocr/models/textdet/detectors
tests/test_models/test_textdet/test_wrappers.py
|
@ -41,6 +41,8 @@ mmocr/datasets/utils/loader.py
|
|||
# It will be removed after TTA refactor
|
||||
mmocr/datasets/pipelines/test_time_aug.py
|
||||
|
||||
# Major part is coverd, however, it's hard to cover model's output.
|
||||
mmocr/models/textdet/detectors/mmdet_wrapper.py
|
||||
# Cover it by tests seems like an impossible mission
|
||||
mmocr/models/textdet/postprocessors/drrg_postprocessor.py
|
||||
|
||||
|
|
|
@ -2,6 +2,7 @@
|
|||
from .dbnet import DBNet
|
||||
from .drrg import DRRG
|
||||
from .fcenet import FCENet
|
||||
from .mmdet_wrapper import MMDetWrapper
|
||||
from .panet import PANet
|
||||
from .psenet import PSENet
|
||||
from .single_stage_text_detector import SingleStageTextDetector
|
||||
|
@ -9,5 +10,5 @@ from .textsnake import TextSnake
|
|||
|
||||
__all__ = [
|
||||
'SingleStageTextDetector', 'DBNet', 'PANet', 'PSENet', 'TextSnake',
|
||||
'FCENet', 'DRRG'
|
||||
'FCENet', 'DRRG', 'MMDetWrapper'
|
||||
]
|
||||
|
|
|
@ -0,0 +1,142 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from typing import Dict, List
|
||||
|
||||
import cv2
|
||||
import torch
|
||||
from mmdet.core import DetDataSample
|
||||
from mmdet.core.mask.structures import bitmap_to_polygon
|
||||
from mmdet.core.utils import ForwardResults, OptSampleList
|
||||
from mmengine import InstanceData
|
||||
from mmengine.model import BaseModel
|
||||
|
||||
from mmocr.data import TextDetDataSample
|
||||
from mmocr.registry import MODELS
|
||||
from mmocr.utils.bbox_utils import bbox2poly
|
||||
|
||||
|
||||
@MODELS.register_module()
|
||||
class MMDetWrapper(BaseModel):
|
||||
"""A wrapper of MMDet's model.
|
||||
|
||||
Args:
|
||||
cfg (dict): The config of the model.
|
||||
text_repr_type (str): The boundary encoding type 'poly' or 'quad'.
|
||||
Defaults to 'poly'.
|
||||
"""
|
||||
|
||||
def __init__(self, cfg: Dict, text_repr_type: str = 'poly') -> None:
|
||||
data_preprocessor = cfg.pop('data_preprocessor')
|
||||
data_preprocessor.update(_scope_='mmdet')
|
||||
super().__init__(data_preprocessor=data_preprocessor, init_cfg=None)
|
||||
cfg['_scope_'] = 'mmdet'
|
||||
self.wrapped_model = MODELS.build(cfg)
|
||||
self.text_repr_type = text_repr_type
|
||||
|
||||
def forward(self,
|
||||
batch_inputs: torch.Tensor,
|
||||
batch_data_samples: OptSampleList = None,
|
||||
mode: str = 'tensor',
|
||||
**kwargs) -> ForwardResults:
|
||||
"""The unified entry for a forward process in both training and test.
|
||||
|
||||
The method should accept three modes: "tensor", "predict" and "loss":
|
||||
|
||||
- "tensor": Forward the whole network and return tensor or tuple of
|
||||
tensor without any post-processing, same as a common nn.Module.
|
||||
- "predict": Forward and return the predictions, which are fully
|
||||
processed to a list of :obj:`DetDataSample`.
|
||||
- "loss": Forward and return a dict of losses according to the given
|
||||
inputs and data samples.
|
||||
|
||||
Note that this method doesn't handle neither back propagation nor
|
||||
optimizer updating, which are done in the :meth:`train_step`.
|
||||
|
||||
Args:
|
||||
batch_inputs (torch.Tensor): The input tensor with shape
|
||||
(N, C, ...) in general.
|
||||
batch_data_samples (list[:obj:`DetDataSample`], optional): The
|
||||
annotation data of every samples. Defaults to None.
|
||||
mode (str): Return what kind of value. Defaults to 'tensor'.
|
||||
|
||||
Returns:
|
||||
The return type depends on ``mode``.
|
||||
|
||||
- If ``mode="tensor"``, return a tensor or a tuple of tensor.
|
||||
- If ``mode="predict"``, return a list of :obj:`TextDetDataSample`.
|
||||
- If ``mode="loss"``, return a dict of tensor.
|
||||
"""
|
||||
results = self.wrapped_model.forward(batch_inputs, batch_data_samples,
|
||||
mode, **kwargs)
|
||||
if mode == 'predict':
|
||||
results = self.adapt_predictions(results)
|
||||
|
||||
return results
|
||||
|
||||
def adapt_predictions(self, data: List[DetDataSample]
|
||||
) -> List[TextDetDataSample]:
|
||||
"""Convert Instance datas from MMDet into MMOCR's format.
|
||||
|
||||
Args:
|
||||
data: (list[DetDataSample]): Detection results of the
|
||||
input images. Each DetDataSample usually contain
|
||||
'pred_instances'. And the ``pred_instances`` usually
|
||||
contains following keys.
|
||||
- scores (Tensor): Classification scores, has a shape
|
||||
(num_instance, )
|
||||
- labels (Tensor): Labels of bboxes, has a shape
|
||||
(num_instances, ).
|
||||
- bboxes (Tensor): Has a shape (num_instances, 4),
|
||||
the last dimension 4 arrange as (x1, y1, x2, y2).
|
||||
- masks (Tensor, Optional): Has a shape (num_instances, H, W).
|
||||
|
||||
Returns:
|
||||
list[TextDetDataSample]: A list of N datasamples of prediction
|
||||
results.
|
||||
The polygon results are saved in
|
||||
``TextDetDataSample.pred_instances.polygons``
|
||||
The confidence scores are saved in
|
||||
``TextDetDataSample.pred_instances.scores``.
|
||||
"""
|
||||
results = []
|
||||
for data_sample in data:
|
||||
result = TextDetDataSample()
|
||||
result.pred_instances = InstanceData()
|
||||
# convert mask to polygons if mask exists
|
||||
if 'masks' in data_sample.pred_instances.keys():
|
||||
masks = data_sample.pred_instances.masks.cpu().numpy()
|
||||
polygons = []
|
||||
scores = []
|
||||
for mask_idx, mask in enumerate(masks):
|
||||
contours, _ = bitmap_to_polygon(mask)
|
||||
polygons += [contour.reshape(-1) for contour in contours]
|
||||
scores += [
|
||||
data_sample.pred_instances.scores[mask_idx].cpu()
|
||||
] * len(contours)
|
||||
# filter invalid polygons
|
||||
filterd_polygons = []
|
||||
keep_idx = []
|
||||
for poly_idx, polygon in enumerate(polygons):
|
||||
if len(polygon) < 6:
|
||||
continue
|
||||
filterd_polygons.append(polygon)
|
||||
keep_idx.append(poly_idx)
|
||||
# convert by text_repr_type
|
||||
if self.text_repr_type == 'quad':
|
||||
for i, poly in enumerate(filterd_polygons):
|
||||
rect = cv2.minAreaRect(poly)
|
||||
vertices = cv2.boxPoints(rect)
|
||||
poly = vertices.flatten()
|
||||
filterd_polygons[i] = poly
|
||||
|
||||
result.pred_instances.polygons = filterd_polygons
|
||||
result.pred_instances.scores = torch.FloatTensor(
|
||||
scores)[keep_idx]
|
||||
else:
|
||||
bboxes = data_sample.pred_instances.bboxes.cpu().numpy()
|
||||
polygons = [bbox2poly(bbox) for bbox in bboxes]
|
||||
result.pred_instances.polygons = polygons
|
||||
result.pred_instances.scores = torch.FloatTensor(
|
||||
data_sample.pred_instances.scores.cpu())
|
||||
results.append(result)
|
||||
|
||||
return results
|
|
@ -0,0 +1,270 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import unittest
|
||||
|
||||
import torch
|
||||
from mmdet.core import DetDataSample
|
||||
from mmdet.testing import demo_mm_inputs
|
||||
from mmengine.config import Config
|
||||
from mmengine.data import InstanceData
|
||||
|
||||
from mmocr.data import TextDetDataSample
|
||||
from mmocr.registry import MODELS
|
||||
|
||||
|
||||
class TestMMDetWrapper(unittest.TestCase):
|
||||
|
||||
def setUp(self):
|
||||
model_cfg_fcos = dict(
|
||||
type='MMDetWrapper',
|
||||
cfg=dict(
|
||||
type='FCOS',
|
||||
data_preprocessor=dict(
|
||||
type='DetDataPreprocessor',
|
||||
mean=[102.9801, 115.9465, 122.7717],
|
||||
std=[1.0, 1.0, 1.0],
|
||||
bgr_to_rgb=False,
|
||||
pad_size_divisor=32),
|
||||
backbone=dict(
|
||||
type='ResNet',
|
||||
depth=50,
|
||||
num_stages=4,
|
||||
out_indices=(0, 1, 2, 3),
|
||||
frozen_stages=1,
|
||||
norm_cfg=dict(type='BN', requires_grad=False),
|
||||
norm_eval=True,
|
||||
style='caffe',
|
||||
init_cfg=dict(
|
||||
type='Pretrained',
|
||||
checkpoint='open-mmlab://detectron/resnet50_caffe')),
|
||||
neck=dict(
|
||||
type='FPN',
|
||||
in_channels=[256, 512, 1024, 2048],
|
||||
out_channels=256,
|
||||
start_level=1,
|
||||
add_extra_convs='on_output', # use P5
|
||||
num_outs=5,
|
||||
relu_before_extra_convs=True),
|
||||
bbox_head=dict(
|
||||
type='FCOSHead',
|
||||
num_classes=2,
|
||||
in_channels=256,
|
||||
stacked_convs=4,
|
||||
feat_channels=256,
|
||||
strides=[8, 16, 32, 64, 128],
|
||||
loss_cls=dict(
|
||||
type='FocalLoss',
|
||||
use_sigmoid=True,
|
||||
gamma=2.0,
|
||||
alpha=0.25,
|
||||
loss_weight=1.0),
|
||||
loss_bbox=dict(type='IoULoss', loss_weight=1.0),
|
||||
loss_centerness=dict(
|
||||
type='CrossEntropyLoss',
|
||||
use_sigmoid=True,
|
||||
loss_weight=1.0)),
|
||||
# testing settings
|
||||
test_cfg=dict(
|
||||
nms_pre=1000,
|
||||
min_bbox_size=0,
|
||||
score_thr=0.05,
|
||||
nms=dict(type='nms', iou_threshold=0.5),
|
||||
max_per_img=100)))
|
||||
model_cfg_maskrcnn = dict(
|
||||
type='MMDetWrapper',
|
||||
text_repr_type='quad',
|
||||
cfg=dict(
|
||||
type='MaskRCNN',
|
||||
data_preprocessor=dict(
|
||||
type='DetDataPreprocessor',
|
||||
mean=[123.675, 116.28, 103.53],
|
||||
std=[58.395, 57.12, 57.375],
|
||||
bgr_to_rgb=True,
|
||||
pad_size_divisor=32),
|
||||
backbone=dict(
|
||||
type='ResNet',
|
||||
depth=50,
|
||||
num_stages=4,
|
||||
out_indices=(0, 1, 2, 3),
|
||||
frozen_stages=1,
|
||||
norm_cfg=dict(type='BN', requires_grad=True),
|
||||
norm_eval=True,
|
||||
style='pytorch',
|
||||
init_cfg=dict(
|
||||
type='Pretrained',
|
||||
checkpoint='torchvision://resnet50')),
|
||||
neck=dict(
|
||||
type='FPN',
|
||||
in_channels=[256, 512, 1024, 2048],
|
||||
out_channels=256,
|
||||
num_outs=5),
|
||||
rpn_head=dict(
|
||||
type='RPNHead',
|
||||
in_channels=256,
|
||||
feat_channels=256,
|
||||
anchor_generator=dict(
|
||||
type='AnchorGenerator',
|
||||
scales=[8],
|
||||
ratios=[0.5, 1.0, 2.0],
|
||||
strides=[4, 8, 16, 32, 64]),
|
||||
bbox_coder=dict(
|
||||
type='DeltaXYWHBBoxCoder',
|
||||
target_means=[.0, .0, .0, .0],
|
||||
target_stds=[1.0, 1.0, 1.0, 1.0]),
|
||||
loss_cls=dict(
|
||||
type='CrossEntropyLoss',
|
||||
use_sigmoid=True,
|
||||
loss_weight=1.0),
|
||||
loss_bbox=dict(type='L1Loss', loss_weight=1.0)),
|
||||
roi_head=dict(
|
||||
type='StandardRoIHead',
|
||||
bbox_roi_extractor=dict(
|
||||
type='SingleRoIExtractor',
|
||||
roi_layer=dict(
|
||||
type='RoIAlign', output_size=7, sampling_ratio=0),
|
||||
out_channels=256,
|
||||
featmap_strides=[4, 8, 16, 32]),
|
||||
bbox_head=dict(
|
||||
type='Shared2FCBBoxHead',
|
||||
in_channels=256,
|
||||
fc_out_channels=1024,
|
||||
roi_feat_size=7,
|
||||
num_classes=80,
|
||||
bbox_coder=dict(
|
||||
type='DeltaXYWHBBoxCoder',
|
||||
target_means=[0., 0., 0., 0.],
|
||||
target_stds=[0.1, 0.1, 0.2, 0.2]),
|
||||
reg_class_agnostic=False,
|
||||
loss_cls=dict(
|
||||
type='CrossEntropyLoss',
|
||||
use_sigmoid=False,
|
||||
loss_weight=1.0),
|
||||
loss_bbox=dict(type='L1Loss', loss_weight=1.0)),
|
||||
mask_roi_extractor=dict(
|
||||
type='SingleRoIExtractor',
|
||||
roi_layer=dict(
|
||||
type='RoIAlign', output_size=14, sampling_ratio=0),
|
||||
out_channels=256,
|
||||
featmap_strides=[4, 8, 16, 32]),
|
||||
mask_head=dict(
|
||||
type='FCNMaskHead',
|
||||
num_convs=4,
|
||||
in_channels=256,
|
||||
conv_out_channels=256,
|
||||
num_classes=80,
|
||||
loss_mask=dict(
|
||||
type='CrossEntropyLoss',
|
||||
use_mask=True,
|
||||
loss_weight=1.0))),
|
||||
# model training and testing settings
|
||||
train_cfg=dict(
|
||||
rpn=dict(
|
||||
assigner=dict(
|
||||
type='MaxIoUAssigner',
|
||||
pos_iou_thr=0.7,
|
||||
neg_iou_thr=0.3,
|
||||
min_pos_iou=0.3,
|
||||
match_low_quality=True,
|
||||
ignore_iof_thr=-1),
|
||||
sampler=dict(
|
||||
type='RandomSampler',
|
||||
num=256,
|
||||
pos_fraction=0.5,
|
||||
neg_pos_ub=-1,
|
||||
add_gt_as_proposals=False),
|
||||
allowed_border=-1,
|
||||
pos_weight=-1,
|
||||
debug=False),
|
||||
rpn_proposal=dict(
|
||||
nms_pre=2000,
|
||||
max_per_img=1000,
|
||||
nms=dict(type='nms', iou_threshold=0.7),
|
||||
min_bbox_size=0),
|
||||
rcnn=dict(
|
||||
assigner=dict(
|
||||
type='MaxIoUAssigner',
|
||||
pos_iou_thr=0.5,
|
||||
neg_iou_thr=0.5,
|
||||
min_pos_iou=0.5,
|
||||
match_low_quality=True,
|
||||
ignore_iof_thr=-1),
|
||||
sampler=dict(
|
||||
type='RandomSampler',
|
||||
num=512,
|
||||
pos_fraction=0.25,
|
||||
neg_pos_ub=-1,
|
||||
add_gt_as_proposals=True),
|
||||
mask_size=28,
|
||||
pos_weight=-1,
|
||||
debug=False)),
|
||||
test_cfg=dict(
|
||||
rpn=dict(
|
||||
nms_pre=1000,
|
||||
max_per_img=1000,
|
||||
nms=dict(type='nms', iou_threshold=0.7),
|
||||
min_bbox_size=0),
|
||||
rcnn=dict(
|
||||
score_thr=0.05,
|
||||
nms=dict(type='nms', iou_threshold=0.5),
|
||||
max_per_img=100,
|
||||
mask_thr_binary=0.5))))
|
||||
|
||||
self.FCOS = MODELS.build(Config(model_cfg_fcos))
|
||||
self.MRCNN = MODELS.build(Config(model_cfg_maskrcnn))
|
||||
|
||||
def test_one_stage_wrapper(self):
|
||||
packed_inputs = demo_mm_inputs(
|
||||
2, [[3, 128, 128], [3, 128, 128]], num_classes=2)
|
||||
# Test forward train
|
||||
bi, ds = self.FCOS.data_preprocessor(packed_inputs, True)
|
||||
losses = self.FCOS.forward(bi, ds, mode='loss')
|
||||
assert isinstance(losses, dict)
|
||||
# Test forward test
|
||||
self.FCOS.eval()
|
||||
with torch.no_grad():
|
||||
batch_results = self.FCOS.forward(bi, ds, mode='predict')
|
||||
self.assertEqual(len(batch_results), 2)
|
||||
self.assertIsInstance(batch_results[0], TextDetDataSample)
|
||||
|
||||
def test_mask_two_stage_wrapper(self):
|
||||
packed_inputs = demo_mm_inputs(
|
||||
2, [[3, 128, 128], [3, 128, 128]], num_classes=2, with_mask=True)
|
||||
# Test forward train
|
||||
bi, ds = self.MRCNN.data_preprocessor(packed_inputs, True)
|
||||
losses = self.MRCNN.forward(bi, ds, mode='loss')
|
||||
assert isinstance(losses, dict)
|
||||
# Test forward test
|
||||
self.MRCNN.eval()
|
||||
with torch.no_grad():
|
||||
batch_results = self.MRCNN.forward(bi, ds, mode='predict')
|
||||
self.assertEqual(len(batch_results), 2)
|
||||
self.assertIsInstance(batch_results[0], TextDetDataSample)
|
||||
|
||||
def test_adapt_predictions(self):
|
||||
data_sample = DetDataSample()
|
||||
pred_instances = InstanceData()
|
||||
pred_instances.scores = torch.randn(1)
|
||||
pred_instances.labels = torch.Tensor([1])
|
||||
pred_instances.bboxes = torch.Tensor([[0, 0, 2, 2]])
|
||||
pred_instances.masks = torch.rand(1, 10, 10)
|
||||
data_sample.pred_instances = pred_instances
|
||||
results = self.MRCNN.adapt_predictions([data_sample])
|
||||
self.assertEqual(len(results), 1)
|
||||
self.assertIsInstance(results[0], TextDetDataSample)
|
||||
self.assertTrue('polygons' in results[0].pred_instances.keys())
|
||||
|
||||
data_sample = DetDataSample()
|
||||
pred_instances = InstanceData()
|
||||
pred_instances.scores = torch.randn(1)
|
||||
pred_instances.labels = torch.Tensor([1])
|
||||
pred_instances.bboxes = torch.Tensor([[0, 0, 2, 2]])
|
||||
data_sample.pred_instances = pred_instances
|
||||
results = self.FCOS.adapt_predictions([data_sample])
|
||||
self.assertEqual(len(results), 1)
|
||||
self.assertIsInstance(results[0], TextDetDataSample)
|
||||
self.assertTrue('polygons' in results[0].pred_instances.keys())
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
test = TestMMDetWrapper()
|
||||
test.setUp()
|
||||
test.test_mask_two_stage_wrapper()
|
Loading…
Reference in New Issue