mirror of https://github.com/open-mmlab/mmocr.git
[Detector] refactor basedetector
parent
d2e8e79df1
commit
b955df9904
|
@ -21,6 +21,6 @@ model = dict(
|
|||
det_head=dict(
|
||||
type='DBHead',
|
||||
in_channels=256,
|
||||
loss=dict(type='DBLoss'),
|
||||
loss_module=dict(type='DBLoss'),
|
||||
postprocessor=dict(type='DBPostprocessor', text_repr_type='quad')),
|
||||
preprocess_cfg=preprocess_cfg)
|
||||
|
|
|
@ -23,6 +23,6 @@ model = dict(
|
|||
det_head=dict(
|
||||
type='DBHead',
|
||||
in_channels=256,
|
||||
loss=dict(type='DBLoss'),
|
||||
loss_module=dict(type='DBLoss'),
|
||||
postprocessor=dict(type='DBPostprocessor', text_repr_type='quad')),
|
||||
preprocess_cfg=preprocess_cfg)
|
||||
|
|
|
@ -20,7 +20,7 @@ model = dict(
|
|||
bbox_head=dict(
|
||||
type='DBHead',
|
||||
in_channels=256,
|
||||
loss=dict(type='DBLoss', alpha=5.0, beta=10.0, bbce_loss=True),
|
||||
loss_module=dict(type='DBLoss', alpha=5.0, beta=10.0, bbce_loss=True),
|
||||
postprocessor=dict(
|
||||
type='DBPostprocessor', text_repr_type='quad',
|
||||
epsilon_ratio=0.002)),
|
||||
|
|
|
@ -17,5 +17,5 @@ model = dict(
|
|||
in_channels=32,
|
||||
text_region_thr=0.3,
|
||||
center_region_thr=0.4,
|
||||
loss=dict(type='DRRGLoss'),
|
||||
loss_module=dict(type='DRRGLoss'),
|
||||
postprocessor=dict(type='DRRGPostprocessor', link_thr=0.80)))
|
||||
|
|
|
@ -23,7 +23,7 @@ model = dict(
|
|||
in_channels=256,
|
||||
scales=(8, 16, 32),
|
||||
fourier_degree=5,
|
||||
loss=dict(type='FCELoss', num_sample=50),
|
||||
loss_module=dict(type='FCELoss', num_sample=50),
|
||||
postprocessor=dict(
|
||||
type='FCEPostprocessor',
|
||||
text_repr_type='quad',
|
||||
|
|
|
@ -25,7 +25,7 @@ model = dict(
|
|||
in_channels=256,
|
||||
scales=(8, 16, 32),
|
||||
fourier_degree=5,
|
||||
loss=dict(type='FCELoss', num_sample=50),
|
||||
loss_module=dict(type='FCELoss', num_sample=50),
|
||||
postprocessor=dict(
|
||||
type='FCEPostprocessor',
|
||||
text_repr_type='poly',
|
||||
|
|
|
@ -1,13 +1,13 @@
|
|||
preprocess_cfg = dict(
|
||||
mean=[123.675, 116.28, 103.53],
|
||||
std=[58.395, 57.12, 57.375],
|
||||
to_rgb=True,
|
||||
pad_size_divisor=32)
|
||||
|
||||
# BasicBlock has a little difference from official PANet
|
||||
# BasicBlock in mmdet lacks RELU in the last convolution.
|
||||
model = dict(
|
||||
type='PANet',
|
||||
data_preprocessor=dict(
|
||||
type='TextDetDataPreprocessor',
|
||||
mean=[123.675, 116.28, 103.53],
|
||||
std=[58.395, 57.12, 57.375],
|
||||
bgr_to_rgb=True,
|
||||
pad_size_divisor=32),
|
||||
backbone=dict(
|
||||
type='mmdet.ResNet',
|
||||
depth=18,
|
||||
|
@ -25,10 +25,9 @@ model = dict(
|
|||
in_channels=[128, 128, 128, 128],
|
||||
hidden_dim=128,
|
||||
out_channel=6,
|
||||
loss=dict(
|
||||
loss_module=dict(
|
||||
type='PANLoss',
|
||||
loss_text=dict(type='MaskedSquareDiceLoss'),
|
||||
loss_kernel=dict(type='MaskedSquareDiceLoss'),
|
||||
),
|
||||
postprocessor=dict(type='PANPostprocessor', text_repr_type='quad')),
|
||||
preprocess_cfg=preprocess_cfg)
|
||||
postprocessor=dict(type='PANPostprocessor', text_repr_type='quad')))
|
||||
|
|
|
@ -15,7 +15,7 @@ model = dict(
|
|||
type='PANHead',
|
||||
in_channels=[128, 128, 128, 128],
|
||||
out_channels=6,
|
||||
loss=dict(type='PANLoss', speedup_bbox_thr=32),
|
||||
loss_module=dict(type='PANLoss', speedup_bbox_thr=32),
|
||||
postprocessor=dict(type='PANPostprocessor', text_repr_type='poly')),
|
||||
train_cfg=None,
|
||||
test_cfg=None)
|
||||
|
|
|
@ -26,7 +26,7 @@ model_poly = dict(
|
|||
in_channels=[256],
|
||||
hidden_dim=256,
|
||||
out_channel=7,
|
||||
loss=dict(type='PSELoss'),
|
||||
loss_module=dict(type='PSELoss'),
|
||||
postprocessor=dict(type='PSEPostprocessor', text_repr_type='poly')),
|
||||
preprocess_cfg=preprocess_cfg)
|
||||
|
||||
|
|
|
@ -21,7 +21,7 @@ model = dict(
|
|||
det_head=dict(
|
||||
type='TextSnakeHead',
|
||||
in_channels=32,
|
||||
loss=dict(type='TextSnakeLoss'),
|
||||
loss_module=dict(type='TextSnakeLoss'),
|
||||
postprocessor=dict(
|
||||
type='TextSnakePostprocessor', text_repr_type='poly')),
|
||||
preprocess_cfg=preprocess_cfg)
|
||||
|
|
|
@ -12,7 +12,7 @@ model = dict(
|
|||
type='CRNNDecoder',
|
||||
in_channels=512,
|
||||
rnn_flag=True,
|
||||
loss=dict(type='CTCLoss', letter_case='lower'),
|
||||
loss_module=dict(type='CTCLoss', letter_case='lower'),
|
||||
postprocessor=dict(type='CTCPostProcessor')),
|
||||
dictionary=dictionary,
|
||||
data_preprocessor=dict(
|
||||
|
|
|
@ -1,5 +1,7 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from . import detectors, heads, losses, necks, postprocessors
|
||||
from . import (data_preprocessors, detectors, heads, losses, necks,
|
||||
postprocessors)
|
||||
from .data_preprocessors import * # NOQA
|
||||
from .detectors import * # NOQA
|
||||
from .heads import * # NOQA
|
||||
from .losses import * # NOQA
|
||||
|
@ -8,4 +10,4 @@ from .postprocessors import * # NOQA
|
|||
|
||||
__all__ = (
|
||||
heads.__all__ + detectors.__all__ + losses.__all__ + necks.__all__ +
|
||||
postprocessors.__all__)
|
||||
postprocessors.__all__ + data_preprocessors.__all__)
|
||||
|
|
|
@ -0,0 +1,4 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from .data_preprocessor import TextDetDataPreprocessor
|
||||
|
||||
__all__ = ['TextDetDataPreprocessor']
|
|
@ -0,0 +1,104 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from numbers import Number
|
||||
from typing import List, Optional, Sequence, Tuple, Union
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from mmengine.model import ImgDataPreprocessor
|
||||
|
||||
from mmocr.registry import MODELS
|
||||
|
||||
|
||||
@MODELS.register_module()
|
||||
class TextDetDataPreprocessor(ImgDataPreprocessor):
|
||||
"""Image pre-processor for detection tasks.
|
||||
|
||||
Comparing with the :class:`mmengine.ImgDataPreprocessor`,
|
||||
|
||||
1. It supports batch augmentations.
|
||||
2. It will additionally append batch_input_shape and pad_shape
|
||||
to data_samples considering the object detection task.
|
||||
|
||||
It provides the data pre-processing as follows
|
||||
|
||||
- Collate and move data to the target device.
|
||||
- Pad inputs to the maximum size of current batch with defined
|
||||
``pad_value``. The padding size can be divisible by a defined
|
||||
``pad_size_divisor``
|
||||
- Stack inputs to batch_inputs.
|
||||
- Convert inputs from bgr to rgb if the shape of input is (3, H, W).
|
||||
- Normalize image with defined std and mean.
|
||||
- Do batch augmentations during training.
|
||||
|
||||
Args:
|
||||
mean (Sequence[Number], optional): The pixel mean of R, G, B channels.
|
||||
Defaults to None.
|
||||
std (Sequence[Number], optional): The pixel standard deviation of
|
||||
R, G, B channels. Defaults to None.
|
||||
pad_size_divisor (int): The size of padded image should be
|
||||
divisible by ``pad_size_divisor``. Defaults to 1.
|
||||
pad_value (Number): The padded pixel value. Defaults to 0.
|
||||
pad_mask (bool): Whether to pad instance masks. Defaults to False.
|
||||
mask_pad_value (int): The padded pixel value for instance masks.
|
||||
Defaults to 0.
|
||||
pad_seg (bool): Whether to pad semantic segmentation maps.
|
||||
Defaults to False.
|
||||
seg_pad_value (int): The padded pixel value for semantic
|
||||
segmentation maps. Defaults to 255.
|
||||
bgr_to_rgb (bool): whether to convert image from BGR to RGB.
|
||||
Defaults to False.
|
||||
rgb_to_bgr (bool): whether to convert image from RGB to RGB.
|
||||
Defaults to False.
|
||||
batch_augments (list[dict], optional): Batch-level augmentations
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
mean: Sequence[Number] = None,
|
||||
std: Sequence[Number] = None,
|
||||
pad_size_divisor: int = 1,
|
||||
pad_value: Union[float, int] = 0,
|
||||
bgr_to_rgb: bool = False,
|
||||
rgb_to_bgr: bool = False,
|
||||
batch_augments: Optional[List[dict]] = None):
|
||||
super().__init__(
|
||||
mean=mean,
|
||||
std=std,
|
||||
pad_size_divisor=pad_size_divisor,
|
||||
pad_value=pad_value,
|
||||
bgr_to_rgb=bgr_to_rgb,
|
||||
rgb_to_bgr=rgb_to_bgr)
|
||||
if batch_augments is not None:
|
||||
self.batch_augments = nn.ModuleList(
|
||||
[MODELS.build(aug) for aug in batch_augments])
|
||||
else:
|
||||
self.batch_augments = None
|
||||
|
||||
def forward(self,
|
||||
data: Sequence[dict],
|
||||
training: bool = False) -> Tuple[torch.Tensor, Optional[list]]:
|
||||
"""Perform normalization、padding and bgr2rgb conversion based on
|
||||
``BaseDataPreprocessor``.
|
||||
|
||||
Args:
|
||||
data (Sequence[dict]): data sampled from dataloader.
|
||||
training (bool): Whether to enable training time augmentation.
|
||||
|
||||
Returns:
|
||||
Tuple[torch.Tensor, Optional[list]]: Data in the same format as the
|
||||
model input.
|
||||
"""
|
||||
batch_inputs, batch_data_samples = super().forward(
|
||||
data=data, training=training)
|
||||
|
||||
if batch_data_samples is not None:
|
||||
batch_input_shape = tuple(batch_inputs[0].size()[-2:])
|
||||
for data_samples in batch_data_samples:
|
||||
data_samples.set_metainfo(
|
||||
{'batch_input_shape': batch_input_shape})
|
||||
|
||||
if training and self.batch_augments is not None:
|
||||
for batch_aug in self.batch_augments:
|
||||
batch_inputs, batch_data_samples = batch_aug(
|
||||
batch_inputs, batch_data_samples)
|
||||
|
||||
return batch_inputs, batch_data_samples
|
|
@ -2,7 +2,6 @@
|
|||
from typing import Dict, Optional, Sequence
|
||||
|
||||
import torch
|
||||
from mmcv.runner import auto_fp16
|
||||
from mmdet.models.detectors.base import BaseDetector as MMDET_BaseDetector
|
||||
|
||||
from mmocr.core.data_structures import TextDetDataSample
|
||||
|
@ -21,7 +20,7 @@ class SingleStageTextDetector(MMDET_BaseDetector):
|
|||
neck (dict, optional): Neck config. If None, the output from backbone
|
||||
will be directly fed into ``det_head``.
|
||||
det_head (dict): Head config.
|
||||
preprocess_cfg (dict, optional): Model preprocessing config
|
||||
data_preprocessor (dict, optional): Model preprocessing config
|
||||
for processing the input image data. Keys allowed are
|
||||
``to_rgb``(bool), ``pad_size_divisor``(int), ``pad_value``(int or
|
||||
float), ``mean``(int or float) and ``std``(int or float).
|
||||
|
@ -35,83 +34,96 @@ class SingleStageTextDetector(MMDET_BaseDetector):
|
|||
backbone: Dict,
|
||||
det_head: Dict,
|
||||
neck: Optional[Dict] = None,
|
||||
preprocess_cfg: Optional[Dict] = None,
|
||||
data_preprocessor: Optional[Dict] = None,
|
||||
init_cfg: Optional[Dict] = None) -> None:
|
||||
super().__init__(preprocess_cfg=preprocess_cfg, init_cfg=init_cfg)
|
||||
super().__init__(
|
||||
data_preprocessor=data_preprocessor, init_cfg=init_cfg)
|
||||
assert det_head is not None, 'det_head cannot be None!'
|
||||
self.backbone = MODELS.build(backbone)
|
||||
if neck is not None:
|
||||
self.neck = MODELS.build(neck)
|
||||
self.det_head = MODELS.build(det_head)
|
||||
|
||||
def extract_feat(self, img: torch.Tensor) -> torch.Tensor:
|
||||
"""Directly extract features from the backbone+neck."""
|
||||
x = self.backbone(img)
|
||||
if self.with_neck:
|
||||
x = self.neck(x)
|
||||
return x
|
||||
def extract_feat(self, batch_inputs: torch.Tensor) -> torch.Tensor:
|
||||
"""Extract features.
|
||||
|
||||
def forward_train(self, img: torch.Tensor,
|
||||
data_samples: Sequence[TextDetDataSample]) -> Dict:
|
||||
"""
|
||||
Args:
|
||||
img (torch.Tensor): Input images of shape (N, C, H, W).
|
||||
batch_inputs (Tensor): Image tensor with shape (N, C, H ,W).
|
||||
|
||||
Returns:
|
||||
Tensor or tuple[Tensor]: Multi-level features that may have
|
||||
different resolutions.
|
||||
"""
|
||||
batch_inputs = self.backbone(batch_inputs)
|
||||
if self.with_neck:
|
||||
batch_inputs = self.neck(batch_inputs)
|
||||
return batch_inputs
|
||||
|
||||
def loss(self, batch_inputs: torch.Tensor,
|
||||
batch_data_samples: Sequence[TextDetDataSample]) -> Dict:
|
||||
"""Calculate losses from a batch of inputs and data samples.
|
||||
|
||||
Args:
|
||||
batch_inputs (torch.Tensor): Input images of shape (N, C, H, W).
|
||||
Typically these should be mean centered and std scaled.
|
||||
data_samples (list[TextDetDataSample]): A list of N datasamples,
|
||||
containing meta information and gold annotations for each of
|
||||
the images.
|
||||
batch_data_samples (list[TextDetDataSample]): A list of N
|
||||
datasamples, containing meta information and gold annotations
|
||||
for each of the images.
|
||||
Returns:
|
||||
dict[str, Tensor]: A dictionary of loss components.
|
||||
"""
|
||||
x = self.extract_feat(img)
|
||||
preds = self.det_head(x, data_samples)
|
||||
losses = self.det_head.loss(preds, data_samples)
|
||||
return losses
|
||||
batch_inputs = self.extract_feat(batch_inputs)
|
||||
return self.det_head.loss(batch_inputs, batch_data_samples)
|
||||
|
||||
def simple_test(self, img: torch.Tensor,
|
||||
data_samples: Sequence[TextDetDataSample]
|
||||
) -> Sequence[TextDetDataSample]:
|
||||
"""Test function without test-time augmentation.
|
||||
def predict(
|
||||
self, batch_inputs: torch.Tensor,
|
||||
batch_data_samples: Sequence[TextDetDataSample]
|
||||
) -> Sequence[TextDetDataSample]:
|
||||
"""Predict results from a batch of inputs and data samples with post-
|
||||
processing.
|
||||
|
||||
Args:
|
||||
img (torch.Tensor): Images of shape (N, C, H, W).
|
||||
data_samples (list[TextDetDataSample]): A list of N datasamples,
|
||||
containing meta information and gold annotations for each of
|
||||
the images.
|
||||
batch_inputs (torch.Tensor): Images of shape (N, C, H, W).
|
||||
batch_data_samples (list[TextDetDataSample]): A list of N
|
||||
datasamples, containing meta information and gold annotations
|
||||
for each of the images.
|
||||
|
||||
Returns:
|
||||
list[TextDetDataSample]: A list of N datasamples of prediction
|
||||
results. Results are stored in ``pred_instances``.
|
||||
results. Each DetDataSample usually contain
|
||||
'pred_instances'. And the ``pred_instances`` usually
|
||||
contains following keys.
|
||||
|
||||
- scores (Tensor): Classification scores, has a shape
|
||||
(num_instance, )
|
||||
- labels (Tensor): Labels of bboxes, has a shape
|
||||
(num_instances, ).
|
||||
- bboxes (Tensor): Has a shape (num_instances, 4),
|
||||
the last dimension 4 arrange as (x1, y1, x2, y2).
|
||||
- polygons (list[np.ndarray]): The length is num_instances.
|
||||
Each element represents the polygon of the
|
||||
instance, in (xn, yn) order.
|
||||
"""
|
||||
x = self.extract_feat(img)
|
||||
preds = self.det_head(x, data_samples)
|
||||
return self.det_head.postprocessor(preds, data_samples)
|
||||
x = self.extract_feat(batch_inputs)
|
||||
return self.det_head.predict(x, batch_data_samples)
|
||||
|
||||
def aug_test(
|
||||
self, imgs: Sequence[torch.Tensor],
|
||||
data_samples: Sequence[Sequence[TextDetDataSample]]
|
||||
) -> Sequence[Sequence[TextDetDataSample]]:
|
||||
"""Test function with test time augmentation."""
|
||||
raise NotImplementedError
|
||||
def _forward(self,
|
||||
batch_inputs: torch.Tensor,
|
||||
batch_data_samples: Optional[
|
||||
Sequence[TextDetDataSample]] = None,
|
||||
**kwargs) -> torch.Tensor:
|
||||
"""Network forward process. Usually includes backbone, neck and head
|
||||
forward without any post-processing.
|
||||
|
||||
@auto_fp16(apply_to=('imgs', ))
|
||||
def forward_simple_test(self, imgs: torch.Tensor,
|
||||
data_samples: Sequence[TextDetDataSample]
|
||||
) -> Sequence[TextDetDataSample]:
|
||||
"""Test forward function called by self.forward() when running in test
|
||||
mode without test time augmentation.
|
||||
|
||||
Though not useful in MMOCR, it has been kept to maintain the maximum
|
||||
compatibility with MMDetection's BaseDetector.
|
||||
|
||||
Args:
|
||||
img (torch.Tensor): Images of shape (N, C, H, W).
|
||||
data_samples (list[TextDetDataSample]): A list of N datasamples,
|
||||
containing meta information and gold annotations for each of
|
||||
the images.
|
||||
Args:
|
||||
batch_inputs (Tensor): Inputs with shape (N, C, H, W).
|
||||
batch_data_samples (list[TextDetDataSample]): A list of N
|
||||
datasamples, containing meta information and gold annotations
|
||||
for each of the images.
|
||||
|
||||
Returns:
|
||||
list[TextDetDataSample]: A list of N datasamples of prediction
|
||||
results. Results are stored in ``pred_instances``.
|
||||
Tensor or tuple[Tensor]: A tuple of features from ``det_head``
|
||||
forward.
|
||||
"""
|
||||
return self.simple_test(imgs, data_samples)
|
||||
x = self.extract_feat(batch_inputs)
|
||||
return self.det_head(x, batch_data_samples)
|
||||
|
|
|
@ -1,15 +1,52 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from typing import Dict, List, Optional, Union
|
||||
from typing import Dict, List, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
from mmcv.runner import BaseModule
|
||||
from torch import Tensor
|
||||
|
||||
from mmocr.core.data_structures import TextDetDataSample
|
||||
from mmocr.registry import MODELS
|
||||
|
||||
SampleList = List[TextDetDataSample]
|
||||
|
||||
|
||||
@MODELS.register_module()
|
||||
class BaseTextDetHead(BaseModule):
|
||||
"""Base head for text detection, build the loss and postprocessor.
|
||||
|
||||
1. The ``init_weights`` method is used to initialize head's
|
||||
model parameters. After detector initialization, ``init_weights``
|
||||
is triggered when ``detector.init_weights()`` is called externally.
|
||||
|
||||
2. The ``loss`` method is used to calculate the loss of head,
|
||||
which includes two steps: (1) the head model performs forward
|
||||
propagation to obtain the feature maps (2) The ``loss_module`` method
|
||||
is called based on the feature maps to calculate the loss.
|
||||
|
||||
.. code:: text
|
||||
|
||||
loss(): forward() -> loss_module()
|
||||
|
||||
3. The ``predict`` method is used to predict detection results,
|
||||
which includes two steps: (1) the head model performs forward
|
||||
propagation to obtain the feature maps (2) The ``postprocessor`` method
|
||||
is called based on the feature maps to predict detection results including
|
||||
post-processing.
|
||||
|
||||
.. code:: text
|
||||
|
||||
predict(): forward() -> postprocessor()
|
||||
|
||||
4. The ``loss_and_predict`` method is used to return loss and detection
|
||||
results at the same time. It will call head's ``forward``,
|
||||
``loss_module`` and ``postprocessor`` methods in order.
|
||||
|
||||
.. code:: text
|
||||
|
||||
loss_and_predict(): forward() -> loss_module() -> postprocessor()
|
||||
|
||||
|
||||
Args:
|
||||
loss (dict): Config to build loss.
|
||||
postprocessor (dict): Config to build postprocessor.
|
||||
|
@ -18,12 +55,75 @@ class BaseTextDetHead(BaseModule):
|
|||
"""
|
||||
|
||||
def __init__(self,
|
||||
loss: Dict,
|
||||
loss_module: Dict,
|
||||
postprocessor: Dict,
|
||||
init_cfg: Optional[Union[Dict, List[Dict]]] = None) -> None:
|
||||
super().__init__(init_cfg=init_cfg)
|
||||
assert isinstance(loss, dict)
|
||||
assert isinstance(loss_module, dict)
|
||||
assert isinstance(postprocessor, dict)
|
||||
|
||||
self.loss = MODELS.build(loss)
|
||||
self.loss_module = MODELS.build(loss_module)
|
||||
self.postprocessor = MODELS.build(postprocessor)
|
||||
|
||||
def loss(self, x: Tuple[Tensor], batch_data_samples: SampleList) -> dict:
|
||||
"""Perform forward propagation and loss calculation of the detection
|
||||
head on the features of the upstream network.
|
||||
|
||||
Args:
|
||||
x (tuple[Tensor]): Features from the upstream network, each is
|
||||
a 4D-tensor.
|
||||
batch_data_samples (List[:obj:`DetDataSample`]): The Data
|
||||
Samples. It usually includes information such as
|
||||
`gt_instance`, `gt_panoptic_seg` and `gt_sem_seg`.
|
||||
|
||||
Returns:
|
||||
dict: A dictionary of loss components.
|
||||
"""
|
||||
outs = self(x)
|
||||
losses = self.loss_module(outs, batch_data_samples)
|
||||
return losses
|
||||
|
||||
def loss_and_predict(self, x: Tuple[Tensor], batch_data_samples: SampleList
|
||||
) -> Tuple[dict, SampleList]:
|
||||
"""Perform forward propagation of the head, then calculate loss and
|
||||
predictions from the features and data samples.
|
||||
|
||||
Args:
|
||||
x (tuple[Tensor]): Features from FPN.
|
||||
batch_data_samples (list[:obj:`DetDataSample`]): Each item contains
|
||||
the meta information of each image and corresponding
|
||||
annotations.
|
||||
|
||||
Returns:
|
||||
tuple: the return value is a tuple contains:
|
||||
|
||||
- losses: (dict[str, Tensor]): A dictionary of loss components.
|
||||
- predictions (list[:obj:`InstanceData`]): Detection
|
||||
results of each image after the post process.
|
||||
"""
|
||||
outs = self(x)
|
||||
losses = self.loss_module(outs, batch_data_samples)
|
||||
|
||||
predictions = self.postprocessor(outs, batch_data_samples)
|
||||
return losses, predictions
|
||||
|
||||
def predict(self, x: torch.Tensor,
|
||||
batch_data_samples: SampleList) -> SampleList:
|
||||
"""Perform forward propagation of the detection head and predict
|
||||
detection results on the features of the upstream network.
|
||||
|
||||
Args:
|
||||
x (tuple[Tensor]): Multi-level features from the
|
||||
upstream network, each is a 4D-tensor.
|
||||
batch_data_samples (List[:obj:`DetDataSample`]): The Data
|
||||
Samples. It usually includes information such as
|
||||
`gt_instance`, `gt_panoptic_seg` and `gt_sem_seg`.
|
||||
|
||||
Returns:
|
||||
SampleList: Detection results of each image
|
||||
after the post process.
|
||||
"""
|
||||
outs = self(x)
|
||||
|
||||
predictions = self.postprocessor(outs, batch_data_samples)
|
||||
return predictions
|
||||
|
|
|
@ -28,7 +28,7 @@ class DBHead(BaseTextDetHead):
|
|||
self,
|
||||
in_channels: int,
|
||||
with_bias: bool = False,
|
||||
loss: Dict = dict(type='DBLoss'),
|
||||
loss_module: Dict = dict(type='DBLoss'),
|
||||
postprocessor: Dict = dict(
|
||||
type='DBPostprocessor', text_repr_type='quad'),
|
||||
init_cfg: Optional[Union[Dict, List[Dict]]] = [
|
||||
|
@ -37,7 +37,9 @@ class DBHead(BaseTextDetHead):
|
|||
]
|
||||
) -> None:
|
||||
super().__init__(
|
||||
loss=loss, postprocessor=postprocessor, init_cfg=init_cfg)
|
||||
loss_module=loss_module,
|
||||
postprocessor=postprocessor,
|
||||
init_cfg=init_cfg)
|
||||
|
||||
assert isinstance(in_channels, int)
|
||||
assert isinstance(with_bias, bool)
|
||||
|
|
|
@ -30,7 +30,7 @@ class FCEHead(BaseTextDetHead):
|
|||
self,
|
||||
in_channels: int,
|
||||
fourier_degree: int = 5,
|
||||
loss: Dict = dict(type='FCELoss', num_sample=50),
|
||||
loss_module: Dict = dict(type='FCELoss', num_sample=50),
|
||||
postprocessor: Dict = dict(
|
||||
type='FCEPostprocessor',
|
||||
text_repr_type='poly',
|
||||
|
@ -45,10 +45,12 @@ class FCEHead(BaseTextDetHead):
|
|||
override=[dict(name='out_conv_cls'),
|
||||
dict(name='out_conv_reg')])
|
||||
) -> None:
|
||||
loss['fourier_degree'] = fourier_degree
|
||||
loss_module['fourier_degree'] = fourier_degree
|
||||
postprocessor['fourier_degree'] = fourier_degree
|
||||
super().__init__(
|
||||
loss=loss, postprocessor=postprocessor, init_cfg=init_cfg)
|
||||
loss_module=loss_module,
|
||||
postprocessor=postprocessor,
|
||||
init_cfg=init_cfg)
|
||||
|
||||
assert isinstance(in_channels, int)
|
||||
assert isinstance(fourier_degree, int)
|
||||
|
|
|
@ -33,7 +33,7 @@ class PANHead(BaseTextDetHead):
|
|||
in_channels: List[int],
|
||||
hidden_dim: int,
|
||||
out_channel: int,
|
||||
loss=dict(type='PANLoss'),
|
||||
loss_module=dict(type='PANLoss'),
|
||||
postprocessor=dict(type='PANPostprocessor', text_repr_type='poly'),
|
||||
init_cfg=[
|
||||
dict(type='Normal', mean=0, std=0.01, layer='Conv2d'),
|
||||
|
@ -41,7 +41,9 @@ class PANHead(BaseTextDetHead):
|
|||
]
|
||||
) -> None:
|
||||
super().__init__(
|
||||
loss=loss, postprocessor=postprocessor, init_cfg=init_cfg)
|
||||
loss_module=loss_module,
|
||||
postprocessor=postprocessor,
|
||||
init_cfg=init_cfg)
|
||||
|
||||
assert check_argument.is_type_list(in_channels, int)
|
||||
assert isinstance(out_channel, int)
|
||||
|
|
|
@ -24,7 +24,7 @@ class PSEHead(PANHead):
|
|||
in_channels: List[int],
|
||||
hidden_dim: int,
|
||||
out_channel: int,
|
||||
loss: Dict = dict(type='PSELoss'),
|
||||
loss_module: Dict = dict(type='PSELoss'),
|
||||
postprocessor: Dict = dict(
|
||||
type='PSEPostprocessor', text_repr_type='poly'),
|
||||
init_cfg: Optional[Union[Dict, List[Dict]]] = None) -> None:
|
||||
|
@ -33,6 +33,6 @@ class PSEHead(PANHead):
|
|||
in_channels=in_channels,
|
||||
hidden_dim=hidden_dim,
|
||||
out_channel=out_channel,
|
||||
loss=loss,
|
||||
loss_module=loss_module,
|
||||
postprocessor=postprocessor,
|
||||
init_cfg=init_cfg)
|
||||
|
|
|
@ -31,14 +31,16 @@ class TextSnakeHead(BaseTextDetHead):
|
|||
in_channels: int,
|
||||
out_channels: int = 5,
|
||||
downsample_ratio: float = 1.0,
|
||||
loss: Dict = dict(type='TextSnakeLoss'),
|
||||
loss_module: Dict = dict(type='TextSnakeLoss'),
|
||||
postprocessor: Dict = dict(
|
||||
type='TextSnakePostprocessor', text_repr_type='poly'),
|
||||
init_cfg: Optional[Union[Dict, List[Dict]]] = dict(
|
||||
type='Normal', override=dict(name='out_conv'), mean=0, std=0.01)
|
||||
) -> None:
|
||||
super().__init__(
|
||||
loss=loss, postprocessor=postprocessor, init_cfg=init_cfg)
|
||||
loss_module=loss_module,
|
||||
postprocessor=postprocessor,
|
||||
init_cfg=init_cfg)
|
||||
assert isinstance(in_channels, int)
|
||||
assert isinstance(out_channels, int)
|
||||
self.in_channels = in_channels
|
||||
|
|
|
@ -0,0 +1,106 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from unittest import TestCase
|
||||
|
||||
import torch
|
||||
|
||||
from mmocr.core import TextDetDataSample
|
||||
from mmocr.models.textdet.data_preprocessors import TextDetDataPreprocessor
|
||||
from mmocr.registry import MODELS
|
||||
|
||||
|
||||
@MODELS.register_module()
|
||||
class TDAugment(torch.nn.Module):
|
||||
|
||||
def forward(self, batch_inputs, batch_data_samples):
|
||||
return batch_inputs, batch_data_samples
|
||||
|
||||
|
||||
class TestTextDetDataPreprocessor(TestCase):
|
||||
|
||||
def test_init(self):
|
||||
# test mean is None
|
||||
processor = TextDetDataPreprocessor()
|
||||
self.assertTrue(not hasattr(processor, 'mean'))
|
||||
self.assertTrue(processor._enable_normalize is False)
|
||||
|
||||
# test mean is not None
|
||||
processor = TextDetDataPreprocessor(mean=[0, 0, 0], std=[1, 1, 1])
|
||||
self.assertTrue(hasattr(processor, 'mean'))
|
||||
self.assertTrue(hasattr(processor, 'std'))
|
||||
self.assertTrue(processor._enable_normalize)
|
||||
|
||||
# please specify both mean and std
|
||||
with self.assertRaises(AssertionError):
|
||||
TextDetDataPreprocessor(mean=[0, 0, 0])
|
||||
|
||||
# bgr2rgb and rgb2bgr cannot be set to True at the same time
|
||||
with self.assertRaises(AssertionError):
|
||||
TextDetDataPreprocessor(bgr_to_rgb=True, rgb_to_bgr=True)
|
||||
|
||||
aug_cfg = [dict(type='TDAugment')]
|
||||
processor = TextDetDataPreprocessor()
|
||||
self.assertIsNone(processor.batch_augments)
|
||||
processor = TextDetDataPreprocessor(batch_augments=aug_cfg)
|
||||
self.assertIsInstance(processor.batch_augments, torch.nn.ModuleList)
|
||||
self.assertIsInstance(processor.batch_augments[0], TDAugment)
|
||||
|
||||
def test_forward(self):
|
||||
processor = TextDetDataPreprocessor(mean=[0, 0, 0], std=[1, 1, 1])
|
||||
|
||||
data = [{
|
||||
'inputs':
|
||||
torch.randint(0, 256, (3, 11, 10)),
|
||||
'data_sample':
|
||||
TextDetDataSample(
|
||||
metainfo=dict(img_shape=(11, 10), valid_ratio=1.0))
|
||||
}]
|
||||
inputs, data_samples = processor(data)
|
||||
print(inputs.dtype)
|
||||
self.assertEqual(inputs.shape, (1, 3, 11, 10))
|
||||
self.assertEqual(len(data_samples), 1)
|
||||
|
||||
# test channel_conversion
|
||||
processor = TextDetDataPreprocessor(
|
||||
mean=[0., 0., 0.], std=[1., 1., 1.], bgr_to_rgb=True)
|
||||
inputs, data_samples = processor(data)
|
||||
self.assertEqual(inputs.shape, (1, 3, 11, 10))
|
||||
self.assertEqual(len(data_samples), 1)
|
||||
|
||||
# test padding
|
||||
data = [{
|
||||
'inputs': torch.randint(0, 256, (3, 10, 11))
|
||||
}, {
|
||||
'inputs': torch.randint(0, 256, (3, 9, 14))
|
||||
}]
|
||||
processor = TextDetDataPreprocessor(
|
||||
mean=[0., 0., 0.], std=[1., 1., 1.], bgr_to_rgb=True)
|
||||
inputs, data_samples = processor(data)
|
||||
self.assertEqual(inputs.shape, (2, 3, 10, 14))
|
||||
self.assertIsNone(data_samples)
|
||||
|
||||
# test pad_size_divisor
|
||||
data = [{
|
||||
'inputs':
|
||||
torch.randint(0, 256, (3, 10, 11)),
|
||||
'data_sample':
|
||||
TextDetDataSample(
|
||||
metainfo=dict(img_shape=(10, 11), valid_ratio=1.0))
|
||||
}, {
|
||||
'inputs':
|
||||
torch.randint(0, 256, (3, 9, 24)),
|
||||
'data_sample':
|
||||
TextDetDataSample(
|
||||
metainfo=dict(img_shape=(9, 24), valid_ratio=1.0))
|
||||
}]
|
||||
aug_cfg = [dict(type='TDAugment')]
|
||||
processor = TextDetDataPreprocessor(
|
||||
mean=[0., 0., 0.],
|
||||
std=[1., 1., 1.],
|
||||
pad_size_divisor=5,
|
||||
batch_augments=aug_cfg)
|
||||
inputs, data_samples = processor(data, training=True)
|
||||
self.assertEqual(inputs.shape, (2, 3, 10, 25))
|
||||
self.assertEqual(len(data_samples), 2)
|
||||
for data_sample, expected_shape in zip(data_samples, [(10, 25),
|
||||
(10, 25)]):
|
||||
self.assertEqual(data_sample.batch_input_shape, expected_shape)
|
|
@ -0,0 +1,59 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from unittest import TestCase, mock
|
||||
|
||||
from mmocr.models.textdet import BaseTextDetHead
|
||||
from mmocr.registry import MODELS
|
||||
|
||||
|
||||
@MODELS.register_module()
|
||||
class FakeModule:
|
||||
|
||||
def __init__(self) -> None:
|
||||
pass
|
||||
|
||||
def get_targets(self, datasamples):
|
||||
return None
|
||||
|
||||
def __call__(self, *args):
|
||||
return None
|
||||
|
||||
|
||||
class TestBaseTextDetHead(TestCase):
|
||||
|
||||
def test_init(self):
|
||||
cfg = dict(type='FakeModule')
|
||||
|
||||
with self.assertRaises(AssertionError):
|
||||
BaseTextDetHead([], cfg)
|
||||
with self.assertRaises(AssertionError):
|
||||
BaseTextDetHead(cfg, [])
|
||||
|
||||
decoder = BaseTextDetHead(cfg, cfg)
|
||||
self.assertIsInstance(decoder.loss_module, FakeModule)
|
||||
self.assertIsInstance(decoder.postprocessor, FakeModule)
|
||||
|
||||
@mock.patch(f'{__name__}.BaseTextDetHead.forward')
|
||||
def test_forward(self, mock_forward):
|
||||
|
||||
def mock_forward(feat, out_enc, datasamples):
|
||||
|
||||
return True
|
||||
|
||||
mock_forward.side_effect = mock_forward
|
||||
cfg = dict(type='FakeModule')
|
||||
decoder = BaseTextDetHead(cfg, cfg)
|
||||
# test loss
|
||||
loss = decoder.loss(None, None)
|
||||
self.assertIsNone(loss)
|
||||
|
||||
# test predict
|
||||
predict = decoder.predict(None, None)
|
||||
self.assertIsNone(predict)
|
||||
|
||||
# test forward
|
||||
tensor = decoder(None, None)
|
||||
self.assertTrue(tensor)
|
||||
|
||||
loss, predict = decoder.loss_and_predict(None, None)
|
||||
self.assertIsNone(loss)
|
||||
self.assertIsNone(predict)
|
Loading…
Reference in New Issue