diff --git a/doc/doc_ch/knowledge_distillation.md b/doc/doc_ch/knowledge_distillation.md index 6eaaf3099..769401dbf 100644 --- a/doc/doc_ch/knowledge_distillation.md +++ b/doc/doc_ch/knowledge_distillation.md @@ -569,7 +569,7 @@ all_params = paddle.load("ch_PP-OCRv2_det_distill_train/best_accuracy.pdparams") # 查看权重参数的keys print(all_params.keys()) # 学生模型的权重提取 -s_params = {key[len("Student."):]: all_params[key] for key in all_params if "Student." in key} +s_params = {key[len("student_model."):]: all_params[key] for key in all_params if "student_model." in key} # 查看学生模型权重参数的keys print(s_params.keys()) # 保存