fix convert weight

pull/1925/head
littletomatodonkey 2022-05-17 12:46:49 +00:00
parent 5a81627859
commit 6f631e4340
1 changed files with 14 additions and 3 deletions

View File

@ -16,16 +16,27 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
__all__ == ["extract_subnet_weights"]
import os
import paddle
def convert_distill_weights(distill_weights_path, student_weights_path):
def extract_subnet_weights(distill_weights_path,
student_weights_path,
student_name="Student"):
assert os.path.exists(distill_weights_path), \
"Given distill_weights_path {} not exist.".format(distill_weights_path)
# Load teacher and student weights
all_params = paddle.load(distill_weights_path)
# Extract student weights
s_params = {key[len("Student."):]: all_params[key] for key in all_params if "Student." in key}
# Save student weights
student_prefix = student_name + "."
s_params = {
key[len(student_prefix):]: all_params[key]
for key in all_params if student_prefix in key
}
assert len(
s_params
) > 0, f"extracted params length must be > 0 but got {len(s_params)}"
# Save subnet weights
paddle.save(s_params, student_weights_path)