[Feature] Add ResNetV1c. (#692)
* add ResNetV1c * add unit tests * fix lint * update docstring * fix lintpull/710/head
parent
43024cda73
commit
7fcaedcbfb
|
@ -0,0 +1,17 @@
|
|||
# model settings
|
||||
model = dict(
|
||||
type='ImageClassifier',
|
||||
backbone=dict(
|
||||
type='ResNetV1c',
|
||||
depth=50,
|
||||
num_stages=4,
|
||||
out_indices=(3, ),
|
||||
style='pytorch'),
|
||||
neck=dict(type='GlobalAveragePooling'),
|
||||
head=dict(
|
||||
type='LinearClsHead',
|
||||
num_classes=1000,
|
||||
in_channels=2048,
|
||||
loss=dict(type='CrossEntropyLoss', loss_weight=1.0),
|
||||
topk=(1, 5),
|
||||
))
|
|
@ -40,6 +40,9 @@ The depth of representations is of central importance for many visual recognitio
|
|||
| ResNet-50 | 25.56 | 4.12 | 76.55 | 93.06 | [config](https://github.com/open-mmlab/mmclassification/blob/master/configs/resnet/resnet50_8xb32_in1k.py) | [model](https://download.openmmlab.com/mmclassification/v0/resnet/resnet50_8xb32_in1k_20210831-ea4938fc.pth) | [log](https://download.openmmlab.com/mmclassification/v0/resnet/resnet50_8xb32_in1k_20210831-ea4938fc.log.json) |
|
||||
| ResNet-101 | 44.55 | 7.85 | 77.97 | 94.06 | [config](https://github.com/open-mmlab/mmclassification/blob/master/configs/resnet/resnet101_8xb32_in1k.py) | [model](https://download.openmmlab.com/mmclassification/v0/resnet/resnet101_8xb32_in1k_20210831-539c63f8.pth) | [log](https://download.openmmlab.com/mmclassification/v0/resnet/resnet101_8xb32_in1k_20210831-539c63f8.log.json) |
|
||||
| ResNet-152 | 60.19 | 11.58 | 78.48 | 94.13 | [config](https://github.com/open-mmlab/mmclassification/blob/master/configs/resnet/resnet152_8xb32_in1k.py) | [model](https://download.openmmlab.com/mmclassification/v0/resnet/resnet152_8xb32_in1k_20210901-4d7582fa.pth) | [log](https://download.openmmlab.com/mmclassification/v0/resnet/resnet152_8xb32_in1k_20210901-4d7582fa.log.json) |
|
||||
| ResNetV1C-50 | 25.58 | 4.36 | 77.01 | 93.58 | [config](https://github.com/open-mmlab/mmclassification/blob/master/configs/resnet/resnetv1c50_8xb32_in1k.py) | [model](https://download.openmmlab.com/mmclassification/v0/resnet/resnetv1c50_8xb32_in1k_20220214-3343eccd.pth) | [log](https://download.openmmlab.com/mmclassification/v0/resnet/resnetv1c50_8xb32_in1k_20220214-3343eccd.log.json) |
|
||||
| ResNetV1C-101 | 44.57 | 8.09 | 78.30 | 94.27 | [config](https://github.com/open-mmlab/mmclassification/blob/master/configs/resnet/resnetv1c101_8xb32_in1k.py) | [model](https://download.openmmlab.com/mmclassification/v0/resnet/resnetv1c101_8xb32_in1k_20220214-434fe45f.pth) | [log](https://download.openmmlab.com/mmclassification/v0/resnet/resnetv1c101_8xb32_in1k_20220214-434fe45f.log.json) |
|
||||
| ResNetV1C-152 | 60.21 | 11.82 | 78.76 | 94.41 | [config](https://github.com/open-mmlab/mmclassification/blob/master/configs/resnet/resnetv1c152_8xb32_in1k.py) | [model](https://download.openmmlab.com/mmclassification/v0/resnet/resnetv1c152_8xb32_in1k_20220214-c013291f.pth) | [log](https://download.openmmlab.com/mmclassification/v0/resnet/resnetv1c152_8xb32_in1k_20220214-c013291f.log.json) |
|
||||
| ResNetV1D-50 | 25.58 | 4.36 | 77.54 | 93.57 | [config](https://github.com/open-mmlab/mmclassification/blob/master/configs/resnet/resnetv1d50_8xb32_in1k.py) | [model](https://download.openmmlab.com/mmclassification/v0/resnet/resnetv1d50_b32x8_imagenet_20210531-db14775a.pth) | [log](https://download.openmmlab.com/mmclassification/v0/resnet/resnetv1d50_b32x8_imagenet_20210531-db14775a.log.json) |
|
||||
| ResNetV1D-101 | 44.57 | 8.09 | 78.93 | 94.48 | [config](https://github.com/open-mmlab/mmclassification/blob/master/configs/resnet/resnetv1d101_8xb32_in1k.py) | [model](https://download.openmmlab.com/mmclassification/v0/resnet/resnetv1d101_b32x8_imagenet_20210531-6e13bcd3.pth) | [log](https://download.openmmlab.com/mmclassification/v0/resnet/resnetv1d101_b32x8_imagenet_20210531-6e13bcd3.log.json) |
|
||||
| ResNetV1D-152 | 60.21 | 11.82 | 79.41 | 94.70 | [config](https://github.com/open-mmlab/mmclassification/blob/master/configs/resnet/resnetv1d152_8xb32_in1k.py) | [model](https://download.openmmlab.com/mmclassification/v0/resnet/resnetv1d152_b32x8_imagenet_20210531-278cf22a.pth) | [log](https://download.openmmlab.com/mmclassification/v0/resnet/resnetv1d152_b32x8_imagenet_20210531-278cf22a.log.json) |
|
||||
|
|
|
@ -298,3 +298,42 @@ Models:
|
|||
Top 5 Accuracy: 93.80
|
||||
Weights: https://download.openmmlab.com/mmclassification/v0/resnet/resnet50_8xb256-rsb-a3-100e_in1k_20211228-3493673c.pth
|
||||
Config: configs/resnet/resnet50_8xb256-rsb-a3-100e_in1k.py
|
||||
- Name: resnetv1c50_8xb32_in1k
|
||||
Metadata:
|
||||
FLOPs: 4360000000
|
||||
Parameters: 25580000
|
||||
In Collection: ResNet
|
||||
Results:
|
||||
- Dataset: ImageNet-1k
|
||||
Metrics:
|
||||
Top 1 Accuracy: 77.01
|
||||
Top 5 Accuracy: 93.58
|
||||
Task: Image Classification
|
||||
Weights: https://download.openmmlab.com/mmclassification/v0/resnet/resnetv1c50_8xb32_in1k_20220214-3343eccd.pth
|
||||
Config: configs/resnet/resnetv1c50_8xb32_in1k.py
|
||||
- Name: resnetv1c101_8xb32_in1k
|
||||
Metadata:
|
||||
FLOPs: 8090000000
|
||||
Parameters: 44570000
|
||||
In Collection: ResNet
|
||||
Results:
|
||||
- Dataset: ImageNet-1k
|
||||
Metrics:
|
||||
Top 1 Accuracy: 78.30
|
||||
Top 5 Accuracy: 94.27
|
||||
Task: Image Classification
|
||||
Weights: https://download.openmmlab.com/mmclassification/v0/resnet/resnetv1c101_8xb32_in1k_20220214-434fe45f.pth
|
||||
Config: configs/resnet/resnetv1c101_8xb32_in1k.py
|
||||
- Name: resnetv1c152_8xb32_in1k
|
||||
Metadata:
|
||||
FLOPs: 11820000000
|
||||
Parameters: 60210000
|
||||
In Collection: ResNet
|
||||
Results:
|
||||
- Dataset: ImageNet-1k
|
||||
Metrics:
|
||||
Top 1 Accuracy: 78.76
|
||||
Top 5 Accuracy: 94.41
|
||||
Task: Image Classification
|
||||
Weights: https://download.openmmlab.com/mmclassification/v0/resnet/resnetv1c152_8xb32_in1k_20220214-c013291f.pth
|
||||
Config: configs/resnet/resnetv1c152_8xb32_in1k.py
|
||||
|
|
|
@ -0,0 +1,7 @@
|
|||
_base_ = [
|
||||
'../_base_/models/resnetv1c50.py',
|
||||
'../_base_/datasets/imagenet_bs32_pil_resize.py',
|
||||
'../_base_/schedules/imagenet_bs256.py', '../_base_/default_runtime.py'
|
||||
]
|
||||
|
||||
model = dict(backbone=dict(depth=101))
|
|
@ -0,0 +1,7 @@
|
|||
_base_ = [
|
||||
'../_base_/models/resnetv1c50.py',
|
||||
'../_base_/datasets/imagenet_bs32_pil_resize.py',
|
||||
'../_base_/schedules/imagenet_bs256.py', '../_base_/default_runtime.py'
|
||||
]
|
||||
|
||||
model = dict(backbone=dict(depth=152))
|
|
@ -0,0 +1,5 @@
|
|||
_base_ = [
|
||||
'../_base_/models/resnetv1c50.py',
|
||||
'../_base_/datasets/imagenet_bs32_pil_resize.py',
|
||||
'../_base_/schedules/imagenet_bs256.py', '../_base_/default_runtime.py'
|
||||
]
|
|
@ -13,7 +13,7 @@ from .regnet import RegNet
|
|||
from .repvgg import RepVGG
|
||||
from .res2net import Res2Net
|
||||
from .resnest import ResNeSt
|
||||
from .resnet import ResNet, ResNetV1d
|
||||
from .resnet import ResNet, ResNetV1c, ResNetV1d
|
||||
from .resnet_cifar import ResNet_CIFAR
|
||||
from .resnext import ResNeXt
|
||||
from .seresnet import SEResNet
|
||||
|
@ -34,5 +34,5 @@ __all__ = [
|
|||
'ShuffleNetV2', 'MobileNetV2', 'MobileNetV3', 'VisionTransformer',
|
||||
'SwinTransformer', 'TNT', 'TIMMBackbone', 'T2T_ViT', 'Res2Net', 'RepVGG',
|
||||
'Conformer', 'MlpMixer', 'DistilledVisionTransformer', 'PCPVT', 'SVT',
|
||||
'EfficientNet', 'ConvNeXt', 'HRNet'
|
||||
'EfficientNet', 'ConvNeXt', 'HRNet', 'ResNetV1c'
|
||||
]
|
||||
|
|
|
@ -653,6 +653,22 @@ class ResNet(BaseBackbone):
|
|||
m.eval()
|
||||
|
||||
|
||||
@BACKBONES.register_module()
|
||||
class ResNetV1c(ResNet):
|
||||
"""ResNetV1c backbone.
|
||||
|
||||
This variant is described in `Bag of Tricks.
|
||||
<https://arxiv.org/pdf/1812.01187.pdf>`_.
|
||||
|
||||
Compared with default ResNet(ResNetV1b), ResNetV1c replaces the 7x7 conv
|
||||
in the input stem with three 3x3 convs.
|
||||
"""
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
super(ResNetV1c, self).__init__(
|
||||
deep_stem=True, avg_down=False, **kwargs)
|
||||
|
||||
|
||||
@BACKBONES.register_module()
|
||||
class ResNetV1d(ResNet):
|
||||
"""ResNetV1d backbone.
|
||||
|
|
|
@ -5,7 +5,7 @@ import torch.nn as nn
|
|||
from mmcv.cnn import ConvModule
|
||||
from mmcv.utils.parrots_wrapper import _BatchNorm
|
||||
|
||||
from mmcls.models.backbones import ResNet, ResNetV1d
|
||||
from mmcls.models.backbones import ResNet, ResNetV1c, ResNetV1d
|
||||
from mmcls.models.backbones.resnet import (BasicBlock, Bottleneck, ResLayer,
|
||||
get_expansion)
|
||||
|
||||
|
@ -526,6 +526,45 @@ def test_resnet():
|
|||
assert not all_zeros(m.norm2)
|
||||
|
||||
|
||||
def test_resnet_v1c():
|
||||
model = ResNetV1c(depth=50, out_indices=(0, 1, 2, 3))
|
||||
model.init_weights()
|
||||
model.train()
|
||||
|
||||
assert len(model.stem) == 3
|
||||
for i in range(3):
|
||||
assert isinstance(model.stem[i], ConvModule)
|
||||
|
||||
imgs = torch.randn(1, 3, 224, 224)
|
||||
feat = model.stem(imgs)
|
||||
assert feat.shape == (1, 64, 112, 112)
|
||||
feat = model(imgs)
|
||||
assert len(feat) == 4
|
||||
assert feat[0].shape == (1, 256, 56, 56)
|
||||
assert feat[1].shape == (1, 512, 28, 28)
|
||||
assert feat[2].shape == (1, 1024, 14, 14)
|
||||
assert feat[3].shape == (1, 2048, 7, 7)
|
||||
|
||||
# Test ResNet50V1d with first stage frozen
|
||||
frozen_stages = 1
|
||||
model = ResNetV1d(depth=50, frozen_stages=frozen_stages)
|
||||
assert len(model.stem) == 3
|
||||
for i in range(3):
|
||||
assert isinstance(model.stem[i], ConvModule)
|
||||
model.init_weights()
|
||||
model.train()
|
||||
check_norm_state(model.stem, False)
|
||||
for param in model.stem.parameters():
|
||||
assert param.requires_grad is False
|
||||
for i in range(1, frozen_stages + 1):
|
||||
layer = getattr(model, f'layer{i}')
|
||||
for mod in layer.modules():
|
||||
if isinstance(mod, _BatchNorm):
|
||||
assert mod.training is False
|
||||
for param in layer.parameters():
|
||||
assert param.requires_grad is False
|
||||
|
||||
|
||||
def test_resnet_v1d():
|
||||
model = ResNetV1d(depth=50, out_indices=(0, 1, 2, 3))
|
||||
model.init_weights()
|
||||
|
|
|
@ -4,6 +4,7 @@ import copy
|
|||
import os
|
||||
import os.path as osp
|
||||
import time
|
||||
import warnings
|
||||
|
||||
import mmcv
|
||||
import torch
|
||||
|
|
Loading…
Reference in New Issue