[Feature]: Add pretraining for FGIA (#607)
* [Feature]: Add pretraining for FGIA * [Fix]: Add requirements and title * [Refactor]: Move readme to root folder * [Feature]: Add cls link * [Fix]: Fix typopull/630/head
parent
83e0917482
commit
a194464863
|
@ -0,0 +1,69 @@
|
||||||
|
# Solution of FGIA ACCV 2022 (1st Place)
|
||||||
|
|
||||||
|
## Requirements
|
||||||
|
|
||||||
|
```shell
|
||||||
|
PyTorch 1.11.0
|
||||||
|
torchvision 0.12.0
|
||||||
|
CUDA 11.3
|
||||||
|
MMEngine >= 0.1.0
|
||||||
|
MMCV >= 2.0.0rc0
|
||||||
|
MMClassification >= 1.0.0rc1
|
||||||
|
```
|
||||||
|
|
||||||
|
## Preparing the dataset
|
||||||
|
|
||||||
|
First you should refactor the folder of your dataset in the following format:
|
||||||
|
|
||||||
|
```text
|
||||||
|
mmselfsup
|
||||||
|
|
|
||||||
|
|── data
|
||||||
|
| |── WebiNat5000
|
||||||
|
| | |── meta
|
||||||
|
| | | |── train.txt
|
||||||
|
| | |── train
|
||||||
|
| | |── testa
|
||||||
|
| | |── testb
|
||||||
|
```
|
||||||
|
|
||||||
|
The `train`, `testa`, and `testb` folders contain the same content with
|
||||||
|
those provided by the official website of the competition.
|
||||||
|
|
||||||
|
## Start pre-training
|
||||||
|
|
||||||
|
First, you should install all these requirements, following this [page](https://mmselfsup.readthedocs.io/en/dev-1.x/get_started.html).
|
||||||
|
Then change your current directory to the root of MMSelfSup
|
||||||
|
|
||||||
|
```shell
|
||||||
|
cd $MMSelfSup
|
||||||
|
```
|
||||||
|
|
||||||
|
Then you have the following two choices to start pre-training
|
||||||
|
|
||||||
|
### Slurm
|
||||||
|
|
||||||
|
If you have a cluster managed by Slurm, you can use the following command:
|
||||||
|
|
||||||
|
```shell
|
||||||
|
## we use 16 NVIDIA 80G A100 GPUs for pre-training
|
||||||
|
GPUS_PER_NODE=8 GPUS=16 SRUN_ARGS=${SRUN_ARGS} bash tools/slurm_train.sh ${PARTITION} ${JOB_NAME} projects/fgia_accv2022_1st/config/mae_vit-large-p16_8xb512-amp-coslr-1600e_in1k.py [optional arguments]
|
||||||
|
```
|
||||||
|
|
||||||
|
### Pytorch
|
||||||
|
|
||||||
|
Or you can use the following two commands to start distributed training on two separate nodes:
|
||||||
|
|
||||||
|
```shell
|
||||||
|
# node 1
|
||||||
|
NNODES=2 NODE_RANK=0 PORT=${MASTER_PORT} MASTER_ADDR=${MASTER_ADDR} bash tools/dist_train.sh projects/fgia_accv2022_1st/config/mae_vit-large-p16_8xb512-amp-coslr-1600e_in1k.py 8
|
||||||
|
```
|
||||||
|
|
||||||
|
```shell
|
||||||
|
# node 2
|
||||||
|
NNODES=2 NODE_RANK=1 PORT=${MASTER_PORT} MASTER_ADDR=${MASTER_ADDR} bash tools/dist_train.sh projects/fgia_accv2022_1st/config/mae_vit-large-p16_8xb512-amp-coslr-1600e_in1k.py 8
|
||||||
|
```
|
||||||
|
|
||||||
|
All these logs and checkpoints will be saved under the folder `work_dirs`in the root.
|
||||||
|
|
||||||
|
Then you can use the pre-trained weights to initialize the model for downstream fine-tuning, following this [project](https://github.com/open-mmlab/mmclassification/tree/dev-1.x/projects/projects/fgia_accv2022_1st) in MMClassification.
|
|
@ -0,0 +1,120 @@
|
||||||
|
model = dict(
|
||||||
|
type='MAE',
|
||||||
|
data_preprocessor=dict(
|
||||||
|
mean=[123.675, 116.28, 103.53],
|
||||||
|
std=[58.395, 57.12, 57.375],
|
||||||
|
bgr_to_rgb=True),
|
||||||
|
backbone=dict(type='MAEViT', arch='l', patch_size=16, mask_ratio=0.75),
|
||||||
|
neck=dict(
|
||||||
|
type='MAEPretrainDecoder',
|
||||||
|
patch_size=16,
|
||||||
|
in_chans=3,
|
||||||
|
embed_dim=1024,
|
||||||
|
decoder_embed_dim=512,
|
||||||
|
decoder_depth=8,
|
||||||
|
decoder_num_heads=16,
|
||||||
|
mlp_ratio=4.0),
|
||||||
|
head=dict(
|
||||||
|
type='MAEPretrainHead',
|
||||||
|
norm_pix=True,
|
||||||
|
patch_size=16,
|
||||||
|
loss=dict(type='MAEReconstructionLoss')),
|
||||||
|
init_cfg=dict(
|
||||||
|
type='Pretrained',
|
||||||
|
checkpoint= # noqa: E251
|
||||||
|
'https://download.openmmlab.com/mmselfsup/1.x/mae/mae_vit-large-p16_8xb512-fp16-coslr-1600e_in1k/mae_vit-large-p16_8xb512-fp16-coslr-1600e_in1k_20220825-cc7e98c9.pth' # noqa
|
||||||
|
))
|
||||||
|
custom_imports = dict(imports='mmcls.datasets', allow_failed_imports=False)
|
||||||
|
dataset_type = 'mmcls.ImageNet'
|
||||||
|
data_root = 'data/WebiNat5000/'
|
||||||
|
file_client_args = dict(backend='disk')
|
||||||
|
train_pipeline = [
|
||||||
|
dict(type='LoadImageFromFile', file_client_args=file_client_args),
|
||||||
|
dict(
|
||||||
|
type='RandomResizedCrop',
|
||||||
|
size=224,
|
||||||
|
scale=(0.2, 1.0),
|
||||||
|
backend='pillow',
|
||||||
|
interpolation='bicubic'),
|
||||||
|
dict(type='RandomFlip', prob=0.5),
|
||||||
|
dict(type='PackSelfSupInputs', meta_keys=['img_path'])
|
||||||
|
]
|
||||||
|
train_dataloader = dict(
|
||||||
|
batch_size=256,
|
||||||
|
num_workers=16,
|
||||||
|
persistent_workers=True,
|
||||||
|
sampler=dict(type='DefaultSampler', shuffle=True),
|
||||||
|
collate_fn=dict(type='default_collate'),
|
||||||
|
pin_memory=True,
|
||||||
|
dataset=dict(
|
||||||
|
type='mmcls.ImageNet',
|
||||||
|
data_root='data/WebiNat5000/',
|
||||||
|
ann_file='data/WebiNat5000/meta/train.txt',
|
||||||
|
data_prefix=dict(img_path='train/'),
|
||||||
|
pipeline=[
|
||||||
|
dict(type='LoadImageFromFile', file_client_args=file_client_args),
|
||||||
|
dict(
|
||||||
|
type='RandomResizedCrop',
|
||||||
|
size=224,
|
||||||
|
scale=(0.2, 1.0),
|
||||||
|
backend='pillow',
|
||||||
|
interpolation='bicubic'),
|
||||||
|
dict(type='RandomFlip', prob=0.5),
|
||||||
|
dict(type='PackSelfSupInputs', meta_keys=['img_path'])
|
||||||
|
]))
|
||||||
|
optimizer = dict(type='AdamW', lr=0.0024, betas=(0.9, 0.95), weight_decay=0.05)
|
||||||
|
optim_wrapper = dict(
|
||||||
|
type='AmpOptimWrapper',
|
||||||
|
optimizer=dict(
|
||||||
|
type='AdamW', lr=0.0024, betas=(0.9, 0.95), weight_decay=0.05),
|
||||||
|
paramwise_cfg=dict(
|
||||||
|
custom_keys=dict(
|
||||||
|
ln=dict(decay_mult=0.0),
|
||||||
|
bias=dict(decay_mult=0.0),
|
||||||
|
pos_embed=dict(decay_mult=0.0),
|
||||||
|
mask_token=dict(decay_mult=0.0),
|
||||||
|
cls_token=dict(decay_mult=0.0))),
|
||||||
|
loss_scale='dynamic')
|
||||||
|
param_scheduler = [
|
||||||
|
dict(
|
||||||
|
type='LinearLR',
|
||||||
|
start_factor=0.0001,
|
||||||
|
by_epoch=True,
|
||||||
|
begin=0,
|
||||||
|
end=40,
|
||||||
|
convert_to_iter_based=True),
|
||||||
|
dict(
|
||||||
|
type='CosineAnnealingLR',
|
||||||
|
T_max=1560,
|
||||||
|
by_epoch=True,
|
||||||
|
begin=40,
|
||||||
|
end=1600,
|
||||||
|
convert_to_iter_based=True)
|
||||||
|
]
|
||||||
|
train_cfg = dict(type='EpochBasedTrainLoop', max_epochs=1600)
|
||||||
|
default_scope = 'mmselfsup'
|
||||||
|
default_hooks = dict(
|
||||||
|
runtime_info=dict(type='RuntimeInfoHook'),
|
||||||
|
timer=dict(type='IterTimerHook'),
|
||||||
|
logger=dict(type='LoggerHook', interval=100),
|
||||||
|
param_scheduler=dict(type='ParamSchedulerHook'),
|
||||||
|
checkpoint=dict(type='CheckpointHook', interval=1, max_keep_ckpts=1),
|
||||||
|
sampler_seed=dict(type='DistSamplerSeedHook'))
|
||||||
|
env_cfg = dict(
|
||||||
|
cudnn_benchmark=False,
|
||||||
|
mp_cfg=dict(mp_start_method='fork', opencv_num_threads=0),
|
||||||
|
dist_cfg=dict(backend='nccl'))
|
||||||
|
log_processor = dict(
|
||||||
|
window_size=10,
|
||||||
|
custom_cfg=[dict(data_src='', method='mean', windows_size='global')])
|
||||||
|
vis_backends = [dict(type='LocalVisBackend')]
|
||||||
|
visualizer = dict(
|
||||||
|
type='SelfSupVisualizer',
|
||||||
|
vis_backends=[dict(type='LocalVisBackend')],
|
||||||
|
name='visualizer')
|
||||||
|
log_level = 'INFO'
|
||||||
|
load_from = None
|
||||||
|
resume = False
|
||||||
|
randomness = dict(seed=0, diff_rank_seed=True)
|
||||||
|
launcher = 'slurm'
|
||||||
|
work_dir = './work_dirs/selfsup/mae_vit-large-p16_8xb512-amp-coslr-1600e_in1k'
|
Loading…
Reference in New Issue