mmrazor/tools/model_converters/convert_attentivenas_nas_ck...

171 lines
7.4 KiB
Python

# Copyright (c) OpenMMLab. All rights reserved.
import argparse
from pathlib import Path
import torch
def parse_args():
parser = argparse.ArgumentParser(
description='Process a checkpoint to be published')
parser.add_argument('checkpoint', help='input checkpoint filename')
parser.add_argument(
'--inplace', action='store_true', help='replace origin ckpt')
args = parser.parse_args()
return args
def main():
args = parse_args()
checkpoint = torch.load(args.checkpoint, map_location='cpu')
new_state_dict = dict()
for key, value in checkpoint['state_dict'].items():
key = key.replace('module.', 'architecture.backbone.')
if 'blocks.10' in key:
new_key = key.replace('blocks.10', 'layer3.3')
elif 'blocks.11' in key:
new_key = key.replace('blocks.11', 'layer3.4')
elif 'blocks.12' in key:
new_key = key.replace('blocks.12', 'layer3.5')
elif 'blocks.13' in key:
new_key = key.replace('blocks.13', 'layer4.0')
elif 'blocks.14' in key:
new_key = key.replace('blocks.14', 'layer4.1')
elif 'blocks.15' in key:
new_key = key.replace('blocks.15', 'layer4.2')
elif 'blocks.16' in key:
new_key = key.replace('blocks.16', 'layer4.3')
elif 'blocks.17' in key:
new_key = key.replace('blocks.17', 'layer4.4')
elif 'blocks.18' in key:
new_key = key.replace('blocks.18', 'layer4.5')
elif 'blocks.19' in key:
new_key = key.replace('blocks.19', 'layer5.0')
elif 'blocks.20' in key:
new_key = key.replace('blocks.20', 'layer5.1')
elif 'blocks.21' in key:
new_key = key.replace('blocks.21', 'layer5.2')
elif 'blocks.22' in key:
new_key = key.replace('blocks.22', 'layer5.3')
elif 'blocks.23' in key:
new_key = key.replace('blocks.23', 'layer5.4')
elif 'blocks.24' in key:
new_key = key.replace('blocks.24', 'layer5.5')
elif 'blocks.25' in key:
new_key = key.replace('blocks.25', 'layer5.6')
elif 'blocks.26' in key:
new_key = key.replace('blocks.26', 'layer5.7')
elif 'blocks.27' in key:
new_key = key.replace('blocks.27', 'layer6.0')
elif 'blocks.28' in key:
new_key = key.replace('blocks.28', 'layer6.1')
elif 'blocks.29' in key:
new_key = key.replace('blocks.29', 'layer6.2')
elif 'blocks.30' in key:
new_key = key.replace('blocks.30', 'layer6.3')
elif 'blocks.31' in key:
new_key = key.replace('blocks.31', 'layer6.4')
elif 'blocks.32' in key:
new_key = key.replace('blocks.32', 'layer6.5')
elif 'blocks.33' in key:
new_key = key.replace('blocks.33', 'layer6.6')
elif 'blocks.34' in key:
new_key = key.replace('blocks.34', 'layer6.7')
elif 'blocks.35' in key:
new_key = key.replace('blocks.35', 'layer7.0')
elif 'blocks.36' in key:
new_key = key.replace('blocks.36', 'layer7.1')
elif 'blocks.0' in key:
new_key = key.replace('blocks.0', 'layer1.0')
elif 'blocks.1' in key:
new_key = key.replace('blocks.1', 'layer1.1')
elif 'blocks.2' in key:
new_key = key.replace('blocks.2', 'layer2.0')
elif 'blocks.3' in key:
new_key = key.replace('blocks.3', 'layer2.1')
elif 'blocks.4' in key:
new_key = key.replace('blocks.4', 'layer2.2')
elif 'blocks.5' in key:
new_key = key.replace('blocks.5', 'layer2.3')
elif 'blocks.6' in key:
new_key = key.replace('blocks.6', 'layer2.4')
elif 'blocks.7' in key:
new_key = key.replace('blocks.7', 'layer3.0')
elif 'blocks.8' in key:
new_key = key.replace('blocks.8', 'layer3.1')
elif 'blocks.9' in key:
new_key = key.replace('blocks.9', 'layer3.2')
else:
new_key = key
if 'mobile_inverted_conv.depth_conv.conv.conv' in new_key:
final_new_key = new_key.replace(
'mobile_inverted_conv.depth_conv.conv.conv',
'depthwise_conv.conv')
elif 'mobile_inverted_conv.depth_conv.bn.bn' in new_key:
final_new_key = new_key.replace(
'mobile_inverted_conv.depth_conv.bn.bn', 'depthwise_conv.bn')
elif 'mobile_inverted_conv.point_linear.conv.conv' in new_key:
final_new_key = new_key.replace(
'mobile_inverted_conv.point_linear.conv.conv',
'linear_conv.conv')
elif 'mobile_inverted_conv.point_linear.bn.bn' in new_key:
final_new_key = new_key.replace(
'mobile_inverted_conv.point_linear.bn.bn', 'linear_conv.bn')
elif 'shortcut.conv.conv' in new_key:
final_new_key = new_key.replace('shortcut.conv.conv',
'shortcut.conv')
elif 'mobile_inverted_conv.inverted_bottleneck.conv.conv' in new_key:
final_new_key = new_key.replace(
'mobile_inverted_conv.inverted_bottleneck.conv.conv',
'expand_conv.conv')
elif 'mobile_inverted_conv.inverted_bottleneck.bn.bn' in new_key:
final_new_key = new_key.replace(
'mobile_inverted_conv.inverted_bottleneck.bn.bn',
'expand_conv.bn')
elif 'mobile_inverted_conv.depth_conv.se.fc.reduce' in new_key:
final_new_key = new_key.replace(
'mobile_inverted_conv.depth_conv.se.fc.reduce',
'se.conv1.conv')
elif 'mobile_inverted_conv.depth_conv.se.fc.expand' in new_key:
final_new_key = new_key.replace(
'mobile_inverted_conv.depth_conv.se.fc.expand',
'se.conv2.conv')
elif 'first_conv.conv.conv' in new_key:
final_new_key = new_key.replace('first_conv.conv.conv',
'first_conv.conv')
elif 'first_conv.bn.bn' in new_key:
final_new_key = new_key.replace('first_conv.bn.bn',
'first_conv.bn')
elif 'final_expand_layer.conv.conv' in new_key:
final_new_key = new_key.replace('final_expand_layer.conv.conv',
'final_expand_layer.conv')
elif 'final_expand_layer.bn.bn' in new_key:
final_new_key = new_key.replace('final_expand_layer.bn.bn',
'final_expand_layer.bn')
elif 'feature_mix_layer.conv.conv' in new_key:
final_new_key = new_key.replace('feature_mix_layer.conv.conv',
'feature_mix_layer.conv')
elif 'classifier.linear.linear' in new_key:
final_new_key = new_key.replace(
'backbone.classifier.linear.linear', 'head.fc')
else:
final_new_key = new_key
new_state_dict[final_new_key] = value
checkpoint['state_dict'] = new_state_dict
if args.inplace:
torch.save(checkpoint, args.checkpoint)
else:
ckpt_path = Path(args.checkpoint)
ckpt_name = ckpt_path.stem
ckpt_dir = ckpt_path.parent
new_ckpt_path = ckpt_dir / f'{ckpt_name}_latest.pth'
torch.save(checkpoint, new_ckpt_path)
if __name__ == '__main__':
main()