[PSE] PSE Neck FPNF

pull/1178/head
wangxinyu 2022-06-21 07:43:48 +00:00 committed by gaotongxiao
parent 05990c58d9
commit 00ba46b5b9
2 changed files with 42 additions and 9 deletions

View File

@ -1,8 +1,11 @@
# Copyright (c) OpenMMLab. All rights reserved.
from typing import Dict, List, Optional, Union
import torch
import torch.nn.functional as F
from mmcv.cnn import ConvModule
from mmcv.runner import BaseModule, ModuleList, auto_fp16
from torch import Tensor
from mmocr.registry import MODELS
@ -14,18 +17,24 @@ class FPNF(BaseModule):
Args:
in_channels (list[int]): A list of number of input channels.
Defaults to [256, 512, 1024, 2048].
out_channels (int): The number of output channels.
Defaults to 256.
fusion_type (str): Type of the final feature fusion layer. Available
options are "concat" and "add".
options are "concat" and "add". Defaults to "concat".
init_cfg (dict or list[dict], optional): Initialization configs.
Defaults to
dict(type='Xavier', layer='Conv2d', distribution='uniform')
"""
def __init__(self,
in_channels=[256, 512, 1024, 2048],
out_channels=256,
fusion_type='concat',
init_cfg=dict(
type='Xavier', layer='Conv2d', distribution='uniform')):
def __init__(
self,
in_channels: List[int] = [256, 512, 1024, 2048],
out_channels: int = 256,
fusion_type: str = 'concat',
init_cfg: Optional[Union[Dict, List[Dict]]] = dict(
type='Xavier', layer='Conv2d', distribution='uniform')
) -> None:
super().__init__(init_cfg=init_cfg)
conv_cfg = None
norm_cfg = dict(type='BN')
@ -80,7 +89,7 @@ class FPNF(BaseModule):
inplace=False)
@auto_fp16()
def forward(self, inputs):
def forward(self, inputs: List[Tensor]) -> Tensor:
"""
Args:
inputs (list[Tensor]): Each tensor has the shape of
@ -109,7 +118,7 @@ class FPNF(BaseModule):
# step 2: smooth level i-1
laterals[i - 1] = self.fpn_convs[i - 1](laterals[i - 1])
# upsample and cont
# upsample and cat
bottom_shape = laterals[0].shape[2:]
for i in range(1, used_backbone_levels):
laterals[i] = F.interpolate(

View File

@ -0,0 +1,24 @@
# Copyright (c) OpenMMLab. All rights reserved.
import unittest
import torch
from nose_parameterized import parameterized
from mmocr.models.textdet.necks import FPNF
class TestFPNF(unittest.TestCase):
def setUp(self):
in_channels = [256, 512, 1024, 2048]
size = [112, 56, 28, 14]
inputs = []
for i in range(4):
inputs.append(torch.rand(1, in_channels[i], size[i], size[i]))
self.inputs = inputs
@parameterized.expand([('concat'), ('add')])
def test_forward(self, fusion_type):
fpnf = FPNF(fusion_type=fusion_type)
outputs = fpnf.forward(self.inputs)
self.assertListEqual(list(outputs.size()), [1, 256, 112, 112])