add dkd (#1888)
* add dkd * update dkd * update dkd * update dkd * update dkd * update dkd * update dkd and add tipcpull/1905/head
parent
f5da904497
commit
283ae9b327
|
@ -0,0 +1,155 @@
|
||||||
|
# global configs
|
||||||
|
Global:
|
||||||
|
checkpoints: null
|
||||||
|
pretrained_model: null
|
||||||
|
output_dir: "./output/"
|
||||||
|
device: "gpu"
|
||||||
|
save_interval: 1
|
||||||
|
eval_during_train: True
|
||||||
|
eval_interval: 1
|
||||||
|
epochs: 100
|
||||||
|
print_batch_step: 10
|
||||||
|
use_visualdl: False
|
||||||
|
# used for static mode and model export
|
||||||
|
image_shape: [3, 224, 224]
|
||||||
|
save_inference_dir: "./inference"
|
||||||
|
|
||||||
|
# model architecture
|
||||||
|
Arch:
|
||||||
|
name: "DistillationModel"
|
||||||
|
# if not null, its lengths should be same as models
|
||||||
|
pretrained_list:
|
||||||
|
# if not null, its lengths should be same as models
|
||||||
|
freeze_params_list:
|
||||||
|
- True
|
||||||
|
- False
|
||||||
|
models:
|
||||||
|
- Teacher:
|
||||||
|
name: ResNet34
|
||||||
|
pretrained: True
|
||||||
|
|
||||||
|
- Student:
|
||||||
|
name: ResNet18
|
||||||
|
pretrained: False
|
||||||
|
|
||||||
|
infer_model_name: "Student"
|
||||||
|
|
||||||
|
|
||||||
|
# loss function config for traing/eval process
|
||||||
|
Loss:
|
||||||
|
Train:
|
||||||
|
- DistillationGTCELoss:
|
||||||
|
weight: 1.0
|
||||||
|
model_names: ["Student"]
|
||||||
|
- DistillationDKDLoss:
|
||||||
|
weight: 1.0
|
||||||
|
model_name_pairs: [["Student", "Teacher"]]
|
||||||
|
temperature: 1
|
||||||
|
alpha: 1.0
|
||||||
|
beta: 1.0
|
||||||
|
Eval:
|
||||||
|
- CELoss:
|
||||||
|
weight: 1.0
|
||||||
|
|
||||||
|
|
||||||
|
Optimizer:
|
||||||
|
name: Momentum
|
||||||
|
momentum: 0.9
|
||||||
|
weight_decay: 1e-4
|
||||||
|
lr:
|
||||||
|
name: MultiStepDecay
|
||||||
|
learning_rate: 0.2
|
||||||
|
milestones: [30, 60, 90]
|
||||||
|
step_each_epoch: 1
|
||||||
|
gamma: 0.1
|
||||||
|
|
||||||
|
|
||||||
|
# data loader for train and eval
|
||||||
|
DataLoader:
|
||||||
|
Train:
|
||||||
|
dataset:
|
||||||
|
name: ImageNetDataset
|
||||||
|
image_root: "./dataset/ILSVRC2012/"
|
||||||
|
cls_label_path: "./dataset/ILSVRC2012/train_list.txt"
|
||||||
|
transform_ops:
|
||||||
|
- DecodeImage:
|
||||||
|
to_rgb: True
|
||||||
|
channel_first: False
|
||||||
|
- RandCropImage:
|
||||||
|
size: 224
|
||||||
|
- RandFlipImage:
|
||||||
|
flip_code: 1
|
||||||
|
- NormalizeImage:
|
||||||
|
scale: 0.00392157
|
||||||
|
mean: [0.485, 0.456, 0.406]
|
||||||
|
std: [0.229, 0.224, 0.225]
|
||||||
|
order: ''
|
||||||
|
|
||||||
|
sampler:
|
||||||
|
name: DistributedBatchSampler
|
||||||
|
batch_size: 128
|
||||||
|
drop_last: False
|
||||||
|
shuffle: True
|
||||||
|
loader:
|
||||||
|
num_workers: 8
|
||||||
|
use_shared_memory: True
|
||||||
|
|
||||||
|
Eval:
|
||||||
|
dataset:
|
||||||
|
name: ImageNetDataset
|
||||||
|
image_root: "./dataset/ILSVRC2012/"
|
||||||
|
cls_label_path: "./dataset/ILSVRC2012/val_list.txt"
|
||||||
|
transform_ops:
|
||||||
|
- DecodeImage:
|
||||||
|
to_rgb: True
|
||||||
|
channel_first: False
|
||||||
|
- ResizeImage:
|
||||||
|
resize_short: 256
|
||||||
|
- CropImage:
|
||||||
|
size: 224
|
||||||
|
- NormalizeImage:
|
||||||
|
scale: 0.00392157
|
||||||
|
mean: [0.485, 0.456, 0.406]
|
||||||
|
std: [0.229, 0.224, 0.225]
|
||||||
|
order: ''
|
||||||
|
sampler:
|
||||||
|
name: DistributedBatchSampler
|
||||||
|
batch_size: 64
|
||||||
|
drop_last: False
|
||||||
|
shuffle: False
|
||||||
|
loader:
|
||||||
|
num_workers: 4
|
||||||
|
use_shared_memory: True
|
||||||
|
|
||||||
|
Infer:
|
||||||
|
infer_imgs: "docs/images/inference_deployment/whl_demo.jpg"
|
||||||
|
batch_size: 10
|
||||||
|
transforms:
|
||||||
|
- DecodeImage:
|
||||||
|
to_rgb: True
|
||||||
|
channel_first: False
|
||||||
|
- ResizeImage:
|
||||||
|
resize_short: 256
|
||||||
|
- CropImage:
|
||||||
|
size: 224
|
||||||
|
- NormalizeImage:
|
||||||
|
scale: 1.0/255.0
|
||||||
|
mean: [0.485, 0.456, 0.406]
|
||||||
|
std: [0.229, 0.224, 0.225]
|
||||||
|
order: ''
|
||||||
|
- ToCHWImage:
|
||||||
|
PostProcess:
|
||||||
|
name: DistillationPostProcess
|
||||||
|
func: Topk
|
||||||
|
topk: 5
|
||||||
|
class_id_map_file: "ppcls/utils/imagenet1k_label_list.txt"
|
||||||
|
|
||||||
|
Metric:
|
||||||
|
Train:
|
||||||
|
- DistillationTopkAcc:
|
||||||
|
model_key: "Student"
|
||||||
|
topk: [1, 5]
|
||||||
|
Eval:
|
||||||
|
- DistillationTopkAcc:
|
||||||
|
model_key: "Student"
|
||||||
|
topk: [1, 5]
|
|
@ -23,6 +23,7 @@ from .distillationloss import DistillationDMLLoss
|
||||||
from .distillationloss import DistillationDistanceLoss
|
from .distillationloss import DistillationDistanceLoss
|
||||||
from .distillationloss import DistillationRKDLoss
|
from .distillationloss import DistillationRKDLoss
|
||||||
from .distillationloss import DistillationKLDivLoss
|
from .distillationloss import DistillationKLDivLoss
|
||||||
|
from .distillationloss import DistillationDKDLoss
|
||||||
from .multilabelloss import MultiLabelLoss
|
from .multilabelloss import MultiLabelLoss
|
||||||
from .afdloss import AFDLoss
|
from .afdloss import AFDLoss
|
||||||
|
|
||||||
|
|
|
@ -21,6 +21,7 @@ from .dmlloss import DMLLoss
|
||||||
from .distanceloss import DistanceLoss
|
from .distanceloss import DistanceLoss
|
||||||
from .rkdloss import RKdAngle, RkdDistance
|
from .rkdloss import RKdAngle, RkdDistance
|
||||||
from .kldivloss import KLDivLoss
|
from .kldivloss import KLDivLoss
|
||||||
|
from .dkdloss import DKDLoss
|
||||||
|
|
||||||
|
|
||||||
class DistillationCELoss(CELoss):
|
class DistillationCELoss(CELoss):
|
||||||
|
@ -204,3 +205,33 @@ class DistillationKLDivLoss(KLDivLoss):
|
||||||
for key in loss:
|
for key in loss:
|
||||||
loss_dict["{}_{}_{}".format(key, pair[0], pair[1])] = loss[key]
|
loss_dict["{}_{}_{}".format(key, pair[0], pair[1])] = loss[key]
|
||||||
return loss_dict
|
return loss_dict
|
||||||
|
|
||||||
|
|
||||||
|
class DistillationDKDLoss(DKDLoss):
|
||||||
|
"""
|
||||||
|
DistillationDKDLoss
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self,
|
||||||
|
model_name_pairs=[],
|
||||||
|
key=None,
|
||||||
|
temperature=1.0,
|
||||||
|
alpha=1.0,
|
||||||
|
beta=1.0,
|
||||||
|
name="loss_dkd"):
|
||||||
|
super().__init__(temperature=temperature, alpha=alpha, beta=beta)
|
||||||
|
self.key = key
|
||||||
|
self.model_name_pairs = model_name_pairs
|
||||||
|
self.name = name
|
||||||
|
|
||||||
|
def forward(self, predicts, batch):
|
||||||
|
loss_dict = dict()
|
||||||
|
for idx, pair in enumerate(self.model_name_pairs):
|
||||||
|
out1 = predicts[pair[0]]
|
||||||
|
out2 = predicts[pair[1]]
|
||||||
|
if self.key is not None:
|
||||||
|
out1 = out1[self.key]
|
||||||
|
out2 = out2[self.key]
|
||||||
|
loss = super().forward(out1, out2, batch)
|
||||||
|
loss_dict[f"{self.name}_{pair[0]}_{pair[1]}"] = loss
|
||||||
|
return loss_dict
|
||||||
|
|
|
@ -0,0 +1,61 @@
|
||||||
|
import paddle
|
||||||
|
import paddle.nn as nn
|
||||||
|
import paddle.nn.functional as F
|
||||||
|
|
||||||
|
|
||||||
|
class DKDLoss(nn.Layer):
|
||||||
|
"""
|
||||||
|
DKDLoss
|
||||||
|
Reference: https://arxiv.org/abs/2203.08679
|
||||||
|
Code was heavily based on https://github.com/megvii-research/mdistiller
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, temperature=1.0, alpha=1.0, beta=1.0):
|
||||||
|
super().__init__()
|
||||||
|
self.temperature = temperature
|
||||||
|
self.alpha = alpha
|
||||||
|
self.beta = beta
|
||||||
|
|
||||||
|
def forward(self, logits_student, logits_teacher, target):
|
||||||
|
gt_mask = _get_gt_mask(logits_student, target)
|
||||||
|
other_mask = 1 - gt_mask
|
||||||
|
pred_student = F.softmax(logits_student / self.temperature, axis=1)
|
||||||
|
pred_teacher = F.softmax(logits_teacher / self.temperature, axis=1)
|
||||||
|
pred_student = cat_mask(pred_student, gt_mask, other_mask)
|
||||||
|
pred_teacher = cat_mask(pred_teacher, gt_mask, other_mask)
|
||||||
|
log_pred_student = paddle.log(pred_student)
|
||||||
|
tckd_loss = (F.kl_div(
|
||||||
|
log_pred_student, pred_teacher,
|
||||||
|
reduction='sum') * (self.temperature**2) / target.shape[0])
|
||||||
|
pred_teacher_part2 = F.softmax(
|
||||||
|
logits_teacher / self.temperature - 1000.0 * gt_mask, axis=1)
|
||||||
|
log_pred_student_part2 = F.log_softmax(
|
||||||
|
logits_student / self.temperature - 1000.0 * gt_mask, axis=1)
|
||||||
|
nckd_loss = (F.kl_div(
|
||||||
|
log_pred_student_part2, pred_teacher_part2,
|
||||||
|
reduction='sum') * (self.temperature**2) / target.shape[0])
|
||||||
|
return self.alpha * tckd_loss + self.beta * nckd_loss
|
||||||
|
|
||||||
|
|
||||||
|
def _get_gt_mask(logits, target):
|
||||||
|
target = target.reshape([-1]).unsqueeze(1)
|
||||||
|
updates = paddle.ones_like(target)
|
||||||
|
mask = scatter(
|
||||||
|
paddle.zeros_like(logits), target, updates.astype('float32'))
|
||||||
|
return mask
|
||||||
|
|
||||||
|
|
||||||
|
def cat_mask(t, mask1, mask2):
|
||||||
|
t1 = (t * mask1).sum(axis=1, keepdim=True)
|
||||||
|
t2 = (t * mask2).sum(axis=1, keepdim=True)
|
||||||
|
rt = paddle.concat([t1, t2], axis=1)
|
||||||
|
return rt
|
||||||
|
|
||||||
|
|
||||||
|
def scatter(x, index, updates):
|
||||||
|
i, j = index.shape
|
||||||
|
grid_x, grid_y = paddle.meshgrid(paddle.arange(i), paddle.arange(j))
|
||||||
|
index = paddle.stack([grid_x.flatten(), index.flatten()], axis=1)
|
||||||
|
updates_index = paddle.stack([grid_x.flatten(), grid_y.flatten()], axis=1)
|
||||||
|
updates = paddle.gather_nd(updates, index=updates_index)
|
||||||
|
return paddle.scatter_nd_add(x, index, updates)
|
|
@ -0,0 +1,54 @@
|
||||||
|
===========================train_params===========================
|
||||||
|
model_name:DistillationModel
|
||||||
|
python:python3.7
|
||||||
|
gpu_list:0|0,1
|
||||||
|
-o Global.device:gpu
|
||||||
|
-o Global.auto_cast:null
|
||||||
|
-o Global.epochs:lite_train_lite_infer=2|whole_train_whole_infer=100
|
||||||
|
-o Global.output_dir:./output/
|
||||||
|
-o DataLoader.Train.sampler.batch_size:8
|
||||||
|
-o Global.pretrained_model:null
|
||||||
|
train_model_name:latest
|
||||||
|
train_infer_img_dir:./dataset/ILSVRC2012/val
|
||||||
|
null:null
|
||||||
|
##
|
||||||
|
trainer:amp_train
|
||||||
|
amp_train:tools/train.py -c ppcls/configs/ImageNet/Distillation/resnet34_distill_resnet18_dkd.yaml -o Global.seed=1234 -o DataLoader.Train.sampler.shuffle=False -o DataLoader.Train.loader.num_workers=0 -o DataLoader.Train.loader.use_shared_memory=False -o AMP.scale_loss=128 -o AMP.use_dynamic_loss_scaling=True -o AMP.level=O2
|
||||||
|
pact_train:null
|
||||||
|
fpgm_train:null
|
||||||
|
distill_train:null
|
||||||
|
null:null
|
||||||
|
null:null
|
||||||
|
##
|
||||||
|
===========================eval_params===========================
|
||||||
|
eval:tools/eval.py -c ppcls/configs/ImageNet/Distillation/resnet34_distill_resnet18_dkd.yaml
|
||||||
|
null:null
|
||||||
|
##
|
||||||
|
===========================infer_params==========================
|
||||||
|
-o Global.save_inference_dir:./inference
|
||||||
|
-o Global.pretrained_model:
|
||||||
|
norm_export:tools/export_model.py -c ppcls/configs/ImageNet/Distillation/resnet34_distill_resnet18_dkd.yaml
|
||||||
|
quant_export:null
|
||||||
|
fpgm_export:null
|
||||||
|
distill_export:null
|
||||||
|
kl_quant:null
|
||||||
|
export2:null
|
||||||
|
pretrained_model_url:https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/legendary_models/ResNet18_pretrained.pdparams
|
||||||
|
infer_model:../inference/
|
||||||
|
infer_export:True
|
||||||
|
infer_quant:Fasle
|
||||||
|
inference:python/predict_cls.py -c configs/inference_cls.yaml
|
||||||
|
-o Global.use_gpu:True|False
|
||||||
|
-o Global.enable_mkldnn:True|False
|
||||||
|
-o Global.cpu_num_threads:1|6
|
||||||
|
-o Global.batch_size:1|16
|
||||||
|
-o Global.use_tensorrt:True|False
|
||||||
|
-o Global.use_fp16:True|False
|
||||||
|
-o Global.inference_model_dir:../inference
|
||||||
|
-o Global.infer_imgs:../dataset/ILSVRC2012/val
|
||||||
|
-o Global.save_log_path:null
|
||||||
|
-o Global.benchmark:True
|
||||||
|
null:null
|
||||||
|
null:null
|
||||||
|
===========================infer_benchmark_params==========================
|
||||||
|
random_infer_input:[{float32,[3,224,224]}]
|
|
@ -0,0 +1,54 @@
|
||||||
|
===========================train_params===========================
|
||||||
|
model_name:DistillationModel
|
||||||
|
python:python3.7
|
||||||
|
gpu_list:0|0,1
|
||||||
|
-o Global.device:gpu
|
||||||
|
-o Global.auto_cast:null
|
||||||
|
-o Global.epochs:lite_train_lite_infer=2|whole_train_whole_infer=100
|
||||||
|
-o Global.output_dir:./output/
|
||||||
|
-o DataLoader.Train.sampler.batch_size:8
|
||||||
|
-o Global.pretrained_model:null
|
||||||
|
train_model_name:latest
|
||||||
|
train_infer_img_dir:./dataset/ILSVRC2012/val
|
||||||
|
null:null
|
||||||
|
##
|
||||||
|
trainer:norm_train
|
||||||
|
norm_train:tools/train.py -c ppcls/configs/ImageNet/Distillation/resnet34_distill_resnet18_dkd.yaml -o Global.seed=1234 -o DataLoader.Train.sampler.shuffle=False -o DataLoader.Train.loader.num_workers=0 -o DataLoader.Train.loader.use_shared_memory=False
|
||||||
|
pact_train:null
|
||||||
|
fpgm_train:null
|
||||||
|
distill_train:null
|
||||||
|
null:null
|
||||||
|
null:null
|
||||||
|
##
|
||||||
|
===========================eval_params===========================
|
||||||
|
eval:tools/eval.py -c ppcls/configs/ImageNet/Distillation/resnet34_distill_resnet18_dkd.yaml
|
||||||
|
null:null
|
||||||
|
##
|
||||||
|
===========================infer_params==========================
|
||||||
|
-o Global.save_inference_dir:./inference
|
||||||
|
-o Global.pretrained_model:
|
||||||
|
norm_export:tools/export_model.py -c ppcls/configs/ImageNet/Distillation/resnet34_distill_resnet18_dkd.yaml
|
||||||
|
quant_export:null
|
||||||
|
fpgm_export:null
|
||||||
|
distill_export:null
|
||||||
|
kl_quant:null
|
||||||
|
export2:null
|
||||||
|
pretrained_model_url:https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/legendary_models/ResNet18_pretrained.pdparams
|
||||||
|
infer_model:../inference/
|
||||||
|
infer_export:True
|
||||||
|
infer_quant:Fasle
|
||||||
|
inference:python/predict_cls.py -c configs/inference_cls.yaml
|
||||||
|
-o Global.use_gpu:True|False
|
||||||
|
-o Global.enable_mkldnn:True|False
|
||||||
|
-o Global.cpu_num_threads:1|6
|
||||||
|
-o Global.batch_size:1|16
|
||||||
|
-o Global.use_tensorrt:True|False
|
||||||
|
-o Global.use_fp16:True|False
|
||||||
|
-o Global.inference_model_dir:../inference
|
||||||
|
-o Global.infer_imgs:../dataset/ILSVRC2012/val
|
||||||
|
-o Global.save_log_path:null
|
||||||
|
-o Global.benchmark:True
|
||||||
|
null:null
|
||||||
|
null:null
|
||||||
|
===========================infer_benchmark_params==========================
|
||||||
|
random_infer_input:[{float32,[3,224,224]}]
|
Loading…
Reference in New Issue