From 26207a8c77e5ed27ff4f9e00d698db475cadb4d4 Mon Sep 17 00:00:00 2001 From: littletomatodonkey Date: Tue, 19 Jul 2022 10:50:51 +0800 Subject: [PATCH] add mgd loss (#2161) * add mgd loss * add init * fix doc --- .../knowledge_distillation.md | 74 +++- .../resnet34_distill_resnet18_mgd.yaml | 159 +++++++++ ppcls/loss/__init__.py | 1 + ppcls/loss/distillationloss.py | 44 +++ ppcls/loss/mgd_loss.py | 84 +++++ ppcls/utils/initializer.py | 318 ++++++++++++++++++ 6 files changed, 678 insertions(+), 2 deletions(-) create mode 100644 ppcls/configs/ImageNet/Distillation/resnet34_distill_resnet18_mgd.yaml create mode 100644 ppcls/loss/mgd_loss.py create mode 100644 ppcls/utils/initializer.py diff --git a/docs/zh_CN/advanced_tutorials/knowledge_distillation.md b/docs/zh_CN/advanced_tutorials/knowledge_distillation.md index 18bb25f2e..5a3584376 100644 --- a/docs/zh_CN/advanced_tutorials/knowledge_distillation.md +++ b/docs/zh_CN/advanced_tutorials/knowledge_distillation.md @@ -15,6 +15,7 @@ - [1.2.4 AFD](#1.2.4) - [1.2.5 DKD](#1.2.5) - [1.2.6 DIST](#1.2.6) + - [1.2.7 MGD](#1.2.7) - [2. 使用方法](#2) - [2.1 环境配置](#2.1) - [2.2 数据准备](#2.2) @@ -24,8 +25,6 @@ - [2.6 模型导出与推理](#2.6) - [3. 参考文献](#3) - - ## 1. 算法介绍 @@ -512,6 +511,77 @@ Loss: weight: 1.0 ``` + + +#### 1.2.7 MGD + +##### 1.2.7.1 MGD 算法介绍 + +论文信息: + + +> [Masked Generative Distillation](https://arxiv.org/abs/2205.01529) +> +> Zhendong Yang, Zhe Li, Mingqi Shao, Dachuan Shi, Zehuan Yuan, Chun Yuan +> +> ECCV 2022 + +该方法针对特征图展开蒸馏,在蒸馏的过程中,对特征进行随机mask,强制学生用部分特征去生成教师模型的所有特征,以提升学生模型的表征能力,最终在特征蒸馏任务上达到了SOTA,并在检测、分割等任务中广泛验证有效。 + +在ImageNet1k公开数据集上,效果如下所示。 + +| 策略 | 骨干网络 | 配置文件 | Top-1 acc | 下载链接 | +| --- | --- | --- | --- | --- | +| baseline | ResNet18 | [ResNet18.yaml](../../../ppcls/configs/ImageNet/ResNet/ResNet18.yaml) | 70.8% | - | +| MGD | ResNet18 | [resnet34_distill_resnet18_dist.yaml](../../../ppcls/configs/ImageNet/Distillation/resnet34_distill_resnet18_mgd.yaml) | 71.86%(**+1.06%**) | - | + + +##### 1.2.7.2 MGD 配置 + +MGD 配置如下所示。在模型构建Arch字段中,需要同时定义学生模型与教师模型,教师模型固定参数,且需要加载预训练模型。在损失函数Loss字段中,需要定义`DistillationPairLoss`(学生与教师模型之间的MGDLoss)以及`DistillationGTCELoss`(学生与教师关于真值标签的CE loss),作为训练的损失函数。 + +```yaml +Arch: + name: "DistillationModel" + class_num: &class_num 1000 + # if not null, its lengths should be same as models + pretrained_list: + # if not null, its lengths should be same as models + freeze_params_list: + - True + - False + infer_model_name: "Student" + models: + - Teacher: + name: ResNet34 + class_num: *class_num + pretrained: True + return_patterns: &t_stages ["blocks[2]", "blocks[6]", "blocks[12]", "blocks[15]"] + - Student: + name: ResNet18 + class_num: *class_num + pretrained: False + return_patterns: &s_stages ["blocks[1]", "blocks[3]", "blocks[5]", "blocks[7]"] + +# loss function config for traing/eval process +Loss: + Train: + - DistillationGTCELoss: + weight: 1.0 + model_names: ["Student"] + - DistillationPairLoss: + weight: 1.0 + model_name_pairs: [["Student", "Teacher"]] # calculate mgdloss for Student and Teacher + name: "loss_mgd" + base_loss_name: MGDLoss # MGD loss,the following are parameters of 'MGD loss' + s_keys: ["blocks[7]"] # feature map used to calculate MGD loss in student model + t_keys: ["blocks[15]"] # feature map used to calculate MGD loss in teacher model + student_channels: 512 # channel num for stduent feature map + teacher_channels: 512 # channel num for teacher feature map + Eval: + - CELoss: + weight: 1.0 +``` diff --git a/ppcls/configs/ImageNet/Distillation/resnet34_distill_resnet18_mgd.yaml b/ppcls/configs/ImageNet/Distillation/resnet34_distill_resnet18_mgd.yaml new file mode 100644 index 000000000..5f2260e9c --- /dev/null +++ b/ppcls/configs/ImageNet/Distillation/resnet34_distill_resnet18_mgd.yaml @@ -0,0 +1,159 @@ +# global configs +Global: + checkpoints: null + pretrained_model: null + output_dir: ./output/r34_r18_mgd + device: gpu + save_interval: 1 + eval_during_train: True + eval_interval: 1 + epochs: 100 + print_batch_step: 10 + use_visualdl: False + # used for static mode and model export + image_shape: [3, 224, 224] + save_inference_dir: ./inference + to_static: False + +# model architecture +Arch: + name: "DistillationModel" + class_num: &class_num 1000 + # if not null, its lengths should be same as models + pretrained_list: + # if not null, its lengths should be same as models + freeze_params_list: + - True + - False + infer_model_name: "Student" + models: + - Teacher: + name: ResNet34 + class_num: *class_num + pretrained: True + return_patterns: &t_stages ["blocks[2]", "blocks[6]", "blocks[12]", "blocks[15]"] + - Student: + name: ResNet18 + class_num: *class_num + pretrained: False + return_patterns: &s_stages ["blocks[1]", "blocks[3]", "blocks[5]", "blocks[7]"] + +# loss function config for traing/eval process +Loss: + Train: + - DistillationGTCELoss: + weight: 1.0 + model_names: ["Student"] + - DistillationPairLoss: + weight: 1.0 + base_loss_name: MGDLoss + model_name_pairs: [["Student", "Teacher"]] + s_keys: ["blocks[7]"] + t_keys: ["blocks[15]"] + name: "loss_mgd" + student_channels: 512 + teacher_channels: 512 + Eval: + - CELoss: + weight: 1.0 + +Optimizer: + name: Momentum + momentum: 0.9 + weight_decay: 1e-4 + lr: + name: Piecewise + learning_rate: 0.1 + decay_epochs: [30, 60, 90] + values: [0.1, 0.01, 0.001, 0.0001] + + +# data loader for train and eval +DataLoader: + Train: + dataset: + name: ImageNetDataset + image_root: ./dataset/ILSVRC2012/ + cls_label_path: ./dataset/ILSVRC2012/train_list.txt + transform_ops: + - DecodeImage: + to_rgb: True + channel_first: False + - RandCropImage: + size: 224 + - RandFlipImage: + flip_code: 1 + - NormalizeImage: + scale: 1.0/255.0 + mean: [0.485, 0.456, 0.406] + std: [0.229, 0.224, 0.225] + order: '' + + sampler: + name: DistributedBatchSampler + batch_size: 64 + drop_last: False + shuffle: True + loader: + num_workers: 8 + use_shared_memory: True + + Eval: + dataset: + name: ImageNetDataset + image_root: ./dataset/ILSVRC2012/ + cls_label_path: ./dataset/ILSVRC2012/val_list.txt + transform_ops: + - DecodeImage: + to_rgb: True + channel_first: False + - ResizeImage: + resize_short: 256 + - CropImage: + size: 224 + - NormalizeImage: + scale: 1.0/255.0 + mean: [0.485, 0.456, 0.406] + std: [0.229, 0.224, 0.225] + order: '' + sampler: + name: DistributedBatchSampler + batch_size: 256 + drop_last: False + shuffle: False + loader: + num_workers: 8 + use_shared_memory: True + +Infer: + infer_imgs: docs/images/inference_deployment/whl_demo.jpg + batch_size: 10 + transforms: + - DecodeImage: + to_rgb: True + channel_first: False + - ResizeImage: + resize_short: 256 + - CropImage: + size: 224 + - NormalizeImage: + scale: 1.0/255.0 + mean: [0.485, 0.456, 0.406] + std: [0.229, 0.224, 0.225] + order: '' + - ToCHWImage: + PostProcess: + name: Topk + topk: 5 + class_id_map_file: ppcls/utils/imagenet1k_label_list.txt + +Metric: + Train: + - DistillationTopkAcc: + model_key: "Student" + topk: [1, 5] + Eval: + - DistillationTopkAcc: + model_key: "Student" + topk: [1, 5] + diff --git a/ppcls/loss/__init__.py b/ppcls/loss/__init__.py index 5a62e0156..489aea7fb 100644 --- a/ppcls/loss/__init__.py +++ b/ppcls/loss/__init__.py @@ -26,6 +26,7 @@ from .distillationloss import DistillationKLDivLoss from .distillationloss import DistillationDKDLoss from .distillationloss import DistillationMultiLabelLoss from .distillationloss import DistillationDISTLoss +from .distillationloss import DistillationPairLoss from .multilabelloss import MultiLabelLoss from .afdloss import AFDLoss diff --git a/ppcls/loss/distillationloss.py b/ppcls/loss/distillationloss.py index 8537fc548..5a924afe7 100644 --- a/ppcls/loss/distillationloss.py +++ b/ppcls/loss/distillationloss.py @@ -24,6 +24,7 @@ from .kldivloss import KLDivLoss from .dkdloss import DKDLoss from .dist_loss import DISTLoss from .multilabelloss import MultiLabelLoss +from .mgd_loss import MGDLoss class DistillationCELoss(CELoss): @@ -319,3 +320,46 @@ class DistillationDISTLoss(DISTLoss): loss = super().forward(out1, out2) loss_dict[f"{self.name}_{pair[0]}_{pair[1]}"] = loss return loss_dict + + +class DistillationPairLoss(nn.Layer): + """ + DistillationPairLoss + """ + + def __init__(self, + base_loss_name, + model_name_pairs=[], + s_keys=None, + t_keys=None, + name="loss", + **kwargs): + super().__init__() + self.loss_func = eval(base_loss_name)(**kwargs) + if not isinstance(s_keys, list): + s_keys = [s_keys] + if not isinstance(t_keys, list): + t_keys = [t_keys] + self.s_keys = s_keys + self.t_keys = t_keys + self.model_name_pairs = model_name_pairs + self.name = name + + def forward(self, predicts, batch): + loss_dict = dict() + for idx, pair in enumerate(self.model_name_pairs): + out1 = predicts[pair[0]] + out2 = predicts[pair[1]] + out1 = [out1[k] if k is not None else out1 for k in self.s_keys] + out2 = [out2[k] if k is not None else out2 for k in self.t_keys] + for feat_idx, (o1, o2) in enumerate(zip(out1, out2)): + loss = self.loss_func.forward(o1, o2) + if isinstance(loss, dict): + for k in loss: + loss_dict[ + f"{self.name}_{idx}_{feat_idx}_{pair[0]}_{pair[1]}_{k}"] = loss[ + k] + else: + loss_dict[ + f"{self.name}_{idx}_{feat_idx}_{pair[0]}_{pair[1]}"] = loss + return loss_dict diff --git a/ppcls/loss/mgd_loss.py b/ppcls/loss/mgd_loss.py new file mode 100644 index 000000000..799a91431 --- /dev/null +++ b/ppcls/loss/mgd_loss.py @@ -0,0 +1,84 @@ +# copyright (c) 2022 PaddlePaddle Authors. All Rights Reserve. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import paddle +import paddle.nn as nn +import paddle.nn.functional as F +from ppcls.utils.initializer import kaiming_normal_ + + +class MGDLoss(nn.Layer): + """Paddle version of `Masked Generative Distillation` + MGDLoss + Reference: https://arxiv.org/abs/2205.01529 + Code was heavily based on https://github.com/yzd-v/MGD + """ + + def __init__( + self, + student_channels, + teacher_channels, + alpha_mgd=1.756, + lambda_mgd=0.15, ): + super().__init__() + self.alpha_mgd = alpha_mgd + self.lambda_mgd = lambda_mgd + + if student_channels != teacher_channels: + self.align = nn.Conv2D( + student_channels, + teacher_channels, + kernel_size=1, + stride=1, + padding=0) + else: + self.align = None + + self.generation = nn.Sequential( + nn.Conv2D( + teacher_channels, teacher_channels, kernel_size=3, padding=1), + nn.ReLU(), + nn.Conv2D( + teacher_channels, teacher_channels, kernel_size=3, padding=1)) + + self.apply(self._init_weights) + + def _init_weights(self, m): + if isinstance(m, nn.Conv2D): + kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu") + + def forward(self, pred_s, pred_t): + """Forward function. + Args: + pred_s(Tensor): Bs*C*H*W, student's feature map + pred_t(Tensor): Bs*C*H*W, teacher's feature map + """ + assert pred_s.shape[-2:] == pred_t.shape[-2:] + + if self.align is not None: + pred_s = self.align(pred_s) + + loss = self.get_dis_loss(pred_s, pred_t) * self.alpha_mgd + + return loss + + def get_dis_loss(self, pred_s, pred_t): + loss_mse = nn.MSELoss(reduction='mean') + N, C, _, _ = pred_t.shape + mat = paddle.rand([N, C, 1, 1]) + mat = paddle.where(mat < self.lambda_mgd, 0, 1).astype("float32") + masked_fea = paddle.multiply(pred_s, mat) + new_fea = self.generation(masked_fea) + dis_loss = loss_mse(new_fea, pred_t) + return dis_loss diff --git a/ppcls/utils/initializer.py b/ppcls/utils/initializer.py new file mode 100644 index 000000000..b044e8088 --- /dev/null +++ b/ppcls/utils/initializer.py @@ -0,0 +1,318 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +This code is based on https://github.com/pytorch/pytorch/blob/master/torch/nn/init.py +Ths copyright of pytorch/pytorch is a BSD-style license, as found in the LICENSE file. +""" + +import math +import numpy as np + +import paddle +import paddle.nn as nn + +__all__ = [ + 'uniform_', + 'normal_', + 'constant_', + 'ones_', + 'zeros_', + 'xavier_uniform_', + 'xavier_normal_', + 'kaiming_uniform_', + 'kaiming_normal_', + 'linear_init_', + 'conv_init_', + 'reset_initialized_parameter', +] + + +def _no_grad_uniform_(tensor, a, b): + with paddle.no_grad(): + tensor.set_value( + paddle.uniform( + shape=tensor.shape, dtype=tensor.dtype, min=a, max=b)) + return tensor + + +def _no_grad_normal_(tensor, mean=0., std=1.): + with paddle.no_grad(): + tensor.set_value(paddle.normal(mean=mean, std=std, shape=tensor.shape)) + return tensor + + +def _no_grad_fill_(tensor, value=0.): + with paddle.no_grad(): + tensor.set_value(paddle.full_like(tensor, value, dtype=tensor.dtype)) + return tensor + + +def uniform_(tensor, a, b): + """ + Modified tensor inspace using uniform_ + Args: + tensor (paddle.Tensor): paddle Tensor + a (float|int): min value. + b (float|int): max value. + Return: + tensor + """ + return _no_grad_uniform_(tensor, a, b) + + +def normal_(tensor, mean=0., std=1.): + """ + Modified tensor inspace using normal_ + Args: + tensor (paddle.Tensor): paddle Tensor + mean (float|int): mean value. + std (float|int): std value. + Return: + tensor + """ + return _no_grad_normal_(tensor, mean, std) + + +def constant_(tensor, value=0.): + """ + Modified tensor inspace using constant_ + Args: + tensor (paddle.Tensor): paddle Tensor + value (float|int): value to fill tensor. + Return: + tensor + """ + return _no_grad_fill_(tensor, value) + + +def ones_(tensor): + """ + Modified tensor inspace using ones_ + Args: + tensor (paddle.Tensor): paddle Tensor + Return: + tensor + """ + return _no_grad_fill_(tensor, 1) + + +def zeros_(tensor): + """ + Modified tensor inspace using zeros_ + Args: + tensor (paddle.Tensor): paddle Tensor + Return: + tensor + """ + return _no_grad_fill_(tensor, 0) + + +def _calculate_fan_in_and_fan_out(tensor, reverse=False): + """ + Calculate (fan_in, _fan_out) for tensor + + Args: + tensor (Tensor): paddle.Tensor + reverse (bool: False): tensor data format order, False by default as [fout, fin, ...]. e.g. : conv.weight [cout, cin, kh, kw] is False; linear.weight [cin, cout] is True + + Return: + Tuple[fan_in, fan_out] + """ + if tensor.ndim < 2: + raise ValueError( + "Fan in and fan out can not be computed for tensor with fewer than 2 dimensions" + ) + + if reverse: + num_input_fmaps, num_output_fmaps = tensor.shape[0], tensor.shape[1] + else: + num_input_fmaps, num_output_fmaps = tensor.shape[1], tensor.shape[0] + + receptive_field_size = 1 + if tensor.ndim > 2: + receptive_field_size = np.prod(tensor.shape[2:]) + + fan_in = num_input_fmaps * receptive_field_size + fan_out = num_output_fmaps * receptive_field_size + + return fan_in, fan_out + + +def xavier_uniform_(tensor, gain=1., reverse=False): + """ + Modified tensor inspace using xavier_uniform_ + Args: + tensor (paddle.Tensor): paddle Tensor + gain (float): super parameter, 1. default. + reverse (bool): reverse (bool: False): tensor data format order, False by default as [fout, fin, ...]. + Return: + tensor + """ + fan_in, fan_out = _calculate_fan_in_and_fan_out(tensor, reverse=reverse) + std = gain * math.sqrt(2.0 / float(fan_in + fan_out)) + k = math.sqrt(3.0) * std + return _no_grad_uniform_(tensor, -k, k) + + +def xavier_normal_(tensor, gain=1., reverse=False): + """ + Modified tensor inspace using xavier_normal_ + Args: + tensor (paddle.Tensor): paddle Tensor + gain (float): super parameter, 1. default. + reverse (bool): reverse (bool: False): tensor data format order, False by default as [fout, fin, ...]. + Return: + tensor + """ + fan_in, fan_out = _calculate_fan_in_and_fan_out(tensor, reverse=reverse) + std = gain * math.sqrt(2.0 / float(fan_in + fan_out)) + return _no_grad_normal_(tensor, 0, std) + + +# reference: https://pytorch.org/docs/stable/_modules/torch/nn/init.html +def _calculate_correct_fan(tensor, mode, reverse=False): + mode = mode.lower() + valid_modes = ['fan_in', 'fan_out'] + if mode not in valid_modes: + raise ValueError("Mode {} not supported, please use one of {}".format( + mode, valid_modes)) + + fan_in, fan_out = _calculate_fan_in_and_fan_out(tensor, reverse) + + return fan_in if mode == 'fan_in' else fan_out + + +def _calculate_gain(nonlinearity, param=None): + linear_fns = [ + 'linear', 'conv1d', 'conv2d', 'conv3d', 'conv_transpose1d', + 'conv_transpose2d', 'conv_transpose3d' + ] + if nonlinearity in linear_fns or nonlinearity == 'sigmoid': + return 1 + elif nonlinearity == 'tanh': + return 5.0 / 3 + elif nonlinearity == 'relu': + return math.sqrt(2.0) + elif nonlinearity == 'leaky_relu': + if param is None: + negative_slope = 0.01 + elif not isinstance(param, bool) and isinstance( + param, int) or isinstance(param, float): + # True/False are instances of int, hence check above + negative_slope = param + else: + raise ValueError("negative_slope {} not a valid number".format( + param)) + return math.sqrt(2.0 / (1 + negative_slope**2)) + elif nonlinearity == 'selu': + return 3.0 / 4 + else: + raise ValueError("Unsupported nonlinearity {}".format(nonlinearity)) + + +def kaiming_uniform_(tensor, + a=0, + mode='fan_in', + nonlinearity='leaky_relu', + reverse=False): + """ + Modified tensor inspace using kaiming_uniform method + Args: + tensor (paddle.Tensor): paddle Tensor + mode (str): ['fan_in', 'fan_out'], 'fin_in' defalut + nonlinearity (str): nonlinearity method name + reverse (bool): reverse (bool: False): tensor data format order, False by default as [fout, fin, ...]. + Return: + tensor + """ + fan = _calculate_correct_fan(tensor, mode, reverse) + gain = _calculate_gain(nonlinearity, a) + std = gain / math.sqrt(fan) + k = math.sqrt(3.0) * std + return _no_grad_uniform_(tensor, -k, k) + + +def kaiming_normal_(tensor, + a=0, + mode='fan_in', + nonlinearity='leaky_relu', + reverse=False): + """ + Modified tensor inspace using kaiming_normal_ + Args: + tensor (paddle.Tensor): paddle Tensor + mode (str): ['fan_in', 'fan_out'], 'fin_in' defalut + nonlinearity (str): nonlinearity method name + reverse (bool): reverse (bool: False): tensor data format order, False by default as [fout, fin, ...]. + Return: + tensor + """ + fan = _calculate_correct_fan(tensor, mode, reverse) + gain = _calculate_gain(nonlinearity, a) + std = gain / math.sqrt(fan) + return _no_grad_normal_(tensor, 0, std) + + +def linear_init_(module): + bound = 1 / math.sqrt(module.weight.shape[0]) + uniform_(module.weight, -bound, bound) + uniform_(module.bias, -bound, bound) + + +def conv_init_(module): + bound = 1 / np.sqrt(np.prod(module.weight.shape[1:])) + uniform_(module.weight, -bound, bound) + if module.bias is not None: + uniform_(module.bias, -bound, bound) + + +def bias_init_with_prob(prior_prob=0.01): + """initialize conv/fc bias value according to a given probability value.""" + bias_init = float(-np.log((1 - prior_prob) / prior_prob)) + return bias_init + + +@paddle.no_grad() +def reset_initialized_parameter(model, include_self=True): + """ + Reset initialized parameter using following method for [conv, linear, embedding, bn] + + Args: + model (paddle.Layer): paddle Layer + include_self (bool: False): include_self for Layer.named_sublayers method. Indicate whether including itself + Return: + None + """ + for _, m in model.named_sublayers(include_self=include_self): + if isinstance(m, nn.Conv2D): + k = float(m._groups) / (m._in_channels * m._kernel_size[0] * + m._kernel_size[1]) + k = math.sqrt(k) + _no_grad_uniform_(m.weight, -k, k) + if hasattr(m, 'bias') and getattr(m, 'bias') is not None: + _no_grad_uniform_(m.bias, -k, k) + + elif isinstance(m, nn.Linear): + k = math.sqrt(1. / m.weight.shape[0]) + _no_grad_uniform_(m.weight, -k, k) + if hasattr(m, 'bias') and getattr(m, 'bias') is not None: + _no_grad_uniform_(m.bias, -k, k) + + elif isinstance(m, nn.Embedding): + _no_grad_normal_(m.weight, mean=0., std=1.) + + elif isinstance(m, (nn.BatchNorm2D, nn.LayerNorm)): + _no_grad_fill_(m.weight, 1.) + if hasattr(m, 'bias') and getattr(m, 'bias') is not None: + _no_grad_fill_(m.bias, 0)