[Enhancement] Delete convert function and add instruction to ViT/Swin README.md (#791)

* delete convert function and add instruction to README.md

* unified model convert and README

* remove url

* fix import error

* fix unittest

* rename pretrain

* rename vit and deit pretrain

* Update upernet_deit-b16_512x512_160k_ade20k.py

* Update upernet_deit-b16_512x512_80k_ade20k.py

* Update upernet_deit-b16_ln_mln_512x512_160k_ade20k.py

* Update upernet_deit-b16_mln_512x512_160k_ade20k.py

* Update upernet_deit-s16_512x512_160k_ade20k.py

* Update upernet_deit-s16_512x512_80k_ade20k.py

* Update upernet_deit-s16_ln_mln_512x512_160k_ade20k.py

* Update upernet_deit-s16_mln_512x512_160k_ade20k.py

Co-authored-by: Jiarui XU <xvjiarui0826@gmail.com>
Co-authored-by: Junjun2016 <hejunjun@sjtu.edu.cn>
pull/1801/head
谢昕辰 2021-08-26 06:00:41 +08:00 committed by GitHub
parent 4e9c26bbbc
commit 119bbd838d
35 changed files with 131 additions and 217 deletions

View File

@ -3,8 +3,7 @@ backbone_norm_cfg = dict(type='LN', eps=1e-6, requires_grad=True)
norm_cfg = dict(type='SyncBN', requires_grad=True)
model = dict(
type='EncoderDecoder',
pretrained=\
'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_large_p16_384-b3be5167.pth', # noqa
pretrained='pretrain/jx_vit_large_p16_384-b3be5167.pth',
backbone=dict(
type='VisionTransformer',
img_size=(768, 768),

View File

@ -3,8 +3,7 @@ backbone_norm_cfg = dict(type='LN', eps=1e-6, requires_grad=True)
norm_cfg = dict(type='SyncBN', requires_grad=True)
model = dict(
type='EncoderDecoder',
pretrained=\
'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_large_p16_384-b3be5167.pth', # noqa
pretrained='pretrain/jx_vit_large_p16_384-b3be5167.pth',
backbone=dict(
type='VisionTransformer',
img_size=(768, 768),

View File

@ -3,8 +3,7 @@ backbone_norm_cfg = dict(type='LN', eps=1e-6, requires_grad=True)
norm_cfg = dict(type='SyncBN', requires_grad=True)
model = dict(
type='EncoderDecoder',
pretrained=\
'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_large_p16_384-b3be5167.pth', # noqa
pretrained='pretrain/jx_vit_large_p16_384-b3be5167.pth',
backbone=dict(
type='VisionTransformer',
img_size=(768, 768),

View File

@ -2,7 +2,7 @@
norm_cfg = dict(type='SyncBN', requires_grad=True)
model = dict(
type='EncoderDecoder',
pretrained='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_base_p16_224-80ecf9dd.pth', # noqa
pretrained='pretrain/jx_vit_base_p16_224-80ecf9dd.pth',
backbone=dict(
type='VisionTransformer',
img_size=(512, 512),

View File

@ -13,6 +13,18 @@
}
```
## Usage
To use other repositories' pre-trained models, it is necessary to convert keys.
We provide a script [`mit2mmseg.py`](../../tools/model_converters/mit2mmseg.py) in the tools directory to convert the key of models from [the official repo](https://github.com/NVlabs/SegFormer) to MMSegmentation style.
```shell
python tools/model_converters/swin2mmseg.py ${PRETRAIN_PATH} ${STORE_PATH}
```
This script convert model from `PRETRAIN_PATH` and store the converted model in `STORE_PATH`.
## Results and models
### ADE20k
@ -61,13 +73,3 @@ test_pipeline = [
])
]
```
## How to use segformer official pretrain weights
We convert the backbone weights from the official repo (https://github.com/NVlabs/SegFormer) with `tools/model_converters/mit_convert.py`.
You may follow below steps to start segformer training preparation:
1. Download segformer pretrain weights (Suggest put in `pretrain/`);
2. Run convert script to convert official pretrain weights: `python tools/model_converters/mit_convert.py pretrain/mit_b0.pth pretrain/mit_b0.pth`;
3. Modify `pretrained` of segformer model config, for example, `pretrained` of `segformer_mit-b0_512x512_160k_ade20k.py` is set to `pretrain/mit_b0.pth`;

View File

@ -4,6 +4,7 @@ _base_ = [
]
norm_cfg = dict(type='SyncBN', requires_grad=True)
model = dict(
pretrained='pretrain/vit_large_patch16_384.pth',
backbone=dict(img_size=(512, 512), drop_rate=0.),
decode_head=dict(num_classes=150),
auxiliary_head=[

View File

@ -4,6 +4,7 @@ _base_ = [
]
norm_cfg = dict(type='SyncBN', requires_grad=True)
model = dict(
pretrained='pretrain/vit_large_patch16_384.pth',
backbone=dict(img_size=(512, 512), drop_rate=0.),
decode_head=dict(num_classes=150),
auxiliary_head=[

View File

@ -4,6 +4,7 @@ _base_ = [
]
norm_cfg = dict(type='SyncBN', requires_grad=True)
model = dict(
pretrained='pretrain/vit_large_patch16_384.pth',
backbone=dict(img_size=(512, 512), drop_rate=0.),
decode_head=dict(num_classes=150),
auxiliary_head=[

View File

@ -13,6 +13,24 @@
}
```
## Usage
To use other repositories' pre-trained models, it is necessary to convert keys.
We provide a script [`swin2mmseg.py`](../../tools/model_converters/swin2mmseg.py) in the tools directory to convert the key of models from [the official repo](https://github.com/SwinTransformer/Swin-Transformer-Semantic-Segmentation) to MMSegmentation style.
```shell
python tools/model_converters/swin2mmseg.py ${PRETRAIN_PATH} ${STORE_PATH}
```
E.g.
```shell
python tools/model_converters/swin2mmseg.py https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_base_patch4_window7_224.pth pretrain/swin_base_patch4_window7_224.pth
```
This script convert model from `PRETRAIN_PATH` and store the converted model in `STORE_PATH`.
## Results and models
### ADE20K

View File

@ -3,8 +3,7 @@ _base_ = [
'pretrain_224x224_1K.py'
]
model = dict(
pretrained=\
'https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_base_patch4_window12_384.pth', # noqa
pretrained='pretrain/swin_base_patch4_window12_384.pth',
backbone=dict(
pretrain_img_size=384,
embed_dims=128,

View File

@ -2,7 +2,4 @@ _base_ = [
'./upernet_swin_base_patch4_window12_512x512_160k_ade20k_'
'pretrain_384x384_1K.py'
]
model = dict(
pretrained=\
'https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_base_patch4_window12_384_22k.pth', # noqa
)
model = dict(pretrained='pretrain/swin_base_patch4_window12_384_22k.pth')

View File

@ -3,11 +3,8 @@ _base_ = [
'pretrain_224x224_1K.py'
]
model = dict(
pretrained=\
'https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_base_patch4_window7_224.pth', # noqa
pretrained='pretrain/swin_base_patch4_window7_224.pth',
backbone=dict(
embed_dims=128,
depths=[2, 2, 18, 2],
num_heads=[4, 8, 16, 32]),
embed_dims=128, depths=[2, 2, 18, 2], num_heads=[4, 8, 16, 32]),
decode_head=dict(in_channels=[128, 256, 512, 1024], num_classes=150),
auxiliary_head=dict(in_channels=512, num_classes=150))

View File

@ -2,7 +2,4 @@ _base_ = [
'./upernet_swin_base_patch4_window7_512x512_160k_ade20k_'
'pretrain_224x224_1K.py'
]
model = dict(
pretrained=\
'https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_base_patch4_window7_224_22k.pth', # noqa
)
model = dict(pretrained='pretrain/swin_base_patch4_window7_224_22k.pth')

View File

@ -3,15 +3,7 @@ _base_ = [
'pretrain_224x224_1K.py'
]
model = dict(
pretrained=\
'https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_small_patch4_window7_224.pth', # noqa
backbone=dict(
depths=[2, 2, 18, 2]),
decode_head=dict(
in_channels=[96, 192, 384, 768],
num_classes=150
),
auxiliary_head=dict(
in_channels=384,
num_classes=150
))
pretrained='pretrain/swin_small_patch4_window7_224.pth',
backbone=dict(depths=[2, 2, 18, 2]),
decode_head=dict(in_channels=[96, 192, 384, 768], num_classes=150),
auxiliary_head=dict(in_channels=384, num_classes=150))

View File

@ -3,8 +3,7 @@ _base_ = [
'../_base_/default_runtime.py', '../_base_/schedules/schedule_160k.py'
]
model = dict(
pretrained=\
'https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_tiny_patch4_window7_224.pth', # noqa
pretrained='pretrain/swin_tiny_patch4_window7_224.pth',
backbone=dict(
embed_dims=96,
depths=[2, 2, 6, 2],

View File

@ -13,6 +13,24 @@
}
```
## Usage
To use other repositories' pre-trained models, it is necessary to convert keys.
We provide a script [`vit2mmseg.py`](../../tools/model_converters/vit2mmseg.py) in the tools directory to convert the key of models from [timm](https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py) to MMSegmentation style.
```shell
python tools/model_converters/vit2mmseg.py ${PRETRAIN_PATH} ${STORE_PATH}
```
E.g.
```shell
python tools/model_converters/vit2mmseg.py https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_base_p16_224-80ecf9dd.pth pretrain/jx_vit_base_p16_224-80ecf9dd.pth
```
This script convert model from `PRETRAIN_PATH` and store the converted model in `STORE_PATH`.
## Results and models
### ADE20K

View File

@ -1,6 +1,6 @@
_base_ = './upernet_vit-b16_mln_512x512_160k_ade20k.py'
model = dict(
pretrained='https://dl.fbaipublicfiles.com/deit/deit_base_patch16_224-b5f2ef4d.pth', # noqa
pretrained='pretrain/deit_base_patch16_224-b5f2ef4d.pth',
backbone=dict(drop_path_rate=0.1),
neck=None) # yapf: disable
neck=None)

View File

@ -1,6 +1,6 @@
_base_ = './upernet_vit-b16_mln_512x512_80k_ade20k.py'
model = dict(
pretrained='https://dl.fbaipublicfiles.com/deit/deit_base_patch16_224-b5f2ef4d.pth', # noqa
pretrained='pretrain/deit_base_patch16_224-b5f2ef4d.pth',
backbone=dict(drop_path_rate=0.1),
neck=None) # yapf: disable
neck=None)

View File

@ -1,5 +1,5 @@
_base_ = './upernet_vit-b16_mln_512x512_160k_ade20k.py'
model = dict(
pretrained='https://dl.fbaipublicfiles.com/deit/deit_base_patch16_224-b5f2ef4d.pth', # noqa
backbone=dict(drop_path_rate=0.1, final_norm=True)) # yapf: disable
pretrained='pretrain/deit_base_patch16_224-b5f2ef4d.pth',
backbone=dict(drop_path_rate=0.1, final_norm=True))

View File

@ -1,5 +1,6 @@
_base_ = './upernet_vit-b16_mln_512x512_160k_ade20k.py'
model = dict(
pretrained='https://dl.fbaipublicfiles.com/deit/deit_base_patch16_224-b5f2ef4d.pth', # noqa
backbone=dict(drop_path_rate=0.1),) # yapf: disable
pretrained='pretrain/deit_base_patch16_224-b5f2ef4d.pth',
backbone=dict(drop_path_rate=0.1),
)

View File

@ -1,8 +1,8 @@
_base_ = './upernet_vit-b16_mln_512x512_160k_ade20k.py'
model = dict(
pretrained='https://dl.fbaipublicfiles.com/deit/deit_small_patch16_224-cd65a155.pth', # noqa
pretrained='pretrain/deit_small_patch16_224-cd65a155.pth',
backbone=dict(num_heads=6, embed_dims=384, drop_path_rate=0.1),
decode_head=dict(num_classes=150, in_channels=[384, 384, 384, 384]),
neck=None,
auxiliary_head=dict(num_classes=150, in_channels=384)) # yapf: disable
auxiliary_head=dict(num_classes=150, in_channels=384))

View File

@ -1,8 +1,8 @@
_base_ = './upernet_vit-b16_mln_512x512_80k_ade20k.py'
model = dict(
pretrained='https://dl.fbaipublicfiles.com/deit/deit_small_patch16_224-cd65a155.pth', # noqa
pretrained='pretrain/deit_small_patch16_224-cd65a155.pth',
backbone=dict(num_heads=6, embed_dims=384, drop_path_rate=0.1),
decode_head=dict(num_classes=150, in_channels=[384, 384, 384, 384]),
neck=None,
auxiliary_head=dict(num_classes=150, in_channels=384)) # yapf: disable
auxiliary_head=dict(num_classes=150, in_channels=384))

View File

@ -1,12 +1,9 @@
_base_ = './upernet_vit-b16_mln_512x512_160k_ade20k.py'
model = dict(
pretrained='https://dl.fbaipublicfiles.com/deit/deit_small_patch16_224-cd65a155.pth', # noqa
pretrained='pretrain/deit_small_patch16_224-cd65a155.pth',
backbone=dict(
num_heads=6,
embed_dims=384,
drop_path_rate=0.1,
final_norm=True),
num_heads=6, embed_dims=384, drop_path_rate=0.1, final_norm=True),
decode_head=dict(num_classes=150, in_channels=[384, 384, 384, 384]),
neck=dict(in_channels=[384, 384, 384, 384], out_channels=384),
auxiliary_head=dict(num_classes=150, in_channels=384)) # yapf: disable
auxiliary_head=dict(num_classes=150, in_channels=384))

View File

@ -1,8 +1,8 @@
_base_ = './upernet_vit-b16_mln_512x512_160k_ade20k.py'
model = dict(
pretrained='https://dl.fbaipublicfiles.com/deit/deit_small_patch16_224-cd65a155.pth', # noqa
pretrained='pretrain/deit_small_patch16_224-cd65a155.pth',
backbone=dict(num_heads=6, embed_dims=384, drop_path_rate=0.1),
decode_head=dict(num_classes=150, in_channels=[384, 384, 384, 384]),
neck=dict(in_channels=[384, 384, 384, 384], out_channels=384),
auxiliary_head=dict(num_classes=150, in_channels=384)) # yapf: disable
auxiliary_head=dict(num_classes=150, in_channels=384))

View File

@ -5,6 +5,7 @@ _base_ = [
]
model = dict(
pretrained='pretrain/vit_base_patch16_224.pth',
backbone=dict(drop_path_rate=0.1, final_norm=True),
decode_head=dict(num_classes=150),
auxiliary_head=dict(num_classes=150))

View File

@ -5,7 +5,9 @@ _base_ = [
]
model = dict(
decode_head=dict(num_classes=150), auxiliary_head=dict(num_classes=150))
pretrained='pretrain/vit_base_patch16_224.pth',
decode_head=dict(num_classes=150),
auxiliary_head=dict(num_classes=150))
# AdamW optimizer, no weight decay for position embedding & layer norm
# in backbone

View File

@ -5,7 +5,9 @@ _base_ = [
]
model = dict(
decode_head=dict(num_classes=150), auxiliary_head=dict(num_classes=150))
pretrained='pretrain/vit_base_patch16_224.pth',
decode_head=dict(num_classes=150),
auxiliary_head=dict(num_classes=150))
# AdamW optimizer, no weight decay for position embedding & layer norm
# in backbone

View File

@ -17,7 +17,7 @@ from torch.nn.modules.utils import _pair as to_2tuple
from mmseg.ops import resize
from ...utils import get_root_logger
from ..builder import ATTENTION, BACKBONES
from ..utils import PatchEmbed, swin_convert
from ..utils import PatchEmbed
class PatchMerging(BaseModule):
@ -564,8 +564,6 @@ class SwinTransformer(BaseModule):
Default: dict(type='LN').
norm_cfg (dict): Config dict for normalization layer at
output of backone. Defaults: dict(type='LN').
pretrain_style (str): Choose to use official or mmcls pretrain weights.
Default: official.
pretrained (str, optional): model pretrained path. Default: None.
init_cfg (dict, optional): The Config for initialization.
Defaults to None.
@ -591,7 +589,6 @@ class SwinTransformer(BaseModule):
use_abs_pos_embed=False,
act_cfg=dict(type='GELU'),
norm_cfg=dict(type='LN'),
pretrain_style='official',
pretrained=None,
init_cfg=None):
super(SwinTransformer, self).__init__()
@ -605,9 +602,6 @@ class SwinTransformer(BaseModule):
f'The size of image should have length 1 or 2, ' \
f'but got {len(pretrain_img_size)}'
assert pretrain_style in ['official', 'mmcls'], 'We only support load '
'official ckpt and mmcls ckpt.'
if isinstance(pretrained, str) or pretrained is None:
warnings.warn('DeprecationWarning: pretrained is a deprecated, '
'please use "init_cfg" instead')
@ -617,7 +611,6 @@ class SwinTransformer(BaseModule):
num_layers = len(depths)
self.out_indices = out_indices
self.use_abs_pos_embed = use_abs_pos_embed
self.pretrain_style = pretrain_style
self.pretrained = pretrained
self.init_cfg = init_cfg
@ -713,9 +706,6 @@ class SwinTransformer(BaseModule):
else:
state_dict = ckpt
if self.pretrain_style == 'official':
state_dict = swin_convert(state_dict)
# strip prefix of state_dict
if list(state_dict.keys())[0].startswith('module.'):
state_dict = {k[7:]: v for k, v in state_dict.items()}

View File

@ -14,7 +14,7 @@ from torch.nn.modules.utils import _pair as to_2tuple
from mmseg.ops import resize
from mmseg.utils import get_root_logger
from ..builder import BACKBONES
from ..utils import PatchEmbed, vit_convert
from ..utils import PatchEmbed
class TransformerEncoderLayer(BaseModule):
@ -140,8 +140,6 @@ class VisionTransformer(BaseModule):
and its variants only. Default: False.
with_cp (bool): Use checkpoint or not. Using checkpoint will save
some memory while slowing down the training speed. Default: False.
pretrain_style (str): Choose to use timm or mmcls pretrain weights.
Default: timm.
pretrained (str, optional): model pretrained path. Default: None.
init_cfg (dict or list[dict], optional): Initialization config dict.
Default: None.
@ -170,7 +168,6 @@ class VisionTransformer(BaseModule):
num_fcs=2,
norm_eval=False,
with_cp=False,
pretrain_style='timm',
pretrained=None,
init_cfg=None):
super(VisionTransformer, self).__init__()
@ -184,8 +181,6 @@ class VisionTransformer(BaseModule):
f'The size of image should have length 1 or 2, ' \
f'but got {len(img_size)}'
assert pretrain_style in ['timm', 'mmcls']
if output_cls_token:
assert with_cls_token is True, f'with_cls_token must be True if' \
f'set output_cls_token to True, but got {with_cls_token}'
@ -201,7 +196,6 @@ class VisionTransformer(BaseModule):
self.interpolate_mode = interpolate_mode
self.norm_eval = norm_eval
self.with_cp = with_cp
self.pretrain_style = pretrain_style
self.pretrained = pretrained
self.init_cfg = init_cfg
@ -272,17 +266,9 @@ class VisionTransformer(BaseModule):
self.pretrained, logger=logger, map_location='cpu')
if 'state_dict' in checkpoint:
state_dict = checkpoint['state_dict']
elif 'model' in checkpoint:
state_dict = checkpoint['model']
else:
state_dict = checkpoint
if self.pretrain_style == 'timm':
# Because the refactor of vit is blocked by mmcls,
# so we firstly use timm pretrain weights to train
# downstream model.
state_dict = vit_convert(state_dict)
if 'pos_embed' in state_dict.keys():
if self.pos_embed.shape != state_dict['pos_embed'].shape:
logger.info(msg=f'Resize the pos_embed shape from '

View File

@ -1,5 +1,3 @@
# Copyright (c) OpenMMLab. All rights reserved.
from .ckpt_convert import swin_convert, vit_convert
from .embed import PatchEmbed
from .inverted_residual import InvertedResidual, InvertedResidualV3
from .make_divisible import make_divisible
@ -11,6 +9,6 @@ from .up_conv_block import UpConvBlock
__all__ = [
'ResLayer', 'SelfAttentionBlock', 'make_divisible', 'InvertedResidual',
'UpConvBlock', 'InvertedResidualV3', 'SELayer', 'vit_convert',
'swin_convert', 'PatchEmbed', 'nchw_to_nlc', 'nlc_to_nchw'
'UpConvBlock', 'InvertedResidualV3', 'SELayer', 'PatchEmbed',
'nchw_to_nlc', 'nlc_to_nchw'
]

View File

@ -1,91 +0,0 @@
# Copyright (c) OpenMMLab. All rights reserved.
from collections import OrderedDict
def swin_convert(ckpt):
new_ckpt = OrderedDict()
def correct_unfold_reduction_order(x):
out_channel, in_channel = x.shape
x = x.reshape(out_channel, 4, in_channel // 4)
x = x[:, [0, 2, 1, 3], :].transpose(1,
2).reshape(out_channel, in_channel)
return x
def correct_unfold_norm_order(x):
in_channel = x.shape[0]
x = x.reshape(4, in_channel // 4)
x = x[[0, 2, 1, 3], :].transpose(0, 1).reshape(in_channel)
return x
for k, v in ckpt.items():
if k.startswith('head'):
continue
elif k.startswith('layers'):
new_v = v
if 'attn.' in k:
new_k = k.replace('attn.', 'attn.w_msa.')
elif 'mlp.' in k:
if 'mlp.fc1.' in k:
new_k = k.replace('mlp.fc1.', 'ffn.layers.0.0.')
elif 'mlp.fc2.' in k:
new_k = k.replace('mlp.fc2.', 'ffn.layers.1.')
else:
new_k = k.replace('mlp.', 'ffn.')
elif 'downsample' in k:
new_k = k
if 'reduction.' in k:
new_v = correct_unfold_reduction_order(v)
elif 'norm.' in k:
new_v = correct_unfold_norm_order(v)
else:
new_k = k
new_k = new_k.replace('layers', 'stages', 1)
elif k.startswith('patch_embed'):
new_v = v
if 'proj' in k:
new_k = k.replace('proj', 'projection')
else:
new_k = k
else:
new_v = v
new_k = k
new_ckpt[new_k] = new_v
return new_ckpt
def vit_convert(ckpt):
new_ckpt = OrderedDict()
for k, v in ckpt.items():
if k.startswith('head'):
continue
if k.startswith('norm'):
new_k = k.replace('norm.', 'ln1.')
elif k.startswith('patch_embed'):
if 'proj' in k:
new_k = k.replace('proj', 'projection')
else:
new_k = k
elif k.startswith('blocks'):
if 'norm' in k:
new_k = k.replace('norm', 'ln')
elif 'mlp.fc1' in k:
new_k = k.replace('mlp.fc1', 'ffn.layers.0.0')
elif 'mlp.fc2' in k:
new_k = k.replace('mlp.fc2', 'ffn.layers.1')
elif 'attn.qkv' in k:
new_k = k.replace('attn.qkv.', 'attn.attn.in_proj_')
elif 'attn.proj' in k:
new_k = k.replace('attn.proj', 'attn.attn.out_proj')
else:
new_k = k
new_k = new_k.replace('blocks.', 'layers.')
else:
new_k = k
new_ckpt[new_k] = v
return new_ckpt

View File

@ -8,10 +8,6 @@ from mmseg.models.backbones import SwinTransformer
def test_swin_transformer():
"""Test Swin Transformer backbone."""
with pytest.raises(AssertionError):
# We only support 'official' or 'mmcls' for this arg.
model = SwinTransformer(pretrain_style='swin')
with pytest.raises(TypeError):
# Pretrained arg must be str or None.
model = SwinTransformer(pretrained=123)

View File

@ -1,8 +1,11 @@
# Copyright (c) OpenMMLab. All rights reserved.
import argparse
import os.path as osp
from collections import OrderedDict
import mmcv
import torch
from mmcv.runner import CheckpointLoader
def convert_mit(ckpt):
@ -54,24 +57,26 @@ def convert_mit(ckpt):
return new_ckpt
def parse_args():
def main():
parser = argparse.ArgumentParser(
'Convert official segformer backbone weights to mmseg style.')
parser.add_argument(
'src', help='Source path of official segformer backbone weights.')
parser.add_argument(
'dst',
help='Destination path of converted segformer backbone weights.')
description='Convert keys in official pretrained segformer to '
'MMSegmentation style.')
parser.add_argument('src', help='src model path or url')
# The dst path must be a full path of the new checkpoint.
parser.add_argument('dst', help='save path')
args = parser.parse_args()
return parser.parse_args()
checkpoint = CheckpointLoader.load_checkpoint(args.src, map_location='cpu')
if 'state_dict' in checkpoint:
state_dict = checkpoint['state_dict']
elif 'model' in checkpoint:
state_dict = checkpoint['model']
else:
state_dict = checkpoint
weight = convert_mit(state_dict)
mmcv.mkdir_or_exist(osp.dirname(args.dst))
torch.save(weight, args.dst)
if __name__ == '__main__':
args = parse_args()
src_path = args.src
dst_path = args.dst
ckpt = torch.load(src_path, map_location='cpu')
ckpt = convert_mit(ckpt)
torch.save(ckpt, dst_path)
main()

View File

@ -1,7 +1,11 @@
# Copyright (c) OpenMMLab. All rights reserved.
import argparse
import os.path as osp
from collections import OrderedDict
import mmcv
import torch
from mmcv.runner import CheckpointLoader
def convert_swin(ckpt):
@ -62,12 +66,12 @@ def main():
parser = argparse.ArgumentParser(
description='Convert keys in official pretrained swin models to'
'MMSegmentation style.')
parser.add_argument('src', help='src segmentation model path')
parser.add_argument('src', help='src model path or url')
# The dst path must be a full path of the new checkpoint.
parser.add_argument('dst', help='save path')
args = parser.parse_args()
checkpoint = torch.load(args.src, map_location='cpu')
checkpoint = CheckpointLoader.load_checkpoint(args.src, map_location='cpu')
if 'state_dict' in checkpoint:
state_dict = checkpoint['state_dict']
elif 'model' in checkpoint:
@ -75,8 +79,8 @@ def main():
else:
state_dict = checkpoint
weight = convert_swin(state_dict)
with open(args.dst, 'wb') as f:
torch.save(weight, f)
mmcv.mkdir_or_exist(osp.dirname(args.dst))
torch.save(weight, args.dst)
if __name__ == '__main__':

View File

@ -1,7 +1,11 @@
# Copyright (c) OpenMMLab. All rights reserved.
import argparse
import os.path as osp
from collections import OrderedDict
import mmcv
import torch
from mmcv.runner import CheckpointLoader
def convert_vit(ckpt):
@ -43,12 +47,12 @@ def main():
parser = argparse.ArgumentParser(
description='Convert keys in timm pretrained vit models to '
'MMSegmentation style.')
parser.add_argument('src', help='src segmentation model path')
parser.add_argument('src', help='src model path or url')
# The dst path must be a full path of the new checkpoint.
parser.add_argument('dst', help='save path')
args = parser.parse_args()
checkpoint = torch.load(args.src, map_location='cpu')
checkpoint = CheckpointLoader.load_checkpoint(args.src, map_location='cpu')
if 'state_dict' in checkpoint:
# timm checkpoint
state_dict = checkpoint['state_dict']
@ -58,8 +62,8 @@ def main():
else:
state_dict = checkpoint
weight = convert_vit(state_dict)
with open(args.dst, 'wb') as f:
torch.save(weight, f)
mmcv.mkdir_or_exist(osp.dirname(args.dst))
torch.save(weight, args.dst)
if __name__ == '__main__':