diff --git a/docs/zh_CN/training/advanced/knowledge_distillation.md b/docs/zh_CN/training/advanced/knowledge_distillation.md index 24b48528b..cbddf3485 100644 --- a/docs/zh_CN/training/advanced/knowledge_distillation.md +++ b/docs/zh_CN/training/advanced/knowledge_distillation.md @@ -583,76 +583,6 @@ Loss: weight: 1.0 ``` -<<<<<<< HEAD -======= - - -#### 1.2.8 WSL - -##### 1.2.8.1 WSL 算法介绍 - -论文信息: - - -> [Rethinking Soft Labels For Knowledge Distillation: A Bias-variance Tradeoff Perspective](https://arxiv.org/abs/2102.0650) -> -> Helong Zhou, Liangchen Song, Jiajie Chen, Ye Zhou, Guoli Wang, Junsong Yuan, Qian Zhang -> -> ICLR, 2021 - -WSL (Weighted Soft Labels) 损失函数根据教师模型与学生模型关于真值标签的 CE Loss 比值,对每个样本的 KD Loss 分别赋予权重。若学生模型相对教师模型在某个样本上预测结果更好,则对该样本赋予较小的权重。该方法简单、有效,使各个样本的权重可自适应调节,提升了蒸馏精度。 - -在ImageNet1k公开数据集上,效果如下所示。 - -| 策略 | 骨干网络 | 配置文件 | Top-1 acc | 下载链接 | -| --- | --- | --- | --- | --- | -| baseline | ResNet18 | [ResNet18.yaml](../../../../ppcls/configs/ImageNet/ResNet/ResNet18.yaml) | 70.8% | - | -| WSL | ResNet18 | [resnet34_distill_resnet18_wsl.yaml](../../../../ppcls/configs/ImageNet/Distillation/resnet34_distill_resnet18_wsl.yaml) | 72.23%(**+1.43%**) | - | - - -##### 1.2.8.2 WSL 配置 - -WSL 配置如下所示。在模型构建Arch字段中,需要同时定义学生模型与教师模型,教师模型固定参数,且需要加载预训练模型。在损失函数Loss字段中,需要定义`DistillationGTCELoss`(学生与真值标签之间的CE loss)以及`DistillationWSLLoss`(学生与教师之间的WSL loss),作为训练的损失函数。 - - -```yaml -# model architecture -Arch: - name: "DistillationModel" - # if not null, its lengths should be same as models - pretrained_list: - # if not null, its lengths should be same as models - freeze_params_list: - - True - - False - models: - - Teacher: - name: ResNet34 - pretrained: True - - - Student: - name: ResNet18 - pretrained: False - - infer_model_name: "Student" - - -# loss function config for traing/eval process -Loss: - Train: - - DistillationGTCELoss: - weight: 1.0 - model_names: ["Student"] - - DistillationWSLLoss: - weight: 2.5 - model_name_pairs: [["Student", "Teacher"]] - temperature: 2 - Eval: - - CELoss: - weight: 1.0 -``` - ->>>>>>> 1f6f4797 (docs: refactor & fix link & rename) ## 2. 模型训练、评估和预测