mirror of
https://github.com/open-mmlab/mmocr.git
synced 2025-06-03 21:54:47 +08:00
Refactor SingleStageTextDetector and update model classes
This commit is contained in:
parent
26da038d49
commit
2c23098b29
@ -11,3 +11,10 @@ mmocr/models/textdet/dense_heads/head_mixin.py
|
|||||||
|
|
||||||
# Will be covered by det head tests
|
# Will be covered by det head tests
|
||||||
mmocr/models/textdet/dense_heads/base_textdet_head.py
|
mmocr/models/textdet/dense_heads/base_textdet_head.py
|
||||||
|
|
||||||
|
# They will be removed later all det models have been refactored
|
||||||
|
mmocr/models/common/detectors/single_stage.py
|
||||||
|
mmocr/models/textdet/detectors/text_detector_mixin.py
|
||||||
|
|
||||||
|
# It will be covered by tests of any det model implemented in future
|
||||||
|
mmocr/models/textdet/detectors/single_stage_text_detector.py
|
||||||
|
@ -15,6 +15,7 @@ class SingleStageDetector(MMDET_SingleStageDetector):
|
|||||||
output features of the backbone+neck.
|
output features of the backbone+neck.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
# TODO: Remove this class as SDGMR has been refactored
|
||||||
def __init__(self,
|
def __init__(self,
|
||||||
backbone,
|
backbone,
|
||||||
neck=None,
|
neck=None,
|
||||||
|
@ -1,61 +1,117 @@
|
|||||||
# Copyright (c) OpenMMLab. All rights reserved.
|
# Copyright (c) OpenMMLab. All rights reserved.
|
||||||
import torch
|
from typing import Dict, Optional, Sequence
|
||||||
|
|
||||||
from mmocr.models.common.detectors import SingleStageDetector
|
import torch
|
||||||
|
from mmcv.runner import auto_fp16
|
||||||
|
from mmdet.models.detectors.base import BaseDetector as MMDET_BaseDetector
|
||||||
|
|
||||||
|
from mmocr.core.data_structures import TextDetDataSample
|
||||||
from mmocr.registry import MODELS
|
from mmocr.registry import MODELS
|
||||||
|
|
||||||
|
|
||||||
@MODELS.register_module()
|
@MODELS.register_module()
|
||||||
class SingleStageTextDetector(SingleStageDetector):
|
class SingleStageTextDetector(MMDET_BaseDetector):
|
||||||
"""The class for implementing single stage text detector."""
|
"""The class for implementing single stage text detector.
|
||||||
|
|
||||||
|
Single-stage text detectors directly and densely predict bounding boxes or
|
||||||
|
polygons on the output features of the backbone + neck (optional).
|
||||||
|
|
||||||
|
Args:
|
||||||
|
backbone (dict): Backbone config.
|
||||||
|
neck (dict, optional): Neck config. If None, the output from backbone
|
||||||
|
will be directly fed into ``det_head``.
|
||||||
|
det_head (dict): Head config.
|
||||||
|
preprocess_cfg (dict, optional): Model preprocessing config
|
||||||
|
for processing the input image data. Keys allowed are
|
||||||
|
``to_rgb``(bool), ``pad_size_divisor``(int), ``pad_value``(int or
|
||||||
|
float), ``mean``(int or float) and ``std``(int or float).
|
||||||
|
Preprcessing order: 1. to rgb; 2. normalization 3. pad.
|
||||||
|
Defaults to None.
|
||||||
|
init_cfg (dict or list[dict], optional): Initialization configs.
|
||||||
|
Defaults to None.
|
||||||
|
"""
|
||||||
|
|
||||||
def __init__(self,
|
def __init__(self,
|
||||||
backbone,
|
backbone: Dict,
|
||||||
neck,
|
det_head: Dict,
|
||||||
bbox_head,
|
neck: Optional[Dict] = None,
|
||||||
train_cfg=None,
|
preprocess_cfg: Optional[Dict] = None,
|
||||||
test_cfg=None,
|
init_cfg: Optional[Dict] = None) -> None:
|
||||||
pretrained=None,
|
super().__init__(preprocess_cfg=preprocess_cfg, init_cfg=init_cfg)
|
||||||
init_cfg=None):
|
assert det_head is not None, 'det_head cannot be None!'
|
||||||
SingleStageDetector.__init__(self, backbone, neck, bbox_head,
|
self.backbone = MODELS.build(backbone)
|
||||||
train_cfg, test_cfg, pretrained, init_cfg)
|
if neck is not None:
|
||||||
|
self.neck = MODELS.build(neck)
|
||||||
|
self.det_head = MODELS.build(det_head)
|
||||||
|
|
||||||
def forward_train(self, img, img_metas, **kwargs):
|
def extract_feat(self, img: torch.Tensor) -> torch.Tensor:
|
||||||
|
"""Directly extract features from the backbone+neck."""
|
||||||
|
x = self.backbone(img)
|
||||||
|
if self.with_neck:
|
||||||
|
x = self.neck(x)
|
||||||
|
return x
|
||||||
|
|
||||||
|
def forward_train(self, img: torch.Tensor,
|
||||||
|
data_samples: Sequence[TextDetDataSample]) -> Dict:
|
||||||
"""
|
"""
|
||||||
Args:
|
Args:
|
||||||
img (Tensor): Input images of shape (N, C, H, W).
|
img (torch.Tensor): Input images of shape (N, C, H, W).
|
||||||
Typically these should be mean centered and std scaled.
|
Typically these should be mean centered and std scaled.
|
||||||
img_metas (list[dict]): A list of image info dict where each dict
|
data_samples (list[TextDetDataSample]): A list of N datasamples,
|
||||||
has: 'img_shape', 'scale_factor', 'flip', and may also contain
|
containing meta information and gold annotations for each of
|
||||||
'filename', 'ori_shape', 'pad_shape', and 'img_norm_cfg'.
|
the images.
|
||||||
For details on the values of these keys, see
|
|
||||||
:class:`mmdet.datasets.pipelines.Collect`.
|
|
||||||
Returns:
|
Returns:
|
||||||
dict[str, Tensor]: A dictionary of loss components.
|
dict[str, Tensor]: A dictionary of loss components.
|
||||||
"""
|
"""
|
||||||
x = self.extract_feat(img)
|
x = self.extract_feat(img)
|
||||||
preds = self.bbox_head(x)
|
preds = self.det_head(x, data_samples)
|
||||||
losses = self.bbox_head.loss(preds, **kwargs)
|
losses = self.det_head.loss(preds, data_samples)
|
||||||
return losses
|
return losses
|
||||||
|
|
||||||
def simple_test(self, img, img_metas, rescale=False):
|
def simple_test(self, img: torch.Tensor,
|
||||||
|
data_samples: Sequence[TextDetDataSample]
|
||||||
|
) -> Sequence[TextDetDataSample]:
|
||||||
|
"""Test function without test-time augmentation.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
img (torch.Tensor): Images of shape (N, C, H, W).
|
||||||
|
data_samples (list[TextDetDataSample]): A list of N datasamples,
|
||||||
|
containing meta information and gold annotations for each of
|
||||||
|
the images.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
list[TextDetDataSample]: A list of N datasamples of prediction
|
||||||
|
results. Results are stored in ``pred_instances``.
|
||||||
|
"""
|
||||||
x = self.extract_feat(img)
|
x = self.extract_feat(img)
|
||||||
outs = self.bbox_head(x)
|
preds = self.det_head(x, data_samples)
|
||||||
|
return self.det_head.postprocessor(preds, data_samples)
|
||||||
|
|
||||||
# early return to avoid post processing
|
def aug_test(
|
||||||
if torch.onnx.is_in_onnx_export():
|
self, imgs: Sequence[torch.Tensor],
|
||||||
return outs
|
data_samples: Sequence[Sequence[TextDetDataSample]]
|
||||||
|
) -> Sequence[Sequence[TextDetDataSample]]:
|
||||||
|
"""Test function with test time augmentation."""
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
if len(img_metas) > 1:
|
@auto_fp16(apply_to=('imgs', ))
|
||||||
boundaries = [
|
def forward_simple_test(self, imgs: torch.Tensor,
|
||||||
self.bbox_head.get_boundary(*(outs[i].unsqueeze(0)),
|
data_samples: Sequence[TextDetDataSample]
|
||||||
[img_metas[i]], rescale)
|
) -> Sequence[TextDetDataSample]:
|
||||||
for i in range(len(img_metas))
|
"""Test forward function called by self.forward() when running in test
|
||||||
]
|
mode without test time augmentation.
|
||||||
|
|
||||||
else:
|
Though not useful in MMOCR, it has been kept to maintain the maximum
|
||||||
boundaries = [
|
compatibility with MMDetection's BaseDetector.
|
||||||
self.bbox_head.get_boundary(*outs, img_metas, rescale)
|
|
||||||
]
|
|
||||||
|
|
||||||
return boundaries
|
Args:
|
||||||
|
img (torch.Tensor): Images of shape (N, C, H, W).
|
||||||
|
data_samples (list[TextDetDataSample]): A list of N datasamples,
|
||||||
|
containing meta information and gold annotations for each of
|
||||||
|
the images.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
list[TextDetDataSample]: A list of N datasamples of prediction
|
||||||
|
results. Results are stored in ``pred_instances``.
|
||||||
|
"""
|
||||||
|
return self.simple_test(imgs, data_samples)
|
||||||
|
@ -6,6 +6,7 @@ import mmcv
|
|||||||
from mmocr.core import imshow_pred_boundary
|
from mmocr.core import imshow_pred_boundary
|
||||||
|
|
||||||
|
|
||||||
|
# TODO: delete this
|
||||||
class TextDetectorMixin:
|
class TextDetectorMixin:
|
||||||
"""Base class for text detector, only to show results.
|
"""Base class for text detector, only to show results.
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user