mmpretrain/configs/vision_transformer/vit-large-p16_ft-64xb64_in1k-384.py
Ma Zerun 2932f9d8a3
[Refactor] Refator ViT (Continue #295) (#395)
* [Squash] Refator ViT (from #295)

* Use base variable to simplify auto_aug setting

* Use common PatchEmbed, remove HybridEmbed and refactor ViT init
structure.

* Add `output_cls_token` option and change the output format of ViT and
input format of ViT head.

* Update unit tests and add test for `output_cls_token`.

* Support out_indices.

* Standardize config files

* Support resize position embedding.

* Add readme file of vit

* Rename config file

* Improve docs about ViT.

* Update docstring

* Use local version `MultiheadAttention` instead of mmcv version.

* Fix MultiheadAttention

* Support `qk_scale` argument in `MultiheadAttention`

* Improve docs and change `layer_cfg` to `layer_cfgs` and support
sequence.

* Use init_cfg to init Linear layer in VisionTransformerHead

* update metafile

* Update checkpoints and configs

* Imporve docstring.

* Update README

* Revert GAP modification.
2021-10-18 16:07:00 +08:00

37 lines
1.1 KiB
Python

_base_ = [
'../_base_/models/vit-large-p16.py',
'../_base_/datasets/imagenet_bs64_pil_resize_autoaug.py',
'../_base_/schedules/imagenet_bs4096_AdamW.py',
'../_base_/default_runtime.py'
]
model = dict(backbone=dict(img_size=384))
img_norm_cfg = dict(
mean=[127.5, 127.5, 127.5], std=[127.5, 127.5, 127.5], to_rgb=True)
train_pipeline = [
dict(type='LoadImageFromFile'),
dict(type='RandomResizedCrop', size=384, backend='pillow'),
dict(type='RandomFlip', flip_prob=0.5, direction='horizontal'),
dict(type='Normalize', **img_norm_cfg),
dict(type='ImageToTensor', keys=['img']),
dict(type='ToTensor', keys=['gt_label']),
dict(type='Collect', keys=['img', 'gt_label'])
]
test_pipeline = [
dict(type='LoadImageFromFile'),
dict(type='Resize', size=(384, -1), backend='pillow'),
dict(type='CenterCrop', crop_size=384),
dict(type='Normalize', **img_norm_cfg),
dict(type='ImageToTensor', keys=['img']),
dict(type='Collect', keys=['img'])
]
data = dict(
train=dict(pipeline=train_pipeline),
val=dict(pipeline=test_pipeline),
test=dict(pipeline=test_pipeline),
)