mirror of https://github.com/open-mmlab/mmocr.git
243 lines
9.3 KiB
Python
243 lines
9.3 KiB
Python
# Copyright (c) OpenMMLab. All rights reserved.
|
|
from typing import List
|
|
|
|
import torch
|
|
import torch.nn as nn
|
|
import torch.nn.functional as F
|
|
from mmcv.cnn import ConvModule
|
|
from mmengine.model import BaseModule
|
|
|
|
from mmocr.registry import MODELS
|
|
from mmocr.utils import ConfigType, MultiConfig, OptConfigType
|
|
|
|
|
|
@MODELS.register_module()
|
|
class BiFPN(BaseModule):
|
|
"""illustration of a minimal bifpn unit P7_0 ------------------------->
|
|
P7_2 -------->
|
|
|
|
|-------------| ↑ ↓ |
|
|
P6_0 ---------> P6_1 ---------> P6_2 -------->
|
|
|-------------|--------------↑ ↑ ↓ | P5_0
|
|
---------> P5_1 ---------> P5_2 --------> |-------------|--------------↑
|
|
↑ ↓ | P4_0 ---------> P4_1 ---------> P4_2
|
|
--------> |-------------|--------------↑ ↑
|
|
|--------------↓ | P3_0 -------------------------> P3_2 -------->
|
|
"""
|
|
|
|
def __init__(self,
|
|
in_channels: List[int],
|
|
out_channels: int,
|
|
num_outs: int,
|
|
repeat_times: int = 2,
|
|
start_level: int = 0,
|
|
end_level: int = -1,
|
|
add_extra_convs: bool = False,
|
|
relu_before_extra_convs: bool = False,
|
|
no_norm_on_lateral: bool = False,
|
|
conv_cfg: OptConfigType = None,
|
|
norm_cfg: OptConfigType = None,
|
|
act_cfg: OptConfigType = None,
|
|
laterial_conv1x1: bool = False,
|
|
upsample_cfg: ConfigType = dict(mode='nearest'),
|
|
pool_cfg: ConfigType = dict(),
|
|
init_cfg: MultiConfig = dict(
|
|
type='Xavier', layer='Conv2d', distribution='uniform')):
|
|
super().__init__(init_cfg=init_cfg)
|
|
assert isinstance(in_channels, list)
|
|
self.in_channels = in_channels
|
|
self.out_channels = out_channels
|
|
self.num_ins = len(in_channels)
|
|
self.num_outs = num_outs
|
|
self.relu_before_extra_convs = relu_before_extra_convs
|
|
self.no_norm_on_lateral = no_norm_on_lateral
|
|
self.upsample_cfg = upsample_cfg.copy()
|
|
self.repeat_times = repeat_times
|
|
if end_level == -1 or end_level == self.num_ins - 1:
|
|
self.backbone_end_level = self.num_ins
|
|
assert num_outs >= self.num_ins - start_level
|
|
else:
|
|
# if end_level is not the last level, no extra level is allowed
|
|
self.backbone_end_level = end_level + 1
|
|
assert end_level < self.num_ins
|
|
assert num_outs == end_level - start_level + 1
|
|
self.start_level = start_level
|
|
self.end_level = end_level
|
|
self.add_extra_convs = add_extra_convs
|
|
|
|
self.lateral_convs = nn.ModuleList()
|
|
self.extra_convs = nn.ModuleList()
|
|
self.bifpn_convs = nn.ModuleList()
|
|
for i in range(self.start_level, self.backbone_end_level):
|
|
if in_channels[i] == out_channels:
|
|
l_conv = nn.Identity()
|
|
else:
|
|
l_conv = ConvModule(
|
|
in_channels[i],
|
|
out_channels,
|
|
1,
|
|
conv_cfg=conv_cfg,
|
|
norm_cfg=norm_cfg,
|
|
bias=True,
|
|
act_cfg=act_cfg,
|
|
inplace=False)
|
|
self.lateral_convs.append(l_conv)
|
|
|
|
for _ in range(repeat_times):
|
|
self.bifpn_convs.append(
|
|
BiFPNLayer(
|
|
channels=out_channels,
|
|
levels=num_outs,
|
|
conv_cfg=conv_cfg,
|
|
norm_cfg=norm_cfg,
|
|
act_cfg=act_cfg,
|
|
pool_cfg=pool_cfg))
|
|
|
|
# add extra conv layers (e.g., RetinaNet)
|
|
extra_levels = num_outs - self.backbone_end_level + self.start_level
|
|
if add_extra_convs and extra_levels >= 1:
|
|
for i in range(extra_levels):
|
|
if i == 0:
|
|
in_channels = self.in_channels[self.backbone_end_level - 1]
|
|
else:
|
|
in_channels = out_channels
|
|
if in_channels == out_channels:
|
|
extra_fpn_conv = nn.MaxPool2d(
|
|
kernel_size=3, stride=2, padding=1)
|
|
else:
|
|
extra_fpn_conv = nn.Sequential(
|
|
ConvModule(
|
|
in_channels=in_channels,
|
|
out_channels=out_channels,
|
|
kernel_size=1,
|
|
norm_cfg=norm_cfg,
|
|
act_cfg=act_cfg),
|
|
nn.MaxPool2d(kernel_size=3, stride=2, padding=1))
|
|
self.extra_convs.append(extra_fpn_conv)
|
|
|
|
def forward(self, inputs):
|
|
|
|
def extra_convs(inputs, extra_convs):
|
|
outputs = list()
|
|
for extra_conv in extra_convs:
|
|
inputs = extra_conv(inputs)
|
|
outputs.append(inputs)
|
|
return outputs
|
|
|
|
assert len(inputs) == len(self.in_channels)
|
|
|
|
# build laterals
|
|
laterals = [
|
|
lateral_conv(inputs[i + self.start_level])
|
|
for i, lateral_conv in enumerate(self.lateral_convs)
|
|
]
|
|
if self.num_outs > len(laterals) and self.add_extra_convs:
|
|
extra_source = inputs[self.backbone_end_level - 1]
|
|
for extra_conv in self.extra_convs:
|
|
extra_source = extra_conv(extra_source)
|
|
laterals.append(extra_source)
|
|
|
|
for bifpn_module in self.bifpn_convs:
|
|
laterals = bifpn_module(laterals)
|
|
outs = laterals
|
|
|
|
return tuple(outs)
|
|
|
|
|
|
def swish(x):
|
|
return x * x.sigmoid()
|
|
|
|
|
|
class BiFPNLayer(BaseModule):
|
|
|
|
def __init__(self,
|
|
channels,
|
|
levels,
|
|
init=0.5,
|
|
conv_cfg=None,
|
|
norm_cfg=None,
|
|
act_cfg=None,
|
|
upsample_cfg=None,
|
|
pool_cfg=None,
|
|
eps=0.0001,
|
|
init_cfg=None):
|
|
super().__init__(init_cfg=init_cfg)
|
|
self.act_cfg = act_cfg
|
|
self.upsample_cfg = upsample_cfg
|
|
self.pool_cfg = pool_cfg
|
|
self.eps = eps
|
|
self.levels = levels
|
|
self.bifpn_convs = nn.ModuleList()
|
|
# weighted
|
|
self.weight_two_nodes = nn.Parameter(
|
|
torch.Tensor(2, levels).fill_(init))
|
|
self.weight_three_nodes = nn.Parameter(
|
|
torch.Tensor(3, levels - 2).fill_(init))
|
|
self.relu = nn.ReLU()
|
|
for _ in range(2):
|
|
for _ in range(self.levels - 1): # 1,2,3
|
|
fpn_conv = nn.Sequential(
|
|
ConvModule(
|
|
channels,
|
|
channels,
|
|
3,
|
|
padding=1,
|
|
conv_cfg=conv_cfg,
|
|
norm_cfg=norm_cfg,
|
|
act_cfg=act_cfg,
|
|
inplace=False))
|
|
self.bifpn_convs.append(fpn_conv)
|
|
|
|
def forward(self, inputs):
|
|
assert len(inputs) == self.levels
|
|
# build top-down and down-top path with stack
|
|
levels = self.levels
|
|
# w relu
|
|
w1 = self.relu(self.weight_two_nodes)
|
|
w1 /= torch.sum(w1, dim=0) + self.eps # normalize
|
|
w2 = self.relu(self.weight_three_nodes)
|
|
# w2 /= torch.sum(w2, dim=0) + self.eps # normalize
|
|
# build top-down
|
|
idx_bifpn = 0
|
|
pathtd = inputs
|
|
inputs_clone = []
|
|
for in_tensor in inputs:
|
|
inputs_clone.append(in_tensor.clone())
|
|
|
|
for i in range(levels - 1, 0, -1):
|
|
_, _, h, w = pathtd[i - 1].shape
|
|
# pathtd[i - 1] = (
|
|
# w1[0, i - 1] * pathtd[i - 1] + w1[1, i - 1] *
|
|
# F.interpolate(pathtd[i], size=(h, w), mode='nearest')) / (
|
|
# w1[0, i - 1] + w1[1, i - 1] + self.eps)
|
|
pathtd[i -
|
|
1] = w1[0, i -
|
|
1] * pathtd[i - 1] + w1[1, i - 1] * F.interpolate(
|
|
pathtd[i], size=(h, w), mode='nearest')
|
|
pathtd[i - 1] = swish(pathtd[i - 1])
|
|
pathtd[i - 1] = self.bifpn_convs[idx_bifpn](pathtd[i - 1])
|
|
idx_bifpn = idx_bifpn + 1
|
|
# build down-top
|
|
for i in range(0, levels - 2, 1):
|
|
tmp_path = torch.stack([
|
|
inputs_clone[i + 1], pathtd[i + 1],
|
|
F.max_pool2d(pathtd[i], kernel_size=3, stride=2, padding=1)
|
|
],
|
|
dim=-1)
|
|
norm_weight = w2[:, i] / (w2[:, i].sum() + self.eps)
|
|
pathtd[i + 1] = (norm_weight * tmp_path).sum(dim=-1)
|
|
# pathtd[i + 1] = w2[0, i] * inputs_clone[i + 1]
|
|
# + w2[1, i] * pathtd[
|
|
# i + 1] + w2[2, i] * F.max_pool2d(
|
|
# pathtd[i], kernel_size=3, stride=2, padding=1)
|
|
pathtd[i + 1] = swish(pathtd[i + 1])
|
|
pathtd[i + 1] = self.bifpn_convs[idx_bifpn](pathtd[i + 1])
|
|
idx_bifpn = idx_bifpn + 1
|
|
|
|
pathtd[levels - 1] = w1[0, levels - 1] * pathtd[levels - 1] + w1[
|
|
1, levels - 1] * F.max_pool2d(
|
|
pathtd[levels - 2], kernel_size=3, stride=2, padding=1)
|
|
pathtd[levels - 1] = swish(pathtd[levels - 1])
|
|
pathtd[levels - 1] = self.bifpn_convs[idx_bifpn](pathtd[levels - 1])
|
|
return pathtd
|