mirror of https://github.com/open-mmlab/mmocr.git
parent
7e9f7756bc
commit
65e746eb3d
|
@ -1,4 +1,7 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from typing import Dict, List, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from mmengine.model import BaseModule, ModuleList
|
||||
from torch import nn
|
||||
|
@ -14,7 +17,9 @@ class FPEM(BaseModule):
|
|||
init_cfg (dict or list[dict], optional): Initialization configs.
|
||||
"""
|
||||
|
||||
def __init__(self, in_channels=128, init_cfg=None):
|
||||
def __init__(self,
|
||||
in_channels: int = 128,
|
||||
init_cfg: Optional[Union[Dict, List[Dict]]] = None) -> None:
|
||||
super().__init__(init_cfg=init_cfg)
|
||||
self.up_add1 = SeparableConv2d(in_channels, in_channels, 1)
|
||||
self.up_add2 = SeparableConv2d(in_channels, in_channels, 1)
|
||||
|
@ -23,7 +28,8 @@ class FPEM(BaseModule):
|
|||
self.down_add2 = SeparableConv2d(in_channels, in_channels, 2)
|
||||
self.down_add3 = SeparableConv2d(in_channels, in_channels, 2)
|
||||
|
||||
def forward(self, c2, c3, c4, c5):
|
||||
def forward(self, c2: torch.Tensor, c3: torch.Tensor, c4: torch.Tensor,
|
||||
c5: torch.Tensor) -> List[torch.Tensor]:
|
||||
"""
|
||||
Args:
|
||||
c2, c3, c4, c5 (Tensor): Each has the shape of
|
||||
|
@ -48,8 +54,21 @@ class FPEM(BaseModule):
|
|||
|
||||
|
||||
class SeparableConv2d(BaseModule):
|
||||
"""Implementation of separable convolution, which is consisted of depthwise
|
||||
convolution and pointwise convolution.
|
||||
|
||||
def __init__(self, in_channels, out_channels, stride=1, init_cfg=None):
|
||||
Args:
|
||||
in_channels (int): Number of input channels.
|
||||
out_channels (int): Number of output channels.
|
||||
stride (int): Stride of the depthwise convolution.
|
||||
init_cfg (dict or list[dict], optional): Initialization configs.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
in_channels: int,
|
||||
out_channels: int,
|
||||
stride: int = 1,
|
||||
init_cfg: Optional[Union[Dict, List[Dict]]] = None) -> None:
|
||||
super().__init__(init_cfg=init_cfg)
|
||||
|
||||
self.depthwise_conv = nn.Conv2d(
|
||||
|
@ -64,7 +83,15 @@ class SeparableConv2d(BaseModule):
|
|||
self.bn = nn.BatchNorm2d(out_channels)
|
||||
self.relu = nn.ReLU()
|
||||
|
||||
def forward(self, x):
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
"""Forward function.
|
||||
|
||||
Args:
|
||||
x (Tensor): Input tensor.
|
||||
|
||||
Returns:
|
||||
Tensor: Output tensor.
|
||||
"""
|
||||
x = self.depthwise_conv(x)
|
||||
x = self.pointwise_conv(x)
|
||||
x = self.bn(x)
|
||||
|
@ -85,13 +112,15 @@ class FPEM_FFM(BaseModule):
|
|||
init_cfg (dict or list[dict], optional): Initialization configs.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
in_channels,
|
||||
conv_out=128,
|
||||
fpem_repeat=2,
|
||||
align_corners=False,
|
||||
init_cfg=dict(
|
||||
type='Xavier', layer='Conv2d', distribution='uniform')):
|
||||
def __init__(
|
||||
self,
|
||||
in_channels: List[int],
|
||||
conv_out: int = 128,
|
||||
fpem_repeat: int = 2,
|
||||
align_corners: bool = False,
|
||||
init_cfg: Optional[Union[Dict, List[Dict]]] = dict(
|
||||
type='Xavier', layer='Conv2d', distribution='uniform')
|
||||
) -> None:
|
||||
super().__init__(init_cfg=init_cfg)
|
||||
# reduce layers
|
||||
self.reduce_conv_c2 = nn.Sequential(
|
||||
|
@ -119,7 +148,7 @@ class FPEM_FFM(BaseModule):
|
|||
for _ in range(fpem_repeat):
|
||||
self.fpems.append(FPEM(conv_out))
|
||||
|
||||
def forward(self, x):
|
||||
def forward(self, x: List[torch.Tensor]) -> Tuple[torch.Tensor]:
|
||||
"""
|
||||
Args:
|
||||
x (list[Tensor]): A list of four tensors of shape
|
||||
|
@ -128,7 +157,7 @@ class FPEM_FFM(BaseModule):
|
|||
``in_channels``.
|
||||
|
||||
Returns:
|
||||
list[Tensor]: Four tensors of shape
|
||||
tuple[Tensor]: Four tensors of shape
|
||||
:math:`(N, C_{out}, H_0, W_0)` where :math:`C_{out}` is
|
||||
``conv_out``.
|
||||
"""
|
||||
|
|
|
@ -1,7 +1,7 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from .bbox_utils import (bbox2poly, bbox_center_distance, bbox_diag_distance,
|
||||
bezier2polygon, is_on_same_line, rescale_bboxes,
|
||||
stitch_boxes_into_lines)
|
||||
bezier2polygon, is_on_same_line, rescale_bbox,
|
||||
rescale_bboxes, stitch_boxes_into_lines)
|
||||
from .check_argument import (equal_len, is_2dlist, is_3dlist, is_none_or_type,
|
||||
is_type_list, valid_boundary)
|
||||
from .collect_env import collect_env
|
||||
|
@ -34,17 +34,18 @@ __all__ = [
|
|||
'is_2dlist', 'valid_boundary', 'list_to_file', 'list_from_file',
|
||||
'is_on_same_line', 'stitch_boxes_into_lines', 'StringStripper',
|
||||
'bezier2polygon', 'sort_points', 'dump_ocr_data', 'recog_anno_to_imginfo',
|
||||
'rescale_polygons', 'rescale_polygon', 'rescale_bboxes', 'bbox2poly',
|
||||
'crop_polygon', 'is_poly_inside_rect', 'poly2bbox', 'poly_intersection',
|
||||
'poly_iou', 'poly_make_valid', 'poly_union', 'poly2shapely',
|
||||
'polys2shapely', 'register_all_modules', 'offset_polygon', 'sort_vertex8',
|
||||
'sort_vertex', 'bbox_center_distance', 'bbox_diag_distance',
|
||||
'boundary_iou', 'point_distance', 'points_center', 'fill_hole',
|
||||
'LineJsonParser', 'LineStrParser', 'shapely2poly', 'crop_img', 'warp_img',
|
||||
'ConfigType', 'DetSampleList', 'RecForwardResults', 'InitConfigType',
|
||||
'OptConfigType', 'OptDetSampleList', 'OptInitConfigType', 'OptMultiConfig',
|
||||
'OptRecSampleList', 'RecSampleList', 'MultiConfig', 'OptTensor',
|
||||
'ColorType', 'OptKIESampleList', 'KIESampleList', 'is_archive',
|
||||
'check_integrity', 'list_files', 'get_md5', 'InstanceList', 'LabelList',
|
||||
'OptInstanceList', 'OptLabelList', 'RangeType', 'remove_pipeline_elements'
|
||||
'rescale_polygons', 'rescale_polygon', 'rescale_bbox', 'rescale_bboxes',
|
||||
'bbox2poly', 'crop_polygon', 'is_poly_inside_rect', 'poly2bbox',
|
||||
'poly_intersection', 'poly_iou', 'poly_make_valid', 'poly_union',
|
||||
'poly2shapely', 'polys2shapely', 'register_all_modules', 'offset_polygon',
|
||||
'sort_vertex8', 'sort_vertex', 'bbox_center_distance',
|
||||
'bbox_diag_distance', 'boundary_iou', 'point_distance', 'points_center',
|
||||
'fill_hole', 'LineJsonParser', 'LineStrParser', 'shapely2poly', 'crop_img',
|
||||
'warp_img', 'ConfigType', 'DetSampleList', 'RecForwardResults',
|
||||
'InitConfigType', 'OptConfigType', 'OptDetSampleList', 'OptInitConfigType',
|
||||
'OptMultiConfig', 'OptRecSampleList', 'RecSampleList', 'MultiConfig',
|
||||
'OptTensor', 'ColorType', 'OptKIESampleList', 'KIESampleList',
|
||||
'is_archive', 'check_integrity', 'list_files', 'get_md5', 'InstanceList',
|
||||
'LabelList', 'OptInstanceList', 'OptLabelList', 'RangeType',
|
||||
'remove_pipeline_elements'
|
||||
]
|
||||
|
|
|
@ -0,0 +1,46 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import unittest
|
||||
|
||||
import torch
|
||||
|
||||
from mmocr.models.textdet.necks.fpem_ffm import FPEM, FPEM_FFM
|
||||
|
||||
|
||||
class TestFPEM(unittest.TestCase):
|
||||
|
||||
def setUp(self):
|
||||
self.c2 = torch.Tensor(1, 8, 64, 64)
|
||||
self.c3 = torch.Tensor(1, 8, 32, 32)
|
||||
self.c4 = torch.Tensor(1, 8, 16, 16)
|
||||
self.c5 = torch.Tensor(1, 8, 8, 8)
|
||||
self.fpem = FPEM(in_channels=8)
|
||||
|
||||
def test_forward(self):
|
||||
neck = FPEM(in_channels=8)
|
||||
neck.init_weights()
|
||||
out = neck(self.c2, self.c3, self.c4, self.c5)
|
||||
self.assertTrue(out[0].shape == self.c2.shape)
|
||||
self.assertTrue(out[1].shape == self.c3.shape)
|
||||
self.assertTrue(out[2].shape == self.c4.shape)
|
||||
self.assertTrue(out[3].shape == self.c5.shape)
|
||||
|
||||
|
||||
class TestFPEM_FFM(unittest.TestCase):
|
||||
|
||||
def setUp(self):
|
||||
self.c2 = torch.Tensor(1, 8, 64, 64)
|
||||
self.c3 = torch.Tensor(1, 16, 32, 32)
|
||||
self.c4 = torch.Tensor(1, 32, 16, 16)
|
||||
self.c5 = torch.Tensor(1, 64, 8, 8)
|
||||
self.in_channels = [8, 16, 32, 64]
|
||||
self.conv_out = 8
|
||||
self.features = [self.c2, self.c3, self.c4, self.c5]
|
||||
|
||||
def test_forward(self):
|
||||
neck = FPEM_FFM(in_channels=self.in_channels, conv_out=self.conv_out)
|
||||
neck.init_weights()
|
||||
out = neck(self.features)
|
||||
self.assertTrue(out[0].shape == torch.Size([1, 8, 64, 64]))
|
||||
self.assertTrue(out[1].shape == out[0].shape)
|
||||
self.assertTrue(out[2].shape == out[0].shape)
|
||||
self.assertTrue(out[3].shape == out[0].shape)
|
|
@ -5,8 +5,8 @@ import numpy as np
|
|||
import torch
|
||||
|
||||
from mmocr.utils import (bbox2poly, bbox_center_distance, bbox_diag_distance,
|
||||
bezier2polygon, is_on_same_line,
|
||||
stitch_boxes_into_lines)
|
||||
bezier2polygon, is_on_same_line, rescale_bbox,
|
||||
rescale_bboxes, stitch_boxes_into_lines)
|
||||
from mmocr.utils.bbox_utils import bbox_jitter
|
||||
|
||||
|
||||
|
@ -236,3 +236,31 @@ class TestStitchBoxesIntoLines(unittest.TestCase):
|
|||
result.sort(key=lambda x: x['box'][0])
|
||||
expected_result.sort(key=lambda x: x['box'][0])
|
||||
self.assertEqual(result, expected_result)
|
||||
|
||||
|
||||
class TestRescaleBbox(unittest.TestCase):
|
||||
|
||||
def setUp(self) -> None:
|
||||
self.bbox = np.array([0, 0, 1, 1])
|
||||
self.bboxes = np.array([[0, 0, 1, 1], [1, 1, 2, 2]])
|
||||
self.scale = 2
|
||||
|
||||
def test_rescale_bbox(self):
|
||||
# mul
|
||||
rescaled_bbox = rescale_bbox(self.bbox, self.scale, mode='mul')
|
||||
self.assertTrue(np.allclose(rescaled_bbox, np.array([0, 0, 2, 2])))
|
||||
# div
|
||||
rescaled_bbox = rescale_bbox(self.bbox, self.scale, mode='div')
|
||||
self.assertTrue(np.allclose(rescaled_bbox, np.array([0, 0, 0.5, 0.5])))
|
||||
|
||||
def test_rescale_bboxes(self):
|
||||
# mul
|
||||
rescaled_bboxes = rescale_bboxes(self.bboxes, self.scale, mode='mul')
|
||||
self.assertTrue(
|
||||
np.allclose(rescaled_bboxes, np.array([[0, 0, 2, 2], [2, 2, 4,
|
||||
4]])))
|
||||
# div
|
||||
rescaled_bboxes = rescale_bboxes(self.bboxes, self.scale, mode='div')
|
||||
self.assertTrue(
|
||||
np.allclose(rescaled_bboxes,
|
||||
np.array([[0, 0, 0.5, 0.5], [0.5, 0.5, 1, 1]])))
|
||||
|
|
|
@ -46,3 +46,9 @@ def test_valid_boundary():
|
|||
assert utils.valid_boundary(x, False)
|
||||
x = [0, 0, 1, 0, 1, 1, 0, 1, 1]
|
||||
assert utils.valid_boundary(x, True)
|
||||
|
||||
|
||||
def test_equal_len():
|
||||
|
||||
assert utils.equal_len([1, 2, 3], [1, 2, 3])
|
||||
assert not utils.equal_len([1, 2, 3], [1, 2, 3, 4])
|
||||
|
|
Loading…
Reference in New Issue