rstrudel ee47c41740 [Feature] Support Segmenter (#955)
* segmenter: add model

* update

* readme: update

* config: update

* segmenter: update readme

* segmenter: update

* segmenter: update

* segmenter: update

* configs: set checkpoint path to pretrain folder

* segmenter: modify vit-s/lin, remove data config

* rreadme: update

* configs: transfer from _base_ to segmenter

* configs: add 8x1 suffix

* configs: remove redundant lines

* configs: cleanup

* first attempt

* swipe CI error

* Update mmseg/models/decode_heads/__init__.py

Co-authored-by: Junjun2016 <hejunjun@sjtu.edu.cn>

* segmenter_linear: use fcn backbone

* segmenter_mask: update

* models: add segmenter vit

* decoders: yapf+remove unused imports

* apply precommit

* segmenter/linear_head: fix

* segmenter/linear_header: fix

* segmenter: fix mask transformer

* fix error

* segmenter/mask_head: use trunc_normal init

* refactor segmenter head

* Fetch upstream (#1)

* [Feature] Change options to cfg-option (#1129)

* [Feature] Change option to cfg-option

* add expire date and fix the docs

* modify docstring

* [Fix] Add <!-- [ABSTRACT] --> in metafile #1127

* [Fix] Fix correct num_classes of HRNet in LoveDA dataset #1136

* Bump to v0.20.1 (#1138)

* bump version 0.20.1

* bump version 0.20.1

* [Fix] revise --option to --options #1140

Co-authored-by: Rockey <41846794+RockeyCoss@users.noreply.github.com>
Co-authored-by: MengzhangLI <mcmong@pku.edu.cn>

* decode_head: switch from linear to fcn

* fix init list formatting

* configs: remove variants, keep only vit-s on ade

* align inference metric of vit-s-mask

* configs: add vit t/b/l

* Update mmseg/models/decode_heads/segmenter_mask_head.py

Co-authored-by: Miao Zheng <76149310+MeowZheng@users.noreply.github.com>

* Update mmseg/models/decode_heads/segmenter_mask_head.py

Co-authored-by: Miao Zheng <76149310+MeowZheng@users.noreply.github.com>

* Update mmseg/models/decode_heads/segmenter_mask_head.py

Co-authored-by: Miao Zheng <76149310+MeowZheng@users.noreply.github.com>

* Update mmseg/models/decode_heads/segmenter_mask_head.py

Co-authored-by: Miao Zheng <76149310+MeowZheng@users.noreply.github.com>

* Update mmseg/models/decode_heads/segmenter_mask_head.py

Co-authored-by: Miao Zheng <76149310+MeowZheng@users.noreply.github.com>

* model_converters: use torch instead of einops

* setup: remove einops

* segmenter_mask: fix missing imports

* add necessary imported init funtion

* segmenter/seg-l: set resolution to 640

* segmenter/seg-l: fix test size

* fix vitjax2mmseg

* add README and unittest

* fix unittest

* add docstring

* refactor config and add pretrained link

* fix typo

* add paper name in readme

* change segmenter config names

* fix typo in readme

* fix typos in readme

* fix segmenter typo

* fix segmenter typo

* delete redundant comma in config files

* delete redundant comma in config files

* fix convert script

* update lateset master version

Co-authored-by: MengzhangLI <mcmong@pku.edu.cn>
Co-authored-by: Junjun2016 <hejunjun@sjtu.edu.cn>
Co-authored-by: Rockey <41846794+RockeyCoss@users.noreply.github.com>
Co-authored-by: Miao Zheng <76149310+MeowZheng@users.noreply.github.com>
2022-01-26 13:50:51 +08:00

123 lines
4.5 KiB
Python

import argparse
import os.path as osp
import mmcv
import numpy as np
import torch
def vit_jax_to_torch(jax_weights, num_layer=12):
torch_weights = dict()
# patch embedding
conv_filters = jax_weights['embedding/kernel']
conv_filters = conv_filters.permute(3, 2, 0, 1)
torch_weights['patch_embed.projection.weight'] = conv_filters
torch_weights['patch_embed.projection.bias'] = jax_weights[
'embedding/bias']
# pos embedding
torch_weights['pos_embed'] = jax_weights[
'Transformer/posembed_input/pos_embedding']
# cls token
torch_weights['cls_token'] = jax_weights['cls']
# head
torch_weights['ln1.weight'] = jax_weights['Transformer/encoder_norm/scale']
torch_weights['ln1.bias'] = jax_weights['Transformer/encoder_norm/bias']
# transformer blocks
for i in range(num_layer):
jax_block = f'Transformer/encoderblock_{i}'
torch_block = f'layers.{i}'
# attention norm
torch_weights[f'{torch_block}.ln1.weight'] = jax_weights[
f'{jax_block}/LayerNorm_0/scale']
torch_weights[f'{torch_block}.ln1.bias'] = jax_weights[
f'{jax_block}/LayerNorm_0/bias']
# attention
query_weight = jax_weights[
f'{jax_block}/MultiHeadDotProductAttention_1/query/kernel']
query_bias = jax_weights[
f'{jax_block}/MultiHeadDotProductAttention_1/query/bias']
key_weight = jax_weights[
f'{jax_block}/MultiHeadDotProductAttention_1/key/kernel']
key_bias = jax_weights[
f'{jax_block}/MultiHeadDotProductAttention_1/key/bias']
value_weight = jax_weights[
f'{jax_block}/MultiHeadDotProductAttention_1/value/kernel']
value_bias = jax_weights[
f'{jax_block}/MultiHeadDotProductAttention_1/value/bias']
qkv_weight = torch.from_numpy(
np.stack((query_weight, key_weight, value_weight), 1))
qkv_weight = torch.flatten(qkv_weight, start_dim=1)
qkv_bias = torch.from_numpy(
np.stack((query_bias, key_bias, value_bias), 0))
qkv_bias = torch.flatten(qkv_bias, start_dim=0)
torch_weights[f'{torch_block}.attn.attn.in_proj_weight'] = qkv_weight
torch_weights[f'{torch_block}.attn.attn.in_proj_bias'] = qkv_bias
to_out_weight = jax_weights[
f'{jax_block}/MultiHeadDotProductAttention_1/out/kernel']
to_out_weight = torch.flatten(to_out_weight, start_dim=0, end_dim=1)
torch_weights[
f'{torch_block}.attn.attn.out_proj.weight'] = to_out_weight
torch_weights[f'{torch_block}.attn.attn.out_proj.bias'] = jax_weights[
f'{jax_block}/MultiHeadDotProductAttention_1/out/bias']
# mlp norm
torch_weights[f'{torch_block}.ln2.weight'] = jax_weights[
f'{jax_block}/LayerNorm_2/scale']
torch_weights[f'{torch_block}.ln2.bias'] = jax_weights[
f'{jax_block}/LayerNorm_2/bias']
# mlp
torch_weights[f'{torch_block}.ffn.layers.0.0.weight'] = jax_weights[
f'{jax_block}/MlpBlock_3/Dense_0/kernel']
torch_weights[f'{torch_block}.ffn.layers.0.0.bias'] = jax_weights[
f'{jax_block}/MlpBlock_3/Dense_0/bias']
torch_weights[f'{torch_block}.ffn.layers.1.weight'] = jax_weights[
f'{jax_block}/MlpBlock_3/Dense_1/kernel']
torch_weights[f'{torch_block}.ffn.layers.1.bias'] = jax_weights[
f'{jax_block}/MlpBlock_3/Dense_1/bias']
# transpose weights
for k, v in torch_weights.items():
if 'weight' in k and 'patch_embed' not in k and 'ln' not in k:
v = v.permute(1, 0)
torch_weights[k] = v
return torch_weights
def main():
# stole refactoring code from Robin Strudel, thanks
parser = argparse.ArgumentParser(
description='Convert keys from jax official pretrained vit models to '
'MMSegmentation style.')
parser.add_argument('src', help='src model path or url')
# The dst path must be a full path of the new checkpoint.
parser.add_argument('dst', help='save path')
args = parser.parse_args()
jax_weights = np.load(args.src)
jax_weights_tensor = {}
for key in jax_weights.files:
value = torch.from_numpy(jax_weights[key])
jax_weights_tensor[key] = value
if 'L_16-i21k' in args.src:
num_layer = 24
else:
num_layer = 12
torch_weights = vit_jax_to_torch(jax_weights_tensor, num_layer)
mmcv.mkdir_or_exist(osp.dirname(args.dst))
torch.save(torch_weights, args.dst)
if __name__ == '__main__':
main()