diff --git a/.dev_scripts/covignore.cfg b/.dev_scripts/covignore.cfg index 4f516a7c..94aa8c42 100644 --- a/.dev_scripts/covignore.cfg +++ b/.dev_scripts/covignore.cfg @@ -11,3 +11,10 @@ mmocr/models/textdet/dense_heads/head_mixin.py # Will be covered by det head tests mmocr/models/textdet/dense_heads/base_textdet_head.py + +# They will be removed later all det models have been refactored +mmocr/models/common/detectors/single_stage.py +mmocr/models/textdet/detectors/text_detector_mixin.py + +# It will be covered by tests of any det model implemented in future +mmocr/models/textdet/detectors/single_stage_text_detector.py diff --git a/mmocr/models/common/detectors/single_stage.py b/mmocr/models/common/detectors/single_stage.py index b5336523..c315386f 100644 --- a/mmocr/models/common/detectors/single_stage.py +++ b/mmocr/models/common/detectors/single_stage.py @@ -15,6 +15,7 @@ class SingleStageDetector(MMDET_SingleStageDetector): output features of the backbone+neck. """ + # TODO: Remove this class as SDGMR has been refactored def __init__(self, backbone, neck=None, diff --git a/mmocr/models/textdet/detectors/single_stage_text_detector.py b/mmocr/models/textdet/detectors/single_stage_text_detector.py index 0b0e1a33..9849b52c 100644 --- a/mmocr/models/textdet/detectors/single_stage_text_detector.py +++ b/mmocr/models/textdet/detectors/single_stage_text_detector.py @@ -1,61 +1,117 @@ # Copyright (c) OpenMMLab. All rights reserved. -import torch +from typing import Dict, Optional, Sequence -from mmocr.models.common.detectors import SingleStageDetector +import torch +from mmcv.runner import auto_fp16 +from mmdet.models.detectors.base import BaseDetector as MMDET_BaseDetector + +from mmocr.core.data_structures import TextDetDataSample from mmocr.registry import MODELS @MODELS.register_module() -class SingleStageTextDetector(SingleStageDetector): - """The class for implementing single stage text detector.""" +class SingleStageTextDetector(MMDET_BaseDetector): + """The class for implementing single stage text detector. + + Single-stage text detectors directly and densely predict bounding boxes or + polygons on the output features of the backbone + neck (optional). + + Args: + backbone (dict): Backbone config. + neck (dict, optional): Neck config. If None, the output from backbone + will be directly fed into ``det_head``. + det_head (dict): Head config. + preprocess_cfg (dict, optional): Model preprocessing config + for processing the input image data. Keys allowed are + ``to_rgb``(bool), ``pad_size_divisor``(int), ``pad_value``(int or + float), ``mean``(int or float) and ``std``(int or float). + Preprcessing order: 1. to rgb; 2. normalization 3. pad. + Defaults to None. + init_cfg (dict or list[dict], optional): Initialization configs. + Defaults to None. + """ def __init__(self, - backbone, - neck, - bbox_head, - train_cfg=None, - test_cfg=None, - pretrained=None, - init_cfg=None): - SingleStageDetector.__init__(self, backbone, neck, bbox_head, - train_cfg, test_cfg, pretrained, init_cfg) + backbone: Dict, + det_head: Dict, + neck: Optional[Dict] = None, + preprocess_cfg: Optional[Dict] = None, + init_cfg: Optional[Dict] = None) -> None: + super().__init__(preprocess_cfg=preprocess_cfg, init_cfg=init_cfg) + assert det_head is not None, 'det_head cannot be None!' + self.backbone = MODELS.build(backbone) + if neck is not None: + self.neck = MODELS.build(neck) + self.det_head = MODELS.build(det_head) - def forward_train(self, img, img_metas, **kwargs): + def extract_feat(self, img: torch.Tensor) -> torch.Tensor: + """Directly extract features from the backbone+neck.""" + x = self.backbone(img) + if self.with_neck: + x = self.neck(x) + return x + + def forward_train(self, img: torch.Tensor, + data_samples: Sequence[TextDetDataSample]) -> Dict: """ Args: - img (Tensor): Input images of shape (N, C, H, W). + img (torch.Tensor): Input images of shape (N, C, H, W). Typically these should be mean centered and std scaled. - img_metas (list[dict]): A list of image info dict where each dict - has: 'img_shape', 'scale_factor', 'flip', and may also contain - 'filename', 'ori_shape', 'pad_shape', and 'img_norm_cfg'. - For details on the values of these keys, see - :class:`mmdet.datasets.pipelines.Collect`. + data_samples (list[TextDetDataSample]): A list of N datasamples, + containing meta information and gold annotations for each of + the images. Returns: dict[str, Tensor]: A dictionary of loss components. """ x = self.extract_feat(img) - preds = self.bbox_head(x) - losses = self.bbox_head.loss(preds, **kwargs) + preds = self.det_head(x, data_samples) + losses = self.det_head.loss(preds, data_samples) return losses - def simple_test(self, img, img_metas, rescale=False): + def simple_test(self, img: torch.Tensor, + data_samples: Sequence[TextDetDataSample] + ) -> Sequence[TextDetDataSample]: + """Test function without test-time augmentation. + + Args: + img (torch.Tensor): Images of shape (N, C, H, W). + data_samples (list[TextDetDataSample]): A list of N datasamples, + containing meta information and gold annotations for each of + the images. + + Returns: + list[TextDetDataSample]: A list of N datasamples of prediction + results. Results are stored in ``pred_instances``. + """ x = self.extract_feat(img) - outs = self.bbox_head(x) + preds = self.det_head(x, data_samples) + return self.det_head.postprocessor(preds, data_samples) - # early return to avoid post processing - if torch.onnx.is_in_onnx_export(): - return outs + def aug_test( + self, imgs: Sequence[torch.Tensor], + data_samples: Sequence[Sequence[TextDetDataSample]] + ) -> Sequence[Sequence[TextDetDataSample]]: + """Test function with test time augmentation.""" + raise NotImplementedError - if len(img_metas) > 1: - boundaries = [ - self.bbox_head.get_boundary(*(outs[i].unsqueeze(0)), - [img_metas[i]], rescale) - for i in range(len(img_metas)) - ] + @auto_fp16(apply_to=('imgs', )) + def forward_simple_test(self, imgs: torch.Tensor, + data_samples: Sequence[TextDetDataSample] + ) -> Sequence[TextDetDataSample]: + """Test forward function called by self.forward() when running in test + mode without test time augmentation. - else: - boundaries = [ - self.bbox_head.get_boundary(*outs, img_metas, rescale) - ] + Though not useful in MMOCR, it has been kept to maintain the maximum + compatibility with MMDetection's BaseDetector. - return boundaries + Args: + img (torch.Tensor): Images of shape (N, C, H, W). + data_samples (list[TextDetDataSample]): A list of N datasamples, + containing meta information and gold annotations for each of + the images. + + Returns: + list[TextDetDataSample]: A list of N datasamples of prediction + results. Results are stored in ``pred_instances``. + """ + return self.simple_test(imgs, data_samples) diff --git a/mmocr/models/textdet/detectors/text_detector_mixin.py b/mmocr/models/textdet/detectors/text_detector_mixin.py index e779b266..b243f853 100644 --- a/mmocr/models/textdet/detectors/text_detector_mixin.py +++ b/mmocr/models/textdet/detectors/text_detector_mixin.py @@ -6,6 +6,7 @@ import mmcv from mmocr.core import imshow_pred_boundary +# TODO: delete this class TextDetectorMixin: """Base class for text detector, only to show results.