Add VGG and pretained models (#27)
* add vgg * add vgg model coversion tool * fix out_indices and docstr * add vgg models in configs * add params, flops and accuracy in docs * add pretrained models url * use ConvModule and refine var names * update vgg conversion tool * modify bn config * add docs for arch_setting * add unit test for vgg * rm debug code * update vgg pretrained modelspull/52/head
parent
99115fddbc
commit
bc1b08ba41
|
@ -0,0 +1,10 @@
|
|||
# model settings
|
||||
model = dict(
|
||||
type='ImageClassifier',
|
||||
backbone=dict(type='VGG', depth=11, num_classes=1000),
|
||||
neck=None,
|
||||
head=dict(
|
||||
type='ClsHead',
|
||||
loss=dict(type='CrossEntropyLoss', loss_weight=1.0),
|
||||
topk=(1, 5),
|
||||
))
|
|
@ -0,0 +1,11 @@
|
|||
# model settings
|
||||
model = dict(
|
||||
type='ImageClassifier',
|
||||
backbone=dict(
|
||||
type='VGG', depth=11, norm_cfg=dict(type='BN'), num_classes=1000),
|
||||
neck=None,
|
||||
head=dict(
|
||||
type='ClsHead',
|
||||
loss=dict(type='CrossEntropyLoss', loss_weight=1.0),
|
||||
topk=(1, 5),
|
||||
))
|
|
@ -0,0 +1,10 @@
|
|||
# model settings
|
||||
model = dict(
|
||||
type='ImageClassifier',
|
||||
backbone=dict(type='VGG', depth=13, num_classes=1000),
|
||||
neck=None,
|
||||
head=dict(
|
||||
type='ClsHead',
|
||||
loss=dict(type='CrossEntropyLoss', loss_weight=1.0),
|
||||
topk=(1, 5),
|
||||
))
|
|
@ -0,0 +1,11 @@
|
|||
# model settings
|
||||
model = dict(
|
||||
type='ImageClassifier',
|
||||
backbone=dict(
|
||||
type='VGG', depth=13, norm_cfg=dict(type='BN'), num_classes=1000),
|
||||
neck=None,
|
||||
head=dict(
|
||||
type='ClsHead',
|
||||
loss=dict(type='CrossEntropyLoss', loss_weight=1.0),
|
||||
topk=(1, 5),
|
||||
))
|
|
@ -0,0 +1,10 @@
|
|||
# model settings
|
||||
model = dict(
|
||||
type='ImageClassifier',
|
||||
backbone=dict(type='VGG', depth=16, num_classes=1000),
|
||||
neck=None,
|
||||
head=dict(
|
||||
type='ClsHead',
|
||||
loss=dict(type='CrossEntropyLoss', loss_weight=1.0),
|
||||
topk=(1, 5),
|
||||
))
|
|
@ -0,0 +1,11 @@
|
|||
# model settings
|
||||
model = dict(
|
||||
type='ImageClassifier',
|
||||
backbone=dict(
|
||||
type='VGG', depth=16, norm_cfg=dict(type='BN'), num_classes=1000),
|
||||
neck=None,
|
||||
head=dict(
|
||||
type='ClsHead',
|
||||
loss=dict(type='CrossEntropyLoss', loss_weight=1.0),
|
||||
topk=(1, 5),
|
||||
))
|
|
@ -0,0 +1,10 @@
|
|||
# model settings
|
||||
model = dict(
|
||||
type='ImageClassifier',
|
||||
backbone=dict(type='VGG', depth=19, num_classes=1000),
|
||||
neck=None,
|
||||
head=dict(
|
||||
type='ClsHead',
|
||||
loss=dict(type='CrossEntropyLoss', loss_weight=1.0),
|
||||
topk=(1, 5),
|
||||
))
|
|
@ -0,0 +1,11 @@
|
|||
# model settings
|
||||
model = dict(
|
||||
type='ImageClassifier',
|
||||
backbone=dict(
|
||||
type='VGG', depth=19, norm_cfg=dict(type='BN'), num_classes=1000),
|
||||
neck=None,
|
||||
head=dict(
|
||||
type='ClsHead',
|
||||
loss=dict(type='CrossEntropyLoss', loss_weight=1.0),
|
||||
topk=(1, 5),
|
||||
))
|
|
@ -0,0 +1,5 @@
|
|||
_base_ = [
|
||||
'../_base_/models/vgg11.py',
|
||||
'../_base_/datasets/imagenet_bs32_pil_resize.py',
|
||||
'../_base_/default_runtime.py'
|
||||
]
|
|
@ -0,0 +1,5 @@
|
|||
_base_ = [
|
||||
'../_base_/models/vgg11bn.py',
|
||||
'../_base_/datasets/imagenet_bs32_pil_resize.py',
|
||||
'../_base_/default_runtime.py'
|
||||
]
|
|
@ -0,0 +1,5 @@
|
|||
_base_ = [
|
||||
'../_base_/models/vgg13.py',
|
||||
'../_base_/datasets/imagenet_bs32_pil_resize.py',
|
||||
'../_base_/default_runtime.py'
|
||||
]
|
|
@ -0,0 +1,5 @@
|
|||
_base_ = [
|
||||
'../_base_/models/vgg13bn.py',
|
||||
'../_base_/datasets/imagenet_bs32_pil_resize.py',
|
||||
'../_base_/default_runtime.py'
|
||||
]
|
|
@ -0,0 +1,5 @@
|
|||
_base_ = [
|
||||
'../_base_/models/vgg16.py',
|
||||
'../_base_/datasets/imagenet_bs32_pil_resize.py',
|
||||
'../_base_/default_runtime.py'
|
||||
]
|
|
@ -0,0 +1,5 @@
|
|||
_base_ = [
|
||||
'../_base_/models/vgg16bn.py',
|
||||
'../_base_/datasets/imagenet_bs32_pil_resize.py',
|
||||
'../_base_/default_runtime.py'
|
||||
]
|
|
@ -0,0 +1,5 @@
|
|||
_base_ = [
|
||||
'../_base_/models/vgg19.py',
|
||||
'../_base_/datasets/imagenet_bs32_pil_resize.py',
|
||||
'../_base_/default_runtime.py'
|
||||
]
|
|
@ -0,0 +1,5 @@
|
|||
_base_ = [
|
||||
'../_base_/models/vgg19bn.py',
|
||||
'../_base_/datasets/imagenet_bs32_pil_resize.py',
|
||||
'../_base_/default_runtime.py'
|
||||
]
|
|
@ -8,6 +8,14 @@ The ResNet family models below are trained by standard data augmentations, i.e.,
|
|||
|
||||
| Model | Params(M) | Flops(G) | Top-1 (%) | Top-5 (%) | Download |
|
||||
|:---------------------:|:---------:|:--------:|:---------:|:---------:|:--------:|
|
||||
| VGG-11 | 132.86 | 7.63 | 69.03 | 88.63 | [model](https://openmmlab.oss-cn-hangzhou.aliyuncs.com/mmclassification/v0/imagenet/vgg11-01ecd97e.pth)* |
|
||||
| VGG-13 | 133.05 | 11.34 | 69.93 | 89.26 | [model](https://openmmlab.oss-cn-hangzhou.aliyuncs.com/mmclassification/v0/imagenet/vgg13-9ad3945d.pth)*|
|
||||
| VGG-16 | 138.36 | 15.5 | 71.59 | 90.39 | [model](https://openmmlab.oss-cn-hangzhou.aliyuncs.com/mmclassification/v0/imagenet/vgg16-91b6d117.pth)*|
|
||||
| VGG-19 | 143.67 | 19.67 | 72.38 | 90.88 | [model](https://openmmlab.oss-cn-hangzhou.aliyuncs.com/mmclassification/v0/imagenet/vgg19-fee352a8.pth)*|
|
||||
| VGG-11-BN | 132.87 | 7.64 | 70.37 | 89.81 | [model](https://openmmlab.oss-cn-hangzhou.aliyuncs.com/mmclassification/v0/imagenet/vgg11_bn-6fbbbf3f.pth)*|
|
||||
| VGG-13-BN | 133.05 | 11.36 | 71.55 | 90.37 | [model](https://openmmlab.oss-cn-hangzhou.aliyuncs.com/mmclassification/v0/imagenet/vgg13_bn-4b5f9390.pth)*|
|
||||
| VGG-16-BN | 138.37 | 15.53 | 73.36 | 91.5 | [model](https://openmmlab.oss-cn-hangzhou.aliyuncs.com/mmclassification/v0/imagenet/vgg16_bn-3ac6d8fd.pth)*|
|
||||
| VGG-19-BN | 143.68 | 19.7 | 74.24 | 91.84 | [model](https://openmmlab.oss-cn-hangzhou.aliyuncs.com/mmclassification/v0/imagenet/vgg19_bn-7c058385.pth)*|
|
||||
| ResNet-18 | 11.69 | 1.82 | 70.07 | 89.44 | [model](https://openmmlab.oss-accelerate.aliyuncs.com/mmclassification/v0/imagenet/resnet18_batch256_20200708-34ab8f90.pth) | [log](https://openmmlab.oss-accelerate.aliyuncs.com/mmclassification/v0/imagenet/resnet18_batch256_20200708-34ab8f90.log.json) |
|
||||
| ResNet-34 | 21.8 | 3.68 | 73.85 | 91.53 | [model](https://openmmlab.oss-accelerate.aliyuncs.com/mmclassification/v0/imagenet/resnet34_batch256_20200708-32ffb4f7.pth) | [log](https://openmmlab.oss-accelerate.aliyuncs.com/mmclassification/v0/imagenet/resnet34_batch256_20200708-32ffb4f7.log.json) |
|
||||
| ResNet-50 | 25.56 | 4.12 | 76.55 | 93.15 | [model](https://openmmlab.oss-accelerate.aliyuncs.com/mmclassification/v0/imagenet/resnet50_batch256_20200708-cfb998bf.pth) | [log](https://openmmlab.oss-accelerate.aliyuncs.com/mmclassification/v0/imagenet/resnet50_batch256_20200708-cfb998bf.log.json) |
|
||||
|
|
|
@ -11,9 +11,10 @@ from .seresnet import SEResNet
|
|||
from .seresnext import SEResNeXt
|
||||
from .shufflenet_v1 import ShuffleNetV1
|
||||
from .shufflenet_v2 import ShuffleNetV2
|
||||
from .vgg import VGG
|
||||
|
||||
__all__ = [
|
||||
'LeNet5', 'AlexNet', 'RegNet', 'ResNet', 'ResNeXt', 'ResNetV1d', 'ResNeSt',
|
||||
'ResNet_CIFAR', 'SEResNet', 'SEResNeXt', 'ShuffleNetV1', 'ShuffleNetV2',
|
||||
'MobileNetV2', 'MobileNetv3'
|
||||
'LeNet5', 'AlexNet', 'VGG', 'RegNet', 'ResNet', 'ResNeXt', 'ResNetV1d',
|
||||
'ResNeSt', 'ResNet_CIFAR', 'SEResNet', 'SEResNeXt', 'ShuffleNetV1',
|
||||
'ShuffleNetV2', 'MobileNetV2', 'MobileNetv3'
|
||||
]
|
||||
|
|
|
@ -0,0 +1,194 @@
|
|||
import torch.nn as nn
|
||||
from mmcv.cnn import ConvModule, constant_init, kaiming_init, normal_init
|
||||
from mmcv.utils.parrots_wrapper import _BatchNorm
|
||||
|
||||
from ..builder import BACKBONES
|
||||
from .base_backbone import BaseBackbone
|
||||
|
||||
|
||||
def make_vgg_layer(in_channels,
|
||||
out_channels,
|
||||
num_blocks,
|
||||
conv_cfg=None,
|
||||
norm_cfg=None,
|
||||
act_cfg=dict(type='ReLU'),
|
||||
dilation=1,
|
||||
with_norm=False,
|
||||
ceil_mode=False):
|
||||
layers = []
|
||||
for _ in range(num_blocks):
|
||||
layer = ConvModule(
|
||||
in_channels=in_channels,
|
||||
out_channels=out_channels,
|
||||
kernel_size=3,
|
||||
dilation=dilation,
|
||||
padding=dilation,
|
||||
bias=True,
|
||||
conv_cfg=conv_cfg,
|
||||
norm_cfg=norm_cfg,
|
||||
act_cfg=act_cfg)
|
||||
layers.append(layer)
|
||||
in_channels = out_channels
|
||||
layers.append(nn.MaxPool2d(kernel_size=2, stride=2, ceil_mode=ceil_mode))
|
||||
|
||||
return layers
|
||||
|
||||
|
||||
@BACKBONES.register_module()
|
||||
class VGG(BaseBackbone):
|
||||
"""VGG backbone.
|
||||
|
||||
Args:
|
||||
depth (int): Depth of vgg, from {11, 13, 16, 19}.
|
||||
with_norm (bool): Use BatchNorm or not.
|
||||
num_classes (int): number of classes for classification.
|
||||
num_stages (int): VGG stages, normally 5.
|
||||
dilations (Sequence[int]): Dilation of each stage.
|
||||
out_indices (Sequence[int]): Output from which stages. If only one
|
||||
stage is specified, a single tensor (feature map) is returned,
|
||||
otherwise multiple stages are specified, a tuple of tensors will
|
||||
be returned. When it is None, the default behavior depends on
|
||||
whether num_classes is specified. If num_classes <= 0, the default
|
||||
value is (4, ), outputing the last feature map before classifier.
|
||||
If num_classes > 0, the default value is (5, ), outputing the
|
||||
classification score. Default: None.
|
||||
frozen_stages (int): Stages to be frozen (all param fixed). -1 means
|
||||
not freezing any parameters.
|
||||
norm_eval (bool): Whether to set norm layers to eval mode, namely,
|
||||
freeze running stats (mean and var). Note: Effect on Batch Norm
|
||||
and its variants only. Default: False.
|
||||
ceil_mode (bool): Whether to use ceil_mode of MaxPool. Default: False.
|
||||
with_last_pool (bool): Whether to keep the last pooling before
|
||||
classifier. Default: True.
|
||||
"""
|
||||
|
||||
# Parameters to build layers. Each element specifies the number of conv in
|
||||
# each stage. For example, VGG11 contains 11 layers with learnable
|
||||
# parameters. 11 is computed as 11 = (1 + 1 + 2 + 2 + 2) + 3,
|
||||
# where 3 indicates the last three fully-connected layers.
|
||||
arch_settings = {
|
||||
11: (1, 1, 2, 2, 2),
|
||||
13: (2, 2, 2, 2, 2),
|
||||
16: (2, 2, 3, 3, 3),
|
||||
19: (2, 2, 4, 4, 4)
|
||||
}
|
||||
|
||||
def __init__(self,
|
||||
depth,
|
||||
num_classes=-1,
|
||||
num_stages=5,
|
||||
dilations=(1, 1, 1, 1, 1),
|
||||
out_indices=None,
|
||||
frozen_stages=-1,
|
||||
conv_cfg=None,
|
||||
norm_cfg=None,
|
||||
act_cfg=dict(type='ReLU'),
|
||||
norm_eval=False,
|
||||
ceil_mode=False,
|
||||
with_last_pool=True):
|
||||
super(VGG, self).__init__()
|
||||
if depth not in self.arch_settings:
|
||||
raise KeyError(f'invalid depth {depth} for vgg')
|
||||
assert num_stages >= 1 and num_stages <= 5
|
||||
stage_blocks = self.arch_settings[depth]
|
||||
self.stage_blocks = stage_blocks[:num_stages]
|
||||
assert len(dilations) == num_stages
|
||||
|
||||
self.num_classes = num_classes
|
||||
self.frozen_stages = frozen_stages
|
||||
self.norm_eval = norm_eval
|
||||
with_norm = norm_cfg is not None
|
||||
|
||||
if out_indices is None:
|
||||
out_indices = (5, ) if num_classes > 0 else (4, )
|
||||
assert max(out_indices) <= num_stages
|
||||
self.out_indices = out_indices
|
||||
|
||||
self.in_channels = 3
|
||||
start_idx = 0
|
||||
vgg_layers = []
|
||||
self.range_sub_modules = []
|
||||
for i, num_blocks in enumerate(self.stage_blocks):
|
||||
num_modules = num_blocks + 1
|
||||
end_idx = start_idx + num_modules
|
||||
dilation = dilations[i]
|
||||
out_channels = 64 * 2**i if i < 4 else 512
|
||||
vgg_layer = make_vgg_layer(
|
||||
self.in_channels,
|
||||
out_channels,
|
||||
num_blocks,
|
||||
conv_cfg=conv_cfg,
|
||||
norm_cfg=norm_cfg,
|
||||
act_cfg=act_cfg,
|
||||
dilation=dilation,
|
||||
with_norm=with_norm,
|
||||
ceil_mode=ceil_mode)
|
||||
vgg_layers.extend(vgg_layer)
|
||||
self.in_channels = out_channels
|
||||
self.range_sub_modules.append([start_idx, end_idx])
|
||||
start_idx = end_idx
|
||||
if not with_last_pool:
|
||||
vgg_layers.pop(-1)
|
||||
self.range_sub_modules[-1][1] -= 1
|
||||
self.module_name = 'features'
|
||||
self.add_module(self.module_name, nn.Sequential(*vgg_layers))
|
||||
|
||||
if self.num_classes > 0:
|
||||
self.classifier = nn.Sequential(
|
||||
nn.Linear(512 * 7 * 7, 4096),
|
||||
nn.ReLU(True),
|
||||
nn.Dropout(),
|
||||
nn.Linear(4096, 4096),
|
||||
nn.ReLU(True),
|
||||
nn.Dropout(),
|
||||
nn.Linear(4096, num_classes),
|
||||
)
|
||||
|
||||
def init_weights(self, pretrained=None):
|
||||
super(VGG, self).init_weights(pretrained)
|
||||
if pretrained is None:
|
||||
for m in self.modules():
|
||||
if isinstance(m, nn.Conv2d):
|
||||
kaiming_init(m)
|
||||
elif isinstance(m, _BatchNorm):
|
||||
constant_init(m, 1)
|
||||
elif isinstance(m, nn.Linear):
|
||||
normal_init(m, std=0.01)
|
||||
else:
|
||||
raise TypeError('pretrained must be a str or None')
|
||||
|
||||
def forward(self, x):
|
||||
outs = []
|
||||
vgg_layers = getattr(self, self.module_name)
|
||||
for i in range(len(self.stage_blocks)):
|
||||
for j in range(*self.range_sub_modules[i]):
|
||||
vgg_layer = vgg_layers[j]
|
||||
x = vgg_layer(x)
|
||||
if i in self.out_indices:
|
||||
outs.append(x)
|
||||
if self.num_classes > 0:
|
||||
x = x.view(x.size(0), -1)
|
||||
x = self.classifier(x)
|
||||
outs.append(x)
|
||||
if len(outs) == 1:
|
||||
return outs[0]
|
||||
else:
|
||||
return tuple(outs)
|
||||
|
||||
def _freeze_stages(self):
|
||||
vgg_layers = getattr(self, self.module_name)
|
||||
for i in range(self.frozen_stages):
|
||||
for j in range(*self.range_sub_modules[i]):
|
||||
m = vgg_layers[j]
|
||||
m.eval()
|
||||
for param in m.parameters():
|
||||
param.requires_grad = False
|
||||
|
||||
def train(self, mode=True):
|
||||
super(VGG, self).train(mode)
|
||||
self._freeze_stages()
|
||||
if mode and self.norm_eval:
|
||||
for m in self.modules():
|
||||
# trick: eval have effect on BatchNorm only
|
||||
if isinstance(m, _BatchNorm):
|
||||
m.eval()
|
|
@ -0,0 +1,136 @@
|
|||
import pytest
|
||||
import torch
|
||||
from mmcv.utils.parrots_wrapper import _BatchNorm
|
||||
|
||||
from mmcls.models.backbones import VGG
|
||||
|
||||
|
||||
def check_norm_state(modules, train_state):
|
||||
"""Check if norm layer is in correct train state."""
|
||||
for mod in modules:
|
||||
if isinstance(mod, _BatchNorm):
|
||||
if mod.training != train_state:
|
||||
return False
|
||||
return True
|
||||
|
||||
|
||||
def test_vgg():
|
||||
"""Test VGG backbone"""
|
||||
with pytest.raises(KeyError):
|
||||
# VGG depth should be in [11, 13, 16, 19]
|
||||
VGG(18)
|
||||
|
||||
with pytest.raises(AssertionError):
|
||||
# In VGG: 1 <= num_stages <= 5
|
||||
VGG(11, num_stages=0)
|
||||
|
||||
with pytest.raises(AssertionError):
|
||||
# In VGG: 1 <= num_stages <= 5
|
||||
VGG(11, num_stages=6)
|
||||
|
||||
with pytest.raises(AssertionError):
|
||||
# len(dilations) == num_stages
|
||||
VGG(11, dilations=(1, 1), num_stages=3)
|
||||
|
||||
with pytest.raises(TypeError):
|
||||
# pretrained must be a string path
|
||||
model = VGG(11)
|
||||
model.init_weights(pretrained=0)
|
||||
|
||||
# Test VGG11 norm_eval=True
|
||||
model = VGG(11, norm_eval=True)
|
||||
model.init_weights()
|
||||
model.train()
|
||||
assert check_norm_state(model.modules(), False)
|
||||
|
||||
# Test VGG11 forward without classifiers
|
||||
model = VGG(11, out_indices=(0, 1, 2, 3, 4))
|
||||
model.init_weights()
|
||||
model.train()
|
||||
|
||||
imgs = torch.randn(1, 3, 224, 224)
|
||||
feat = model(imgs)
|
||||
assert len(feat) == 5
|
||||
assert feat[0].shape == (1, 64, 112, 112)
|
||||
assert feat[1].shape == (1, 128, 56, 56)
|
||||
assert feat[2].shape == (1, 256, 28, 28)
|
||||
assert feat[3].shape == (1, 512, 14, 14)
|
||||
assert feat[4].shape == (1, 512, 7, 7)
|
||||
|
||||
# Test VGG11 forward with classifiers
|
||||
model = VGG(11, num_classes=10, out_indices=(0, 1, 2, 3, 4, 5))
|
||||
model.init_weights()
|
||||
model.train()
|
||||
|
||||
imgs = torch.randn(1, 3, 224, 224)
|
||||
feat = model(imgs)
|
||||
assert len(feat) == 6
|
||||
assert feat[0].shape == (1, 64, 112, 112)
|
||||
assert feat[1].shape == (1, 128, 56, 56)
|
||||
assert feat[2].shape == (1, 256, 28, 28)
|
||||
assert feat[3].shape == (1, 512, 14, 14)
|
||||
assert feat[4].shape == (1, 512, 7, 7)
|
||||
assert feat[5].shape == (1, 10)
|
||||
|
||||
# Test VGG11BN forward
|
||||
model = VGG(11, norm_cfg=dict(type='BN'), out_indices=(0, 1, 2, 3, 4))
|
||||
model.init_weights()
|
||||
model.train()
|
||||
|
||||
imgs = torch.randn(1, 3, 224, 224)
|
||||
feat = model(imgs)
|
||||
assert len(feat) == 5
|
||||
assert feat[0].shape == (1, 64, 112, 112)
|
||||
assert feat[1].shape == (1, 128, 56, 56)
|
||||
assert feat[2].shape == (1, 256, 28, 28)
|
||||
assert feat[3].shape == (1, 512, 14, 14)
|
||||
assert feat[4].shape == (1, 512, 7, 7)
|
||||
|
||||
# Test VGG11BN forward with classifiers
|
||||
model = VGG(
|
||||
11,
|
||||
num_classes=10,
|
||||
norm_cfg=dict(type='BN'),
|
||||
out_indices=(0, 1, 2, 3, 4, 5))
|
||||
model.init_weights()
|
||||
model.train()
|
||||
|
||||
imgs = torch.randn(1, 3, 224, 224)
|
||||
feat = model(imgs)
|
||||
assert len(feat) == 6
|
||||
assert feat[0].shape == (1, 64, 112, 112)
|
||||
assert feat[1].shape == (1, 128, 56, 56)
|
||||
assert feat[2].shape == (1, 256, 28, 28)
|
||||
assert feat[3].shape == (1, 512, 14, 14)
|
||||
assert feat[4].shape == (1, 512, 7, 7)
|
||||
assert feat[5].shape == (1, 10)
|
||||
|
||||
# Test VGG13 with layers 1, 2, 3 out forward
|
||||
model = VGG(13, out_indices=(0, 1, 2))
|
||||
model.init_weights()
|
||||
model.train()
|
||||
|
||||
imgs = torch.randn(1, 3, 224, 224)
|
||||
feat = model(imgs)
|
||||
assert len(feat) == 3
|
||||
assert feat[0].shape == (1, 64, 112, 112)
|
||||
assert feat[1].shape == (1, 128, 56, 56)
|
||||
assert feat[2].shape == (1, 256, 28, 28)
|
||||
|
||||
# Test VGG16 with top feature maps out forward
|
||||
model = VGG(16)
|
||||
model.init_weights()
|
||||
model.train()
|
||||
|
||||
imgs = torch.randn(1, 3, 224, 224)
|
||||
feat = model(imgs)
|
||||
assert feat.shape == (1, 512, 7, 7)
|
||||
|
||||
# Test VGG19 with classification score out forward
|
||||
model = VGG(19, num_classes=10)
|
||||
model.init_weights()
|
||||
model.train()
|
||||
|
||||
imgs = torch.randn(1, 3, 224, 224)
|
||||
feat = model(imgs)
|
||||
assert feat.shape == (1, 10)
|
|
@ -0,0 +1,117 @@
|
|||
import argparse
|
||||
import os
|
||||
from collections import OrderedDict
|
||||
|
||||
import torch
|
||||
|
||||
|
||||
def get_layer_maps(layer_num, with_bn):
|
||||
layer_maps = {'conv': {}, 'bn': {}}
|
||||
if with_bn:
|
||||
if layer_num == 11:
|
||||
layer_idxs = [0, 4, 8, 11, 15, 18, 22, 25]
|
||||
elif layer_num == 13:
|
||||
layer_idxs = [0, 3, 7, 10, 14, 17, 21, 24, 28, 31]
|
||||
elif layer_num == 16:
|
||||
layer_idxs = [0, 3, 7, 10, 14, 17, 20, 24, 27, 30, 34, 37, 40]
|
||||
elif layer_num == 19:
|
||||
layer_idxs = [
|
||||
0, 3, 7, 10, 14, 17, 20, 23, 27, 30, 33, 36, 40, 43, 46, 49
|
||||
]
|
||||
else:
|
||||
raise ValueError(f'Invalid number of layers: {layer_num}')
|
||||
for i, layer_idx in enumerate(layer_idxs):
|
||||
if i == 0:
|
||||
new_layer_idx = layer_idx
|
||||
else:
|
||||
new_layer_idx += int((layer_idx - layer_idxs[i - 1]) / 2)
|
||||
layer_maps['conv'][layer_idx] = new_layer_idx
|
||||
layer_maps['bn'][layer_idx + 1] = new_layer_idx
|
||||
else:
|
||||
if layer_num == 11:
|
||||
layer_idxs = [0, 3, 6, 8, 11, 13, 16, 18]
|
||||
new_layer_idxs = [0, 2, 4, 5, 7, 8, 10, 11]
|
||||
elif layer_num == 13:
|
||||
layer_idxs = [0, 2, 5, 7, 10, 12, 15, 17, 20, 22]
|
||||
new_layer_idxs = [0, 1, 3, 4, 6, 7, 9, 10, 12, 13]
|
||||
elif layer_num == 16:
|
||||
layer_idxs = [0, 2, 5, 7, 10, 12, 14, 17, 19, 21, 24, 26, 28]
|
||||
new_layer_idxs = [0, 1, 3, 4, 6, 7, 8, 10, 11, 12, 14, 15, 16]
|
||||
elif layer_num == 19:
|
||||
layer_idxs = [
|
||||
0, 2, 5, 7, 10, 12, 14, 16, 19, 21, 23, 25, 28, 30, 32, 34
|
||||
]
|
||||
new_layer_idxs = [
|
||||
0, 1, 3, 4, 6, 7, 8, 9, 11, 12, 13, 14, 16, 17, 18, 19
|
||||
]
|
||||
else:
|
||||
raise ValueError(f'Invalid number of layers: {layer_num}')
|
||||
|
||||
layer_maps['conv'] = {
|
||||
layer_idx: new_layer_idx
|
||||
for layer_idx, new_layer_idx in zip(layer_idxs, new_layer_idxs)
|
||||
}
|
||||
|
||||
return layer_maps
|
||||
|
||||
|
||||
def convert(src, dst, layer_num, with_bn=False):
|
||||
"""Convert keys in torchvision pretrained VGG models to mmcls
|
||||
style."""
|
||||
|
||||
# load pytorch model
|
||||
assert os.path.isfile(src), f'no checkpoint found at {src}'
|
||||
blobs = torch.load(src, map_location='cpu')
|
||||
|
||||
# convert to pytorch style
|
||||
state_dict = OrderedDict()
|
||||
|
||||
layer_maps = get_layer_maps(layer_num, with_bn)
|
||||
|
||||
prefix = 'backbone'
|
||||
delimiter = '.'
|
||||
for key, weight in blobs.items():
|
||||
if 'features' in key:
|
||||
module, layer_idx, weight_type = key.split(delimiter)
|
||||
new_key = delimiter.join([prefix, key])
|
||||
layer_idx = int(layer_idx)
|
||||
for layer_key, maps in layer_maps.items():
|
||||
if layer_idx in maps:
|
||||
new_layer_idx = maps[layer_idx]
|
||||
new_key = delimiter.join([
|
||||
prefix, 'features',
|
||||
str(new_layer_idx), layer_key, weight_type
|
||||
])
|
||||
state_dict[new_key] = weight
|
||||
print(f'Convert {key} to {new_key}')
|
||||
elif 'classifier' in key:
|
||||
new_key = delimiter.join([prefix, key])
|
||||
state_dict[new_key] = weight
|
||||
print(f'Convert {key} to {new_key}')
|
||||
else:
|
||||
state_dict[key] = weight
|
||||
|
||||
# save checkpoint
|
||||
checkpoint = dict()
|
||||
checkpoint['state_dict'] = state_dict
|
||||
torch.save(checkpoint, dst)
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(description='Convert model keys')
|
||||
parser.add_argument('src', help='src torchvision model path')
|
||||
parser.add_argument('dst', help='save path')
|
||||
parser.add_argument(
|
||||
'--bn', action='store_true', help='whether original vgg has BN')
|
||||
parser.add_argument(
|
||||
'--layer_num',
|
||||
type=int,
|
||||
choices=[11, 13, 16, 19],
|
||||
default=11,
|
||||
help='number of VGG layers')
|
||||
args = parser.parse_args()
|
||||
convert(args.src, args.dst, layer_num=args.layer_num, with_bn=args.bn)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
Loading…
Reference in New Issue