[Feature] support modified resnet structure used in oCLIP (#1458)

* support modified ResNet in CLIP and oCLIP

* update unit test for TestCLIPBottleneck; update docs

* Apply suggestions from code review

* fix

Co-authored-by: Tong Gao <gaotongxiao@gmail.com>
pull/1517/head
Wenqing Zhang 2022-11-03 17:54:15 +08:00 committed by GitHub
parent 1c06edc68f
commit f1dd437d8d
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 228 additions and 1 deletions

View File

@ -1,4 +1,5 @@
# Copyright (c) OpenMMLab. All rights reserved.
from .clip_resnet import CLIPResNet
from .unet import UNet
__all__ = ['UNet']
__all__ = ['UNet', 'CLIPResNet']

View File

@ -0,0 +1,100 @@
# Copyright (c) OpenMMLab. All rights reserved.
import torch.nn as nn
from mmdet.models.backbones import ResNet
from mmdet.models.backbones.resnet import Bottleneck
from mmocr.registry import MODELS
class CLIPBottleneck(Bottleneck):
"""Bottleneck for CLIPResNet.
It is a Bottleneck variant used in the ResNet variant of CLIP. After the
second convolution layer, there is an additional average pooling layer with
kernel_size 2 and stride 2, which is added as a plugin when the
input stride > 1. The stride of each convolution layer is always set to 1.
Args:
**kwargs: Keyword arguments for
:class:``mmdet.models.backbones.resnet.Bottleneck``.
"""
def __init__(self, **kwargs):
stride = kwargs.get('stride', 1)
kwargs['stride'] = 1
plugins = kwargs.get('plugins', None)
if stride > 1:
if plugins is None:
plugins = []
plugins.insert(
0,
dict(
cfg=dict(type='mmocr.AvgPool2d', kernel_size=2),
position='after_conv2'))
kwargs['plugins'] = plugins
super().__init__(**kwargs)
@MODELS.register_module()
class CLIPResNet(ResNet):
"""Implement the ResNet variant used in `oCLIP.
<https://github.com/bytedance/oclip>`_.
It is also the official structure in
`CLIP <https://github.com/openai/CLIP>`_.
Compared with ResNetV1d structure, CLIPResNet replaces the
max pooling layer with an average pooling layer at the end
of the input stem.
In the Bottleneck of CLIPResNet, after the second convolution
layer, there is an additional average pooling layer with
kernel_size 2 and stride 2, which is added as a plugin
when the input stride > 1.
The stride of each convolution layer is always set to 1.
Args:
depth (int): Depth of resnet, options are [50]. Defaults to 50.
strides (sequence(int)): Strides of the first block of each stage.
Defaults to (1, 2, 2, 2).
deep_stem (bool): Replace 7x7 conv in input stem with 3 3x3 conv.
Defaults to True.
avg_down (bool): Use AvgPool instead of stride conv at
the downsampling stage in the bottleneck. Defaults to True.
**kwargs: Keyword arguments for
:class:``mmdet.models.backbones.resnet.ResNet``.
"""
arch_settings = {
50: (CLIPBottleneck, (3, 4, 6, 3)),
}
def __init__(self,
depth=50,
strides=(1, 2, 2, 2),
deep_stem=True,
avg_down=True,
**kwargs):
super().__init__(
depth=depth,
strides=strides,
deep_stem=deep_stem,
avg_down=avg_down,
**kwargs)
def _make_stem_layer(self, in_channels: int, stem_channels: int):
"""Build stem layer for CLIPResNet used in `CLIP
https://github.com/openai/CLIP>`_.
It uses an average pooling layer rather than a max pooling
layer at the end of the input stem.
Args:
in_channels (int): Number of input channels.
stem_channels (int): Number of output channels.
"""
super()._make_stem_layer(in_channels, stem_channels)
if self.deep_stem:
self.maxpool = nn.AvgPool2d(kernel_size=2)

View File

@ -0,0 +1,4 @@
# Copyright (c) OpenMMLab. All rights reserved.
from .common import AvgPool2d
__all__ = ['AvgPool2d']

View File

@ -0,0 +1,40 @@
# Copyright (c) OpenMMLab. All rights reserved.
from typing import Optional, Tuple, Union
import torch
import torch.nn as nn
from mmocr.registry import MODELS
@MODELS.register_module()
class AvgPool2d(nn.Module):
"""Applies a 2D average pooling over an input signal composed of several
input planes.
It can also be used as a network plugin.
Args:
kernel_size (int or tuple(int)): the size of the window.
stride (int or tuple(int), optional): the stride of the window.
Defaults to None.
padding (int or tuple(int)): implicit zero padding. Defaults to 0.
"""
def __init__(self,
kernel_size: Union[int, Tuple[int]],
stride: Optional[Union[int, Tuple[int]]] = None,
padding: Union[int, Tuple[int]] = 0,
**kwargs) -> None:
super().__init__()
self.model = nn.AvgPool2d(kernel_size, stride, padding)
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""Forward function.
Args:
x (Tensor): Input feature map.
Returns:
Tensor: Output tensor after Avgpooling layer.
"""
return self.model(x)

View File

@ -0,0 +1,66 @@
# Copyright (c) OpenMMLab. All rights reserved.
from unittest import TestCase
import torch
import torch.nn as nn
from mmcv.cnn import build_conv_layer, build_norm_layer
from mmocr.models.common.backbones import CLIPResNet
from mmocr.models.common.backbones.clip_resnet import CLIPBottleneck
class TestCLIPResNet(TestCase):
def test_forward(self):
model = CLIPResNet()
model.eval()
imgs = torch.randn(1, 3, 32, 32)
feat = model(imgs)
assert len(feat) == 4
assert feat[0].shape == torch.Size([1, 256, 8, 8])
assert feat[1].shape == torch.Size([1, 512, 4, 4])
assert feat[2].shape == torch.Size([1, 1024, 2, 2])
assert feat[3].shape == torch.Size([1, 2048, 1, 1])
class TestCLIPBottleneck(TestCase):
def test_forward(self):
stride = 2
inplanes = 256
planes = 128
conv_cfg = None
norm_cfg = {'type': 'BN', 'requires_grad': True}
downsample = []
downsample.append(
nn.AvgPool2d(
kernel_size=stride,
stride=stride,
ceil_mode=True,
count_include_pad=False))
downsample.extend([
build_conv_layer(
conv_cfg,
inplanes,
planes * CLIPBottleneck.expansion,
kernel_size=1,
stride=1,
bias=False),
build_norm_layer(norm_cfg, planes * CLIPBottleneck.expansion)[1]
])
downsample = nn.Sequential(*downsample)
model = CLIPBottleneck(
inplanes=inplanes,
planes=planes,
stride=stride,
downsample=downsample,
conv_cfg=conv_cfg,
norm_cfg=norm_cfg)
model.eval()
input_feat = torch.randn(1, 256, 8, 8)
output_feat = model(input_feat)
assert output_feat.shape == torch.Size([1, 512, 4, 4])

View File

@ -0,0 +1,16 @@
# Copyright (c) OpenMMLab. All rights reserved.
from unittest import TestCase
import torch
from mmocr.models.common.plugins import AvgPool2d
class TestAvgPool2d(TestCase):
def setUp(self) -> None:
self.img = torch.rand(1, 3, 32, 100)
def test_avgpool2d(self):
avgpool2d = AvgPool2d(kernel_size=2, stride=2)
self.assertEqual(avgpool2d(self.img).shape, torch.Size([1, 3, 16, 50]))