mirror of
https://github.com/huggingface/pytorch-image-models.git
synced 2025-06-03 15:01:08 +08:00
Halo, bottleneck attn, lambda layer additions and cleanup along w/ experimental model defs
* align interfaces of halo, bottleneck attn and lambda layer * add qk_ratio to all of above, control q/k dim relative to output dim * add experimental haloregnetz, and trionet (lambda + halo + bottle) models
This commit is contained in:
parent
e0b3a3fab3
commit
e2b8d44ff0
@ -66,6 +66,13 @@ default_cfgs = {
|
||||
'lambda_resnet26rpt_256': _cfg(
|
||||
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-attn-weights/lambda_resnet26rpt_a2h_256-482adad8.pth',
|
||||
fixed_input_size=True, input_size=(3, 256, 256), pool_size=(8, 8)),
|
||||
|
||||
'haloregnetz_b': _cfg(
|
||||
url='',
|
||||
input_size=(3, 224, 224), pool_size=(7, 7), min_input_size=(3, 224, 224), crop_pct=0.94),
|
||||
'trionet50ts_256': _cfg(
|
||||
url='',
|
||||
fixed_input_size=True, input_size=(3, 256, 256), pool_size=(8, 8)),
|
||||
}
|
||||
|
||||
|
||||
@ -232,6 +239,46 @@ model_cfgs = dict(
|
||||
self_attn_layer='lambda',
|
||||
self_attn_kwargs=dict(r=None)
|
||||
),
|
||||
|
||||
# experimental
|
||||
haloregnetz_b=ByoModelCfg(
|
||||
blocks=(
|
||||
ByoBlockCfg(type='bottle', d=2, c=48, s=2, gs=16, br=3),
|
||||
ByoBlockCfg(type='bottle', d=6, c=96, s=2, gs=16, br=3),
|
||||
interleave_blocks(types=('bottle', 'self_attn'), every=3, d=12, c=192, s=2, gs=16, br=3),
|
||||
ByoBlockCfg('self_attn', d=2, c=288, s=2, gs=16, br=3),
|
||||
),
|
||||
stem_chs=32,
|
||||
stem_pool='',
|
||||
downsample='',
|
||||
num_features=1536,
|
||||
act_layer='silu',
|
||||
attn_layer='se',
|
||||
attn_kwargs=dict(rd_ratio=0.25),
|
||||
block_kwargs=dict(bottle_in=True, linear_out=True),
|
||||
self_attn_layer='halo',
|
||||
self_attn_kwargs=dict(block_size=7, halo_size=2, qk_ratio=0.33)
|
||||
),
|
||||
|
||||
# experimental
|
||||
trionet50ts=ByoModelCfg(
|
||||
blocks=(
|
||||
ByoBlockCfg(type='bottle', d=3, c=256, s=1, gs=0, br=0.25),
|
||||
interleave_blocks(
|
||||
types=('bottle', 'self_attn'), d=4, c=512, s=2, gs=0, br=0.25,
|
||||
self_attn_layer='lambda', self_attn_kwargs=dict(r=13)),
|
||||
interleave_blocks(
|
||||
types=('bottle', 'self_attn'), d=6, c=1024, s=2, gs=0, br=0.25,
|
||||
self_attn_layer='halo', self_attn_kwargs=dict(halo_size=3)),
|
||||
interleave_blocks(
|
||||
types=('bottle', 'self_attn'), d=3, c=2048, s=2, gs=0, br=0.25,
|
||||
self_attn_layer='bottleneck', self_attn_kwargs=dict()),
|
||||
),
|
||||
stem_chs=64,
|
||||
stem_type='tiered',
|
||||
stem_pool='',
|
||||
act_layer='silu',
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
@ -327,3 +374,17 @@ def lambda_resnet26rpt_256(pretrained=False, **kwargs):
|
||||
"""
|
||||
kwargs.setdefault('img_size', 256)
|
||||
return _create_byoanet('lambda_resnet26rpt_256', pretrained=pretrained, **kwargs)
|
||||
|
||||
|
||||
@register_model
|
||||
def haloregnetz_b(pretrained=False, **kwargs):
|
||||
""" Halo + RegNetZ
|
||||
"""
|
||||
return _create_byoanet('haloregnetz_b', pretrained=pretrained, **kwargs)
|
||||
|
||||
|
||||
@register_model
|
||||
def trionet50ts_256(pretrained=False, **kwargs):
|
||||
""" HaloNet w/ a ResNet50-t backbone, silu act. Halo attention in final two stages
|
||||
"""
|
||||
return _create_byoanet('trionet50ts_256', 'trionet50ts', pretrained=pretrained, **kwargs)
|
||||
|
@ -1096,18 +1096,16 @@ class SelfAttnBlock(nn.Module):
|
||||
self.self_attn.reset_parameters()
|
||||
|
||||
def forward(self, x):
|
||||
shortcut = self.shortcut(x)
|
||||
|
||||
shortcut = x
|
||||
x = self.conv1_1x1(x)
|
||||
x = self.conv2_kxk(x)
|
||||
x = self.self_attn(x)
|
||||
x = self.post_attn(x)
|
||||
x = self.conv3_1x1(x)
|
||||
x = self.drop_path(x)
|
||||
|
||||
x = self.act(x + shortcut)
|
||||
return x
|
||||
|
||||
if self.shortcut is not None:
|
||||
x = x + self.shortcut(shortcut)
|
||||
return self.act(x)
|
||||
|
||||
_block_registry = dict(
|
||||
basic=BasicBlock,
|
||||
|
@ -20,7 +20,7 @@ import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
from .helpers import to_2tuple
|
||||
from .helpers import to_2tuple, make_divisible
|
||||
from .weight_init import trunc_normal_
|
||||
|
||||
|
||||
@ -66,10 +66,10 @@ class PosEmbedRel(nn.Module):
|
||||
self.width_rel = nn.Parameter(torch.randn(self.width * 2 - 1, dim_head) * self.scale)
|
||||
|
||||
def forward(self, q):
|
||||
B, num_heads, HW, _ = q.shape
|
||||
B, HW, _ = q.shape
|
||||
|
||||
# relative logits in width dimension.
|
||||
q = q.reshape(B * num_heads, self.height, self.width, -1)
|
||||
q = q.reshape(B, self.height, self.width, -1)
|
||||
rel_logits_w = rel_logits_1d(q, self.width_rel, permute_mask=(0, 1, 3, 2, 4))
|
||||
|
||||
# relative logits in height dimension.
|
||||
@ -77,35 +77,56 @@ class PosEmbedRel(nn.Module):
|
||||
rel_logits_h = rel_logits_1d(q, self.height_rel, permute_mask=(0, 3, 1, 4, 2))
|
||||
|
||||
rel_logits = rel_logits_h + rel_logits_w
|
||||
rel_logits = rel_logits.reshape(B, num_heads, HW, HW)
|
||||
rel_logits = rel_logits.reshape(B, HW, HW)
|
||||
return rel_logits
|
||||
|
||||
|
||||
class BottleneckAttn(nn.Module):
|
||||
""" Bottleneck Attention
|
||||
Paper: `Bottleneck Transformers for Visual Recognition` - https://arxiv.org/abs/2101.11605
|
||||
|
||||
The internal dimensions of the attention module are controlled by the interaction of several arguments.
|
||||
* the output dimension of the module is specified by dim_out, which falls back to input dim if not set
|
||||
* the value (v) dimension is set to dim_out // num_heads, the v projection determines the output dim
|
||||
* the query and key (qk) dimensions are determined by
|
||||
* num_heads * dim_head if dim_head is not None
|
||||
* num_heads * (dim_out * attn_ratio // num_heads) if dim_head is None
|
||||
* as seen above, attn_ratio determines the ratio of q and k relative to the output if dim_head not used
|
||||
|
||||
Args:
|
||||
dim (int): input dimension to the module
|
||||
dim_out (int): output dimension of the module, same as dim if not set
|
||||
stride (int): output stride of the module, avg pool used if stride == 2 (default: 1).
|
||||
num_heads (int): parallel attention heads (default: 4)
|
||||
dim_head (int): dimension of query and key heads, calculated from dim_out * attn_ratio // num_heads if not set
|
||||
qk_ratio (float): ratio of q and k dimensions to output dimension when dim_head not set. (default: 1.0)
|
||||
qkv_bias (bool): add bias to q, k, and v projections
|
||||
"""
|
||||
def __init__(self, dim, dim_out=None, feat_size=None, stride=1, num_heads=4, qkv_bias=False):
|
||||
def __init__(
|
||||
self, dim, dim_out=None, feat_size=None, stride=1, num_heads=4, dim_head=None,
|
||||
qk_ratio=1.0, qkv_bias=False):
|
||||
super().__init__()
|
||||
assert feat_size is not None, 'A concrete feature size matching expected input (H, W) is required'
|
||||
dim_out = dim_out or dim
|
||||
assert dim_out % num_heads == 0
|
||||
self.num_heads = num_heads
|
||||
self.dim_out = dim_out
|
||||
self.dim_head = dim_out // num_heads
|
||||
self.scale = self.dim_head ** -0.5
|
||||
self.dim_head_qk = dim_head or make_divisible(dim_out * qk_ratio, divisor=8) // num_heads
|
||||
self.dim_head_v = dim_out // self.num_heads
|
||||
self.dim_out_qk = num_heads * self.dim_head_qk
|
||||
self.dim_out_v = num_heads * self.dim_head_v
|
||||
self.scale = self.dim_head_qk ** -0.5
|
||||
|
||||
self.qkv = nn.Conv2d(dim, self.dim_out * 3, 1, bias=qkv_bias)
|
||||
self.qkv = nn.Conv2d(dim, self.dim_out_qk * 2 + self.dim_out_v, 1, bias=qkv_bias)
|
||||
|
||||
# NOTE I'm only supporting relative pos embedding for now
|
||||
self.pos_embed = PosEmbedRel(feat_size, dim_head=self.dim_head, scale=self.scale)
|
||||
self.pos_embed = PosEmbedRel(feat_size, dim_head=self.dim_head_qk, scale=self.scale)
|
||||
|
||||
self.pool = nn.AvgPool2d(2, 2) if stride == 2 else nn.Identity()
|
||||
|
||||
self.reset_parameters()
|
||||
|
||||
def reset_parameters(self):
|
||||
trunc_normal_(self.qkv.weight, std=self.qkv.weight.shape[1] ** -0.5)
|
||||
trunc_normal_(self.qkv.weight, std=self.qkv.weight.shape[1] ** -0.5) # fan-in
|
||||
trunc_normal_(self.pos_embed.height_rel, std=self.scale)
|
||||
trunc_normal_(self.pos_embed.width_rel, std=self.scale)
|
||||
|
||||
@ -114,15 +135,20 @@ class BottleneckAttn(nn.Module):
|
||||
assert H == self.pos_embed.height
|
||||
assert W == self.pos_embed.width
|
||||
|
||||
x = self.qkv(x) # B, 3 * num_heads * dim_head, H, W
|
||||
x = x.reshape(B, -1, self.dim_head, H * W).transpose(-1, -2)
|
||||
q, k, v = torch.split(x, self.num_heads, dim=1)
|
||||
x = self.qkv(x) # B, (2 * dim_head_qk + dim_head_v) * num_heads, H, W
|
||||
|
||||
attn = (q @ k.transpose(-1, -2)) * self.scale
|
||||
attn = attn + self.pos_embed(q) # B, num_heads, H * W, H * W
|
||||
# NOTE head vs channel split ordering in qkv projection was decided before I allowed qk to differ from v
|
||||
# So, this is more verbose than if heads were before qkv splits, but throughput is not impacted.
|
||||
q, k, v = torch.split(x, [self.dim_out_qk, self.dim_out_qk, self.dim_out_v], dim=1)
|
||||
q = q.reshape(B * self.num_heads, self.dim_head_qk, -1).transpose(-1, -2)
|
||||
k = k.reshape(B * self.num_heads, self.dim_head_qk, -1) # no transpose, for q @ k
|
||||
v = v.reshape(B * self.num_heads, self.dim_head_v, -1).transpose(-1, -2)
|
||||
|
||||
attn = (q @ k) * self.scale
|
||||
attn = attn + self.pos_embed(q) # B * num_heads, H * W, H * W
|
||||
attn = attn.softmax(dim=-1)
|
||||
|
||||
out = (attn @ v).transpose(-1, -2).reshape(B, self.dim_out, H, W) # B, dim_out, H, W
|
||||
out = (attn @ v).transpose(-1, -2).reshape(B, self.dim_out_v, H, W) # B, dim_out, H, W
|
||||
out = self.pool(out)
|
||||
return out
|
||||
|
||||
|
@ -22,6 +22,7 @@ import torch
|
||||
from torch import nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
from .helpers import make_divisible
|
||||
from .weight_init import trunc_normal_
|
||||
|
||||
|
||||
@ -98,31 +99,62 @@ class HaloAttn(nn.Module):
|
||||
|
||||
Paper: `Scaling Local Self-Attention for Parameter Efficient Visual Backbones`
|
||||
- https://arxiv.org/abs/2103.12731
|
||||
|
||||
The internal dimensions of the attention module are controlled by the interaction of several arguments.
|
||||
* the output dimension of the module is specified by dim_out, which falls back to input dim if not set
|
||||
* the value (v) dimension is set to dim_out // num_heads, the v projection determines the output dim
|
||||
* the query and key (qk) dimensions are determined by
|
||||
* num_heads * dim_head if dim_head is not None
|
||||
* num_heads * (dim_out * attn_ratio // num_heads) if dim_head is None
|
||||
* as seen above, attn_ratio determines the ratio of q and k relative to the output if dim_head not used
|
||||
|
||||
Args:
|
||||
dim (int): input dimension to the module
|
||||
dim_out (int): output dimension of the module, same as dim if not set
|
||||
feat_size (Tuple[int, int]): size of input feature_map (not used, for arg compat with bottle/lambda)
|
||||
stride: output stride of the module, query downscaled if > 1 (default: 1).
|
||||
num_heads: parallel attention heads (default: 8).
|
||||
dim_head: dimension of query and key heads, calculated from dim_out * attn_ratio // num_heads if not set
|
||||
block_size (int): size of blocks. (default: 8)
|
||||
halo_size (int): size of halo overlap. (default: 3)
|
||||
qk_ratio (float): ratio of q and k dimensions to output dimension when dim_head not set. (default: 1.0)
|
||||
qkv_bias (bool) : add bias to q, k, and v projections
|
||||
avg_down (bool): use average pool downsample instead of strided query blocks
|
||||
|
||||
"""
|
||||
def __init__(
|
||||
self, dim, dim_out=None, stride=1, num_heads=8, dim_head=None, block_size=8, halo_size=3, qkv_bias=False):
|
||||
self, dim, dim_out=None, feat_size=None, stride=1, num_heads=8, dim_head=None, block_size=8, halo_size=3,
|
||||
qk_ratio=1.0, qkv_bias=False, avg_down=False):
|
||||
super().__init__()
|
||||
dim_out = dim_out or dim
|
||||
assert dim_out % num_heads == 0
|
||||
self.stride = stride
|
||||
assert stride in (1, 2)
|
||||
self.num_heads = num_heads
|
||||
self.dim_head_qk = dim_head or dim_out // num_heads
|
||||
self.dim_head_qk = dim_head or make_divisible(dim_out * qk_ratio, divisor=8) // num_heads
|
||||
self.dim_head_v = dim_out // self.num_heads
|
||||
self.dim_out_qk = num_heads * self.dim_head_qk
|
||||
self.dim_out_v = num_heads * self.dim_head_v
|
||||
self.block_size = block_size
|
||||
self.scale = self.dim_head_qk ** -0.5
|
||||
self.block_size = self.block_size_ds = block_size
|
||||
self.halo_size = halo_size
|
||||
self.win_size = block_size + halo_size * 2 # neighbourhood window size
|
||||
self.scale = self.dim_head_qk ** -0.5
|
||||
self.block_stride = 1
|
||||
use_avg_pool = False
|
||||
if stride > 1:
|
||||
use_avg_pool = avg_down or block_size % stride != 0
|
||||
self.block_stride = 1 if use_avg_pool else stride
|
||||
self.block_size_ds = self.block_size // self.block_stride
|
||||
|
||||
# FIXME not clear if this stride behaviour is what the paper intended
|
||||
# Also, the paper mentions using a 3D conv for dealing with the blocking/gather, and leaving
|
||||
# data in unfolded block form. I haven't wrapped my head around how that'd look.
|
||||
self.q = nn.Conv2d(dim, self.dim_out_qk, 1, stride=self.stride, bias=qkv_bias)
|
||||
self.q = nn.Conv2d(dim, self.dim_out_qk, 1, stride=self.block_stride, bias=qkv_bias)
|
||||
self.kv = nn.Conv2d(dim, self.dim_out_qk + self.dim_out_v, 1, bias=qkv_bias)
|
||||
|
||||
self.pos_embed = PosEmbedRel(
|
||||
block_size=block_size // self.stride, win_size=self.win_size, dim_head=self.dim_head_qk, scale=self.scale)
|
||||
block_size=self.block_size_ds, win_size=self.win_size, dim_head=self.dim_head_qk, scale=self.scale)
|
||||
|
||||
self.pool = nn.AvgPool2d(2, 2) if use_avg_pool else nn.Identity()
|
||||
|
||||
self.reset_parameters()
|
||||
|
||||
@ -140,11 +172,12 @@ class HaloAttn(nn.Module):
|
||||
num_h_blocks = H // self.block_size
|
||||
num_w_blocks = W // self.block_size
|
||||
num_blocks = num_h_blocks * num_w_blocks
|
||||
bs_stride = self.block_size // self.stride
|
||||
|
||||
q = self.q(x)
|
||||
# unfold
|
||||
q = q.reshape(-1, self.dim_head_qk, num_h_blocks, bs_stride, num_w_blocks, bs_stride).permute(0, 1, 3, 5, 2, 4)
|
||||
q = q.reshape(
|
||||
-1, self.dim_head_qk,
|
||||
num_h_blocks, self.block_size_ds, num_w_blocks, self.block_size_ds).permute(0, 1, 3, 5, 2, 4)
|
||||
# B, num_heads * dim_head * block_size ** 2, num_blocks
|
||||
q = q.reshape(B * self.num_heads, self.dim_head_qk, -1, num_blocks).transpose(1, 3)
|
||||
# B * num_heads, num_blocks, block_size ** 2, dim_head
|
||||
@ -163,9 +196,11 @@ class HaloAttn(nn.Module):
|
||||
|
||||
out = (attn @ v).transpose(1, 3) # B * num_heads, dim_head_v, block_size ** 2, num_blocks
|
||||
# fold
|
||||
out = out.reshape(-1, bs_stride, bs_stride, num_h_blocks, num_w_blocks)
|
||||
out = out.permute(0, 3, 1, 4, 2).contiguous().view(B, self.dim_out_v, H // self.stride, W // self.stride)
|
||||
# B, dim_out, H // stride, W // stride
|
||||
out = out.reshape(-1, self.block_size_ds, self.block_size_ds, num_h_blocks, num_w_blocks)
|
||||
out = out.permute(0, 3, 1, 4, 2).contiguous().view(
|
||||
B, self.dim_out_v, H // self.block_stride, W // self.block_stride)
|
||||
# B, dim_out, H // block_stride, W // block_stride
|
||||
out = self.pool(out)
|
||||
return out
|
||||
|
||||
|
||||
|
@ -24,7 +24,7 @@ import torch
|
||||
from torch import nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
from .helpers import to_2tuple
|
||||
from .helpers import to_2tuple, make_divisible
|
||||
from .weight_init import trunc_normal_
|
||||
|
||||
|
||||
@ -44,28 +44,46 @@ class LambdaLayer(nn.Module):
|
||||
- https://arxiv.org/abs/2102.08602
|
||||
|
||||
NOTE: intra-depth parameter 'u' is fixed at 1. It did not appear worth the complexity to add.
|
||||
|
||||
The internal dimensions of the lambda module are controlled via the interaction of several arguments.
|
||||
* the output dimension of the module is specified by dim_out, which falls back to input dim if not set
|
||||
* the value (v) dimension is set to dim_out // num_heads, the v projection determines the output dim
|
||||
* the query (q) and key (k) dimension are determined by
|
||||
* dim_head = (dim_out * attn_ratio // num_heads) if dim_head is None
|
||||
* q = num_heads * dim_head, k = dim_head
|
||||
* as seen above, attn_ratio determines the ratio of q and k relative to the output if dim_head not set
|
||||
|
||||
Args:
|
||||
dim (int): input dimension to the module
|
||||
dim_out (int): output dimension of the module, same as dim if not set
|
||||
feat_size (Tuple[int, int]): size of input feature_map for relative pos variant H, W
|
||||
stride (int): output stride of the module, avg pool used if stride == 2
|
||||
num_heads (int): parallel attention heads.
|
||||
dim_head (int): dimension of query and key heads, calculated from dim_out * attn_ratio // num_heads if not set
|
||||
r (int): local lambda convolution radius. Use lambda conv if set, else relative pos if not. (default: 9)
|
||||
qk_ratio (float): ratio of q and k dimensions to output dimension when dim_head not set. (default: 1.0)
|
||||
qkv_bias (bool): add bias to q, k, and v projections
|
||||
"""
|
||||
def __init__(
|
||||
self,
|
||||
dim, dim_out=None, feat_size=None, stride=1, num_heads=4, dim_head=16, r=7, qkv_bias=False):
|
||||
self, dim, dim_out=None, feat_size=None, stride=1, num_heads=4, dim_head=16, r=9,
|
||||
qk_ratio=1.0, qkv_bias=False):
|
||||
super().__init__()
|
||||
self.dim = dim
|
||||
self.dim_out = dim_out or dim
|
||||
self.dim_k = dim_head # query depth 'k'
|
||||
dim_out = dim_out or dim
|
||||
assert dim_out % num_heads == 0, ' should be divided by num_heads'
|
||||
self.dim_qk = dim_head or make_divisible(dim_out * qk_ratio, divisor=8) // num_heads
|
||||
self.num_heads = num_heads
|
||||
assert self.dim_out % num_heads == 0, ' should be divided by num_heads'
|
||||
self.dim_v = self.dim_out // num_heads # value depth 'v'
|
||||
self.dim_v = dim_out // num_heads
|
||||
|
||||
self.qkv = nn.Conv2d(
|
||||
dim,
|
||||
num_heads * dim_head + dim_head + self.dim_v,
|
||||
num_heads * self.dim_qk + self.dim_qk + self.dim_v,
|
||||
kernel_size=1, bias=qkv_bias)
|
||||
self.norm_q = nn.BatchNorm2d(num_heads * dim_head)
|
||||
self.norm_q = nn.BatchNorm2d(num_heads * self.dim_qk)
|
||||
self.norm_v = nn.BatchNorm2d(self.dim_v)
|
||||
|
||||
if r is not None:
|
||||
# local lambda convolution for pos
|
||||
self.conv_lambda = nn.Conv3d(1, dim_head, (r, r, 1), padding=(r // 2, r // 2, 0))
|
||||
self.conv_lambda = nn.Conv3d(1, self.dim_qk, (r, r, 1), padding=(r // 2, r // 2, 0))
|
||||
self.pos_emb = None
|
||||
self.rel_pos_indices = None
|
||||
else:
|
||||
@ -74,7 +92,7 @@ class LambdaLayer(nn.Module):
|
||||
feat_size = to_2tuple(feat_size)
|
||||
rel_size = [2 * s - 1 for s in feat_size]
|
||||
self.conv_lambda = None
|
||||
self.pos_emb = nn.Parameter(torch.zeros(rel_size[0], rel_size[1], self.dim_k))
|
||||
self.pos_emb = nn.Parameter(torch.zeros(rel_size[0], rel_size[1], self.dim_qk))
|
||||
self.register_buffer('rel_pos_indices', rel_pos_indices(feat_size), persistent=False)
|
||||
|
||||
self.pool = nn.AvgPool2d(2, 2) if stride == 2 else nn.Identity()
|
||||
@ -82,9 +100,9 @@ class LambdaLayer(nn.Module):
|
||||
self.reset_parameters()
|
||||
|
||||
def reset_parameters(self):
|
||||
trunc_normal_(self.qkv.weight, std=self.dim ** -0.5)
|
||||
trunc_normal_(self.qkv.weight, std=self.qkv.weight.shape[1] ** -0.5) # fan-in
|
||||
if self.conv_lambda is not None:
|
||||
trunc_normal_(self.conv_lambda.weight, std=self.dim_k ** -0.5)
|
||||
trunc_normal_(self.conv_lambda.weight, std=self.dim_qk ** -0.5)
|
||||
if self.pos_emb is not None:
|
||||
trunc_normal_(self.pos_emb, std=.02)
|
||||
|
||||
@ -93,17 +111,17 @@ class LambdaLayer(nn.Module):
|
||||
M = H * W
|
||||
qkv = self.qkv(x)
|
||||
q, k, v = torch.split(qkv, [
|
||||
self.num_heads * self.dim_k, self.dim_k, self.dim_v], dim=1)
|
||||
q = self.norm_q(q).reshape(B, self.num_heads, self.dim_k, M).transpose(-1, -2) # B, num_heads, M, K
|
||||
self.num_heads * self.dim_qk, self.dim_qk, self.dim_v], dim=1)
|
||||
q = self.norm_q(q).reshape(B, self.num_heads, self.dim_qk, M).transpose(-1, -2) # B, num_heads, M, K
|
||||
v = self.norm_v(v).reshape(B, self.dim_v, M).transpose(-1, -2) # B, M, V
|
||||
k = F.softmax(k.reshape(B, self.dim_k, M), dim=-1) # B, K, M
|
||||
k = F.softmax(k.reshape(B, self.dim_qk, M), dim=-1) # B, K, M
|
||||
|
||||
content_lam = k @ v # B, K, V
|
||||
content_out = q @ content_lam.unsqueeze(1) # B, num_heads, M, V
|
||||
|
||||
if self.pos_emb is None:
|
||||
position_lam = self.conv_lambda(v.reshape(B, 1, H, W, self.dim_v)) # B, H, W, V, K
|
||||
position_lam = position_lam.reshape(B, 1, self.dim_k, H * W, self.dim_v).transpose(2, 3) # B, 1, M, K, V
|
||||
position_lam = position_lam.reshape(B, 1, self.dim_qk, H * W, self.dim_v).transpose(2, 3) # B, 1, M, K, V
|
||||
else:
|
||||
# FIXME relative pos embedding path not fully verified
|
||||
pos_emb = self.pos_emb[self.rel_pos_indices[0], self.rel_pos_indices[1]].expand(B, -1, -1, -1)
|
||||
|
Loading…
x
Reference in New Issue
Block a user