124 lines
4.2 KiB
Python
124 lines
4.2 KiB
Python
#copyright (c) 2021 PaddlePaddle Authors. All Rights Reserve.
|
|
#
|
|
#Licensed under the Apache License, Version 2.0 (the "License");
|
|
#you may not use this file except in compliance with the License.
|
|
#You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
#Unless required by applicable law or agreed to in writing, software
|
|
#distributed under the License is distributed on an "AS IS" BASIS,
|
|
#WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
#See the License for the specific language governing permissions and
|
|
#limitations under the License.
|
|
|
|
import paddle.nn as nn
|
|
import paddle.nn.functional as F
|
|
import paddle
|
|
import numpy as np
|
|
|
|
|
|
class LinearBNReLU(nn.Layer):
|
|
def __init__(self, nin, nout):
|
|
super().__init__()
|
|
self.linear = nn.Linear(nin, nout)
|
|
self.bn = nn.BatchNorm1D(nout)
|
|
self.relu = nn.ReLU()
|
|
|
|
def forward(self, x, relu=True):
|
|
if relu:
|
|
return self.relu(self.bn(self.linear(x)))
|
|
return self.bn(self.linear(x))
|
|
|
|
|
|
def unique_shape(s_shapes):
|
|
n_s = []
|
|
unique_shapes = []
|
|
n = -1
|
|
for s_shape in s_shapes:
|
|
if s_shape not in unique_shapes:
|
|
unique_shapes.append(s_shape)
|
|
n += 1
|
|
n_s.append(n)
|
|
return n_s, unique_shapes
|
|
|
|
|
|
class LinearTransformTeacher(nn.Layer):
|
|
def __init__(self, qk_dim, t_shapes, keys):
|
|
super().__init__()
|
|
self.teacher_keys = keys
|
|
self.t_shapes = [[1] + t_i for t_i in t_shapes]
|
|
self.query_layer = nn.LayerList(
|
|
[LinearBNReLU(t_shape[1], qk_dim) for t_shape in self.t_shapes])
|
|
|
|
def forward(self, t_features_dict):
|
|
g_t = [t_features_dict[key] for key in self.teacher_keys]
|
|
bs = g_t[0].shape[0]
|
|
channel_mean = [f_t.mean(3).mean(2) for f_t in g_t]
|
|
spatial_mean = []
|
|
for i in range(len(g_t)):
|
|
c, h, w = g_t[i].shape[1:]
|
|
spatial_mean.append(g_t[i].pow(2).mean(1).reshape([bs, h * w]))
|
|
query = paddle.stack(
|
|
[
|
|
query_layer(
|
|
f_t, relu=False)
|
|
for f_t, query_layer in zip(channel_mean, self.query_layer)
|
|
],
|
|
axis=1)
|
|
value = [F.normalize(f_s, axis=1) for f_s in spatial_mean]
|
|
return {"query": query, "value": value}
|
|
|
|
|
|
class LinearTransformStudent(nn.Layer):
|
|
def __init__(self, qk_dim, t_shapes, s_shapes, keys):
|
|
super().__init__()
|
|
self.student_keys = keys
|
|
self.t_shapes = [[1] + t_i for t_i in t_shapes]
|
|
self.s_shapes = [[1] + s_i for s_i in s_shapes]
|
|
self.t = len(self.t_shapes)
|
|
self.s = len(self.s_shapes)
|
|
self.qk_dim = qk_dim
|
|
self.n_t, self.unique_t_shapes = unique_shape(self.t_shapes)
|
|
self.relu = nn.ReLU()
|
|
self.samplers = nn.LayerList(
|
|
[Sample(t_shape) for t_shape in self.unique_t_shapes])
|
|
self.key_layer = nn.LayerList([
|
|
LinearBNReLU(s_shape[1], self.qk_dim) for s_shape in self.s_shapes
|
|
])
|
|
self.bilinear = LinearBNReLU(qk_dim, qk_dim * len(self.t_shapes))
|
|
|
|
def forward(self, s_features_dict):
|
|
g_s = [s_features_dict[key] for key in self.student_keys]
|
|
bs = g_s[0].shape[0]
|
|
channel_mean = [f_s.mean(3).mean(2) for f_s in g_s]
|
|
spatial_mean = [sampler(g_s, bs) for sampler in self.samplers]
|
|
|
|
key = paddle.stack(
|
|
[
|
|
key_layer(f_s)
|
|
for key_layer, f_s in zip(self.key_layer, channel_mean)
|
|
],
|
|
axis=1).reshape([-1, self.qk_dim]) # Bs x h
|
|
bilinear_key = self.bilinear(
|
|
key, relu=False).reshape([bs, self.s, self.t, self.qk_dim])
|
|
value = [F.normalize(s_m, axis=2) for s_m in spatial_mean]
|
|
return {"bilinear_key": bilinear_key, "value": value}
|
|
|
|
|
|
class Sample(nn.Layer):
|
|
def __init__(self, t_shape):
|
|
super().__init__()
|
|
self.t_N, self.t_C, self.t_H, self.t_W = t_shape
|
|
self.sample = nn.AdaptiveAvgPool2D((self.t_H, self.t_W))
|
|
|
|
def forward(self, g_s, bs):
|
|
g_s = paddle.stack(
|
|
[
|
|
self.sample(f_s.pow(2).mean(
|
|
1, keepdim=True)).reshape([bs, self.t_H * self.t_W])
|
|
for f_s in g_s
|
|
],
|
|
axis=1)
|
|
return g_s
|