add model convert (#8)

pull/11/head
wanghonglie 2022-09-18 11:12:08 +08:00 committed by GitHub
parent 8945c76f81
commit b3d405aa4c
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 258 additions and 0 deletions

View File

@ -0,0 +1,77 @@
# Copyright (c) OpenMMLab. All rights reserved.
import argparse
from collections import OrderedDict
import torch
convert_dict = {
'model.0': 'backbone.stem',
'model.1': 'backbone.stage1.0',
'model.2': 'backbone.stage1.1',
'model.3': 'backbone.stage2.0',
'model.4': 'backbone.stage2.1',
'model.5': 'backbone.stage3.0',
'model.6': 'backbone.stage3.1',
'model.7': 'backbone.stage4.0',
'model.8': 'backbone.stage4.1',
'model.9.cv1': 'backbone.stage4.2.conv1',
'model.9.cv2': 'backbone.stage4.2.conv2',
'model.10': 'neck.reduce_layers.2',
'model.13': 'neck.top_down_layers.0.0',
'model.14': 'neck.top_down_layers.0.1',
'model.17': 'neck.top_down_layers.1',
'model.18': 'neck.downsample_layers.0',
'model.20': 'neck.bottom_up_layers.0',
'model.21': 'neck.downsample_layers.1',
'model.23': 'neck.bottom_up_layers.1',
'model.24.m': 'bbox_head.head_module.convs_pred',
}
def convert(src, dst):
"""Convert keys in detectron pretrained YOLOV5 models to mmyolo style."""
yolov5_model = torch.load(src)['model']
blobs = yolov5_model.state_dict()
state_dict = OrderedDict()
for key, weight in blobs.items():
num, module = key.split('.')[1:3]
if num == '9' or num == '24':
if module == 'anchors':
continue
prefix = f'model.{num}.{module}'
else:
prefix = f'model.{num}'
new_key = key.replace(prefix, convert_dict[prefix])
if '.m.' in new_key:
new_key = new_key.replace('.m.', '.blocks.')
new_key = new_key.replace('.cv', '.conv')
else:
new_key = new_key.replace('.cv1', '.main_conv')
new_key = new_key.replace('.cv2', '.short_conv')
new_key = new_key.replace('.cv3', '.final_conv')
state_dict[new_key] = weight
print(f'Convert {key} to {new_key}')
# save checkpoint
checkpoint = dict()
checkpoint['state_dict'] = state_dict
torch.save(checkpoint, dst)
# Note: This script must be placed under the yolov5 repo to run.
def main():
parser = argparse.ArgumentParser(description='Convert model keys')
parser.add_argument(
'--src', default='yolov5s.pt', help='src yolov5 model path')
parser.add_argument('--dst', default='mmyolov5.pt', help='save path')
args = parser.parse_args()
convert(args.src, args.dst)
if __name__ == '__main__':
main()

View File

@ -0,0 +1,71 @@
import argparse
from collections import OrderedDict
import torch
def convert(src, dst):
ckpt = torch.load(src, map_location=torch.device('cpu'))
# The saved model is the model before reparameterization
model = ckpt['ema' if ckpt.get('ema') else 'model'].float()
new_state_dict = OrderedDict()
for k, v in model.state_dict().items():
name = k
if 'detect' in k:
if 'proj' in k:
continue
name = k.replace('detect', 'bbox_head.head_module')
if k.find('anchors') >= 0 or k.find('anchor_grid') >= 0:
continue
if 'ERBlock_2' in k:
name = k.replace('ERBlock_2', 'stage1.0')
elif 'ERBlock_3' in k:
name = k.replace('ERBlock_3', 'stage2.0')
elif 'ERBlock_4' in k:
name = k.replace('ERBlock_4', 'stage3.0')
elif 'ERBlock_5' in k:
name = k.replace('ERBlock_5', 'stage4.0')
if 'stage4.0.2' in name:
name = name.replace('stage4.0.2', 'stage4.1')
name = name.replace('cv', 'conv')
elif 'reduce_layer0' in k:
name = k.replace('reduce_layer0', 'reduce_layers.2')
elif 'Rep_p4' in k:
name = k.replace('Rep_p4', 'top_down_layers.0.0')
elif 'reduce_layer1' in k:
name = k.replace('reduce_layer1', 'top_down_layers.0.1')
elif 'Rep_p3' in k:
name = k.replace('Rep_p3', 'top_down_layers.1')
elif 'upsample0' in k:
name = k.replace('upsample0.upsample_transpose',
'upsample_layers.0')
elif 'upsample1' in k:
name = k.replace('upsample1.upsample_transpose',
'upsample_layers.1')
elif 'Rep_n3' in k:
name = k.replace('Rep_n3', 'bottom_up_layers.0')
elif 'Rep_n4' in k:
name = k.replace('Rep_n4', 'bottom_up_layers.1')
elif 'downsample2' in k:
name = k.replace('downsample2', 'downsample_layers.0')
elif 'downsample1' in k:
name = k.replace('downsample1', 'downsample_layers.1')
new_state_dict[name] = v
data = {'state_dict': new_state_dict}
torch.save(data, dst)
# Note: This script must be placed under the yolov6 repo to run.
def main():
parser = argparse.ArgumentParser(description='Convert model keys')
parser.add_argument(
'--src', default='yolov6s.pt', help='src yolov6 model path')
parser.add_argument('--dst', default='mmyolov6.pt', help='save path')
args = parser.parse_args()
convert(args.src, args.dst)
if __name__ == '__main__':
main()

View File

@ -0,0 +1,110 @@
# Copyright (c) OpenMMLab. All rights reserved.
import argparse
from collections import OrderedDict
import torch
neck_dict = {
'backbone.lateral_conv0': 'neck.reduce_layers.2',
'backbone.C3_p4.conv': 'neck.top_down_layers.0.0.cv',
'backbone.C3_p4.m.0.': 'neck.top_down_layers.0.0.m.0.',
'backbone.reduce_conv1': 'neck.top_down_layers.0.1',
'backbone.C3_p3.conv': 'neck.top_down_layers.1.cv',
'backbone.C3_p3.m.0.': 'neck.top_down_layers.1.m.0.',
'backbone.bu_conv2': 'neck.downsample_layers.0',
'backbone.C3_n3.conv': 'neck.bottom_up_layers.0.cv',
'backbone.C3_n3.m.0.': 'neck.bottom_up_layers.0.m.0.',
'backbone.bu_conv1': 'neck.downsample_layers.1',
'backbone.C3_n4.conv': 'neck.bottom_up_layers.1.cv',
'backbone.C3_n4.m.0.': 'neck.bottom_up_layers.1.m.0.',
}
def convert_stem(model_key, model_weight, state_dict, converted_names):
new_key = model_key[9:]
state_dict[new_key] = model_weight
converted_names.add(model_key)
print(f'Convert {model_key} to {new_key}')
def convert_backbone(model_key, model_weight, state_dict, converted_names):
new_key = model_key.replace('backbone.dark', 'stage')
num = int(new_key[14]) - 1
new_key = new_key[:14] + str(num) + new_key[15:]
if '.m.' in model_key:
new_key = new_key.replace('.m.', '.blocks.')
elif not new_key[16] == '0' and 'stage4.1' not in new_key:
new_key = new_key.replace('conv1', 'main_conv')
new_key = new_key.replace('conv2', 'short_conv')
new_key = new_key.replace('conv3', 'final_conv')
state_dict[new_key] = model_weight
converted_names.add(model_key)
print(f'Convert {model_key} to {new_key}')
def convert_neck(model_key, model_weight, state_dict, converted_names):
for old, new in neck_dict.items():
if old in model_key:
new_key = model_key.replace(old, new)
if '.m.' in model_key:
new_key = new_key.replace('.m.', '.blocks.')
elif '.C' in model_key:
new_key = new_key.replace('cv1', 'main_conv')
new_key = new_key.replace('cv2', 'short_conv')
new_key = new_key.replace('cv3', 'final_conv')
state_dict[new_key] = model_weight
converted_names.add(model_key)
print(f'Convert {model_key} to {new_key}')
def convert_head(model_key, model_weight, state_dict, converted_names):
if 'stem' in model_key:
new_key = model_key.replace('head.stem', 'neck.out_layer')
elif 'cls_convs' in model_key:
new_key = model_key.replace(
'head.cls_convs', 'bbox_head.head_module.multi_level_cls_convs')
elif 'reg_convs' in model_key:
new_key = model_key.replace(
'head.reg_convs', 'bbox_head.head_module.multi_level_reg_convs')
elif 'preds' in model_key:
new_key = model_key.replace('head.',
'bbox_head.head_module.multi_level_conv_')
new_key = new_key.replace('_preds', '')
state_dict[new_key] = model_weight
converted_names.add(model_key)
print(f'Convert {model_key} to {new_key}')
def convert(src, dst):
"""Convert keys in detectron pretrained YOLOX models to mmyolo style."""
blobs = torch.load(src)['model']
state_dict = OrderedDict()
converted_names = set()
for key, weight in blobs.items():
if 'backbone.stem' in key:
convert_stem(key, weight, state_dict, converted_names)
elif 'backbone.backbone' in key:
convert_backbone(key, weight, state_dict, converted_names)
elif 'backbone.neck' not in key and 'head' not in key:
convert_neck(key, weight, state_dict, converted_names)
elif 'head' in key:
convert_head(key, weight, state_dict, converted_names)
# save checkpoint
checkpoint = dict()
checkpoint['state_dict'] = state_dict
torch.save(checkpoint, dst)
def main():
parser = argparse.ArgumentParser(description='Convert model keys')
parser.add_argument(
'--src', default='yolox_s.pth', help='src yolox model path')
parser.add_argument('--dst', default='mmyoloxs.pt', help='save path')
args = parser.parse_args()
convert(args.src, args.dst)
if __name__ == '__main__':
main()