[Feature]Add Vit (#214)

* add imagenet bs 4096

* add vit_base_patch16_224_finetune

* add vit_base_patch16_224_pretrain

* add vit_base_patch16_384_finetune

* add vit_base_patch16_384_finetune

* add vit_b_p16_224_finetune_imagenet

* add vit_b_p16_224_pretrain_imagenet

* add vit_b_p16_384_finetune_imagenet

* add vit

* add vit

* add vit head

* vit unitest

* keep up with ClsHead

* test vit

* add flag to determiine whether to calculate acc during training

* Changes related to mmcv1.3.0

* change checkpoint saving interval to 10

* add label smooth

* default_runtime.py recovery

* docformatter

* docformatter

* delete 2 lines of comments

* delete configs/_base_/schedules/imagenet_bs4096.py

* add configs/_base_/schedules/imagenet_bs2048_AdamW.py

* rename imagenet_bs4096.py to imagenet_bs2048_AdamW.py

* add helpers.py

* test vit hybrid backbone

* fix HybridEmbed

* use to_2tuple instead
pull/220/head
whcao 2021-04-16 19:22:41 +08:00 committed by GitHub
parent 7d618e6606
commit affb39fe07
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
15 changed files with 811 additions and 3 deletions

View File

@ -0,0 +1,21 @@
# model settings
model = dict(
type='ImageClassifier',
backbone=dict(
type='VisionTransformer',
num_layers=12,
embed_dim=768,
num_heads=12,
img_size=224,
patch_size=16,
in_channels=3,
feedforward_channels=3072,
drop_rate=0.1),
neck=None,
head=dict(
type='VisionTransformerClsHead',
num_classes=1000,
in_channels=768,
loss=dict(type='CrossEntropyLoss', loss_weight=1.0),
topk=(1, 5),
))

View File

@ -0,0 +1,24 @@
# model settings
model = dict(
type='ImageClassifier',
backbone=dict(
type='VisionTransformer',
num_layers=12,
embed_dim=768,
num_heads=12,
img_size=224,
patch_size=16,
in_channels=3,
feedforward_channels=3072,
drop_rate=0.1,
attn_drop_rate=0.),
neck=None,
head=dict(
type='VisionTransformerClsHead',
num_classes=1000,
in_channels=768,
hidden_dim=3072,
loss=dict(type='LabelSmoothLoss', label_smooth_val=0.1),
topk=(1, 5),
),
train_cfg=dict(mixup=dict(alpha=0.2, num_classes=1000)))

View File

@ -0,0 +1,21 @@
# model settings
model = dict(
type='ImageClassifier',
backbone=dict(
type='VisionTransformer',
num_layers=12,
embed_dim=768,
num_heads=12,
img_size=384,
patch_size=16,
in_channels=3,
feedforward_channels=3072,
drop_rate=0.1),
neck=None,
head=dict(
type='VisionTransformerClsHead',
num_classes=1000,
in_channels=768,
loss=dict(type='CrossEntropyLoss', loss_weight=1.0),
topk=(1, 5),
))

View File

@ -0,0 +1,13 @@
# optimizer
# In ClassyVision, the lr is set to 0.003 for bs4096.
# In this implementation(bs2048), lr = 0.003 / 4096 * (32bs * 64gpus) = 0.0015
optimizer = dict(type='AdamW', lr=0.0015, weight_decay=0.3)
optimizer_config = dict(grad_clip=dict(max_norm=1.0))
# learning policy
lr_config = dict(
policy='CosineAnnealing',
min_lr=0,
warmup='linear',
warmup_iters=10000,
warmup_ratio=1e-4)
runner = dict(type='EpochBasedRunner', max_epochs=300)

View File

@ -0,0 +1,10 @@
# Refer to pytorch-image-models
_base_ = [
'../_base_/models/vit_base_patch16_224_finetune.py',
'../_base_/datasets/imagenet_bs32_pil_resize.py',
'../_base_/schedules/imagenet_bs256_epochstep.py',
'../_base_/default_runtime.py'
]
img_norm_cfg = dict(
mean=[127.5, 127.5, 127.5], std=[127.5, 127.5, 127.5], to_rgb=True)

View File

@ -0,0 +1,6 @@
_base_ = [
'../_base_/models/vit_base_patch16_224_pretrain.py',
'../_base_/datasets/imagenet_bs32_pil_resize.py',
'../_base_/schedules/imagenet_bs2048_AdamW.py',
'../_base_/default_runtime.py'
]

View File

@ -0,0 +1,21 @@
# Refer to pytorch-image-models
_base_ = [
'../_base_/models/vit_base_patch16_384_finetune.py',
'../_base_/datasets/imagenet_bs32_pil_resize.py',
'../_base_/schedules/imagenet_bs256_epochstep.py',
'../_base_/default_runtime.py'
]
img_norm_cfg = dict(
mean=[127.5, 127.5, 127.5], std=[127.5, 127.5, 127.5], to_rgb=True)
test_pipeline = [
dict(type='LoadImageFromFile'),
dict(type='Resize', size=(384, -1), backend='pillow'),
dict(type='CenterCrop', crop_size=384),
dict(type='Normalize', **img_norm_cfg),
dict(type='ImageToTensor', keys=['img']),
dict(type='Collect', keys=['img'])
]
data = dict(test=dict(pipeline=test_pipeline))

View File

@ -12,9 +12,10 @@ from .seresnext import SEResNeXt
from .shufflenet_v1 import ShuffleNetV1
from .shufflenet_v2 import ShuffleNetV2
from .vgg import VGG
from .vision_transformer import VisionTransformer
__all__ = [
'LeNet5', 'AlexNet', 'VGG', 'RegNet', 'ResNet', 'ResNeXt', 'ResNetV1d',
'ResNeSt', 'ResNet_CIFAR', 'SEResNet', 'SEResNeXt', 'ShuffleNetV1',
'ShuffleNetV2', 'MobileNetV2', 'MobileNetv3'
'ShuffleNetV2', 'MobileNetV2', 'MobileNetv3', 'VisionTransformer'
]

View File

@ -0,0 +1,493 @@
import torch
import torch.nn as nn
from mmcv.cnn import (build_activation_layer, build_conv_layer,
build_norm_layer, kaiming_init)
from ..builder import BACKBONES
from ..utils import to_2tuple
from .base_backbone import BaseBackbone
# Modified from mmdet
class FFN(nn.Module):
"""Implements feed-forward networks (FFNs) with residual connection.
Args:
embed_dims (int): The feature dimension. Same as
`MultiheadAttention`.
feedforward_channels (int): The hidden dimension of FFNs.
num_fcs (int, optional): The number of fully-connected layers in
FFNs. Defaluts to 2.
act_cfg (dict, optional): The activation config for FFNs.
dropout (float, optional): Probability of an element to be
zeroed. Default 0.0.
add_residual (bool, optional): Add resudual connection.
Defaults to False.
"""
def __init__(self,
embed_dims,
feedforward_channels,
num_fcs=2,
act_cfg=dict(type='GELU'),
dropout=0.0,
add_residual=True):
super(FFN, self).__init__()
assert num_fcs >= 2, 'num_fcs should be no less ' \
f'than 2. got {num_fcs}.'
self.embed_dims = embed_dims
self.feedforward_channels = feedforward_channels
self.num_fcs = num_fcs
self.act_cfg = act_cfg
self.activate = build_activation_layer(act_cfg)
layers = nn.ModuleList()
in_channels = embed_dims
for _ in range(num_fcs - 1):
layers.append(
nn.Sequential(
nn.Linear(in_channels, feedforward_channels),
self.activate, nn.Dropout(dropout)))
in_channels = feedforward_channels
layers.append(nn.Linear(feedforward_channels, embed_dims))
self.layers = nn.Sequential(*layers)
self.dropout = nn.Dropout(dropout)
self.add_residual = add_residual
self.init_weights()
def init_weights(self):
for m in self.modules():
if isinstance(m, nn.Linear):
# xavier_init(m, distribution='uniform')
# Bias init is different from our API
# therefore initialize them separately
# The initialization is sync with ClassyVision
nn.init.xavier_normal_(m.weight)
nn.init.normal_(m.bias, std=1e-6)
def forward(self, x, residual=None):
"""Forward function for `FFN`."""
out = self.layers(x)
if not self.add_residual:
return out
if residual is None:
residual = x
return residual + self.dropout(out)
def __repr__(self):
"""str: a string that describes the module"""
repr_str = self.__class__.__name__
repr_str += f'(embed_dims={self.embed_dims}, '
repr_str += f'feedforward_channels={self.feedforward_channels}, '
repr_str += f'num_fcs={self.num_fcs}, '
repr_str += f'act_cfg={self.act_cfg}, '
repr_str += f'dropout={self.dropout}, '
repr_str += f'add_residual={self.add_residual})'
return repr_str
# Modified from mmdet
class MultiheadAttention(nn.Module):
"""A warpper for torch.nn.MultiheadAttention.
This module implements MultiheadAttention with residual connection.
Args:
embed_dims (int): The embedding dimension.
num_heads (int): Parallel attention heads. Same as
`nn.MultiheadAttention`.
attn_drop (float): A Dropout layer on attn_output_weights. Default 0.0.
proj_drop (float): The drop out rate after attention. Default 0.0.
"""
def __init__(self, embed_dims, num_heads, attn_drop=0.0, proj_drop=0.0):
super(MultiheadAttention, self).__init__()
assert embed_dims % num_heads == 0, 'embed_dims must be ' \
f'divisible by num_heads. got {embed_dims} and {num_heads}.'
self.embed_dims = embed_dims
self.num_heads = num_heads
self.attn = nn.MultiheadAttention(embed_dims, num_heads, attn_drop)
self.dropout = nn.Dropout(proj_drop)
def forward(self,
x,
key=None,
value=None,
residual=None,
query_pos=None,
key_pos=None,
attn_mask=None,
key_padding_mask=None):
"""Forward function for `MultiheadAttention`.
Args:
x (Tensor): The input query with shape [num_query, bs,
embed_dims]. Same in `nn.MultiheadAttention.forward`.
key (Tensor): The key tensor with shape [num_key, bs,
embed_dims]. Same in `nn.MultiheadAttention.forward`.
Default None. If None, the `query` will be used.
value (Tensor): The value tensor with same shape as `key`.
Same in `nn.MultiheadAttention.forward`. Default None.
If None, the `key` will be used.
residual (Tensor): The tensor used for addition, with the
same shape as `x`. Default None. If None, `x` will be used.
query_pos (Tensor): The positional encoding for query, with
the same shape as `x`. Default None. If not None, it will
be added to `x` before forward function.
key_pos (Tensor): The positional encoding for `key`, with the
same shape as `key`. Default None. If not None, it will
be added to `key` before forward function. If None, and
`query_pos` has the same shape as `key`, then `query_pos`
will be used for `key_pos`.
attn_mask (Tensor): ByteTensor mask with shape [num_query,
num_key]. Same in `nn.MultiheadAttention.forward`.
Default None.
key_padding_mask (Tensor): ByteTensor with shape [bs, num_key].
Same in `nn.MultiheadAttention.forward`. Default None.
Returns:
Tensor: forwarded results with shape [num_query, bs, embed_dims].
"""
query = x
if key is None:
key = query
if value is None:
value = key
if residual is None:
residual = x
if key_pos is None:
if query_pos is not None and key is not None:
if query_pos.shape == key.shape:
key_pos = query_pos
if query_pos is not None:
query = query + query_pos
if key_pos is not None:
key = key + key_pos
out = self.attn(
query,
key,
value=value,
attn_mask=attn_mask,
key_padding_mask=key_padding_mask)[0]
return residual + self.dropout(out)
# Modified from mmdet
class TransformerEncoderLayer(nn.Module):
"""Implements one encoder layer in Vision Transformer.
Args:
embed_dims (int): The feature dimension. Same as `FFN`.
num_heads (int): Parallel attention heads.
feedforward_channels (int): The hidden dimension for FFNs.
attn_drop (float): The drop out rate for attention layer.
Default 0.0.
proj_drop (float): Probability of an element to be zeroed
after the feed forward layer. Default 0.0.
act_cfg (dict): The activation config for FFNs. Defalut GELU.
norm_cfg (dict): Config dict for normalization layer. Default
layer normalization.
num_fcs (int): The number of fully-connected layers for FFNs.
Default 2.
"""
def __init__(self,
embed_dims,
num_heads,
feedforward_channels,
attn_drop=0.,
proj_drop=0.,
act_cfg=dict(type='GELU'),
norm_cfg=dict(type='LN'),
num_fcs=2):
super(TransformerEncoderLayer, self).__init__()
self.norm1_name, norm1 = build_norm_layer(
norm_cfg, embed_dims, postfix=1)
self.add_module(self.norm1_name, norm1)
self.attn = MultiheadAttention(
embed_dims,
num_heads=num_heads,
attn_drop=attn_drop,
proj_drop=proj_drop)
self.norm2_name, norm2 = build_norm_layer(
norm_cfg, embed_dims, postfix=2)
self.add_module(self.norm2_name, norm2)
self.mlp = FFN(embed_dims, feedforward_channels, num_fcs, act_cfg,
proj_drop)
@property
def norm1(self):
return getattr(self, self.norm1_name)
@property
def norm2(self):
return getattr(self, self.norm2_name)
def forward(self, x):
norm_x = self.norm1(x)
# Reason for permute: as the shape of input from pretrained weight
# from pytorch-image-models is [batch_size, num_query, embed_dim],
# but the one from nn.MultiheadAttention is
# [num_query, batch_size, embed_dim]
x = x.permute(1, 0, 2)
norm_x = norm_x.permute(1, 0, 2)
x = self.attn(norm_x, residual=x)
# Convert the shape back to [batch_size, num_query, embed_dim] in
# order to make use of the pretrained weight
x = x.permute(1, 0, 2)
x = self.mlp(self.norm2(x), residual=x)
return x
# Modified from pytorch-image-models
class PatchEmbed(nn.Module):
"""Image to Patch Embedding.
Args:
img_size (int | tuple): The size of input image.
patch_size (int): The size of one patch
in_channels (int): The num of input channels.
embed_dim (int): The dimensions of embedding.
conv_cfg (dict | None): The config dict for conv layers.
Default: None.
"""
def __init__(self,
img_size=224,
patch_size=16,
in_channels=3,
embed_dim=768,
conv_cfg=None):
super(PatchEmbed, self).__init__()
if isinstance(img_size, int):
img_size = to_2tuple(img_size)
# img_size = tuple(repeat(img_size, 2))
elif isinstance(img_size, tuple):
if len(img_size) == 1:
img_size = to_2tuple(img_size[0])
# img_size = tuple(repeat(img_size[0], 2))
assert len(img_size) == 2, \
f'The size of image should have length 1 or 2, ' \
f'but got {len(img_size)}'
self.img_size = img_size
self.patch_size = to_2tuple(patch_size)
num_patches = (self.img_size[1] // self.patch_size[1]) * (
self.img_size[0] // self.patch_size[0])
assert num_patches * self.patch_size[0] * self.patch_size[1] == \
self.img_size[0] * self.img_size[1], \
'The image size H*W must be divisible by patch size'
self.num_patches = num_patches
# Use conv layer to embed
self.projection = build_conv_layer(
conv_cfg,
in_channels,
embed_dim,
kernel_size=patch_size,
stride=patch_size)
self.init_weights()
def init_weights(self):
# Lecun norm from ClassyVision
kaiming_init(self.projection, mode='fan_in', nonlinearity='linear')
def forward(self, x):
B, C, H, W = x.shape
# FIXME look at relaxing size constraints
assert H == self.img_size[0] and W == self.img_size[1], \
f"Input image size ({H}*{W}) doesn't " \
f'match model ({self.img_size[0]}*{self.img_size[1]}).'
# The output size is (B, N, D), where N=H*W/P/P, D is embid_dim
x = self.projection(x).flatten(2).transpose(1, 2)
return x
class HybridEmbed(nn.Module):
"""CNN Feature Map Embedding.
Extract feature map from CNN, flatten, project to embedding dim.
"""
def __init__(self,
backbone,
img_size=224,
feature_size=None,
in_channels=3,
embed_dim=768,
conv_cfg=None):
super().__init__()
assert isinstance(backbone, nn.Module)
if isinstance(img_size, int):
img_size = to_2tuple(img_size)
elif isinstance(img_size, tuple):
if len(img_size) == 1:
img_size = to_2tuple(img_size[0])
assert len(img_size) == 2, \
f'The size of image should have length 1 or 2, ' \
f'but got {len(img_size)}'
self.img_size = img_size
self.backbone = backbone
if feature_size is None:
with torch.no_grad():
# FIXME this is hacky, but most reliable way of
# determining the exact dim of the output feature
# map for all networks, the feature metadata has
# reliable channel and stride info, but using
# stride to calc feature dim requires info about padding of
# each stage that isn't captured.
training = backbone.training
if training:
backbone.eval()
o = self.backbone(
torch.zeros(1, in_channels, img_size[0], img_size[1]))
if isinstance(o, (list, tuple)):
# last feature if backbone outputs list/tuple of features
o = o[-1]
feature_size = o.shape[-2:]
feature_dim = o.shape[1]
backbone.train(training)
else:
feature_size = to_2tuple(feature_size)
if hasattr(self.backbone, 'feature_info'):
feature_dim = self.backbone.feature_info.channels()[-1]
else:
feature_dim = self.backbone.num_features
self.num_patches = feature_size[0] * feature_size[1]
# Use conv layer to embed
self.projection = build_conv_layer(
conv_cfg, feature_dim, embed_dim, kernel_size=1, stride=1)
self.init_weights()
def init_weights(self):
# Lecun norm from ClassyVision
kaiming_init(self.projection, mode='fan_in', nonlinearity='linear')
def forward(self, x):
x = self.backbone(x)
if isinstance(x, (list, tuple)):
# last feature if backbone outputs list/tuple of features
x = x[-1]
x = self.projection(x).flatten(2).transpose(1, 2)
return x
# Modified from pytorch-image-models and mmdet
@BACKBONES.register_module()
class VisionTransformer(BaseBackbone):
""" Vision Transformer
A PyTorch impl of : `An Image is Worth 16x16 Words:
Transformers for Image Recognition at Scale` -
https://arxiv.org/abs/2010.11929
Args:
num_layers (int): Depth of transformer
embed_dim (int): Embedding dimension
num_heads (int): Number of attention heads
img_size (int | tuple): Input image size
patch_size (int | tuple): The patch size
in_channels (int): Number of input channels
feedforward_channels (int): The hidden dimension for FFNs.
drop_rate (float): Probability of an element to be zeroed.
Default 0.0.
attn_drop (float): The drop out rate for attention layer.
Default 0.0.
hybrid_backbone (nn.Module): CNN backbone to use in-place of
PatchEmbed module. Default None.
norm_cfg
norm_cfg (dict): Config dict for normalization layer. Default
layer normalization.
act_cfg (dict): The activation config for FFNs. Defalut GELU.
num_fcs (int): The number of fully-connected layers for FFNs.
Default 2.
"""
def __init__(self,
num_layers=12,
embed_dim=768,
num_heads=12,
img_size=224,
patch_size=16,
in_channels=3,
feedforward_channels=3072,
drop_rate=0.,
attn_drop_rate=0.,
hybrid_backbone=None,
norm_cfg=dict(type='LN'),
act_cfg=dict(type='GELU'),
num_fcs=2):
super(VisionTransformer, self).__init__()
self.embed_dim = embed_dim
if hybrid_backbone is not None:
self.patch_embed = HybridEmbed(
hybrid_backbone,
img_size=img_size,
in_channels=in_channels,
embed_dim=embed_dim)
else:
self.patch_embed = PatchEmbed(
img_size=img_size,
patch_size=patch_size,
in_channels=in_channels,
embed_dim=embed_dim)
num_patches = self.patch_embed.num_patches
self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
self.pos_embed = nn.Parameter(
torch.zeros(1, num_patches + 1, embed_dim))
self.drop_after_pos = nn.Dropout(p=drop_rate)
self.layers = nn.ModuleList()
for _ in range(num_layers):
self.layers.append(
TransformerEncoderLayer(
embed_dim,
num_heads,
feedforward_channels,
attn_drop=attn_drop_rate,
proj_drop=drop_rate,
act_cfg=act_cfg,
norm_cfg=norm_cfg,
num_fcs=num_fcs))
self.norm1_name, norm1 = build_norm_layer(
norm_cfg, embed_dim, postfix=1)
self.add_module(self.norm1_name, norm1)
self.init_weights()
def init_weights(self, pretrained=None):
super(VisionTransformer, self).init_weights(pretrained)
if pretrained is None:
# Modified from ClassyVision
nn.init.normal_(self.pos_embed, std=0.02)
@property
def norm1(self):
return getattr(self, self.norm1_name)
@torch.jit.ignore
def no_weight_decay(self):
return {'pos_embed', 'cls_token'}
def forward(self, x):
B = x.shape[0]
x = self.patch_embed(x)
cls_tokens = self.cls_token.expand(
B, -1, -1) # stole cls_tokens impl from Phil Wang, thanks
x = torch.cat((cls_tokens, x), dim=1)
x = x + self.pos_embed
x = self.drop_after_pos(x)
for layer in self.layers:
x = layer(x)
x = self.norm1(x)[:, 0]
return x

View File

@ -2,7 +2,9 @@ from .cls_head import ClsHead
from .linear_head import LinearClsHead
from .multi_label_head import MultiLabelClsHead
from .multi_label_linear_head import MultiLabelLinearClsHead
from .vision_transformer_head import VisionTransformerClsHead
__all__ = [
'ClsHead', 'LinearClsHead', 'MultiLabelClsHead', 'MultiLabelLinearClsHead'
'ClsHead', 'LinearClsHead', 'MultiLabelClsHead', 'MultiLabelLinearClsHead',
'VisionTransformerClsHead'
]

View File

@ -0,0 +1,84 @@
from collections import OrderedDict
import torch
import torch.nn as nn
import torch.nn.functional as F
from mmcv.cnn import build_activation_layer, constant_init, kaiming_init
from ..builder import HEADS
from .cls_head import ClsHead
@HEADS.register_module()
class VisionTransformerClsHead(ClsHead):
"""Vision Transformer classifier head.
Args:
num_classes (int): Number of categories excluding the background
category.
in_channels (int): Number of channels in the input feature map.
hidden_dim (int): Number of the dimensions for hidden layer. Only
available during pre-training. Default None.
act_cfg (dict): The activation config. Only available during
pre-training. Defalut Tanh.
loss (dict): Config of classification loss.
topk (int | tuple): Top-k accuracy.
cal_acc (bool): Whether to calculate accuracy during training.
If mixup is used, this should be False. Default False.
""" # noqa: W605
def __init__(self,
num_classes,
in_channels,
hidden_dim=None,
act_cfg=dict(type='Tanh'),
loss=dict(type='CrossEntropyLoss', loss_weight=1.0),
topk=(1, ),
cal_acc=False):
super(VisionTransformerClsHead, self).__init__(
loss=loss, topk=topk, cal_acc=cal_acc)
self.in_channels = in_channels
self.num_classes = num_classes
self.hidden_dim = hidden_dim
self.act_cfg = act_cfg
if self.num_classes <= 0:
raise ValueError(
f'num_classes={num_classes} must be a positive integer')
self._init_layers()
def _init_layers(self):
if self.hidden_dim is None:
layers = [('head', nn.Linear(self.in_channels, self.num_classes))]
else:
layers = [
('pre_logits', nn.Linear(self.in_channels, self.hidden_dim)),
('act', build_activation_layer(self.act_cfg)),
('head', nn.Linear(self.hidden_dim, self.num_classes)),
]
self.layers = nn.Sequential(OrderedDict(layers))
def init_weights(self):
# Modified from ClassyVision
if hasattr(self.layers, 'pre_logits'):
# Lecun norm
kaiming_init(
self.layers.pre_logits, mode='fan_in', nonlinearity='linear')
constant_init(self.layers.head, 0)
def simple_test(self, img):
"""Test without augmentation."""
cls_score = self.layers(img)
if isinstance(cls_score, list):
cls_score = sum(cls_score) / float(len(cls_score))
pred = F.softmax(cls_score, dim=1) if cls_score is not None else None
if torch.onnx.is_in_onnx_export():
return pred
pred = list(pred.detach().cpu().numpy())
return pred
def forward_train(self, x, gt_label):
cls_score = self.layers(x)
losses = self.loss(cls_score, gt_label)
return losses

View File

@ -1,5 +1,6 @@
from .channel_shuffle import channel_shuffle
from .cutmix import BatchCutMixLayer
from .helpers import to_2tuple, to_3tuple, to_4tuple, to_ntuple
from .inverted_residual import InvertedResidual
from .make_divisible import make_divisible
from .mixup import BatchMixupLayer
@ -7,5 +8,6 @@ from .se_layer import SELayer
__all__ = [
'channel_shuffle', 'make_divisible', 'InvertedResidual', 'BatchMixupLayer',
'BatchCutMixLayer', 'SELayer'
'BatchCutMixLayer', 'SELayer', 'to_ntuple', 'to_2tuple', 'to_3tuple',
'to_4tuple'
]

View File

@ -0,0 +1,20 @@
import collections.abc
from itertools import repeat
# From PyTorch internals
def _ntuple(n):
def parse(x):
if isinstance(x, collections.abc.Iterable):
return x
return tuple(repeat(x, n))
return parse
to_1tuple = _ntuple(1)
to_2tuple = _ntuple(2)
to_3tuple = _ntuple(3)
to_4tuple = _ntuple(4)
to_ntuple = _ntuple

View File

@ -0,0 +1,57 @@
import pytest
import torch
from torch.nn.modules import GroupNorm
from torch.nn.modules.batchnorm import _BatchNorm
from mmcls.models.backbones import VGG, VisionTransformer
def is_norm(modules):
"""Check if is one of the norms."""
if isinstance(modules, (GroupNorm, _BatchNorm)):
return True
return False
def check_norm_state(modules, train_state):
"""Check if norm layer is in correct train state."""
for mod in modules:
if isinstance(mod, _BatchNorm):
if mod.training != train_state:
return False
return True
def test_vit_backbone():
with pytest.raises(TypeError):
# pretrained must be a string path
model = VisionTransformer()
model.init_weights(pretrained=0)
# Test ViT base model with input size of 224
# and patch size of 16
model = VisionTransformer()
model.init_weights()
model.train()
assert check_norm_state(model.modules(), True)
imgs = torch.randn(1, 3, 224, 224)
feat = model(imgs)
assert feat.shape == torch.Size((1, 768))
def test_vit_hybrid_backbone():
# Test VGG11+ViT-B/16 hybrid model
backbone = VGG(11, norm_eval=True)
backbone.init_weights()
model = VisionTransformer(hybrid_backbone=backbone)
model.init_weights()
model.train()
assert check_norm_state(model.modules(), True)
imgs = torch.randn(1, 3, 224, 224)
feat = model(imgs)
assert feat.shape == torch.Size((1, 768))

View File

@ -82,3 +82,36 @@ def test_image_classifier_with_label_smooth_loss():
losses = img_classifier.forward_train(imgs, label)
assert losses['loss'].item() > 0
def test_image_classifier_vit():
model_cfg = dict(
backbone=dict(
type='VisionTransformer',
num_layers=12,
embed_dim=768,
num_heads=12,
img_size=224,
patch_size=16,
in_channels=3,
feedforward_channels=3072,
drop_rate=0.1,
attn_drop_rate=0.),
neck=None,
head=dict(
type='VisionTransformerClsHead',
num_classes=1000,
in_channels=768,
hidden_dim=3072,
loss=dict(type='CrossEntropyLoss', loss_weight=1.0, use_soft=True),
topk=(1, 5),
),
train_cfg=dict(mixup=dict(alpha=0.2, num_classes=1000)))
img_classifier = ImageClassifier(**model_cfg)
img_classifier.init_weights()
imgs = torch.randn(16, 3, 224, 224)
label = torch.randint(0, 1000, (16, ))
losses = img_classifier.forward_train(imgs, label)
assert losses['loss'].item() > 0