[Detector] refactor basedetector

pull/1178/head
liukuikun 2022-07-04 03:14:31 +00:00 committed by gaotongxiao
parent d2e8e79df1
commit b955df9904
23 changed files with 487 additions and 93 deletions

View File

@ -21,6 +21,6 @@ model = dict(
det_head=dict(
type='DBHead',
in_channels=256,
loss=dict(type='DBLoss'),
loss_module=dict(type='DBLoss'),
postprocessor=dict(type='DBPostprocessor', text_repr_type='quad')),
preprocess_cfg=preprocess_cfg)

View File

@ -23,6 +23,6 @@ model = dict(
det_head=dict(
type='DBHead',
in_channels=256,
loss=dict(type='DBLoss'),
loss_module=dict(type='DBLoss'),
postprocessor=dict(type='DBPostprocessor', text_repr_type='quad')),
preprocess_cfg=preprocess_cfg)

View File

@ -20,7 +20,7 @@ model = dict(
bbox_head=dict(
type='DBHead',
in_channels=256,
loss=dict(type='DBLoss', alpha=5.0, beta=10.0, bbce_loss=True),
loss_module=dict(type='DBLoss', alpha=5.0, beta=10.0, bbce_loss=True),
postprocessor=dict(
type='DBPostprocessor', text_repr_type='quad',
epsilon_ratio=0.002)),

View File

@ -17,5 +17,5 @@ model = dict(
in_channels=32,
text_region_thr=0.3,
center_region_thr=0.4,
loss=dict(type='DRRGLoss'),
loss_module=dict(type='DRRGLoss'),
postprocessor=dict(type='DRRGPostprocessor', link_thr=0.80)))

View File

@ -23,7 +23,7 @@ model = dict(
in_channels=256,
scales=(8, 16, 32),
fourier_degree=5,
loss=dict(type='FCELoss', num_sample=50),
loss_module=dict(type='FCELoss', num_sample=50),
postprocessor=dict(
type='FCEPostprocessor',
text_repr_type='quad',

View File

@ -25,7 +25,7 @@ model = dict(
in_channels=256,
scales=(8, 16, 32),
fourier_degree=5,
loss=dict(type='FCELoss', num_sample=50),
loss_module=dict(type='FCELoss', num_sample=50),
postprocessor=dict(
type='FCEPostprocessor',
text_repr_type='poly',

View File

@ -1,13 +1,13 @@
preprocess_cfg = dict(
mean=[123.675, 116.28, 103.53],
std=[58.395, 57.12, 57.375],
to_rgb=True,
pad_size_divisor=32)
# BasicBlock has a little difference from official PANet
# BasicBlock in mmdet lacks RELU in the last convolution.
model = dict(
type='PANet',
data_preprocessor=dict(
type='TextDetDataPreprocessor',
mean=[123.675, 116.28, 103.53],
std=[58.395, 57.12, 57.375],
bgr_to_rgb=True,
pad_size_divisor=32),
backbone=dict(
type='mmdet.ResNet',
depth=18,
@ -25,10 +25,9 @@ model = dict(
in_channels=[128, 128, 128, 128],
hidden_dim=128,
out_channel=6,
loss=dict(
loss_module=dict(
type='PANLoss',
loss_text=dict(type='MaskedSquareDiceLoss'),
loss_kernel=dict(type='MaskedSquareDiceLoss'),
),
postprocessor=dict(type='PANPostprocessor', text_repr_type='quad')),
preprocess_cfg=preprocess_cfg)
postprocessor=dict(type='PANPostprocessor', text_repr_type='quad')))

View File

@ -15,7 +15,7 @@ model = dict(
type='PANHead',
in_channels=[128, 128, 128, 128],
out_channels=6,
loss=dict(type='PANLoss', speedup_bbox_thr=32),
loss_module=dict(type='PANLoss', speedup_bbox_thr=32),
postprocessor=dict(type='PANPostprocessor', text_repr_type='poly')),
train_cfg=None,
test_cfg=None)

View File

@ -26,7 +26,7 @@ model_poly = dict(
in_channels=[256],
hidden_dim=256,
out_channel=7,
loss=dict(type='PSELoss'),
loss_module=dict(type='PSELoss'),
postprocessor=dict(type='PSEPostprocessor', text_repr_type='poly')),
preprocess_cfg=preprocess_cfg)

View File

@ -21,7 +21,7 @@ model = dict(
det_head=dict(
type='TextSnakeHead',
in_channels=32,
loss=dict(type='TextSnakeLoss'),
loss_module=dict(type='TextSnakeLoss'),
postprocessor=dict(
type='TextSnakePostprocessor', text_repr_type='poly')),
preprocess_cfg=preprocess_cfg)

View File

@ -12,7 +12,7 @@ model = dict(
type='CRNNDecoder',
in_channels=512,
rnn_flag=True,
loss=dict(type='CTCLoss', letter_case='lower'),
loss_module=dict(type='CTCLoss', letter_case='lower'),
postprocessor=dict(type='CTCPostProcessor')),
dictionary=dictionary,
data_preprocessor=dict(

View File

@ -1,5 +1,7 @@
# Copyright (c) OpenMMLab. All rights reserved.
from . import detectors, heads, losses, necks, postprocessors
from . import (data_preprocessors, detectors, heads, losses, necks,
postprocessors)
from .data_preprocessors import * # NOQA
from .detectors import * # NOQA
from .heads import * # NOQA
from .losses import * # NOQA
@ -8,4 +10,4 @@ from .postprocessors import * # NOQA
__all__ = (
heads.__all__ + detectors.__all__ + losses.__all__ + necks.__all__ +
postprocessors.__all__)
postprocessors.__all__ + data_preprocessors.__all__)

View File

@ -0,0 +1,4 @@
# Copyright (c) OpenMMLab. All rights reserved.
from .data_preprocessor import TextDetDataPreprocessor
__all__ = ['TextDetDataPreprocessor']

View File

@ -0,0 +1,104 @@
# Copyright (c) OpenMMLab. All rights reserved.
from numbers import Number
from typing import List, Optional, Sequence, Tuple, Union
import torch
import torch.nn as nn
from mmengine.model import ImgDataPreprocessor
from mmocr.registry import MODELS
@MODELS.register_module()
class TextDetDataPreprocessor(ImgDataPreprocessor):
"""Image pre-processor for detection tasks.
Comparing with the :class:`mmengine.ImgDataPreprocessor`,
1. It supports batch augmentations.
2. It will additionally append batch_input_shape and pad_shape
to data_samples considering the object detection task.
It provides the data pre-processing as follows
- Collate and move data to the target device.
- Pad inputs to the maximum size of current batch with defined
``pad_value``. The padding size can be divisible by a defined
``pad_size_divisor``
- Stack inputs to batch_inputs.
- Convert inputs from bgr to rgb if the shape of input is (3, H, W).
- Normalize image with defined std and mean.
- Do batch augmentations during training.
Args:
mean (Sequence[Number], optional): The pixel mean of R, G, B channels.
Defaults to None.
std (Sequence[Number], optional): The pixel standard deviation of
R, G, B channels. Defaults to None.
pad_size_divisor (int): The size of padded image should be
divisible by ``pad_size_divisor``. Defaults to 1.
pad_value (Number): The padded pixel value. Defaults to 0.
pad_mask (bool): Whether to pad instance masks. Defaults to False.
mask_pad_value (int): The padded pixel value for instance masks.
Defaults to 0.
pad_seg (bool): Whether to pad semantic segmentation maps.
Defaults to False.
seg_pad_value (int): The padded pixel value for semantic
segmentation maps. Defaults to 255.
bgr_to_rgb (bool): whether to convert image from BGR to RGB.
Defaults to False.
rgb_to_bgr (bool): whether to convert image from RGB to RGB.
Defaults to False.
batch_augments (list[dict], optional): Batch-level augmentations
"""
def __init__(self,
mean: Sequence[Number] = None,
std: Sequence[Number] = None,
pad_size_divisor: int = 1,
pad_value: Union[float, int] = 0,
bgr_to_rgb: bool = False,
rgb_to_bgr: bool = False,
batch_augments: Optional[List[dict]] = None):
super().__init__(
mean=mean,
std=std,
pad_size_divisor=pad_size_divisor,
pad_value=pad_value,
bgr_to_rgb=bgr_to_rgb,
rgb_to_bgr=rgb_to_bgr)
if batch_augments is not None:
self.batch_augments = nn.ModuleList(
[MODELS.build(aug) for aug in batch_augments])
else:
self.batch_augments = None
def forward(self,
data: Sequence[dict],
training: bool = False) -> Tuple[torch.Tensor, Optional[list]]:
"""Perform normalization、padding and bgr2rgb conversion based on
``BaseDataPreprocessor``.
Args:
data (Sequence[dict]): data sampled from dataloader.
training (bool): Whether to enable training time augmentation.
Returns:
Tuple[torch.Tensor, Optional[list]]: Data in the same format as the
model input.
"""
batch_inputs, batch_data_samples = super().forward(
data=data, training=training)
if batch_data_samples is not None:
batch_input_shape = tuple(batch_inputs[0].size()[-2:])
for data_samples in batch_data_samples:
data_samples.set_metainfo(
{'batch_input_shape': batch_input_shape})
if training and self.batch_augments is not None:
for batch_aug in self.batch_augments:
batch_inputs, batch_data_samples = batch_aug(
batch_inputs, batch_data_samples)
return batch_inputs, batch_data_samples

View File

@ -2,7 +2,6 @@
from typing import Dict, Optional, Sequence
import torch
from mmcv.runner import auto_fp16
from mmdet.models.detectors.base import BaseDetector as MMDET_BaseDetector
from mmocr.core.data_structures import TextDetDataSample
@ -21,7 +20,7 @@ class SingleStageTextDetector(MMDET_BaseDetector):
neck (dict, optional): Neck config. If None, the output from backbone
will be directly fed into ``det_head``.
det_head (dict): Head config.
preprocess_cfg (dict, optional): Model preprocessing config
data_preprocessor (dict, optional): Model preprocessing config
for processing the input image data. Keys allowed are
``to_rgb``(bool), ``pad_size_divisor``(int), ``pad_value``(int or
float), ``mean``(int or float) and ``std``(int or float).
@ -35,83 +34,96 @@ class SingleStageTextDetector(MMDET_BaseDetector):
backbone: Dict,
det_head: Dict,
neck: Optional[Dict] = None,
preprocess_cfg: Optional[Dict] = None,
data_preprocessor: Optional[Dict] = None,
init_cfg: Optional[Dict] = None) -> None:
super().__init__(preprocess_cfg=preprocess_cfg, init_cfg=init_cfg)
super().__init__(
data_preprocessor=data_preprocessor, init_cfg=init_cfg)
assert det_head is not None, 'det_head cannot be None!'
self.backbone = MODELS.build(backbone)
if neck is not None:
self.neck = MODELS.build(neck)
self.det_head = MODELS.build(det_head)
def extract_feat(self, img: torch.Tensor) -> torch.Tensor:
"""Directly extract features from the backbone+neck."""
x = self.backbone(img)
if self.with_neck:
x = self.neck(x)
return x
def extract_feat(self, batch_inputs: torch.Tensor) -> torch.Tensor:
"""Extract features.
def forward_train(self, img: torch.Tensor,
data_samples: Sequence[TextDetDataSample]) -> Dict:
"""
Args:
img (torch.Tensor): Input images of shape (N, C, H, W).
batch_inputs (Tensor): Image tensor with shape (N, C, H ,W).
Returns:
Tensor or tuple[Tensor]: Multi-level features that may have
different resolutions.
"""
batch_inputs = self.backbone(batch_inputs)
if self.with_neck:
batch_inputs = self.neck(batch_inputs)
return batch_inputs
def loss(self, batch_inputs: torch.Tensor,
batch_data_samples: Sequence[TextDetDataSample]) -> Dict:
"""Calculate losses from a batch of inputs and data samples.
Args:
batch_inputs (torch.Tensor): Input images of shape (N, C, H, W).
Typically these should be mean centered and std scaled.
data_samples (list[TextDetDataSample]): A list of N datasamples,
containing meta information and gold annotations for each of
the images.
batch_data_samples (list[TextDetDataSample]): A list of N
datasamples, containing meta information and gold annotations
for each of the images.
Returns:
dict[str, Tensor]: A dictionary of loss components.
"""
x = self.extract_feat(img)
preds = self.det_head(x, data_samples)
losses = self.det_head.loss(preds, data_samples)
return losses
batch_inputs = self.extract_feat(batch_inputs)
return self.det_head.loss(batch_inputs, batch_data_samples)
def simple_test(self, img: torch.Tensor,
data_samples: Sequence[TextDetDataSample]
) -> Sequence[TextDetDataSample]:
"""Test function without test-time augmentation.
def predict(
self, batch_inputs: torch.Tensor,
batch_data_samples: Sequence[TextDetDataSample]
) -> Sequence[TextDetDataSample]:
"""Predict results from a batch of inputs and data samples with post-
processing.
Args:
img (torch.Tensor): Images of shape (N, C, H, W).
data_samples (list[TextDetDataSample]): A list of N datasamples,
containing meta information and gold annotations for each of
the images.
batch_inputs (torch.Tensor): Images of shape (N, C, H, W).
batch_data_samples (list[TextDetDataSample]): A list of N
datasamples, containing meta information and gold annotations
for each of the images.
Returns:
list[TextDetDataSample]: A list of N datasamples of prediction
results. Results are stored in ``pred_instances``.
results. Each DetDataSample usually contain
'pred_instances'. And the ``pred_instances`` usually
contains following keys.
- scores (Tensor): Classification scores, has a shape
(num_instance, )
- labels (Tensor): Labels of bboxes, has a shape
(num_instances, ).
- bboxes (Tensor): Has a shape (num_instances, 4),
the last dimension 4 arrange as (x1, y1, x2, y2).
- polygons (list[np.ndarray]): The length is num_instances.
Each element represents the polygon of the
instance, in (xn, yn) order.
"""
x = self.extract_feat(img)
preds = self.det_head(x, data_samples)
return self.det_head.postprocessor(preds, data_samples)
x = self.extract_feat(batch_inputs)
return self.det_head.predict(x, batch_data_samples)
def aug_test(
self, imgs: Sequence[torch.Tensor],
data_samples: Sequence[Sequence[TextDetDataSample]]
) -> Sequence[Sequence[TextDetDataSample]]:
"""Test function with test time augmentation."""
raise NotImplementedError
def _forward(self,
batch_inputs: torch.Tensor,
batch_data_samples: Optional[
Sequence[TextDetDataSample]] = None,
**kwargs) -> torch.Tensor:
"""Network forward process. Usually includes backbone, neck and head
forward without any post-processing.
@auto_fp16(apply_to=('imgs', ))
def forward_simple_test(self, imgs: torch.Tensor,
data_samples: Sequence[TextDetDataSample]
) -> Sequence[TextDetDataSample]:
"""Test forward function called by self.forward() when running in test
mode without test time augmentation.
Though not useful in MMOCR, it has been kept to maintain the maximum
compatibility with MMDetection's BaseDetector.
Args:
img (torch.Tensor): Images of shape (N, C, H, W).
data_samples (list[TextDetDataSample]): A list of N datasamples,
containing meta information and gold annotations for each of
the images.
Args:
batch_inputs (Tensor): Inputs with shape (N, C, H, W).
batch_data_samples (list[TextDetDataSample]): A list of N
datasamples, containing meta information and gold annotations
for each of the images.
Returns:
list[TextDetDataSample]: A list of N datasamples of prediction
results. Results are stored in ``pred_instances``.
Tensor or tuple[Tensor]: A tuple of features from ``det_head``
forward.
"""
return self.simple_test(imgs, data_samples)
x = self.extract_feat(batch_inputs)
return self.det_head(x, batch_data_samples)

View File

@ -1,15 +1,52 @@
# Copyright (c) OpenMMLab. All rights reserved.
from typing import Dict, List, Optional, Union
from typing import Dict, List, Optional, Tuple, Union
import torch
from mmcv.runner import BaseModule
from torch import Tensor
from mmocr.core.data_structures import TextDetDataSample
from mmocr.registry import MODELS
SampleList = List[TextDetDataSample]
@MODELS.register_module()
class BaseTextDetHead(BaseModule):
"""Base head for text detection, build the loss and postprocessor.
1. The ``init_weights`` method is used to initialize head's
model parameters. After detector initialization, ``init_weights``
is triggered when ``detector.init_weights()`` is called externally.
2. The ``loss`` method is used to calculate the loss of head,
which includes two steps: (1) the head model performs forward
propagation to obtain the feature maps (2) The ``loss_module`` method
is called based on the feature maps to calculate the loss.
.. code:: text
loss(): forward() -> loss_module()
3. The ``predict`` method is used to predict detection results,
which includes two steps: (1) the head model performs forward
propagation to obtain the feature maps (2) The ``postprocessor`` method
is called based on the feature maps to predict detection results including
post-processing.
.. code:: text
predict(): forward() -> postprocessor()
4. The ``loss_and_predict`` method is used to return loss and detection
results at the same time. It will call head's ``forward``,
``loss_module`` and ``postprocessor`` methods in order.
.. code:: text
loss_and_predict(): forward() -> loss_module() -> postprocessor()
Args:
loss (dict): Config to build loss.
postprocessor (dict): Config to build postprocessor.
@ -18,12 +55,75 @@ class BaseTextDetHead(BaseModule):
"""
def __init__(self,
loss: Dict,
loss_module: Dict,
postprocessor: Dict,
init_cfg: Optional[Union[Dict, List[Dict]]] = None) -> None:
super().__init__(init_cfg=init_cfg)
assert isinstance(loss, dict)
assert isinstance(loss_module, dict)
assert isinstance(postprocessor, dict)
self.loss = MODELS.build(loss)
self.loss_module = MODELS.build(loss_module)
self.postprocessor = MODELS.build(postprocessor)
def loss(self, x: Tuple[Tensor], batch_data_samples: SampleList) -> dict:
"""Perform forward propagation and loss calculation of the detection
head on the features of the upstream network.
Args:
x (tuple[Tensor]): Features from the upstream network, each is
a 4D-tensor.
batch_data_samples (List[:obj:`DetDataSample`]): The Data
Samples. It usually includes information such as
`gt_instance`, `gt_panoptic_seg` and `gt_sem_seg`.
Returns:
dict: A dictionary of loss components.
"""
outs = self(x)
losses = self.loss_module(outs, batch_data_samples)
return losses
def loss_and_predict(self, x: Tuple[Tensor], batch_data_samples: SampleList
) -> Tuple[dict, SampleList]:
"""Perform forward propagation of the head, then calculate loss and
predictions from the features and data samples.
Args:
x (tuple[Tensor]): Features from FPN.
batch_data_samples (list[:obj:`DetDataSample`]): Each item contains
the meta information of each image and corresponding
annotations.
Returns:
tuple: the return value is a tuple contains:
- losses: (dict[str, Tensor]): A dictionary of loss components.
- predictions (list[:obj:`InstanceData`]): Detection
results of each image after the post process.
"""
outs = self(x)
losses = self.loss_module(outs, batch_data_samples)
predictions = self.postprocessor(outs, batch_data_samples)
return losses, predictions
def predict(self, x: torch.Tensor,
batch_data_samples: SampleList) -> SampleList:
"""Perform forward propagation of the detection head and predict
detection results on the features of the upstream network.
Args:
x (tuple[Tensor]): Multi-level features from the
upstream network, each is a 4D-tensor.
batch_data_samples (List[:obj:`DetDataSample`]): The Data
Samples. It usually includes information such as
`gt_instance`, `gt_panoptic_seg` and `gt_sem_seg`.
Returns:
SampleList: Detection results of each image
after the post process.
"""
outs = self(x)
predictions = self.postprocessor(outs, batch_data_samples)
return predictions

View File

@ -28,7 +28,7 @@ class DBHead(BaseTextDetHead):
self,
in_channels: int,
with_bias: bool = False,
loss: Dict = dict(type='DBLoss'),
loss_module: Dict = dict(type='DBLoss'),
postprocessor: Dict = dict(
type='DBPostprocessor', text_repr_type='quad'),
init_cfg: Optional[Union[Dict, List[Dict]]] = [
@ -37,7 +37,9 @@ class DBHead(BaseTextDetHead):
]
) -> None:
super().__init__(
loss=loss, postprocessor=postprocessor, init_cfg=init_cfg)
loss_module=loss_module,
postprocessor=postprocessor,
init_cfg=init_cfg)
assert isinstance(in_channels, int)
assert isinstance(with_bias, bool)

View File

@ -30,7 +30,7 @@ class FCEHead(BaseTextDetHead):
self,
in_channels: int,
fourier_degree: int = 5,
loss: Dict = dict(type='FCELoss', num_sample=50),
loss_module: Dict = dict(type='FCELoss', num_sample=50),
postprocessor: Dict = dict(
type='FCEPostprocessor',
text_repr_type='poly',
@ -45,10 +45,12 @@ class FCEHead(BaseTextDetHead):
override=[dict(name='out_conv_cls'),
dict(name='out_conv_reg')])
) -> None:
loss['fourier_degree'] = fourier_degree
loss_module['fourier_degree'] = fourier_degree
postprocessor['fourier_degree'] = fourier_degree
super().__init__(
loss=loss, postprocessor=postprocessor, init_cfg=init_cfg)
loss_module=loss_module,
postprocessor=postprocessor,
init_cfg=init_cfg)
assert isinstance(in_channels, int)
assert isinstance(fourier_degree, int)

View File

@ -33,7 +33,7 @@ class PANHead(BaseTextDetHead):
in_channels: List[int],
hidden_dim: int,
out_channel: int,
loss=dict(type='PANLoss'),
loss_module=dict(type='PANLoss'),
postprocessor=dict(type='PANPostprocessor', text_repr_type='poly'),
init_cfg=[
dict(type='Normal', mean=0, std=0.01, layer='Conv2d'),
@ -41,7 +41,9 @@ class PANHead(BaseTextDetHead):
]
) -> None:
super().__init__(
loss=loss, postprocessor=postprocessor, init_cfg=init_cfg)
loss_module=loss_module,
postprocessor=postprocessor,
init_cfg=init_cfg)
assert check_argument.is_type_list(in_channels, int)
assert isinstance(out_channel, int)

View File

@ -24,7 +24,7 @@ class PSEHead(PANHead):
in_channels: List[int],
hidden_dim: int,
out_channel: int,
loss: Dict = dict(type='PSELoss'),
loss_module: Dict = dict(type='PSELoss'),
postprocessor: Dict = dict(
type='PSEPostprocessor', text_repr_type='poly'),
init_cfg: Optional[Union[Dict, List[Dict]]] = None) -> None:
@ -33,6 +33,6 @@ class PSEHead(PANHead):
in_channels=in_channels,
hidden_dim=hidden_dim,
out_channel=out_channel,
loss=loss,
loss_module=loss_module,
postprocessor=postprocessor,
init_cfg=init_cfg)

View File

@ -31,14 +31,16 @@ class TextSnakeHead(BaseTextDetHead):
in_channels: int,
out_channels: int = 5,
downsample_ratio: float = 1.0,
loss: Dict = dict(type='TextSnakeLoss'),
loss_module: Dict = dict(type='TextSnakeLoss'),
postprocessor: Dict = dict(
type='TextSnakePostprocessor', text_repr_type='poly'),
init_cfg: Optional[Union[Dict, List[Dict]]] = dict(
type='Normal', override=dict(name='out_conv'), mean=0, std=0.01)
) -> None:
super().__init__(
loss=loss, postprocessor=postprocessor, init_cfg=init_cfg)
loss_module=loss_module,
postprocessor=postprocessor,
init_cfg=init_cfg)
assert isinstance(in_channels, int)
assert isinstance(out_channels, int)
self.in_channels = in_channels

View File

@ -0,0 +1,106 @@
# Copyright (c) OpenMMLab. All rights reserved.
from unittest import TestCase
import torch
from mmocr.core import TextDetDataSample
from mmocr.models.textdet.data_preprocessors import TextDetDataPreprocessor
from mmocr.registry import MODELS
@MODELS.register_module()
class TDAugment(torch.nn.Module):
def forward(self, batch_inputs, batch_data_samples):
return batch_inputs, batch_data_samples
class TestTextDetDataPreprocessor(TestCase):
def test_init(self):
# test mean is None
processor = TextDetDataPreprocessor()
self.assertTrue(not hasattr(processor, 'mean'))
self.assertTrue(processor._enable_normalize is False)
# test mean is not None
processor = TextDetDataPreprocessor(mean=[0, 0, 0], std=[1, 1, 1])
self.assertTrue(hasattr(processor, 'mean'))
self.assertTrue(hasattr(processor, 'std'))
self.assertTrue(processor._enable_normalize)
# please specify both mean and std
with self.assertRaises(AssertionError):
TextDetDataPreprocessor(mean=[0, 0, 0])
# bgr2rgb and rgb2bgr cannot be set to True at the same time
with self.assertRaises(AssertionError):
TextDetDataPreprocessor(bgr_to_rgb=True, rgb_to_bgr=True)
aug_cfg = [dict(type='TDAugment')]
processor = TextDetDataPreprocessor()
self.assertIsNone(processor.batch_augments)
processor = TextDetDataPreprocessor(batch_augments=aug_cfg)
self.assertIsInstance(processor.batch_augments, torch.nn.ModuleList)
self.assertIsInstance(processor.batch_augments[0], TDAugment)
def test_forward(self):
processor = TextDetDataPreprocessor(mean=[0, 0, 0], std=[1, 1, 1])
data = [{
'inputs':
torch.randint(0, 256, (3, 11, 10)),
'data_sample':
TextDetDataSample(
metainfo=dict(img_shape=(11, 10), valid_ratio=1.0))
}]
inputs, data_samples = processor(data)
print(inputs.dtype)
self.assertEqual(inputs.shape, (1, 3, 11, 10))
self.assertEqual(len(data_samples), 1)
# test channel_conversion
processor = TextDetDataPreprocessor(
mean=[0., 0., 0.], std=[1., 1., 1.], bgr_to_rgb=True)
inputs, data_samples = processor(data)
self.assertEqual(inputs.shape, (1, 3, 11, 10))
self.assertEqual(len(data_samples), 1)
# test padding
data = [{
'inputs': torch.randint(0, 256, (3, 10, 11))
}, {
'inputs': torch.randint(0, 256, (3, 9, 14))
}]
processor = TextDetDataPreprocessor(
mean=[0., 0., 0.], std=[1., 1., 1.], bgr_to_rgb=True)
inputs, data_samples = processor(data)
self.assertEqual(inputs.shape, (2, 3, 10, 14))
self.assertIsNone(data_samples)
# test pad_size_divisor
data = [{
'inputs':
torch.randint(0, 256, (3, 10, 11)),
'data_sample':
TextDetDataSample(
metainfo=dict(img_shape=(10, 11), valid_ratio=1.0))
}, {
'inputs':
torch.randint(0, 256, (3, 9, 24)),
'data_sample':
TextDetDataSample(
metainfo=dict(img_shape=(9, 24), valid_ratio=1.0))
}]
aug_cfg = [dict(type='TDAugment')]
processor = TextDetDataPreprocessor(
mean=[0., 0., 0.],
std=[1., 1., 1.],
pad_size_divisor=5,
batch_augments=aug_cfg)
inputs, data_samples = processor(data, training=True)
self.assertEqual(inputs.shape, (2, 3, 10, 25))
self.assertEqual(len(data_samples), 2)
for data_sample, expected_shape in zip(data_samples, [(10, 25),
(10, 25)]):
self.assertEqual(data_sample.batch_input_shape, expected_shape)

View File

@ -0,0 +1,59 @@
# Copyright (c) OpenMMLab. All rights reserved.
from unittest import TestCase, mock
from mmocr.models.textdet import BaseTextDetHead
from mmocr.registry import MODELS
@MODELS.register_module()
class FakeModule:
def __init__(self) -> None:
pass
def get_targets(self, datasamples):
return None
def __call__(self, *args):
return None
class TestBaseTextDetHead(TestCase):
def test_init(self):
cfg = dict(type='FakeModule')
with self.assertRaises(AssertionError):
BaseTextDetHead([], cfg)
with self.assertRaises(AssertionError):
BaseTextDetHead(cfg, [])
decoder = BaseTextDetHead(cfg, cfg)
self.assertIsInstance(decoder.loss_module, FakeModule)
self.assertIsInstance(decoder.postprocessor, FakeModule)
@mock.patch(f'{__name__}.BaseTextDetHead.forward')
def test_forward(self, mock_forward):
def mock_forward(feat, out_enc, datasamples):
return True
mock_forward.side_effect = mock_forward
cfg = dict(type='FakeModule')
decoder = BaseTextDetHead(cfg, cfg)
# test loss
loss = decoder.loss(None, None)
self.assertIsNone(loss)
# test predict
predict = decoder.predict(None, None)
self.assertIsNone(predict)
# test forward
tensor = decoder(None, None)
self.assertTrue(tensor)
loss, predict = decoder.loss_and_predict(None, None)
self.assertIsNone(loss)
self.assertIsNone(predict)