[Enhancement] Delete convert function and add instruction to ViT/Swin README.md (#791)
* delete convert function and add instruction to README.md * unified model convert and README * remove url * fix import error * fix unittest * rename pretrain * rename vit and deit pretrain * Update upernet_deit-b16_512x512_160k_ade20k.py * Update upernet_deit-b16_512x512_80k_ade20k.py * Update upernet_deit-b16_ln_mln_512x512_160k_ade20k.py * Update upernet_deit-b16_mln_512x512_160k_ade20k.py * Update upernet_deit-s16_512x512_160k_ade20k.py * Update upernet_deit-s16_512x512_80k_ade20k.py * Update upernet_deit-s16_ln_mln_512x512_160k_ade20k.py * Update upernet_deit-s16_mln_512x512_160k_ade20k.py Co-authored-by: Jiarui XU <xvjiarui0826@gmail.com> Co-authored-by: Junjun2016 <hejunjun@sjtu.edu.cn>pull/1801/head
parent
4e9c26bbbc
commit
119bbd838d
|
@ -3,8 +3,7 @@ backbone_norm_cfg = dict(type='LN', eps=1e-6, requires_grad=True)
|
||||||
norm_cfg = dict(type='SyncBN', requires_grad=True)
|
norm_cfg = dict(type='SyncBN', requires_grad=True)
|
||||||
model = dict(
|
model = dict(
|
||||||
type='EncoderDecoder',
|
type='EncoderDecoder',
|
||||||
pretrained=\
|
pretrained='pretrain/jx_vit_large_p16_384-b3be5167.pth',
|
||||||
'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_large_p16_384-b3be5167.pth', # noqa
|
|
||||||
backbone=dict(
|
backbone=dict(
|
||||||
type='VisionTransformer',
|
type='VisionTransformer',
|
||||||
img_size=(768, 768),
|
img_size=(768, 768),
|
||||||
|
|
|
@ -3,8 +3,7 @@ backbone_norm_cfg = dict(type='LN', eps=1e-6, requires_grad=True)
|
||||||
norm_cfg = dict(type='SyncBN', requires_grad=True)
|
norm_cfg = dict(type='SyncBN', requires_grad=True)
|
||||||
model = dict(
|
model = dict(
|
||||||
type='EncoderDecoder',
|
type='EncoderDecoder',
|
||||||
pretrained=\
|
pretrained='pretrain/jx_vit_large_p16_384-b3be5167.pth',
|
||||||
'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_large_p16_384-b3be5167.pth', # noqa
|
|
||||||
backbone=dict(
|
backbone=dict(
|
||||||
type='VisionTransformer',
|
type='VisionTransformer',
|
||||||
img_size=(768, 768),
|
img_size=(768, 768),
|
||||||
|
|
|
@ -3,8 +3,7 @@ backbone_norm_cfg = dict(type='LN', eps=1e-6, requires_grad=True)
|
||||||
norm_cfg = dict(type='SyncBN', requires_grad=True)
|
norm_cfg = dict(type='SyncBN', requires_grad=True)
|
||||||
model = dict(
|
model = dict(
|
||||||
type='EncoderDecoder',
|
type='EncoderDecoder',
|
||||||
pretrained=\
|
pretrained='pretrain/jx_vit_large_p16_384-b3be5167.pth',
|
||||||
'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_large_p16_384-b3be5167.pth', # noqa
|
|
||||||
backbone=dict(
|
backbone=dict(
|
||||||
type='VisionTransformer',
|
type='VisionTransformer',
|
||||||
img_size=(768, 768),
|
img_size=(768, 768),
|
||||||
|
|
|
@ -2,7 +2,7 @@
|
||||||
norm_cfg = dict(type='SyncBN', requires_grad=True)
|
norm_cfg = dict(type='SyncBN', requires_grad=True)
|
||||||
model = dict(
|
model = dict(
|
||||||
type='EncoderDecoder',
|
type='EncoderDecoder',
|
||||||
pretrained='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_base_p16_224-80ecf9dd.pth', # noqa
|
pretrained='pretrain/jx_vit_base_p16_224-80ecf9dd.pth',
|
||||||
backbone=dict(
|
backbone=dict(
|
||||||
type='VisionTransformer',
|
type='VisionTransformer',
|
||||||
img_size=(512, 512),
|
img_size=(512, 512),
|
||||||
|
|
|
@ -13,6 +13,18 @@
|
||||||
}
|
}
|
||||||
```
|
```
|
||||||
|
|
||||||
|
## Usage
|
||||||
|
|
||||||
|
To use other repositories' pre-trained models, it is necessary to convert keys.
|
||||||
|
|
||||||
|
We provide a script [`mit2mmseg.py`](../../tools/model_converters/mit2mmseg.py) in the tools directory to convert the key of models from [the official repo](https://github.com/NVlabs/SegFormer) to MMSegmentation style.
|
||||||
|
|
||||||
|
```shell
|
||||||
|
python tools/model_converters/swin2mmseg.py ${PRETRAIN_PATH} ${STORE_PATH}
|
||||||
|
```
|
||||||
|
|
||||||
|
This script convert model from `PRETRAIN_PATH` and store the converted model in `STORE_PATH`.
|
||||||
|
|
||||||
## Results and models
|
## Results and models
|
||||||
|
|
||||||
### ADE20k
|
### ADE20k
|
||||||
|
@ -61,13 +73,3 @@ test_pipeline = [
|
||||||
])
|
])
|
||||||
]
|
]
|
||||||
```
|
```
|
||||||
|
|
||||||
## How to use segformer official pretrain weights
|
|
||||||
|
|
||||||
We convert the backbone weights from the official repo (https://github.com/NVlabs/SegFormer) with `tools/model_converters/mit_convert.py`.
|
|
||||||
|
|
||||||
You may follow below steps to start segformer training preparation:
|
|
||||||
|
|
||||||
1. Download segformer pretrain weights (Suggest put in `pretrain/`);
|
|
||||||
2. Run convert script to convert official pretrain weights: `python tools/model_converters/mit_convert.py pretrain/mit_b0.pth pretrain/mit_b0.pth`;
|
|
||||||
3. Modify `pretrained` of segformer model config, for example, `pretrained` of `segformer_mit-b0_512x512_160k_ade20k.py` is set to `pretrain/mit_b0.pth`;
|
|
||||||
|
|
|
@ -4,6 +4,7 @@ _base_ = [
|
||||||
]
|
]
|
||||||
norm_cfg = dict(type='SyncBN', requires_grad=True)
|
norm_cfg = dict(type='SyncBN', requires_grad=True)
|
||||||
model = dict(
|
model = dict(
|
||||||
|
pretrained='pretrain/vit_large_patch16_384.pth',
|
||||||
backbone=dict(img_size=(512, 512), drop_rate=0.),
|
backbone=dict(img_size=(512, 512), drop_rate=0.),
|
||||||
decode_head=dict(num_classes=150),
|
decode_head=dict(num_classes=150),
|
||||||
auxiliary_head=[
|
auxiliary_head=[
|
||||||
|
|
|
@ -4,6 +4,7 @@ _base_ = [
|
||||||
]
|
]
|
||||||
norm_cfg = dict(type='SyncBN', requires_grad=True)
|
norm_cfg = dict(type='SyncBN', requires_grad=True)
|
||||||
model = dict(
|
model = dict(
|
||||||
|
pretrained='pretrain/vit_large_patch16_384.pth',
|
||||||
backbone=dict(img_size=(512, 512), drop_rate=0.),
|
backbone=dict(img_size=(512, 512), drop_rate=0.),
|
||||||
decode_head=dict(num_classes=150),
|
decode_head=dict(num_classes=150),
|
||||||
auxiliary_head=[
|
auxiliary_head=[
|
||||||
|
|
|
@ -4,6 +4,7 @@ _base_ = [
|
||||||
]
|
]
|
||||||
norm_cfg = dict(type='SyncBN', requires_grad=True)
|
norm_cfg = dict(type='SyncBN', requires_grad=True)
|
||||||
model = dict(
|
model = dict(
|
||||||
|
pretrained='pretrain/vit_large_patch16_384.pth',
|
||||||
backbone=dict(img_size=(512, 512), drop_rate=0.),
|
backbone=dict(img_size=(512, 512), drop_rate=0.),
|
||||||
decode_head=dict(num_classes=150),
|
decode_head=dict(num_classes=150),
|
||||||
auxiliary_head=[
|
auxiliary_head=[
|
||||||
|
|
|
@ -13,6 +13,24 @@
|
||||||
}
|
}
|
||||||
```
|
```
|
||||||
|
|
||||||
|
## Usage
|
||||||
|
|
||||||
|
To use other repositories' pre-trained models, it is necessary to convert keys.
|
||||||
|
|
||||||
|
We provide a script [`swin2mmseg.py`](../../tools/model_converters/swin2mmseg.py) in the tools directory to convert the key of models from [the official repo](https://github.com/SwinTransformer/Swin-Transformer-Semantic-Segmentation) to MMSegmentation style.
|
||||||
|
|
||||||
|
```shell
|
||||||
|
python tools/model_converters/swin2mmseg.py ${PRETRAIN_PATH} ${STORE_PATH}
|
||||||
|
```
|
||||||
|
|
||||||
|
E.g.
|
||||||
|
|
||||||
|
```shell
|
||||||
|
python tools/model_converters/swin2mmseg.py https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_base_patch4_window7_224.pth pretrain/swin_base_patch4_window7_224.pth
|
||||||
|
```
|
||||||
|
|
||||||
|
This script convert model from `PRETRAIN_PATH` and store the converted model in `STORE_PATH`.
|
||||||
|
|
||||||
## Results and models
|
## Results and models
|
||||||
|
|
||||||
### ADE20K
|
### ADE20K
|
||||||
|
|
|
@ -3,8 +3,7 @@ _base_ = [
|
||||||
'pretrain_224x224_1K.py'
|
'pretrain_224x224_1K.py'
|
||||||
]
|
]
|
||||||
model = dict(
|
model = dict(
|
||||||
pretrained=\
|
pretrained='pretrain/swin_base_patch4_window12_384.pth',
|
||||||
'https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_base_patch4_window12_384.pth', # noqa
|
|
||||||
backbone=dict(
|
backbone=dict(
|
||||||
pretrain_img_size=384,
|
pretrain_img_size=384,
|
||||||
embed_dims=128,
|
embed_dims=128,
|
||||||
|
|
|
@ -2,7 +2,4 @@ _base_ = [
|
||||||
'./upernet_swin_base_patch4_window12_512x512_160k_ade20k_'
|
'./upernet_swin_base_patch4_window12_512x512_160k_ade20k_'
|
||||||
'pretrain_384x384_1K.py'
|
'pretrain_384x384_1K.py'
|
||||||
]
|
]
|
||||||
model = dict(
|
model = dict(pretrained='pretrain/swin_base_patch4_window12_384_22k.pth')
|
||||||
pretrained=\
|
|
||||||
'https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_base_patch4_window12_384_22k.pth', # noqa
|
|
||||||
)
|
|
||||||
|
|
|
@ -3,11 +3,8 @@ _base_ = [
|
||||||
'pretrain_224x224_1K.py'
|
'pretrain_224x224_1K.py'
|
||||||
]
|
]
|
||||||
model = dict(
|
model = dict(
|
||||||
pretrained=\
|
pretrained='pretrain/swin_base_patch4_window7_224.pth',
|
||||||
'https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_base_patch4_window7_224.pth', # noqa
|
|
||||||
backbone=dict(
|
backbone=dict(
|
||||||
embed_dims=128,
|
embed_dims=128, depths=[2, 2, 18, 2], num_heads=[4, 8, 16, 32]),
|
||||||
depths=[2, 2, 18, 2],
|
|
||||||
num_heads=[4, 8, 16, 32]),
|
|
||||||
decode_head=dict(in_channels=[128, 256, 512, 1024], num_classes=150),
|
decode_head=dict(in_channels=[128, 256, 512, 1024], num_classes=150),
|
||||||
auxiliary_head=dict(in_channels=512, num_classes=150))
|
auxiliary_head=dict(in_channels=512, num_classes=150))
|
||||||
|
|
|
@ -2,7 +2,4 @@ _base_ = [
|
||||||
'./upernet_swin_base_patch4_window7_512x512_160k_ade20k_'
|
'./upernet_swin_base_patch4_window7_512x512_160k_ade20k_'
|
||||||
'pretrain_224x224_1K.py'
|
'pretrain_224x224_1K.py'
|
||||||
]
|
]
|
||||||
model = dict(
|
model = dict(pretrained='pretrain/swin_base_patch4_window7_224_22k.pth')
|
||||||
pretrained=\
|
|
||||||
'https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_base_patch4_window7_224_22k.pth', # noqa
|
|
||||||
)
|
|
||||||
|
|
|
@ -3,15 +3,7 @@ _base_ = [
|
||||||
'pretrain_224x224_1K.py'
|
'pretrain_224x224_1K.py'
|
||||||
]
|
]
|
||||||
model = dict(
|
model = dict(
|
||||||
pretrained=\
|
pretrained='pretrain/swin_small_patch4_window7_224.pth',
|
||||||
'https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_small_patch4_window7_224.pth', # noqa
|
backbone=dict(depths=[2, 2, 18, 2]),
|
||||||
backbone=dict(
|
decode_head=dict(in_channels=[96, 192, 384, 768], num_classes=150),
|
||||||
depths=[2, 2, 18, 2]),
|
auxiliary_head=dict(in_channels=384, num_classes=150))
|
||||||
decode_head=dict(
|
|
||||||
in_channels=[96, 192, 384, 768],
|
|
||||||
num_classes=150
|
|
||||||
),
|
|
||||||
auxiliary_head=dict(
|
|
||||||
in_channels=384,
|
|
||||||
num_classes=150
|
|
||||||
))
|
|
||||||
|
|
|
@ -3,8 +3,7 @@ _base_ = [
|
||||||
'../_base_/default_runtime.py', '../_base_/schedules/schedule_160k.py'
|
'../_base_/default_runtime.py', '../_base_/schedules/schedule_160k.py'
|
||||||
]
|
]
|
||||||
model = dict(
|
model = dict(
|
||||||
pretrained=\
|
pretrained='pretrain/swin_tiny_patch4_window7_224.pth',
|
||||||
'https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_tiny_patch4_window7_224.pth', # noqa
|
|
||||||
backbone=dict(
|
backbone=dict(
|
||||||
embed_dims=96,
|
embed_dims=96,
|
||||||
depths=[2, 2, 6, 2],
|
depths=[2, 2, 6, 2],
|
||||||
|
|
|
@ -13,6 +13,24 @@
|
||||||
}
|
}
|
||||||
```
|
```
|
||||||
|
|
||||||
|
## Usage
|
||||||
|
|
||||||
|
To use other repositories' pre-trained models, it is necessary to convert keys.
|
||||||
|
|
||||||
|
We provide a script [`vit2mmseg.py`](../../tools/model_converters/vit2mmseg.py) in the tools directory to convert the key of models from [timm](https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py) to MMSegmentation style.
|
||||||
|
|
||||||
|
```shell
|
||||||
|
python tools/model_converters/vit2mmseg.py ${PRETRAIN_PATH} ${STORE_PATH}
|
||||||
|
```
|
||||||
|
|
||||||
|
E.g.
|
||||||
|
|
||||||
|
```shell
|
||||||
|
python tools/model_converters/vit2mmseg.py https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_base_p16_224-80ecf9dd.pth pretrain/jx_vit_base_p16_224-80ecf9dd.pth
|
||||||
|
```
|
||||||
|
|
||||||
|
This script convert model from `PRETRAIN_PATH` and store the converted model in `STORE_PATH`.
|
||||||
|
|
||||||
## Results and models
|
## Results and models
|
||||||
|
|
||||||
### ADE20K
|
### ADE20K
|
||||||
|
|
|
@ -1,6 +1,6 @@
|
||||||
_base_ = './upernet_vit-b16_mln_512x512_160k_ade20k.py'
|
_base_ = './upernet_vit-b16_mln_512x512_160k_ade20k.py'
|
||||||
|
|
||||||
model = dict(
|
model = dict(
|
||||||
pretrained='https://dl.fbaipublicfiles.com/deit/deit_base_patch16_224-b5f2ef4d.pth', # noqa
|
pretrained='pretrain/deit_base_patch16_224-b5f2ef4d.pth',
|
||||||
backbone=dict(drop_path_rate=0.1),
|
backbone=dict(drop_path_rate=0.1),
|
||||||
neck=None) # yapf: disable
|
neck=None)
|
||||||
|
|
|
@ -1,6 +1,6 @@
|
||||||
_base_ = './upernet_vit-b16_mln_512x512_80k_ade20k.py'
|
_base_ = './upernet_vit-b16_mln_512x512_80k_ade20k.py'
|
||||||
|
|
||||||
model = dict(
|
model = dict(
|
||||||
pretrained='https://dl.fbaipublicfiles.com/deit/deit_base_patch16_224-b5f2ef4d.pth', # noqa
|
pretrained='pretrain/deit_base_patch16_224-b5f2ef4d.pth',
|
||||||
backbone=dict(drop_path_rate=0.1),
|
backbone=dict(drop_path_rate=0.1),
|
||||||
neck=None) # yapf: disable
|
neck=None)
|
||||||
|
|
|
@ -1,5 +1,5 @@
|
||||||
_base_ = './upernet_vit-b16_mln_512x512_160k_ade20k.py'
|
_base_ = './upernet_vit-b16_mln_512x512_160k_ade20k.py'
|
||||||
|
|
||||||
model = dict(
|
model = dict(
|
||||||
pretrained='https://dl.fbaipublicfiles.com/deit/deit_base_patch16_224-b5f2ef4d.pth', # noqa
|
pretrained='pretrain/deit_base_patch16_224-b5f2ef4d.pth',
|
||||||
backbone=dict(drop_path_rate=0.1, final_norm=True)) # yapf: disable
|
backbone=dict(drop_path_rate=0.1, final_norm=True))
|
||||||
|
|
|
@ -1,5 +1,6 @@
|
||||||
_base_ = './upernet_vit-b16_mln_512x512_160k_ade20k.py'
|
_base_ = './upernet_vit-b16_mln_512x512_160k_ade20k.py'
|
||||||
|
|
||||||
model = dict(
|
model = dict(
|
||||||
pretrained='https://dl.fbaipublicfiles.com/deit/deit_base_patch16_224-b5f2ef4d.pth', # noqa
|
pretrained='pretrain/deit_base_patch16_224-b5f2ef4d.pth',
|
||||||
backbone=dict(drop_path_rate=0.1),) # yapf: disable
|
backbone=dict(drop_path_rate=0.1),
|
||||||
|
)
|
||||||
|
|
|
@ -1,8 +1,8 @@
|
||||||
_base_ = './upernet_vit-b16_mln_512x512_160k_ade20k.py'
|
_base_ = './upernet_vit-b16_mln_512x512_160k_ade20k.py'
|
||||||
|
|
||||||
model = dict(
|
model = dict(
|
||||||
pretrained='https://dl.fbaipublicfiles.com/deit/deit_small_patch16_224-cd65a155.pth', # noqa
|
pretrained='pretrain/deit_small_patch16_224-cd65a155.pth',
|
||||||
backbone=dict(num_heads=6, embed_dims=384, drop_path_rate=0.1),
|
backbone=dict(num_heads=6, embed_dims=384, drop_path_rate=0.1),
|
||||||
decode_head=dict(num_classes=150, in_channels=[384, 384, 384, 384]),
|
decode_head=dict(num_classes=150, in_channels=[384, 384, 384, 384]),
|
||||||
neck=None,
|
neck=None,
|
||||||
auxiliary_head=dict(num_classes=150, in_channels=384)) # yapf: disable
|
auxiliary_head=dict(num_classes=150, in_channels=384))
|
||||||
|
|
|
@ -1,8 +1,8 @@
|
||||||
_base_ = './upernet_vit-b16_mln_512x512_80k_ade20k.py'
|
_base_ = './upernet_vit-b16_mln_512x512_80k_ade20k.py'
|
||||||
|
|
||||||
model = dict(
|
model = dict(
|
||||||
pretrained='https://dl.fbaipublicfiles.com/deit/deit_small_patch16_224-cd65a155.pth', # noqa
|
pretrained='pretrain/deit_small_patch16_224-cd65a155.pth',
|
||||||
backbone=dict(num_heads=6, embed_dims=384, drop_path_rate=0.1),
|
backbone=dict(num_heads=6, embed_dims=384, drop_path_rate=0.1),
|
||||||
decode_head=dict(num_classes=150, in_channels=[384, 384, 384, 384]),
|
decode_head=dict(num_classes=150, in_channels=[384, 384, 384, 384]),
|
||||||
neck=None,
|
neck=None,
|
||||||
auxiliary_head=dict(num_classes=150, in_channels=384)) # yapf: disable
|
auxiliary_head=dict(num_classes=150, in_channels=384))
|
||||||
|
|
|
@ -1,12 +1,9 @@
|
||||||
_base_ = './upernet_vit-b16_mln_512x512_160k_ade20k.py'
|
_base_ = './upernet_vit-b16_mln_512x512_160k_ade20k.py'
|
||||||
|
|
||||||
model = dict(
|
model = dict(
|
||||||
pretrained='https://dl.fbaipublicfiles.com/deit/deit_small_patch16_224-cd65a155.pth', # noqa
|
pretrained='pretrain/deit_small_patch16_224-cd65a155.pth',
|
||||||
backbone=dict(
|
backbone=dict(
|
||||||
num_heads=6,
|
num_heads=6, embed_dims=384, drop_path_rate=0.1, final_norm=True),
|
||||||
embed_dims=384,
|
|
||||||
drop_path_rate=0.1,
|
|
||||||
final_norm=True),
|
|
||||||
decode_head=dict(num_classes=150, in_channels=[384, 384, 384, 384]),
|
decode_head=dict(num_classes=150, in_channels=[384, 384, 384, 384]),
|
||||||
neck=dict(in_channels=[384, 384, 384, 384], out_channels=384),
|
neck=dict(in_channels=[384, 384, 384, 384], out_channels=384),
|
||||||
auxiliary_head=dict(num_classes=150, in_channels=384)) # yapf: disable
|
auxiliary_head=dict(num_classes=150, in_channels=384))
|
||||||
|
|
|
@ -1,8 +1,8 @@
|
||||||
_base_ = './upernet_vit-b16_mln_512x512_160k_ade20k.py'
|
_base_ = './upernet_vit-b16_mln_512x512_160k_ade20k.py'
|
||||||
|
|
||||||
model = dict(
|
model = dict(
|
||||||
pretrained='https://dl.fbaipublicfiles.com/deit/deit_small_patch16_224-cd65a155.pth', # noqa
|
pretrained='pretrain/deit_small_patch16_224-cd65a155.pth',
|
||||||
backbone=dict(num_heads=6, embed_dims=384, drop_path_rate=0.1),
|
backbone=dict(num_heads=6, embed_dims=384, drop_path_rate=0.1),
|
||||||
decode_head=dict(num_classes=150, in_channels=[384, 384, 384, 384]),
|
decode_head=dict(num_classes=150, in_channels=[384, 384, 384, 384]),
|
||||||
neck=dict(in_channels=[384, 384, 384, 384], out_channels=384),
|
neck=dict(in_channels=[384, 384, 384, 384], out_channels=384),
|
||||||
auxiliary_head=dict(num_classes=150, in_channels=384)) # yapf: disable
|
auxiliary_head=dict(num_classes=150, in_channels=384))
|
||||||
|
|
|
@ -5,6 +5,7 @@ _base_ = [
|
||||||
]
|
]
|
||||||
|
|
||||||
model = dict(
|
model = dict(
|
||||||
|
pretrained='pretrain/vit_base_patch16_224.pth',
|
||||||
backbone=dict(drop_path_rate=0.1, final_norm=True),
|
backbone=dict(drop_path_rate=0.1, final_norm=True),
|
||||||
decode_head=dict(num_classes=150),
|
decode_head=dict(num_classes=150),
|
||||||
auxiliary_head=dict(num_classes=150))
|
auxiliary_head=dict(num_classes=150))
|
||||||
|
|
|
@ -5,7 +5,9 @@ _base_ = [
|
||||||
]
|
]
|
||||||
|
|
||||||
model = dict(
|
model = dict(
|
||||||
decode_head=dict(num_classes=150), auxiliary_head=dict(num_classes=150))
|
pretrained='pretrain/vit_base_patch16_224.pth',
|
||||||
|
decode_head=dict(num_classes=150),
|
||||||
|
auxiliary_head=dict(num_classes=150))
|
||||||
|
|
||||||
# AdamW optimizer, no weight decay for position embedding & layer norm
|
# AdamW optimizer, no weight decay for position embedding & layer norm
|
||||||
# in backbone
|
# in backbone
|
||||||
|
|
|
@ -5,7 +5,9 @@ _base_ = [
|
||||||
]
|
]
|
||||||
|
|
||||||
model = dict(
|
model = dict(
|
||||||
decode_head=dict(num_classes=150), auxiliary_head=dict(num_classes=150))
|
pretrained='pretrain/vit_base_patch16_224.pth',
|
||||||
|
decode_head=dict(num_classes=150),
|
||||||
|
auxiliary_head=dict(num_classes=150))
|
||||||
|
|
||||||
# AdamW optimizer, no weight decay for position embedding & layer norm
|
# AdamW optimizer, no weight decay for position embedding & layer norm
|
||||||
# in backbone
|
# in backbone
|
||||||
|
|
|
@ -17,7 +17,7 @@ from torch.nn.modules.utils import _pair as to_2tuple
|
||||||
from mmseg.ops import resize
|
from mmseg.ops import resize
|
||||||
from ...utils import get_root_logger
|
from ...utils import get_root_logger
|
||||||
from ..builder import ATTENTION, BACKBONES
|
from ..builder import ATTENTION, BACKBONES
|
||||||
from ..utils import PatchEmbed, swin_convert
|
from ..utils import PatchEmbed
|
||||||
|
|
||||||
|
|
||||||
class PatchMerging(BaseModule):
|
class PatchMerging(BaseModule):
|
||||||
|
@ -564,8 +564,6 @@ class SwinTransformer(BaseModule):
|
||||||
Default: dict(type='LN').
|
Default: dict(type='LN').
|
||||||
norm_cfg (dict): Config dict for normalization layer at
|
norm_cfg (dict): Config dict for normalization layer at
|
||||||
output of backone. Defaults: dict(type='LN').
|
output of backone. Defaults: dict(type='LN').
|
||||||
pretrain_style (str): Choose to use official or mmcls pretrain weights.
|
|
||||||
Default: official.
|
|
||||||
pretrained (str, optional): model pretrained path. Default: None.
|
pretrained (str, optional): model pretrained path. Default: None.
|
||||||
init_cfg (dict, optional): The Config for initialization.
|
init_cfg (dict, optional): The Config for initialization.
|
||||||
Defaults to None.
|
Defaults to None.
|
||||||
|
@ -591,7 +589,6 @@ class SwinTransformer(BaseModule):
|
||||||
use_abs_pos_embed=False,
|
use_abs_pos_embed=False,
|
||||||
act_cfg=dict(type='GELU'),
|
act_cfg=dict(type='GELU'),
|
||||||
norm_cfg=dict(type='LN'),
|
norm_cfg=dict(type='LN'),
|
||||||
pretrain_style='official',
|
|
||||||
pretrained=None,
|
pretrained=None,
|
||||||
init_cfg=None):
|
init_cfg=None):
|
||||||
super(SwinTransformer, self).__init__()
|
super(SwinTransformer, self).__init__()
|
||||||
|
@ -605,9 +602,6 @@ class SwinTransformer(BaseModule):
|
||||||
f'The size of image should have length 1 or 2, ' \
|
f'The size of image should have length 1 or 2, ' \
|
||||||
f'but got {len(pretrain_img_size)}'
|
f'but got {len(pretrain_img_size)}'
|
||||||
|
|
||||||
assert pretrain_style in ['official', 'mmcls'], 'We only support load '
|
|
||||||
'official ckpt and mmcls ckpt.'
|
|
||||||
|
|
||||||
if isinstance(pretrained, str) or pretrained is None:
|
if isinstance(pretrained, str) or pretrained is None:
|
||||||
warnings.warn('DeprecationWarning: pretrained is a deprecated, '
|
warnings.warn('DeprecationWarning: pretrained is a deprecated, '
|
||||||
'please use "init_cfg" instead')
|
'please use "init_cfg" instead')
|
||||||
|
@ -617,7 +611,6 @@ class SwinTransformer(BaseModule):
|
||||||
num_layers = len(depths)
|
num_layers = len(depths)
|
||||||
self.out_indices = out_indices
|
self.out_indices = out_indices
|
||||||
self.use_abs_pos_embed = use_abs_pos_embed
|
self.use_abs_pos_embed = use_abs_pos_embed
|
||||||
self.pretrain_style = pretrain_style
|
|
||||||
self.pretrained = pretrained
|
self.pretrained = pretrained
|
||||||
self.init_cfg = init_cfg
|
self.init_cfg = init_cfg
|
||||||
|
|
||||||
|
@ -713,9 +706,6 @@ class SwinTransformer(BaseModule):
|
||||||
else:
|
else:
|
||||||
state_dict = ckpt
|
state_dict = ckpt
|
||||||
|
|
||||||
if self.pretrain_style == 'official':
|
|
||||||
state_dict = swin_convert(state_dict)
|
|
||||||
|
|
||||||
# strip prefix of state_dict
|
# strip prefix of state_dict
|
||||||
if list(state_dict.keys())[0].startswith('module.'):
|
if list(state_dict.keys())[0].startswith('module.'):
|
||||||
state_dict = {k[7:]: v for k, v in state_dict.items()}
|
state_dict = {k[7:]: v for k, v in state_dict.items()}
|
||||||
|
|
|
@ -14,7 +14,7 @@ from torch.nn.modules.utils import _pair as to_2tuple
|
||||||
from mmseg.ops import resize
|
from mmseg.ops import resize
|
||||||
from mmseg.utils import get_root_logger
|
from mmseg.utils import get_root_logger
|
||||||
from ..builder import BACKBONES
|
from ..builder import BACKBONES
|
||||||
from ..utils import PatchEmbed, vit_convert
|
from ..utils import PatchEmbed
|
||||||
|
|
||||||
|
|
||||||
class TransformerEncoderLayer(BaseModule):
|
class TransformerEncoderLayer(BaseModule):
|
||||||
|
@ -140,8 +140,6 @@ class VisionTransformer(BaseModule):
|
||||||
and its variants only. Default: False.
|
and its variants only. Default: False.
|
||||||
with_cp (bool): Use checkpoint or not. Using checkpoint will save
|
with_cp (bool): Use checkpoint or not. Using checkpoint will save
|
||||||
some memory while slowing down the training speed. Default: False.
|
some memory while slowing down the training speed. Default: False.
|
||||||
pretrain_style (str): Choose to use timm or mmcls pretrain weights.
|
|
||||||
Default: timm.
|
|
||||||
pretrained (str, optional): model pretrained path. Default: None.
|
pretrained (str, optional): model pretrained path. Default: None.
|
||||||
init_cfg (dict or list[dict], optional): Initialization config dict.
|
init_cfg (dict or list[dict], optional): Initialization config dict.
|
||||||
Default: None.
|
Default: None.
|
||||||
|
@ -170,7 +168,6 @@ class VisionTransformer(BaseModule):
|
||||||
num_fcs=2,
|
num_fcs=2,
|
||||||
norm_eval=False,
|
norm_eval=False,
|
||||||
with_cp=False,
|
with_cp=False,
|
||||||
pretrain_style='timm',
|
|
||||||
pretrained=None,
|
pretrained=None,
|
||||||
init_cfg=None):
|
init_cfg=None):
|
||||||
super(VisionTransformer, self).__init__()
|
super(VisionTransformer, self).__init__()
|
||||||
|
@ -184,8 +181,6 @@ class VisionTransformer(BaseModule):
|
||||||
f'The size of image should have length 1 or 2, ' \
|
f'The size of image should have length 1 or 2, ' \
|
||||||
f'but got {len(img_size)}'
|
f'but got {len(img_size)}'
|
||||||
|
|
||||||
assert pretrain_style in ['timm', 'mmcls']
|
|
||||||
|
|
||||||
if output_cls_token:
|
if output_cls_token:
|
||||||
assert with_cls_token is True, f'with_cls_token must be True if' \
|
assert with_cls_token is True, f'with_cls_token must be True if' \
|
||||||
f'set output_cls_token to True, but got {with_cls_token}'
|
f'set output_cls_token to True, but got {with_cls_token}'
|
||||||
|
@ -201,7 +196,6 @@ class VisionTransformer(BaseModule):
|
||||||
self.interpolate_mode = interpolate_mode
|
self.interpolate_mode = interpolate_mode
|
||||||
self.norm_eval = norm_eval
|
self.norm_eval = norm_eval
|
||||||
self.with_cp = with_cp
|
self.with_cp = with_cp
|
||||||
self.pretrain_style = pretrain_style
|
|
||||||
self.pretrained = pretrained
|
self.pretrained = pretrained
|
||||||
self.init_cfg = init_cfg
|
self.init_cfg = init_cfg
|
||||||
|
|
||||||
|
@ -272,17 +266,9 @@ class VisionTransformer(BaseModule):
|
||||||
self.pretrained, logger=logger, map_location='cpu')
|
self.pretrained, logger=logger, map_location='cpu')
|
||||||
if 'state_dict' in checkpoint:
|
if 'state_dict' in checkpoint:
|
||||||
state_dict = checkpoint['state_dict']
|
state_dict = checkpoint['state_dict']
|
||||||
elif 'model' in checkpoint:
|
|
||||||
state_dict = checkpoint['model']
|
|
||||||
else:
|
else:
|
||||||
state_dict = checkpoint
|
state_dict = checkpoint
|
||||||
|
|
||||||
if self.pretrain_style == 'timm':
|
|
||||||
# Because the refactor of vit is blocked by mmcls,
|
|
||||||
# so we firstly use timm pretrain weights to train
|
|
||||||
# downstream model.
|
|
||||||
state_dict = vit_convert(state_dict)
|
|
||||||
|
|
||||||
if 'pos_embed' in state_dict.keys():
|
if 'pos_embed' in state_dict.keys():
|
||||||
if self.pos_embed.shape != state_dict['pos_embed'].shape:
|
if self.pos_embed.shape != state_dict['pos_embed'].shape:
|
||||||
logger.info(msg=f'Resize the pos_embed shape from '
|
logger.info(msg=f'Resize the pos_embed shape from '
|
||||||
|
|
|
@ -1,5 +1,3 @@
|
||||||
# Copyright (c) OpenMMLab. All rights reserved.
|
|
||||||
from .ckpt_convert import swin_convert, vit_convert
|
|
||||||
from .embed import PatchEmbed
|
from .embed import PatchEmbed
|
||||||
from .inverted_residual import InvertedResidual, InvertedResidualV3
|
from .inverted_residual import InvertedResidual, InvertedResidualV3
|
||||||
from .make_divisible import make_divisible
|
from .make_divisible import make_divisible
|
||||||
|
@ -11,6 +9,6 @@ from .up_conv_block import UpConvBlock
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
'ResLayer', 'SelfAttentionBlock', 'make_divisible', 'InvertedResidual',
|
'ResLayer', 'SelfAttentionBlock', 'make_divisible', 'InvertedResidual',
|
||||||
'UpConvBlock', 'InvertedResidualV3', 'SELayer', 'vit_convert',
|
'UpConvBlock', 'InvertedResidualV3', 'SELayer', 'PatchEmbed',
|
||||||
'swin_convert', 'PatchEmbed', 'nchw_to_nlc', 'nlc_to_nchw'
|
'nchw_to_nlc', 'nlc_to_nchw'
|
||||||
]
|
]
|
||||||
|
|
|
@ -1,91 +0,0 @@
|
||||||
# Copyright (c) OpenMMLab. All rights reserved.
|
|
||||||
from collections import OrderedDict
|
|
||||||
|
|
||||||
|
|
||||||
def swin_convert(ckpt):
|
|
||||||
new_ckpt = OrderedDict()
|
|
||||||
|
|
||||||
def correct_unfold_reduction_order(x):
|
|
||||||
out_channel, in_channel = x.shape
|
|
||||||
x = x.reshape(out_channel, 4, in_channel // 4)
|
|
||||||
x = x[:, [0, 2, 1, 3], :].transpose(1,
|
|
||||||
2).reshape(out_channel, in_channel)
|
|
||||||
return x
|
|
||||||
|
|
||||||
def correct_unfold_norm_order(x):
|
|
||||||
in_channel = x.shape[0]
|
|
||||||
x = x.reshape(4, in_channel // 4)
|
|
||||||
x = x[[0, 2, 1, 3], :].transpose(0, 1).reshape(in_channel)
|
|
||||||
return x
|
|
||||||
|
|
||||||
for k, v in ckpt.items():
|
|
||||||
if k.startswith('head'):
|
|
||||||
continue
|
|
||||||
elif k.startswith('layers'):
|
|
||||||
new_v = v
|
|
||||||
if 'attn.' in k:
|
|
||||||
new_k = k.replace('attn.', 'attn.w_msa.')
|
|
||||||
elif 'mlp.' in k:
|
|
||||||
if 'mlp.fc1.' in k:
|
|
||||||
new_k = k.replace('mlp.fc1.', 'ffn.layers.0.0.')
|
|
||||||
elif 'mlp.fc2.' in k:
|
|
||||||
new_k = k.replace('mlp.fc2.', 'ffn.layers.1.')
|
|
||||||
else:
|
|
||||||
new_k = k.replace('mlp.', 'ffn.')
|
|
||||||
elif 'downsample' in k:
|
|
||||||
new_k = k
|
|
||||||
if 'reduction.' in k:
|
|
||||||
new_v = correct_unfold_reduction_order(v)
|
|
||||||
elif 'norm.' in k:
|
|
||||||
new_v = correct_unfold_norm_order(v)
|
|
||||||
else:
|
|
||||||
new_k = k
|
|
||||||
new_k = new_k.replace('layers', 'stages', 1)
|
|
||||||
elif k.startswith('patch_embed'):
|
|
||||||
new_v = v
|
|
||||||
if 'proj' in k:
|
|
||||||
new_k = k.replace('proj', 'projection')
|
|
||||||
else:
|
|
||||||
new_k = k
|
|
||||||
else:
|
|
||||||
new_v = v
|
|
||||||
new_k = k
|
|
||||||
|
|
||||||
new_ckpt[new_k] = new_v
|
|
||||||
|
|
||||||
return new_ckpt
|
|
||||||
|
|
||||||
|
|
||||||
def vit_convert(ckpt):
|
|
||||||
|
|
||||||
new_ckpt = OrderedDict()
|
|
||||||
|
|
||||||
for k, v in ckpt.items():
|
|
||||||
if k.startswith('head'):
|
|
||||||
continue
|
|
||||||
if k.startswith('norm'):
|
|
||||||
new_k = k.replace('norm.', 'ln1.')
|
|
||||||
elif k.startswith('patch_embed'):
|
|
||||||
if 'proj' in k:
|
|
||||||
new_k = k.replace('proj', 'projection')
|
|
||||||
else:
|
|
||||||
new_k = k
|
|
||||||
elif k.startswith('blocks'):
|
|
||||||
if 'norm' in k:
|
|
||||||
new_k = k.replace('norm', 'ln')
|
|
||||||
elif 'mlp.fc1' in k:
|
|
||||||
new_k = k.replace('mlp.fc1', 'ffn.layers.0.0')
|
|
||||||
elif 'mlp.fc2' in k:
|
|
||||||
new_k = k.replace('mlp.fc2', 'ffn.layers.1')
|
|
||||||
elif 'attn.qkv' in k:
|
|
||||||
new_k = k.replace('attn.qkv.', 'attn.attn.in_proj_')
|
|
||||||
elif 'attn.proj' in k:
|
|
||||||
new_k = k.replace('attn.proj', 'attn.attn.out_proj')
|
|
||||||
else:
|
|
||||||
new_k = k
|
|
||||||
new_k = new_k.replace('blocks.', 'layers.')
|
|
||||||
else:
|
|
||||||
new_k = k
|
|
||||||
new_ckpt[new_k] = v
|
|
||||||
|
|
||||||
return new_ckpt
|
|
|
@ -8,10 +8,6 @@ from mmseg.models.backbones import SwinTransformer
|
||||||
def test_swin_transformer():
|
def test_swin_transformer():
|
||||||
"""Test Swin Transformer backbone."""
|
"""Test Swin Transformer backbone."""
|
||||||
|
|
||||||
with pytest.raises(AssertionError):
|
|
||||||
# We only support 'official' or 'mmcls' for this arg.
|
|
||||||
model = SwinTransformer(pretrain_style='swin')
|
|
||||||
|
|
||||||
with pytest.raises(TypeError):
|
with pytest.raises(TypeError):
|
||||||
# Pretrained arg must be str or None.
|
# Pretrained arg must be str or None.
|
||||||
model = SwinTransformer(pretrained=123)
|
model = SwinTransformer(pretrained=123)
|
||||||
|
|
|
@ -1,8 +1,11 @@
|
||||||
# Copyright (c) OpenMMLab. All rights reserved.
|
# Copyright (c) OpenMMLab. All rights reserved.
|
||||||
import argparse
|
import argparse
|
||||||
|
import os.path as osp
|
||||||
from collections import OrderedDict
|
from collections import OrderedDict
|
||||||
|
|
||||||
|
import mmcv
|
||||||
import torch
|
import torch
|
||||||
|
from mmcv.runner import CheckpointLoader
|
||||||
|
|
||||||
|
|
||||||
def convert_mit(ckpt):
|
def convert_mit(ckpt):
|
||||||
|
@ -54,24 +57,26 @@ def convert_mit(ckpt):
|
||||||
return new_ckpt
|
return new_ckpt
|
||||||
|
|
||||||
|
|
||||||
def parse_args():
|
def main():
|
||||||
parser = argparse.ArgumentParser(
|
parser = argparse.ArgumentParser(
|
||||||
'Convert official segformer backbone weights to mmseg style.')
|
description='Convert keys in official pretrained segformer to '
|
||||||
parser.add_argument(
|
'MMSegmentation style.')
|
||||||
'src', help='Source path of official segformer backbone weights.')
|
parser.add_argument('src', help='src model path or url')
|
||||||
parser.add_argument(
|
# The dst path must be a full path of the new checkpoint.
|
||||||
'dst',
|
parser.add_argument('dst', help='save path')
|
||||||
help='Destination path of converted segformer backbone weights.')
|
args = parser.parse_args()
|
||||||
|
|
||||||
return parser.parse_args()
|
checkpoint = CheckpointLoader.load_checkpoint(args.src, map_location='cpu')
|
||||||
|
if 'state_dict' in checkpoint:
|
||||||
|
state_dict = checkpoint['state_dict']
|
||||||
|
elif 'model' in checkpoint:
|
||||||
|
state_dict = checkpoint['model']
|
||||||
|
else:
|
||||||
|
state_dict = checkpoint
|
||||||
|
weight = convert_mit(state_dict)
|
||||||
|
mmcv.mkdir_or_exist(osp.dirname(args.dst))
|
||||||
|
torch.save(weight, args.dst)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
args = parse_args()
|
main()
|
||||||
src_path = args.src
|
|
||||||
dst_path = args.dst
|
|
||||||
|
|
||||||
ckpt = torch.load(src_path, map_location='cpu')
|
|
||||||
|
|
||||||
ckpt = convert_mit(ckpt)
|
|
||||||
torch.save(ckpt, dst_path)
|
|
|
@ -1,7 +1,11 @@
|
||||||
|
# Copyright (c) OpenMMLab. All rights reserved.
|
||||||
import argparse
|
import argparse
|
||||||
|
import os.path as osp
|
||||||
from collections import OrderedDict
|
from collections import OrderedDict
|
||||||
|
|
||||||
|
import mmcv
|
||||||
import torch
|
import torch
|
||||||
|
from mmcv.runner import CheckpointLoader
|
||||||
|
|
||||||
|
|
||||||
def convert_swin(ckpt):
|
def convert_swin(ckpt):
|
||||||
|
@ -62,12 +66,12 @@ def main():
|
||||||
parser = argparse.ArgumentParser(
|
parser = argparse.ArgumentParser(
|
||||||
description='Convert keys in official pretrained swin models to'
|
description='Convert keys in official pretrained swin models to'
|
||||||
'MMSegmentation style.')
|
'MMSegmentation style.')
|
||||||
parser.add_argument('src', help='src segmentation model path')
|
parser.add_argument('src', help='src model path or url')
|
||||||
# The dst path must be a full path of the new checkpoint.
|
# The dst path must be a full path of the new checkpoint.
|
||||||
parser.add_argument('dst', help='save path')
|
parser.add_argument('dst', help='save path')
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
checkpoint = torch.load(args.src, map_location='cpu')
|
checkpoint = CheckpointLoader.load_checkpoint(args.src, map_location='cpu')
|
||||||
if 'state_dict' in checkpoint:
|
if 'state_dict' in checkpoint:
|
||||||
state_dict = checkpoint['state_dict']
|
state_dict = checkpoint['state_dict']
|
||||||
elif 'model' in checkpoint:
|
elif 'model' in checkpoint:
|
||||||
|
@ -75,8 +79,8 @@ def main():
|
||||||
else:
|
else:
|
||||||
state_dict = checkpoint
|
state_dict = checkpoint
|
||||||
weight = convert_swin(state_dict)
|
weight = convert_swin(state_dict)
|
||||||
with open(args.dst, 'wb') as f:
|
mmcv.mkdir_or_exist(osp.dirname(args.dst))
|
||||||
torch.save(weight, f)
|
torch.save(weight, args.dst)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
|
|
|
@ -1,7 +1,11 @@
|
||||||
|
# Copyright (c) OpenMMLab. All rights reserved.
|
||||||
import argparse
|
import argparse
|
||||||
|
import os.path as osp
|
||||||
from collections import OrderedDict
|
from collections import OrderedDict
|
||||||
|
|
||||||
|
import mmcv
|
||||||
import torch
|
import torch
|
||||||
|
from mmcv.runner import CheckpointLoader
|
||||||
|
|
||||||
|
|
||||||
def convert_vit(ckpt):
|
def convert_vit(ckpt):
|
||||||
|
@ -43,12 +47,12 @@ def main():
|
||||||
parser = argparse.ArgumentParser(
|
parser = argparse.ArgumentParser(
|
||||||
description='Convert keys in timm pretrained vit models to '
|
description='Convert keys in timm pretrained vit models to '
|
||||||
'MMSegmentation style.')
|
'MMSegmentation style.')
|
||||||
parser.add_argument('src', help='src segmentation model path')
|
parser.add_argument('src', help='src model path or url')
|
||||||
# The dst path must be a full path of the new checkpoint.
|
# The dst path must be a full path of the new checkpoint.
|
||||||
parser.add_argument('dst', help='save path')
|
parser.add_argument('dst', help='save path')
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
checkpoint = torch.load(args.src, map_location='cpu')
|
checkpoint = CheckpointLoader.load_checkpoint(args.src, map_location='cpu')
|
||||||
if 'state_dict' in checkpoint:
|
if 'state_dict' in checkpoint:
|
||||||
# timm checkpoint
|
# timm checkpoint
|
||||||
state_dict = checkpoint['state_dict']
|
state_dict = checkpoint['state_dict']
|
||||||
|
@ -58,8 +62,8 @@ def main():
|
||||||
else:
|
else:
|
||||||
state_dict = checkpoint
|
state_dict = checkpoint
|
||||||
weight = convert_vit(state_dict)
|
weight = convert_vit(state_dict)
|
||||||
with open(args.dst, 'wb') as f:
|
mmcv.mkdir_or_exist(osp.dirname(args.dst))
|
||||||
torch.save(weight, f)
|
torch.save(weight, args.dst)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
|
|
Loading…
Reference in New Issue