133 lines
4.3 KiB
Python
133 lines
4.3 KiB
Python
#copyright (c) 2021 PaddlePaddle Authors. All Rights Reserve.
|
|
#
|
|
#Licensed under the Apache License, Version 2.0 (the "License");
|
|
#you may not use this file except in compliance with the License.
|
|
#You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
#Unless required by applicable law or agreed to in writing, software
|
|
#distributed under the License is distributed on an "AS IS" BASIS,
|
|
#WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
#See the License for the specific language governing permissions and
|
|
#limitations under the License.
|
|
|
|
import paddle.nn as nn
|
|
import paddle.nn.functional as F
|
|
import paddle
|
|
import numpy as np
|
|
import matplotlib.pyplot as plt
|
|
import cv2
|
|
import warnings
|
|
warnings.filterwarnings('ignore')
|
|
|
|
|
|
class LinearBNReLU(nn.Layer):
|
|
def __init__(self, nin, nout):
|
|
super().__init__()
|
|
self.linear = nn.Linear(nin, nout)
|
|
self.bn = nn.BatchNorm1D(nout)
|
|
self.relu = nn.ReLU()
|
|
|
|
def forward(self, x, relu=True):
|
|
if relu:
|
|
return self.relu(self.bn(self.linear(x)))
|
|
return self.bn(self.linear(x))
|
|
|
|
|
|
def unique_shape(s_shapes):
|
|
n_s = []
|
|
unique_shapes = []
|
|
n = -1
|
|
for s_shape in s_shapes:
|
|
if s_shape not in unique_shapes:
|
|
unique_shapes.append(s_shape)
|
|
n += 1
|
|
n_s.append(n)
|
|
return n_s, unique_shapes
|
|
|
|
|
|
class AFDLoss(nn.Layer):
|
|
"""
|
|
AFDLoss
|
|
https://www.aaai.org/AAAI21Papers/AAAI-9785.JiM.pdf
|
|
https://github.com/clovaai/attention-feature-distillation
|
|
"""
|
|
|
|
def __init__(self,
|
|
model_name_pair=["Student", "Teacher"],
|
|
student_keys=["bilinear_key", "value"],
|
|
teacher_keys=["query", "value"],
|
|
s_shapes=[[64, 16, 160], [128, 8, 160], [256, 4, 160],
|
|
[512, 2, 160]],
|
|
t_shapes=[[640, 48], [320, 96], [160, 192]],
|
|
qk_dim=128,
|
|
name="loss_afd"):
|
|
super().__init__()
|
|
assert isinstance(model_name_pair, list)
|
|
self.model_name_pair = model_name_pair
|
|
self.student_keys = student_keys
|
|
self.teacher_keys = teacher_keys
|
|
self.s_shapes = [[1] + s_i for s_i in s_shapes]
|
|
self.t_shapes = [[1] + t_i for t_i in t_shapes]
|
|
self.qk_dim = qk_dim
|
|
self.n_t, self.unique_t_shapes = unique_shape(self.t_shapes)
|
|
self.attention = Attention(self.qk_dim, self.t_shapes, self.s_shapes,
|
|
self.n_t, self.unique_t_shapes)
|
|
self.name = name
|
|
|
|
def forward(self, predicts, batch):
|
|
s_features_dict = predicts[self.model_name_pair[0]]
|
|
t_features_dict = predicts[self.model_name_pair[1]]
|
|
|
|
g_s = [s_features_dict[key] for key in self.student_keys]
|
|
g_t = [t_features_dict[key] for key in self.teacher_keys]
|
|
|
|
loss = self.attention(g_s, g_t)
|
|
sum_loss = sum(loss)
|
|
|
|
loss_dict = dict()
|
|
loss_dict[self.name] = sum_loss
|
|
|
|
return loss_dict
|
|
|
|
|
|
class Attention(nn.Layer):
|
|
def __init__(self, qk_dim, t_shapes, s_shapes, n_t, unique_t_shapes):
|
|
super().__init__()
|
|
self.qk_dim = qk_dim
|
|
self.n_t = n_t
|
|
# self.linear_trans_s = LinearTransformStudent(qk_dim, t_shapes, s_shapes, unique_t_shapes)
|
|
# self.linear_trans_t = LinearTransformTeacher(qk_dim, t_shapes)
|
|
|
|
self.p_t = self.create_parameter(
|
|
shape=[len(t_shapes), qk_dim],
|
|
default_initializer=nn.initializer.XavierNormal())
|
|
self.p_s = self.create_parameter(
|
|
shape=[len(s_shapes), qk_dim],
|
|
default_initializer=nn.initializer.XavierNormal())
|
|
|
|
def forward(self, g_s, g_t):
|
|
bilinear_key, h_hat_s_all = g_s
|
|
query, h_t_all = g_t
|
|
|
|
p_logit = paddle.matmul(self.p_t, self.p_s.t())
|
|
|
|
logit = paddle.add(
|
|
paddle.einsum('bstq,btq->bts', bilinear_key, query),
|
|
p_logit) / np.sqrt(self.qk_dim)
|
|
atts = F.softmax(logit, axis=2) # b x t x s
|
|
|
|
loss = []
|
|
|
|
for i, (n, h_t) in enumerate(zip(self.n_t, h_t_all)):
|
|
h_hat_s = h_hat_s_all[n]
|
|
diff = self.cal_diff(h_hat_s, h_t, atts[:, i])
|
|
loss.append(diff)
|
|
return loss
|
|
|
|
def cal_diff(self, v_s, v_t, att):
|
|
diff = (v_s - v_t.unsqueeze(1)).pow(2).mean(2)
|
|
diff = paddle.multiply(diff, att).sum(1).mean()
|
|
return diff
|