mirror of https://github.com/open-mmlab/mmocr.git
[PSE] PSE Neck FPNF
parent
05990c58d9
commit
00ba46b5b9
|
@ -1,8 +1,11 @@
|
||||||
# Copyright (c) OpenMMLab. All rights reserved.
|
# Copyright (c) OpenMMLab. All rights reserved.
|
||||||
|
from typing import Dict, List, Optional, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
from mmcv.cnn import ConvModule
|
from mmcv.cnn import ConvModule
|
||||||
from mmcv.runner import BaseModule, ModuleList, auto_fp16
|
from mmcv.runner import BaseModule, ModuleList, auto_fp16
|
||||||
|
from torch import Tensor
|
||||||
|
|
||||||
from mmocr.registry import MODELS
|
from mmocr.registry import MODELS
|
||||||
|
|
||||||
|
@ -14,18 +17,24 @@ class FPNF(BaseModule):
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
in_channels (list[int]): A list of number of input channels.
|
in_channels (list[int]): A list of number of input channels.
|
||||||
|
Defaults to [256, 512, 1024, 2048].
|
||||||
out_channels (int): The number of output channels.
|
out_channels (int): The number of output channels.
|
||||||
|
Defaults to 256.
|
||||||
fusion_type (str): Type of the final feature fusion layer. Available
|
fusion_type (str): Type of the final feature fusion layer. Available
|
||||||
options are "concat" and "add".
|
options are "concat" and "add". Defaults to "concat".
|
||||||
init_cfg (dict or list[dict], optional): Initialization configs.
|
init_cfg (dict or list[dict], optional): Initialization configs.
|
||||||
|
Defaults to
|
||||||
|
dict(type='Xavier', layer='Conv2d', distribution='uniform')
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self,
|
def __init__(
|
||||||
in_channels=[256, 512, 1024, 2048],
|
self,
|
||||||
out_channels=256,
|
in_channels: List[int] = [256, 512, 1024, 2048],
|
||||||
fusion_type='concat',
|
out_channels: int = 256,
|
||||||
init_cfg=dict(
|
fusion_type: str = 'concat',
|
||||||
type='Xavier', layer='Conv2d', distribution='uniform')):
|
init_cfg: Optional[Union[Dict, List[Dict]]] = dict(
|
||||||
|
type='Xavier', layer='Conv2d', distribution='uniform')
|
||||||
|
) -> None:
|
||||||
super().__init__(init_cfg=init_cfg)
|
super().__init__(init_cfg=init_cfg)
|
||||||
conv_cfg = None
|
conv_cfg = None
|
||||||
norm_cfg = dict(type='BN')
|
norm_cfg = dict(type='BN')
|
||||||
|
@ -80,7 +89,7 @@ class FPNF(BaseModule):
|
||||||
inplace=False)
|
inplace=False)
|
||||||
|
|
||||||
@auto_fp16()
|
@auto_fp16()
|
||||||
def forward(self, inputs):
|
def forward(self, inputs: List[Tensor]) -> Tensor:
|
||||||
"""
|
"""
|
||||||
Args:
|
Args:
|
||||||
inputs (list[Tensor]): Each tensor has the shape of
|
inputs (list[Tensor]): Each tensor has the shape of
|
||||||
|
@ -109,7 +118,7 @@ class FPNF(BaseModule):
|
||||||
# step 2: smooth level i-1
|
# step 2: smooth level i-1
|
||||||
laterals[i - 1] = self.fpn_convs[i - 1](laterals[i - 1])
|
laterals[i - 1] = self.fpn_convs[i - 1](laterals[i - 1])
|
||||||
|
|
||||||
# upsample and cont
|
# upsample and cat
|
||||||
bottom_shape = laterals[0].shape[2:]
|
bottom_shape = laterals[0].shape[2:]
|
||||||
for i in range(1, used_backbone_levels):
|
for i in range(1, used_backbone_levels):
|
||||||
laterals[i] = F.interpolate(
|
laterals[i] = F.interpolate(
|
||||||
|
|
|
@ -0,0 +1,24 @@
|
||||||
|
# Copyright (c) OpenMMLab. All rights reserved.
|
||||||
|
import unittest
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from nose_parameterized import parameterized
|
||||||
|
|
||||||
|
from mmocr.models.textdet.necks import FPNF
|
||||||
|
|
||||||
|
|
||||||
|
class TestFPNF(unittest.TestCase):
|
||||||
|
|
||||||
|
def setUp(self):
|
||||||
|
in_channels = [256, 512, 1024, 2048]
|
||||||
|
size = [112, 56, 28, 14]
|
||||||
|
inputs = []
|
||||||
|
for i in range(4):
|
||||||
|
inputs.append(torch.rand(1, in_channels[i], size[i], size[i]))
|
||||||
|
self.inputs = inputs
|
||||||
|
|
||||||
|
@parameterized.expand([('concat'), ('add')])
|
||||||
|
def test_forward(self, fusion_type):
|
||||||
|
fpnf = FPNF(fusion_type=fusion_type)
|
||||||
|
outputs = fpnf.forward(self.inputs)
|
||||||
|
self.assertListEqual(list(outputs.size()), [1, 256, 112, 112])
|
Loading…
Reference in New Issue