diff --git a/ppcls/utils/convert_weights.py b/ppcls/utils/convert_weights.py index ee45d9f38..99ecec5a2 100644 --- a/ppcls/utils/convert_weights.py +++ b/ppcls/utils/convert_weights.py @@ -16,16 +16,27 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +__all__ == ["extract_subnet_weights"] + import os import paddle -def convert_distill_weights(distill_weights_path, student_weights_path): +def extract_subnet_weights(distill_weights_path, + student_weights_path, + student_name="Student"): assert os.path.exists(distill_weights_path), \ "Given distill_weights_path {} not exist.".format(distill_weights_path) # Load teacher and student weights all_params = paddle.load(distill_weights_path) # Extract student weights - s_params = {key[len("Student."):]: all_params[key] for key in all_params if "Student." in key} - # Save student weights + student_prefix = student_name + "." + s_params = { + key[len(student_prefix):]: all_params[key] + for key in all_params if student_prefix in key + } + assert len( + s_params + ) > 0, f"extracted params length must be > 0 but got {len(s_params)}" + # Save subnet weights paddle.save(s_params, student_weights_path)