40 lines
1.4 KiB
Python
40 lines
1.4 KiB
Python
#!/usr/bin/env python
|
|
|
|
# Copyright (c) Facebook, Inc. and its affiliates.
|
|
# All rights reserved.
|
|
|
|
# This source code is licensed under the license found in the
|
|
# LICENSE file in the root directory of this source tree.
|
|
|
|
import argparse
|
|
import os
|
|
import torch
|
|
|
|
|
|
if __name__ == '__main__':
|
|
parser = argparse.ArgumentParser(description='Convert MoCo Pre-Traind Model to DEiT')
|
|
parser.add_argument('--input', default='', type=str, metavar='PATH', required=True,
|
|
help='path to moco pre-trained checkpoint')
|
|
parser.add_argument('--output', default='', type=str, metavar='PATH', required=True,
|
|
help='path to output checkpoint in DEiT format')
|
|
args = parser.parse_args()
|
|
print(args)
|
|
|
|
# load input
|
|
checkpoint = torch.load(args.input, map_location="cpu")
|
|
state_dict = checkpoint['state_dict']
|
|
for k in list(state_dict.keys()):
|
|
# retain only base_encoder up to before the embedding layer
|
|
if k.startswith('module.base_encoder') and not k.startswith('module.base_encoder.head'):
|
|
# remove prefix
|
|
state_dict[k[len("module.base_encoder."):]] = state_dict[k]
|
|
# delete renamed or unused k
|
|
del state_dict[k]
|
|
|
|
# make output directory if necessary
|
|
output_dir = os.path.dirname(args.output)
|
|
if not os.path.isdir(output_dir):
|
|
os.makedirs(output_dir)
|
|
# save to output
|
|
torch.save({'model': state_dict}, args.output)
|