parent
b27acf6ace
commit
7595ba6d70
|
@ -27,8 +27,9 @@ from ppcls.arch.backbone.base.theseus_layer import TheseusLayer
|
|||
from ppcls.utils import logger
|
||||
from ppcls.utils.save_load import load_dygraph_pretrain
|
||||
from ppcls.arch.slim import prune_model, quantize_model
|
||||
from ppcls.arch.distill.afd_attention import LinearTransformStudent, LinearTransformTeacher
|
||||
|
||||
__all__ = ["build_model", "RecModel", "DistillationModel"]
|
||||
__all__ = ["build_model", "RecModel", "DistillationModel", "AttentionModel"]
|
||||
|
||||
|
||||
def build_model(config):
|
||||
|
@ -132,3 +133,24 @@ class DistillationModel(nn.Layer):
|
|||
else:
|
||||
result_dict[model_name] = self.model_list[idx](x, label)
|
||||
return result_dict
|
||||
|
||||
|
||||
class AttentionModel(DistillationModel):
|
||||
def __init__(self,
|
||||
models=None,
|
||||
pretrained_list=None,
|
||||
freeze_params_list=None,
|
||||
**kargs):
|
||||
super().__init__(models, pretrained_list, freeze_params_list, **kargs)
|
||||
|
||||
def forward(self, x, label=None):
|
||||
result_dict = dict()
|
||||
out = x
|
||||
for idx, model_name in enumerate(self.model_name_list):
|
||||
if label is None:
|
||||
out = self.model_list[idx](out)
|
||||
result_dict.update(out)
|
||||
else:
|
||||
out = self.model_list[idx](out, label)
|
||||
result_dict.update(out)
|
||||
return result_dict
|
||||
|
|
|
@ -35,7 +35,7 @@ class TheseusLayer(nn.Layer):
|
|||
self.quanter = None
|
||||
|
||||
def _return_dict_hook(self, layer, input, output):
|
||||
res_dict = {"output": output}
|
||||
res_dict = {"logits": output}
|
||||
# 'list' is needed to avoid error raised by popping self.res_dict
|
||||
for res_key in list(self.res_dict):
|
||||
# clear the res_dict because the forward process may change according to input
|
||||
|
|
|
@ -0,0 +1,123 @@
|
|||
#copyright (c) 2021 PaddlePaddle Authors. All Rights Reserve.
|
||||
#
|
||||
#Licensed under the Apache License, Version 2.0 (the "License");
|
||||
#you may not use this file except in compliance with the License.
|
||||
#You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
#Unless required by applicable law or agreed to in writing, software
|
||||
#distributed under the License is distributed on an "AS IS" BASIS,
|
||||
#WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
#See the License for the specific language governing permissions and
|
||||
#limitations under the License.
|
||||
|
||||
import paddle.nn as nn
|
||||
import paddle.nn.functional as F
|
||||
import paddle
|
||||
import numpy as np
|
||||
|
||||
|
||||
class LinearBNReLU(nn.Layer):
|
||||
def __init__(self, nin, nout):
|
||||
super().__init__()
|
||||
self.linear = nn.Linear(nin, nout)
|
||||
self.bn = nn.BatchNorm1D(nout)
|
||||
self.relu = nn.ReLU()
|
||||
|
||||
def forward(self, x, relu=True):
|
||||
if relu:
|
||||
return self.relu(self.bn(self.linear(x)))
|
||||
return self.bn(self.linear(x))
|
||||
|
||||
|
||||
def unique_shape(s_shapes):
|
||||
n_s = []
|
||||
unique_shapes = []
|
||||
n = -1
|
||||
for s_shape in s_shapes:
|
||||
if s_shape not in unique_shapes:
|
||||
unique_shapes.append(s_shape)
|
||||
n += 1
|
||||
n_s.append(n)
|
||||
return n_s, unique_shapes
|
||||
|
||||
|
||||
class LinearTransformTeacher(nn.Layer):
|
||||
def __init__(self, qk_dim, t_shapes, keys):
|
||||
super().__init__()
|
||||
self.teacher_keys = keys
|
||||
self.t_shapes = [[1] + t_i for t_i in t_shapes]
|
||||
self.query_layer = nn.LayerList(
|
||||
[LinearBNReLU(t_shape[1], qk_dim) for t_shape in self.t_shapes])
|
||||
|
||||
def forward(self, t_features_dict):
|
||||
g_t = [t_features_dict[key] for key in self.teacher_keys]
|
||||
bs = g_t[0].shape[0]
|
||||
channel_mean = [f_t.mean(3).mean(2) for f_t in g_t]
|
||||
spatial_mean = []
|
||||
for i in range(len(g_t)):
|
||||
c, h, w = g_t[i].shape[1:]
|
||||
spatial_mean.append(g_t[i].pow(2).mean(1).reshape([bs, h * w]))
|
||||
query = paddle.stack(
|
||||
[
|
||||
query_layer(
|
||||
f_t, relu=False)
|
||||
for f_t, query_layer in zip(channel_mean, self.query_layer)
|
||||
],
|
||||
axis=1)
|
||||
value = [F.normalize(f_s, axis=1) for f_s in spatial_mean]
|
||||
return {"query": query, "value": value}
|
||||
|
||||
|
||||
class LinearTransformStudent(nn.Layer):
|
||||
def __init__(self, qk_dim, t_shapes, s_shapes, keys):
|
||||
super().__init__()
|
||||
self.student_keys = keys
|
||||
self.t_shapes = [[1] + t_i for t_i in t_shapes]
|
||||
self.s_shapes = [[1] + s_i for s_i in s_shapes]
|
||||
self.t = len(self.t_shapes)
|
||||
self.s = len(self.s_shapes)
|
||||
self.qk_dim = qk_dim
|
||||
self.n_t, self.unique_t_shapes = unique_shape(self.t_shapes)
|
||||
self.relu = nn.ReLU()
|
||||
self.samplers = nn.LayerList(
|
||||
[Sample(t_shape) for t_shape in self.unique_t_shapes])
|
||||
self.key_layer = nn.LayerList([
|
||||
LinearBNReLU(s_shape[1], self.qk_dim) for s_shape in self.s_shapes
|
||||
])
|
||||
self.bilinear = LinearBNReLU(qk_dim, qk_dim * len(self.t_shapes))
|
||||
|
||||
def forward(self, s_features_dict):
|
||||
g_s = [s_features_dict[key] for key in self.student_keys]
|
||||
bs = g_s[0].shape[0]
|
||||
channel_mean = [f_s.mean(3).mean(2) for f_s in g_s]
|
||||
spatial_mean = [sampler(g_s, bs) for sampler in self.samplers]
|
||||
|
||||
key = paddle.stack(
|
||||
[
|
||||
key_layer(f_s)
|
||||
for key_layer, f_s in zip(self.key_layer, channel_mean)
|
||||
],
|
||||
axis=1).reshape([-1, self.qk_dim]) # Bs x h
|
||||
bilinear_key = self.bilinear(
|
||||
key, relu=False).reshape([bs, self.s, self.t, self.qk_dim])
|
||||
value = [F.normalize(s_m, axis=2) for s_m in spatial_mean]
|
||||
return {"bilinear_key": bilinear_key, "value": value}
|
||||
|
||||
|
||||
class Sample(nn.Layer):
|
||||
def __init__(self, t_shape):
|
||||
super().__init__()
|
||||
self.t_N, self.t_C, self.t_H, self.t_W = t_shape
|
||||
self.sample = nn.AdaptiveAvgPool2D((self.t_H, self.t_W))
|
||||
|
||||
def forward(self, g_s, bs):
|
||||
g_s = paddle.stack(
|
||||
[
|
||||
self.sample(f_s.pow(2).mean(
|
||||
1, keepdim=True)).reshape([bs, self.t_H * self.t_W])
|
||||
for f_s in g_s
|
||||
],
|
||||
axis=1)
|
||||
return g_s
|
|
@ -0,0 +1,202 @@
|
|||
# global configs
|
||||
Global:
|
||||
checkpoints: null
|
||||
pretrained_model: null
|
||||
output_dir: "./output/"
|
||||
device: "gpu"
|
||||
save_interval: 1
|
||||
eval_during_train: True
|
||||
eval_interval: 1
|
||||
epochs: 100
|
||||
print_batch_step: 10
|
||||
use_visualdl: False
|
||||
# used for static mode and model export
|
||||
image_shape: [3, 224, 224]
|
||||
save_inference_dir: "./inference"
|
||||
|
||||
# model architecture
|
||||
Arch:
|
||||
name: "DistillationModel"
|
||||
# if not null, its lengths should be same as models
|
||||
pretrained_list:
|
||||
# if not null, its lengths should be same as models
|
||||
freeze_params_list:
|
||||
models:
|
||||
- Teacher:
|
||||
name: AttentionModel
|
||||
pretrained_list:
|
||||
freeze_params_list:
|
||||
- True
|
||||
- False
|
||||
models:
|
||||
- ResNet34:
|
||||
name: ResNet34
|
||||
pretrained: True
|
||||
return_patterns: &t_keys ["blocks[0]", "blocks[1]", "blocks[2]", "blocks[3]",
|
||||
"blocks[4]", "blocks[5]", "blocks[6]", "blocks[7]",
|
||||
"blocks[8]", "blocks[9]", "blocks[10]", "blocks[11]",
|
||||
"blocks[12]", "blocks[13]", "blocks[14]", "blocks[15]"]
|
||||
- LinearTransformTeacher:
|
||||
name: LinearTransformTeacher
|
||||
qk_dim: 128
|
||||
keys: *t_keys
|
||||
t_shapes: &t_shapes [[64, 56, 56], [64, 56, 56], [64, 56, 56], [128, 28, 28],
|
||||
[128, 28, 28], [128, 28, 28], [128, 28, 28], [256, 14, 14],
|
||||
[256, 14, 14], [256, 14, 14], [256, 14, 14], [256, 14, 14],
|
||||
[256, 14, 14], [512, 7, 7], [512, 7, 7], [512, 7, 7]]
|
||||
|
||||
- Student:
|
||||
name: AttentionModel
|
||||
pretrained_list:
|
||||
freeze_params_list:
|
||||
- False
|
||||
- False
|
||||
models:
|
||||
- ResNet18:
|
||||
name: ResNet18
|
||||
pretrained: False
|
||||
return_patterns: &s_keys ["blocks[0]", "blocks[1]", "blocks[2]", "blocks[3]",
|
||||
"blocks[4]", "blocks[5]", "blocks[6]", "blocks[7]"]
|
||||
- LinearTransformStudent:
|
||||
name: LinearTransformStudent
|
||||
qk_dim: 128
|
||||
keys: *s_keys
|
||||
s_shapes: &s_shapes [[64, 56, 56], [64, 56, 56], [128, 28, 28], [128, 28, 28],
|
||||
[256, 14, 14], [256, 14, 14], [512, 7, 7], [512, 7, 7]]
|
||||
t_shapes: *t_shapes
|
||||
|
||||
infer_model_name: "Student"
|
||||
|
||||
|
||||
# loss function config for traing/eval process
|
||||
Loss:
|
||||
Train:
|
||||
- DistillationGTCELoss:
|
||||
weight: 1.0
|
||||
model_names: ["Student"]
|
||||
key: logits
|
||||
- DistillationKLDivLoss:
|
||||
weight: 0.9
|
||||
model_name_pairs: [["Student", "Teacher"]]
|
||||
temperature: 4
|
||||
key: logits
|
||||
- AFDLoss:
|
||||
weight: 50.0
|
||||
model_name_pair: ["Student", "Teacher"]
|
||||
student_keys: ["bilinear_key", "value"]
|
||||
teacher_keys: ["query", "value"]
|
||||
s_shapes: *s_shapes
|
||||
t_shapes: *t_shapes
|
||||
Eval:
|
||||
- DistillationGTCELoss:
|
||||
weight: 1.0
|
||||
model_names: ["Student"]
|
||||
|
||||
|
||||
Optimizer:
|
||||
name: Momentum
|
||||
momentum: 0.9
|
||||
weight_decay: 1e-4
|
||||
lr:
|
||||
name: MultiStepDecay
|
||||
learning_rate: 0.1
|
||||
milestones: [30, 60, 90]
|
||||
step_each_epoch: 1
|
||||
gamma: 0.1
|
||||
|
||||
|
||||
# data loader for train and eval
|
||||
DataLoader:
|
||||
Train:
|
||||
dataset:
|
||||
name: ImageNetDataset
|
||||
image_root: "./dataset/ILSVRC2012/"
|
||||
cls_label_path: "./dataset/ILSVRC2012/train_list.txt"
|
||||
transform_ops:
|
||||
- DecodeImage:
|
||||
to_rgb: True
|
||||
channel_first: False
|
||||
- RandCropImage:
|
||||
size: 224
|
||||
interpolation: bicubic
|
||||
backend: pil
|
||||
- RandFlipImage:
|
||||
flip_code: 1
|
||||
- NormalizeImage:
|
||||
scale: 0.00392157
|
||||
mean: [0.485, 0.456, 0.406]
|
||||
std: [0.229, 0.224, 0.225]
|
||||
order: ''
|
||||
|
||||
sampler:
|
||||
name: DistributedBatchSampler
|
||||
batch_size: 64
|
||||
drop_last: False
|
||||
shuffle: True
|
||||
loader:
|
||||
num_workers: 8
|
||||
use_shared_memory: True
|
||||
|
||||
Eval:
|
||||
dataset:
|
||||
name: ImageNetDataset
|
||||
image_root: "./dataset/ILSVRC2012/"
|
||||
cls_label_path: "./dataset/ILSVRC2012/val_list.txt"
|
||||
transform_ops:
|
||||
- DecodeImage:
|
||||
to_rgb: True
|
||||
channel_first: False
|
||||
- ResizeImage:
|
||||
resize_short: 256
|
||||
interpolation: bicubic
|
||||
backend: pil
|
||||
- CropImage:
|
||||
size: 224
|
||||
- NormalizeImage:
|
||||
scale: 0.00392157
|
||||
mean: [0.485, 0.456, 0.406]
|
||||
std: [0.229, 0.224, 0.225]
|
||||
order: ''
|
||||
sampler:
|
||||
name: DistributedBatchSampler
|
||||
batch_size: 64
|
||||
drop_last: False
|
||||
shuffle: False
|
||||
loader:
|
||||
num_workers: 4
|
||||
use_shared_memory: True
|
||||
|
||||
Infer:
|
||||
infer_imgs: "docs/images/inference_deployment/whl_demo.jpg"
|
||||
batch_size: 10
|
||||
transforms:
|
||||
- DecodeImage:
|
||||
to_rgb: True
|
||||
channel_first: False
|
||||
- ResizeImage:
|
||||
resize_short: 256
|
||||
interpolation: bicubic
|
||||
backend: pil
|
||||
- CropImage:
|
||||
size: 224
|
||||
- NormalizeImage:
|
||||
scale: 1.0/255.0
|
||||
mean: [0.485, 0.456, 0.406]
|
||||
std: [0.229, 0.224, 0.225]
|
||||
order: ''
|
||||
- ToCHWImage:
|
||||
PostProcess:
|
||||
name: DistillationPostProcess
|
||||
func: Topk
|
||||
topk: 5
|
||||
class_id_map_file: "ppcls/utils/imagenet1k_label_list.txt"
|
||||
|
||||
Metric:
|
||||
Train:
|
||||
- DistillationTopkAcc:
|
||||
model_key: "Student"
|
||||
topk: [1, 5]
|
||||
Eval:
|
||||
- DistillationTopkAcc:
|
||||
model_key: "Student"
|
||||
topk: [1, 5]
|
|
@ -46,6 +46,8 @@ class Topk(object):
|
|||
return class_id_map
|
||||
|
||||
def __call__(self, x, file_names=None, multilabel=False):
|
||||
if isinstance(x, dict):
|
||||
x = x['logits']
|
||||
assert isinstance(x, paddle.Tensor)
|
||||
if file_names is not None:
|
||||
assert x.shape[0] == len(file_names)
|
||||
|
|
|
@ -459,5 +459,7 @@ class ExportModel(TheseusLayer):
|
|||
if self.infer_output_key is not None:
|
||||
x = x[self.infer_output_key]
|
||||
if self.out_act is not None:
|
||||
if isinstance(x, dict):
|
||||
x = x["logits"]
|
||||
x = self.out_act(x)
|
||||
return x
|
||||
|
|
|
@ -99,6 +99,8 @@ def classification_eval(engine, epoch_id=0):
|
|||
if isinstance(out, dict):
|
||||
if "Student" in out:
|
||||
out = out["Student"]
|
||||
if isinstance(out, dict):
|
||||
out = out["logits"]
|
||||
elif "logits" in out:
|
||||
out = out["logits"]
|
||||
else:
|
||||
|
|
|
@ -22,7 +22,9 @@ from .distillationloss import DistillationGTCELoss
|
|||
from .distillationloss import DistillationDMLLoss
|
||||
from .distillationloss import DistillationDistanceLoss
|
||||
from .distillationloss import DistillationRKDLoss
|
||||
from .distillationloss import DistillationKLDivLoss
|
||||
from .multilabelloss import MultiLabelLoss
|
||||
from .afdloss import AFDLoss
|
||||
|
||||
from .deephashloss import DSHSDLoss, LCDSHLoss
|
||||
|
||||
|
|
|
@ -0,0 +1,132 @@
|
|||
#copyright (c) 2021 PaddlePaddle Authors. All Rights Reserve.
|
||||
#
|
||||
#Licensed under the Apache License, Version 2.0 (the "License");
|
||||
#you may not use this file except in compliance with the License.
|
||||
#You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
#Unless required by applicable law or agreed to in writing, software
|
||||
#distributed under the License is distributed on an "AS IS" BASIS,
|
||||
#WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
#See the License for the specific language governing permissions and
|
||||
#limitations under the License.
|
||||
|
||||
import paddle.nn as nn
|
||||
import paddle.nn.functional as F
|
||||
import paddle
|
||||
import numpy as np
|
||||
import matplotlib.pyplot as plt
|
||||
import cv2
|
||||
import warnings
|
||||
warnings.filterwarnings('ignore')
|
||||
|
||||
|
||||
class LinearBNReLU(nn.Layer):
|
||||
def __init__(self, nin, nout):
|
||||
super().__init__()
|
||||
self.linear = nn.Linear(nin, nout)
|
||||
self.bn = nn.BatchNorm1D(nout)
|
||||
self.relu = nn.ReLU()
|
||||
|
||||
def forward(self, x, relu=True):
|
||||
if relu:
|
||||
return self.relu(self.bn(self.linear(x)))
|
||||
return self.bn(self.linear(x))
|
||||
|
||||
|
||||
def unique_shape(s_shapes):
|
||||
n_s = []
|
||||
unique_shapes = []
|
||||
n = -1
|
||||
for s_shape in s_shapes:
|
||||
if s_shape not in unique_shapes:
|
||||
unique_shapes.append(s_shape)
|
||||
n += 1
|
||||
n_s.append(n)
|
||||
return n_s, unique_shapes
|
||||
|
||||
|
||||
class AFDLoss(nn.Layer):
|
||||
"""
|
||||
AFDLoss
|
||||
https://www.aaai.org/AAAI21Papers/AAAI-9785.JiM.pdf
|
||||
https://github.com/clovaai/attention-feature-distillation
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
model_name_pair=["Student", "Teacher"],
|
||||
student_keys=["bilinear_key", "value"],
|
||||
teacher_keys=["query", "value"],
|
||||
s_shapes=[[64, 16, 160], [128, 8, 160], [256, 4, 160],
|
||||
[512, 2, 160]],
|
||||
t_shapes=[[640, 48], [320, 96], [160, 192]],
|
||||
qk_dim=128,
|
||||
name="loss_afd"):
|
||||
super().__init__()
|
||||
assert isinstance(model_name_pair, list)
|
||||
self.model_name_pair = model_name_pair
|
||||
self.student_keys = student_keys
|
||||
self.teacher_keys = teacher_keys
|
||||
self.s_shapes = [[1] + s_i for s_i in s_shapes]
|
||||
self.t_shapes = [[1] + t_i for t_i in t_shapes]
|
||||
self.qk_dim = qk_dim
|
||||
self.n_t, self.unique_t_shapes = unique_shape(self.t_shapes)
|
||||
self.attention = Attention(self.qk_dim, self.t_shapes, self.s_shapes,
|
||||
self.n_t, self.unique_t_shapes)
|
||||
self.name = name
|
||||
|
||||
def forward(self, predicts, batch):
|
||||
s_features_dict = predicts[self.model_name_pair[0]]
|
||||
t_features_dict = predicts[self.model_name_pair[1]]
|
||||
|
||||
g_s = [s_features_dict[key] for key in self.student_keys]
|
||||
g_t = [t_features_dict[key] for key in self.teacher_keys]
|
||||
|
||||
loss = self.attention(g_s, g_t)
|
||||
sum_loss = sum(loss)
|
||||
|
||||
loss_dict = dict()
|
||||
loss_dict[self.name] = sum_loss
|
||||
|
||||
return loss_dict
|
||||
|
||||
|
||||
class Attention(nn.Layer):
|
||||
def __init__(self, qk_dim, t_shapes, s_shapes, n_t, unique_t_shapes):
|
||||
super().__init__()
|
||||
self.qk_dim = qk_dim
|
||||
self.n_t = n_t
|
||||
# self.linear_trans_s = LinearTransformStudent(qk_dim, t_shapes, s_shapes, unique_t_shapes)
|
||||
# self.linear_trans_t = LinearTransformTeacher(qk_dim, t_shapes)
|
||||
|
||||
self.p_t = self.create_parameter(
|
||||
shape=[len(t_shapes), qk_dim],
|
||||
default_initializer=nn.initializer.XavierNormal())
|
||||
self.p_s = self.create_parameter(
|
||||
shape=[len(s_shapes), qk_dim],
|
||||
default_initializer=nn.initializer.XavierNormal())
|
||||
|
||||
def forward(self, g_s, g_t):
|
||||
bilinear_key, h_hat_s_all = g_s
|
||||
query, h_t_all = g_t
|
||||
|
||||
p_logit = paddle.matmul(self.p_t, self.p_s.t())
|
||||
|
||||
logit = paddle.add(
|
||||
paddle.einsum('bstq,btq->bts', bilinear_key, query),
|
||||
p_logit) / np.sqrt(self.qk_dim)
|
||||
atts = F.softmax(logit, axis=2) # b x t x s
|
||||
|
||||
loss = []
|
||||
|
||||
for i, (n, h_t) in enumerate(zip(self.n_t, h_t_all)):
|
||||
h_hat_s = h_hat_s_all[n]
|
||||
diff = self.cal_diff(h_hat_s, h_t, atts[:, i])
|
||||
loss.append(diff)
|
||||
return loss
|
||||
|
||||
def cal_diff(self, v_s, v_t, att):
|
||||
diff = (v_s - v_t.unsqueeze(1)).pow(2).mean(2)
|
||||
diff = paddle.multiply(diff, att).sum(1).mean()
|
||||
return diff
|
|
@ -14,11 +14,13 @@
|
|||
|
||||
import paddle
|
||||
import paddle.nn as nn
|
||||
import paddle.nn.functional as F
|
||||
|
||||
from .celoss import CELoss
|
||||
from .dmlloss import DMLLoss
|
||||
from .distanceloss import DistanceLoss
|
||||
from .rkdloss import RKdAngle, RkdDistance
|
||||
from .kldivloss import KLDivLoss
|
||||
|
||||
|
||||
class DistillationCELoss(CELoss):
|
||||
|
@ -172,3 +174,33 @@ class DistillationRKDLoss(nn.Layer):
|
|||
student_out, teacher_out)
|
||||
|
||||
return loss_dict
|
||||
|
||||
|
||||
class DistillationKLDivLoss(KLDivLoss):
|
||||
"""
|
||||
DistillationKLDivLoss
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
model_name_pairs=[],
|
||||
temperature=4,
|
||||
key=None,
|
||||
name="loss_kl"):
|
||||
super().__init__(temperature=temperature)
|
||||
assert isinstance(model_name_pairs, list)
|
||||
self.key = key
|
||||
self.model_name_pairs = model_name_pairs
|
||||
self.name = name
|
||||
|
||||
def forward(self, predicts, batch):
|
||||
loss_dict = dict()
|
||||
for idx, pair in enumerate(self.model_name_pairs):
|
||||
out1 = predicts[pair[0]]
|
||||
out2 = predicts[pair[1]]
|
||||
if self.key is not None:
|
||||
out1 = out1[self.key]
|
||||
out2 = out2[self.key]
|
||||
loss = super().forward(out1, out2)
|
||||
for key in loss:
|
||||
loss_dict["{}_{}_{}".format(key, pair[0], pair[1])] = loss[key]
|
||||
return loss_dict
|
||||
|
|
|
@ -0,0 +1,33 @@
|
|||
# copyright (c) 2021 PaddlePaddle Authors. All Rights Reserve.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import paddle
|
||||
import paddle.nn as nn
|
||||
import paddle.nn.functional as F
|
||||
|
||||
|
||||
class KLDivLoss(nn.Layer):
|
||||
"""
|
||||
Distilling the Knowledge in a Neural Network
|
||||
"""
|
||||
|
||||
def __init__(self, temperature=4):
|
||||
super(KLDivLoss, self).__init__()
|
||||
self.T = temperature
|
||||
|
||||
def forward(self, y_s, y_t):
|
||||
p_s = F.log_softmax(y_s / self.T, axis=1)
|
||||
p_t = F.softmax(y_t / self.T, axis=1)
|
||||
loss = F.kl_div(p_s, p_t, reduction='sum') * (self.T**2) / y_s.shape[0]
|
||||
return {"loss_kldiv": loss}
|
Loading…
Reference in New Issue