diff --git a/configs/_base_/models/resnetv1c50.py b/configs/_base_/models/resnetv1c50.py new file mode 100644 index 00000000..3b973e20 --- /dev/null +++ b/configs/_base_/models/resnetv1c50.py @@ -0,0 +1,17 @@ +# model settings +model = dict( + type='ImageClassifier', + backbone=dict( + type='ResNetV1c', + depth=50, + num_stages=4, + out_indices=(3, ), + style='pytorch'), + neck=dict(type='GlobalAveragePooling'), + head=dict( + type='LinearClsHead', + num_classes=1000, + in_channels=2048, + loss=dict(type='CrossEntropyLoss', loss_weight=1.0), + topk=(1, 5), + )) diff --git a/configs/resnet/README.md b/configs/resnet/README.md index 9811c00f..799c0a45 100644 --- a/configs/resnet/README.md +++ b/configs/resnet/README.md @@ -40,6 +40,9 @@ The depth of representations is of central importance for many visual recognitio | ResNet-50 | 25.56 | 4.12 | 76.55 | 93.06 | [config](https://github.com/open-mmlab/mmclassification/blob/master/configs/resnet/resnet50_8xb32_in1k.py) | [model](https://download.openmmlab.com/mmclassification/v0/resnet/resnet50_8xb32_in1k_20210831-ea4938fc.pth) | [log](https://download.openmmlab.com/mmclassification/v0/resnet/resnet50_8xb32_in1k_20210831-ea4938fc.log.json) | | ResNet-101 | 44.55 | 7.85 | 77.97 | 94.06 | [config](https://github.com/open-mmlab/mmclassification/blob/master/configs/resnet/resnet101_8xb32_in1k.py) | [model](https://download.openmmlab.com/mmclassification/v0/resnet/resnet101_8xb32_in1k_20210831-539c63f8.pth) | [log](https://download.openmmlab.com/mmclassification/v0/resnet/resnet101_8xb32_in1k_20210831-539c63f8.log.json) | | ResNet-152 | 60.19 | 11.58 | 78.48 | 94.13 | [config](https://github.com/open-mmlab/mmclassification/blob/master/configs/resnet/resnet152_8xb32_in1k.py) | [model](https://download.openmmlab.com/mmclassification/v0/resnet/resnet152_8xb32_in1k_20210901-4d7582fa.pth) | [log](https://download.openmmlab.com/mmclassification/v0/resnet/resnet152_8xb32_in1k_20210901-4d7582fa.log.json) | +| ResNetV1C-50 | 25.58 | 4.36 | 77.01 | 93.58 | [config](https://github.com/open-mmlab/mmclassification/blob/master/configs/resnet/resnetv1c50_8xb32_in1k.py) | [model](https://download.openmmlab.com/mmclassification/v0/resnet/resnetv1c50_8xb32_in1k_20220214-3343eccd.pth) | [log](https://download.openmmlab.com/mmclassification/v0/resnet/resnetv1c50_8xb32_in1k_20220214-3343eccd.log.json) | +| ResNetV1C-101 | 44.57 | 8.09 | 78.30 | 94.27 | [config](https://github.com/open-mmlab/mmclassification/blob/master/configs/resnet/resnetv1c101_8xb32_in1k.py) | [model](https://download.openmmlab.com/mmclassification/v0/resnet/resnetv1c101_8xb32_in1k_20220214-434fe45f.pth) | [log](https://download.openmmlab.com/mmclassification/v0/resnet/resnetv1c101_8xb32_in1k_20220214-434fe45f.log.json) | +| ResNetV1C-152 | 60.21 | 11.82 | 78.76 | 94.41 | [config](https://github.com/open-mmlab/mmclassification/blob/master/configs/resnet/resnetv1c152_8xb32_in1k.py) | [model](https://download.openmmlab.com/mmclassification/v0/resnet/resnetv1c152_8xb32_in1k_20220214-c013291f.pth) | [log](https://download.openmmlab.com/mmclassification/v0/resnet/resnetv1c152_8xb32_in1k_20220214-c013291f.log.json) | | ResNetV1D-50 | 25.58 | 4.36 | 77.54 | 93.57 | [config](https://github.com/open-mmlab/mmclassification/blob/master/configs/resnet/resnetv1d50_8xb32_in1k.py) | [model](https://download.openmmlab.com/mmclassification/v0/resnet/resnetv1d50_b32x8_imagenet_20210531-db14775a.pth) | [log](https://download.openmmlab.com/mmclassification/v0/resnet/resnetv1d50_b32x8_imagenet_20210531-db14775a.log.json) | | ResNetV1D-101 | 44.57 | 8.09 | 78.93 | 94.48 | [config](https://github.com/open-mmlab/mmclassification/blob/master/configs/resnet/resnetv1d101_8xb32_in1k.py) | [model](https://download.openmmlab.com/mmclassification/v0/resnet/resnetv1d101_b32x8_imagenet_20210531-6e13bcd3.pth) | [log](https://download.openmmlab.com/mmclassification/v0/resnet/resnetv1d101_b32x8_imagenet_20210531-6e13bcd3.log.json) | | ResNetV1D-152 | 60.21 | 11.82 | 79.41 | 94.70 | [config](https://github.com/open-mmlab/mmclassification/blob/master/configs/resnet/resnetv1d152_8xb32_in1k.py) | [model](https://download.openmmlab.com/mmclassification/v0/resnet/resnetv1d152_b32x8_imagenet_20210531-278cf22a.pth) | [log](https://download.openmmlab.com/mmclassification/v0/resnet/resnetv1d152_b32x8_imagenet_20210531-278cf22a.log.json) | diff --git a/configs/resnet/metafile.yml b/configs/resnet/metafile.yml index 719b5d4e..61f52689 100644 --- a/configs/resnet/metafile.yml +++ b/configs/resnet/metafile.yml @@ -298,3 +298,42 @@ Models: Top 5 Accuracy: 93.80 Weights: https://download.openmmlab.com/mmclassification/v0/resnet/resnet50_8xb256-rsb-a3-100e_in1k_20211228-3493673c.pth Config: configs/resnet/resnet50_8xb256-rsb-a3-100e_in1k.py + - Name: resnetv1c50_8xb32_in1k + Metadata: + FLOPs: 4360000000 + Parameters: 25580000 + In Collection: ResNet + Results: + - Dataset: ImageNet-1k + Metrics: + Top 1 Accuracy: 77.01 + Top 5 Accuracy: 93.58 + Task: Image Classification + Weights: https://download.openmmlab.com/mmclassification/v0/resnet/resnetv1c50_8xb32_in1k_20220214-3343eccd.pth + Config: configs/resnet/resnetv1c50_8xb32_in1k.py + - Name: resnetv1c101_8xb32_in1k + Metadata: + FLOPs: 8090000000 + Parameters: 44570000 + In Collection: ResNet + Results: + - Dataset: ImageNet-1k + Metrics: + Top 1 Accuracy: 78.30 + Top 5 Accuracy: 94.27 + Task: Image Classification + Weights: https://download.openmmlab.com/mmclassification/v0/resnet/resnetv1c101_8xb32_in1k_20220214-434fe45f.pth + Config: configs/resnet/resnetv1c101_8xb32_in1k.py + - Name: resnetv1c152_8xb32_in1k + Metadata: + FLOPs: 11820000000 + Parameters: 60210000 + In Collection: ResNet + Results: + - Dataset: ImageNet-1k + Metrics: + Top 1 Accuracy: 78.76 + Top 5 Accuracy: 94.41 + Task: Image Classification + Weights: https://download.openmmlab.com/mmclassification/v0/resnet/resnetv1c152_8xb32_in1k_20220214-c013291f.pth + Config: configs/resnet/resnetv1c152_8xb32_in1k.py diff --git a/configs/resnet/resnetv1c101_8xb32_in1k.py b/configs/resnet/resnetv1c101_8xb32_in1k.py new file mode 100644 index 00000000..441aff59 --- /dev/null +++ b/configs/resnet/resnetv1c101_8xb32_in1k.py @@ -0,0 +1,7 @@ +_base_ = [ + '../_base_/models/resnetv1c50.py', + '../_base_/datasets/imagenet_bs32_pil_resize.py', + '../_base_/schedules/imagenet_bs256.py', '../_base_/default_runtime.py' +] + +model = dict(backbone=dict(depth=101)) diff --git a/configs/resnet/resnetv1c152_8xb32_in1k.py b/configs/resnet/resnetv1c152_8xb32_in1k.py new file mode 100644 index 00000000..b9f466f8 --- /dev/null +++ b/configs/resnet/resnetv1c152_8xb32_in1k.py @@ -0,0 +1,7 @@ +_base_ = [ + '../_base_/models/resnetv1c50.py', + '../_base_/datasets/imagenet_bs32_pil_resize.py', + '../_base_/schedules/imagenet_bs256.py', '../_base_/default_runtime.py' +] + +model = dict(backbone=dict(depth=152)) diff --git a/configs/resnet/resnetv1c50_8xb32_in1k.py b/configs/resnet/resnetv1c50_8xb32_in1k.py new file mode 100644 index 00000000..aa1c8b64 --- /dev/null +++ b/configs/resnet/resnetv1c50_8xb32_in1k.py @@ -0,0 +1,5 @@ +_base_ = [ + '../_base_/models/resnetv1c50.py', + '../_base_/datasets/imagenet_bs32_pil_resize.py', + '../_base_/schedules/imagenet_bs256.py', '../_base_/default_runtime.py' +] diff --git a/mmcls/models/backbones/__init__.py b/mmcls/models/backbones/__init__.py index bc5e5b01..05d20551 100644 --- a/mmcls/models/backbones/__init__.py +++ b/mmcls/models/backbones/__init__.py @@ -13,7 +13,7 @@ from .regnet import RegNet from .repvgg import RepVGG from .res2net import Res2Net from .resnest import ResNeSt -from .resnet import ResNet, ResNetV1d +from .resnet import ResNet, ResNetV1c, ResNetV1d from .resnet_cifar import ResNet_CIFAR from .resnext import ResNeXt from .seresnet import SEResNet @@ -34,5 +34,5 @@ __all__ = [ 'ShuffleNetV2', 'MobileNetV2', 'MobileNetV3', 'VisionTransformer', 'SwinTransformer', 'TNT', 'TIMMBackbone', 'T2T_ViT', 'Res2Net', 'RepVGG', 'Conformer', 'MlpMixer', 'DistilledVisionTransformer', 'PCPVT', 'SVT', - 'EfficientNet', 'ConvNeXt', 'HRNet' + 'EfficientNet', 'ConvNeXt', 'HRNet', 'ResNetV1c' ] diff --git a/mmcls/models/backbones/resnet.py b/mmcls/models/backbones/resnet.py index efb0e2b6..6e938c94 100644 --- a/mmcls/models/backbones/resnet.py +++ b/mmcls/models/backbones/resnet.py @@ -653,6 +653,22 @@ class ResNet(BaseBackbone): m.eval() +@BACKBONES.register_module() +class ResNetV1c(ResNet): + """ResNetV1c backbone. + + This variant is described in `Bag of Tricks. + `_. + + Compared with default ResNet(ResNetV1b), ResNetV1c replaces the 7x7 conv + in the input stem with three 3x3 convs. + """ + + def __init__(self, **kwargs): + super(ResNetV1c, self).__init__( + deep_stem=True, avg_down=False, **kwargs) + + @BACKBONES.register_module() class ResNetV1d(ResNet): """ResNetV1d backbone. diff --git a/tests/test_models/test_backbones/test_resnet.py b/tests/test_models/test_backbones/test_resnet.py index 44be2ff5..8ff8bc8f 100644 --- a/tests/test_models/test_backbones/test_resnet.py +++ b/tests/test_models/test_backbones/test_resnet.py @@ -5,7 +5,7 @@ import torch.nn as nn from mmcv.cnn import ConvModule from mmcv.utils.parrots_wrapper import _BatchNorm -from mmcls.models.backbones import ResNet, ResNetV1d +from mmcls.models.backbones import ResNet, ResNetV1c, ResNetV1d from mmcls.models.backbones.resnet import (BasicBlock, Bottleneck, ResLayer, get_expansion) @@ -526,6 +526,45 @@ def test_resnet(): assert not all_zeros(m.norm2) +def test_resnet_v1c(): + model = ResNetV1c(depth=50, out_indices=(0, 1, 2, 3)) + model.init_weights() + model.train() + + assert len(model.stem) == 3 + for i in range(3): + assert isinstance(model.stem[i], ConvModule) + + imgs = torch.randn(1, 3, 224, 224) + feat = model.stem(imgs) + assert feat.shape == (1, 64, 112, 112) + feat = model(imgs) + assert len(feat) == 4 + assert feat[0].shape == (1, 256, 56, 56) + assert feat[1].shape == (1, 512, 28, 28) + assert feat[2].shape == (1, 1024, 14, 14) + assert feat[3].shape == (1, 2048, 7, 7) + + # Test ResNet50V1d with first stage frozen + frozen_stages = 1 + model = ResNetV1d(depth=50, frozen_stages=frozen_stages) + assert len(model.stem) == 3 + for i in range(3): + assert isinstance(model.stem[i], ConvModule) + model.init_weights() + model.train() + check_norm_state(model.stem, False) + for param in model.stem.parameters(): + assert param.requires_grad is False + for i in range(1, frozen_stages + 1): + layer = getattr(model, f'layer{i}') + for mod in layer.modules(): + if isinstance(mod, _BatchNorm): + assert mod.training is False + for param in layer.parameters(): + assert param.requires_grad is False + + def test_resnet_v1d(): model = ResNetV1d(depth=50, out_indices=(0, 1, 2, 3)) model.init_weights() diff --git a/tools/train.py b/tools/train.py index bceb26de..b40a1df0 100644 --- a/tools/train.py +++ b/tools/train.py @@ -4,6 +4,7 @@ import copy import os import os.path as osp import time +import warnings import mmcv import torch