mirror of
https://github.com/PaddlePaddle/PaddleClas.git
synced 2025-06-03 21:55:06 +08:00
parent
973cdef15c
commit
26207a8c77
@ -15,6 +15,7 @@
|
|||||||
- [1.2.4 AFD](#1.2.4)
|
- [1.2.4 AFD](#1.2.4)
|
||||||
- [1.2.5 DKD](#1.2.5)
|
- [1.2.5 DKD](#1.2.5)
|
||||||
- [1.2.6 DIST](#1.2.6)
|
- [1.2.6 DIST](#1.2.6)
|
||||||
|
- [1.2.7 MGD](#1.2.7)
|
||||||
- [2. 使用方法](#2)
|
- [2. 使用方法](#2)
|
||||||
- [2.1 环境配置](#2.1)
|
- [2.1 环境配置](#2.1)
|
||||||
- [2.2 数据准备](#2.2)
|
- [2.2 数据准备](#2.2)
|
||||||
@ -24,8 +25,6 @@
|
|||||||
- [2.6 模型导出与推理](#2.6)
|
- [2.6 模型导出与推理](#2.6)
|
||||||
- [3. 参考文献](#3)
|
- [3. 参考文献](#3)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
<a name="1"></a>
|
<a name="1"></a>
|
||||||
|
|
||||||
## 1. 算法介绍
|
## 1. 算法介绍
|
||||||
@ -512,6 +511,77 @@ Loss:
|
|||||||
weight: 1.0
|
weight: 1.0
|
||||||
```
|
```
|
||||||
|
|
||||||
|
<a name='1.2.7'></a>
|
||||||
|
|
||||||
|
#### 1.2.7 MGD
|
||||||
|
|
||||||
|
##### 1.2.7.1 MGD 算法介绍
|
||||||
|
|
||||||
|
论文信息:
|
||||||
|
|
||||||
|
|
||||||
|
> [Masked Generative Distillation](https://arxiv.org/abs/2205.01529)
|
||||||
|
>
|
||||||
|
> Zhendong Yang, Zhe Li, Mingqi Shao, Dachuan Shi, Zehuan Yuan, Chun Yuan
|
||||||
|
>
|
||||||
|
> ECCV 2022
|
||||||
|
|
||||||
|
该方法针对特征图展开蒸馏,在蒸馏的过程中,对特征进行随机mask,强制学生用部分特征去生成教师模型的所有特征,以提升学生模型的表征能力,最终在特征蒸馏任务上达到了SOTA,并在检测、分割等任务中广泛验证有效。
|
||||||
|
|
||||||
|
在ImageNet1k公开数据集上,效果如下所示。
|
||||||
|
|
||||||
|
| 策略 | 骨干网络 | 配置文件 | Top-1 acc | 下载链接 |
|
||||||
|
| --- | --- | --- | --- | --- |
|
||||||
|
| baseline | ResNet18 | [ResNet18.yaml](../../../ppcls/configs/ImageNet/ResNet/ResNet18.yaml) | 70.8% | - |
|
||||||
|
| MGD | ResNet18 | [resnet34_distill_resnet18_dist.yaml](../../../ppcls/configs/ImageNet/Distillation/resnet34_distill_resnet18_mgd.yaml) | 71.86%(**+1.06%**) | - |
|
||||||
|
|
||||||
|
|
||||||
|
##### 1.2.7.2 MGD 配置
|
||||||
|
|
||||||
|
MGD 配置如下所示。在模型构建Arch字段中,需要同时定义学生模型与教师模型,教师模型固定参数,且需要加载预训练模型。在损失函数Loss字段中,需要定义`DistillationPairLoss`(学生与教师模型之间的MGDLoss)以及`DistillationGTCELoss`(学生与教师关于真值标签的CE loss),作为训练的损失函数。
|
||||||
|
|
||||||
|
```yaml
|
||||||
|
Arch:
|
||||||
|
name: "DistillationModel"
|
||||||
|
class_num: &class_num 1000
|
||||||
|
# if not null, its lengths should be same as models
|
||||||
|
pretrained_list:
|
||||||
|
# if not null, its lengths should be same as models
|
||||||
|
freeze_params_list:
|
||||||
|
- True
|
||||||
|
- False
|
||||||
|
infer_model_name: "Student"
|
||||||
|
models:
|
||||||
|
- Teacher:
|
||||||
|
name: ResNet34
|
||||||
|
class_num: *class_num
|
||||||
|
pretrained: True
|
||||||
|
return_patterns: &t_stages ["blocks[2]", "blocks[6]", "blocks[12]", "blocks[15]"]
|
||||||
|
- Student:
|
||||||
|
name: ResNet18
|
||||||
|
class_num: *class_num
|
||||||
|
pretrained: False
|
||||||
|
return_patterns: &s_stages ["blocks[1]", "blocks[3]", "blocks[5]", "blocks[7]"]
|
||||||
|
|
||||||
|
# loss function config for traing/eval process
|
||||||
|
Loss:
|
||||||
|
Train:
|
||||||
|
- DistillationGTCELoss:
|
||||||
|
weight: 1.0
|
||||||
|
model_names: ["Student"]
|
||||||
|
- DistillationPairLoss:
|
||||||
|
weight: 1.0
|
||||||
|
model_name_pairs: [["Student", "Teacher"]] # calculate mgdloss for Student and Teacher
|
||||||
|
name: "loss_mgd"
|
||||||
|
base_loss_name: MGDLoss # MGD loss,the following are parameters of 'MGD loss'
|
||||||
|
s_keys: ["blocks[7]"] # feature map used to calculate MGD loss in student model
|
||||||
|
t_keys: ["blocks[15]"] # feature map used to calculate MGD loss in teacher model
|
||||||
|
student_channels: 512 # channel num for stduent feature map
|
||||||
|
teacher_channels: 512 # channel num for teacher feature map
|
||||||
|
Eval:
|
||||||
|
- CELoss:
|
||||||
|
weight: 1.0
|
||||||
|
```
|
||||||
|
|
||||||
<a name="2"></a>
|
<a name="2"></a>
|
||||||
|
|
||||||
|
@ -0,0 +1,159 @@
|
|||||||
|
# global configs
|
||||||
|
Global:
|
||||||
|
checkpoints: null
|
||||||
|
pretrained_model: null
|
||||||
|
output_dir: ./output/r34_r18_mgd
|
||||||
|
device: gpu
|
||||||
|
save_interval: 1
|
||||||
|
eval_during_train: True
|
||||||
|
eval_interval: 1
|
||||||
|
epochs: 100
|
||||||
|
print_batch_step: 10
|
||||||
|
use_visualdl: False
|
||||||
|
# used for static mode and model export
|
||||||
|
image_shape: [3, 224, 224]
|
||||||
|
save_inference_dir: ./inference
|
||||||
|
to_static: False
|
||||||
|
|
||||||
|
# model architecture
|
||||||
|
Arch:
|
||||||
|
name: "DistillationModel"
|
||||||
|
class_num: &class_num 1000
|
||||||
|
# if not null, its lengths should be same as models
|
||||||
|
pretrained_list:
|
||||||
|
# if not null, its lengths should be same as models
|
||||||
|
freeze_params_list:
|
||||||
|
- True
|
||||||
|
- False
|
||||||
|
infer_model_name: "Student"
|
||||||
|
models:
|
||||||
|
- Teacher:
|
||||||
|
name: ResNet34
|
||||||
|
class_num: *class_num
|
||||||
|
pretrained: True
|
||||||
|
return_patterns: &t_stages ["blocks[2]", "blocks[6]", "blocks[12]", "blocks[15]"]
|
||||||
|
- Student:
|
||||||
|
name: ResNet18
|
||||||
|
class_num: *class_num
|
||||||
|
pretrained: False
|
||||||
|
return_patterns: &s_stages ["blocks[1]", "blocks[3]", "blocks[5]", "blocks[7]"]
|
||||||
|
|
||||||
|
# loss function config for traing/eval process
|
||||||
|
Loss:
|
||||||
|
Train:
|
||||||
|
- DistillationGTCELoss:
|
||||||
|
weight: 1.0
|
||||||
|
model_names: ["Student"]
|
||||||
|
- DistillationPairLoss:
|
||||||
|
weight: 1.0
|
||||||
|
base_loss_name: MGDLoss
|
||||||
|
model_name_pairs: [["Student", "Teacher"]]
|
||||||
|
s_keys: ["blocks[7]"]
|
||||||
|
t_keys: ["blocks[15]"]
|
||||||
|
name: "loss_mgd"
|
||||||
|
student_channels: 512
|
||||||
|
teacher_channels: 512
|
||||||
|
Eval:
|
||||||
|
- CELoss:
|
||||||
|
weight: 1.0
|
||||||
|
|
||||||
|
Optimizer:
|
||||||
|
name: Momentum
|
||||||
|
momentum: 0.9
|
||||||
|
weight_decay: 1e-4
|
||||||
|
lr:
|
||||||
|
name: Piecewise
|
||||||
|
learning_rate: 0.1
|
||||||
|
decay_epochs: [30, 60, 90]
|
||||||
|
values: [0.1, 0.01, 0.001, 0.0001]
|
||||||
|
|
||||||
|
|
||||||
|
# data loader for train and eval
|
||||||
|
DataLoader:
|
||||||
|
Train:
|
||||||
|
dataset:
|
||||||
|
name: ImageNetDataset
|
||||||
|
image_root: ./dataset/ILSVRC2012/
|
||||||
|
cls_label_path: ./dataset/ILSVRC2012/train_list.txt
|
||||||
|
transform_ops:
|
||||||
|
- DecodeImage:
|
||||||
|
to_rgb: True
|
||||||
|
channel_first: False
|
||||||
|
- RandCropImage:
|
||||||
|
size: 224
|
||||||
|
- RandFlipImage:
|
||||||
|
flip_code: 1
|
||||||
|
- NormalizeImage:
|
||||||
|
scale: 1.0/255.0
|
||||||
|
mean: [0.485, 0.456, 0.406]
|
||||||
|
std: [0.229, 0.224, 0.225]
|
||||||
|
order: ''
|
||||||
|
|
||||||
|
sampler:
|
||||||
|
name: DistributedBatchSampler
|
||||||
|
batch_size: 64
|
||||||
|
drop_last: False
|
||||||
|
shuffle: True
|
||||||
|
loader:
|
||||||
|
num_workers: 8
|
||||||
|
use_shared_memory: True
|
||||||
|
|
||||||
|
Eval:
|
||||||
|
dataset:
|
||||||
|
name: ImageNetDataset
|
||||||
|
image_root: ./dataset/ILSVRC2012/
|
||||||
|
cls_label_path: ./dataset/ILSVRC2012/val_list.txt
|
||||||
|
transform_ops:
|
||||||
|
- DecodeImage:
|
||||||
|
to_rgb: True
|
||||||
|
channel_first: False
|
||||||
|
- ResizeImage:
|
||||||
|
resize_short: 256
|
||||||
|
- CropImage:
|
||||||
|
size: 224
|
||||||
|
- NormalizeImage:
|
||||||
|
scale: 1.0/255.0
|
||||||
|
mean: [0.485, 0.456, 0.406]
|
||||||
|
std: [0.229, 0.224, 0.225]
|
||||||
|
order: ''
|
||||||
|
sampler:
|
||||||
|
name: DistributedBatchSampler
|
||||||
|
batch_size: 256
|
||||||
|
drop_last: False
|
||||||
|
shuffle: False
|
||||||
|
loader:
|
||||||
|
num_workers: 8
|
||||||
|
use_shared_memory: True
|
||||||
|
|
||||||
|
Infer:
|
||||||
|
infer_imgs: docs/images/inference_deployment/whl_demo.jpg
|
||||||
|
batch_size: 10
|
||||||
|
transforms:
|
||||||
|
- DecodeImage:
|
||||||
|
to_rgb: True
|
||||||
|
channel_first: False
|
||||||
|
- ResizeImage:
|
||||||
|
resize_short: 256
|
||||||
|
- CropImage:
|
||||||
|
size: 224
|
||||||
|
- NormalizeImage:
|
||||||
|
scale: 1.0/255.0
|
||||||
|
mean: [0.485, 0.456, 0.406]
|
||||||
|
std: [0.229, 0.224, 0.225]
|
||||||
|
order: ''
|
||||||
|
- ToCHWImage:
|
||||||
|
PostProcess:
|
||||||
|
name: Topk
|
||||||
|
topk: 5
|
||||||
|
class_id_map_file: ppcls/utils/imagenet1k_label_list.txt
|
||||||
|
|
||||||
|
Metric:
|
||||||
|
Train:
|
||||||
|
- DistillationTopkAcc:
|
||||||
|
model_key: "Student"
|
||||||
|
topk: [1, 5]
|
||||||
|
Eval:
|
||||||
|
- DistillationTopkAcc:
|
||||||
|
model_key: "Student"
|
||||||
|
topk: [1, 5]
|
||||||
|
|
@ -26,6 +26,7 @@ from .distillationloss import DistillationKLDivLoss
|
|||||||
from .distillationloss import DistillationDKDLoss
|
from .distillationloss import DistillationDKDLoss
|
||||||
from .distillationloss import DistillationMultiLabelLoss
|
from .distillationloss import DistillationMultiLabelLoss
|
||||||
from .distillationloss import DistillationDISTLoss
|
from .distillationloss import DistillationDISTLoss
|
||||||
|
from .distillationloss import DistillationPairLoss
|
||||||
|
|
||||||
from .multilabelloss import MultiLabelLoss
|
from .multilabelloss import MultiLabelLoss
|
||||||
from .afdloss import AFDLoss
|
from .afdloss import AFDLoss
|
||||||
|
@ -24,6 +24,7 @@ from .kldivloss import KLDivLoss
|
|||||||
from .dkdloss import DKDLoss
|
from .dkdloss import DKDLoss
|
||||||
from .dist_loss import DISTLoss
|
from .dist_loss import DISTLoss
|
||||||
from .multilabelloss import MultiLabelLoss
|
from .multilabelloss import MultiLabelLoss
|
||||||
|
from .mgd_loss import MGDLoss
|
||||||
|
|
||||||
|
|
||||||
class DistillationCELoss(CELoss):
|
class DistillationCELoss(CELoss):
|
||||||
@ -319,3 +320,46 @@ class DistillationDISTLoss(DISTLoss):
|
|||||||
loss = super().forward(out1, out2)
|
loss = super().forward(out1, out2)
|
||||||
loss_dict[f"{self.name}_{pair[0]}_{pair[1]}"] = loss
|
loss_dict[f"{self.name}_{pair[0]}_{pair[1]}"] = loss
|
||||||
return loss_dict
|
return loss_dict
|
||||||
|
|
||||||
|
|
||||||
|
class DistillationPairLoss(nn.Layer):
|
||||||
|
"""
|
||||||
|
DistillationPairLoss
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self,
|
||||||
|
base_loss_name,
|
||||||
|
model_name_pairs=[],
|
||||||
|
s_keys=None,
|
||||||
|
t_keys=None,
|
||||||
|
name="loss",
|
||||||
|
**kwargs):
|
||||||
|
super().__init__()
|
||||||
|
self.loss_func = eval(base_loss_name)(**kwargs)
|
||||||
|
if not isinstance(s_keys, list):
|
||||||
|
s_keys = [s_keys]
|
||||||
|
if not isinstance(t_keys, list):
|
||||||
|
t_keys = [t_keys]
|
||||||
|
self.s_keys = s_keys
|
||||||
|
self.t_keys = t_keys
|
||||||
|
self.model_name_pairs = model_name_pairs
|
||||||
|
self.name = name
|
||||||
|
|
||||||
|
def forward(self, predicts, batch):
|
||||||
|
loss_dict = dict()
|
||||||
|
for idx, pair in enumerate(self.model_name_pairs):
|
||||||
|
out1 = predicts[pair[0]]
|
||||||
|
out2 = predicts[pair[1]]
|
||||||
|
out1 = [out1[k] if k is not None else out1 for k in self.s_keys]
|
||||||
|
out2 = [out2[k] if k is not None else out2 for k in self.t_keys]
|
||||||
|
for feat_idx, (o1, o2) in enumerate(zip(out1, out2)):
|
||||||
|
loss = self.loss_func.forward(o1, o2)
|
||||||
|
if isinstance(loss, dict):
|
||||||
|
for k in loss:
|
||||||
|
loss_dict[
|
||||||
|
f"{self.name}_{idx}_{feat_idx}_{pair[0]}_{pair[1]}_{k}"] = loss[
|
||||||
|
k]
|
||||||
|
else:
|
||||||
|
loss_dict[
|
||||||
|
f"{self.name}_{idx}_{feat_idx}_{pair[0]}_{pair[1]}"] = loss
|
||||||
|
return loss_dict
|
||||||
|
84
ppcls/loss/mgd_loss.py
Normal file
84
ppcls/loss/mgd_loss.py
Normal file
@ -0,0 +1,84 @@
|
|||||||
|
# copyright (c) 2022 PaddlePaddle Authors. All Rights Reserve.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
import paddle
|
||||||
|
import paddle.nn as nn
|
||||||
|
import paddle.nn.functional as F
|
||||||
|
from ppcls.utils.initializer import kaiming_normal_
|
||||||
|
|
||||||
|
|
||||||
|
class MGDLoss(nn.Layer):
|
||||||
|
"""Paddle version of `Masked Generative Distillation`
|
||||||
|
MGDLoss
|
||||||
|
Reference: https://arxiv.org/abs/2205.01529
|
||||||
|
Code was heavily based on https://github.com/yzd-v/MGD
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
student_channels,
|
||||||
|
teacher_channels,
|
||||||
|
alpha_mgd=1.756,
|
||||||
|
lambda_mgd=0.15, ):
|
||||||
|
super().__init__()
|
||||||
|
self.alpha_mgd = alpha_mgd
|
||||||
|
self.lambda_mgd = lambda_mgd
|
||||||
|
|
||||||
|
if student_channels != teacher_channels:
|
||||||
|
self.align = nn.Conv2D(
|
||||||
|
student_channels,
|
||||||
|
teacher_channels,
|
||||||
|
kernel_size=1,
|
||||||
|
stride=1,
|
||||||
|
padding=0)
|
||||||
|
else:
|
||||||
|
self.align = None
|
||||||
|
|
||||||
|
self.generation = nn.Sequential(
|
||||||
|
nn.Conv2D(
|
||||||
|
teacher_channels, teacher_channels, kernel_size=3, padding=1),
|
||||||
|
nn.ReLU(),
|
||||||
|
nn.Conv2D(
|
||||||
|
teacher_channels, teacher_channels, kernel_size=3, padding=1))
|
||||||
|
|
||||||
|
self.apply(self._init_weights)
|
||||||
|
|
||||||
|
def _init_weights(self, m):
|
||||||
|
if isinstance(m, nn.Conv2D):
|
||||||
|
kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu")
|
||||||
|
|
||||||
|
def forward(self, pred_s, pred_t):
|
||||||
|
"""Forward function.
|
||||||
|
Args:
|
||||||
|
pred_s(Tensor): Bs*C*H*W, student's feature map
|
||||||
|
pred_t(Tensor): Bs*C*H*W, teacher's feature map
|
||||||
|
"""
|
||||||
|
assert pred_s.shape[-2:] == pred_t.shape[-2:]
|
||||||
|
|
||||||
|
if self.align is not None:
|
||||||
|
pred_s = self.align(pred_s)
|
||||||
|
|
||||||
|
loss = self.get_dis_loss(pred_s, pred_t) * self.alpha_mgd
|
||||||
|
|
||||||
|
return loss
|
||||||
|
|
||||||
|
def get_dis_loss(self, pred_s, pred_t):
|
||||||
|
loss_mse = nn.MSELoss(reduction='mean')
|
||||||
|
N, C, _, _ = pred_t.shape
|
||||||
|
mat = paddle.rand([N, C, 1, 1])
|
||||||
|
mat = paddle.where(mat < self.lambda_mgd, 0, 1).astype("float32")
|
||||||
|
masked_fea = paddle.multiply(pred_s, mat)
|
||||||
|
new_fea = self.generation(masked_fea)
|
||||||
|
dis_loss = loss_mse(new_fea, pred_t)
|
||||||
|
return dis_loss
|
318
ppcls/utils/initializer.py
Normal file
318
ppcls/utils/initializer.py
Normal file
@ -0,0 +1,318 @@
|
|||||||
|
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
"""
|
||||||
|
This code is based on https://github.com/pytorch/pytorch/blob/master/torch/nn/init.py
|
||||||
|
Ths copyright of pytorch/pytorch is a BSD-style license, as found in the LICENSE file.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import math
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
import paddle
|
||||||
|
import paddle.nn as nn
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
'uniform_',
|
||||||
|
'normal_',
|
||||||
|
'constant_',
|
||||||
|
'ones_',
|
||||||
|
'zeros_',
|
||||||
|
'xavier_uniform_',
|
||||||
|
'xavier_normal_',
|
||||||
|
'kaiming_uniform_',
|
||||||
|
'kaiming_normal_',
|
||||||
|
'linear_init_',
|
||||||
|
'conv_init_',
|
||||||
|
'reset_initialized_parameter',
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
def _no_grad_uniform_(tensor, a, b):
|
||||||
|
with paddle.no_grad():
|
||||||
|
tensor.set_value(
|
||||||
|
paddle.uniform(
|
||||||
|
shape=tensor.shape, dtype=tensor.dtype, min=a, max=b))
|
||||||
|
return tensor
|
||||||
|
|
||||||
|
|
||||||
|
def _no_grad_normal_(tensor, mean=0., std=1.):
|
||||||
|
with paddle.no_grad():
|
||||||
|
tensor.set_value(paddle.normal(mean=mean, std=std, shape=tensor.shape))
|
||||||
|
return tensor
|
||||||
|
|
||||||
|
|
||||||
|
def _no_grad_fill_(tensor, value=0.):
|
||||||
|
with paddle.no_grad():
|
||||||
|
tensor.set_value(paddle.full_like(tensor, value, dtype=tensor.dtype))
|
||||||
|
return tensor
|
||||||
|
|
||||||
|
|
||||||
|
def uniform_(tensor, a, b):
|
||||||
|
"""
|
||||||
|
Modified tensor inspace using uniform_
|
||||||
|
Args:
|
||||||
|
tensor (paddle.Tensor): paddle Tensor
|
||||||
|
a (float|int): min value.
|
||||||
|
b (float|int): max value.
|
||||||
|
Return:
|
||||||
|
tensor
|
||||||
|
"""
|
||||||
|
return _no_grad_uniform_(tensor, a, b)
|
||||||
|
|
||||||
|
|
||||||
|
def normal_(tensor, mean=0., std=1.):
|
||||||
|
"""
|
||||||
|
Modified tensor inspace using normal_
|
||||||
|
Args:
|
||||||
|
tensor (paddle.Tensor): paddle Tensor
|
||||||
|
mean (float|int): mean value.
|
||||||
|
std (float|int): std value.
|
||||||
|
Return:
|
||||||
|
tensor
|
||||||
|
"""
|
||||||
|
return _no_grad_normal_(tensor, mean, std)
|
||||||
|
|
||||||
|
|
||||||
|
def constant_(tensor, value=0.):
|
||||||
|
"""
|
||||||
|
Modified tensor inspace using constant_
|
||||||
|
Args:
|
||||||
|
tensor (paddle.Tensor): paddle Tensor
|
||||||
|
value (float|int): value to fill tensor.
|
||||||
|
Return:
|
||||||
|
tensor
|
||||||
|
"""
|
||||||
|
return _no_grad_fill_(tensor, value)
|
||||||
|
|
||||||
|
|
||||||
|
def ones_(tensor):
|
||||||
|
"""
|
||||||
|
Modified tensor inspace using ones_
|
||||||
|
Args:
|
||||||
|
tensor (paddle.Tensor): paddle Tensor
|
||||||
|
Return:
|
||||||
|
tensor
|
||||||
|
"""
|
||||||
|
return _no_grad_fill_(tensor, 1)
|
||||||
|
|
||||||
|
|
||||||
|
def zeros_(tensor):
|
||||||
|
"""
|
||||||
|
Modified tensor inspace using zeros_
|
||||||
|
Args:
|
||||||
|
tensor (paddle.Tensor): paddle Tensor
|
||||||
|
Return:
|
||||||
|
tensor
|
||||||
|
"""
|
||||||
|
return _no_grad_fill_(tensor, 0)
|
||||||
|
|
||||||
|
|
||||||
|
def _calculate_fan_in_and_fan_out(tensor, reverse=False):
|
||||||
|
"""
|
||||||
|
Calculate (fan_in, _fan_out) for tensor
|
||||||
|
|
||||||
|
Args:
|
||||||
|
tensor (Tensor): paddle.Tensor
|
||||||
|
reverse (bool: False): tensor data format order, False by default as [fout, fin, ...]. e.g. : conv.weight [cout, cin, kh, kw] is False; linear.weight [cin, cout] is True
|
||||||
|
|
||||||
|
Return:
|
||||||
|
Tuple[fan_in, fan_out]
|
||||||
|
"""
|
||||||
|
if tensor.ndim < 2:
|
||||||
|
raise ValueError(
|
||||||
|
"Fan in and fan out can not be computed for tensor with fewer than 2 dimensions"
|
||||||
|
)
|
||||||
|
|
||||||
|
if reverse:
|
||||||
|
num_input_fmaps, num_output_fmaps = tensor.shape[0], tensor.shape[1]
|
||||||
|
else:
|
||||||
|
num_input_fmaps, num_output_fmaps = tensor.shape[1], tensor.shape[0]
|
||||||
|
|
||||||
|
receptive_field_size = 1
|
||||||
|
if tensor.ndim > 2:
|
||||||
|
receptive_field_size = np.prod(tensor.shape[2:])
|
||||||
|
|
||||||
|
fan_in = num_input_fmaps * receptive_field_size
|
||||||
|
fan_out = num_output_fmaps * receptive_field_size
|
||||||
|
|
||||||
|
return fan_in, fan_out
|
||||||
|
|
||||||
|
|
||||||
|
def xavier_uniform_(tensor, gain=1., reverse=False):
|
||||||
|
"""
|
||||||
|
Modified tensor inspace using xavier_uniform_
|
||||||
|
Args:
|
||||||
|
tensor (paddle.Tensor): paddle Tensor
|
||||||
|
gain (float): super parameter, 1. default.
|
||||||
|
reverse (bool): reverse (bool: False): tensor data format order, False by default as [fout, fin, ...].
|
||||||
|
Return:
|
||||||
|
tensor
|
||||||
|
"""
|
||||||
|
fan_in, fan_out = _calculate_fan_in_and_fan_out(tensor, reverse=reverse)
|
||||||
|
std = gain * math.sqrt(2.0 / float(fan_in + fan_out))
|
||||||
|
k = math.sqrt(3.0) * std
|
||||||
|
return _no_grad_uniform_(tensor, -k, k)
|
||||||
|
|
||||||
|
|
||||||
|
def xavier_normal_(tensor, gain=1., reverse=False):
|
||||||
|
"""
|
||||||
|
Modified tensor inspace using xavier_normal_
|
||||||
|
Args:
|
||||||
|
tensor (paddle.Tensor): paddle Tensor
|
||||||
|
gain (float): super parameter, 1. default.
|
||||||
|
reverse (bool): reverse (bool: False): tensor data format order, False by default as [fout, fin, ...].
|
||||||
|
Return:
|
||||||
|
tensor
|
||||||
|
"""
|
||||||
|
fan_in, fan_out = _calculate_fan_in_and_fan_out(tensor, reverse=reverse)
|
||||||
|
std = gain * math.sqrt(2.0 / float(fan_in + fan_out))
|
||||||
|
return _no_grad_normal_(tensor, 0, std)
|
||||||
|
|
||||||
|
|
||||||
|
# reference: https://pytorch.org/docs/stable/_modules/torch/nn/init.html
|
||||||
|
def _calculate_correct_fan(tensor, mode, reverse=False):
|
||||||
|
mode = mode.lower()
|
||||||
|
valid_modes = ['fan_in', 'fan_out']
|
||||||
|
if mode not in valid_modes:
|
||||||
|
raise ValueError("Mode {} not supported, please use one of {}".format(
|
||||||
|
mode, valid_modes))
|
||||||
|
|
||||||
|
fan_in, fan_out = _calculate_fan_in_and_fan_out(tensor, reverse)
|
||||||
|
|
||||||
|
return fan_in if mode == 'fan_in' else fan_out
|
||||||
|
|
||||||
|
|
||||||
|
def _calculate_gain(nonlinearity, param=None):
|
||||||
|
linear_fns = [
|
||||||
|
'linear', 'conv1d', 'conv2d', 'conv3d', 'conv_transpose1d',
|
||||||
|
'conv_transpose2d', 'conv_transpose3d'
|
||||||
|
]
|
||||||
|
if nonlinearity in linear_fns or nonlinearity == 'sigmoid':
|
||||||
|
return 1
|
||||||
|
elif nonlinearity == 'tanh':
|
||||||
|
return 5.0 / 3
|
||||||
|
elif nonlinearity == 'relu':
|
||||||
|
return math.sqrt(2.0)
|
||||||
|
elif nonlinearity == 'leaky_relu':
|
||||||
|
if param is None:
|
||||||
|
negative_slope = 0.01
|
||||||
|
elif not isinstance(param, bool) and isinstance(
|
||||||
|
param, int) or isinstance(param, float):
|
||||||
|
# True/False are instances of int, hence check above
|
||||||
|
negative_slope = param
|
||||||
|
else:
|
||||||
|
raise ValueError("negative_slope {} not a valid number".format(
|
||||||
|
param))
|
||||||
|
return math.sqrt(2.0 / (1 + negative_slope**2))
|
||||||
|
elif nonlinearity == 'selu':
|
||||||
|
return 3.0 / 4
|
||||||
|
else:
|
||||||
|
raise ValueError("Unsupported nonlinearity {}".format(nonlinearity))
|
||||||
|
|
||||||
|
|
||||||
|
def kaiming_uniform_(tensor,
|
||||||
|
a=0,
|
||||||
|
mode='fan_in',
|
||||||
|
nonlinearity='leaky_relu',
|
||||||
|
reverse=False):
|
||||||
|
"""
|
||||||
|
Modified tensor inspace using kaiming_uniform method
|
||||||
|
Args:
|
||||||
|
tensor (paddle.Tensor): paddle Tensor
|
||||||
|
mode (str): ['fan_in', 'fan_out'], 'fin_in' defalut
|
||||||
|
nonlinearity (str): nonlinearity method name
|
||||||
|
reverse (bool): reverse (bool: False): tensor data format order, False by default as [fout, fin, ...].
|
||||||
|
Return:
|
||||||
|
tensor
|
||||||
|
"""
|
||||||
|
fan = _calculate_correct_fan(tensor, mode, reverse)
|
||||||
|
gain = _calculate_gain(nonlinearity, a)
|
||||||
|
std = gain / math.sqrt(fan)
|
||||||
|
k = math.sqrt(3.0) * std
|
||||||
|
return _no_grad_uniform_(tensor, -k, k)
|
||||||
|
|
||||||
|
|
||||||
|
def kaiming_normal_(tensor,
|
||||||
|
a=0,
|
||||||
|
mode='fan_in',
|
||||||
|
nonlinearity='leaky_relu',
|
||||||
|
reverse=False):
|
||||||
|
"""
|
||||||
|
Modified tensor inspace using kaiming_normal_
|
||||||
|
Args:
|
||||||
|
tensor (paddle.Tensor): paddle Tensor
|
||||||
|
mode (str): ['fan_in', 'fan_out'], 'fin_in' defalut
|
||||||
|
nonlinearity (str): nonlinearity method name
|
||||||
|
reverse (bool): reverse (bool: False): tensor data format order, False by default as [fout, fin, ...].
|
||||||
|
Return:
|
||||||
|
tensor
|
||||||
|
"""
|
||||||
|
fan = _calculate_correct_fan(tensor, mode, reverse)
|
||||||
|
gain = _calculate_gain(nonlinearity, a)
|
||||||
|
std = gain / math.sqrt(fan)
|
||||||
|
return _no_grad_normal_(tensor, 0, std)
|
||||||
|
|
||||||
|
|
||||||
|
def linear_init_(module):
|
||||||
|
bound = 1 / math.sqrt(module.weight.shape[0])
|
||||||
|
uniform_(module.weight, -bound, bound)
|
||||||
|
uniform_(module.bias, -bound, bound)
|
||||||
|
|
||||||
|
|
||||||
|
def conv_init_(module):
|
||||||
|
bound = 1 / np.sqrt(np.prod(module.weight.shape[1:]))
|
||||||
|
uniform_(module.weight, -bound, bound)
|
||||||
|
if module.bias is not None:
|
||||||
|
uniform_(module.bias, -bound, bound)
|
||||||
|
|
||||||
|
|
||||||
|
def bias_init_with_prob(prior_prob=0.01):
|
||||||
|
"""initialize conv/fc bias value according to a given probability value."""
|
||||||
|
bias_init = float(-np.log((1 - prior_prob) / prior_prob))
|
||||||
|
return bias_init
|
||||||
|
|
||||||
|
|
||||||
|
@paddle.no_grad()
|
||||||
|
def reset_initialized_parameter(model, include_self=True):
|
||||||
|
"""
|
||||||
|
Reset initialized parameter using following method for [conv, linear, embedding, bn]
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model (paddle.Layer): paddle Layer
|
||||||
|
include_self (bool: False): include_self for Layer.named_sublayers method. Indicate whether including itself
|
||||||
|
Return:
|
||||||
|
None
|
||||||
|
"""
|
||||||
|
for _, m in model.named_sublayers(include_self=include_self):
|
||||||
|
if isinstance(m, nn.Conv2D):
|
||||||
|
k = float(m._groups) / (m._in_channels * m._kernel_size[0] *
|
||||||
|
m._kernel_size[1])
|
||||||
|
k = math.sqrt(k)
|
||||||
|
_no_grad_uniform_(m.weight, -k, k)
|
||||||
|
if hasattr(m, 'bias') and getattr(m, 'bias') is not None:
|
||||||
|
_no_grad_uniform_(m.bias, -k, k)
|
||||||
|
|
||||||
|
elif isinstance(m, nn.Linear):
|
||||||
|
k = math.sqrt(1. / m.weight.shape[0])
|
||||||
|
_no_grad_uniform_(m.weight, -k, k)
|
||||||
|
if hasattr(m, 'bias') and getattr(m, 'bias') is not None:
|
||||||
|
_no_grad_uniform_(m.bias, -k, k)
|
||||||
|
|
||||||
|
elif isinstance(m, nn.Embedding):
|
||||||
|
_no_grad_normal_(m.weight, mean=0., std=1.)
|
||||||
|
|
||||||
|
elif isinstance(m, (nn.BatchNorm2D, nn.LayerNorm)):
|
||||||
|
_no_grad_fill_(m.weight, 1.)
|
||||||
|
if hasattr(m, 'bias') and getattr(m, 'bias') is not None:
|
||||||
|
_no_grad_fill_(m.bias, 0)
|
Loading…
x
Reference in New Issue
Block a user