[Feature] Add ResNetV1c. (#692)

* add ResNetV1c

* add unit tests

* fix lint

* update docstring

* fix lint
This commit is contained in:
Ezra-Yu 2022-02-23 11:36:33 +08:00 committed by GitHub
parent 43024cda73
commit 7fcaedcbfb
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
10 changed files with 137 additions and 3 deletions

View File

@ -0,0 +1,17 @@
# model settings
model = dict(
type='ImageClassifier',
backbone=dict(
type='ResNetV1c',
depth=50,
num_stages=4,
out_indices=(3, ),
style='pytorch'),
neck=dict(type='GlobalAveragePooling'),
head=dict(
type='LinearClsHead',
num_classes=1000,
in_channels=2048,
loss=dict(type='CrossEntropyLoss', loss_weight=1.0),
topk=(1, 5),
))

View File

@ -40,6 +40,9 @@ The depth of representations is of central importance for many visual recognitio
| ResNet-50 | 25.56 | 4.12 | 76.55 | 93.06 | [config](https://github.com/open-mmlab/mmclassification/blob/master/configs/resnet/resnet50_8xb32_in1k.py) | [model](https://download.openmmlab.com/mmclassification/v0/resnet/resnet50_8xb32_in1k_20210831-ea4938fc.pth) | [log](https://download.openmmlab.com/mmclassification/v0/resnet/resnet50_8xb32_in1k_20210831-ea4938fc.log.json) | | ResNet-50 | 25.56 | 4.12 | 76.55 | 93.06 | [config](https://github.com/open-mmlab/mmclassification/blob/master/configs/resnet/resnet50_8xb32_in1k.py) | [model](https://download.openmmlab.com/mmclassification/v0/resnet/resnet50_8xb32_in1k_20210831-ea4938fc.pth) | [log](https://download.openmmlab.com/mmclassification/v0/resnet/resnet50_8xb32_in1k_20210831-ea4938fc.log.json) |
| ResNet-101 | 44.55 | 7.85 | 77.97 | 94.06 | [config](https://github.com/open-mmlab/mmclassification/blob/master/configs/resnet/resnet101_8xb32_in1k.py) | [model](https://download.openmmlab.com/mmclassification/v0/resnet/resnet101_8xb32_in1k_20210831-539c63f8.pth) | [log](https://download.openmmlab.com/mmclassification/v0/resnet/resnet101_8xb32_in1k_20210831-539c63f8.log.json) | | ResNet-101 | 44.55 | 7.85 | 77.97 | 94.06 | [config](https://github.com/open-mmlab/mmclassification/blob/master/configs/resnet/resnet101_8xb32_in1k.py) | [model](https://download.openmmlab.com/mmclassification/v0/resnet/resnet101_8xb32_in1k_20210831-539c63f8.pth) | [log](https://download.openmmlab.com/mmclassification/v0/resnet/resnet101_8xb32_in1k_20210831-539c63f8.log.json) |
| ResNet-152 | 60.19 | 11.58 | 78.48 | 94.13 | [config](https://github.com/open-mmlab/mmclassification/blob/master/configs/resnet/resnet152_8xb32_in1k.py) | [model](https://download.openmmlab.com/mmclassification/v0/resnet/resnet152_8xb32_in1k_20210901-4d7582fa.pth) | [log](https://download.openmmlab.com/mmclassification/v0/resnet/resnet152_8xb32_in1k_20210901-4d7582fa.log.json) | | ResNet-152 | 60.19 | 11.58 | 78.48 | 94.13 | [config](https://github.com/open-mmlab/mmclassification/blob/master/configs/resnet/resnet152_8xb32_in1k.py) | [model](https://download.openmmlab.com/mmclassification/v0/resnet/resnet152_8xb32_in1k_20210901-4d7582fa.pth) | [log](https://download.openmmlab.com/mmclassification/v0/resnet/resnet152_8xb32_in1k_20210901-4d7582fa.log.json) |
| ResNetV1C-50 | 25.58 | 4.36 | 77.01 | 93.58 | [config](https://github.com/open-mmlab/mmclassification/blob/master/configs/resnet/resnetv1c50_8xb32_in1k.py) | [model](https://download.openmmlab.com/mmclassification/v0/resnet/resnetv1c50_8xb32_in1k_20220214-3343eccd.pth) | [log](https://download.openmmlab.com/mmclassification/v0/resnet/resnetv1c50_8xb32_in1k_20220214-3343eccd.log.json) |
| ResNetV1C-101 | 44.57 | 8.09 | 78.30 | 94.27 | [config](https://github.com/open-mmlab/mmclassification/blob/master/configs/resnet/resnetv1c101_8xb32_in1k.py) | [model](https://download.openmmlab.com/mmclassification/v0/resnet/resnetv1c101_8xb32_in1k_20220214-434fe45f.pth) | [log](https://download.openmmlab.com/mmclassification/v0/resnet/resnetv1c101_8xb32_in1k_20220214-434fe45f.log.json) |
| ResNetV1C-152 | 60.21 | 11.82 | 78.76 | 94.41 | [config](https://github.com/open-mmlab/mmclassification/blob/master/configs/resnet/resnetv1c152_8xb32_in1k.py) | [model](https://download.openmmlab.com/mmclassification/v0/resnet/resnetv1c152_8xb32_in1k_20220214-c013291f.pth) | [log](https://download.openmmlab.com/mmclassification/v0/resnet/resnetv1c152_8xb32_in1k_20220214-c013291f.log.json) |
| ResNetV1D-50 | 25.58 | 4.36 | 77.54 | 93.57 | [config](https://github.com/open-mmlab/mmclassification/blob/master/configs/resnet/resnetv1d50_8xb32_in1k.py) | [model](https://download.openmmlab.com/mmclassification/v0/resnet/resnetv1d50_b32x8_imagenet_20210531-db14775a.pth) | [log](https://download.openmmlab.com/mmclassification/v0/resnet/resnetv1d50_b32x8_imagenet_20210531-db14775a.log.json) | | ResNetV1D-50 | 25.58 | 4.36 | 77.54 | 93.57 | [config](https://github.com/open-mmlab/mmclassification/blob/master/configs/resnet/resnetv1d50_8xb32_in1k.py) | [model](https://download.openmmlab.com/mmclassification/v0/resnet/resnetv1d50_b32x8_imagenet_20210531-db14775a.pth) | [log](https://download.openmmlab.com/mmclassification/v0/resnet/resnetv1d50_b32x8_imagenet_20210531-db14775a.log.json) |
| ResNetV1D-101 | 44.57 | 8.09 | 78.93 | 94.48 | [config](https://github.com/open-mmlab/mmclassification/blob/master/configs/resnet/resnetv1d101_8xb32_in1k.py) | [model](https://download.openmmlab.com/mmclassification/v0/resnet/resnetv1d101_b32x8_imagenet_20210531-6e13bcd3.pth) | [log](https://download.openmmlab.com/mmclassification/v0/resnet/resnetv1d101_b32x8_imagenet_20210531-6e13bcd3.log.json) | | ResNetV1D-101 | 44.57 | 8.09 | 78.93 | 94.48 | [config](https://github.com/open-mmlab/mmclassification/blob/master/configs/resnet/resnetv1d101_8xb32_in1k.py) | [model](https://download.openmmlab.com/mmclassification/v0/resnet/resnetv1d101_b32x8_imagenet_20210531-6e13bcd3.pth) | [log](https://download.openmmlab.com/mmclassification/v0/resnet/resnetv1d101_b32x8_imagenet_20210531-6e13bcd3.log.json) |
| ResNetV1D-152 | 60.21 | 11.82 | 79.41 | 94.70 | [config](https://github.com/open-mmlab/mmclassification/blob/master/configs/resnet/resnetv1d152_8xb32_in1k.py) | [model](https://download.openmmlab.com/mmclassification/v0/resnet/resnetv1d152_b32x8_imagenet_20210531-278cf22a.pth) | [log](https://download.openmmlab.com/mmclassification/v0/resnet/resnetv1d152_b32x8_imagenet_20210531-278cf22a.log.json) | | ResNetV1D-152 | 60.21 | 11.82 | 79.41 | 94.70 | [config](https://github.com/open-mmlab/mmclassification/blob/master/configs/resnet/resnetv1d152_8xb32_in1k.py) | [model](https://download.openmmlab.com/mmclassification/v0/resnet/resnetv1d152_b32x8_imagenet_20210531-278cf22a.pth) | [log](https://download.openmmlab.com/mmclassification/v0/resnet/resnetv1d152_b32x8_imagenet_20210531-278cf22a.log.json) |

View File

@ -298,3 +298,42 @@ Models:
Top 5 Accuracy: 93.80 Top 5 Accuracy: 93.80
Weights: https://download.openmmlab.com/mmclassification/v0/resnet/resnet50_8xb256-rsb-a3-100e_in1k_20211228-3493673c.pth Weights: https://download.openmmlab.com/mmclassification/v0/resnet/resnet50_8xb256-rsb-a3-100e_in1k_20211228-3493673c.pth
Config: configs/resnet/resnet50_8xb256-rsb-a3-100e_in1k.py Config: configs/resnet/resnet50_8xb256-rsb-a3-100e_in1k.py
- Name: resnetv1c50_8xb32_in1k
Metadata:
FLOPs: 4360000000
Parameters: 25580000
In Collection: ResNet
Results:
- Dataset: ImageNet-1k
Metrics:
Top 1 Accuracy: 77.01
Top 5 Accuracy: 93.58
Task: Image Classification
Weights: https://download.openmmlab.com/mmclassification/v0/resnet/resnetv1c50_8xb32_in1k_20220214-3343eccd.pth
Config: configs/resnet/resnetv1c50_8xb32_in1k.py
- Name: resnetv1c101_8xb32_in1k
Metadata:
FLOPs: 8090000000
Parameters: 44570000
In Collection: ResNet
Results:
- Dataset: ImageNet-1k
Metrics:
Top 1 Accuracy: 78.30
Top 5 Accuracy: 94.27
Task: Image Classification
Weights: https://download.openmmlab.com/mmclassification/v0/resnet/resnetv1c101_8xb32_in1k_20220214-434fe45f.pth
Config: configs/resnet/resnetv1c101_8xb32_in1k.py
- Name: resnetv1c152_8xb32_in1k
Metadata:
FLOPs: 11820000000
Parameters: 60210000
In Collection: ResNet
Results:
- Dataset: ImageNet-1k
Metrics:
Top 1 Accuracy: 78.76
Top 5 Accuracy: 94.41
Task: Image Classification
Weights: https://download.openmmlab.com/mmclassification/v0/resnet/resnetv1c152_8xb32_in1k_20220214-c013291f.pth
Config: configs/resnet/resnetv1c152_8xb32_in1k.py

View File

@ -0,0 +1,7 @@
_base_ = [
'../_base_/models/resnetv1c50.py',
'../_base_/datasets/imagenet_bs32_pil_resize.py',
'../_base_/schedules/imagenet_bs256.py', '../_base_/default_runtime.py'
]
model = dict(backbone=dict(depth=101))

View File

@ -0,0 +1,7 @@
_base_ = [
'../_base_/models/resnetv1c50.py',
'../_base_/datasets/imagenet_bs32_pil_resize.py',
'../_base_/schedules/imagenet_bs256.py', '../_base_/default_runtime.py'
]
model = dict(backbone=dict(depth=152))

View File

@ -0,0 +1,5 @@
_base_ = [
'../_base_/models/resnetv1c50.py',
'../_base_/datasets/imagenet_bs32_pil_resize.py',
'../_base_/schedules/imagenet_bs256.py', '../_base_/default_runtime.py'
]

View File

@ -13,7 +13,7 @@ from .regnet import RegNet
from .repvgg import RepVGG from .repvgg import RepVGG
from .res2net import Res2Net from .res2net import Res2Net
from .resnest import ResNeSt from .resnest import ResNeSt
from .resnet import ResNet, ResNetV1d from .resnet import ResNet, ResNetV1c, ResNetV1d
from .resnet_cifar import ResNet_CIFAR from .resnet_cifar import ResNet_CIFAR
from .resnext import ResNeXt from .resnext import ResNeXt
from .seresnet import SEResNet from .seresnet import SEResNet
@ -34,5 +34,5 @@ __all__ = [
'ShuffleNetV2', 'MobileNetV2', 'MobileNetV3', 'VisionTransformer', 'ShuffleNetV2', 'MobileNetV2', 'MobileNetV3', 'VisionTransformer',
'SwinTransformer', 'TNT', 'TIMMBackbone', 'T2T_ViT', 'Res2Net', 'RepVGG', 'SwinTransformer', 'TNT', 'TIMMBackbone', 'T2T_ViT', 'Res2Net', 'RepVGG',
'Conformer', 'MlpMixer', 'DistilledVisionTransformer', 'PCPVT', 'SVT', 'Conformer', 'MlpMixer', 'DistilledVisionTransformer', 'PCPVT', 'SVT',
'EfficientNet', 'ConvNeXt', 'HRNet' 'EfficientNet', 'ConvNeXt', 'HRNet', 'ResNetV1c'
] ]

View File

@ -653,6 +653,22 @@ class ResNet(BaseBackbone):
m.eval() m.eval()
@BACKBONES.register_module()
class ResNetV1c(ResNet):
"""ResNetV1c backbone.
This variant is described in `Bag of Tricks.
<https://arxiv.org/pdf/1812.01187.pdf>`_.
Compared with default ResNet(ResNetV1b), ResNetV1c replaces the 7x7 conv
in the input stem with three 3x3 convs.
"""
def __init__(self, **kwargs):
super(ResNetV1c, self).__init__(
deep_stem=True, avg_down=False, **kwargs)
@BACKBONES.register_module() @BACKBONES.register_module()
class ResNetV1d(ResNet): class ResNetV1d(ResNet):
"""ResNetV1d backbone. """ResNetV1d backbone.

View File

@ -5,7 +5,7 @@ import torch.nn as nn
from mmcv.cnn import ConvModule from mmcv.cnn import ConvModule
from mmcv.utils.parrots_wrapper import _BatchNorm from mmcv.utils.parrots_wrapper import _BatchNorm
from mmcls.models.backbones import ResNet, ResNetV1d from mmcls.models.backbones import ResNet, ResNetV1c, ResNetV1d
from mmcls.models.backbones.resnet import (BasicBlock, Bottleneck, ResLayer, from mmcls.models.backbones.resnet import (BasicBlock, Bottleneck, ResLayer,
get_expansion) get_expansion)
@ -526,6 +526,45 @@ def test_resnet():
assert not all_zeros(m.norm2) assert not all_zeros(m.norm2)
def test_resnet_v1c():
model = ResNetV1c(depth=50, out_indices=(0, 1, 2, 3))
model.init_weights()
model.train()
assert len(model.stem) == 3
for i in range(3):
assert isinstance(model.stem[i], ConvModule)
imgs = torch.randn(1, 3, 224, 224)
feat = model.stem(imgs)
assert feat.shape == (1, 64, 112, 112)
feat = model(imgs)
assert len(feat) == 4
assert feat[0].shape == (1, 256, 56, 56)
assert feat[1].shape == (1, 512, 28, 28)
assert feat[2].shape == (1, 1024, 14, 14)
assert feat[3].shape == (1, 2048, 7, 7)
# Test ResNet50V1d with first stage frozen
frozen_stages = 1
model = ResNetV1d(depth=50, frozen_stages=frozen_stages)
assert len(model.stem) == 3
for i in range(3):
assert isinstance(model.stem[i], ConvModule)
model.init_weights()
model.train()
check_norm_state(model.stem, False)
for param in model.stem.parameters():
assert param.requires_grad is False
for i in range(1, frozen_stages + 1):
layer = getattr(model, f'layer{i}')
for mod in layer.modules():
if isinstance(mod, _BatchNorm):
assert mod.training is False
for param in layer.parameters():
assert param.requires_grad is False
def test_resnet_v1d(): def test_resnet_v1d():
model = ResNetV1d(depth=50, out_indices=(0, 1, 2, 3)) model = ResNetV1d(depth=50, out_indices=(0, 1, 2, 3))
model.init_weights() model.init_weights()

View File

@ -4,6 +4,7 @@ import copy
import os import os
import os.path as osp import os.path as osp
import time import time
import warnings
import mmcv import mmcv
import torch import torch