[UT] Add missing unit tests (#1651)

* update

* remove code
pull/1652/head
Qing Jiang 2022-12-30 12:01:14 +08:00 committed by GitHub
parent 7e9f7756bc
commit 65e746eb3d
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 140 additions and 30 deletions

View File

@ -1,4 +1,7 @@
# Copyright (c) OpenMMLab. All rights reserved.
from typing import Dict, List, Optional, Tuple, Union
import torch
import torch.nn.functional as F
from mmengine.model import BaseModule, ModuleList
from torch import nn
@ -14,7 +17,9 @@ class FPEM(BaseModule):
init_cfg (dict or list[dict], optional): Initialization configs.
"""
def __init__(self, in_channels=128, init_cfg=None):
def __init__(self,
in_channels: int = 128,
init_cfg: Optional[Union[Dict, List[Dict]]] = None) -> None:
super().__init__(init_cfg=init_cfg)
self.up_add1 = SeparableConv2d(in_channels, in_channels, 1)
self.up_add2 = SeparableConv2d(in_channels, in_channels, 1)
@ -23,7 +28,8 @@ class FPEM(BaseModule):
self.down_add2 = SeparableConv2d(in_channels, in_channels, 2)
self.down_add3 = SeparableConv2d(in_channels, in_channels, 2)
def forward(self, c2, c3, c4, c5):
def forward(self, c2: torch.Tensor, c3: torch.Tensor, c4: torch.Tensor,
c5: torch.Tensor) -> List[torch.Tensor]:
"""
Args:
c2, c3, c4, c5 (Tensor): Each has the shape of
@ -48,8 +54,21 @@ class FPEM(BaseModule):
class SeparableConv2d(BaseModule):
"""Implementation of separable convolution, which is consisted of depthwise
convolution and pointwise convolution.
def __init__(self, in_channels, out_channels, stride=1, init_cfg=None):
Args:
in_channels (int): Number of input channels.
out_channels (int): Number of output channels.
stride (int): Stride of the depthwise convolution.
init_cfg (dict or list[dict], optional): Initialization configs.
"""
def __init__(self,
in_channels: int,
out_channels: int,
stride: int = 1,
init_cfg: Optional[Union[Dict, List[Dict]]] = None) -> None:
super().__init__(init_cfg=init_cfg)
self.depthwise_conv = nn.Conv2d(
@ -64,7 +83,15 @@ class SeparableConv2d(BaseModule):
self.bn = nn.BatchNorm2d(out_channels)
self.relu = nn.ReLU()
def forward(self, x):
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""Forward function.
Args:
x (Tensor): Input tensor.
Returns:
Tensor: Output tensor.
"""
x = self.depthwise_conv(x)
x = self.pointwise_conv(x)
x = self.bn(x)
@ -85,13 +112,15 @@ class FPEM_FFM(BaseModule):
init_cfg (dict or list[dict], optional): Initialization configs.
"""
def __init__(self,
in_channels,
conv_out=128,
fpem_repeat=2,
align_corners=False,
init_cfg=dict(
type='Xavier', layer='Conv2d', distribution='uniform')):
def __init__(
self,
in_channels: List[int],
conv_out: int = 128,
fpem_repeat: int = 2,
align_corners: bool = False,
init_cfg: Optional[Union[Dict, List[Dict]]] = dict(
type='Xavier', layer='Conv2d', distribution='uniform')
) -> None:
super().__init__(init_cfg=init_cfg)
# reduce layers
self.reduce_conv_c2 = nn.Sequential(
@ -119,7 +148,7 @@ class FPEM_FFM(BaseModule):
for _ in range(fpem_repeat):
self.fpems.append(FPEM(conv_out))
def forward(self, x):
def forward(self, x: List[torch.Tensor]) -> Tuple[torch.Tensor]:
"""
Args:
x (list[Tensor]): A list of four tensors of shape
@ -128,7 +157,7 @@ class FPEM_FFM(BaseModule):
``in_channels``.
Returns:
list[Tensor]: Four tensors of shape
tuple[Tensor]: Four tensors of shape
:math:`(N, C_{out}, H_0, W_0)` where :math:`C_{out}` is
``conv_out``.
"""

View File

@ -1,7 +1,7 @@
# Copyright (c) OpenMMLab. All rights reserved.
from .bbox_utils import (bbox2poly, bbox_center_distance, bbox_diag_distance,
bezier2polygon, is_on_same_line, rescale_bboxes,
stitch_boxes_into_lines)
bezier2polygon, is_on_same_line, rescale_bbox,
rescale_bboxes, stitch_boxes_into_lines)
from .check_argument import (equal_len, is_2dlist, is_3dlist, is_none_or_type,
is_type_list, valid_boundary)
from .collect_env import collect_env
@ -34,17 +34,18 @@ __all__ = [
'is_2dlist', 'valid_boundary', 'list_to_file', 'list_from_file',
'is_on_same_line', 'stitch_boxes_into_lines', 'StringStripper',
'bezier2polygon', 'sort_points', 'dump_ocr_data', 'recog_anno_to_imginfo',
'rescale_polygons', 'rescale_polygon', 'rescale_bboxes', 'bbox2poly',
'crop_polygon', 'is_poly_inside_rect', 'poly2bbox', 'poly_intersection',
'poly_iou', 'poly_make_valid', 'poly_union', 'poly2shapely',
'polys2shapely', 'register_all_modules', 'offset_polygon', 'sort_vertex8',
'sort_vertex', 'bbox_center_distance', 'bbox_diag_distance',
'boundary_iou', 'point_distance', 'points_center', 'fill_hole',
'LineJsonParser', 'LineStrParser', 'shapely2poly', 'crop_img', 'warp_img',
'ConfigType', 'DetSampleList', 'RecForwardResults', 'InitConfigType',
'OptConfigType', 'OptDetSampleList', 'OptInitConfigType', 'OptMultiConfig',
'OptRecSampleList', 'RecSampleList', 'MultiConfig', 'OptTensor',
'ColorType', 'OptKIESampleList', 'KIESampleList', 'is_archive',
'check_integrity', 'list_files', 'get_md5', 'InstanceList', 'LabelList',
'OptInstanceList', 'OptLabelList', 'RangeType', 'remove_pipeline_elements'
'rescale_polygons', 'rescale_polygon', 'rescale_bbox', 'rescale_bboxes',
'bbox2poly', 'crop_polygon', 'is_poly_inside_rect', 'poly2bbox',
'poly_intersection', 'poly_iou', 'poly_make_valid', 'poly_union',
'poly2shapely', 'polys2shapely', 'register_all_modules', 'offset_polygon',
'sort_vertex8', 'sort_vertex', 'bbox_center_distance',
'bbox_diag_distance', 'boundary_iou', 'point_distance', 'points_center',
'fill_hole', 'LineJsonParser', 'LineStrParser', 'shapely2poly', 'crop_img',
'warp_img', 'ConfigType', 'DetSampleList', 'RecForwardResults',
'InitConfigType', 'OptConfigType', 'OptDetSampleList', 'OptInitConfigType',
'OptMultiConfig', 'OptRecSampleList', 'RecSampleList', 'MultiConfig',
'OptTensor', 'ColorType', 'OptKIESampleList', 'KIESampleList',
'is_archive', 'check_integrity', 'list_files', 'get_md5', 'InstanceList',
'LabelList', 'OptInstanceList', 'OptLabelList', 'RangeType',
'remove_pipeline_elements'
]

View File

@ -0,0 +1,46 @@
# Copyright (c) OpenMMLab. All rights reserved.
import unittest
import torch
from mmocr.models.textdet.necks.fpem_ffm import FPEM, FPEM_FFM
class TestFPEM(unittest.TestCase):
def setUp(self):
self.c2 = torch.Tensor(1, 8, 64, 64)
self.c3 = torch.Tensor(1, 8, 32, 32)
self.c4 = torch.Tensor(1, 8, 16, 16)
self.c5 = torch.Tensor(1, 8, 8, 8)
self.fpem = FPEM(in_channels=8)
def test_forward(self):
neck = FPEM(in_channels=8)
neck.init_weights()
out = neck(self.c2, self.c3, self.c4, self.c5)
self.assertTrue(out[0].shape == self.c2.shape)
self.assertTrue(out[1].shape == self.c3.shape)
self.assertTrue(out[2].shape == self.c4.shape)
self.assertTrue(out[3].shape == self.c5.shape)
class TestFPEM_FFM(unittest.TestCase):
def setUp(self):
self.c2 = torch.Tensor(1, 8, 64, 64)
self.c3 = torch.Tensor(1, 16, 32, 32)
self.c4 = torch.Tensor(1, 32, 16, 16)
self.c5 = torch.Tensor(1, 64, 8, 8)
self.in_channels = [8, 16, 32, 64]
self.conv_out = 8
self.features = [self.c2, self.c3, self.c4, self.c5]
def test_forward(self):
neck = FPEM_FFM(in_channels=self.in_channels, conv_out=self.conv_out)
neck.init_weights()
out = neck(self.features)
self.assertTrue(out[0].shape == torch.Size([1, 8, 64, 64]))
self.assertTrue(out[1].shape == out[0].shape)
self.assertTrue(out[2].shape == out[0].shape)
self.assertTrue(out[3].shape == out[0].shape)

View File

@ -5,8 +5,8 @@ import numpy as np
import torch
from mmocr.utils import (bbox2poly, bbox_center_distance, bbox_diag_distance,
bezier2polygon, is_on_same_line,
stitch_boxes_into_lines)
bezier2polygon, is_on_same_line, rescale_bbox,
rescale_bboxes, stitch_boxes_into_lines)
from mmocr.utils.bbox_utils import bbox_jitter
@ -236,3 +236,31 @@ class TestStitchBoxesIntoLines(unittest.TestCase):
result.sort(key=lambda x: x['box'][0])
expected_result.sort(key=lambda x: x['box'][0])
self.assertEqual(result, expected_result)
class TestRescaleBbox(unittest.TestCase):
def setUp(self) -> None:
self.bbox = np.array([0, 0, 1, 1])
self.bboxes = np.array([[0, 0, 1, 1], [1, 1, 2, 2]])
self.scale = 2
def test_rescale_bbox(self):
# mul
rescaled_bbox = rescale_bbox(self.bbox, self.scale, mode='mul')
self.assertTrue(np.allclose(rescaled_bbox, np.array([0, 0, 2, 2])))
# div
rescaled_bbox = rescale_bbox(self.bbox, self.scale, mode='div')
self.assertTrue(np.allclose(rescaled_bbox, np.array([0, 0, 0.5, 0.5])))
def test_rescale_bboxes(self):
# mul
rescaled_bboxes = rescale_bboxes(self.bboxes, self.scale, mode='mul')
self.assertTrue(
np.allclose(rescaled_bboxes, np.array([[0, 0, 2, 2], [2, 2, 4,
4]])))
# div
rescaled_bboxes = rescale_bboxes(self.bboxes, self.scale, mode='div')
self.assertTrue(
np.allclose(rescaled_bboxes,
np.array([[0, 0, 0.5, 0.5], [0.5, 0.5, 1, 1]])))

View File

@ -46,3 +46,9 @@ def test_valid_boundary():
assert utils.valid_boundary(x, False)
x = [0, 0, 1, 0, 1, 1, 0, 1, 1]
assert utils.valid_boundary(x, True)
def test_equal_len():
assert utils.equal_len([1, 2, 3], [1, 2, 3])
assert not utils.equal_len([1, 2, 3], [1, 2, 3, 4])