mirror of https://github.com/open-mmlab/mmyolo.git
185 lines
7.6 KiB
Python
185 lines
7.6 KiB
Python
import argparse
|
|
import pickle
|
|
from collections import OrderedDict
|
|
|
|
import torch
|
|
|
|
|
|
def convert_bn(k: str):
|
|
name = k.replace('._mean',
|
|
'.running_mean').replace('._variance', '.running_var')
|
|
return name
|
|
|
|
|
|
def convert_repvgg(k: str):
|
|
if '.conv2.conv1.' in k:
|
|
name = k.replace('.conv2.conv1.', '.conv2.rbr_dense.')
|
|
return name
|
|
elif '.conv2.conv2.' in k:
|
|
name = k.replace('.conv2.conv2.', '.conv2.rbr_1x1.')
|
|
return name
|
|
else:
|
|
return k
|
|
|
|
|
|
def convert(src: str, dst: str, imagenet_pretrain: bool = False):
|
|
with open(src, 'rb') as f:
|
|
model = pickle.load(f)
|
|
|
|
new_state_dict = OrderedDict()
|
|
if imagenet_pretrain:
|
|
for k, v in model.items():
|
|
if '@@' in k:
|
|
continue
|
|
if 'stem.' in k:
|
|
# backbone.stem.conv1.conv.weight
|
|
# -> backbone.stem.0.conv.weight
|
|
org_ind = k.split('.')[1][-1]
|
|
new_ind = str(int(org_ind) - 1)
|
|
name = k.replace('stem.conv%s.' % org_ind,
|
|
'stem.%s.' % new_ind)
|
|
else:
|
|
# backbone.stages.1.conv2.bn._variance
|
|
# -> backbone.stage2.0.conv2.bn.running_var
|
|
org_stage_ind = k.split('.')[1]
|
|
new_stage_ind = str(int(org_stage_ind) + 1)
|
|
name = k.replace('stages.%s.' % org_stage_ind,
|
|
'stage%s.0.' % new_stage_ind)
|
|
name = convert_repvgg(name)
|
|
if '.attn.' in k:
|
|
name = name.replace('.attn.fc.', '.attn.fc.conv.')
|
|
name = convert_bn(name)
|
|
name = 'backbone.' + name
|
|
|
|
new_state_dict[name] = torch.from_numpy(v)
|
|
else:
|
|
for k, v in model.items():
|
|
name = k
|
|
if k.startswith('backbone.'):
|
|
if '.stem.' in k:
|
|
# backbone.stem.conv1.conv.weight
|
|
# -> backbone.stem.0.conv.weight
|
|
org_ind = k.split('.')[2][-1]
|
|
new_ind = str(int(org_ind) - 1)
|
|
name = k.replace('.stem.conv%s.' % org_ind,
|
|
'.stem.%s.' % new_ind)
|
|
else:
|
|
# backbone.stages.1.conv2.bn._variance
|
|
# -> backbone.stage2.0.conv2.bn.running_var
|
|
org_stage_ind = k.split('.')[2]
|
|
new_stage_ind = str(int(org_stage_ind) + 1)
|
|
name = k.replace('.stages.%s.' % org_stage_ind,
|
|
'.stage%s.0.' % new_stage_ind)
|
|
name = convert_repvgg(name)
|
|
if '.attn.' in k:
|
|
name = name.replace('.attn.fc.', '.attn.fc.conv.')
|
|
name = convert_bn(name)
|
|
elif k.startswith('neck.'):
|
|
# fpn_stages
|
|
if k.startswith('neck.fpn_stages.'):
|
|
# neck.fpn_stages.0.0.conv1.conv.weight
|
|
# -> neck.reduce_layers.2.0.conv1.conv.weight
|
|
if k.startswith('neck.fpn_stages.0.0.'):
|
|
name = k.replace('neck.fpn_stages.0.0.',
|
|
'neck.reduce_layers.2.0.')
|
|
if '.spp.' in name:
|
|
name = name.replace('.spp.conv.', '.spp.conv2.')
|
|
# neck.fpn_stages.1.0.conv1.conv.weight
|
|
# -> neck.top_down_layers.0.0.conv1.conv.weight
|
|
elif k.startswith('neck.fpn_stages.1.0.'):
|
|
name = k.replace('neck.fpn_stages.1.0.',
|
|
'neck.top_down_layers.0.0.')
|
|
elif k.startswith('neck.fpn_stages.2.0.'):
|
|
name = k.replace('neck.fpn_stages.2.0.',
|
|
'neck.top_down_layers.1.0.')
|
|
else:
|
|
raise NotImplementedError('Not implemented.')
|
|
name = name.replace('.0.convs.', '.0.blocks.')
|
|
elif k.startswith('neck.fpn_routes.'):
|
|
# neck.fpn_routes.0.conv.weight
|
|
# -> neck.upsample_layers.0.0.conv.weight
|
|
index = k.split('.')[2]
|
|
name = 'neck.upsample_layers.' + index + '.0.' + '.'.join(
|
|
k.split('.')[-2:])
|
|
name = name.replace('.0.convs.', '.0.blocks.')
|
|
elif k.startswith('neck.pan_stages.'):
|
|
# neck.pan_stages.0.0.conv1.conv.weight
|
|
# -> neck.bottom_up_layers.1.0.conv1.conv.weight
|
|
ind = k.split('.')[2]
|
|
name = k.replace(
|
|
'neck.pan_stages.' + ind, 'neck.bottom_up_layers.' +
|
|
('0' if ind == '1' else '1'))
|
|
name = name.replace('.0.convs.', '.0.blocks.')
|
|
elif k.startswith('neck.pan_routes.'):
|
|
# neck.pan_routes.0.conv.weight
|
|
# -> neck.downsample_layers.0.conv.weight
|
|
ind = k.split('.')[2]
|
|
name = k.replace(
|
|
'neck.pan_routes.' + ind, 'neck.downsample_layers.' +
|
|
('0' if ind == '1' else '1'))
|
|
name = name.replace('.0.convs.', '.0.blocks.')
|
|
|
|
else:
|
|
raise NotImplementedError('Not implement.')
|
|
name = convert_repvgg(name)
|
|
name = convert_bn(name)
|
|
elif k.startswith('yolo_head.'):
|
|
if ('anchor_points' in k) or ('stride_tensor' in k):
|
|
continue
|
|
if 'proj_conv' in k:
|
|
name = k.replace('yolo_head.proj_conv.',
|
|
'bbox_head.head_module.proj_conv.')
|
|
else:
|
|
for org_key, rep_key in [
|
|
[
|
|
'yolo_head.stem_cls.',
|
|
'bbox_head.head_module.cls_stems.'
|
|
],
|
|
[
|
|
'yolo_head.stem_reg.',
|
|
'bbox_head.head_module.reg_stems.'
|
|
],
|
|
[
|
|
'yolo_head.pred_cls.',
|
|
'bbox_head.head_module.cls_preds.'
|
|
],
|
|
[
|
|
'yolo_head.pred_reg.',
|
|
'bbox_head.head_module.reg_preds.'
|
|
]
|
|
]:
|
|
name = name.replace(org_key, rep_key)
|
|
name = name.split('.')
|
|
ind = name[3]
|
|
name[3] = str(2 - int(ind))
|
|
name = '.'.join(name)
|
|
name = convert_bn(name)
|
|
else:
|
|
continue
|
|
|
|
new_state_dict[name] = torch.from_numpy(v)
|
|
data = {'state_dict': new_state_dict}
|
|
torch.save(data, dst)
|
|
|
|
|
|
def main():
|
|
parser = argparse.ArgumentParser(description='Convert model keys')
|
|
parser.add_argument(
|
|
'--src',
|
|
default='ppyoloe_plus_crn_s_80e_coco.pdparams',
|
|
help='src ppyoloe model path')
|
|
parser.add_argument(
|
|
'--dst', default='mmppyoloe_plus_s.pt', help='save path')
|
|
parser.add_argument(
|
|
'--imagenet-pretrain',
|
|
action='store_true',
|
|
default=False,
|
|
help='Load model pretrained on imagenet dataset which only '
|
|
'have weight for backbone.')
|
|
args = parser.parse_args()
|
|
convert(args.src, args.dst, args.imagenet_pretrain)
|
|
|
|
|
|
if __name__ == '__main__':
|
|
main()
|