[Feature]: Add pretraining for FGIA (#607)

* [Feature]: Add pretraining for FGIA

* [Fix]: Add requirements and title

* [Refactor]: Move readme to root folder

* [Feature]: Add cls link

* [Fix]: Fix typo
pull/630/head
Yuan Liu 2022-12-12 18:26:33 +08:00 committed by GitHub
parent 83e0917482
commit a194464863
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 189 additions and 0 deletions

View File

@ -0,0 +1,69 @@
# Solution of FGIA ACCV 2022 (1st Place)
## Requirements
```shell
PyTorch 1.11.0
torchvision 0.12.0
CUDA 11.3
MMEngine >= 0.1.0
MMCV >= 2.0.0rc0
MMClassification >= 1.0.0rc1
```
## Preparing the dataset
First you should refactor the folder of your dataset in the following format:
```text
mmselfsup
|
|── data
| |── WebiNat5000
| | |── meta
| | | |── train.txt
| | |── train
| | |── testa
| | |── testb
```
The `train`, `testa`, and `testb` folders contain the same content with
those provided by the official website of the competition.
## Start pre-training
First, you should install all these requirements, following this [page](https://mmselfsup.readthedocs.io/en/dev-1.x/get_started.html).
Then change your current directory to the root of MMSelfSup
```shell
cd $MMSelfSup
```
Then you have the following two choices to start pre-training
### Slurm
If you have a cluster managed by Slurm, you can use the following command:
```shell
## we use 16 NVIDIA 80G A100 GPUs for pre-training
GPUS_PER_NODE=8 GPUS=16 SRUN_ARGS=${SRUN_ARGS} bash tools/slurm_train.sh ${PARTITION} ${JOB_NAME} projects/fgia_accv2022_1st/config/mae_vit-large-p16_8xb512-amp-coslr-1600e_in1k.py [optional arguments]
```
### Pytorch
Or you can use the following two commands to start distributed training on two separate nodes:
```shell
# node 1
NNODES=2 NODE_RANK=0 PORT=${MASTER_PORT} MASTER_ADDR=${MASTER_ADDR} bash tools/dist_train.sh projects/fgia_accv2022_1st/config/mae_vit-large-p16_8xb512-amp-coslr-1600e_in1k.py 8
```
```shell
# node 2
NNODES=2 NODE_RANK=1 PORT=${MASTER_PORT} MASTER_ADDR=${MASTER_ADDR} bash tools/dist_train.sh projects/fgia_accv2022_1st/config/mae_vit-large-p16_8xb512-amp-coslr-1600e_in1k.py 8
```
All these logs and checkpoints will be saved under the folder `work_dirs`in the root.
Then you can use the pre-trained weights to initialize the model for downstream fine-tuning, following this [project](https://github.com/open-mmlab/mmclassification/tree/dev-1.x/projects/projects/fgia_accv2022_1st) in MMClassification.

View File

@ -0,0 +1,120 @@
model = dict(
type='MAE',
data_preprocessor=dict(
mean=[123.675, 116.28, 103.53],
std=[58.395, 57.12, 57.375],
bgr_to_rgb=True),
backbone=dict(type='MAEViT', arch='l', patch_size=16, mask_ratio=0.75),
neck=dict(
type='MAEPretrainDecoder',
patch_size=16,
in_chans=3,
embed_dim=1024,
decoder_embed_dim=512,
decoder_depth=8,
decoder_num_heads=16,
mlp_ratio=4.0),
head=dict(
type='MAEPretrainHead',
norm_pix=True,
patch_size=16,
loss=dict(type='MAEReconstructionLoss')),
init_cfg=dict(
type='Pretrained',
checkpoint= # noqa: E251
'https://download.openmmlab.com/mmselfsup/1.x/mae/mae_vit-large-p16_8xb512-fp16-coslr-1600e_in1k/mae_vit-large-p16_8xb512-fp16-coslr-1600e_in1k_20220825-cc7e98c9.pth' # noqa
))
custom_imports = dict(imports='mmcls.datasets', allow_failed_imports=False)
dataset_type = 'mmcls.ImageNet'
data_root = 'data/WebiNat5000/'
file_client_args = dict(backend='disk')
train_pipeline = [
dict(type='LoadImageFromFile', file_client_args=file_client_args),
dict(
type='RandomResizedCrop',
size=224,
scale=(0.2, 1.0),
backend='pillow',
interpolation='bicubic'),
dict(type='RandomFlip', prob=0.5),
dict(type='PackSelfSupInputs', meta_keys=['img_path'])
]
train_dataloader = dict(
batch_size=256,
num_workers=16,
persistent_workers=True,
sampler=dict(type='DefaultSampler', shuffle=True),
collate_fn=dict(type='default_collate'),
pin_memory=True,
dataset=dict(
type='mmcls.ImageNet',
data_root='data/WebiNat5000/',
ann_file='data/WebiNat5000/meta/train.txt',
data_prefix=dict(img_path='train/'),
pipeline=[
dict(type='LoadImageFromFile', file_client_args=file_client_args),
dict(
type='RandomResizedCrop',
size=224,
scale=(0.2, 1.0),
backend='pillow',
interpolation='bicubic'),
dict(type='RandomFlip', prob=0.5),
dict(type='PackSelfSupInputs', meta_keys=['img_path'])
]))
optimizer = dict(type='AdamW', lr=0.0024, betas=(0.9, 0.95), weight_decay=0.05)
optim_wrapper = dict(
type='AmpOptimWrapper',
optimizer=dict(
type='AdamW', lr=0.0024, betas=(0.9, 0.95), weight_decay=0.05),
paramwise_cfg=dict(
custom_keys=dict(
ln=dict(decay_mult=0.0),
bias=dict(decay_mult=0.0),
pos_embed=dict(decay_mult=0.0),
mask_token=dict(decay_mult=0.0),
cls_token=dict(decay_mult=0.0))),
loss_scale='dynamic')
param_scheduler = [
dict(
type='LinearLR',
start_factor=0.0001,
by_epoch=True,
begin=0,
end=40,
convert_to_iter_based=True),
dict(
type='CosineAnnealingLR',
T_max=1560,
by_epoch=True,
begin=40,
end=1600,
convert_to_iter_based=True)
]
train_cfg = dict(type='EpochBasedTrainLoop', max_epochs=1600)
default_scope = 'mmselfsup'
default_hooks = dict(
runtime_info=dict(type='RuntimeInfoHook'),
timer=dict(type='IterTimerHook'),
logger=dict(type='LoggerHook', interval=100),
param_scheduler=dict(type='ParamSchedulerHook'),
checkpoint=dict(type='CheckpointHook', interval=1, max_keep_ckpts=1),
sampler_seed=dict(type='DistSamplerSeedHook'))
env_cfg = dict(
cudnn_benchmark=False,
mp_cfg=dict(mp_start_method='fork', opencv_num_threads=0),
dist_cfg=dict(backend='nccl'))
log_processor = dict(
window_size=10,
custom_cfg=[dict(data_src='', method='mean', windows_size='global')])
vis_backends = [dict(type='LocalVisBackend')]
visualizer = dict(
type='SelfSupVisualizer',
vis_backends=[dict(type='LocalVisBackend')],
name='visualizer')
log_level = 'INFO'
load_from = None
resume = False
randomness = dict(seed=0, diff_rank_seed=True)
launcher = 'slurm'
work_dir = './work_dirs/selfsup/mae_vit-large-p16_8xb512-amp-coslr-1600e_in1k'