[Feature]Add Vit (#214)
* add imagenet bs 4096 * add vit_base_patch16_224_finetune * add vit_base_patch16_224_pretrain * add vit_base_patch16_384_finetune * add vit_base_patch16_384_finetune * add vit_b_p16_224_finetune_imagenet * add vit_b_p16_224_pretrain_imagenet * add vit_b_p16_384_finetune_imagenet * add vit * add vit * add vit head * vit unitest * keep up with ClsHead * test vit * add flag to determiine whether to calculate acc during training * Changes related to mmcv1.3.0 * change checkpoint saving interval to 10 * add label smooth * default_runtime.py recovery * docformatter * docformatter * delete 2 lines of comments * delete configs/_base_/schedules/imagenet_bs4096.py * add configs/_base_/schedules/imagenet_bs2048_AdamW.py * rename imagenet_bs4096.py to imagenet_bs2048_AdamW.py * add helpers.py * test vit hybrid backbone * fix HybridEmbed * use to_2tuple insteadpull/220/head
parent
7d618e6606
commit
affb39fe07
|
@ -0,0 +1,21 @@
|
||||||
|
# model settings
|
||||||
|
model = dict(
|
||||||
|
type='ImageClassifier',
|
||||||
|
backbone=dict(
|
||||||
|
type='VisionTransformer',
|
||||||
|
num_layers=12,
|
||||||
|
embed_dim=768,
|
||||||
|
num_heads=12,
|
||||||
|
img_size=224,
|
||||||
|
patch_size=16,
|
||||||
|
in_channels=3,
|
||||||
|
feedforward_channels=3072,
|
||||||
|
drop_rate=0.1),
|
||||||
|
neck=None,
|
||||||
|
head=dict(
|
||||||
|
type='VisionTransformerClsHead',
|
||||||
|
num_classes=1000,
|
||||||
|
in_channels=768,
|
||||||
|
loss=dict(type='CrossEntropyLoss', loss_weight=1.0),
|
||||||
|
topk=(1, 5),
|
||||||
|
))
|
|
@ -0,0 +1,24 @@
|
||||||
|
# model settings
|
||||||
|
model = dict(
|
||||||
|
type='ImageClassifier',
|
||||||
|
backbone=dict(
|
||||||
|
type='VisionTransformer',
|
||||||
|
num_layers=12,
|
||||||
|
embed_dim=768,
|
||||||
|
num_heads=12,
|
||||||
|
img_size=224,
|
||||||
|
patch_size=16,
|
||||||
|
in_channels=3,
|
||||||
|
feedforward_channels=3072,
|
||||||
|
drop_rate=0.1,
|
||||||
|
attn_drop_rate=0.),
|
||||||
|
neck=None,
|
||||||
|
head=dict(
|
||||||
|
type='VisionTransformerClsHead',
|
||||||
|
num_classes=1000,
|
||||||
|
in_channels=768,
|
||||||
|
hidden_dim=3072,
|
||||||
|
loss=dict(type='LabelSmoothLoss', label_smooth_val=0.1),
|
||||||
|
topk=(1, 5),
|
||||||
|
),
|
||||||
|
train_cfg=dict(mixup=dict(alpha=0.2, num_classes=1000)))
|
|
@ -0,0 +1,21 @@
|
||||||
|
# model settings
|
||||||
|
model = dict(
|
||||||
|
type='ImageClassifier',
|
||||||
|
backbone=dict(
|
||||||
|
type='VisionTransformer',
|
||||||
|
num_layers=12,
|
||||||
|
embed_dim=768,
|
||||||
|
num_heads=12,
|
||||||
|
img_size=384,
|
||||||
|
patch_size=16,
|
||||||
|
in_channels=3,
|
||||||
|
feedforward_channels=3072,
|
||||||
|
drop_rate=0.1),
|
||||||
|
neck=None,
|
||||||
|
head=dict(
|
||||||
|
type='VisionTransformerClsHead',
|
||||||
|
num_classes=1000,
|
||||||
|
in_channels=768,
|
||||||
|
loss=dict(type='CrossEntropyLoss', loss_weight=1.0),
|
||||||
|
topk=(1, 5),
|
||||||
|
))
|
|
@ -0,0 +1,13 @@
|
||||||
|
# optimizer
|
||||||
|
# In ClassyVision, the lr is set to 0.003 for bs4096.
|
||||||
|
# In this implementation(bs2048), lr = 0.003 / 4096 * (32bs * 64gpus) = 0.0015
|
||||||
|
optimizer = dict(type='AdamW', lr=0.0015, weight_decay=0.3)
|
||||||
|
optimizer_config = dict(grad_clip=dict(max_norm=1.0))
|
||||||
|
# learning policy
|
||||||
|
lr_config = dict(
|
||||||
|
policy='CosineAnnealing',
|
||||||
|
min_lr=0,
|
||||||
|
warmup='linear',
|
||||||
|
warmup_iters=10000,
|
||||||
|
warmup_ratio=1e-4)
|
||||||
|
runner = dict(type='EpochBasedRunner', max_epochs=300)
|
|
@ -0,0 +1,10 @@
|
||||||
|
# Refer to pytorch-image-models
|
||||||
|
_base_ = [
|
||||||
|
'../_base_/models/vit_base_patch16_224_finetune.py',
|
||||||
|
'../_base_/datasets/imagenet_bs32_pil_resize.py',
|
||||||
|
'../_base_/schedules/imagenet_bs256_epochstep.py',
|
||||||
|
'../_base_/default_runtime.py'
|
||||||
|
]
|
||||||
|
|
||||||
|
img_norm_cfg = dict(
|
||||||
|
mean=[127.5, 127.5, 127.5], std=[127.5, 127.5, 127.5], to_rgb=True)
|
|
@ -0,0 +1,6 @@
|
||||||
|
_base_ = [
|
||||||
|
'../_base_/models/vit_base_patch16_224_pretrain.py',
|
||||||
|
'../_base_/datasets/imagenet_bs32_pil_resize.py',
|
||||||
|
'../_base_/schedules/imagenet_bs2048_AdamW.py',
|
||||||
|
'../_base_/default_runtime.py'
|
||||||
|
]
|
|
@ -0,0 +1,21 @@
|
||||||
|
# Refer to pytorch-image-models
|
||||||
|
_base_ = [
|
||||||
|
'../_base_/models/vit_base_patch16_384_finetune.py',
|
||||||
|
'../_base_/datasets/imagenet_bs32_pil_resize.py',
|
||||||
|
'../_base_/schedules/imagenet_bs256_epochstep.py',
|
||||||
|
'../_base_/default_runtime.py'
|
||||||
|
]
|
||||||
|
|
||||||
|
img_norm_cfg = dict(
|
||||||
|
mean=[127.5, 127.5, 127.5], std=[127.5, 127.5, 127.5], to_rgb=True)
|
||||||
|
|
||||||
|
test_pipeline = [
|
||||||
|
dict(type='LoadImageFromFile'),
|
||||||
|
dict(type='Resize', size=(384, -1), backend='pillow'),
|
||||||
|
dict(type='CenterCrop', crop_size=384),
|
||||||
|
dict(type='Normalize', **img_norm_cfg),
|
||||||
|
dict(type='ImageToTensor', keys=['img']),
|
||||||
|
dict(type='Collect', keys=['img'])
|
||||||
|
]
|
||||||
|
|
||||||
|
data = dict(test=dict(pipeline=test_pipeline))
|
|
@ -12,9 +12,10 @@ from .seresnext import SEResNeXt
|
||||||
from .shufflenet_v1 import ShuffleNetV1
|
from .shufflenet_v1 import ShuffleNetV1
|
||||||
from .shufflenet_v2 import ShuffleNetV2
|
from .shufflenet_v2 import ShuffleNetV2
|
||||||
from .vgg import VGG
|
from .vgg import VGG
|
||||||
|
from .vision_transformer import VisionTransformer
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
'LeNet5', 'AlexNet', 'VGG', 'RegNet', 'ResNet', 'ResNeXt', 'ResNetV1d',
|
'LeNet5', 'AlexNet', 'VGG', 'RegNet', 'ResNet', 'ResNeXt', 'ResNetV1d',
|
||||||
'ResNeSt', 'ResNet_CIFAR', 'SEResNet', 'SEResNeXt', 'ShuffleNetV1',
|
'ResNeSt', 'ResNet_CIFAR', 'SEResNet', 'SEResNeXt', 'ShuffleNetV1',
|
||||||
'ShuffleNetV2', 'MobileNetV2', 'MobileNetv3'
|
'ShuffleNetV2', 'MobileNetV2', 'MobileNetv3', 'VisionTransformer'
|
||||||
]
|
]
|
||||||
|
|
|
@ -0,0 +1,493 @@
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
from mmcv.cnn import (build_activation_layer, build_conv_layer,
|
||||||
|
build_norm_layer, kaiming_init)
|
||||||
|
|
||||||
|
from ..builder import BACKBONES
|
||||||
|
from ..utils import to_2tuple
|
||||||
|
from .base_backbone import BaseBackbone
|
||||||
|
|
||||||
|
|
||||||
|
# Modified from mmdet
|
||||||
|
class FFN(nn.Module):
|
||||||
|
"""Implements feed-forward networks (FFNs) with residual connection.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
embed_dims (int): The feature dimension. Same as
|
||||||
|
`MultiheadAttention`.
|
||||||
|
feedforward_channels (int): The hidden dimension of FFNs.
|
||||||
|
num_fcs (int, optional): The number of fully-connected layers in
|
||||||
|
FFNs. Defaluts to 2.
|
||||||
|
act_cfg (dict, optional): The activation config for FFNs.
|
||||||
|
dropout (float, optional): Probability of an element to be
|
||||||
|
zeroed. Default 0.0.
|
||||||
|
add_residual (bool, optional): Add resudual connection.
|
||||||
|
Defaults to False.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self,
|
||||||
|
embed_dims,
|
||||||
|
feedforward_channels,
|
||||||
|
num_fcs=2,
|
||||||
|
act_cfg=dict(type='GELU'),
|
||||||
|
dropout=0.0,
|
||||||
|
add_residual=True):
|
||||||
|
super(FFN, self).__init__()
|
||||||
|
assert num_fcs >= 2, 'num_fcs should be no less ' \
|
||||||
|
f'than 2. got {num_fcs}.'
|
||||||
|
self.embed_dims = embed_dims
|
||||||
|
self.feedforward_channels = feedforward_channels
|
||||||
|
self.num_fcs = num_fcs
|
||||||
|
self.act_cfg = act_cfg
|
||||||
|
self.activate = build_activation_layer(act_cfg)
|
||||||
|
|
||||||
|
layers = nn.ModuleList()
|
||||||
|
in_channels = embed_dims
|
||||||
|
for _ in range(num_fcs - 1):
|
||||||
|
layers.append(
|
||||||
|
nn.Sequential(
|
||||||
|
nn.Linear(in_channels, feedforward_channels),
|
||||||
|
self.activate, nn.Dropout(dropout)))
|
||||||
|
in_channels = feedforward_channels
|
||||||
|
layers.append(nn.Linear(feedforward_channels, embed_dims))
|
||||||
|
self.layers = nn.Sequential(*layers)
|
||||||
|
self.dropout = nn.Dropout(dropout)
|
||||||
|
self.add_residual = add_residual
|
||||||
|
self.init_weights()
|
||||||
|
|
||||||
|
def init_weights(self):
|
||||||
|
for m in self.modules():
|
||||||
|
if isinstance(m, nn.Linear):
|
||||||
|
# xavier_init(m, distribution='uniform')
|
||||||
|
|
||||||
|
# Bias init is different from our API
|
||||||
|
# therefore initialize them separately
|
||||||
|
# The initialization is sync with ClassyVision
|
||||||
|
nn.init.xavier_normal_(m.weight)
|
||||||
|
nn.init.normal_(m.bias, std=1e-6)
|
||||||
|
|
||||||
|
def forward(self, x, residual=None):
|
||||||
|
"""Forward function for `FFN`."""
|
||||||
|
out = self.layers(x)
|
||||||
|
if not self.add_residual:
|
||||||
|
return out
|
||||||
|
if residual is None:
|
||||||
|
residual = x
|
||||||
|
return residual + self.dropout(out)
|
||||||
|
|
||||||
|
def __repr__(self):
|
||||||
|
"""str: a string that describes the module"""
|
||||||
|
repr_str = self.__class__.__name__
|
||||||
|
repr_str += f'(embed_dims={self.embed_dims}, '
|
||||||
|
repr_str += f'feedforward_channels={self.feedforward_channels}, '
|
||||||
|
repr_str += f'num_fcs={self.num_fcs}, '
|
||||||
|
repr_str += f'act_cfg={self.act_cfg}, '
|
||||||
|
repr_str += f'dropout={self.dropout}, '
|
||||||
|
repr_str += f'add_residual={self.add_residual})'
|
||||||
|
return repr_str
|
||||||
|
|
||||||
|
|
||||||
|
# Modified from mmdet
|
||||||
|
class MultiheadAttention(nn.Module):
|
||||||
|
"""A warpper for torch.nn.MultiheadAttention.
|
||||||
|
|
||||||
|
This module implements MultiheadAttention with residual connection.
|
||||||
|
Args:
|
||||||
|
embed_dims (int): The embedding dimension.
|
||||||
|
num_heads (int): Parallel attention heads. Same as
|
||||||
|
`nn.MultiheadAttention`.
|
||||||
|
attn_drop (float): A Dropout layer on attn_output_weights. Default 0.0.
|
||||||
|
proj_drop (float): The drop out rate after attention. Default 0.0.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, embed_dims, num_heads, attn_drop=0.0, proj_drop=0.0):
|
||||||
|
super(MultiheadAttention, self).__init__()
|
||||||
|
assert embed_dims % num_heads == 0, 'embed_dims must be ' \
|
||||||
|
f'divisible by num_heads. got {embed_dims} and {num_heads}.'
|
||||||
|
self.embed_dims = embed_dims
|
||||||
|
self.num_heads = num_heads
|
||||||
|
self.attn = nn.MultiheadAttention(embed_dims, num_heads, attn_drop)
|
||||||
|
self.dropout = nn.Dropout(proj_drop)
|
||||||
|
|
||||||
|
def forward(self,
|
||||||
|
x,
|
||||||
|
key=None,
|
||||||
|
value=None,
|
||||||
|
residual=None,
|
||||||
|
query_pos=None,
|
||||||
|
key_pos=None,
|
||||||
|
attn_mask=None,
|
||||||
|
key_padding_mask=None):
|
||||||
|
"""Forward function for `MultiheadAttention`.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
x (Tensor): The input query with shape [num_query, bs,
|
||||||
|
embed_dims]. Same in `nn.MultiheadAttention.forward`.
|
||||||
|
key (Tensor): The key tensor with shape [num_key, bs,
|
||||||
|
embed_dims]. Same in `nn.MultiheadAttention.forward`.
|
||||||
|
Default None. If None, the `query` will be used.
|
||||||
|
value (Tensor): The value tensor with same shape as `key`.
|
||||||
|
Same in `nn.MultiheadAttention.forward`. Default None.
|
||||||
|
If None, the `key` will be used.
|
||||||
|
residual (Tensor): The tensor used for addition, with the
|
||||||
|
same shape as `x`. Default None. If None, `x` will be used.
|
||||||
|
query_pos (Tensor): The positional encoding for query, with
|
||||||
|
the same shape as `x`. Default None. If not None, it will
|
||||||
|
be added to `x` before forward function.
|
||||||
|
key_pos (Tensor): The positional encoding for `key`, with the
|
||||||
|
same shape as `key`. Default None. If not None, it will
|
||||||
|
be added to `key` before forward function. If None, and
|
||||||
|
`query_pos` has the same shape as `key`, then `query_pos`
|
||||||
|
will be used for `key_pos`.
|
||||||
|
attn_mask (Tensor): ByteTensor mask with shape [num_query,
|
||||||
|
num_key]. Same in `nn.MultiheadAttention.forward`.
|
||||||
|
Default None.
|
||||||
|
key_padding_mask (Tensor): ByteTensor with shape [bs, num_key].
|
||||||
|
Same in `nn.MultiheadAttention.forward`. Default None.
|
||||||
|
Returns:
|
||||||
|
Tensor: forwarded results with shape [num_query, bs, embed_dims].
|
||||||
|
"""
|
||||||
|
query = x
|
||||||
|
if key is None:
|
||||||
|
key = query
|
||||||
|
if value is None:
|
||||||
|
value = key
|
||||||
|
if residual is None:
|
||||||
|
residual = x
|
||||||
|
if key_pos is None:
|
||||||
|
if query_pos is not None and key is not None:
|
||||||
|
if query_pos.shape == key.shape:
|
||||||
|
key_pos = query_pos
|
||||||
|
if query_pos is not None:
|
||||||
|
query = query + query_pos
|
||||||
|
if key_pos is not None:
|
||||||
|
key = key + key_pos
|
||||||
|
out = self.attn(
|
||||||
|
query,
|
||||||
|
key,
|
||||||
|
value=value,
|
||||||
|
attn_mask=attn_mask,
|
||||||
|
key_padding_mask=key_padding_mask)[0]
|
||||||
|
|
||||||
|
return residual + self.dropout(out)
|
||||||
|
|
||||||
|
|
||||||
|
# Modified from mmdet
|
||||||
|
class TransformerEncoderLayer(nn.Module):
|
||||||
|
"""Implements one encoder layer in Vision Transformer.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
embed_dims (int): The feature dimension. Same as `FFN`.
|
||||||
|
num_heads (int): Parallel attention heads.
|
||||||
|
feedforward_channels (int): The hidden dimension for FFNs.
|
||||||
|
attn_drop (float): The drop out rate for attention layer.
|
||||||
|
Default 0.0.
|
||||||
|
proj_drop (float): Probability of an element to be zeroed
|
||||||
|
after the feed forward layer. Default 0.0.
|
||||||
|
act_cfg (dict): The activation config for FFNs. Defalut GELU.
|
||||||
|
norm_cfg (dict): Config dict for normalization layer. Default
|
||||||
|
layer normalization.
|
||||||
|
num_fcs (int): The number of fully-connected layers for FFNs.
|
||||||
|
Default 2.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self,
|
||||||
|
embed_dims,
|
||||||
|
num_heads,
|
||||||
|
feedforward_channels,
|
||||||
|
attn_drop=0.,
|
||||||
|
proj_drop=0.,
|
||||||
|
act_cfg=dict(type='GELU'),
|
||||||
|
norm_cfg=dict(type='LN'),
|
||||||
|
num_fcs=2):
|
||||||
|
super(TransformerEncoderLayer, self).__init__()
|
||||||
|
self.norm1_name, norm1 = build_norm_layer(
|
||||||
|
norm_cfg, embed_dims, postfix=1)
|
||||||
|
self.add_module(self.norm1_name, norm1)
|
||||||
|
self.attn = MultiheadAttention(
|
||||||
|
embed_dims,
|
||||||
|
num_heads=num_heads,
|
||||||
|
attn_drop=attn_drop,
|
||||||
|
proj_drop=proj_drop)
|
||||||
|
|
||||||
|
self.norm2_name, norm2 = build_norm_layer(
|
||||||
|
norm_cfg, embed_dims, postfix=2)
|
||||||
|
self.add_module(self.norm2_name, norm2)
|
||||||
|
self.mlp = FFN(embed_dims, feedforward_channels, num_fcs, act_cfg,
|
||||||
|
proj_drop)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def norm1(self):
|
||||||
|
return getattr(self, self.norm1_name)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def norm2(self):
|
||||||
|
return getattr(self, self.norm2_name)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
norm_x = self.norm1(x)
|
||||||
|
# Reason for permute: as the shape of input from pretrained weight
|
||||||
|
# from pytorch-image-models is [batch_size, num_query, embed_dim],
|
||||||
|
# but the one from nn.MultiheadAttention is
|
||||||
|
# [num_query, batch_size, embed_dim]
|
||||||
|
x = x.permute(1, 0, 2)
|
||||||
|
norm_x = norm_x.permute(1, 0, 2)
|
||||||
|
x = self.attn(norm_x, residual=x)
|
||||||
|
# Convert the shape back to [batch_size, num_query, embed_dim] in
|
||||||
|
# order to make use of the pretrained weight
|
||||||
|
x = x.permute(1, 0, 2)
|
||||||
|
x = self.mlp(self.norm2(x), residual=x)
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
# Modified from pytorch-image-models
|
||||||
|
class PatchEmbed(nn.Module):
|
||||||
|
"""Image to Patch Embedding.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
img_size (int | tuple): The size of input image.
|
||||||
|
patch_size (int): The size of one patch
|
||||||
|
in_channels (int): The num of input channels.
|
||||||
|
embed_dim (int): The dimensions of embedding.
|
||||||
|
conv_cfg (dict | None): The config dict for conv layers.
|
||||||
|
Default: None.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self,
|
||||||
|
img_size=224,
|
||||||
|
patch_size=16,
|
||||||
|
in_channels=3,
|
||||||
|
embed_dim=768,
|
||||||
|
conv_cfg=None):
|
||||||
|
super(PatchEmbed, self).__init__()
|
||||||
|
if isinstance(img_size, int):
|
||||||
|
img_size = to_2tuple(img_size)
|
||||||
|
# img_size = tuple(repeat(img_size, 2))
|
||||||
|
elif isinstance(img_size, tuple):
|
||||||
|
if len(img_size) == 1:
|
||||||
|
img_size = to_2tuple(img_size[0])
|
||||||
|
# img_size = tuple(repeat(img_size[0], 2))
|
||||||
|
assert len(img_size) == 2, \
|
||||||
|
f'The size of image should have length 1 or 2, ' \
|
||||||
|
f'but got {len(img_size)}'
|
||||||
|
|
||||||
|
self.img_size = img_size
|
||||||
|
self.patch_size = to_2tuple(patch_size)
|
||||||
|
|
||||||
|
num_patches = (self.img_size[1] // self.patch_size[1]) * (
|
||||||
|
self.img_size[0] // self.patch_size[0])
|
||||||
|
assert num_patches * self.patch_size[0] * self.patch_size[1] == \
|
||||||
|
self.img_size[0] * self.img_size[1], \
|
||||||
|
'The image size H*W must be divisible by patch size'
|
||||||
|
self.num_patches = num_patches
|
||||||
|
|
||||||
|
# Use conv layer to embed
|
||||||
|
self.projection = build_conv_layer(
|
||||||
|
conv_cfg,
|
||||||
|
in_channels,
|
||||||
|
embed_dim,
|
||||||
|
kernel_size=patch_size,
|
||||||
|
stride=patch_size)
|
||||||
|
|
||||||
|
self.init_weights()
|
||||||
|
|
||||||
|
def init_weights(self):
|
||||||
|
# Lecun norm from ClassyVision
|
||||||
|
kaiming_init(self.projection, mode='fan_in', nonlinearity='linear')
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
B, C, H, W = x.shape
|
||||||
|
# FIXME look at relaxing size constraints
|
||||||
|
assert H == self.img_size[0] and W == self.img_size[1], \
|
||||||
|
f"Input image size ({H}*{W}) doesn't " \
|
||||||
|
f'match model ({self.img_size[0]}*{self.img_size[1]}).'
|
||||||
|
# The output size is (B, N, D), where N=H*W/P/P, D is embid_dim
|
||||||
|
x = self.projection(x).flatten(2).transpose(1, 2)
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
class HybridEmbed(nn.Module):
|
||||||
|
"""CNN Feature Map Embedding.
|
||||||
|
|
||||||
|
Extract feature map from CNN, flatten, project to embedding dim.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self,
|
||||||
|
backbone,
|
||||||
|
img_size=224,
|
||||||
|
feature_size=None,
|
||||||
|
in_channels=3,
|
||||||
|
embed_dim=768,
|
||||||
|
conv_cfg=None):
|
||||||
|
super().__init__()
|
||||||
|
assert isinstance(backbone, nn.Module)
|
||||||
|
if isinstance(img_size, int):
|
||||||
|
img_size = to_2tuple(img_size)
|
||||||
|
elif isinstance(img_size, tuple):
|
||||||
|
if len(img_size) == 1:
|
||||||
|
img_size = to_2tuple(img_size[0])
|
||||||
|
assert len(img_size) == 2, \
|
||||||
|
f'The size of image should have length 1 or 2, ' \
|
||||||
|
f'but got {len(img_size)}'
|
||||||
|
|
||||||
|
self.img_size = img_size
|
||||||
|
self.backbone = backbone
|
||||||
|
if feature_size is None:
|
||||||
|
with torch.no_grad():
|
||||||
|
# FIXME this is hacky, but most reliable way of
|
||||||
|
# determining the exact dim of the output feature
|
||||||
|
# map for all networks, the feature metadata has
|
||||||
|
# reliable channel and stride info, but using
|
||||||
|
# stride to calc feature dim requires info about padding of
|
||||||
|
# each stage that isn't captured.
|
||||||
|
training = backbone.training
|
||||||
|
if training:
|
||||||
|
backbone.eval()
|
||||||
|
o = self.backbone(
|
||||||
|
torch.zeros(1, in_channels, img_size[0], img_size[1]))
|
||||||
|
if isinstance(o, (list, tuple)):
|
||||||
|
# last feature if backbone outputs list/tuple of features
|
||||||
|
o = o[-1]
|
||||||
|
feature_size = o.shape[-2:]
|
||||||
|
feature_dim = o.shape[1]
|
||||||
|
backbone.train(training)
|
||||||
|
else:
|
||||||
|
feature_size = to_2tuple(feature_size)
|
||||||
|
if hasattr(self.backbone, 'feature_info'):
|
||||||
|
feature_dim = self.backbone.feature_info.channels()[-1]
|
||||||
|
else:
|
||||||
|
feature_dim = self.backbone.num_features
|
||||||
|
self.num_patches = feature_size[0] * feature_size[1]
|
||||||
|
|
||||||
|
# Use conv layer to embed
|
||||||
|
self.projection = build_conv_layer(
|
||||||
|
conv_cfg, feature_dim, embed_dim, kernel_size=1, stride=1)
|
||||||
|
|
||||||
|
self.init_weights()
|
||||||
|
|
||||||
|
def init_weights(self):
|
||||||
|
# Lecun norm from ClassyVision
|
||||||
|
kaiming_init(self.projection, mode='fan_in', nonlinearity='linear')
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
x = self.backbone(x)
|
||||||
|
if isinstance(x, (list, tuple)):
|
||||||
|
# last feature if backbone outputs list/tuple of features
|
||||||
|
x = x[-1]
|
||||||
|
x = self.projection(x).flatten(2).transpose(1, 2)
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
# Modified from pytorch-image-models and mmdet
|
||||||
|
@BACKBONES.register_module()
|
||||||
|
class VisionTransformer(BaseBackbone):
|
||||||
|
""" Vision Transformer
|
||||||
|
A PyTorch impl of : `An Image is Worth 16x16 Words:
|
||||||
|
Transformers for Image Recognition at Scale` -
|
||||||
|
https://arxiv.org/abs/2010.11929
|
||||||
|
Args:
|
||||||
|
num_layers (int): Depth of transformer
|
||||||
|
embed_dim (int): Embedding dimension
|
||||||
|
num_heads (int): Number of attention heads
|
||||||
|
img_size (int | tuple): Input image size
|
||||||
|
patch_size (int | tuple): The patch size
|
||||||
|
in_channels (int): Number of input channels
|
||||||
|
feedforward_channels (int): The hidden dimension for FFNs.
|
||||||
|
drop_rate (float): Probability of an element to be zeroed.
|
||||||
|
Default 0.0.
|
||||||
|
attn_drop (float): The drop out rate for attention layer.
|
||||||
|
Default 0.0.
|
||||||
|
hybrid_backbone (nn.Module): CNN backbone to use in-place of
|
||||||
|
PatchEmbed module. Default None.
|
||||||
|
norm_cfg
|
||||||
|
norm_cfg (dict): Config dict for normalization layer. Default
|
||||||
|
layer normalization.
|
||||||
|
act_cfg (dict): The activation config for FFNs. Defalut GELU.
|
||||||
|
num_fcs (int): The number of fully-connected layers for FFNs.
|
||||||
|
Default 2.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self,
|
||||||
|
num_layers=12,
|
||||||
|
embed_dim=768,
|
||||||
|
num_heads=12,
|
||||||
|
img_size=224,
|
||||||
|
patch_size=16,
|
||||||
|
in_channels=3,
|
||||||
|
feedforward_channels=3072,
|
||||||
|
drop_rate=0.,
|
||||||
|
attn_drop_rate=0.,
|
||||||
|
hybrid_backbone=None,
|
||||||
|
norm_cfg=dict(type='LN'),
|
||||||
|
act_cfg=dict(type='GELU'),
|
||||||
|
num_fcs=2):
|
||||||
|
super(VisionTransformer, self).__init__()
|
||||||
|
self.embed_dim = embed_dim
|
||||||
|
|
||||||
|
if hybrid_backbone is not None:
|
||||||
|
self.patch_embed = HybridEmbed(
|
||||||
|
hybrid_backbone,
|
||||||
|
img_size=img_size,
|
||||||
|
in_channels=in_channels,
|
||||||
|
embed_dim=embed_dim)
|
||||||
|
else:
|
||||||
|
self.patch_embed = PatchEmbed(
|
||||||
|
img_size=img_size,
|
||||||
|
patch_size=patch_size,
|
||||||
|
in_channels=in_channels,
|
||||||
|
embed_dim=embed_dim)
|
||||||
|
num_patches = self.patch_embed.num_patches
|
||||||
|
|
||||||
|
self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
|
||||||
|
self.pos_embed = nn.Parameter(
|
||||||
|
torch.zeros(1, num_patches + 1, embed_dim))
|
||||||
|
self.drop_after_pos = nn.Dropout(p=drop_rate)
|
||||||
|
|
||||||
|
self.layers = nn.ModuleList()
|
||||||
|
for _ in range(num_layers):
|
||||||
|
self.layers.append(
|
||||||
|
TransformerEncoderLayer(
|
||||||
|
embed_dim,
|
||||||
|
num_heads,
|
||||||
|
feedforward_channels,
|
||||||
|
attn_drop=attn_drop_rate,
|
||||||
|
proj_drop=drop_rate,
|
||||||
|
act_cfg=act_cfg,
|
||||||
|
norm_cfg=norm_cfg,
|
||||||
|
num_fcs=num_fcs))
|
||||||
|
|
||||||
|
self.norm1_name, norm1 = build_norm_layer(
|
||||||
|
norm_cfg, embed_dim, postfix=1)
|
||||||
|
self.add_module(self.norm1_name, norm1)
|
||||||
|
|
||||||
|
self.init_weights()
|
||||||
|
|
||||||
|
def init_weights(self, pretrained=None):
|
||||||
|
super(VisionTransformer, self).init_weights(pretrained)
|
||||||
|
if pretrained is None:
|
||||||
|
# Modified from ClassyVision
|
||||||
|
nn.init.normal_(self.pos_embed, std=0.02)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def norm1(self):
|
||||||
|
return getattr(self, self.norm1_name)
|
||||||
|
|
||||||
|
@torch.jit.ignore
|
||||||
|
def no_weight_decay(self):
|
||||||
|
return {'pos_embed', 'cls_token'}
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
B = x.shape[0]
|
||||||
|
x = self.patch_embed(x)
|
||||||
|
|
||||||
|
cls_tokens = self.cls_token.expand(
|
||||||
|
B, -1, -1) # stole cls_tokens impl from Phil Wang, thanks
|
||||||
|
x = torch.cat((cls_tokens, x), dim=1)
|
||||||
|
x = x + self.pos_embed
|
||||||
|
x = self.drop_after_pos(x)
|
||||||
|
|
||||||
|
for layer in self.layers:
|
||||||
|
x = layer(x)
|
||||||
|
|
||||||
|
x = self.norm1(x)[:, 0]
|
||||||
|
return x
|
|
@ -2,7 +2,9 @@ from .cls_head import ClsHead
|
||||||
from .linear_head import LinearClsHead
|
from .linear_head import LinearClsHead
|
||||||
from .multi_label_head import MultiLabelClsHead
|
from .multi_label_head import MultiLabelClsHead
|
||||||
from .multi_label_linear_head import MultiLabelLinearClsHead
|
from .multi_label_linear_head import MultiLabelLinearClsHead
|
||||||
|
from .vision_transformer_head import VisionTransformerClsHead
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
'ClsHead', 'LinearClsHead', 'MultiLabelClsHead', 'MultiLabelLinearClsHead'
|
'ClsHead', 'LinearClsHead', 'MultiLabelClsHead', 'MultiLabelLinearClsHead',
|
||||||
|
'VisionTransformerClsHead'
|
||||||
]
|
]
|
||||||
|
|
|
@ -0,0 +1,84 @@
|
||||||
|
from collections import OrderedDict
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
import torch.nn.functional as F
|
||||||
|
from mmcv.cnn import build_activation_layer, constant_init, kaiming_init
|
||||||
|
|
||||||
|
from ..builder import HEADS
|
||||||
|
from .cls_head import ClsHead
|
||||||
|
|
||||||
|
|
||||||
|
@HEADS.register_module()
|
||||||
|
class VisionTransformerClsHead(ClsHead):
|
||||||
|
"""Vision Transformer classifier head.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
num_classes (int): Number of categories excluding the background
|
||||||
|
category.
|
||||||
|
in_channels (int): Number of channels in the input feature map.
|
||||||
|
hidden_dim (int): Number of the dimensions for hidden layer. Only
|
||||||
|
available during pre-training. Default None.
|
||||||
|
act_cfg (dict): The activation config. Only available during
|
||||||
|
pre-training. Defalut Tanh.
|
||||||
|
loss (dict): Config of classification loss.
|
||||||
|
topk (int | tuple): Top-k accuracy.
|
||||||
|
cal_acc (bool): Whether to calculate accuracy during training.
|
||||||
|
If mixup is used, this should be False. Default False.
|
||||||
|
""" # noqa: W605
|
||||||
|
|
||||||
|
def __init__(self,
|
||||||
|
num_classes,
|
||||||
|
in_channels,
|
||||||
|
hidden_dim=None,
|
||||||
|
act_cfg=dict(type='Tanh'),
|
||||||
|
loss=dict(type='CrossEntropyLoss', loss_weight=1.0),
|
||||||
|
topk=(1, ),
|
||||||
|
cal_acc=False):
|
||||||
|
super(VisionTransformerClsHead, self).__init__(
|
||||||
|
loss=loss, topk=topk, cal_acc=cal_acc)
|
||||||
|
self.in_channels = in_channels
|
||||||
|
self.num_classes = num_classes
|
||||||
|
self.hidden_dim = hidden_dim
|
||||||
|
self.act_cfg = act_cfg
|
||||||
|
|
||||||
|
if self.num_classes <= 0:
|
||||||
|
raise ValueError(
|
||||||
|
f'num_classes={num_classes} must be a positive integer')
|
||||||
|
|
||||||
|
self._init_layers()
|
||||||
|
|
||||||
|
def _init_layers(self):
|
||||||
|
if self.hidden_dim is None:
|
||||||
|
layers = [('head', nn.Linear(self.in_channels, self.num_classes))]
|
||||||
|
else:
|
||||||
|
layers = [
|
||||||
|
('pre_logits', nn.Linear(self.in_channels, self.hidden_dim)),
|
||||||
|
('act', build_activation_layer(self.act_cfg)),
|
||||||
|
('head', nn.Linear(self.hidden_dim, self.num_classes)),
|
||||||
|
]
|
||||||
|
self.layers = nn.Sequential(OrderedDict(layers))
|
||||||
|
|
||||||
|
def init_weights(self):
|
||||||
|
# Modified from ClassyVision
|
||||||
|
if hasattr(self.layers, 'pre_logits'):
|
||||||
|
# Lecun norm
|
||||||
|
kaiming_init(
|
||||||
|
self.layers.pre_logits, mode='fan_in', nonlinearity='linear')
|
||||||
|
constant_init(self.layers.head, 0)
|
||||||
|
|
||||||
|
def simple_test(self, img):
|
||||||
|
"""Test without augmentation."""
|
||||||
|
cls_score = self.layers(img)
|
||||||
|
if isinstance(cls_score, list):
|
||||||
|
cls_score = sum(cls_score) / float(len(cls_score))
|
||||||
|
pred = F.softmax(cls_score, dim=1) if cls_score is not None else None
|
||||||
|
if torch.onnx.is_in_onnx_export():
|
||||||
|
return pred
|
||||||
|
pred = list(pred.detach().cpu().numpy())
|
||||||
|
return pred
|
||||||
|
|
||||||
|
def forward_train(self, x, gt_label):
|
||||||
|
cls_score = self.layers(x)
|
||||||
|
losses = self.loss(cls_score, gt_label)
|
||||||
|
return losses
|
|
@ -1,5 +1,6 @@
|
||||||
from .channel_shuffle import channel_shuffle
|
from .channel_shuffle import channel_shuffle
|
||||||
from .cutmix import BatchCutMixLayer
|
from .cutmix import BatchCutMixLayer
|
||||||
|
from .helpers import to_2tuple, to_3tuple, to_4tuple, to_ntuple
|
||||||
from .inverted_residual import InvertedResidual
|
from .inverted_residual import InvertedResidual
|
||||||
from .make_divisible import make_divisible
|
from .make_divisible import make_divisible
|
||||||
from .mixup import BatchMixupLayer
|
from .mixup import BatchMixupLayer
|
||||||
|
@ -7,5 +8,6 @@ from .se_layer import SELayer
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
'channel_shuffle', 'make_divisible', 'InvertedResidual', 'BatchMixupLayer',
|
'channel_shuffle', 'make_divisible', 'InvertedResidual', 'BatchMixupLayer',
|
||||||
'BatchCutMixLayer', 'SELayer'
|
'BatchCutMixLayer', 'SELayer', 'to_ntuple', 'to_2tuple', 'to_3tuple',
|
||||||
|
'to_4tuple'
|
||||||
]
|
]
|
||||||
|
|
|
@ -0,0 +1,20 @@
|
||||||
|
import collections.abc
|
||||||
|
from itertools import repeat
|
||||||
|
|
||||||
|
|
||||||
|
# From PyTorch internals
|
||||||
|
def _ntuple(n):
|
||||||
|
|
||||||
|
def parse(x):
|
||||||
|
if isinstance(x, collections.abc.Iterable):
|
||||||
|
return x
|
||||||
|
return tuple(repeat(x, n))
|
||||||
|
|
||||||
|
return parse
|
||||||
|
|
||||||
|
|
||||||
|
to_1tuple = _ntuple(1)
|
||||||
|
to_2tuple = _ntuple(2)
|
||||||
|
to_3tuple = _ntuple(3)
|
||||||
|
to_4tuple = _ntuple(4)
|
||||||
|
to_ntuple = _ntuple
|
|
@ -0,0 +1,57 @@
|
||||||
|
import pytest
|
||||||
|
import torch
|
||||||
|
from torch.nn.modules import GroupNorm
|
||||||
|
from torch.nn.modules.batchnorm import _BatchNorm
|
||||||
|
|
||||||
|
from mmcls.models.backbones import VGG, VisionTransformer
|
||||||
|
|
||||||
|
|
||||||
|
def is_norm(modules):
|
||||||
|
"""Check if is one of the norms."""
|
||||||
|
if isinstance(modules, (GroupNorm, _BatchNorm)):
|
||||||
|
return True
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
def check_norm_state(modules, train_state):
|
||||||
|
"""Check if norm layer is in correct train state."""
|
||||||
|
for mod in modules:
|
||||||
|
if isinstance(mod, _BatchNorm):
|
||||||
|
if mod.training != train_state:
|
||||||
|
return False
|
||||||
|
return True
|
||||||
|
|
||||||
|
|
||||||
|
def test_vit_backbone():
|
||||||
|
with pytest.raises(TypeError):
|
||||||
|
# pretrained must be a string path
|
||||||
|
model = VisionTransformer()
|
||||||
|
model.init_weights(pretrained=0)
|
||||||
|
|
||||||
|
# Test ViT base model with input size of 224
|
||||||
|
# and patch size of 16
|
||||||
|
model = VisionTransformer()
|
||||||
|
model.init_weights()
|
||||||
|
model.train()
|
||||||
|
|
||||||
|
assert check_norm_state(model.modules(), True)
|
||||||
|
|
||||||
|
imgs = torch.randn(1, 3, 224, 224)
|
||||||
|
feat = model(imgs)
|
||||||
|
assert feat.shape == torch.Size((1, 768))
|
||||||
|
|
||||||
|
|
||||||
|
def test_vit_hybrid_backbone():
|
||||||
|
|
||||||
|
# Test VGG11+ViT-B/16 hybrid model
|
||||||
|
backbone = VGG(11, norm_eval=True)
|
||||||
|
backbone.init_weights()
|
||||||
|
model = VisionTransformer(hybrid_backbone=backbone)
|
||||||
|
model.init_weights()
|
||||||
|
model.train()
|
||||||
|
|
||||||
|
assert check_norm_state(model.modules(), True)
|
||||||
|
|
||||||
|
imgs = torch.randn(1, 3, 224, 224)
|
||||||
|
feat = model(imgs)
|
||||||
|
assert feat.shape == torch.Size((1, 768))
|
|
@ -82,3 +82,36 @@ def test_image_classifier_with_label_smooth_loss():
|
||||||
|
|
||||||
losses = img_classifier.forward_train(imgs, label)
|
losses = img_classifier.forward_train(imgs, label)
|
||||||
assert losses['loss'].item() > 0
|
assert losses['loss'].item() > 0
|
||||||
|
|
||||||
|
|
||||||
|
def test_image_classifier_vit():
|
||||||
|
|
||||||
|
model_cfg = dict(
|
||||||
|
backbone=dict(
|
||||||
|
type='VisionTransformer',
|
||||||
|
num_layers=12,
|
||||||
|
embed_dim=768,
|
||||||
|
num_heads=12,
|
||||||
|
img_size=224,
|
||||||
|
patch_size=16,
|
||||||
|
in_channels=3,
|
||||||
|
feedforward_channels=3072,
|
||||||
|
drop_rate=0.1,
|
||||||
|
attn_drop_rate=0.),
|
||||||
|
neck=None,
|
||||||
|
head=dict(
|
||||||
|
type='VisionTransformerClsHead',
|
||||||
|
num_classes=1000,
|
||||||
|
in_channels=768,
|
||||||
|
hidden_dim=3072,
|
||||||
|
loss=dict(type='CrossEntropyLoss', loss_weight=1.0, use_soft=True),
|
||||||
|
topk=(1, 5),
|
||||||
|
),
|
||||||
|
train_cfg=dict(mixup=dict(alpha=0.2, num_classes=1000)))
|
||||||
|
img_classifier = ImageClassifier(**model_cfg)
|
||||||
|
img_classifier.init_weights()
|
||||||
|
imgs = torch.randn(16, 3, 224, 224)
|
||||||
|
label = torch.randint(0, 1000, (16, ))
|
||||||
|
|
||||||
|
losses = img_classifier.forward_train(imgs, label)
|
||||||
|
assert losses['loss'].item() > 0
|
||||||
|
|
Loading…
Reference in New Issue