PaddleClas/ppcls/arch/distill/afd_attention.py

124 lines
4.2 KiB
Python

#copyright (c) 2021 PaddlePaddle Authors. All Rights Reserve.
#
#Licensed under the Apache License, Version 2.0 (the "License");
#you may not use this file except in compliance with the License.
#You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
#Unless required by applicable law or agreed to in writing, software
#distributed under the License is distributed on an "AS IS" BASIS,
#WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#See the License for the specific language governing permissions and
#limitations under the License.
import paddle.nn as nn
import paddle.nn.functional as F
import paddle
import numpy as np
class LinearBNReLU(nn.Layer):
def __init__(self, nin, nout):
super().__init__()
self.linear = nn.Linear(nin, nout)
self.bn = nn.BatchNorm1D(nout)
self.relu = nn.ReLU()
def forward(self, x, relu=True):
if relu:
return self.relu(self.bn(self.linear(x)))
return self.bn(self.linear(x))
def unique_shape(s_shapes):
n_s = []
unique_shapes = []
n = -1
for s_shape in s_shapes:
if s_shape not in unique_shapes:
unique_shapes.append(s_shape)
n += 1
n_s.append(n)
return n_s, unique_shapes
class LinearTransformTeacher(nn.Layer):
def __init__(self, qk_dim, t_shapes, keys):
super().__init__()
self.teacher_keys = keys
self.t_shapes = [[1] + t_i for t_i in t_shapes]
self.query_layer = nn.LayerList(
[LinearBNReLU(t_shape[1], qk_dim) for t_shape in self.t_shapes])
def forward(self, t_features_dict):
g_t = [t_features_dict[key] for key in self.teacher_keys]
bs = g_t[0].shape[0]
channel_mean = [f_t.mean(3).mean(2) for f_t in g_t]
spatial_mean = []
for i in range(len(g_t)):
c, h, w = g_t[i].shape[1:]
spatial_mean.append(g_t[i].pow(2).mean(1).reshape([bs, h * w]))
query = paddle.stack(
[
query_layer(
f_t, relu=False)
for f_t, query_layer in zip(channel_mean, self.query_layer)
],
axis=1)
value = [F.normalize(f_s, axis=1) for f_s in spatial_mean]
return {"query": query, "value": value}
class LinearTransformStudent(nn.Layer):
def __init__(self, qk_dim, t_shapes, s_shapes, keys):
super().__init__()
self.student_keys = keys
self.t_shapes = [[1] + t_i for t_i in t_shapes]
self.s_shapes = [[1] + s_i for s_i in s_shapes]
self.t = len(self.t_shapes)
self.s = len(self.s_shapes)
self.qk_dim = qk_dim
self.n_t, self.unique_t_shapes = unique_shape(self.t_shapes)
self.relu = nn.ReLU()
self.samplers = nn.LayerList(
[Sample(t_shape) for t_shape in self.unique_t_shapes])
self.key_layer = nn.LayerList([
LinearBNReLU(s_shape[1], self.qk_dim) for s_shape in self.s_shapes
])
self.bilinear = LinearBNReLU(qk_dim, qk_dim * len(self.t_shapes))
def forward(self, s_features_dict):
g_s = [s_features_dict[key] for key in self.student_keys]
bs = g_s[0].shape[0]
channel_mean = [f_s.mean(3).mean(2) for f_s in g_s]
spatial_mean = [sampler(g_s, bs) for sampler in self.samplers]
key = paddle.stack(
[
key_layer(f_s)
for key_layer, f_s in zip(self.key_layer, channel_mean)
],
axis=1).reshape([-1, self.qk_dim]) # Bs x h
bilinear_key = self.bilinear(
key, relu=False).reshape([bs, self.s, self.t, self.qk_dim])
value = [F.normalize(s_m, axis=2) for s_m in spatial_mean]
return {"bilinear_key": bilinear_key, "value": value}
class Sample(nn.Layer):
def __init__(self, t_shape):
super().__init__()
self.t_N, self.t_C, self.t_H, self.t_W = t_shape
self.sample = nn.AdaptiveAvgPool2D((self.t_H, self.t_W))
def forward(self, g_s, bs):
g_s = paddle.stack(
[
self.sample(f_s.pow(2).mean(
1, keepdim=True)).reshape([bs, self.t_H * self.t_W])
for f_s in g_s
],
axis=1)
return g_s