[Enhancement] Delete convert function and add instruction to ViT/Swin README.md (#791)

* delete convert function and add instruction to README.md

* unified model convert and README

* remove url

* fix import error

* fix unittest

* rename pretrain

* rename vit and deit pretrain

* Update upernet_deit-b16_512x512_160k_ade20k.py

* Update upernet_deit-b16_512x512_80k_ade20k.py

* Update upernet_deit-b16_ln_mln_512x512_160k_ade20k.py

* Update upernet_deit-b16_mln_512x512_160k_ade20k.py

* Update upernet_deit-s16_512x512_160k_ade20k.py

* Update upernet_deit-s16_512x512_80k_ade20k.py

* Update upernet_deit-s16_ln_mln_512x512_160k_ade20k.py

* Update upernet_deit-s16_mln_512x512_160k_ade20k.py

Co-authored-by: Jiarui XU <xvjiarui0826@gmail.com>
Co-authored-by: Junjun2016 <hejunjun@sjtu.edu.cn>
pull/1801/head
谢昕辰 2021-08-26 06:00:41 +08:00 committed by GitHub
parent 4e9c26bbbc
commit 119bbd838d
35 changed files with 131 additions and 217 deletions

View File

@ -3,8 +3,7 @@ backbone_norm_cfg = dict(type='LN', eps=1e-6, requires_grad=True)
norm_cfg = dict(type='SyncBN', requires_grad=True) norm_cfg = dict(type='SyncBN', requires_grad=True)
model = dict( model = dict(
type='EncoderDecoder', type='EncoderDecoder',
pretrained=\ pretrained='pretrain/jx_vit_large_p16_384-b3be5167.pth',
'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_large_p16_384-b3be5167.pth', # noqa
backbone=dict( backbone=dict(
type='VisionTransformer', type='VisionTransformer',
img_size=(768, 768), img_size=(768, 768),

View File

@ -3,8 +3,7 @@ backbone_norm_cfg = dict(type='LN', eps=1e-6, requires_grad=True)
norm_cfg = dict(type='SyncBN', requires_grad=True) norm_cfg = dict(type='SyncBN', requires_grad=True)
model = dict( model = dict(
type='EncoderDecoder', type='EncoderDecoder',
pretrained=\ pretrained='pretrain/jx_vit_large_p16_384-b3be5167.pth',
'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_large_p16_384-b3be5167.pth', # noqa
backbone=dict( backbone=dict(
type='VisionTransformer', type='VisionTransformer',
img_size=(768, 768), img_size=(768, 768),

View File

@ -3,8 +3,7 @@ backbone_norm_cfg = dict(type='LN', eps=1e-6, requires_grad=True)
norm_cfg = dict(type='SyncBN', requires_grad=True) norm_cfg = dict(type='SyncBN', requires_grad=True)
model = dict( model = dict(
type='EncoderDecoder', type='EncoderDecoder',
pretrained=\ pretrained='pretrain/jx_vit_large_p16_384-b3be5167.pth',
'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_large_p16_384-b3be5167.pth', # noqa
backbone=dict( backbone=dict(
type='VisionTransformer', type='VisionTransformer',
img_size=(768, 768), img_size=(768, 768),

View File

@ -2,7 +2,7 @@
norm_cfg = dict(type='SyncBN', requires_grad=True) norm_cfg = dict(type='SyncBN', requires_grad=True)
model = dict( model = dict(
type='EncoderDecoder', type='EncoderDecoder',
pretrained='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_base_p16_224-80ecf9dd.pth', # noqa pretrained='pretrain/jx_vit_base_p16_224-80ecf9dd.pth',
backbone=dict( backbone=dict(
type='VisionTransformer', type='VisionTransformer',
img_size=(512, 512), img_size=(512, 512),

View File

@ -13,6 +13,18 @@
} }
``` ```
## Usage
To use other repositories' pre-trained models, it is necessary to convert keys.
We provide a script [`mit2mmseg.py`](../../tools/model_converters/mit2mmseg.py) in the tools directory to convert the key of models from [the official repo](https://github.com/NVlabs/SegFormer) to MMSegmentation style.
```shell
python tools/model_converters/swin2mmseg.py ${PRETRAIN_PATH} ${STORE_PATH}
```
This script convert model from `PRETRAIN_PATH` and store the converted model in `STORE_PATH`.
## Results and models ## Results and models
### ADE20k ### ADE20k
@ -61,13 +73,3 @@ test_pipeline = [
]) ])
] ]
``` ```
## How to use segformer official pretrain weights
We convert the backbone weights from the official repo (https://github.com/NVlabs/SegFormer) with `tools/model_converters/mit_convert.py`.
You may follow below steps to start segformer training preparation:
1. Download segformer pretrain weights (Suggest put in `pretrain/`);
2. Run convert script to convert official pretrain weights: `python tools/model_converters/mit_convert.py pretrain/mit_b0.pth pretrain/mit_b0.pth`;
3. Modify `pretrained` of segformer model config, for example, `pretrained` of `segformer_mit-b0_512x512_160k_ade20k.py` is set to `pretrain/mit_b0.pth`;

View File

@ -4,6 +4,7 @@ _base_ = [
] ]
norm_cfg = dict(type='SyncBN', requires_grad=True) norm_cfg = dict(type='SyncBN', requires_grad=True)
model = dict( model = dict(
pretrained='pretrain/vit_large_patch16_384.pth',
backbone=dict(img_size=(512, 512), drop_rate=0.), backbone=dict(img_size=(512, 512), drop_rate=0.),
decode_head=dict(num_classes=150), decode_head=dict(num_classes=150),
auxiliary_head=[ auxiliary_head=[

View File

@ -4,6 +4,7 @@ _base_ = [
] ]
norm_cfg = dict(type='SyncBN', requires_grad=True) norm_cfg = dict(type='SyncBN', requires_grad=True)
model = dict( model = dict(
pretrained='pretrain/vit_large_patch16_384.pth',
backbone=dict(img_size=(512, 512), drop_rate=0.), backbone=dict(img_size=(512, 512), drop_rate=0.),
decode_head=dict(num_classes=150), decode_head=dict(num_classes=150),
auxiliary_head=[ auxiliary_head=[

View File

@ -4,6 +4,7 @@ _base_ = [
] ]
norm_cfg = dict(type='SyncBN', requires_grad=True) norm_cfg = dict(type='SyncBN', requires_grad=True)
model = dict( model = dict(
pretrained='pretrain/vit_large_patch16_384.pth',
backbone=dict(img_size=(512, 512), drop_rate=0.), backbone=dict(img_size=(512, 512), drop_rate=0.),
decode_head=dict(num_classes=150), decode_head=dict(num_classes=150),
auxiliary_head=[ auxiliary_head=[

View File

@ -13,6 +13,24 @@
} }
``` ```
## Usage
To use other repositories' pre-trained models, it is necessary to convert keys.
We provide a script [`swin2mmseg.py`](../../tools/model_converters/swin2mmseg.py) in the tools directory to convert the key of models from [the official repo](https://github.com/SwinTransformer/Swin-Transformer-Semantic-Segmentation) to MMSegmentation style.
```shell
python tools/model_converters/swin2mmseg.py ${PRETRAIN_PATH} ${STORE_PATH}
```
E.g.
```shell
python tools/model_converters/swin2mmseg.py https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_base_patch4_window7_224.pth pretrain/swin_base_patch4_window7_224.pth
```
This script convert model from `PRETRAIN_PATH` and store the converted model in `STORE_PATH`.
## Results and models ## Results and models
### ADE20K ### ADE20K

View File

@ -3,8 +3,7 @@ _base_ = [
'pretrain_224x224_1K.py' 'pretrain_224x224_1K.py'
] ]
model = dict( model = dict(
pretrained=\ pretrained='pretrain/swin_base_patch4_window12_384.pth',
'https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_base_patch4_window12_384.pth', # noqa
backbone=dict( backbone=dict(
pretrain_img_size=384, pretrain_img_size=384,
embed_dims=128, embed_dims=128,

View File

@ -2,7 +2,4 @@ _base_ = [
'./upernet_swin_base_patch4_window12_512x512_160k_ade20k_' './upernet_swin_base_patch4_window12_512x512_160k_ade20k_'
'pretrain_384x384_1K.py' 'pretrain_384x384_1K.py'
] ]
model = dict( model = dict(pretrained='pretrain/swin_base_patch4_window12_384_22k.pth')
pretrained=\
'https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_base_patch4_window12_384_22k.pth', # noqa
)

View File

@ -3,11 +3,8 @@ _base_ = [
'pretrain_224x224_1K.py' 'pretrain_224x224_1K.py'
] ]
model = dict( model = dict(
pretrained=\ pretrained='pretrain/swin_base_patch4_window7_224.pth',
'https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_base_patch4_window7_224.pth', # noqa
backbone=dict( backbone=dict(
embed_dims=128, embed_dims=128, depths=[2, 2, 18, 2], num_heads=[4, 8, 16, 32]),
depths=[2, 2, 18, 2],
num_heads=[4, 8, 16, 32]),
decode_head=dict(in_channels=[128, 256, 512, 1024], num_classes=150), decode_head=dict(in_channels=[128, 256, 512, 1024], num_classes=150),
auxiliary_head=dict(in_channels=512, num_classes=150)) auxiliary_head=dict(in_channels=512, num_classes=150))

View File

@ -2,7 +2,4 @@ _base_ = [
'./upernet_swin_base_patch4_window7_512x512_160k_ade20k_' './upernet_swin_base_patch4_window7_512x512_160k_ade20k_'
'pretrain_224x224_1K.py' 'pretrain_224x224_1K.py'
] ]
model = dict( model = dict(pretrained='pretrain/swin_base_patch4_window7_224_22k.pth')
pretrained=\
'https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_base_patch4_window7_224_22k.pth', # noqa
)

View File

@ -3,15 +3,7 @@ _base_ = [
'pretrain_224x224_1K.py' 'pretrain_224x224_1K.py'
] ]
model = dict( model = dict(
pretrained=\ pretrained='pretrain/swin_small_patch4_window7_224.pth',
'https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_small_patch4_window7_224.pth', # noqa backbone=dict(depths=[2, 2, 18, 2]),
backbone=dict( decode_head=dict(in_channels=[96, 192, 384, 768], num_classes=150),
depths=[2, 2, 18, 2]), auxiliary_head=dict(in_channels=384, num_classes=150))
decode_head=dict(
in_channels=[96, 192, 384, 768],
num_classes=150
),
auxiliary_head=dict(
in_channels=384,
num_classes=150
))

View File

@ -3,8 +3,7 @@ _base_ = [
'../_base_/default_runtime.py', '../_base_/schedules/schedule_160k.py' '../_base_/default_runtime.py', '../_base_/schedules/schedule_160k.py'
] ]
model = dict( model = dict(
pretrained=\ pretrained='pretrain/swin_tiny_patch4_window7_224.pth',
'https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_tiny_patch4_window7_224.pth', # noqa
backbone=dict( backbone=dict(
embed_dims=96, embed_dims=96,
depths=[2, 2, 6, 2], depths=[2, 2, 6, 2],

View File

@ -13,6 +13,24 @@
} }
``` ```
## Usage
To use other repositories' pre-trained models, it is necessary to convert keys.
We provide a script [`vit2mmseg.py`](../../tools/model_converters/vit2mmseg.py) in the tools directory to convert the key of models from [timm](https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py) to MMSegmentation style.
```shell
python tools/model_converters/vit2mmseg.py ${PRETRAIN_PATH} ${STORE_PATH}
```
E.g.
```shell
python tools/model_converters/vit2mmseg.py https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_base_p16_224-80ecf9dd.pth pretrain/jx_vit_base_p16_224-80ecf9dd.pth
```
This script convert model from `PRETRAIN_PATH` and store the converted model in `STORE_PATH`.
## Results and models ## Results and models
### ADE20K ### ADE20K

View File

@ -1,6 +1,6 @@
_base_ = './upernet_vit-b16_mln_512x512_160k_ade20k.py' _base_ = './upernet_vit-b16_mln_512x512_160k_ade20k.py'
model = dict( model = dict(
pretrained='https://dl.fbaipublicfiles.com/deit/deit_base_patch16_224-b5f2ef4d.pth', # noqa pretrained='pretrain/deit_base_patch16_224-b5f2ef4d.pth',
backbone=dict(drop_path_rate=0.1), backbone=dict(drop_path_rate=0.1),
neck=None) # yapf: disable neck=None)

View File

@ -1,6 +1,6 @@
_base_ = './upernet_vit-b16_mln_512x512_80k_ade20k.py' _base_ = './upernet_vit-b16_mln_512x512_80k_ade20k.py'
model = dict( model = dict(
pretrained='https://dl.fbaipublicfiles.com/deit/deit_base_patch16_224-b5f2ef4d.pth', # noqa pretrained='pretrain/deit_base_patch16_224-b5f2ef4d.pth',
backbone=dict(drop_path_rate=0.1), backbone=dict(drop_path_rate=0.1),
neck=None) # yapf: disable neck=None)

View File

@ -1,5 +1,5 @@
_base_ = './upernet_vit-b16_mln_512x512_160k_ade20k.py' _base_ = './upernet_vit-b16_mln_512x512_160k_ade20k.py'
model = dict( model = dict(
pretrained='https://dl.fbaipublicfiles.com/deit/deit_base_patch16_224-b5f2ef4d.pth', # noqa pretrained='pretrain/deit_base_patch16_224-b5f2ef4d.pth',
backbone=dict(drop_path_rate=0.1, final_norm=True)) # yapf: disable backbone=dict(drop_path_rate=0.1, final_norm=True))

View File

@ -1,5 +1,6 @@
_base_ = './upernet_vit-b16_mln_512x512_160k_ade20k.py' _base_ = './upernet_vit-b16_mln_512x512_160k_ade20k.py'
model = dict( model = dict(
pretrained='https://dl.fbaipublicfiles.com/deit/deit_base_patch16_224-b5f2ef4d.pth', # noqa pretrained='pretrain/deit_base_patch16_224-b5f2ef4d.pth',
backbone=dict(drop_path_rate=0.1),) # yapf: disable backbone=dict(drop_path_rate=0.1),
)

View File

@ -1,8 +1,8 @@
_base_ = './upernet_vit-b16_mln_512x512_160k_ade20k.py' _base_ = './upernet_vit-b16_mln_512x512_160k_ade20k.py'
model = dict( model = dict(
pretrained='https://dl.fbaipublicfiles.com/deit/deit_small_patch16_224-cd65a155.pth', # noqa pretrained='pretrain/deit_small_patch16_224-cd65a155.pth',
backbone=dict(num_heads=6, embed_dims=384, drop_path_rate=0.1), backbone=dict(num_heads=6, embed_dims=384, drop_path_rate=0.1),
decode_head=dict(num_classes=150, in_channels=[384, 384, 384, 384]), decode_head=dict(num_classes=150, in_channels=[384, 384, 384, 384]),
neck=None, neck=None,
auxiliary_head=dict(num_classes=150, in_channels=384)) # yapf: disable auxiliary_head=dict(num_classes=150, in_channels=384))

View File

@ -1,8 +1,8 @@
_base_ = './upernet_vit-b16_mln_512x512_80k_ade20k.py' _base_ = './upernet_vit-b16_mln_512x512_80k_ade20k.py'
model = dict( model = dict(
pretrained='https://dl.fbaipublicfiles.com/deit/deit_small_patch16_224-cd65a155.pth', # noqa pretrained='pretrain/deit_small_patch16_224-cd65a155.pth',
backbone=dict(num_heads=6, embed_dims=384, drop_path_rate=0.1), backbone=dict(num_heads=6, embed_dims=384, drop_path_rate=0.1),
decode_head=dict(num_classes=150, in_channels=[384, 384, 384, 384]), decode_head=dict(num_classes=150, in_channels=[384, 384, 384, 384]),
neck=None, neck=None,
auxiliary_head=dict(num_classes=150, in_channels=384)) # yapf: disable auxiliary_head=dict(num_classes=150, in_channels=384))

View File

@ -1,12 +1,9 @@
_base_ = './upernet_vit-b16_mln_512x512_160k_ade20k.py' _base_ = './upernet_vit-b16_mln_512x512_160k_ade20k.py'
model = dict( model = dict(
pretrained='https://dl.fbaipublicfiles.com/deit/deit_small_patch16_224-cd65a155.pth', # noqa pretrained='pretrain/deit_small_patch16_224-cd65a155.pth',
backbone=dict( backbone=dict(
num_heads=6, num_heads=6, embed_dims=384, drop_path_rate=0.1, final_norm=True),
embed_dims=384,
drop_path_rate=0.1,
final_norm=True),
decode_head=dict(num_classes=150, in_channels=[384, 384, 384, 384]), decode_head=dict(num_classes=150, in_channels=[384, 384, 384, 384]),
neck=dict(in_channels=[384, 384, 384, 384], out_channels=384), neck=dict(in_channels=[384, 384, 384, 384], out_channels=384),
auxiliary_head=dict(num_classes=150, in_channels=384)) # yapf: disable auxiliary_head=dict(num_classes=150, in_channels=384))

View File

@ -1,8 +1,8 @@
_base_ = './upernet_vit-b16_mln_512x512_160k_ade20k.py' _base_ = './upernet_vit-b16_mln_512x512_160k_ade20k.py'
model = dict( model = dict(
pretrained='https://dl.fbaipublicfiles.com/deit/deit_small_patch16_224-cd65a155.pth', # noqa pretrained='pretrain/deit_small_patch16_224-cd65a155.pth',
backbone=dict(num_heads=6, embed_dims=384, drop_path_rate=0.1), backbone=dict(num_heads=6, embed_dims=384, drop_path_rate=0.1),
decode_head=dict(num_classes=150, in_channels=[384, 384, 384, 384]), decode_head=dict(num_classes=150, in_channels=[384, 384, 384, 384]),
neck=dict(in_channels=[384, 384, 384, 384], out_channels=384), neck=dict(in_channels=[384, 384, 384, 384], out_channels=384),
auxiliary_head=dict(num_classes=150, in_channels=384)) # yapf: disable auxiliary_head=dict(num_classes=150, in_channels=384))

View File

@ -5,6 +5,7 @@ _base_ = [
] ]
model = dict( model = dict(
pretrained='pretrain/vit_base_patch16_224.pth',
backbone=dict(drop_path_rate=0.1, final_norm=True), backbone=dict(drop_path_rate=0.1, final_norm=True),
decode_head=dict(num_classes=150), decode_head=dict(num_classes=150),
auxiliary_head=dict(num_classes=150)) auxiliary_head=dict(num_classes=150))

View File

@ -5,7 +5,9 @@ _base_ = [
] ]
model = dict( model = dict(
decode_head=dict(num_classes=150), auxiliary_head=dict(num_classes=150)) pretrained='pretrain/vit_base_patch16_224.pth',
decode_head=dict(num_classes=150),
auxiliary_head=dict(num_classes=150))
# AdamW optimizer, no weight decay for position embedding & layer norm # AdamW optimizer, no weight decay for position embedding & layer norm
# in backbone # in backbone

View File

@ -5,7 +5,9 @@ _base_ = [
] ]
model = dict( model = dict(
decode_head=dict(num_classes=150), auxiliary_head=dict(num_classes=150)) pretrained='pretrain/vit_base_patch16_224.pth',
decode_head=dict(num_classes=150),
auxiliary_head=dict(num_classes=150))
# AdamW optimizer, no weight decay for position embedding & layer norm # AdamW optimizer, no weight decay for position embedding & layer norm
# in backbone # in backbone

View File

@ -17,7 +17,7 @@ from torch.nn.modules.utils import _pair as to_2tuple
from mmseg.ops import resize from mmseg.ops import resize
from ...utils import get_root_logger from ...utils import get_root_logger
from ..builder import ATTENTION, BACKBONES from ..builder import ATTENTION, BACKBONES
from ..utils import PatchEmbed, swin_convert from ..utils import PatchEmbed
class PatchMerging(BaseModule): class PatchMerging(BaseModule):
@ -564,8 +564,6 @@ class SwinTransformer(BaseModule):
Default: dict(type='LN'). Default: dict(type='LN').
norm_cfg (dict): Config dict for normalization layer at norm_cfg (dict): Config dict for normalization layer at
output of backone. Defaults: dict(type='LN'). output of backone. Defaults: dict(type='LN').
pretrain_style (str): Choose to use official or mmcls pretrain weights.
Default: official.
pretrained (str, optional): model pretrained path. Default: None. pretrained (str, optional): model pretrained path. Default: None.
init_cfg (dict, optional): The Config for initialization. init_cfg (dict, optional): The Config for initialization.
Defaults to None. Defaults to None.
@ -591,7 +589,6 @@ class SwinTransformer(BaseModule):
use_abs_pos_embed=False, use_abs_pos_embed=False,
act_cfg=dict(type='GELU'), act_cfg=dict(type='GELU'),
norm_cfg=dict(type='LN'), norm_cfg=dict(type='LN'),
pretrain_style='official',
pretrained=None, pretrained=None,
init_cfg=None): init_cfg=None):
super(SwinTransformer, self).__init__() super(SwinTransformer, self).__init__()
@ -605,9 +602,6 @@ class SwinTransformer(BaseModule):
f'The size of image should have length 1 or 2, ' \ f'The size of image should have length 1 or 2, ' \
f'but got {len(pretrain_img_size)}' f'but got {len(pretrain_img_size)}'
assert pretrain_style in ['official', 'mmcls'], 'We only support load '
'official ckpt and mmcls ckpt.'
if isinstance(pretrained, str) or pretrained is None: if isinstance(pretrained, str) or pretrained is None:
warnings.warn('DeprecationWarning: pretrained is a deprecated, ' warnings.warn('DeprecationWarning: pretrained is a deprecated, '
'please use "init_cfg" instead') 'please use "init_cfg" instead')
@ -617,7 +611,6 @@ class SwinTransformer(BaseModule):
num_layers = len(depths) num_layers = len(depths)
self.out_indices = out_indices self.out_indices = out_indices
self.use_abs_pos_embed = use_abs_pos_embed self.use_abs_pos_embed = use_abs_pos_embed
self.pretrain_style = pretrain_style
self.pretrained = pretrained self.pretrained = pretrained
self.init_cfg = init_cfg self.init_cfg = init_cfg
@ -713,9 +706,6 @@ class SwinTransformer(BaseModule):
else: else:
state_dict = ckpt state_dict = ckpt
if self.pretrain_style == 'official':
state_dict = swin_convert(state_dict)
# strip prefix of state_dict # strip prefix of state_dict
if list(state_dict.keys())[0].startswith('module.'): if list(state_dict.keys())[0].startswith('module.'):
state_dict = {k[7:]: v for k, v in state_dict.items()} state_dict = {k[7:]: v for k, v in state_dict.items()}

View File

@ -14,7 +14,7 @@ from torch.nn.modules.utils import _pair as to_2tuple
from mmseg.ops import resize from mmseg.ops import resize
from mmseg.utils import get_root_logger from mmseg.utils import get_root_logger
from ..builder import BACKBONES from ..builder import BACKBONES
from ..utils import PatchEmbed, vit_convert from ..utils import PatchEmbed
class TransformerEncoderLayer(BaseModule): class TransformerEncoderLayer(BaseModule):
@ -140,8 +140,6 @@ class VisionTransformer(BaseModule):
and its variants only. Default: False. and its variants only. Default: False.
with_cp (bool): Use checkpoint or not. Using checkpoint will save with_cp (bool): Use checkpoint or not. Using checkpoint will save
some memory while slowing down the training speed. Default: False. some memory while slowing down the training speed. Default: False.
pretrain_style (str): Choose to use timm or mmcls pretrain weights.
Default: timm.
pretrained (str, optional): model pretrained path. Default: None. pretrained (str, optional): model pretrained path. Default: None.
init_cfg (dict or list[dict], optional): Initialization config dict. init_cfg (dict or list[dict], optional): Initialization config dict.
Default: None. Default: None.
@ -170,7 +168,6 @@ class VisionTransformer(BaseModule):
num_fcs=2, num_fcs=2,
norm_eval=False, norm_eval=False,
with_cp=False, with_cp=False,
pretrain_style='timm',
pretrained=None, pretrained=None,
init_cfg=None): init_cfg=None):
super(VisionTransformer, self).__init__() super(VisionTransformer, self).__init__()
@ -184,8 +181,6 @@ class VisionTransformer(BaseModule):
f'The size of image should have length 1 or 2, ' \ f'The size of image should have length 1 or 2, ' \
f'but got {len(img_size)}' f'but got {len(img_size)}'
assert pretrain_style in ['timm', 'mmcls']
if output_cls_token: if output_cls_token:
assert with_cls_token is True, f'with_cls_token must be True if' \ assert with_cls_token is True, f'with_cls_token must be True if' \
f'set output_cls_token to True, but got {with_cls_token}' f'set output_cls_token to True, but got {with_cls_token}'
@ -201,7 +196,6 @@ class VisionTransformer(BaseModule):
self.interpolate_mode = interpolate_mode self.interpolate_mode = interpolate_mode
self.norm_eval = norm_eval self.norm_eval = norm_eval
self.with_cp = with_cp self.with_cp = with_cp
self.pretrain_style = pretrain_style
self.pretrained = pretrained self.pretrained = pretrained
self.init_cfg = init_cfg self.init_cfg = init_cfg
@ -272,17 +266,9 @@ class VisionTransformer(BaseModule):
self.pretrained, logger=logger, map_location='cpu') self.pretrained, logger=logger, map_location='cpu')
if 'state_dict' in checkpoint: if 'state_dict' in checkpoint:
state_dict = checkpoint['state_dict'] state_dict = checkpoint['state_dict']
elif 'model' in checkpoint:
state_dict = checkpoint['model']
else: else:
state_dict = checkpoint state_dict = checkpoint
if self.pretrain_style == 'timm':
# Because the refactor of vit is blocked by mmcls,
# so we firstly use timm pretrain weights to train
# downstream model.
state_dict = vit_convert(state_dict)
if 'pos_embed' in state_dict.keys(): if 'pos_embed' in state_dict.keys():
if self.pos_embed.shape != state_dict['pos_embed'].shape: if self.pos_embed.shape != state_dict['pos_embed'].shape:
logger.info(msg=f'Resize the pos_embed shape from ' logger.info(msg=f'Resize the pos_embed shape from '

View File

@ -1,5 +1,3 @@
# Copyright (c) OpenMMLab. All rights reserved.
from .ckpt_convert import swin_convert, vit_convert
from .embed import PatchEmbed from .embed import PatchEmbed
from .inverted_residual import InvertedResidual, InvertedResidualV3 from .inverted_residual import InvertedResidual, InvertedResidualV3
from .make_divisible import make_divisible from .make_divisible import make_divisible
@ -11,6 +9,6 @@ from .up_conv_block import UpConvBlock
__all__ = [ __all__ = [
'ResLayer', 'SelfAttentionBlock', 'make_divisible', 'InvertedResidual', 'ResLayer', 'SelfAttentionBlock', 'make_divisible', 'InvertedResidual',
'UpConvBlock', 'InvertedResidualV3', 'SELayer', 'vit_convert', 'UpConvBlock', 'InvertedResidualV3', 'SELayer', 'PatchEmbed',
'swin_convert', 'PatchEmbed', 'nchw_to_nlc', 'nlc_to_nchw' 'nchw_to_nlc', 'nlc_to_nchw'
] ]

View File

@ -1,91 +0,0 @@
# Copyright (c) OpenMMLab. All rights reserved.
from collections import OrderedDict
def swin_convert(ckpt):
new_ckpt = OrderedDict()
def correct_unfold_reduction_order(x):
out_channel, in_channel = x.shape
x = x.reshape(out_channel, 4, in_channel // 4)
x = x[:, [0, 2, 1, 3], :].transpose(1,
2).reshape(out_channel, in_channel)
return x
def correct_unfold_norm_order(x):
in_channel = x.shape[0]
x = x.reshape(4, in_channel // 4)
x = x[[0, 2, 1, 3], :].transpose(0, 1).reshape(in_channel)
return x
for k, v in ckpt.items():
if k.startswith('head'):
continue
elif k.startswith('layers'):
new_v = v
if 'attn.' in k:
new_k = k.replace('attn.', 'attn.w_msa.')
elif 'mlp.' in k:
if 'mlp.fc1.' in k:
new_k = k.replace('mlp.fc1.', 'ffn.layers.0.0.')
elif 'mlp.fc2.' in k:
new_k = k.replace('mlp.fc2.', 'ffn.layers.1.')
else:
new_k = k.replace('mlp.', 'ffn.')
elif 'downsample' in k:
new_k = k
if 'reduction.' in k:
new_v = correct_unfold_reduction_order(v)
elif 'norm.' in k:
new_v = correct_unfold_norm_order(v)
else:
new_k = k
new_k = new_k.replace('layers', 'stages', 1)
elif k.startswith('patch_embed'):
new_v = v
if 'proj' in k:
new_k = k.replace('proj', 'projection')
else:
new_k = k
else:
new_v = v
new_k = k
new_ckpt[new_k] = new_v
return new_ckpt
def vit_convert(ckpt):
new_ckpt = OrderedDict()
for k, v in ckpt.items():
if k.startswith('head'):
continue
if k.startswith('norm'):
new_k = k.replace('norm.', 'ln1.')
elif k.startswith('patch_embed'):
if 'proj' in k:
new_k = k.replace('proj', 'projection')
else:
new_k = k
elif k.startswith('blocks'):
if 'norm' in k:
new_k = k.replace('norm', 'ln')
elif 'mlp.fc1' in k:
new_k = k.replace('mlp.fc1', 'ffn.layers.0.0')
elif 'mlp.fc2' in k:
new_k = k.replace('mlp.fc2', 'ffn.layers.1')
elif 'attn.qkv' in k:
new_k = k.replace('attn.qkv.', 'attn.attn.in_proj_')
elif 'attn.proj' in k:
new_k = k.replace('attn.proj', 'attn.attn.out_proj')
else:
new_k = k
new_k = new_k.replace('blocks.', 'layers.')
else:
new_k = k
new_ckpt[new_k] = v
return new_ckpt

View File

@ -8,10 +8,6 @@ from mmseg.models.backbones import SwinTransformer
def test_swin_transformer(): def test_swin_transformer():
"""Test Swin Transformer backbone.""" """Test Swin Transformer backbone."""
with pytest.raises(AssertionError):
# We only support 'official' or 'mmcls' for this arg.
model = SwinTransformer(pretrain_style='swin')
with pytest.raises(TypeError): with pytest.raises(TypeError):
# Pretrained arg must be str or None. # Pretrained arg must be str or None.
model = SwinTransformer(pretrained=123) model = SwinTransformer(pretrained=123)

View File

@ -1,8 +1,11 @@
# Copyright (c) OpenMMLab. All rights reserved. # Copyright (c) OpenMMLab. All rights reserved.
import argparse import argparse
import os.path as osp
from collections import OrderedDict from collections import OrderedDict
import mmcv
import torch import torch
from mmcv.runner import CheckpointLoader
def convert_mit(ckpt): def convert_mit(ckpt):
@ -54,24 +57,26 @@ def convert_mit(ckpt):
return new_ckpt return new_ckpt
def parse_args(): def main():
parser = argparse.ArgumentParser( parser = argparse.ArgumentParser(
'Convert official segformer backbone weights to mmseg style.') description='Convert keys in official pretrained segformer to '
parser.add_argument( 'MMSegmentation style.')
'src', help='Source path of official segformer backbone weights.') parser.add_argument('src', help='src model path or url')
parser.add_argument( # The dst path must be a full path of the new checkpoint.
'dst', parser.add_argument('dst', help='save path')
help='Destination path of converted segformer backbone weights.') args = parser.parse_args()
return parser.parse_args() checkpoint = CheckpointLoader.load_checkpoint(args.src, map_location='cpu')
if 'state_dict' in checkpoint:
state_dict = checkpoint['state_dict']
elif 'model' in checkpoint:
state_dict = checkpoint['model']
else:
state_dict = checkpoint
weight = convert_mit(state_dict)
mmcv.mkdir_or_exist(osp.dirname(args.dst))
torch.save(weight, args.dst)
if __name__ == '__main__': if __name__ == '__main__':
args = parse_args() main()
src_path = args.src
dst_path = args.dst
ckpt = torch.load(src_path, map_location='cpu')
ckpt = convert_mit(ckpt)
torch.save(ckpt, dst_path)

View File

@ -1,7 +1,11 @@
# Copyright (c) OpenMMLab. All rights reserved.
import argparse import argparse
import os.path as osp
from collections import OrderedDict from collections import OrderedDict
import mmcv
import torch import torch
from mmcv.runner import CheckpointLoader
def convert_swin(ckpt): def convert_swin(ckpt):
@ -62,12 +66,12 @@ def main():
parser = argparse.ArgumentParser( parser = argparse.ArgumentParser(
description='Convert keys in official pretrained swin models to' description='Convert keys in official pretrained swin models to'
'MMSegmentation style.') 'MMSegmentation style.')
parser.add_argument('src', help='src segmentation model path') parser.add_argument('src', help='src model path or url')
# The dst path must be a full path of the new checkpoint. # The dst path must be a full path of the new checkpoint.
parser.add_argument('dst', help='save path') parser.add_argument('dst', help='save path')
args = parser.parse_args() args = parser.parse_args()
checkpoint = torch.load(args.src, map_location='cpu') checkpoint = CheckpointLoader.load_checkpoint(args.src, map_location='cpu')
if 'state_dict' in checkpoint: if 'state_dict' in checkpoint:
state_dict = checkpoint['state_dict'] state_dict = checkpoint['state_dict']
elif 'model' in checkpoint: elif 'model' in checkpoint:
@ -75,8 +79,8 @@ def main():
else: else:
state_dict = checkpoint state_dict = checkpoint
weight = convert_swin(state_dict) weight = convert_swin(state_dict)
with open(args.dst, 'wb') as f: mmcv.mkdir_or_exist(osp.dirname(args.dst))
torch.save(weight, f) torch.save(weight, args.dst)
if __name__ == '__main__': if __name__ == '__main__':

View File

@ -1,7 +1,11 @@
# Copyright (c) OpenMMLab. All rights reserved.
import argparse import argparse
import os.path as osp
from collections import OrderedDict from collections import OrderedDict
import mmcv
import torch import torch
from mmcv.runner import CheckpointLoader
def convert_vit(ckpt): def convert_vit(ckpt):
@ -43,12 +47,12 @@ def main():
parser = argparse.ArgumentParser( parser = argparse.ArgumentParser(
description='Convert keys in timm pretrained vit models to ' description='Convert keys in timm pretrained vit models to '
'MMSegmentation style.') 'MMSegmentation style.')
parser.add_argument('src', help='src segmentation model path') parser.add_argument('src', help='src model path or url')
# The dst path must be a full path of the new checkpoint. # The dst path must be a full path of the new checkpoint.
parser.add_argument('dst', help='save path') parser.add_argument('dst', help='save path')
args = parser.parse_args() args = parser.parse_args()
checkpoint = torch.load(args.src, map_location='cpu') checkpoint = CheckpointLoader.load_checkpoint(args.src, map_location='cpu')
if 'state_dict' in checkpoint: if 'state_dict' in checkpoint:
# timm checkpoint # timm checkpoint
state_dict = checkpoint['state_dict'] state_dict = checkpoint['state_dict']
@ -58,8 +62,8 @@ def main():
else: else:
state_dict = checkpoint state_dict = checkpoint
weight = convert_vit(state_dict) weight = convert_vit(state_dict)
with open(args.dst, 'wb') as f: mmcv.mkdir_or_exist(osp.dirname(args.dst))
torch.save(weight, f) torch.save(weight, args.dst)
if __name__ == '__main__': if __name__ == '__main__':