mirror of
https://github.com/open-mmlab/mmclassification.git
synced 2025-06-03 21:53:55 +08:00
[Feature] Add ResNetV1c. (#692)
* add ResNetV1c * add unit tests * fix lint * update docstring * fix lint
This commit is contained in:
parent
43024cda73
commit
7fcaedcbfb
17
configs/_base_/models/resnetv1c50.py
Normal file
17
configs/_base_/models/resnetv1c50.py
Normal file
@ -0,0 +1,17 @@
|
|||||||
|
# model settings
|
||||||
|
model = dict(
|
||||||
|
type='ImageClassifier',
|
||||||
|
backbone=dict(
|
||||||
|
type='ResNetV1c',
|
||||||
|
depth=50,
|
||||||
|
num_stages=4,
|
||||||
|
out_indices=(3, ),
|
||||||
|
style='pytorch'),
|
||||||
|
neck=dict(type='GlobalAveragePooling'),
|
||||||
|
head=dict(
|
||||||
|
type='LinearClsHead',
|
||||||
|
num_classes=1000,
|
||||||
|
in_channels=2048,
|
||||||
|
loss=dict(type='CrossEntropyLoss', loss_weight=1.0),
|
||||||
|
topk=(1, 5),
|
||||||
|
))
|
@ -40,6 +40,9 @@ The depth of representations is of central importance for many visual recognitio
|
|||||||
| ResNet-50 | 25.56 | 4.12 | 76.55 | 93.06 | [config](https://github.com/open-mmlab/mmclassification/blob/master/configs/resnet/resnet50_8xb32_in1k.py) | [model](https://download.openmmlab.com/mmclassification/v0/resnet/resnet50_8xb32_in1k_20210831-ea4938fc.pth) | [log](https://download.openmmlab.com/mmclassification/v0/resnet/resnet50_8xb32_in1k_20210831-ea4938fc.log.json) |
|
| ResNet-50 | 25.56 | 4.12 | 76.55 | 93.06 | [config](https://github.com/open-mmlab/mmclassification/blob/master/configs/resnet/resnet50_8xb32_in1k.py) | [model](https://download.openmmlab.com/mmclassification/v0/resnet/resnet50_8xb32_in1k_20210831-ea4938fc.pth) | [log](https://download.openmmlab.com/mmclassification/v0/resnet/resnet50_8xb32_in1k_20210831-ea4938fc.log.json) |
|
||||||
| ResNet-101 | 44.55 | 7.85 | 77.97 | 94.06 | [config](https://github.com/open-mmlab/mmclassification/blob/master/configs/resnet/resnet101_8xb32_in1k.py) | [model](https://download.openmmlab.com/mmclassification/v0/resnet/resnet101_8xb32_in1k_20210831-539c63f8.pth) | [log](https://download.openmmlab.com/mmclassification/v0/resnet/resnet101_8xb32_in1k_20210831-539c63f8.log.json) |
|
| ResNet-101 | 44.55 | 7.85 | 77.97 | 94.06 | [config](https://github.com/open-mmlab/mmclassification/blob/master/configs/resnet/resnet101_8xb32_in1k.py) | [model](https://download.openmmlab.com/mmclassification/v0/resnet/resnet101_8xb32_in1k_20210831-539c63f8.pth) | [log](https://download.openmmlab.com/mmclassification/v0/resnet/resnet101_8xb32_in1k_20210831-539c63f8.log.json) |
|
||||||
| ResNet-152 | 60.19 | 11.58 | 78.48 | 94.13 | [config](https://github.com/open-mmlab/mmclassification/blob/master/configs/resnet/resnet152_8xb32_in1k.py) | [model](https://download.openmmlab.com/mmclassification/v0/resnet/resnet152_8xb32_in1k_20210901-4d7582fa.pth) | [log](https://download.openmmlab.com/mmclassification/v0/resnet/resnet152_8xb32_in1k_20210901-4d7582fa.log.json) |
|
| ResNet-152 | 60.19 | 11.58 | 78.48 | 94.13 | [config](https://github.com/open-mmlab/mmclassification/blob/master/configs/resnet/resnet152_8xb32_in1k.py) | [model](https://download.openmmlab.com/mmclassification/v0/resnet/resnet152_8xb32_in1k_20210901-4d7582fa.pth) | [log](https://download.openmmlab.com/mmclassification/v0/resnet/resnet152_8xb32_in1k_20210901-4d7582fa.log.json) |
|
||||||
|
| ResNetV1C-50 | 25.58 | 4.36 | 77.01 | 93.58 | [config](https://github.com/open-mmlab/mmclassification/blob/master/configs/resnet/resnetv1c50_8xb32_in1k.py) | [model](https://download.openmmlab.com/mmclassification/v0/resnet/resnetv1c50_8xb32_in1k_20220214-3343eccd.pth) | [log](https://download.openmmlab.com/mmclassification/v0/resnet/resnetv1c50_8xb32_in1k_20220214-3343eccd.log.json) |
|
||||||
|
| ResNetV1C-101 | 44.57 | 8.09 | 78.30 | 94.27 | [config](https://github.com/open-mmlab/mmclassification/blob/master/configs/resnet/resnetv1c101_8xb32_in1k.py) | [model](https://download.openmmlab.com/mmclassification/v0/resnet/resnetv1c101_8xb32_in1k_20220214-434fe45f.pth) | [log](https://download.openmmlab.com/mmclassification/v0/resnet/resnetv1c101_8xb32_in1k_20220214-434fe45f.log.json) |
|
||||||
|
| ResNetV1C-152 | 60.21 | 11.82 | 78.76 | 94.41 | [config](https://github.com/open-mmlab/mmclassification/blob/master/configs/resnet/resnetv1c152_8xb32_in1k.py) | [model](https://download.openmmlab.com/mmclassification/v0/resnet/resnetv1c152_8xb32_in1k_20220214-c013291f.pth) | [log](https://download.openmmlab.com/mmclassification/v0/resnet/resnetv1c152_8xb32_in1k_20220214-c013291f.log.json) |
|
||||||
| ResNetV1D-50 | 25.58 | 4.36 | 77.54 | 93.57 | [config](https://github.com/open-mmlab/mmclassification/blob/master/configs/resnet/resnetv1d50_8xb32_in1k.py) | [model](https://download.openmmlab.com/mmclassification/v0/resnet/resnetv1d50_b32x8_imagenet_20210531-db14775a.pth) | [log](https://download.openmmlab.com/mmclassification/v0/resnet/resnetv1d50_b32x8_imagenet_20210531-db14775a.log.json) |
|
| ResNetV1D-50 | 25.58 | 4.36 | 77.54 | 93.57 | [config](https://github.com/open-mmlab/mmclassification/blob/master/configs/resnet/resnetv1d50_8xb32_in1k.py) | [model](https://download.openmmlab.com/mmclassification/v0/resnet/resnetv1d50_b32x8_imagenet_20210531-db14775a.pth) | [log](https://download.openmmlab.com/mmclassification/v0/resnet/resnetv1d50_b32x8_imagenet_20210531-db14775a.log.json) |
|
||||||
| ResNetV1D-101 | 44.57 | 8.09 | 78.93 | 94.48 | [config](https://github.com/open-mmlab/mmclassification/blob/master/configs/resnet/resnetv1d101_8xb32_in1k.py) | [model](https://download.openmmlab.com/mmclassification/v0/resnet/resnetv1d101_b32x8_imagenet_20210531-6e13bcd3.pth) | [log](https://download.openmmlab.com/mmclassification/v0/resnet/resnetv1d101_b32x8_imagenet_20210531-6e13bcd3.log.json) |
|
| ResNetV1D-101 | 44.57 | 8.09 | 78.93 | 94.48 | [config](https://github.com/open-mmlab/mmclassification/blob/master/configs/resnet/resnetv1d101_8xb32_in1k.py) | [model](https://download.openmmlab.com/mmclassification/v0/resnet/resnetv1d101_b32x8_imagenet_20210531-6e13bcd3.pth) | [log](https://download.openmmlab.com/mmclassification/v0/resnet/resnetv1d101_b32x8_imagenet_20210531-6e13bcd3.log.json) |
|
||||||
| ResNetV1D-152 | 60.21 | 11.82 | 79.41 | 94.70 | [config](https://github.com/open-mmlab/mmclassification/blob/master/configs/resnet/resnetv1d152_8xb32_in1k.py) | [model](https://download.openmmlab.com/mmclassification/v0/resnet/resnetv1d152_b32x8_imagenet_20210531-278cf22a.pth) | [log](https://download.openmmlab.com/mmclassification/v0/resnet/resnetv1d152_b32x8_imagenet_20210531-278cf22a.log.json) |
|
| ResNetV1D-152 | 60.21 | 11.82 | 79.41 | 94.70 | [config](https://github.com/open-mmlab/mmclassification/blob/master/configs/resnet/resnetv1d152_8xb32_in1k.py) | [model](https://download.openmmlab.com/mmclassification/v0/resnet/resnetv1d152_b32x8_imagenet_20210531-278cf22a.pth) | [log](https://download.openmmlab.com/mmclassification/v0/resnet/resnetv1d152_b32x8_imagenet_20210531-278cf22a.log.json) |
|
||||||
|
@ -298,3 +298,42 @@ Models:
|
|||||||
Top 5 Accuracy: 93.80
|
Top 5 Accuracy: 93.80
|
||||||
Weights: https://download.openmmlab.com/mmclassification/v0/resnet/resnet50_8xb256-rsb-a3-100e_in1k_20211228-3493673c.pth
|
Weights: https://download.openmmlab.com/mmclassification/v0/resnet/resnet50_8xb256-rsb-a3-100e_in1k_20211228-3493673c.pth
|
||||||
Config: configs/resnet/resnet50_8xb256-rsb-a3-100e_in1k.py
|
Config: configs/resnet/resnet50_8xb256-rsb-a3-100e_in1k.py
|
||||||
|
- Name: resnetv1c50_8xb32_in1k
|
||||||
|
Metadata:
|
||||||
|
FLOPs: 4360000000
|
||||||
|
Parameters: 25580000
|
||||||
|
In Collection: ResNet
|
||||||
|
Results:
|
||||||
|
- Dataset: ImageNet-1k
|
||||||
|
Metrics:
|
||||||
|
Top 1 Accuracy: 77.01
|
||||||
|
Top 5 Accuracy: 93.58
|
||||||
|
Task: Image Classification
|
||||||
|
Weights: https://download.openmmlab.com/mmclassification/v0/resnet/resnetv1c50_8xb32_in1k_20220214-3343eccd.pth
|
||||||
|
Config: configs/resnet/resnetv1c50_8xb32_in1k.py
|
||||||
|
- Name: resnetv1c101_8xb32_in1k
|
||||||
|
Metadata:
|
||||||
|
FLOPs: 8090000000
|
||||||
|
Parameters: 44570000
|
||||||
|
In Collection: ResNet
|
||||||
|
Results:
|
||||||
|
- Dataset: ImageNet-1k
|
||||||
|
Metrics:
|
||||||
|
Top 1 Accuracy: 78.30
|
||||||
|
Top 5 Accuracy: 94.27
|
||||||
|
Task: Image Classification
|
||||||
|
Weights: https://download.openmmlab.com/mmclassification/v0/resnet/resnetv1c101_8xb32_in1k_20220214-434fe45f.pth
|
||||||
|
Config: configs/resnet/resnetv1c101_8xb32_in1k.py
|
||||||
|
- Name: resnetv1c152_8xb32_in1k
|
||||||
|
Metadata:
|
||||||
|
FLOPs: 11820000000
|
||||||
|
Parameters: 60210000
|
||||||
|
In Collection: ResNet
|
||||||
|
Results:
|
||||||
|
- Dataset: ImageNet-1k
|
||||||
|
Metrics:
|
||||||
|
Top 1 Accuracy: 78.76
|
||||||
|
Top 5 Accuracy: 94.41
|
||||||
|
Task: Image Classification
|
||||||
|
Weights: https://download.openmmlab.com/mmclassification/v0/resnet/resnetv1c152_8xb32_in1k_20220214-c013291f.pth
|
||||||
|
Config: configs/resnet/resnetv1c152_8xb32_in1k.py
|
||||||
|
7
configs/resnet/resnetv1c101_8xb32_in1k.py
Normal file
7
configs/resnet/resnetv1c101_8xb32_in1k.py
Normal file
@ -0,0 +1,7 @@
|
|||||||
|
_base_ = [
|
||||||
|
'../_base_/models/resnetv1c50.py',
|
||||||
|
'../_base_/datasets/imagenet_bs32_pil_resize.py',
|
||||||
|
'../_base_/schedules/imagenet_bs256.py', '../_base_/default_runtime.py'
|
||||||
|
]
|
||||||
|
|
||||||
|
model = dict(backbone=dict(depth=101))
|
7
configs/resnet/resnetv1c152_8xb32_in1k.py
Normal file
7
configs/resnet/resnetv1c152_8xb32_in1k.py
Normal file
@ -0,0 +1,7 @@
|
|||||||
|
_base_ = [
|
||||||
|
'../_base_/models/resnetv1c50.py',
|
||||||
|
'../_base_/datasets/imagenet_bs32_pil_resize.py',
|
||||||
|
'../_base_/schedules/imagenet_bs256.py', '../_base_/default_runtime.py'
|
||||||
|
]
|
||||||
|
|
||||||
|
model = dict(backbone=dict(depth=152))
|
5
configs/resnet/resnetv1c50_8xb32_in1k.py
Normal file
5
configs/resnet/resnetv1c50_8xb32_in1k.py
Normal file
@ -0,0 +1,5 @@
|
|||||||
|
_base_ = [
|
||||||
|
'../_base_/models/resnetv1c50.py',
|
||||||
|
'../_base_/datasets/imagenet_bs32_pil_resize.py',
|
||||||
|
'../_base_/schedules/imagenet_bs256.py', '../_base_/default_runtime.py'
|
||||||
|
]
|
@ -13,7 +13,7 @@ from .regnet import RegNet
|
|||||||
from .repvgg import RepVGG
|
from .repvgg import RepVGG
|
||||||
from .res2net import Res2Net
|
from .res2net import Res2Net
|
||||||
from .resnest import ResNeSt
|
from .resnest import ResNeSt
|
||||||
from .resnet import ResNet, ResNetV1d
|
from .resnet import ResNet, ResNetV1c, ResNetV1d
|
||||||
from .resnet_cifar import ResNet_CIFAR
|
from .resnet_cifar import ResNet_CIFAR
|
||||||
from .resnext import ResNeXt
|
from .resnext import ResNeXt
|
||||||
from .seresnet import SEResNet
|
from .seresnet import SEResNet
|
||||||
@ -34,5 +34,5 @@ __all__ = [
|
|||||||
'ShuffleNetV2', 'MobileNetV2', 'MobileNetV3', 'VisionTransformer',
|
'ShuffleNetV2', 'MobileNetV2', 'MobileNetV3', 'VisionTransformer',
|
||||||
'SwinTransformer', 'TNT', 'TIMMBackbone', 'T2T_ViT', 'Res2Net', 'RepVGG',
|
'SwinTransformer', 'TNT', 'TIMMBackbone', 'T2T_ViT', 'Res2Net', 'RepVGG',
|
||||||
'Conformer', 'MlpMixer', 'DistilledVisionTransformer', 'PCPVT', 'SVT',
|
'Conformer', 'MlpMixer', 'DistilledVisionTransformer', 'PCPVT', 'SVT',
|
||||||
'EfficientNet', 'ConvNeXt', 'HRNet'
|
'EfficientNet', 'ConvNeXt', 'HRNet', 'ResNetV1c'
|
||||||
]
|
]
|
||||||
|
@ -653,6 +653,22 @@ class ResNet(BaseBackbone):
|
|||||||
m.eval()
|
m.eval()
|
||||||
|
|
||||||
|
|
||||||
|
@BACKBONES.register_module()
|
||||||
|
class ResNetV1c(ResNet):
|
||||||
|
"""ResNetV1c backbone.
|
||||||
|
|
||||||
|
This variant is described in `Bag of Tricks.
|
||||||
|
<https://arxiv.org/pdf/1812.01187.pdf>`_.
|
||||||
|
|
||||||
|
Compared with default ResNet(ResNetV1b), ResNetV1c replaces the 7x7 conv
|
||||||
|
in the input stem with three 3x3 convs.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, **kwargs):
|
||||||
|
super(ResNetV1c, self).__init__(
|
||||||
|
deep_stem=True, avg_down=False, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
@BACKBONES.register_module()
|
@BACKBONES.register_module()
|
||||||
class ResNetV1d(ResNet):
|
class ResNetV1d(ResNet):
|
||||||
"""ResNetV1d backbone.
|
"""ResNetV1d backbone.
|
||||||
|
@ -5,7 +5,7 @@ import torch.nn as nn
|
|||||||
from mmcv.cnn import ConvModule
|
from mmcv.cnn import ConvModule
|
||||||
from mmcv.utils.parrots_wrapper import _BatchNorm
|
from mmcv.utils.parrots_wrapper import _BatchNorm
|
||||||
|
|
||||||
from mmcls.models.backbones import ResNet, ResNetV1d
|
from mmcls.models.backbones import ResNet, ResNetV1c, ResNetV1d
|
||||||
from mmcls.models.backbones.resnet import (BasicBlock, Bottleneck, ResLayer,
|
from mmcls.models.backbones.resnet import (BasicBlock, Bottleneck, ResLayer,
|
||||||
get_expansion)
|
get_expansion)
|
||||||
|
|
||||||
@ -526,6 +526,45 @@ def test_resnet():
|
|||||||
assert not all_zeros(m.norm2)
|
assert not all_zeros(m.norm2)
|
||||||
|
|
||||||
|
|
||||||
|
def test_resnet_v1c():
|
||||||
|
model = ResNetV1c(depth=50, out_indices=(0, 1, 2, 3))
|
||||||
|
model.init_weights()
|
||||||
|
model.train()
|
||||||
|
|
||||||
|
assert len(model.stem) == 3
|
||||||
|
for i in range(3):
|
||||||
|
assert isinstance(model.stem[i], ConvModule)
|
||||||
|
|
||||||
|
imgs = torch.randn(1, 3, 224, 224)
|
||||||
|
feat = model.stem(imgs)
|
||||||
|
assert feat.shape == (1, 64, 112, 112)
|
||||||
|
feat = model(imgs)
|
||||||
|
assert len(feat) == 4
|
||||||
|
assert feat[0].shape == (1, 256, 56, 56)
|
||||||
|
assert feat[1].shape == (1, 512, 28, 28)
|
||||||
|
assert feat[2].shape == (1, 1024, 14, 14)
|
||||||
|
assert feat[3].shape == (1, 2048, 7, 7)
|
||||||
|
|
||||||
|
# Test ResNet50V1d with first stage frozen
|
||||||
|
frozen_stages = 1
|
||||||
|
model = ResNetV1d(depth=50, frozen_stages=frozen_stages)
|
||||||
|
assert len(model.stem) == 3
|
||||||
|
for i in range(3):
|
||||||
|
assert isinstance(model.stem[i], ConvModule)
|
||||||
|
model.init_weights()
|
||||||
|
model.train()
|
||||||
|
check_norm_state(model.stem, False)
|
||||||
|
for param in model.stem.parameters():
|
||||||
|
assert param.requires_grad is False
|
||||||
|
for i in range(1, frozen_stages + 1):
|
||||||
|
layer = getattr(model, f'layer{i}')
|
||||||
|
for mod in layer.modules():
|
||||||
|
if isinstance(mod, _BatchNorm):
|
||||||
|
assert mod.training is False
|
||||||
|
for param in layer.parameters():
|
||||||
|
assert param.requires_grad is False
|
||||||
|
|
||||||
|
|
||||||
def test_resnet_v1d():
|
def test_resnet_v1d():
|
||||||
model = ResNetV1d(depth=50, out_indices=(0, 1, 2, 3))
|
model = ResNetV1d(depth=50, out_indices=(0, 1, 2, 3))
|
||||||
model.init_weights()
|
model.init_weights()
|
||||||
|
@ -4,6 +4,7 @@ import copy
|
|||||||
import os
|
import os
|
||||||
import os.path as osp
|
import os.path as osp
|
||||||
import time
|
import time
|
||||||
|
import warnings
|
||||||
|
|
||||||
import mmcv
|
import mmcv
|
||||||
import torch
|
import torch
|
||||||
|
Loading…
x
Reference in New Issue
Block a user