618 lines
22 KiB
Python
618 lines
22 KiB
Python
# Copyright (c) OpenMMLab. All rights reserved.
|
|
from typing import Sequence
|
|
|
|
import torch
|
|
import torch.nn as nn
|
|
import torch.nn.functional as F
|
|
from mmcv.cnn import build_activation_layer, build_norm_layer
|
|
from mmcv.cnn.bricks.drop import DropPath
|
|
from mmcv.cnn.utils.weight_init import trunc_normal_
|
|
|
|
from mmcls.utils import get_root_logger
|
|
from ..builder import BACKBONES
|
|
from .base_backbone import BaseBackbone, BaseModule
|
|
from .vision_transformer import TransformerEncoderLayer
|
|
|
|
|
|
class ConvBlock(BaseModule):
|
|
"""Basic convluation block used in Conformer.
|
|
|
|
This block includes three convluation modules, and supports three new
|
|
functions:
|
|
1. Returns the output of both the final layers and the second convluation
|
|
module.
|
|
2. Fuses the input of the second convluation module with an extra input
|
|
feature map.
|
|
3. Supports to add an extra convluation module to the identity connection.
|
|
|
|
Args:
|
|
in_channels (int): The number of input channels.
|
|
out_channels (int): The number of output channels.
|
|
stride (int): The stride of the second convluation module.
|
|
Defaults to 1.
|
|
groups (int): The groups of the second convluation module.
|
|
Defaults to 1.
|
|
drop_path_rate (float): The rate of the DropPath layer. Defaults to 0.
|
|
with_residual_conv (bool): Whether to add an extra convluation module
|
|
to the identity connection. Defaults to False.
|
|
norm_cfg (dict): The config of normalization layers.
|
|
Defaults to ``dict(type='BN', eps=1e-6)``.
|
|
act_cfg (dict): The config of activative functions.
|
|
Defaults to ``dict(type='ReLU', inplace=True))``.
|
|
init_cfg (dict, optional): The extra config to initialize the module.
|
|
Defaults to None.
|
|
"""
|
|
|
|
def __init__(self,
|
|
in_channels,
|
|
out_channels,
|
|
stride=1,
|
|
groups=1,
|
|
drop_path_rate=0.,
|
|
with_residual_conv=False,
|
|
norm_cfg=dict(type='BN', eps=1e-6),
|
|
act_cfg=dict(type='ReLU', inplace=True),
|
|
init_cfg=None):
|
|
super(ConvBlock, self).__init__(init_cfg=init_cfg)
|
|
|
|
expansion = 4
|
|
mid_channels = out_channels // expansion
|
|
|
|
self.conv1 = nn.Conv2d(
|
|
in_channels,
|
|
mid_channels,
|
|
kernel_size=1,
|
|
stride=1,
|
|
padding=0,
|
|
bias=False)
|
|
self.bn1 = build_norm_layer(norm_cfg, mid_channels)[1]
|
|
self.act1 = build_activation_layer(act_cfg)
|
|
|
|
self.conv2 = nn.Conv2d(
|
|
mid_channels,
|
|
mid_channels,
|
|
kernel_size=3,
|
|
stride=stride,
|
|
groups=groups,
|
|
padding=1,
|
|
bias=False)
|
|
self.bn2 = build_norm_layer(norm_cfg, mid_channels)[1]
|
|
self.act2 = build_activation_layer(act_cfg)
|
|
|
|
self.conv3 = nn.Conv2d(
|
|
mid_channels,
|
|
out_channels,
|
|
kernel_size=1,
|
|
stride=1,
|
|
padding=0,
|
|
bias=False)
|
|
self.bn3 = build_norm_layer(norm_cfg, out_channels)[1]
|
|
self.act3 = build_activation_layer(act_cfg)
|
|
|
|
if with_residual_conv:
|
|
self.residual_conv = nn.Conv2d(
|
|
in_channels,
|
|
out_channels,
|
|
kernel_size=1,
|
|
stride=stride,
|
|
padding=0,
|
|
bias=False)
|
|
self.residual_bn = build_norm_layer(norm_cfg, out_channels)[1]
|
|
|
|
self.with_residual_conv = with_residual_conv
|
|
self.drop_path = DropPath(
|
|
drop_path_rate) if drop_path_rate > 0. else nn.Identity()
|
|
|
|
def zero_init_last_bn(self):
|
|
nn.init.zeros_(self.bn3.weight)
|
|
|
|
def forward(self, x, fusion_features=None, out_conv2=True):
|
|
identity = x
|
|
|
|
x = self.conv1(x)
|
|
x = self.bn1(x)
|
|
x = self.act1(x)
|
|
|
|
x = self.conv2(x) if fusion_features is None else self.conv2(
|
|
x + fusion_features)
|
|
x = self.bn2(x)
|
|
x2 = self.act2(x)
|
|
|
|
x = self.conv3(x2)
|
|
x = self.bn3(x)
|
|
|
|
if self.drop_path is not None:
|
|
x = self.drop_path(x)
|
|
|
|
if self.with_residual_conv:
|
|
identity = self.residual_conv(identity)
|
|
identity = self.residual_bn(identity)
|
|
|
|
x += identity
|
|
x = self.act3(x)
|
|
|
|
if out_conv2:
|
|
return x, x2
|
|
else:
|
|
return x
|
|
|
|
|
|
class FCUDown(BaseModule):
|
|
"""CNN feature maps -> Transformer patch embeddings."""
|
|
|
|
def __init__(self,
|
|
in_channels,
|
|
out_channels,
|
|
down_stride,
|
|
with_cls_token=True,
|
|
norm_cfg=dict(type='LN', eps=1e-6),
|
|
act_cfg=dict(type='GELU'),
|
|
init_cfg=None):
|
|
super(FCUDown, self).__init__(init_cfg=init_cfg)
|
|
self.down_stride = down_stride
|
|
self.with_cls_token = with_cls_token
|
|
|
|
self.conv_project = nn.Conv2d(
|
|
in_channels, out_channels, kernel_size=1, stride=1, padding=0)
|
|
self.sample_pooling = nn.AvgPool2d(
|
|
kernel_size=down_stride, stride=down_stride)
|
|
|
|
self.ln = build_norm_layer(norm_cfg, out_channels)[1]
|
|
self.act = build_activation_layer(act_cfg)
|
|
|
|
def forward(self, x, x_t):
|
|
x = self.conv_project(x) # [N, C, H, W]
|
|
|
|
x = self.sample_pooling(x).flatten(2).transpose(1, 2)
|
|
x = self.ln(x)
|
|
x = self.act(x)
|
|
|
|
if self.with_cls_token:
|
|
x = torch.cat([x_t[:, 0][:, None, :], x], dim=1)
|
|
|
|
return x
|
|
|
|
|
|
class FCUUp(BaseModule):
|
|
"""Transformer patch embeddings -> CNN feature maps."""
|
|
|
|
def __init__(self,
|
|
in_channels,
|
|
out_channels,
|
|
up_stride,
|
|
with_cls_token=True,
|
|
norm_cfg=dict(type='BN', eps=1e-6),
|
|
act_cfg=dict(type='ReLU', inplace=True),
|
|
init_cfg=None):
|
|
super(FCUUp, self).__init__(init_cfg=init_cfg)
|
|
|
|
self.up_stride = up_stride
|
|
self.with_cls_token = with_cls_token
|
|
|
|
self.conv_project = nn.Conv2d(
|
|
in_channels, out_channels, kernel_size=1, stride=1, padding=0)
|
|
self.bn = build_norm_layer(norm_cfg, out_channels)[1]
|
|
self.act = build_activation_layer(act_cfg)
|
|
|
|
def forward(self, x, H, W):
|
|
B, _, C = x.shape
|
|
# [N, 197, 384] -> [N, 196, 384] -> [N, 384, 196] -> [N, 384, 14, 14]
|
|
if self.with_cls_token:
|
|
x_r = x[:, 1:].transpose(1, 2).reshape(B, C, H, W)
|
|
else:
|
|
x_r = x.transpose(1, 2).reshape(B, C, H, W)
|
|
|
|
x_r = self.act(self.bn(self.conv_project(x_r)))
|
|
|
|
return F.interpolate(
|
|
x_r, size=(H * self.up_stride, W * self.up_stride))
|
|
|
|
|
|
class ConvTransBlock(BaseModule):
|
|
"""Basic module for Conformer.
|
|
|
|
This module is a fusion of CNN block transformer encoder block.
|
|
|
|
Args:
|
|
in_channels (int): The number of input channels in conv blocks.
|
|
out_channels (int): The number of output channels in conv blocks.
|
|
embed_dims (int): The embedding dimension in transformer blocks.
|
|
conv_stride (int): The stride of conv2d layers. Defaults to 1.
|
|
groups (int): The groups of conv blocks. Defaults to 1.
|
|
with_residual_conv (bool): Whether to add a conv-bn layer to the
|
|
identity connect in the conv block. Defaults to False.
|
|
down_stride (int): The stride of the downsample pooling layer.
|
|
Defaults to 4.
|
|
num_heads (int): The number of heads in transformer attention layers.
|
|
Defaults to 12.
|
|
mlp_ratio (float): The expansion ratio in transformer FFN module.
|
|
Defaults to 4.
|
|
qkv_bias (bool): Enable bias for qkv if True. Defaults to False.
|
|
with_cls_token (bool): Whether use class token or not.
|
|
Defaults to True.
|
|
drop_rate (float): The dropout rate of the output projection and
|
|
FFN in the transformer block. Defaults to 0.
|
|
attn_drop_rate (float): The dropout rate after the attention
|
|
calculation in the transformer block. Defaults to 0.
|
|
drop_path_rate (bloat): The drop path rate in both the conv block
|
|
and the transformer block. Defaults to 0.
|
|
last_fusion (bool): Whether this block is the last stage. If so,
|
|
downsample the fusion feature map.
|
|
init_cfg (dict, optional): The extra config to initialize the module.
|
|
Defaults to None.
|
|
"""
|
|
|
|
def __init__(self,
|
|
in_channels,
|
|
out_channels,
|
|
embed_dims,
|
|
conv_stride=1,
|
|
groups=1,
|
|
with_residual_conv=False,
|
|
down_stride=4,
|
|
num_heads=12,
|
|
mlp_ratio=4.,
|
|
qkv_bias=False,
|
|
with_cls_token=True,
|
|
drop_rate=0.,
|
|
attn_drop_rate=0.,
|
|
drop_path_rate=0.,
|
|
last_fusion=False,
|
|
init_cfg=None):
|
|
super(ConvTransBlock, self).__init__(init_cfg=init_cfg)
|
|
expansion = 4
|
|
self.cnn_block = ConvBlock(
|
|
in_channels=in_channels,
|
|
out_channels=out_channels,
|
|
with_residual_conv=with_residual_conv,
|
|
stride=conv_stride,
|
|
groups=groups)
|
|
|
|
if last_fusion:
|
|
self.fusion_block = ConvBlock(
|
|
in_channels=out_channels,
|
|
out_channels=out_channels,
|
|
stride=2,
|
|
with_residual_conv=True,
|
|
groups=groups,
|
|
drop_path_rate=drop_path_rate)
|
|
else:
|
|
self.fusion_block = ConvBlock(
|
|
in_channels=out_channels,
|
|
out_channels=out_channels,
|
|
groups=groups,
|
|
drop_path_rate=drop_path_rate)
|
|
|
|
self.squeeze_block = FCUDown(
|
|
in_channels=out_channels // expansion,
|
|
out_channels=embed_dims,
|
|
down_stride=down_stride,
|
|
with_cls_token=with_cls_token)
|
|
|
|
self.expand_block = FCUUp(
|
|
in_channels=embed_dims,
|
|
out_channels=out_channels // expansion,
|
|
up_stride=down_stride,
|
|
with_cls_token=with_cls_token)
|
|
|
|
self.trans_block = TransformerEncoderLayer(
|
|
embed_dims=embed_dims,
|
|
num_heads=num_heads,
|
|
feedforward_channels=int(embed_dims * mlp_ratio),
|
|
drop_rate=drop_rate,
|
|
drop_path_rate=drop_path_rate,
|
|
attn_drop_rate=attn_drop_rate,
|
|
qkv_bias=qkv_bias,
|
|
norm_cfg=dict(type='LN', eps=1e-6))
|
|
|
|
self.down_stride = down_stride
|
|
self.embed_dim = embed_dims
|
|
self.last_fusion = last_fusion
|
|
|
|
def forward(self, cnn_input, trans_input):
|
|
x, x_conv2 = self.cnn_block(cnn_input, out_conv2=True)
|
|
|
|
_, _, H, W = x_conv2.shape
|
|
|
|
# Convert the feature map of conv2 to transformer embedding
|
|
# and concat with class token.
|
|
conv2_embedding = self.squeeze_block(x_conv2, trans_input)
|
|
|
|
trans_output = self.trans_block(conv2_embedding + trans_input)
|
|
|
|
# Convert the transformer output embedding to feature map
|
|
trans_features = self.expand_block(trans_output, H // self.down_stride,
|
|
W // self.down_stride)
|
|
x = self.fusion_block(
|
|
x, fusion_features=trans_features, out_conv2=False)
|
|
|
|
return x, trans_output
|
|
|
|
|
|
@BACKBONES.register_module()
|
|
class Conformer(BaseBackbone):
|
|
"""Conformer backbone.
|
|
|
|
A PyTorch implementation of : `Conformer: Local Features Coupling Global
|
|
Representations for Visual Recognition <https://arxiv.org/abs/2105.03889>`_
|
|
|
|
Args:
|
|
arch (str | dict): Conformer architecture. Defaults to 'tiny'.
|
|
patch_size (int): The patch size. Defaults to 16.
|
|
base_channels (int): The base number of channels in CNN network.
|
|
Defaults to 64.
|
|
mlp_ratio (float): The expansion ratio of FFN network in transformer
|
|
block. Defaults to 4.
|
|
with_cls_token (bool): Whether use class token or not.
|
|
Defaults to True.
|
|
drop_path_rate (float): stochastic depth rate. Defaults to 0.
|
|
out_indices (Sequence | int): Output from which stages.
|
|
Defaults to -1, means the last stage.
|
|
init_cfg (dict, optional): Initialization config dict.
|
|
Defaults to None.
|
|
"""
|
|
arch_zoo = {
|
|
**dict.fromkeys(['t', 'tiny'],
|
|
{'embed_dims': 384,
|
|
'channel_ratio': 1,
|
|
'num_heads': 6,
|
|
'depths': 12
|
|
}),
|
|
**dict.fromkeys(['s', 'small'],
|
|
{'embed_dims': 384,
|
|
'channel_ratio': 4,
|
|
'num_heads': 6,
|
|
'depths': 12
|
|
}),
|
|
**dict.fromkeys(['b', 'base'],
|
|
{'embed_dims': 576,
|
|
'channel_ratio': 6,
|
|
'num_heads': 9,
|
|
'depths': 12
|
|
}),
|
|
} # yapf: disable
|
|
|
|
_version = 1
|
|
|
|
def __init__(self,
|
|
arch='tiny',
|
|
patch_size=16,
|
|
base_channels=64,
|
|
mlp_ratio=4.,
|
|
qkv_bias=True,
|
|
with_cls_token=True,
|
|
drop_path_rate=0.,
|
|
norm_eval=True,
|
|
frozen_stages=0,
|
|
out_indices=-1,
|
|
init_cfg=None):
|
|
|
|
super().__init__(init_cfg=init_cfg)
|
|
|
|
if isinstance(arch, str):
|
|
arch = arch.lower()
|
|
assert arch in set(self.arch_zoo), \
|
|
f'Arch {arch} is not in default archs {set(self.arch_zoo)}'
|
|
self.arch_settings = self.arch_zoo[arch]
|
|
else:
|
|
essential_keys = {
|
|
'embed_dims', 'depths', 'num_heads', 'channel_ratio'
|
|
}
|
|
assert isinstance(arch, dict) and set(arch) == essential_keys, \
|
|
f'Custom arch needs a dict with keys {essential_keys}'
|
|
self.arch_settings = arch
|
|
|
|
self.num_features = self.embed_dims = self.arch_settings['embed_dims']
|
|
self.depths = self.arch_settings['depths']
|
|
self.num_heads = self.arch_settings['num_heads']
|
|
self.channel_ratio = self.arch_settings['channel_ratio']
|
|
|
|
if isinstance(out_indices, int):
|
|
out_indices = [out_indices]
|
|
assert isinstance(out_indices, Sequence), \
|
|
f'"out_indices" must by a sequence or int, ' \
|
|
f'get {type(out_indices)} instead.'
|
|
for i, index in enumerate(out_indices):
|
|
if index < 0:
|
|
out_indices[i] = self.depths + index + 1
|
|
assert out_indices[i] >= 0, f'Invalid out_indices {index}'
|
|
self.out_indices = out_indices
|
|
|
|
self.norm_eval = norm_eval
|
|
self.frozen_stages = frozen_stages
|
|
|
|
self.with_cls_token = with_cls_token
|
|
if self.with_cls_token:
|
|
self.cls_token = nn.Parameter(torch.zeros(1, 1, self.embed_dims))
|
|
|
|
# stochastic depth decay rule
|
|
self.trans_dpr = [
|
|
x.item() for x in torch.linspace(0, drop_path_rate, self.depths)
|
|
]
|
|
|
|
# Stem stage: get the feature maps by conv block
|
|
self.conv1 = nn.Conv2d(
|
|
3, 64, kernel_size=7, stride=2, padding=3,
|
|
bias=False) # 1 / 2 [112, 112]
|
|
self.bn1 = nn.BatchNorm2d(64)
|
|
self.act1 = nn.ReLU(inplace=True)
|
|
self.maxpool = nn.MaxPool2d(
|
|
kernel_size=3, stride=2, padding=1) # 1 / 4 [56, 56]
|
|
|
|
# 1 stage
|
|
stage1_channels = int(base_channels * self.channel_ratio)
|
|
trans_down_stride = patch_size // 4
|
|
self.conv_1 = ConvBlock(
|
|
in_channels=64,
|
|
out_channels=stage1_channels,
|
|
with_residual_conv=True,
|
|
stride=1)
|
|
self.trans_patch_conv = nn.Conv2d(
|
|
64,
|
|
self.embed_dims,
|
|
kernel_size=trans_down_stride,
|
|
stride=trans_down_stride,
|
|
padding=0)
|
|
|
|
self.trans_1 = TransformerEncoderLayer(
|
|
embed_dims=self.embed_dims,
|
|
num_heads=self.num_heads,
|
|
feedforward_channels=int(self.embed_dims * mlp_ratio),
|
|
drop_path_rate=self.trans_dpr[0],
|
|
qkv_bias=qkv_bias,
|
|
norm_cfg=dict(type='LN', eps=1e-6))
|
|
|
|
# 2~4 stage
|
|
init_stage = 2
|
|
fin_stage = self.depths // 3 + 1
|
|
for i in range(init_stage, fin_stage):
|
|
self.add_module(
|
|
f'conv_trans_{i}',
|
|
ConvTransBlock(
|
|
in_channels=stage1_channels,
|
|
out_channels=stage1_channels,
|
|
embed_dims=self.embed_dims,
|
|
conv_stride=1,
|
|
with_residual_conv=False,
|
|
down_stride=trans_down_stride,
|
|
num_heads=self.num_heads,
|
|
mlp_ratio=mlp_ratio,
|
|
qkv_bias=qkv_bias,
|
|
drop_path_rate=self.trans_dpr[i - 1],
|
|
with_cls_token=self.with_cls_token))
|
|
|
|
stage2_channels = int(base_channels * self.channel_ratio * 2)
|
|
# 5~8 stage
|
|
init_stage = fin_stage # 5
|
|
fin_stage = fin_stage + self.depths // 3 # 9
|
|
for i in range(init_stage, fin_stage):
|
|
if i == init_stage:
|
|
conv_stride = 2
|
|
in_channels = stage1_channels
|
|
else:
|
|
conv_stride = 1
|
|
in_channels = stage2_channels
|
|
|
|
with_residual_conv = True if i == init_stage else False
|
|
self.add_module(
|
|
f'conv_trans_{i}',
|
|
ConvTransBlock(
|
|
in_channels=in_channels,
|
|
out_channels=stage2_channels,
|
|
embed_dims=self.embed_dims,
|
|
conv_stride=conv_stride,
|
|
with_residual_conv=with_residual_conv,
|
|
down_stride=trans_down_stride // 2,
|
|
num_heads=self.num_heads,
|
|
mlp_ratio=mlp_ratio,
|
|
qkv_bias=qkv_bias,
|
|
drop_path_rate=self.trans_dpr[i - 1],
|
|
with_cls_token=self.with_cls_token))
|
|
|
|
stage3_channels = int(base_channels * self.channel_ratio * 2 * 2)
|
|
# 9~12 stage
|
|
init_stage = fin_stage # 9
|
|
fin_stage = fin_stage + self.depths // 3 # 13
|
|
for i in range(init_stage, fin_stage):
|
|
if i == init_stage:
|
|
conv_stride = 2
|
|
in_channels = stage2_channels
|
|
with_residual_conv = True
|
|
else:
|
|
conv_stride = 1
|
|
in_channels = stage3_channels
|
|
with_residual_conv = False
|
|
|
|
last_fusion = (i == self.depths)
|
|
|
|
self.add_module(
|
|
f'conv_trans_{i}',
|
|
ConvTransBlock(
|
|
in_channels=in_channels,
|
|
out_channels=stage3_channels,
|
|
embed_dims=self.embed_dims,
|
|
conv_stride=conv_stride,
|
|
with_residual_conv=with_residual_conv,
|
|
down_stride=trans_down_stride // 4,
|
|
num_heads=self.num_heads,
|
|
mlp_ratio=mlp_ratio,
|
|
qkv_bias=qkv_bias,
|
|
drop_path_rate=self.trans_dpr[i - 1],
|
|
with_cls_token=self.with_cls_token,
|
|
last_fusion=last_fusion))
|
|
self.fin_stage = fin_stage
|
|
|
|
self.pooling = nn.AdaptiveAvgPool2d(1)
|
|
self.trans_norm = nn.LayerNorm(self.embed_dims)
|
|
|
|
if self.with_cls_token:
|
|
trunc_normal_(self.cls_token, std=.02)
|
|
|
|
def _init_weights(self, m):
|
|
if isinstance(m, nn.Linear):
|
|
trunc_normal_(m.weight, std=.02)
|
|
if isinstance(m, nn.Linear) and m.bias is not None:
|
|
nn.init.constant_(m.bias, 0)
|
|
elif isinstance(m, nn.LayerNorm):
|
|
nn.init.constant_(m.bias, 0)
|
|
nn.init.constant_(m.weight, 1.0)
|
|
elif isinstance(m, nn.Conv2d):
|
|
nn.init.kaiming_normal_(
|
|
m.weight, mode='fan_out', nonlinearity='relu')
|
|
elif isinstance(m, nn.BatchNorm2d):
|
|
nn.init.constant_(m.weight, 1.)
|
|
nn.init.constant_(m.bias, 0.)
|
|
|
|
if hasattr(m, 'zero_init_last_bn'):
|
|
m.zero_init_last_bn()
|
|
|
|
def init_weights(self):
|
|
super(Conformer, self).init_weights()
|
|
logger = get_root_logger()
|
|
|
|
if (isinstance(self.init_cfg, dict)
|
|
and self.init_cfg['type'] == 'Pretrained'):
|
|
# Suppress default init if use pretrained model.
|
|
return
|
|
else:
|
|
logger.info(f'No pre-trained weights for '
|
|
f'{self.__class__.__name__}, '
|
|
f'training start from scratch')
|
|
self.apply(self._init_weights)
|
|
|
|
def forward(self, x):
|
|
output = []
|
|
B = x.shape[0]
|
|
if self.with_cls_token:
|
|
cls_tokens = self.cls_token.expand(B, -1, -1)
|
|
|
|
# stem
|
|
x_base = self.maxpool(self.act1(self.bn1(self.conv1(x))))
|
|
|
|
# 1 stage [N, 64, 56, 56] -> [N, 128, 56, 56]
|
|
x = self.conv_1(x_base, out_conv2=False)
|
|
x_t = self.trans_patch_conv(x_base).flatten(2).transpose(1, 2)
|
|
if self.with_cls_token:
|
|
x_t = torch.cat([cls_tokens, x_t], dim=1)
|
|
x_t = self.trans_1(x_t)
|
|
|
|
# 2 ~ final
|
|
for i in range(2, self.fin_stage):
|
|
stage = getattr(self, f'conv_trans_{i}')
|
|
x, x_t = stage(x, x_t)
|
|
if i in self.out_indices:
|
|
if self.with_cls_token:
|
|
output.append([
|
|
self.pooling(x).flatten(1),
|
|
self.trans_norm(x_t)[:, 0]
|
|
])
|
|
else:
|
|
# if no class token, use the mean patch token
|
|
# as the transformer feature.
|
|
output.append([
|
|
self.pooling(x).flatten(1),
|
|
self.trans_norm(x_t).mean(dim=1)
|
|
])
|
|
|
|
return tuple(output)
|