update config and docs
parent
c5884bb24f
commit
996467bd14
|
@ -17,6 +17,7 @@
|
||||||
- [1.2.6 DIST](#1.2.6)
|
- [1.2.6 DIST](#1.2.6)
|
||||||
- [1.2.7 MGD](#1.2.7)
|
- [1.2.7 MGD](#1.2.7)
|
||||||
- [1.2.8 WSL](#1.2.8)
|
- [1.2.8 WSL](#1.2.8)
|
||||||
|
- [1.2.9 SKD](#1.2.9)
|
||||||
- [2. Usage](#2)
|
- [2. Usage](#2)
|
||||||
- [2.1 Environment Configuration](#2.1)
|
- [2.1 Environment Configuration](#2.1)
|
||||||
- [2.2 Data Preparation](#2.2)
|
- [2.2 Data Preparation](#2.2)
|
||||||
|
@ -655,6 +656,72 @@ Loss:
|
||||||
weight: 1.0
|
weight: 1.0
|
||||||
```
|
```
|
||||||
|
|
||||||
|
<a name='1.2.9'></a>
|
||||||
|
|
||||||
|
#### 1.2.9 SKD
|
||||||
|
|
||||||
|
##### 1.2.9.1 Introduction to SKD
|
||||||
|
|
||||||
|
Paper:
|
||||||
|
|
||||||
|
|
||||||
|
> [Reducing the Teacher-Student Gap via Spherical Knowledge Disitllation](https://arxiv.org/abs/2010.07485)
|
||||||
|
>
|
||||||
|
> Jia Guo, Minghao Chen, Yao Hu, Chen Zhu, Xiaofei He, Deng Cai
|
||||||
|
>
|
||||||
|
> 2022, under review
|
||||||
|
|
||||||
|
Due to the limited capacity of the student, student performance would unexpectedly drop when distilling from an oversized teacher. Spherical Knowledge Distillation (SKD) explicitly eliminates the gap of confidence between teacher and student, so as to ease the capacity gap problem. SKD achieves a significant improvement over previous SOTA in distilling ResNet18 on ImageNet1k.
|
||||||
|
|
||||||
|
Performance on ImageNet1k is shown below.
|
||||||
|
|
||||||
|
| Strategy | Backbone | Config | Top-1 acc | Download Link |
|
||||||
|
| --- | --- | --- | --- | --- |
|
||||||
|
| baseline | ResNet18 | [ResNet18.yaml](../../../../ppcls/configs/ImageNet/ResNet/ResNet18.yaml) | 70.8% | - |
|
||||||
|
| SKD | ResNet18 | [resnet34_distill_resnet18_skd.yaml](../../../../ppcls/configs/ImageNet/Distillation/resnet34_distill_resnet18_skd.yaml) | 72.84%(**+2.04%**) | - |
|
||||||
|
|
||||||
|
|
||||||
|
##### 1.2.9.2 Configuration of SKD
|
||||||
|
|
||||||
|
The SKD configuration is shown below. In the `Arch` field, you need to define both the student model and the teacher model. The teacher model has fixed parameters, and the pretrained parameters are loaded. In the `Loss` field, you need to define `DistillationSKDLoss` (SKD loss between student and teacher). It should be noted that SKD loss includes KL div loss with teacher and CE loss with ground truth labels. Therefore, `DistillationGTCELoss` does not need to be defined.
|
||||||
|
|
||||||
|
|
||||||
|
```yaml
|
||||||
|
# model architecture
|
||||||
|
Arch:
|
||||||
|
name: "DistillationModel"
|
||||||
|
# if not null, its lengths should be same as models
|
||||||
|
pretrained_list:
|
||||||
|
# if not null, its lengths should be same as models
|
||||||
|
freeze_params_list:
|
||||||
|
- True
|
||||||
|
- False
|
||||||
|
models:
|
||||||
|
- Teacher:
|
||||||
|
name: ResNet34
|
||||||
|
pretrained: True
|
||||||
|
|
||||||
|
- Student:
|
||||||
|
name: ResNet18
|
||||||
|
pretrained: False
|
||||||
|
|
||||||
|
infer_model_name: "Student"
|
||||||
|
|
||||||
|
|
||||||
|
# loss function config for traing/eval process
|
||||||
|
Loss:
|
||||||
|
Train:
|
||||||
|
- DistillationSKDLoss:
|
||||||
|
weight: 1.0
|
||||||
|
model_name_pairs: [["Student", "Teacher"]]
|
||||||
|
temperature: 1.0
|
||||||
|
multiplier: 2.0
|
||||||
|
alpha: 0.9
|
||||||
|
Eval:
|
||||||
|
- CELoss:
|
||||||
|
weight: 1.0
|
||||||
|
```
|
||||||
|
|
||||||
<a name="2"></a>
|
<a name="2"></a>
|
||||||
|
|
||||||
## 2. Training, Evaluation and Prediction
|
## 2. Training, Evaluation and Prediction
|
||||||
|
|
|
@ -682,7 +682,7 @@ Loss:
|
||||||
|
|
||||||
##### 1.2.9.2 SKD 配置
|
##### 1.2.9.2 SKD 配置
|
||||||
|
|
||||||
SKD 配置如下所示。在模型构建Arch字段中,需要同时定义学生模型与教师模型,教师模型固定参数,且需要加载预训练模型。在损失函数Loss字段中,需要定义`DistillationSKDLoss`(学生与教师之间的SKD loss),作为训练的损失函数。
|
SKD 配置如下所示。在模型构建Arch字段中,需要同时定义学生模型与教师模型,教师模型固定参数,且需要加载预训练模型。在损失函数Loss字段中,需要定义`DistillationSKDLoss`(学生与教师之间的SKD loss),作为训练的损失函数。需要注意的是,SKD loss包含了学生与教师模型之间的KL div loss和学生模型与真值标签之间的CE loss,因此无需定义`DistillationGTCELoss`。
|
||||||
|
|
||||||
|
|
||||||
```yaml
|
```yaml
|
||||||
|
|
|
@ -55,7 +55,7 @@ Optimizer:
|
||||||
weight_decay: 1e-4
|
weight_decay: 1e-4
|
||||||
lr:
|
lr:
|
||||||
name: MultiStepDecay
|
name: MultiStepDecay
|
||||||
learning_rate: 0.2
|
learning_rate: 0.1
|
||||||
milestones: [30, 60, 90]
|
milestones: [30, 60, 90]
|
||||||
step_each_epoch: 1
|
step_each_epoch: 1
|
||||||
gamma: 0.1
|
gamma: 0.1
|
||||||
|
@ -84,7 +84,7 @@ DataLoader:
|
||||||
|
|
||||||
sampler:
|
sampler:
|
||||||
name: DistributedBatchSampler
|
name: DistributedBatchSampler
|
||||||
batch_size: 128
|
batch_size: 64
|
||||||
drop_last: False
|
drop_last: False
|
||||||
shuffle: True
|
shuffle: True
|
||||||
loader:
|
loader:
|
||||||
|
|
Loading…
Reference in New Issue