38 lines
979 B
Python
38 lines
979 B
Python
#!/usr/bin/env python
|
|
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
|
|
|
|
import pickle as pkl
|
|
import sys
|
|
|
|
import torch
|
|
|
|
if __name__ == '__main__':
|
|
input = sys.argv[1]
|
|
|
|
obj = torch.load(input, map_location='cpu')
|
|
obj = obj['state_dict']
|
|
|
|
newmodel = {}
|
|
for k, v in obj.items():
|
|
old_k = k
|
|
if 'layer' not in k:
|
|
k = 'stem.' + k
|
|
for t in [1, 2, 3, 4]:
|
|
k = k.replace('layer{}'.format(t), 'res{}'.format(t + 1))
|
|
for t in [1, 2, 3]:
|
|
k = k.replace('bn{}'.format(t), 'conv{}.norm'.format(t))
|
|
k = k.replace('downsample.0', 'shortcut')
|
|
k = k.replace('downsample.1', 'shortcut.norm')
|
|
print(old_k, '->', k)
|
|
newmodel[k] = v.numpy()
|
|
|
|
res = {
|
|
'model': newmodel,
|
|
'__author__': 'MMSelfSup',
|
|
'matching_heuristics': True
|
|
}
|
|
|
|
assert sys.argv[2].endswith('.pkl')
|
|
with open(sys.argv[2], 'wb') as f:
|
|
pkl.dump(res, f)
|