mmselfsup/tools/extract_backbone_weights.py

32 lines
952 B
Python
Raw Normal View History

2020-06-16 00:05:18 +08:00
import torch
import argparse
def parse_args():
parser = argparse.ArgumentParser(
description='This script extracts backbone weights from a checkpoint')
parser.add_argument('checkpoint', help='checkpoint file')
parser.add_argument(
2020-06-17 01:31:59 +08:00
'output', type=str, help='destination file name')
2020-06-16 00:05:18 +08:00
args = parser.parse_args()
return args
def main():
args = parse_args()
2020-06-17 01:31:59 +08:00
assert args.output.endswith(".pth")
2020-06-16 00:05:18 +08:00
ck = torch.load(args.checkpoint, map_location=torch.device('cpu'))
output_dict = dict(state_dict=dict(), author="OpenSelfSup")
has_backbone = False
for key, value in ck['state_dict'].items():
if key.startswith('backbone'):
output_dict['state_dict'][key[9:]] = value
has_backbone = True
if not has_backbone:
raise Exception("Cannot find a backbone module in the checkpoint.")
2020-06-17 01:31:59 +08:00
torch.save(output_dict, args.output)
2020-06-16 00:05:18 +08:00
if __name__ == '__main__':
main()