101 lines
3.8 KiB
Python
101 lines
3.8 KiB
Python
# Copyright (c) OpenMMLab. All rights reserved.
|
||
"""convert the weights of efficientnetv2 in
|
||
timm(https://github.com/rwightman/pytorch-image-models) to mmpretrain
|
||
format."""
|
||
import argparse
|
||
import os.path as osp
|
||
|
||
import mmengine
|
||
import torch
|
||
from mmengine.runner import CheckpointLoader
|
||
|
||
|
||
def convert_from_efficientnetv2_timm(param):
|
||
# main change_key
|
||
param_lst = list(param.keys())
|
||
op = str(int(param_lst[-9][7]) + 2)
|
||
new_key = dict()
|
||
for name in param_lst:
|
||
data = param[name]
|
||
if 'blocks' not in name:
|
||
if 'conv_stem' in name:
|
||
name = name.replace('conv_stem', 'backbone.layers.0.conv')
|
||
if 'bn1' in name:
|
||
name = name.replace('bn1', 'backbone.layers.0.bn')
|
||
if 'conv_head' in name:
|
||
# if efficientnet-v2_s/base/b1/b2/b3,op = 7,
|
||
# if for m/l/xl , op = 8
|
||
name = name.replace('conv_head', f'backbone.layers.{op}.conv')
|
||
if 'bn2' in name:
|
||
name = name.replace('bn2', f'backbone.layers.{op}.bn')
|
||
if 'classifier' in name:
|
||
name = name.replace('classifier', 'head.fc')
|
||
else:
|
||
operator = int(name[7])
|
||
if operator == 0:
|
||
name = name[:7] + str(operator + 1) + name[8:]
|
||
name = name.replace('blocks', 'backbone.layers')
|
||
if 'conv' in name:
|
||
name = name.replace('conv', 'conv')
|
||
if 'bn1' in name:
|
||
name = name.replace('bn1', 'bn')
|
||
elif operator < 3:
|
||
name = name[:7] + str(operator + 1) + name[8:]
|
||
name = name.replace('blocks', 'backbone.layers')
|
||
if 'conv_exp' in name:
|
||
name = name.replace('conv_exp', 'conv1.conv')
|
||
if 'conv_pwl' in name:
|
||
name = name.replace('conv_pwl', 'conv2.conv')
|
||
if 'bn1' in name:
|
||
name = name.replace('bn1', 'conv1.bn')
|
||
if 'bn2' in name:
|
||
name = name.replace('bn2', 'conv2.bn')
|
||
else:
|
||
name = name[:7] + str(operator + 1) + name[8:]
|
||
name = name.replace('blocks', 'backbone.layers')
|
||
if 'conv_pwl' in name:
|
||
name = name.replace('conv_pwl', 'linear_conv.conv')
|
||
if 'conv_pw' in name:
|
||
name = name.replace('conv_pw', 'expand_conv.conv')
|
||
if 'conv_dw' in name:
|
||
name = name.replace('conv_dw', 'depthwise_conv.conv')
|
||
if 'bn1' in name:
|
||
name = name.replace('bn1', 'expand_conv.bn')
|
||
if 'bn2' in name:
|
||
name = name.replace('bn2', 'depthwise_conv.bn')
|
||
if 'bn3' in name:
|
||
name = name.replace('bn3', 'linear_conv.bn')
|
||
if 'se.conv_reduce' in name:
|
||
name = name.replace('se.conv_reduce', 'se.conv1.conv')
|
||
if 'se.conv_expand' in name:
|
||
name = name.replace('se.conv_expand', 'se.conv2.conv')
|
||
new_key[name] = data
|
||
return new_key
|
||
|
||
|
||
def main():
|
||
parser = argparse.ArgumentParser(
|
||
description='Convert pretrained efficientnetv2 '
|
||
'models in timm to mmpretrain style.')
|
||
parser.add_argument('src', help='src model path or url')
|
||
# The dst path must be a full path of the new checkpoint.
|
||
parser.add_argument('dst', help='save path')
|
||
args = parser.parse_args()
|
||
|
||
checkpoint = CheckpointLoader.load_checkpoint(args.src, map_location='cpu')
|
||
|
||
if 'state_dict' in checkpoint:
|
||
state_dict = checkpoint['state_dict']
|
||
else:
|
||
state_dict = checkpoint
|
||
|
||
weight = convert_from_efficientnetv2_timm(state_dict)
|
||
mmengine.mkdir_or_exist(osp.dirname(args.dst))
|
||
torch.save(weight, args.dst)
|
||
|
||
print('Done!!')
|
||
|
||
|
||
if __name__ == '__main__':
|
||
main()
|