mirror of
https://github.com/open-mmlab/mmselfsup.git
synced 2025-06-03 14:59:38 +08:00
37 lines
980 B
Python
37 lines
980 B
Python
|
#!/usr/bin/env python
|
||
|
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
|
||
|
|
||
|
import pickle as pkl
|
||
|
import sys
|
||
|
import torch
|
||
|
|
||
|
if __name__ == "__main__":
|
||
|
input = sys.argv[1]
|
||
|
|
||
|
obj = torch.load(input, map_location="cpu")
|
||
|
obj = obj["state_dict"]
|
||
|
|
||
|
newmodel = {}
|
||
|
for k, v in obj.items():
|
||
|
old_k = k
|
||
|
if "layer" not in k:
|
||
|
k = "stem." + k
|
||
|
for t in [1, 2, 3, 4]:
|
||
|
k = k.replace("layer{}".format(t), "res{}".format(t + 1))
|
||
|
for t in [1, 2, 3]:
|
||
|
k = k.replace("bn{}".format(t), "conv{}.norm".format(t))
|
||
|
k = k.replace("downsample.0", "shortcut")
|
||
|
k = k.replace("downsample.1", "shortcut.norm")
|
||
|
print(old_k, "->", k)
|
||
|
newmodel[k] = v.numpy()
|
||
|
|
||
|
res = {
|
||
|
"model": newmodel,
|
||
|
"__author__": "OpenSelfSup",
|
||
|
"matching_heuristics": True
|
||
|
}
|
||
|
|
||
|
assert sys.argv[2].endswith('.pkl')
|
||
|
with open(sys.argv[2], "wb") as f:
|
||
|
pkl.dump(res, f)
|