From 283ae9b32790d2a42265d7a60561ecd8889d4ebb Mon Sep 17 00:00:00 2001 From: WangChen0902 <827913668@qq.com> Date: Mon, 9 May 2022 22:55:01 +0800 Subject: [PATCH] add dkd (#1888) * add dkd * update dkd * update dkd * update dkd * update dkd * update dkd * update dkd and add tipc --- .../resnet34_distill_resnet18_dkd.yaml | 155 ++++++++++++++++++ ppcls/loss/__init__.py | 1 + ppcls/loss/distillationloss.py | 31 ++++ ppcls/loss/dkdloss.py | 61 +++++++ ...ll_resnet18_dkd_train_amp_infer_python.txt | 54 ++++++ ...istill_resnet18_dkd_train_infer_python.txt | 54 ++++++ 6 files changed, 356 insertions(+) create mode 100644 ppcls/configs/ImageNet/Distillation/resnet34_distill_resnet18_dkd.yaml create mode 100644 ppcls/loss/dkdloss.py create mode 100644 test_tipc/config/Distillation/resnet34_distill_resnet18_dkd_train_amp_infer_python.txt create mode 100644 test_tipc/config/Distillation/resnet34_distill_resnet18_dkd_train_infer_python.txt diff --git a/ppcls/configs/ImageNet/Distillation/resnet34_distill_resnet18_dkd.yaml b/ppcls/configs/ImageNet/Distillation/resnet34_distill_resnet18_dkd.yaml new file mode 100644 index 000000000..8efa5c054 --- /dev/null +++ b/ppcls/configs/ImageNet/Distillation/resnet34_distill_resnet18_dkd.yaml @@ -0,0 +1,155 @@ +# global configs +Global: + checkpoints: null + pretrained_model: null + output_dir: "./output/" + device: "gpu" + save_interval: 1 + eval_during_train: True + eval_interval: 1 + epochs: 100 + print_batch_step: 10 + use_visualdl: False + # used for static mode and model export + image_shape: [3, 224, 224] + save_inference_dir: "./inference" + +# model architecture +Arch: + name: "DistillationModel" + # if not null, its lengths should be same as models + pretrained_list: + # if not null, its lengths should be same as models + freeze_params_list: + - True + - False + models: + - Teacher: + name: ResNet34 + pretrained: True + + - Student: + name: ResNet18 + pretrained: False + + infer_model_name: "Student" + + +# loss function config for traing/eval process +Loss: + Train: + - DistillationGTCELoss: + weight: 1.0 + model_names: ["Student"] + - DistillationDKDLoss: + weight: 1.0 + model_name_pairs: [["Student", "Teacher"]] + temperature: 1 + alpha: 1.0 + beta: 1.0 + Eval: + - CELoss: + weight: 1.0 + + +Optimizer: + name: Momentum + momentum: 0.9 + weight_decay: 1e-4 + lr: + name: MultiStepDecay + learning_rate: 0.2 + milestones: [30, 60, 90] + step_each_epoch: 1 + gamma: 0.1 + + +# data loader for train and eval +DataLoader: + Train: + dataset: + name: ImageNetDataset + image_root: "./dataset/ILSVRC2012/" + cls_label_path: "./dataset/ILSVRC2012/train_list.txt" + transform_ops: + - DecodeImage: + to_rgb: True + channel_first: False + - RandCropImage: + size: 224 + - RandFlipImage: + flip_code: 1 + - NormalizeImage: + scale: 0.00392157 + mean: [0.485, 0.456, 0.406] + std: [0.229, 0.224, 0.225] + order: '' + + sampler: + name: DistributedBatchSampler + batch_size: 128 + drop_last: False + shuffle: True + loader: + num_workers: 8 + use_shared_memory: True + + Eval: + dataset: + name: ImageNetDataset + image_root: "./dataset/ILSVRC2012/" + cls_label_path: "./dataset/ILSVRC2012/val_list.txt" + transform_ops: + - DecodeImage: + to_rgb: True + channel_first: False + - ResizeImage: + resize_short: 256 + - CropImage: + size: 224 + - NormalizeImage: + scale: 0.00392157 + mean: [0.485, 0.456, 0.406] + std: [0.229, 0.224, 0.225] + order: '' + sampler: + name: DistributedBatchSampler + batch_size: 64 + drop_last: False + shuffle: False + loader: + num_workers: 4 + use_shared_memory: True + +Infer: + infer_imgs: "docs/images/inference_deployment/whl_demo.jpg" + batch_size: 10 + transforms: + - DecodeImage: + to_rgb: True + channel_first: False + - ResizeImage: + resize_short: 256 + - CropImage: + size: 224 + - NormalizeImage: + scale: 1.0/255.0 + mean: [0.485, 0.456, 0.406] + std: [0.229, 0.224, 0.225] + order: '' + - ToCHWImage: + PostProcess: + name: DistillationPostProcess + func: Topk + topk: 5 + class_id_map_file: "ppcls/utils/imagenet1k_label_list.txt" + +Metric: + Train: + - DistillationTopkAcc: + model_key: "Student" + topk: [1, 5] + Eval: + - DistillationTopkAcc: + model_key: "Student" + topk: [1, 5] diff --git a/ppcls/loss/__init__.py b/ppcls/loss/__init__.py index c3281b0e5..c1f2f95df 100644 --- a/ppcls/loss/__init__.py +++ b/ppcls/loss/__init__.py @@ -23,6 +23,7 @@ from .distillationloss import DistillationDMLLoss from .distillationloss import DistillationDistanceLoss from .distillationloss import DistillationRKDLoss from .distillationloss import DistillationKLDivLoss +from .distillationloss import DistillationDKDLoss from .multilabelloss import MultiLabelLoss from .afdloss import AFDLoss diff --git a/ppcls/loss/distillationloss.py b/ppcls/loss/distillationloss.py index 21e5ef371..c60a540db 100644 --- a/ppcls/loss/distillationloss.py +++ b/ppcls/loss/distillationloss.py @@ -21,6 +21,7 @@ from .dmlloss import DMLLoss from .distanceloss import DistanceLoss from .rkdloss import RKdAngle, RkdDistance from .kldivloss import KLDivLoss +from .dkdloss import DKDLoss class DistillationCELoss(CELoss): @@ -204,3 +205,33 @@ class DistillationKLDivLoss(KLDivLoss): for key in loss: loss_dict["{}_{}_{}".format(key, pair[0], pair[1])] = loss[key] return loss_dict + + +class DistillationDKDLoss(DKDLoss): + """ + DistillationDKDLoss + """ + + def __init__(self, + model_name_pairs=[], + key=None, + temperature=1.0, + alpha=1.0, + beta=1.0, + name="loss_dkd"): + super().__init__(temperature=temperature, alpha=alpha, beta=beta) + self.key = key + self.model_name_pairs = model_name_pairs + self.name = name + + def forward(self, predicts, batch): + loss_dict = dict() + for idx, pair in enumerate(self.model_name_pairs): + out1 = predicts[pair[0]] + out2 = predicts[pair[1]] + if self.key is not None: + out1 = out1[self.key] + out2 = out2[self.key] + loss = super().forward(out1, out2, batch) + loss_dict[f"{self.name}_{pair[0]}_{pair[1]}"] = loss + return loss_dict diff --git a/ppcls/loss/dkdloss.py b/ppcls/loss/dkdloss.py new file mode 100644 index 000000000..9ce2c56d9 --- /dev/null +++ b/ppcls/loss/dkdloss.py @@ -0,0 +1,61 @@ +import paddle +import paddle.nn as nn +import paddle.nn.functional as F + + +class DKDLoss(nn.Layer): + """ + DKDLoss + Reference: https://arxiv.org/abs/2203.08679 + Code was heavily based on https://github.com/megvii-research/mdistiller + """ + + def __init__(self, temperature=1.0, alpha=1.0, beta=1.0): + super().__init__() + self.temperature = temperature + self.alpha = alpha + self.beta = beta + + def forward(self, logits_student, logits_teacher, target): + gt_mask = _get_gt_mask(logits_student, target) + other_mask = 1 - gt_mask + pred_student = F.softmax(logits_student / self.temperature, axis=1) + pred_teacher = F.softmax(logits_teacher / self.temperature, axis=1) + pred_student = cat_mask(pred_student, gt_mask, other_mask) + pred_teacher = cat_mask(pred_teacher, gt_mask, other_mask) + log_pred_student = paddle.log(pred_student) + tckd_loss = (F.kl_div( + log_pred_student, pred_teacher, + reduction='sum') * (self.temperature**2) / target.shape[0]) + pred_teacher_part2 = F.softmax( + logits_teacher / self.temperature - 1000.0 * gt_mask, axis=1) + log_pred_student_part2 = F.log_softmax( + logits_student / self.temperature - 1000.0 * gt_mask, axis=1) + nckd_loss = (F.kl_div( + log_pred_student_part2, pred_teacher_part2, + reduction='sum') * (self.temperature**2) / target.shape[0]) + return self.alpha * tckd_loss + self.beta * nckd_loss + + +def _get_gt_mask(logits, target): + target = target.reshape([-1]).unsqueeze(1) + updates = paddle.ones_like(target) + mask = scatter( + paddle.zeros_like(logits), target, updates.astype('float32')) + return mask + + +def cat_mask(t, mask1, mask2): + t1 = (t * mask1).sum(axis=1, keepdim=True) + t2 = (t * mask2).sum(axis=1, keepdim=True) + rt = paddle.concat([t1, t2], axis=1) + return rt + + +def scatter(x, index, updates): + i, j = index.shape + grid_x, grid_y = paddle.meshgrid(paddle.arange(i), paddle.arange(j)) + index = paddle.stack([grid_x.flatten(), index.flatten()], axis=1) + updates_index = paddle.stack([grid_x.flatten(), grid_y.flatten()], axis=1) + updates = paddle.gather_nd(updates, index=updates_index) + return paddle.scatter_nd_add(x, index, updates) diff --git a/test_tipc/config/Distillation/resnet34_distill_resnet18_dkd_train_amp_infer_python.txt b/test_tipc/config/Distillation/resnet34_distill_resnet18_dkd_train_amp_infer_python.txt new file mode 100644 index 000000000..ab9403947 --- /dev/null +++ b/test_tipc/config/Distillation/resnet34_distill_resnet18_dkd_train_amp_infer_python.txt @@ -0,0 +1,54 @@ +===========================train_params=========================== +model_name:DistillationModel +python:python3.7 +gpu_list:0|0,1 +-o Global.device:gpu +-o Global.auto_cast:null +-o Global.epochs:lite_train_lite_infer=2|whole_train_whole_infer=100 +-o Global.output_dir:./output/ +-o DataLoader.Train.sampler.batch_size:8 +-o Global.pretrained_model:null +train_model_name:latest +train_infer_img_dir:./dataset/ILSVRC2012/val +null:null +## +trainer:amp_train +amp_train:tools/train.py -c ppcls/configs/ImageNet/Distillation/resnet34_distill_resnet18_dkd.yaml -o Global.seed=1234 -o DataLoader.Train.sampler.shuffle=False -o DataLoader.Train.loader.num_workers=0 -o DataLoader.Train.loader.use_shared_memory=False -o AMP.scale_loss=128 -o AMP.use_dynamic_loss_scaling=True -o AMP.level=O2 +pact_train:null +fpgm_train:null +distill_train:null +null:null +null:null +## +===========================eval_params=========================== +eval:tools/eval.py -c ppcls/configs/ImageNet/Distillation/resnet34_distill_resnet18_dkd.yaml +null:null +## +===========================infer_params========================== +-o Global.save_inference_dir:./inference +-o Global.pretrained_model: +norm_export:tools/export_model.py -c ppcls/configs/ImageNet/Distillation/resnet34_distill_resnet18_dkd.yaml +quant_export:null +fpgm_export:null +distill_export:null +kl_quant:null +export2:null +pretrained_model_url:https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/legendary_models/ResNet18_pretrained.pdparams +infer_model:../inference/ +infer_export:True +infer_quant:Fasle +inference:python/predict_cls.py -c configs/inference_cls.yaml +-o Global.use_gpu:True|False +-o Global.enable_mkldnn:True|False +-o Global.cpu_num_threads:1|6 +-o Global.batch_size:1|16 +-o Global.use_tensorrt:True|False +-o Global.use_fp16:True|False +-o Global.inference_model_dir:../inference +-o Global.infer_imgs:../dataset/ILSVRC2012/val +-o Global.save_log_path:null +-o Global.benchmark:True +null:null +null:null +===========================infer_benchmark_params========================== +random_infer_input:[{float32,[3,224,224]}] diff --git a/test_tipc/config/Distillation/resnet34_distill_resnet18_dkd_train_infer_python.txt b/test_tipc/config/Distillation/resnet34_distill_resnet18_dkd_train_infer_python.txt new file mode 100644 index 000000000..4b216a9f0 --- /dev/null +++ b/test_tipc/config/Distillation/resnet34_distill_resnet18_dkd_train_infer_python.txt @@ -0,0 +1,54 @@ +===========================train_params=========================== +model_name:DistillationModel +python:python3.7 +gpu_list:0|0,1 +-o Global.device:gpu +-o Global.auto_cast:null +-o Global.epochs:lite_train_lite_infer=2|whole_train_whole_infer=100 +-o Global.output_dir:./output/ +-o DataLoader.Train.sampler.batch_size:8 +-o Global.pretrained_model:null +train_model_name:latest +train_infer_img_dir:./dataset/ILSVRC2012/val +null:null +## +trainer:norm_train +norm_train:tools/train.py -c ppcls/configs/ImageNet/Distillation/resnet34_distill_resnet18_dkd.yaml -o Global.seed=1234 -o DataLoader.Train.sampler.shuffle=False -o DataLoader.Train.loader.num_workers=0 -o DataLoader.Train.loader.use_shared_memory=False +pact_train:null +fpgm_train:null +distill_train:null +null:null +null:null +## +===========================eval_params=========================== +eval:tools/eval.py -c ppcls/configs/ImageNet/Distillation/resnet34_distill_resnet18_dkd.yaml +null:null +## +===========================infer_params========================== +-o Global.save_inference_dir:./inference +-o Global.pretrained_model: +norm_export:tools/export_model.py -c ppcls/configs/ImageNet/Distillation/resnet34_distill_resnet18_dkd.yaml +quant_export:null +fpgm_export:null +distill_export:null +kl_quant:null +export2:null +pretrained_model_url:https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/legendary_models/ResNet18_pretrained.pdparams +infer_model:../inference/ +infer_export:True +infer_quant:Fasle +inference:python/predict_cls.py -c configs/inference_cls.yaml +-o Global.use_gpu:True|False +-o Global.enable_mkldnn:True|False +-o Global.cpu_num_threads:1|6 +-o Global.batch_size:1|16 +-o Global.use_tensorrt:True|False +-o Global.use_fp16:True|False +-o Global.inference_model_dir:../inference +-o Global.infer_imgs:../dataset/ILSVRC2012/val +-o Global.save_log_path:null +-o Global.benchmark:True +null:null +null:null +===========================infer_benchmark_params========================== +random_infer_input:[{float32,[3,224,224]}]