fix convert weight
parent
5a81627859
commit
6f631e4340
|
@ -16,16 +16,27 @@ from __future__ import absolute_import
|
|||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
__all__ == ["extract_subnet_weights"]
|
||||
|
||||
import os
|
||||
import paddle
|
||||
|
||||
|
||||
def convert_distill_weights(distill_weights_path, student_weights_path):
|
||||
def extract_subnet_weights(distill_weights_path,
|
||||
student_weights_path,
|
||||
student_name="Student"):
|
||||
assert os.path.exists(distill_weights_path), \
|
||||
"Given distill_weights_path {} not exist.".format(distill_weights_path)
|
||||
# Load teacher and student weights
|
||||
all_params = paddle.load(distill_weights_path)
|
||||
# Extract student weights
|
||||
s_params = {key[len("Student."):]: all_params[key] for key in all_params if "Student." in key}
|
||||
# Save student weights
|
||||
student_prefix = student_name + "."
|
||||
s_params = {
|
||||
key[len(student_prefix):]: all_params[key]
|
||||
for key in all_params if student_prefix in key
|
||||
}
|
||||
assert len(
|
||||
s_params
|
||||
) > 0, f"extracted params length must be > 0 but got {len(s_params)}"
|
||||
# Save subnet weights
|
||||
paddle.save(s_params, student_weights_path)
|
||||
|
|
Loading…
Reference in New Issue