Refactor SingleStageTextDetector and update model classes

This commit is contained in:
gaotongxiao 2022-05-16 03:28:10 +00:00
parent 26da038d49
commit 2c23098b29
4 changed files with 103 additions and 38 deletions

View File

@ -11,3 +11,10 @@ mmocr/models/textdet/dense_heads/head_mixin.py
# Will be covered by det head tests
mmocr/models/textdet/dense_heads/base_textdet_head.py
# They will be removed later all det models have been refactored
mmocr/models/common/detectors/single_stage.py
mmocr/models/textdet/detectors/text_detector_mixin.py
# It will be covered by tests of any det model implemented in future
mmocr/models/textdet/detectors/single_stage_text_detector.py

View File

@ -15,6 +15,7 @@ class SingleStageDetector(MMDET_SingleStageDetector):
output features of the backbone+neck.
"""
# TODO: Remove this class as SDGMR has been refactored
def __init__(self,
backbone,
neck=None,

View File

@ -1,61 +1,117 @@
# Copyright (c) OpenMMLab. All rights reserved.
import torch
from typing import Dict, Optional, Sequence
from mmocr.models.common.detectors import SingleStageDetector
import torch
from mmcv.runner import auto_fp16
from mmdet.models.detectors.base import BaseDetector as MMDET_BaseDetector
from mmocr.core.data_structures import TextDetDataSample
from mmocr.registry import MODELS
@MODELS.register_module()
class SingleStageTextDetector(SingleStageDetector):
"""The class for implementing single stage text detector."""
class SingleStageTextDetector(MMDET_BaseDetector):
"""The class for implementing single stage text detector.
Single-stage text detectors directly and densely predict bounding boxes or
polygons on the output features of the backbone + neck (optional).
Args:
backbone (dict): Backbone config.
neck (dict, optional): Neck config. If None, the output from backbone
will be directly fed into ``det_head``.
det_head (dict): Head config.
preprocess_cfg (dict, optional): Model preprocessing config
for processing the input image data. Keys allowed are
``to_rgb``(bool), ``pad_size_divisor``(int), ``pad_value``(int or
float), ``mean``(int or float) and ``std``(int or float).
Preprcessing order: 1. to rgb; 2. normalization 3. pad.
Defaults to None.
init_cfg (dict or list[dict], optional): Initialization configs.
Defaults to None.
"""
def __init__(self,
backbone,
neck,
bbox_head,
train_cfg=None,
test_cfg=None,
pretrained=None,
init_cfg=None):
SingleStageDetector.__init__(self, backbone, neck, bbox_head,
train_cfg, test_cfg, pretrained, init_cfg)
backbone: Dict,
det_head: Dict,
neck: Optional[Dict] = None,
preprocess_cfg: Optional[Dict] = None,
init_cfg: Optional[Dict] = None) -> None:
super().__init__(preprocess_cfg=preprocess_cfg, init_cfg=init_cfg)
assert det_head is not None, 'det_head cannot be None!'
self.backbone = MODELS.build(backbone)
if neck is not None:
self.neck = MODELS.build(neck)
self.det_head = MODELS.build(det_head)
def forward_train(self, img, img_metas, **kwargs):
def extract_feat(self, img: torch.Tensor) -> torch.Tensor:
"""Directly extract features from the backbone+neck."""
x = self.backbone(img)
if self.with_neck:
x = self.neck(x)
return x
def forward_train(self, img: torch.Tensor,
data_samples: Sequence[TextDetDataSample]) -> Dict:
"""
Args:
img (Tensor): Input images of shape (N, C, H, W).
img (torch.Tensor): Input images of shape (N, C, H, W).
Typically these should be mean centered and std scaled.
img_metas (list[dict]): A list of image info dict where each dict
has: 'img_shape', 'scale_factor', 'flip', and may also contain
'filename', 'ori_shape', 'pad_shape', and 'img_norm_cfg'.
For details on the values of these keys, see
:class:`mmdet.datasets.pipelines.Collect`.
data_samples (list[TextDetDataSample]): A list of N datasamples,
containing meta information and gold annotations for each of
the images.
Returns:
dict[str, Tensor]: A dictionary of loss components.
"""
x = self.extract_feat(img)
preds = self.bbox_head(x)
losses = self.bbox_head.loss(preds, **kwargs)
preds = self.det_head(x, data_samples)
losses = self.det_head.loss(preds, data_samples)
return losses
def simple_test(self, img, img_metas, rescale=False):
def simple_test(self, img: torch.Tensor,
data_samples: Sequence[TextDetDataSample]
) -> Sequence[TextDetDataSample]:
"""Test function without test-time augmentation.
Args:
img (torch.Tensor): Images of shape (N, C, H, W).
data_samples (list[TextDetDataSample]): A list of N datasamples,
containing meta information and gold annotations for each of
the images.
Returns:
list[TextDetDataSample]: A list of N datasamples of prediction
results. Results are stored in ``pred_instances``.
"""
x = self.extract_feat(img)
outs = self.bbox_head(x)
preds = self.det_head(x, data_samples)
return self.det_head.postprocessor(preds, data_samples)
# early return to avoid post processing
if torch.onnx.is_in_onnx_export():
return outs
def aug_test(
self, imgs: Sequence[torch.Tensor],
data_samples: Sequence[Sequence[TextDetDataSample]]
) -> Sequence[Sequence[TextDetDataSample]]:
"""Test function with test time augmentation."""
raise NotImplementedError
if len(img_metas) > 1:
boundaries = [
self.bbox_head.get_boundary(*(outs[i].unsqueeze(0)),
[img_metas[i]], rescale)
for i in range(len(img_metas))
]
@auto_fp16(apply_to=('imgs', ))
def forward_simple_test(self, imgs: torch.Tensor,
data_samples: Sequence[TextDetDataSample]
) -> Sequence[TextDetDataSample]:
"""Test forward function called by self.forward() when running in test
mode without test time augmentation.
else:
boundaries = [
self.bbox_head.get_boundary(*outs, img_metas, rescale)
]
Though not useful in MMOCR, it has been kept to maintain the maximum
compatibility with MMDetection's BaseDetector.
return boundaries
Args:
img (torch.Tensor): Images of shape (N, C, H, W).
data_samples (list[TextDetDataSample]): A list of N datasamples,
containing meta information and gold annotations for each of
the images.
Returns:
list[TextDetDataSample]: A list of N datasamples of prediction
results. Results are stored in ``pred_instances``.
"""
return self.simple_test(imgs, data_samples)

View File

@ -6,6 +6,7 @@ import mmcv
from mmocr.core import imshow_pred_boundary
# TODO: delete this
class TextDetectorMixin:
"""Base class for text detector, only to show results.