parent
b27acf6ace
commit
7595ba6d70
|
@ -27,8 +27,9 @@ from ppcls.arch.backbone.base.theseus_layer import TheseusLayer
|
||||||
from ppcls.utils import logger
|
from ppcls.utils import logger
|
||||||
from ppcls.utils.save_load import load_dygraph_pretrain
|
from ppcls.utils.save_load import load_dygraph_pretrain
|
||||||
from ppcls.arch.slim import prune_model, quantize_model
|
from ppcls.arch.slim import prune_model, quantize_model
|
||||||
|
from ppcls.arch.distill.afd_attention import LinearTransformStudent, LinearTransformTeacher
|
||||||
|
|
||||||
__all__ = ["build_model", "RecModel", "DistillationModel"]
|
__all__ = ["build_model", "RecModel", "DistillationModel", "AttentionModel"]
|
||||||
|
|
||||||
|
|
||||||
def build_model(config):
|
def build_model(config):
|
||||||
|
@ -132,3 +133,24 @@ class DistillationModel(nn.Layer):
|
||||||
else:
|
else:
|
||||||
result_dict[model_name] = self.model_list[idx](x, label)
|
result_dict[model_name] = self.model_list[idx](x, label)
|
||||||
return result_dict
|
return result_dict
|
||||||
|
|
||||||
|
|
||||||
|
class AttentionModel(DistillationModel):
|
||||||
|
def __init__(self,
|
||||||
|
models=None,
|
||||||
|
pretrained_list=None,
|
||||||
|
freeze_params_list=None,
|
||||||
|
**kargs):
|
||||||
|
super().__init__(models, pretrained_list, freeze_params_list, **kargs)
|
||||||
|
|
||||||
|
def forward(self, x, label=None):
|
||||||
|
result_dict = dict()
|
||||||
|
out = x
|
||||||
|
for idx, model_name in enumerate(self.model_name_list):
|
||||||
|
if label is None:
|
||||||
|
out = self.model_list[idx](out)
|
||||||
|
result_dict.update(out)
|
||||||
|
else:
|
||||||
|
out = self.model_list[idx](out, label)
|
||||||
|
result_dict.update(out)
|
||||||
|
return result_dict
|
||||||
|
|
|
@ -35,7 +35,7 @@ class TheseusLayer(nn.Layer):
|
||||||
self.quanter = None
|
self.quanter = None
|
||||||
|
|
||||||
def _return_dict_hook(self, layer, input, output):
|
def _return_dict_hook(self, layer, input, output):
|
||||||
res_dict = {"output": output}
|
res_dict = {"logits": output}
|
||||||
# 'list' is needed to avoid error raised by popping self.res_dict
|
# 'list' is needed to avoid error raised by popping self.res_dict
|
||||||
for res_key in list(self.res_dict):
|
for res_key in list(self.res_dict):
|
||||||
# clear the res_dict because the forward process may change according to input
|
# clear the res_dict because the forward process may change according to input
|
||||||
|
|
|
@ -0,0 +1,123 @@
|
||||||
|
#copyright (c) 2021 PaddlePaddle Authors. All Rights Reserve.
|
||||||
|
#
|
||||||
|
#Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
#you may not use this file except in compliance with the License.
|
||||||
|
#You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
#Unless required by applicable law or agreed to in writing, software
|
||||||
|
#distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
#WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
#See the License for the specific language governing permissions and
|
||||||
|
#limitations under the License.
|
||||||
|
|
||||||
|
import paddle.nn as nn
|
||||||
|
import paddle.nn.functional as F
|
||||||
|
import paddle
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
|
||||||
|
class LinearBNReLU(nn.Layer):
|
||||||
|
def __init__(self, nin, nout):
|
||||||
|
super().__init__()
|
||||||
|
self.linear = nn.Linear(nin, nout)
|
||||||
|
self.bn = nn.BatchNorm1D(nout)
|
||||||
|
self.relu = nn.ReLU()
|
||||||
|
|
||||||
|
def forward(self, x, relu=True):
|
||||||
|
if relu:
|
||||||
|
return self.relu(self.bn(self.linear(x)))
|
||||||
|
return self.bn(self.linear(x))
|
||||||
|
|
||||||
|
|
||||||
|
def unique_shape(s_shapes):
|
||||||
|
n_s = []
|
||||||
|
unique_shapes = []
|
||||||
|
n = -1
|
||||||
|
for s_shape in s_shapes:
|
||||||
|
if s_shape not in unique_shapes:
|
||||||
|
unique_shapes.append(s_shape)
|
||||||
|
n += 1
|
||||||
|
n_s.append(n)
|
||||||
|
return n_s, unique_shapes
|
||||||
|
|
||||||
|
|
||||||
|
class LinearTransformTeacher(nn.Layer):
|
||||||
|
def __init__(self, qk_dim, t_shapes, keys):
|
||||||
|
super().__init__()
|
||||||
|
self.teacher_keys = keys
|
||||||
|
self.t_shapes = [[1] + t_i for t_i in t_shapes]
|
||||||
|
self.query_layer = nn.LayerList(
|
||||||
|
[LinearBNReLU(t_shape[1], qk_dim) for t_shape in self.t_shapes])
|
||||||
|
|
||||||
|
def forward(self, t_features_dict):
|
||||||
|
g_t = [t_features_dict[key] for key in self.teacher_keys]
|
||||||
|
bs = g_t[0].shape[0]
|
||||||
|
channel_mean = [f_t.mean(3).mean(2) for f_t in g_t]
|
||||||
|
spatial_mean = []
|
||||||
|
for i in range(len(g_t)):
|
||||||
|
c, h, w = g_t[i].shape[1:]
|
||||||
|
spatial_mean.append(g_t[i].pow(2).mean(1).reshape([bs, h * w]))
|
||||||
|
query = paddle.stack(
|
||||||
|
[
|
||||||
|
query_layer(
|
||||||
|
f_t, relu=False)
|
||||||
|
for f_t, query_layer in zip(channel_mean, self.query_layer)
|
||||||
|
],
|
||||||
|
axis=1)
|
||||||
|
value = [F.normalize(f_s, axis=1) for f_s in spatial_mean]
|
||||||
|
return {"query": query, "value": value}
|
||||||
|
|
||||||
|
|
||||||
|
class LinearTransformStudent(nn.Layer):
|
||||||
|
def __init__(self, qk_dim, t_shapes, s_shapes, keys):
|
||||||
|
super().__init__()
|
||||||
|
self.student_keys = keys
|
||||||
|
self.t_shapes = [[1] + t_i for t_i in t_shapes]
|
||||||
|
self.s_shapes = [[1] + s_i for s_i in s_shapes]
|
||||||
|
self.t = len(self.t_shapes)
|
||||||
|
self.s = len(self.s_shapes)
|
||||||
|
self.qk_dim = qk_dim
|
||||||
|
self.n_t, self.unique_t_shapes = unique_shape(self.t_shapes)
|
||||||
|
self.relu = nn.ReLU()
|
||||||
|
self.samplers = nn.LayerList(
|
||||||
|
[Sample(t_shape) for t_shape in self.unique_t_shapes])
|
||||||
|
self.key_layer = nn.LayerList([
|
||||||
|
LinearBNReLU(s_shape[1], self.qk_dim) for s_shape in self.s_shapes
|
||||||
|
])
|
||||||
|
self.bilinear = LinearBNReLU(qk_dim, qk_dim * len(self.t_shapes))
|
||||||
|
|
||||||
|
def forward(self, s_features_dict):
|
||||||
|
g_s = [s_features_dict[key] for key in self.student_keys]
|
||||||
|
bs = g_s[0].shape[0]
|
||||||
|
channel_mean = [f_s.mean(3).mean(2) for f_s in g_s]
|
||||||
|
spatial_mean = [sampler(g_s, bs) for sampler in self.samplers]
|
||||||
|
|
||||||
|
key = paddle.stack(
|
||||||
|
[
|
||||||
|
key_layer(f_s)
|
||||||
|
for key_layer, f_s in zip(self.key_layer, channel_mean)
|
||||||
|
],
|
||||||
|
axis=1).reshape([-1, self.qk_dim]) # Bs x h
|
||||||
|
bilinear_key = self.bilinear(
|
||||||
|
key, relu=False).reshape([bs, self.s, self.t, self.qk_dim])
|
||||||
|
value = [F.normalize(s_m, axis=2) for s_m in spatial_mean]
|
||||||
|
return {"bilinear_key": bilinear_key, "value": value}
|
||||||
|
|
||||||
|
|
||||||
|
class Sample(nn.Layer):
|
||||||
|
def __init__(self, t_shape):
|
||||||
|
super().__init__()
|
||||||
|
self.t_N, self.t_C, self.t_H, self.t_W = t_shape
|
||||||
|
self.sample = nn.AdaptiveAvgPool2D((self.t_H, self.t_W))
|
||||||
|
|
||||||
|
def forward(self, g_s, bs):
|
||||||
|
g_s = paddle.stack(
|
||||||
|
[
|
||||||
|
self.sample(f_s.pow(2).mean(
|
||||||
|
1, keepdim=True)).reshape([bs, self.t_H * self.t_W])
|
||||||
|
for f_s in g_s
|
||||||
|
],
|
||||||
|
axis=1)
|
||||||
|
return g_s
|
|
@ -0,0 +1,202 @@
|
||||||
|
# global configs
|
||||||
|
Global:
|
||||||
|
checkpoints: null
|
||||||
|
pretrained_model: null
|
||||||
|
output_dir: "./output/"
|
||||||
|
device: "gpu"
|
||||||
|
save_interval: 1
|
||||||
|
eval_during_train: True
|
||||||
|
eval_interval: 1
|
||||||
|
epochs: 100
|
||||||
|
print_batch_step: 10
|
||||||
|
use_visualdl: False
|
||||||
|
# used for static mode and model export
|
||||||
|
image_shape: [3, 224, 224]
|
||||||
|
save_inference_dir: "./inference"
|
||||||
|
|
||||||
|
# model architecture
|
||||||
|
Arch:
|
||||||
|
name: "DistillationModel"
|
||||||
|
# if not null, its lengths should be same as models
|
||||||
|
pretrained_list:
|
||||||
|
# if not null, its lengths should be same as models
|
||||||
|
freeze_params_list:
|
||||||
|
models:
|
||||||
|
- Teacher:
|
||||||
|
name: AttentionModel
|
||||||
|
pretrained_list:
|
||||||
|
freeze_params_list:
|
||||||
|
- True
|
||||||
|
- False
|
||||||
|
models:
|
||||||
|
- ResNet34:
|
||||||
|
name: ResNet34
|
||||||
|
pretrained: True
|
||||||
|
return_patterns: &t_keys ["blocks[0]", "blocks[1]", "blocks[2]", "blocks[3]",
|
||||||
|
"blocks[4]", "blocks[5]", "blocks[6]", "blocks[7]",
|
||||||
|
"blocks[8]", "blocks[9]", "blocks[10]", "blocks[11]",
|
||||||
|
"blocks[12]", "blocks[13]", "blocks[14]", "blocks[15]"]
|
||||||
|
- LinearTransformTeacher:
|
||||||
|
name: LinearTransformTeacher
|
||||||
|
qk_dim: 128
|
||||||
|
keys: *t_keys
|
||||||
|
t_shapes: &t_shapes [[64, 56, 56], [64, 56, 56], [64, 56, 56], [128, 28, 28],
|
||||||
|
[128, 28, 28], [128, 28, 28], [128, 28, 28], [256, 14, 14],
|
||||||
|
[256, 14, 14], [256, 14, 14], [256, 14, 14], [256, 14, 14],
|
||||||
|
[256, 14, 14], [512, 7, 7], [512, 7, 7], [512, 7, 7]]
|
||||||
|
|
||||||
|
- Student:
|
||||||
|
name: AttentionModel
|
||||||
|
pretrained_list:
|
||||||
|
freeze_params_list:
|
||||||
|
- False
|
||||||
|
- False
|
||||||
|
models:
|
||||||
|
- ResNet18:
|
||||||
|
name: ResNet18
|
||||||
|
pretrained: False
|
||||||
|
return_patterns: &s_keys ["blocks[0]", "blocks[1]", "blocks[2]", "blocks[3]",
|
||||||
|
"blocks[4]", "blocks[5]", "blocks[6]", "blocks[7]"]
|
||||||
|
- LinearTransformStudent:
|
||||||
|
name: LinearTransformStudent
|
||||||
|
qk_dim: 128
|
||||||
|
keys: *s_keys
|
||||||
|
s_shapes: &s_shapes [[64, 56, 56], [64, 56, 56], [128, 28, 28], [128, 28, 28],
|
||||||
|
[256, 14, 14], [256, 14, 14], [512, 7, 7], [512, 7, 7]]
|
||||||
|
t_shapes: *t_shapes
|
||||||
|
|
||||||
|
infer_model_name: "Student"
|
||||||
|
|
||||||
|
|
||||||
|
# loss function config for traing/eval process
|
||||||
|
Loss:
|
||||||
|
Train:
|
||||||
|
- DistillationGTCELoss:
|
||||||
|
weight: 1.0
|
||||||
|
model_names: ["Student"]
|
||||||
|
key: logits
|
||||||
|
- DistillationKLDivLoss:
|
||||||
|
weight: 0.9
|
||||||
|
model_name_pairs: [["Student", "Teacher"]]
|
||||||
|
temperature: 4
|
||||||
|
key: logits
|
||||||
|
- AFDLoss:
|
||||||
|
weight: 50.0
|
||||||
|
model_name_pair: ["Student", "Teacher"]
|
||||||
|
student_keys: ["bilinear_key", "value"]
|
||||||
|
teacher_keys: ["query", "value"]
|
||||||
|
s_shapes: *s_shapes
|
||||||
|
t_shapes: *t_shapes
|
||||||
|
Eval:
|
||||||
|
- DistillationGTCELoss:
|
||||||
|
weight: 1.0
|
||||||
|
model_names: ["Student"]
|
||||||
|
|
||||||
|
|
||||||
|
Optimizer:
|
||||||
|
name: Momentum
|
||||||
|
momentum: 0.9
|
||||||
|
weight_decay: 1e-4
|
||||||
|
lr:
|
||||||
|
name: MultiStepDecay
|
||||||
|
learning_rate: 0.1
|
||||||
|
milestones: [30, 60, 90]
|
||||||
|
step_each_epoch: 1
|
||||||
|
gamma: 0.1
|
||||||
|
|
||||||
|
|
||||||
|
# data loader for train and eval
|
||||||
|
DataLoader:
|
||||||
|
Train:
|
||||||
|
dataset:
|
||||||
|
name: ImageNetDataset
|
||||||
|
image_root: "./dataset/ILSVRC2012/"
|
||||||
|
cls_label_path: "./dataset/ILSVRC2012/train_list.txt"
|
||||||
|
transform_ops:
|
||||||
|
- DecodeImage:
|
||||||
|
to_rgb: True
|
||||||
|
channel_first: False
|
||||||
|
- RandCropImage:
|
||||||
|
size: 224
|
||||||
|
interpolation: bicubic
|
||||||
|
backend: pil
|
||||||
|
- RandFlipImage:
|
||||||
|
flip_code: 1
|
||||||
|
- NormalizeImage:
|
||||||
|
scale: 0.00392157
|
||||||
|
mean: [0.485, 0.456, 0.406]
|
||||||
|
std: [0.229, 0.224, 0.225]
|
||||||
|
order: ''
|
||||||
|
|
||||||
|
sampler:
|
||||||
|
name: DistributedBatchSampler
|
||||||
|
batch_size: 64
|
||||||
|
drop_last: False
|
||||||
|
shuffle: True
|
||||||
|
loader:
|
||||||
|
num_workers: 8
|
||||||
|
use_shared_memory: True
|
||||||
|
|
||||||
|
Eval:
|
||||||
|
dataset:
|
||||||
|
name: ImageNetDataset
|
||||||
|
image_root: "./dataset/ILSVRC2012/"
|
||||||
|
cls_label_path: "./dataset/ILSVRC2012/val_list.txt"
|
||||||
|
transform_ops:
|
||||||
|
- DecodeImage:
|
||||||
|
to_rgb: True
|
||||||
|
channel_first: False
|
||||||
|
- ResizeImage:
|
||||||
|
resize_short: 256
|
||||||
|
interpolation: bicubic
|
||||||
|
backend: pil
|
||||||
|
- CropImage:
|
||||||
|
size: 224
|
||||||
|
- NormalizeImage:
|
||||||
|
scale: 0.00392157
|
||||||
|
mean: [0.485, 0.456, 0.406]
|
||||||
|
std: [0.229, 0.224, 0.225]
|
||||||
|
order: ''
|
||||||
|
sampler:
|
||||||
|
name: DistributedBatchSampler
|
||||||
|
batch_size: 64
|
||||||
|
drop_last: False
|
||||||
|
shuffle: False
|
||||||
|
loader:
|
||||||
|
num_workers: 4
|
||||||
|
use_shared_memory: True
|
||||||
|
|
||||||
|
Infer:
|
||||||
|
infer_imgs: "docs/images/inference_deployment/whl_demo.jpg"
|
||||||
|
batch_size: 10
|
||||||
|
transforms:
|
||||||
|
- DecodeImage:
|
||||||
|
to_rgb: True
|
||||||
|
channel_first: False
|
||||||
|
- ResizeImage:
|
||||||
|
resize_short: 256
|
||||||
|
interpolation: bicubic
|
||||||
|
backend: pil
|
||||||
|
- CropImage:
|
||||||
|
size: 224
|
||||||
|
- NormalizeImage:
|
||||||
|
scale: 1.0/255.0
|
||||||
|
mean: [0.485, 0.456, 0.406]
|
||||||
|
std: [0.229, 0.224, 0.225]
|
||||||
|
order: ''
|
||||||
|
- ToCHWImage:
|
||||||
|
PostProcess:
|
||||||
|
name: DistillationPostProcess
|
||||||
|
func: Topk
|
||||||
|
topk: 5
|
||||||
|
class_id_map_file: "ppcls/utils/imagenet1k_label_list.txt"
|
||||||
|
|
||||||
|
Metric:
|
||||||
|
Train:
|
||||||
|
- DistillationTopkAcc:
|
||||||
|
model_key: "Student"
|
||||||
|
topk: [1, 5]
|
||||||
|
Eval:
|
||||||
|
- DistillationTopkAcc:
|
||||||
|
model_key: "Student"
|
||||||
|
topk: [1, 5]
|
|
@ -46,6 +46,8 @@ class Topk(object):
|
||||||
return class_id_map
|
return class_id_map
|
||||||
|
|
||||||
def __call__(self, x, file_names=None, multilabel=False):
|
def __call__(self, x, file_names=None, multilabel=False):
|
||||||
|
if isinstance(x, dict):
|
||||||
|
x = x['logits']
|
||||||
assert isinstance(x, paddle.Tensor)
|
assert isinstance(x, paddle.Tensor)
|
||||||
if file_names is not None:
|
if file_names is not None:
|
||||||
assert x.shape[0] == len(file_names)
|
assert x.shape[0] == len(file_names)
|
||||||
|
|
|
@ -459,5 +459,7 @@ class ExportModel(TheseusLayer):
|
||||||
if self.infer_output_key is not None:
|
if self.infer_output_key is not None:
|
||||||
x = x[self.infer_output_key]
|
x = x[self.infer_output_key]
|
||||||
if self.out_act is not None:
|
if self.out_act is not None:
|
||||||
|
if isinstance(x, dict):
|
||||||
|
x = x["logits"]
|
||||||
x = self.out_act(x)
|
x = self.out_act(x)
|
||||||
return x
|
return x
|
||||||
|
|
|
@ -99,6 +99,8 @@ def classification_eval(engine, epoch_id=0):
|
||||||
if isinstance(out, dict):
|
if isinstance(out, dict):
|
||||||
if "Student" in out:
|
if "Student" in out:
|
||||||
out = out["Student"]
|
out = out["Student"]
|
||||||
|
if isinstance(out, dict):
|
||||||
|
out = out["logits"]
|
||||||
elif "logits" in out:
|
elif "logits" in out:
|
||||||
out = out["logits"]
|
out = out["logits"]
|
||||||
else:
|
else:
|
||||||
|
|
|
@ -22,7 +22,9 @@ from .distillationloss import DistillationGTCELoss
|
||||||
from .distillationloss import DistillationDMLLoss
|
from .distillationloss import DistillationDMLLoss
|
||||||
from .distillationloss import DistillationDistanceLoss
|
from .distillationloss import DistillationDistanceLoss
|
||||||
from .distillationloss import DistillationRKDLoss
|
from .distillationloss import DistillationRKDLoss
|
||||||
|
from .distillationloss import DistillationKLDivLoss
|
||||||
from .multilabelloss import MultiLabelLoss
|
from .multilabelloss import MultiLabelLoss
|
||||||
|
from .afdloss import AFDLoss
|
||||||
|
|
||||||
from .deephashloss import DSHSDLoss, LCDSHLoss
|
from .deephashloss import DSHSDLoss, LCDSHLoss
|
||||||
|
|
||||||
|
|
|
@ -0,0 +1,132 @@
|
||||||
|
#copyright (c) 2021 PaddlePaddle Authors. All Rights Reserve.
|
||||||
|
#
|
||||||
|
#Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
#you may not use this file except in compliance with the License.
|
||||||
|
#You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
#Unless required by applicable law or agreed to in writing, software
|
||||||
|
#distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
#WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
#See the License for the specific language governing permissions and
|
||||||
|
#limitations under the License.
|
||||||
|
|
||||||
|
import paddle.nn as nn
|
||||||
|
import paddle.nn.functional as F
|
||||||
|
import paddle
|
||||||
|
import numpy as np
|
||||||
|
import matplotlib.pyplot as plt
|
||||||
|
import cv2
|
||||||
|
import warnings
|
||||||
|
warnings.filterwarnings('ignore')
|
||||||
|
|
||||||
|
|
||||||
|
class LinearBNReLU(nn.Layer):
|
||||||
|
def __init__(self, nin, nout):
|
||||||
|
super().__init__()
|
||||||
|
self.linear = nn.Linear(nin, nout)
|
||||||
|
self.bn = nn.BatchNorm1D(nout)
|
||||||
|
self.relu = nn.ReLU()
|
||||||
|
|
||||||
|
def forward(self, x, relu=True):
|
||||||
|
if relu:
|
||||||
|
return self.relu(self.bn(self.linear(x)))
|
||||||
|
return self.bn(self.linear(x))
|
||||||
|
|
||||||
|
|
||||||
|
def unique_shape(s_shapes):
|
||||||
|
n_s = []
|
||||||
|
unique_shapes = []
|
||||||
|
n = -1
|
||||||
|
for s_shape in s_shapes:
|
||||||
|
if s_shape not in unique_shapes:
|
||||||
|
unique_shapes.append(s_shape)
|
||||||
|
n += 1
|
||||||
|
n_s.append(n)
|
||||||
|
return n_s, unique_shapes
|
||||||
|
|
||||||
|
|
||||||
|
class AFDLoss(nn.Layer):
|
||||||
|
"""
|
||||||
|
AFDLoss
|
||||||
|
https://www.aaai.org/AAAI21Papers/AAAI-9785.JiM.pdf
|
||||||
|
https://github.com/clovaai/attention-feature-distillation
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self,
|
||||||
|
model_name_pair=["Student", "Teacher"],
|
||||||
|
student_keys=["bilinear_key", "value"],
|
||||||
|
teacher_keys=["query", "value"],
|
||||||
|
s_shapes=[[64, 16, 160], [128, 8, 160], [256, 4, 160],
|
||||||
|
[512, 2, 160]],
|
||||||
|
t_shapes=[[640, 48], [320, 96], [160, 192]],
|
||||||
|
qk_dim=128,
|
||||||
|
name="loss_afd"):
|
||||||
|
super().__init__()
|
||||||
|
assert isinstance(model_name_pair, list)
|
||||||
|
self.model_name_pair = model_name_pair
|
||||||
|
self.student_keys = student_keys
|
||||||
|
self.teacher_keys = teacher_keys
|
||||||
|
self.s_shapes = [[1] + s_i for s_i in s_shapes]
|
||||||
|
self.t_shapes = [[1] + t_i for t_i in t_shapes]
|
||||||
|
self.qk_dim = qk_dim
|
||||||
|
self.n_t, self.unique_t_shapes = unique_shape(self.t_shapes)
|
||||||
|
self.attention = Attention(self.qk_dim, self.t_shapes, self.s_shapes,
|
||||||
|
self.n_t, self.unique_t_shapes)
|
||||||
|
self.name = name
|
||||||
|
|
||||||
|
def forward(self, predicts, batch):
|
||||||
|
s_features_dict = predicts[self.model_name_pair[0]]
|
||||||
|
t_features_dict = predicts[self.model_name_pair[1]]
|
||||||
|
|
||||||
|
g_s = [s_features_dict[key] for key in self.student_keys]
|
||||||
|
g_t = [t_features_dict[key] for key in self.teacher_keys]
|
||||||
|
|
||||||
|
loss = self.attention(g_s, g_t)
|
||||||
|
sum_loss = sum(loss)
|
||||||
|
|
||||||
|
loss_dict = dict()
|
||||||
|
loss_dict[self.name] = sum_loss
|
||||||
|
|
||||||
|
return loss_dict
|
||||||
|
|
||||||
|
|
||||||
|
class Attention(nn.Layer):
|
||||||
|
def __init__(self, qk_dim, t_shapes, s_shapes, n_t, unique_t_shapes):
|
||||||
|
super().__init__()
|
||||||
|
self.qk_dim = qk_dim
|
||||||
|
self.n_t = n_t
|
||||||
|
# self.linear_trans_s = LinearTransformStudent(qk_dim, t_shapes, s_shapes, unique_t_shapes)
|
||||||
|
# self.linear_trans_t = LinearTransformTeacher(qk_dim, t_shapes)
|
||||||
|
|
||||||
|
self.p_t = self.create_parameter(
|
||||||
|
shape=[len(t_shapes), qk_dim],
|
||||||
|
default_initializer=nn.initializer.XavierNormal())
|
||||||
|
self.p_s = self.create_parameter(
|
||||||
|
shape=[len(s_shapes), qk_dim],
|
||||||
|
default_initializer=nn.initializer.XavierNormal())
|
||||||
|
|
||||||
|
def forward(self, g_s, g_t):
|
||||||
|
bilinear_key, h_hat_s_all = g_s
|
||||||
|
query, h_t_all = g_t
|
||||||
|
|
||||||
|
p_logit = paddle.matmul(self.p_t, self.p_s.t())
|
||||||
|
|
||||||
|
logit = paddle.add(
|
||||||
|
paddle.einsum('bstq,btq->bts', bilinear_key, query),
|
||||||
|
p_logit) / np.sqrt(self.qk_dim)
|
||||||
|
atts = F.softmax(logit, axis=2) # b x t x s
|
||||||
|
|
||||||
|
loss = []
|
||||||
|
|
||||||
|
for i, (n, h_t) in enumerate(zip(self.n_t, h_t_all)):
|
||||||
|
h_hat_s = h_hat_s_all[n]
|
||||||
|
diff = self.cal_diff(h_hat_s, h_t, atts[:, i])
|
||||||
|
loss.append(diff)
|
||||||
|
return loss
|
||||||
|
|
||||||
|
def cal_diff(self, v_s, v_t, att):
|
||||||
|
diff = (v_s - v_t.unsqueeze(1)).pow(2).mean(2)
|
||||||
|
diff = paddle.multiply(diff, att).sum(1).mean()
|
||||||
|
return diff
|
|
@ -14,11 +14,13 @@
|
||||||
|
|
||||||
import paddle
|
import paddle
|
||||||
import paddle.nn as nn
|
import paddle.nn as nn
|
||||||
|
import paddle.nn.functional as F
|
||||||
|
|
||||||
from .celoss import CELoss
|
from .celoss import CELoss
|
||||||
from .dmlloss import DMLLoss
|
from .dmlloss import DMLLoss
|
||||||
from .distanceloss import DistanceLoss
|
from .distanceloss import DistanceLoss
|
||||||
from .rkdloss import RKdAngle, RkdDistance
|
from .rkdloss import RKdAngle, RkdDistance
|
||||||
|
from .kldivloss import KLDivLoss
|
||||||
|
|
||||||
|
|
||||||
class DistillationCELoss(CELoss):
|
class DistillationCELoss(CELoss):
|
||||||
|
@ -172,3 +174,33 @@ class DistillationRKDLoss(nn.Layer):
|
||||||
student_out, teacher_out)
|
student_out, teacher_out)
|
||||||
|
|
||||||
return loss_dict
|
return loss_dict
|
||||||
|
|
||||||
|
|
||||||
|
class DistillationKLDivLoss(KLDivLoss):
|
||||||
|
"""
|
||||||
|
DistillationKLDivLoss
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self,
|
||||||
|
model_name_pairs=[],
|
||||||
|
temperature=4,
|
||||||
|
key=None,
|
||||||
|
name="loss_kl"):
|
||||||
|
super().__init__(temperature=temperature)
|
||||||
|
assert isinstance(model_name_pairs, list)
|
||||||
|
self.key = key
|
||||||
|
self.model_name_pairs = model_name_pairs
|
||||||
|
self.name = name
|
||||||
|
|
||||||
|
def forward(self, predicts, batch):
|
||||||
|
loss_dict = dict()
|
||||||
|
for idx, pair in enumerate(self.model_name_pairs):
|
||||||
|
out1 = predicts[pair[0]]
|
||||||
|
out2 = predicts[pair[1]]
|
||||||
|
if self.key is not None:
|
||||||
|
out1 = out1[self.key]
|
||||||
|
out2 = out2[self.key]
|
||||||
|
loss = super().forward(out1, out2)
|
||||||
|
for key in loss:
|
||||||
|
loss_dict["{}_{}_{}".format(key, pair[0], pair[1])] = loss[key]
|
||||||
|
return loss_dict
|
||||||
|
|
|
@ -0,0 +1,33 @@
|
||||||
|
# copyright (c) 2021 PaddlePaddle Authors. All Rights Reserve.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
import paddle
|
||||||
|
import paddle.nn as nn
|
||||||
|
import paddle.nn.functional as F
|
||||||
|
|
||||||
|
|
||||||
|
class KLDivLoss(nn.Layer):
|
||||||
|
"""
|
||||||
|
Distilling the Knowledge in a Neural Network
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, temperature=4):
|
||||||
|
super(KLDivLoss, self).__init__()
|
||||||
|
self.T = temperature
|
||||||
|
|
||||||
|
def forward(self, y_s, y_t):
|
||||||
|
p_s = F.log_softmax(y_s / self.T, axis=1)
|
||||||
|
p_t = F.softmax(y_t / self.T, axis=1)
|
||||||
|
loss = F.kl_div(p_s, p_t, reduction='sum') * (self.T**2) / y_s.shape[0]
|
||||||
|
return {"loss_kldiv": loss}
|
Loading…
Reference in New Issue