From 119bbd838deba739ba488a59a7d724a3e1c56154 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E8=B0=A2=E6=98=95=E8=BE=B0?= Date: Thu, 26 Aug 2021 06:00:41 +0800 Subject: [PATCH] [Enhancement] Delete convert function and add instruction to ViT/Swin README.md (#791) * delete convert function and add instruction to README.md * unified model convert and README * remove url * fix import error * fix unittest * rename pretrain * rename vit and deit pretrain * Update upernet_deit-b16_512x512_160k_ade20k.py * Update upernet_deit-b16_512x512_80k_ade20k.py * Update upernet_deit-b16_ln_mln_512x512_160k_ade20k.py * Update upernet_deit-b16_mln_512x512_160k_ade20k.py * Update upernet_deit-s16_512x512_160k_ade20k.py * Update upernet_deit-s16_512x512_80k_ade20k.py * Update upernet_deit-s16_ln_mln_512x512_160k_ade20k.py * Update upernet_deit-s16_mln_512x512_160k_ade20k.py Co-authored-by: Jiarui XU Co-authored-by: Junjun2016 --- configs/_base_/models/setr_mla.py | 3 +- configs/_base_/models/setr_naive.py | 3 +- configs/_base_/models/setr_pup.py | 3 +- .../_base_/models/upernet_vit-b16_ln_mln.py | 2 +- configs/segformer/README.md | 22 +++-- .../setr/setr_mla_512x512_160k_b8_ade20k.py | 1 + .../setr_naive_512x512_160k_b16_ade20k.py | 1 + .../setr/setr_pup_512x512_160k_b16_ade20k.py | 1 + configs/swin/README.md | 18 ++++ ...512x512_160k_ade20k_pretrain_384x384_1K.py | 3 +- ...12x512_160k_ade20k_pretrain_384x384_22K.py | 5 +- ...512x512_160k_ade20k_pretrain_224x224_1K.py | 7 +- ...12x512_160k_ade20k_pretrain_224x224_22K.py | 5 +- ...512x512_160k_ade20k_pretrain_224x224_1K.py | 16 +--- ...512x512_160k_ade20k_pretrain_224x224_1K.py | 3 +- configs/vit/README.md | 18 ++++ .../upernet_deit-b16_512x512_160k_ade20k.py | 4 +- .../upernet_deit-b16_512x512_80k_ade20k.py | 4 +- ...net_deit-b16_ln_mln_512x512_160k_ade20k.py | 4 +- ...pernet_deit-b16_mln_512x512_160k_ade20k.py | 5 +- .../upernet_deit-s16_512x512_160k_ade20k.py | 4 +- .../upernet_deit-s16_512x512_80k_ade20k.py | 4 +- ...net_deit-s16_ln_mln_512x512_160k_ade20k.py | 9 +- ...pernet_deit-s16_mln_512x512_160k_ade20k.py | 4 +- ...rnet_vit-b16_ln_mln_512x512_160k_ade20k.py | 1 + ...upernet_vit-b16_mln_512x512_160k_ade20k.py | 4 +- .../upernet_vit-b16_mln_512x512_80k_ade20k.py | 4 +- mmseg/models/backbones/swin.py | 12 +-- mmseg/models/backbones/vit.py | 16 +--- mmseg/models/utils/__init__.py | 6 +- mmseg/models/utils/ckpt_convert.py | 91 ------------------- tests/test_models/test_backbones/test_swin.py | 4 - .../{mit_convert.py => mit2mmseg.py} | 37 ++++---- tools/model_converters/swin2mmseg.py | 12 ++- tools/model_converters/vit2mmseg.py | 12 ++- 35 files changed, 131 insertions(+), 217 deletions(-) delete mode 100644 mmseg/models/utils/ckpt_convert.py rename tools/model_converters/{mit_convert.py => mit2mmseg.py} (73%) diff --git a/configs/_base_/models/setr_mla.py b/configs/_base_/models/setr_mla.py index facd255f9..af4ba2492 100644 --- a/configs/_base_/models/setr_mla.py +++ b/configs/_base_/models/setr_mla.py @@ -3,8 +3,7 @@ backbone_norm_cfg = dict(type='LN', eps=1e-6, requires_grad=True) norm_cfg = dict(type='SyncBN', requires_grad=True) model = dict( type='EncoderDecoder', - pretrained=\ - 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_large_p16_384-b3be5167.pth', # noqa + pretrained='pretrain/jx_vit_large_p16_384-b3be5167.pth', backbone=dict( type='VisionTransformer', img_size=(768, 768), diff --git a/configs/_base_/models/setr_naive.py b/configs/_base_/models/setr_naive.py index 64d1395b5..0c330ea2d 100644 --- a/configs/_base_/models/setr_naive.py +++ b/configs/_base_/models/setr_naive.py @@ -3,8 +3,7 @@ backbone_norm_cfg = dict(type='LN', eps=1e-6, requires_grad=True) norm_cfg = dict(type='SyncBN', requires_grad=True) model = dict( type='EncoderDecoder', - pretrained=\ - 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_large_p16_384-b3be5167.pth', # noqa + pretrained='pretrain/jx_vit_large_p16_384-b3be5167.pth', backbone=dict( type='VisionTransformer', img_size=(768, 768), diff --git a/configs/_base_/models/setr_pup.py b/configs/_base_/models/setr_pup.py index f87e88b8a..8e5f23b9c 100644 --- a/configs/_base_/models/setr_pup.py +++ b/configs/_base_/models/setr_pup.py @@ -3,8 +3,7 @@ backbone_norm_cfg = dict(type='LN', eps=1e-6, requires_grad=True) norm_cfg = dict(type='SyncBN', requires_grad=True) model = dict( type='EncoderDecoder', - pretrained=\ - 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_large_p16_384-b3be5167.pth', # noqa + pretrained='pretrain/jx_vit_large_p16_384-b3be5167.pth', backbone=dict( type='VisionTransformer', img_size=(768, 768), diff --git a/configs/_base_/models/upernet_vit-b16_ln_mln.py b/configs/_base_/models/upernet_vit-b16_ln_mln.py index 1a5a56972..cd6587dfe 100644 --- a/configs/_base_/models/upernet_vit-b16_ln_mln.py +++ b/configs/_base_/models/upernet_vit-b16_ln_mln.py @@ -2,7 +2,7 @@ norm_cfg = dict(type='SyncBN', requires_grad=True) model = dict( type='EncoderDecoder', - pretrained='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_base_p16_224-80ecf9dd.pth', # noqa + pretrained='pretrain/jx_vit_base_p16_224-80ecf9dd.pth', backbone=dict( type='VisionTransformer', img_size=(512, 512), diff --git a/configs/segformer/README.md b/configs/segformer/README.md index 7a9a5ef74..d325589c6 100644 --- a/configs/segformer/README.md +++ b/configs/segformer/README.md @@ -13,6 +13,18 @@ } ``` +## Usage + +To use other repositories' pre-trained models, it is necessary to convert keys. + +We provide a script [`mit2mmseg.py`](../../tools/model_converters/mit2mmseg.py) in the tools directory to convert the key of models from [the official repo](https://github.com/NVlabs/SegFormer) to MMSegmentation style. + +```shell +python tools/model_converters/swin2mmseg.py ${PRETRAIN_PATH} ${STORE_PATH} +``` + +This script convert model from `PRETRAIN_PATH` and store the converted model in `STORE_PATH`. + ## Results and models ### ADE20k @@ -61,13 +73,3 @@ test_pipeline = [ ]) ] ``` - -## How to use segformer official pretrain weights - -We convert the backbone weights from the official repo (https://github.com/NVlabs/SegFormer) with `tools/model_converters/mit_convert.py`. - -You may follow below steps to start segformer training preparation: - -1. Download segformer pretrain weights (Suggest put in `pretrain/`); -2. Run convert script to convert official pretrain weights: `python tools/model_converters/mit_convert.py pretrain/mit_b0.pth pretrain/mit_b0.pth`; -3. Modify `pretrained` of segformer model config, for example, `pretrained` of `segformer_mit-b0_512x512_160k_ade20k.py` is set to `pretrain/mit_b0.pth`; diff --git a/configs/setr/setr_mla_512x512_160k_b8_ade20k.py b/configs/setr/setr_mla_512x512_160k_b8_ade20k.py index b47cc60af..2958a6df6 100644 --- a/configs/setr/setr_mla_512x512_160k_b8_ade20k.py +++ b/configs/setr/setr_mla_512x512_160k_b8_ade20k.py @@ -4,6 +4,7 @@ _base_ = [ ] norm_cfg = dict(type='SyncBN', requires_grad=True) model = dict( + pretrained='pretrain/vit_large_patch16_384.pth', backbone=dict(img_size=(512, 512), drop_rate=0.), decode_head=dict(num_classes=150), auxiliary_head=[ diff --git a/configs/setr/setr_naive_512x512_160k_b16_ade20k.py b/configs/setr/setr_naive_512x512_160k_b16_ade20k.py index f01b1b876..2abf9df77 100644 --- a/configs/setr/setr_naive_512x512_160k_b16_ade20k.py +++ b/configs/setr/setr_naive_512x512_160k_b16_ade20k.py @@ -4,6 +4,7 @@ _base_ = [ ] norm_cfg = dict(type='SyncBN', requires_grad=True) model = dict( + pretrained='pretrain/vit_large_patch16_384.pth', backbone=dict(img_size=(512, 512), drop_rate=0.), decode_head=dict(num_classes=150), auxiliary_head=[ diff --git a/configs/setr/setr_pup_512x512_160k_b16_ade20k.py b/configs/setr/setr_pup_512x512_160k_b16_ade20k.py index 31c24de65..da3828364 100644 --- a/configs/setr/setr_pup_512x512_160k_b16_ade20k.py +++ b/configs/setr/setr_pup_512x512_160k_b16_ade20k.py @@ -4,6 +4,7 @@ _base_ = [ ] norm_cfg = dict(type='SyncBN', requires_grad=True) model = dict( + pretrained='pretrain/vit_large_patch16_384.pth', backbone=dict(img_size=(512, 512), drop_rate=0.), decode_head=dict(num_classes=150), auxiliary_head=[ diff --git a/configs/swin/README.md b/configs/swin/README.md index 2e50049a7..72f77f523 100644 --- a/configs/swin/README.md +++ b/configs/swin/README.md @@ -13,6 +13,24 @@ } ``` +## Usage + +To use other repositories' pre-trained models, it is necessary to convert keys. + +We provide a script [`swin2mmseg.py`](../../tools/model_converters/swin2mmseg.py) in the tools directory to convert the key of models from [the official repo](https://github.com/SwinTransformer/Swin-Transformer-Semantic-Segmentation) to MMSegmentation style. + +```shell +python tools/model_converters/swin2mmseg.py ${PRETRAIN_PATH} ${STORE_PATH} +``` + +E.g. + +```shell +python tools/model_converters/swin2mmseg.py https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_base_patch4_window7_224.pth pretrain/swin_base_patch4_window7_224.pth +``` + +This script convert model from `PRETRAIN_PATH` and store the converted model in `STORE_PATH`. + ## Results and models ### ADE20K diff --git a/configs/swin/upernet_swin_base_patch4_window12_512x512_160k_ade20k_pretrain_384x384_1K.py b/configs/swin/upernet_swin_base_patch4_window12_512x512_160k_ade20k_pretrain_384x384_1K.py index d89f57cab..a4c2920c2 100644 --- a/configs/swin/upernet_swin_base_patch4_window12_512x512_160k_ade20k_pretrain_384x384_1K.py +++ b/configs/swin/upernet_swin_base_patch4_window12_512x512_160k_ade20k_pretrain_384x384_1K.py @@ -3,8 +3,7 @@ _base_ = [ 'pretrain_224x224_1K.py' ] model = dict( - pretrained=\ - 'https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_base_patch4_window12_384.pth', # noqa + pretrained='pretrain/swin_base_patch4_window12_384.pth', backbone=dict( pretrain_img_size=384, embed_dims=128, diff --git a/configs/swin/upernet_swin_base_patch4_window12_512x512_160k_ade20k_pretrain_384x384_22K.py b/configs/swin/upernet_swin_base_patch4_window12_512x512_160k_ade20k_pretrain_384x384_22K.py index 38fed2648..ecb58936b 100644 --- a/configs/swin/upernet_swin_base_patch4_window12_512x512_160k_ade20k_pretrain_384x384_22K.py +++ b/configs/swin/upernet_swin_base_patch4_window12_512x512_160k_ade20k_pretrain_384x384_22K.py @@ -2,7 +2,4 @@ _base_ = [ './upernet_swin_base_patch4_window12_512x512_160k_ade20k_' 'pretrain_384x384_1K.py' ] -model = dict( - pretrained=\ - 'https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_base_patch4_window12_384_22k.pth', # noqa -) +model = dict(pretrained='pretrain/swin_base_patch4_window12_384_22k.pth') diff --git a/configs/swin/upernet_swin_base_patch4_window7_512x512_160k_ade20k_pretrain_224x224_1K.py b/configs/swin/upernet_swin_base_patch4_window7_512x512_160k_ade20k_pretrain_224x224_1K.py index c34594a46..dde63d29f 100644 --- a/configs/swin/upernet_swin_base_patch4_window7_512x512_160k_ade20k_pretrain_224x224_1K.py +++ b/configs/swin/upernet_swin_base_patch4_window7_512x512_160k_ade20k_pretrain_224x224_1K.py @@ -3,11 +3,8 @@ _base_ = [ 'pretrain_224x224_1K.py' ] model = dict( - pretrained=\ - 'https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_base_patch4_window7_224.pth', # noqa + pretrained='pretrain/swin_base_patch4_window7_224.pth', backbone=dict( - embed_dims=128, - depths=[2, 2, 18, 2], - num_heads=[4, 8, 16, 32]), + embed_dims=128, depths=[2, 2, 18, 2], num_heads=[4, 8, 16, 32]), decode_head=dict(in_channels=[128, 256, 512, 1024], num_classes=150), auxiliary_head=dict(in_channels=512, num_classes=150)) diff --git a/configs/swin/upernet_swin_base_patch4_window7_512x512_160k_ade20k_pretrain_224x224_22K.py b/configs/swin/upernet_swin_base_patch4_window7_512x512_160k_ade20k_pretrain_224x224_22K.py index 5bb51d878..ea3e21059 100644 --- a/configs/swin/upernet_swin_base_patch4_window7_512x512_160k_ade20k_pretrain_224x224_22K.py +++ b/configs/swin/upernet_swin_base_patch4_window7_512x512_160k_ade20k_pretrain_224x224_22K.py @@ -2,7 +2,4 @@ _base_ = [ './upernet_swin_base_patch4_window7_512x512_160k_ade20k_' 'pretrain_224x224_1K.py' ] -model = dict( - pretrained=\ - 'https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_base_patch4_window7_224_22k.pth', # noqa -) +model = dict(pretrained='pretrain/swin_base_patch4_window7_224_22k.pth') diff --git a/configs/swin/upernet_swin_small_patch4_window7_512x512_160k_ade20k_pretrain_224x224_1K.py b/configs/swin/upernet_swin_small_patch4_window7_512x512_160k_ade20k_pretrain_224x224_1K.py index 469b957c2..919e0c41a 100644 --- a/configs/swin/upernet_swin_small_patch4_window7_512x512_160k_ade20k_pretrain_224x224_1K.py +++ b/configs/swin/upernet_swin_small_patch4_window7_512x512_160k_ade20k_pretrain_224x224_1K.py @@ -3,15 +3,7 @@ _base_ = [ 'pretrain_224x224_1K.py' ] model = dict( - pretrained=\ - 'https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_small_patch4_window7_224.pth', # noqa - backbone=dict( - depths=[2, 2, 18, 2]), - decode_head=dict( - in_channels=[96, 192, 384, 768], - num_classes=150 - ), - auxiliary_head=dict( - in_channels=384, - num_classes=150 - )) + pretrained='pretrain/swin_small_patch4_window7_224.pth', + backbone=dict(depths=[2, 2, 18, 2]), + decode_head=dict(in_channels=[96, 192, 384, 768], num_classes=150), + auxiliary_head=dict(in_channels=384, num_classes=150)) diff --git a/configs/swin/upernet_swin_tiny_patch4_window7_512x512_160k_ade20k_pretrain_224x224_1K.py b/configs/swin/upernet_swin_tiny_patch4_window7_512x512_160k_ade20k_pretrain_224x224_1K.py index 7be1cf582..8dd840450 100644 --- a/configs/swin/upernet_swin_tiny_patch4_window7_512x512_160k_ade20k_pretrain_224x224_1K.py +++ b/configs/swin/upernet_swin_tiny_patch4_window7_512x512_160k_ade20k_pretrain_224x224_1K.py @@ -3,8 +3,7 @@ _base_ = [ '../_base_/default_runtime.py', '../_base_/schedules/schedule_160k.py' ] model = dict( - pretrained=\ - 'https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_tiny_patch4_window7_224.pth', # noqa + pretrained='pretrain/swin_tiny_patch4_window7_224.pth', backbone=dict( embed_dims=96, depths=[2, 2, 6, 2], diff --git a/configs/vit/README.md b/configs/vit/README.md index f0b0e1688..0751ae341 100644 --- a/configs/vit/README.md +++ b/configs/vit/README.md @@ -13,6 +13,24 @@ } ``` +## Usage + +To use other repositories' pre-trained models, it is necessary to convert keys. + +We provide a script [`vit2mmseg.py`](../../tools/model_converters/vit2mmseg.py) in the tools directory to convert the key of models from [timm](https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py) to MMSegmentation style. + +```shell +python tools/model_converters/vit2mmseg.py ${PRETRAIN_PATH} ${STORE_PATH} +``` + +E.g. + +```shell +python tools/model_converters/vit2mmseg.py https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_base_p16_224-80ecf9dd.pth pretrain/jx_vit_base_p16_224-80ecf9dd.pth +``` + +This script convert model from `PRETRAIN_PATH` and store the converted model in `STORE_PATH`. + ## Results and models ### ADE20K diff --git a/configs/vit/upernet_deit-b16_512x512_160k_ade20k.py b/configs/vit/upernet_deit-b16_512x512_160k_ade20k.py index 6f17d7a64..68f4bd42b 100644 --- a/configs/vit/upernet_deit-b16_512x512_160k_ade20k.py +++ b/configs/vit/upernet_deit-b16_512x512_160k_ade20k.py @@ -1,6 +1,6 @@ _base_ = './upernet_vit-b16_mln_512x512_160k_ade20k.py' model = dict( - pretrained='https://dl.fbaipublicfiles.com/deit/deit_base_patch16_224-b5f2ef4d.pth', # noqa + pretrained='pretrain/deit_base_patch16_224-b5f2ef4d.pth', backbone=dict(drop_path_rate=0.1), - neck=None) # yapf: disable + neck=None) diff --git a/configs/vit/upernet_deit-b16_512x512_80k_ade20k.py b/configs/vit/upernet_deit-b16_512x512_80k_ade20k.py index 7bff28a10..720482616 100644 --- a/configs/vit/upernet_deit-b16_512x512_80k_ade20k.py +++ b/configs/vit/upernet_deit-b16_512x512_80k_ade20k.py @@ -1,6 +1,6 @@ _base_ = './upernet_vit-b16_mln_512x512_80k_ade20k.py' model = dict( - pretrained='https://dl.fbaipublicfiles.com/deit/deit_base_patch16_224-b5f2ef4d.pth', # noqa + pretrained='pretrain/deit_base_patch16_224-b5f2ef4d.pth', backbone=dict(drop_path_rate=0.1), - neck=None) # yapf: disable + neck=None) diff --git a/configs/vit/upernet_deit-b16_ln_mln_512x512_160k_ade20k.py b/configs/vit/upernet_deit-b16_ln_mln_512x512_160k_ade20k.py index f5b2411df..32909ffa1 100644 --- a/configs/vit/upernet_deit-b16_ln_mln_512x512_160k_ade20k.py +++ b/configs/vit/upernet_deit-b16_ln_mln_512x512_160k_ade20k.py @@ -1,5 +1,5 @@ _base_ = './upernet_vit-b16_mln_512x512_160k_ade20k.py' model = dict( - pretrained='https://dl.fbaipublicfiles.com/deit/deit_base_patch16_224-b5f2ef4d.pth', # noqa - backbone=dict(drop_path_rate=0.1, final_norm=True)) # yapf: disable + pretrained='pretrain/deit_base_patch16_224-b5f2ef4d.pth', + backbone=dict(drop_path_rate=0.1, final_norm=True)) diff --git a/configs/vit/upernet_deit-b16_mln_512x512_160k_ade20k.py b/configs/vit/upernet_deit-b16_mln_512x512_160k_ade20k.py index 68efd4893..4abefe8dc 100644 --- a/configs/vit/upernet_deit-b16_mln_512x512_160k_ade20k.py +++ b/configs/vit/upernet_deit-b16_mln_512x512_160k_ade20k.py @@ -1,5 +1,6 @@ _base_ = './upernet_vit-b16_mln_512x512_160k_ade20k.py' model = dict( - pretrained='https://dl.fbaipublicfiles.com/deit/deit_base_patch16_224-b5f2ef4d.pth', # noqa - backbone=dict(drop_path_rate=0.1),) # yapf: disable + pretrained='pretrain/deit_base_patch16_224-b5f2ef4d.pth', + backbone=dict(drop_path_rate=0.1), +) diff --git a/configs/vit/upernet_deit-s16_512x512_160k_ade20k.py b/configs/vit/upernet_deit-s16_512x512_160k_ade20k.py index cae6f466c..290ff19ed 100644 --- a/configs/vit/upernet_deit-s16_512x512_160k_ade20k.py +++ b/configs/vit/upernet_deit-s16_512x512_160k_ade20k.py @@ -1,8 +1,8 @@ _base_ = './upernet_vit-b16_mln_512x512_160k_ade20k.py' model = dict( - pretrained='https://dl.fbaipublicfiles.com/deit/deit_small_patch16_224-cd65a155.pth', # noqa + pretrained='pretrain/deit_small_patch16_224-cd65a155.pth', backbone=dict(num_heads=6, embed_dims=384, drop_path_rate=0.1), decode_head=dict(num_classes=150, in_channels=[384, 384, 384, 384]), neck=None, - auxiliary_head=dict(num_classes=150, in_channels=384)) # yapf: disable + auxiliary_head=dict(num_classes=150, in_channels=384)) diff --git a/configs/vit/upernet_deit-s16_512x512_80k_ade20k.py b/configs/vit/upernet_deit-s16_512x512_80k_ade20k.py index b176abb79..605d264a7 100644 --- a/configs/vit/upernet_deit-s16_512x512_80k_ade20k.py +++ b/configs/vit/upernet_deit-s16_512x512_80k_ade20k.py @@ -1,8 +1,8 @@ _base_ = './upernet_vit-b16_mln_512x512_80k_ade20k.py' model = dict( - pretrained='https://dl.fbaipublicfiles.com/deit/deit_small_patch16_224-cd65a155.pth', # noqa + pretrained='pretrain/deit_small_patch16_224-cd65a155.pth', backbone=dict(num_heads=6, embed_dims=384, drop_path_rate=0.1), decode_head=dict(num_classes=150, in_channels=[384, 384, 384, 384]), neck=None, - auxiliary_head=dict(num_classes=150, in_channels=384)) # yapf: disable + auxiliary_head=dict(num_classes=150, in_channels=384)) diff --git a/configs/vit/upernet_deit-s16_ln_mln_512x512_160k_ade20k.py b/configs/vit/upernet_deit-s16_ln_mln_512x512_160k_ade20k.py index f328ca860..ef743a20e 100644 --- a/configs/vit/upernet_deit-s16_ln_mln_512x512_160k_ade20k.py +++ b/configs/vit/upernet_deit-s16_ln_mln_512x512_160k_ade20k.py @@ -1,12 +1,9 @@ _base_ = './upernet_vit-b16_mln_512x512_160k_ade20k.py' model = dict( - pretrained='https://dl.fbaipublicfiles.com/deit/deit_small_patch16_224-cd65a155.pth', # noqa + pretrained='pretrain/deit_small_patch16_224-cd65a155.pth', backbone=dict( - num_heads=6, - embed_dims=384, - drop_path_rate=0.1, - final_norm=True), + num_heads=6, embed_dims=384, drop_path_rate=0.1, final_norm=True), decode_head=dict(num_classes=150, in_channels=[384, 384, 384, 384]), neck=dict(in_channels=[384, 384, 384, 384], out_channels=384), - auxiliary_head=dict(num_classes=150, in_channels=384)) # yapf: disable + auxiliary_head=dict(num_classes=150, in_channels=384)) diff --git a/configs/vit/upernet_deit-s16_mln_512x512_160k_ade20k.py b/configs/vit/upernet_deit-s16_mln_512x512_160k_ade20k.py index a1e1c2a4e..069cab74f 100644 --- a/configs/vit/upernet_deit-s16_mln_512x512_160k_ade20k.py +++ b/configs/vit/upernet_deit-s16_mln_512x512_160k_ade20k.py @@ -1,8 +1,8 @@ _base_ = './upernet_vit-b16_mln_512x512_160k_ade20k.py' model = dict( - pretrained='https://dl.fbaipublicfiles.com/deit/deit_small_patch16_224-cd65a155.pth', # noqa + pretrained='pretrain/deit_small_patch16_224-cd65a155.pth', backbone=dict(num_heads=6, embed_dims=384, drop_path_rate=0.1), decode_head=dict(num_classes=150, in_channels=[384, 384, 384, 384]), neck=dict(in_channels=[384, 384, 384, 384], out_channels=384), - auxiliary_head=dict(num_classes=150, in_channels=384)) # yapf: disable + auxiliary_head=dict(num_classes=150, in_channels=384)) diff --git a/configs/vit/upernet_vit-b16_ln_mln_512x512_160k_ade20k.py b/configs/vit/upernet_vit-b16_ln_mln_512x512_160k_ade20k.py index f6f85378b..51eeda012 100644 --- a/configs/vit/upernet_vit-b16_ln_mln_512x512_160k_ade20k.py +++ b/configs/vit/upernet_vit-b16_ln_mln_512x512_160k_ade20k.py @@ -5,6 +5,7 @@ _base_ = [ ] model = dict( + pretrained='pretrain/vit_base_patch16_224.pth', backbone=dict(drop_path_rate=0.1, final_norm=True), decode_head=dict(num_classes=150), auxiliary_head=dict(num_classes=150)) diff --git a/configs/vit/upernet_vit-b16_mln_512x512_160k_ade20k.py b/configs/vit/upernet_vit-b16_mln_512x512_160k_ade20k.py index cc286f1fb..5b148d725 100644 --- a/configs/vit/upernet_vit-b16_mln_512x512_160k_ade20k.py +++ b/configs/vit/upernet_vit-b16_mln_512x512_160k_ade20k.py @@ -5,7 +5,9 @@ _base_ = [ ] model = dict( - decode_head=dict(num_classes=150), auxiliary_head=dict(num_classes=150)) + pretrained='pretrain/vit_base_patch16_224.pth', + decode_head=dict(num_classes=150), + auxiliary_head=dict(num_classes=150)) # AdamW optimizer, no weight decay for position embedding & layer norm # in backbone diff --git a/configs/vit/upernet_vit-b16_mln_512x512_80k_ade20k.py b/configs/vit/upernet_vit-b16_mln_512x512_80k_ade20k.py index d80b0d9fd..f893500d3 100644 --- a/configs/vit/upernet_vit-b16_mln_512x512_80k_ade20k.py +++ b/configs/vit/upernet_vit-b16_mln_512x512_80k_ade20k.py @@ -5,7 +5,9 @@ _base_ = [ ] model = dict( - decode_head=dict(num_classes=150), auxiliary_head=dict(num_classes=150)) + pretrained='pretrain/vit_base_patch16_224.pth', + decode_head=dict(num_classes=150), + auxiliary_head=dict(num_classes=150)) # AdamW optimizer, no weight decay for position embedding & layer norm # in backbone diff --git a/mmseg/models/backbones/swin.py b/mmseg/models/backbones/swin.py index c75bf5fc8..e3e835a03 100644 --- a/mmseg/models/backbones/swin.py +++ b/mmseg/models/backbones/swin.py @@ -17,7 +17,7 @@ from torch.nn.modules.utils import _pair as to_2tuple from mmseg.ops import resize from ...utils import get_root_logger from ..builder import ATTENTION, BACKBONES -from ..utils import PatchEmbed, swin_convert +from ..utils import PatchEmbed class PatchMerging(BaseModule): @@ -564,8 +564,6 @@ class SwinTransformer(BaseModule): Default: dict(type='LN'). norm_cfg (dict): Config dict for normalization layer at output of backone. Defaults: dict(type='LN'). - pretrain_style (str): Choose to use official or mmcls pretrain weights. - Default: official. pretrained (str, optional): model pretrained path. Default: None. init_cfg (dict, optional): The Config for initialization. Defaults to None. @@ -591,7 +589,6 @@ class SwinTransformer(BaseModule): use_abs_pos_embed=False, act_cfg=dict(type='GELU'), norm_cfg=dict(type='LN'), - pretrain_style='official', pretrained=None, init_cfg=None): super(SwinTransformer, self).__init__() @@ -605,9 +602,6 @@ class SwinTransformer(BaseModule): f'The size of image should have length 1 or 2, ' \ f'but got {len(pretrain_img_size)}' - assert pretrain_style in ['official', 'mmcls'], 'We only support load ' - 'official ckpt and mmcls ckpt.' - if isinstance(pretrained, str) or pretrained is None: warnings.warn('DeprecationWarning: pretrained is a deprecated, ' 'please use "init_cfg" instead') @@ -617,7 +611,6 @@ class SwinTransformer(BaseModule): num_layers = len(depths) self.out_indices = out_indices self.use_abs_pos_embed = use_abs_pos_embed - self.pretrain_style = pretrain_style self.pretrained = pretrained self.init_cfg = init_cfg @@ -713,9 +706,6 @@ class SwinTransformer(BaseModule): else: state_dict = ckpt - if self.pretrain_style == 'official': - state_dict = swin_convert(state_dict) - # strip prefix of state_dict if list(state_dict.keys())[0].startswith('module.'): state_dict = {k[7:]: v for k, v in state_dict.items()} diff --git a/mmseg/models/backbones/vit.py b/mmseg/models/backbones/vit.py index 5bee596fe..003fa537e 100644 --- a/mmseg/models/backbones/vit.py +++ b/mmseg/models/backbones/vit.py @@ -14,7 +14,7 @@ from torch.nn.modules.utils import _pair as to_2tuple from mmseg.ops import resize from mmseg.utils import get_root_logger from ..builder import BACKBONES -from ..utils import PatchEmbed, vit_convert +from ..utils import PatchEmbed class TransformerEncoderLayer(BaseModule): @@ -140,8 +140,6 @@ class VisionTransformer(BaseModule): and its variants only. Default: False. with_cp (bool): Use checkpoint or not. Using checkpoint will save some memory while slowing down the training speed. Default: False. - pretrain_style (str): Choose to use timm or mmcls pretrain weights. - Default: timm. pretrained (str, optional): model pretrained path. Default: None. init_cfg (dict or list[dict], optional): Initialization config dict. Default: None. @@ -170,7 +168,6 @@ class VisionTransformer(BaseModule): num_fcs=2, norm_eval=False, with_cp=False, - pretrain_style='timm', pretrained=None, init_cfg=None): super(VisionTransformer, self).__init__() @@ -184,8 +181,6 @@ class VisionTransformer(BaseModule): f'The size of image should have length 1 or 2, ' \ f'but got {len(img_size)}' - assert pretrain_style in ['timm', 'mmcls'] - if output_cls_token: assert with_cls_token is True, f'with_cls_token must be True if' \ f'set output_cls_token to True, but got {with_cls_token}' @@ -201,7 +196,6 @@ class VisionTransformer(BaseModule): self.interpolate_mode = interpolate_mode self.norm_eval = norm_eval self.with_cp = with_cp - self.pretrain_style = pretrain_style self.pretrained = pretrained self.init_cfg = init_cfg @@ -272,17 +266,9 @@ class VisionTransformer(BaseModule): self.pretrained, logger=logger, map_location='cpu') if 'state_dict' in checkpoint: state_dict = checkpoint['state_dict'] - elif 'model' in checkpoint: - state_dict = checkpoint['model'] else: state_dict = checkpoint - if self.pretrain_style == 'timm': - # Because the refactor of vit is blocked by mmcls, - # so we firstly use timm pretrain weights to train - # downstream model. - state_dict = vit_convert(state_dict) - if 'pos_embed' in state_dict.keys(): if self.pos_embed.shape != state_dict['pos_embed'].shape: logger.info(msg=f'Resize the pos_embed shape from ' diff --git a/mmseg/models/utils/__init__.py b/mmseg/models/utils/__init__.py index 817ab9cc6..2417c5183 100644 --- a/mmseg/models/utils/__init__.py +++ b/mmseg/models/utils/__init__.py @@ -1,5 +1,3 @@ -# Copyright (c) OpenMMLab. All rights reserved. -from .ckpt_convert import swin_convert, vit_convert from .embed import PatchEmbed from .inverted_residual import InvertedResidual, InvertedResidualV3 from .make_divisible import make_divisible @@ -11,6 +9,6 @@ from .up_conv_block import UpConvBlock __all__ = [ 'ResLayer', 'SelfAttentionBlock', 'make_divisible', 'InvertedResidual', - 'UpConvBlock', 'InvertedResidualV3', 'SELayer', 'vit_convert', - 'swin_convert', 'PatchEmbed', 'nchw_to_nlc', 'nlc_to_nchw' + 'UpConvBlock', 'InvertedResidualV3', 'SELayer', 'PatchEmbed', + 'nchw_to_nlc', 'nlc_to_nchw' ] diff --git a/mmseg/models/utils/ckpt_convert.py b/mmseg/models/utils/ckpt_convert.py deleted file mode 100644 index fd4632065..000000000 --- a/mmseg/models/utils/ckpt_convert.py +++ /dev/null @@ -1,91 +0,0 @@ -# Copyright (c) OpenMMLab. All rights reserved. -from collections import OrderedDict - - -def swin_convert(ckpt): - new_ckpt = OrderedDict() - - def correct_unfold_reduction_order(x): - out_channel, in_channel = x.shape - x = x.reshape(out_channel, 4, in_channel // 4) - x = x[:, [0, 2, 1, 3], :].transpose(1, - 2).reshape(out_channel, in_channel) - return x - - def correct_unfold_norm_order(x): - in_channel = x.shape[0] - x = x.reshape(4, in_channel // 4) - x = x[[0, 2, 1, 3], :].transpose(0, 1).reshape(in_channel) - return x - - for k, v in ckpt.items(): - if k.startswith('head'): - continue - elif k.startswith('layers'): - new_v = v - if 'attn.' in k: - new_k = k.replace('attn.', 'attn.w_msa.') - elif 'mlp.' in k: - if 'mlp.fc1.' in k: - new_k = k.replace('mlp.fc1.', 'ffn.layers.0.0.') - elif 'mlp.fc2.' in k: - new_k = k.replace('mlp.fc2.', 'ffn.layers.1.') - else: - new_k = k.replace('mlp.', 'ffn.') - elif 'downsample' in k: - new_k = k - if 'reduction.' in k: - new_v = correct_unfold_reduction_order(v) - elif 'norm.' in k: - new_v = correct_unfold_norm_order(v) - else: - new_k = k - new_k = new_k.replace('layers', 'stages', 1) - elif k.startswith('patch_embed'): - new_v = v - if 'proj' in k: - new_k = k.replace('proj', 'projection') - else: - new_k = k - else: - new_v = v - new_k = k - - new_ckpt[new_k] = new_v - - return new_ckpt - - -def vit_convert(ckpt): - - new_ckpt = OrderedDict() - - for k, v in ckpt.items(): - if k.startswith('head'): - continue - if k.startswith('norm'): - new_k = k.replace('norm.', 'ln1.') - elif k.startswith('patch_embed'): - if 'proj' in k: - new_k = k.replace('proj', 'projection') - else: - new_k = k - elif k.startswith('blocks'): - if 'norm' in k: - new_k = k.replace('norm', 'ln') - elif 'mlp.fc1' in k: - new_k = k.replace('mlp.fc1', 'ffn.layers.0.0') - elif 'mlp.fc2' in k: - new_k = k.replace('mlp.fc2', 'ffn.layers.1') - elif 'attn.qkv' in k: - new_k = k.replace('attn.qkv.', 'attn.attn.in_proj_') - elif 'attn.proj' in k: - new_k = k.replace('attn.proj', 'attn.attn.out_proj') - else: - new_k = k - new_k = new_k.replace('blocks.', 'layers.') - else: - new_k = k - new_ckpt[new_k] = v - - return new_ckpt diff --git a/tests/test_models/test_backbones/test_swin.py b/tests/test_models/test_backbones/test_swin.py index d82a4ba10..edb2f833e 100644 --- a/tests/test_models/test_backbones/test_swin.py +++ b/tests/test_models/test_backbones/test_swin.py @@ -8,10 +8,6 @@ from mmseg.models.backbones import SwinTransformer def test_swin_transformer(): """Test Swin Transformer backbone.""" - with pytest.raises(AssertionError): - # We only support 'official' or 'mmcls' for this arg. - model = SwinTransformer(pretrain_style='swin') - with pytest.raises(TypeError): # Pretrained arg must be str or None. model = SwinTransformer(pretrained=123) diff --git a/tools/model_converters/mit_convert.py b/tools/model_converters/mit2mmseg.py similarity index 73% rename from tools/model_converters/mit_convert.py rename to tools/model_converters/mit2mmseg.py index 5138e55c6..37e9b9476 100644 --- a/tools/model_converters/mit_convert.py +++ b/tools/model_converters/mit2mmseg.py @@ -1,8 +1,11 @@ # Copyright (c) OpenMMLab. All rights reserved. import argparse +import os.path as osp from collections import OrderedDict +import mmcv import torch +from mmcv.runner import CheckpointLoader def convert_mit(ckpt): @@ -54,24 +57,26 @@ def convert_mit(ckpt): return new_ckpt -def parse_args(): +def main(): parser = argparse.ArgumentParser( - 'Convert official segformer backbone weights to mmseg style.') - parser.add_argument( - 'src', help='Source path of official segformer backbone weights.') - parser.add_argument( - 'dst', - help='Destination path of converted segformer backbone weights.') + description='Convert keys in official pretrained segformer to ' + 'MMSegmentation style.') + parser.add_argument('src', help='src model path or url') + # The dst path must be a full path of the new checkpoint. + parser.add_argument('dst', help='save path') + args = parser.parse_args() - return parser.parse_args() + checkpoint = CheckpointLoader.load_checkpoint(args.src, map_location='cpu') + if 'state_dict' in checkpoint: + state_dict = checkpoint['state_dict'] + elif 'model' in checkpoint: + state_dict = checkpoint['model'] + else: + state_dict = checkpoint + weight = convert_mit(state_dict) + mmcv.mkdir_or_exist(osp.dirname(args.dst)) + torch.save(weight, args.dst) if __name__ == '__main__': - args = parse_args() - src_path = args.src - dst_path = args.dst - - ckpt = torch.load(src_path, map_location='cpu') - - ckpt = convert_mit(ckpt) - torch.save(ckpt, dst_path) + main() diff --git a/tools/model_converters/swin2mmseg.py b/tools/model_converters/swin2mmseg.py index 5a720f376..03b24ceaa 100644 --- a/tools/model_converters/swin2mmseg.py +++ b/tools/model_converters/swin2mmseg.py @@ -1,7 +1,11 @@ +# Copyright (c) OpenMMLab. All rights reserved. import argparse +import os.path as osp from collections import OrderedDict +import mmcv import torch +from mmcv.runner import CheckpointLoader def convert_swin(ckpt): @@ -62,12 +66,12 @@ def main(): parser = argparse.ArgumentParser( description='Convert keys in official pretrained swin models to' 'MMSegmentation style.') - parser.add_argument('src', help='src segmentation model path') + parser.add_argument('src', help='src model path or url') # The dst path must be a full path of the new checkpoint. parser.add_argument('dst', help='save path') args = parser.parse_args() - checkpoint = torch.load(args.src, map_location='cpu') + checkpoint = CheckpointLoader.load_checkpoint(args.src, map_location='cpu') if 'state_dict' in checkpoint: state_dict = checkpoint['state_dict'] elif 'model' in checkpoint: @@ -75,8 +79,8 @@ def main(): else: state_dict = checkpoint weight = convert_swin(state_dict) - with open(args.dst, 'wb') as f: - torch.save(weight, f) + mmcv.mkdir_or_exist(osp.dirname(args.dst)) + torch.save(weight, args.dst) if __name__ == '__main__': diff --git a/tools/model_converters/vit2mmseg.py b/tools/model_converters/vit2mmseg.py index 176c03a53..bc18ebed8 100644 --- a/tools/model_converters/vit2mmseg.py +++ b/tools/model_converters/vit2mmseg.py @@ -1,7 +1,11 @@ +# Copyright (c) OpenMMLab. All rights reserved. import argparse +import os.path as osp from collections import OrderedDict +import mmcv import torch +from mmcv.runner import CheckpointLoader def convert_vit(ckpt): @@ -43,12 +47,12 @@ def main(): parser = argparse.ArgumentParser( description='Convert keys in timm pretrained vit models to ' 'MMSegmentation style.') - parser.add_argument('src', help='src segmentation model path') + parser.add_argument('src', help='src model path or url') # The dst path must be a full path of the new checkpoint. parser.add_argument('dst', help='save path') args = parser.parse_args() - checkpoint = torch.load(args.src, map_location='cpu') + checkpoint = CheckpointLoader.load_checkpoint(args.src, map_location='cpu') if 'state_dict' in checkpoint: # timm checkpoint state_dict = checkpoint['state_dict'] @@ -58,8 +62,8 @@ def main(): else: state_dict = checkpoint weight = convert_vit(state_dict) - with open(args.dst, 'wb') as f: - torch.save(weight, f) + mmcv.mkdir_or_exist(osp.dirname(args.dst)) + torch.save(weight, args.dst) if __name__ == '__main__':