mirror of https://github.com/open-mmlab/mmocr.git
[Feature] support modified resnet structure used in oCLIP (#1458)
* support modified ResNet in CLIP and oCLIP * update unit test for TestCLIPBottleneck; update docs * Apply suggestions from code review * fix Co-authored-by: Tong Gao <gaotongxiao@gmail.com>pull/1517/head
parent
1c06edc68f
commit
f1dd437d8d
|
@ -1,4 +1,5 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from .clip_resnet import CLIPResNet
|
||||
from .unet import UNet
|
||||
|
||||
__all__ = ['UNet']
|
||||
__all__ = ['UNet', 'CLIPResNet']
|
||||
|
|
|
@ -0,0 +1,100 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
|
||||
import torch.nn as nn
|
||||
from mmdet.models.backbones import ResNet
|
||||
from mmdet.models.backbones.resnet import Bottleneck
|
||||
|
||||
from mmocr.registry import MODELS
|
||||
|
||||
|
||||
class CLIPBottleneck(Bottleneck):
|
||||
"""Bottleneck for CLIPResNet.
|
||||
|
||||
It is a Bottleneck variant used in the ResNet variant of CLIP. After the
|
||||
second convolution layer, there is an additional average pooling layer with
|
||||
kernel_size 2 and stride 2, which is added as a plugin when the
|
||||
input stride > 1. The stride of each convolution layer is always set to 1.
|
||||
|
||||
Args:
|
||||
**kwargs: Keyword arguments for
|
||||
:class:``mmdet.models.backbones.resnet.Bottleneck``.
|
||||
"""
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
stride = kwargs.get('stride', 1)
|
||||
kwargs['stride'] = 1
|
||||
plugins = kwargs.get('plugins', None)
|
||||
if stride > 1:
|
||||
if plugins is None:
|
||||
plugins = []
|
||||
|
||||
plugins.insert(
|
||||
0,
|
||||
dict(
|
||||
cfg=dict(type='mmocr.AvgPool2d', kernel_size=2),
|
||||
position='after_conv2'))
|
||||
kwargs['plugins'] = plugins
|
||||
super().__init__(**kwargs)
|
||||
|
||||
|
||||
@MODELS.register_module()
|
||||
class CLIPResNet(ResNet):
|
||||
"""Implement the ResNet variant used in `oCLIP.
|
||||
|
||||
<https://github.com/bytedance/oclip>`_.
|
||||
|
||||
It is also the official structure in
|
||||
`CLIP <https://github.com/openai/CLIP>`_.
|
||||
|
||||
Compared with ResNetV1d structure, CLIPResNet replaces the
|
||||
max pooling layer with an average pooling layer at the end
|
||||
of the input stem.
|
||||
|
||||
In the Bottleneck of CLIPResNet, after the second convolution
|
||||
layer, there is an additional average pooling layer with
|
||||
kernel_size 2 and stride 2, which is added as a plugin
|
||||
when the input stride > 1.
|
||||
The stride of each convolution layer is always set to 1.
|
||||
|
||||
Args:
|
||||
depth (int): Depth of resnet, options are [50]. Defaults to 50.
|
||||
strides (sequence(int)): Strides of the first block of each stage.
|
||||
Defaults to (1, 2, 2, 2).
|
||||
deep_stem (bool): Replace 7x7 conv in input stem with 3 3x3 conv.
|
||||
Defaults to True.
|
||||
avg_down (bool): Use AvgPool instead of stride conv at
|
||||
the downsampling stage in the bottleneck. Defaults to True.
|
||||
**kwargs: Keyword arguments for
|
||||
:class:``mmdet.models.backbones.resnet.ResNet``.
|
||||
"""
|
||||
arch_settings = {
|
||||
50: (CLIPBottleneck, (3, 4, 6, 3)),
|
||||
}
|
||||
|
||||
def __init__(self,
|
||||
depth=50,
|
||||
strides=(1, 2, 2, 2),
|
||||
deep_stem=True,
|
||||
avg_down=True,
|
||||
**kwargs):
|
||||
super().__init__(
|
||||
depth=depth,
|
||||
strides=strides,
|
||||
deep_stem=deep_stem,
|
||||
avg_down=avg_down,
|
||||
**kwargs)
|
||||
|
||||
def _make_stem_layer(self, in_channels: int, stem_channels: int):
|
||||
"""Build stem layer for CLIPResNet used in `CLIP
|
||||
https://github.com/openai/CLIP>`_.
|
||||
|
||||
It uses an average pooling layer rather than a max pooling
|
||||
layer at the end of the input stem.
|
||||
|
||||
Args:
|
||||
in_channels (int): Number of input channels.
|
||||
stem_channels (int): Number of output channels.
|
||||
"""
|
||||
super()._make_stem_layer(in_channels, stem_channels)
|
||||
if self.deep_stem:
|
||||
self.maxpool = nn.AvgPool2d(kernel_size=2)
|
|
@ -0,0 +1,4 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from .common import AvgPool2d
|
||||
|
||||
__all__ = ['AvgPool2d']
|
|
@ -0,0 +1,40 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from typing import Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from mmocr.registry import MODELS
|
||||
|
||||
|
||||
@MODELS.register_module()
|
||||
class AvgPool2d(nn.Module):
|
||||
"""Applies a 2D average pooling over an input signal composed of several
|
||||
input planes.
|
||||
|
||||
It can also be used as a network plugin.
|
||||
|
||||
Args:
|
||||
kernel_size (int or tuple(int)): the size of the window.
|
||||
stride (int or tuple(int), optional): the stride of the window.
|
||||
Defaults to None.
|
||||
padding (int or tuple(int)): implicit zero padding. Defaults to 0.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
kernel_size: Union[int, Tuple[int]],
|
||||
stride: Optional[Union[int, Tuple[int]]] = None,
|
||||
padding: Union[int, Tuple[int]] = 0,
|
||||
**kwargs) -> None:
|
||||
super().__init__()
|
||||
self.model = nn.AvgPool2d(kernel_size, stride, padding)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
"""Forward function.
|
||||
Args:
|
||||
x (Tensor): Input feature map.
|
||||
|
||||
Returns:
|
||||
Tensor: Output tensor after Avgpooling layer.
|
||||
"""
|
||||
return self.model(x)
|
|
@ -0,0 +1,66 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from unittest import TestCase
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from mmcv.cnn import build_conv_layer, build_norm_layer
|
||||
|
||||
from mmocr.models.common.backbones import CLIPResNet
|
||||
from mmocr.models.common.backbones.clip_resnet import CLIPBottleneck
|
||||
|
||||
|
||||
class TestCLIPResNet(TestCase):
|
||||
|
||||
def test_forward(self):
|
||||
model = CLIPResNet()
|
||||
model.eval()
|
||||
|
||||
imgs = torch.randn(1, 3, 32, 32)
|
||||
feat = model(imgs)
|
||||
assert len(feat) == 4
|
||||
assert feat[0].shape == torch.Size([1, 256, 8, 8])
|
||||
assert feat[1].shape == torch.Size([1, 512, 4, 4])
|
||||
assert feat[2].shape == torch.Size([1, 1024, 2, 2])
|
||||
assert feat[3].shape == torch.Size([1, 2048, 1, 1])
|
||||
|
||||
|
||||
class TestCLIPBottleneck(TestCase):
|
||||
|
||||
def test_forward(self):
|
||||
stride = 2
|
||||
inplanes = 256
|
||||
planes = 128
|
||||
conv_cfg = None
|
||||
norm_cfg = {'type': 'BN', 'requires_grad': True}
|
||||
|
||||
downsample = []
|
||||
downsample.append(
|
||||
nn.AvgPool2d(
|
||||
kernel_size=stride,
|
||||
stride=stride,
|
||||
ceil_mode=True,
|
||||
count_include_pad=False))
|
||||
downsample.extend([
|
||||
build_conv_layer(
|
||||
conv_cfg,
|
||||
inplanes,
|
||||
planes * CLIPBottleneck.expansion,
|
||||
kernel_size=1,
|
||||
stride=1,
|
||||
bias=False),
|
||||
build_norm_layer(norm_cfg, planes * CLIPBottleneck.expansion)[1]
|
||||
])
|
||||
downsample = nn.Sequential(*downsample)
|
||||
|
||||
model = CLIPBottleneck(
|
||||
inplanes=inplanes,
|
||||
planes=planes,
|
||||
stride=stride,
|
||||
downsample=downsample,
|
||||
conv_cfg=conv_cfg,
|
||||
norm_cfg=norm_cfg)
|
||||
model.eval()
|
||||
|
||||
input_feat = torch.randn(1, 256, 8, 8)
|
||||
output_feat = model(input_feat)
|
||||
assert output_feat.shape == torch.Size([1, 512, 4, 4])
|
|
@ -0,0 +1,16 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from unittest import TestCase
|
||||
|
||||
import torch
|
||||
|
||||
from mmocr.models.common.plugins import AvgPool2d
|
||||
|
||||
|
||||
class TestAvgPool2d(TestCase):
|
||||
|
||||
def setUp(self) -> None:
|
||||
self.img = torch.rand(1, 3, 32, 100)
|
||||
|
||||
def test_avgpool2d(self):
|
||||
avgpool2d = AvgPool2d(kernel_size=2, stride=2)
|
||||
self.assertEqual(avgpool2d(self.img).shape, torch.Size([1, 3, 16, 50]))
|
Loading…
Reference in New Issue