diff --git a/configs/_base_/models/vit_base_patch32_384_finetune.py b/configs/_base_/models/vit_base_patch32_384_finetune.py new file mode 100644 index 00000000..5017e390 --- /dev/null +++ b/configs/_base_/models/vit_base_patch32_384_finetune.py @@ -0,0 +1,21 @@ +# model settings +model = dict( + type='ImageClassifier', + backbone=dict( + type='VisionTransformer', + num_layers=12, + embed_dim=768, + num_heads=12, + img_size=384, + patch_size=32, + in_channels=3, + feedforward_channels=3072, + drop_rate=0.1), + neck=None, + head=dict( + type='VisionTransformerClsHead', + num_classes=1000, + in_channels=768, + loss=dict(type='CrossEntropyLoss', loss_weight=1.0), + topk=(1, 5), + )) diff --git a/configs/_base_/models/vit_large_patch16_224_finetune.py b/configs/_base_/models/vit_large_patch16_224_finetune.py new file mode 100644 index 00000000..62a11031 --- /dev/null +++ b/configs/_base_/models/vit_large_patch16_224_finetune.py @@ -0,0 +1,21 @@ +# model settings +model = dict( + type='ImageClassifier', + backbone=dict( + type='VisionTransformer', + num_layers=24, + embed_dim=1024, + num_heads=16, + img_size=224, + patch_size=16, + in_channels=3, + feedforward_channels=4096, + drop_rate=0.1), + neck=None, + head=dict( + type='VisionTransformerClsHead', + num_classes=1000, + in_channels=1024, + loss=dict(type='CrossEntropyLoss', loss_weight=1.0), + topk=(1, 5), + )) diff --git a/configs/_base_/models/vit_large_patch16_384_finetune.py b/configs/_base_/models/vit_large_patch16_384_finetune.py new file mode 100644 index 00000000..6309f608 --- /dev/null +++ b/configs/_base_/models/vit_large_patch16_384_finetune.py @@ -0,0 +1,21 @@ +# model settings +model = dict( + type='ImageClassifier', + backbone=dict( + type='VisionTransformer', + num_layers=24, + embed_dim=1024, + num_heads=16, + img_size=384, + patch_size=16, + in_channels=3, + feedforward_channels=4096, + drop_rate=0.1), + neck=None, + head=dict( + type='VisionTransformerClsHead', + num_classes=1000, + in_channels=1024, + loss=dict(type='CrossEntropyLoss', loss_weight=1.0), + topk=(1, 5), + )) diff --git a/configs/_base_/models/vit_large_patch32_384_finetune.py b/configs/_base_/models/vit_large_patch32_384_finetune.py new file mode 100644 index 00000000..9c2483b1 --- /dev/null +++ b/configs/_base_/models/vit_large_patch32_384_finetune.py @@ -0,0 +1,21 @@ +# model settings +model = dict( + type='ImageClassifier', + backbone=dict( + type='VisionTransformer', + num_layers=24, + embed_dim=1024, + num_heads=16, + img_size=384, + patch_size=32, + in_channels=3, + feedforward_channels=4096, + drop_rate=0.1), + neck=None, + head=dict( + type='VisionTransformerClsHead', + num_classes=1000, + in_channels=1024, + loss=dict(type='CrossEntropyLoss', loss_weight=1.0), + topk=(1, 5), + )) diff --git a/configs/vision_transformer/vit_base_patch32_384_finetune_imagenet.py b/configs/vision_transformer/vit_base_patch32_384_finetune_imagenet.py new file mode 100644 index 00000000..bc97d597 --- /dev/null +++ b/configs/vision_transformer/vit_base_patch32_384_finetune_imagenet.py @@ -0,0 +1,21 @@ +# Refer to pytorch-image-models +_base_ = [ + '../_base_/models/vit_base_patch32_384_finetune.py', + '../_base_/datasets/imagenet_bs32_pil_resize.py', + '../_base_/schedules/imagenet_bs256_epochstep.py', + '../_base_/default_runtime.py' +] + +img_norm_cfg = dict( + mean=[127.5, 127.5, 127.5], std=[127.5, 127.5, 127.5], to_rgb=True) + +test_pipeline = [ + dict(type='LoadImageFromFile'), + dict(type='Resize', size=(384, -1), backend='pillow'), + dict(type='CenterCrop', crop_size=384), + dict(type='Normalize', **img_norm_cfg), + dict(type='ImageToTensor', keys=['img']), + dict(type='Collect', keys=['img']) +] + +data = dict(test=dict(pipeline=test_pipeline)) diff --git a/configs/vision_transformer/vit_large_patch16_224_finetune_imagenet.py b/configs/vision_transformer/vit_large_patch16_224_finetune_imagenet.py new file mode 100644 index 00000000..7809d26b --- /dev/null +++ b/configs/vision_transformer/vit_large_patch16_224_finetune_imagenet.py @@ -0,0 +1,21 @@ +# Refer to pytorch-image-models +_base_ = [ + '../_base_/models/vit_large_patch16_224_finetune.py', + '../_base_/datasets/imagenet_bs32_pil_resize.py', + '../_base_/schedules/imagenet_bs256_epochstep.py', + '../_base_/default_runtime.py' +] + +img_norm_cfg = dict( + mean=[127.5, 127.5, 127.5], std=[127.5, 127.5, 127.5], to_rgb=True) + +test_pipeline = [ + dict(type='LoadImageFromFile'), + dict(type='Resize', size=(384, -1), backend='pillow'), + dict(type='CenterCrop', crop_size=384), + dict(type='Normalize', **img_norm_cfg), + dict(type='ImageToTensor', keys=['img']), + dict(type='Collect', keys=['img']) +] + +data = dict(test=dict(pipeline=test_pipeline)) diff --git a/configs/vision_transformer/vit_large_patch16_384_finetune_imagenet.py b/configs/vision_transformer/vit_large_patch16_384_finetune_imagenet.py new file mode 100644 index 00000000..0bb5a6ac --- /dev/null +++ b/configs/vision_transformer/vit_large_patch16_384_finetune_imagenet.py @@ -0,0 +1,21 @@ +# Refer to pytorch-image-models +_base_ = [ + '../_base_/models/vit_large_patch16_384_finetune.py', + '../_base_/datasets/imagenet_bs32_pil_resize.py', + '../_base_/schedules/imagenet_bs256_epochstep.py', + '../_base_/default_runtime.py' +] + +img_norm_cfg = dict( + mean=[127.5, 127.5, 127.5], std=[127.5, 127.5, 127.5], to_rgb=True) + +test_pipeline = [ + dict(type='LoadImageFromFile'), + dict(type='Resize', size=(384, -1), backend='pillow'), + dict(type='CenterCrop', crop_size=384), + dict(type='Normalize', **img_norm_cfg), + dict(type='ImageToTensor', keys=['img']), + dict(type='Collect', keys=['img']) +] + +data = dict(test=dict(pipeline=test_pipeline)) diff --git a/configs/vision_transformer/vit_large_patch32_384_finetune_imagenet.py b/configs/vision_transformer/vit_large_patch32_384_finetune_imagenet.py new file mode 100644 index 00000000..c4b62c77 --- /dev/null +++ b/configs/vision_transformer/vit_large_patch32_384_finetune_imagenet.py @@ -0,0 +1,21 @@ +# Refer to pytorch-image-models +_base_ = [ + '../_base_/models/vit_large_patch32_384_finetune.py', + '../_base_/datasets/imagenet_bs32_pil_resize.py', + '../_base_/schedules/imagenet_bs256_epochstep.py', + '../_base_/default_runtime.py' +] + +img_norm_cfg = dict( + mean=[127.5, 127.5, 127.5], std=[127.5, 127.5, 127.5], to_rgb=True) + +test_pipeline = [ + dict(type='LoadImageFromFile'), + dict(type='Resize', size=(384, -1), backend='pillow'), + dict(type='CenterCrop', crop_size=384), + dict(type='Normalize', **img_norm_cfg), + dict(type='ImageToTensor', keys=['img']), + dict(type='Collect', keys=['img']) +] + +data = dict(test=dict(pipeline=test_pipeline)) diff --git a/mmcls/models/backbones/vision_transformer.py b/mmcls/models/backbones/vision_transformer.py index 5f36aad9..b6f71357 100644 --- a/mmcls/models/backbones/vision_transformer.py +++ b/mmcls/models/backbones/vision_transformer.py @@ -8,7 +8,6 @@ from ..utils import to_2tuple from .base_backbone import BaseBackbone -# Modified from mmdet class FFN(nn.Module): """Implements feed-forward networks (FFNs) with residual connection. @@ -87,7 +86,6 @@ class FFN(nn.Module): return repr_str -# Modified from mmdet class MultiheadAttention(nn.Module): """A warpper for torch.nn.MultiheadAttention. @@ -172,7 +170,6 @@ class MultiheadAttention(nn.Module): return residual + self.dropout(out) -# Modified from mmdet class TransformerEncoderLayer(nn.Module): """Implements one encoder layer in Vision Transformer. @@ -240,7 +237,6 @@ class TransformerEncoderLayer(nn.Module): return x -# Modified from pytorch-image-models class PatchEmbed(nn.Module): """Image to Patch Embedding. @@ -262,11 +258,9 @@ class PatchEmbed(nn.Module): super(PatchEmbed, self).__init__() if isinstance(img_size, int): img_size = to_2tuple(img_size) - # img_size = tuple(repeat(img_size, 2)) elif isinstance(img_size, tuple): if len(img_size) == 1: img_size = to_2tuple(img_size[0]) - # img_size = tuple(repeat(img_size[0], 2)) assert len(img_size) == 2, \ f'The size of image should have length 1 or 2, ' \ f'but got {len(img_size)}' @@ -378,7 +372,6 @@ class HybridEmbed(nn.Module): return x -# Modified from pytorch-image-models and mmdet @BACKBONES.register_module() class VisionTransformer(BaseBackbone): """ Vision Transformer