mirror of
https://github.com/huggingface/pytorch-image-models.git
synced 2025-06-03 15:01:08 +08:00
Some halo and bottleneck attn code cleanup, add halonet50ts weights, use optimal crop ratios
This commit is contained in:
parent
d9abfa48df
commit
007bc39323
@ -3,7 +3,7 @@
|
||||
A flexible network w/ dataclass based config for stacking NN blocks including
|
||||
self-attention (or similar) layers.
|
||||
|
||||
Currently used to implement experimential variants of:
|
||||
Currently used to implement experimental variants of:
|
||||
* Bottleneck Transformers
|
||||
* Lambda ResNets
|
||||
* HaloNets
|
||||
@ -46,15 +46,16 @@ default_cfgs = {
|
||||
'halonet_h1': _cfg(url='', input_size=(3, 256, 256), pool_size=(8, 8), min_input_size=(3, 256, 256)),
|
||||
'halonet26t': _cfg(
|
||||
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-attn-weights/halonet26t_256-9b4bf0b3.pth',
|
||||
input_size=(3, 256, 256), pool_size=(8, 8), min_input_size=(3, 256, 256)),
|
||||
input_size=(3, 256, 256), pool_size=(8, 8), min_input_size=(3, 256, 256), crop_pct=0.94),
|
||||
'sehalonet33ts': _cfg(
|
||||
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-attn-weights/sehalonet33ts_256-87e053f9.pth',
|
||||
input_size=(3, 256, 256), pool_size=(8, 8), min_input_size=(3, 256, 256), crop_pct=0.94),
|
||||
'halonet50ts': _cfg(
|
||||
url='', input_size=(3, 256, 256), pool_size=(8, 8), min_input_size=(3, 256, 256)),
|
||||
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-attn-weights/halonet50ts_256_ra3-f07eab9f.pth',
|
||||
input_size=(3, 256, 256), pool_size=(8, 8), min_input_size=(3, 256, 256), crop_pct=0.94),
|
||||
'eca_halonext26ts': _cfg(
|
||||
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-attn-weights/eca_halonext26ts_256-1e55880b.pth',
|
||||
input_size=(3, 256, 256), pool_size=(8, 8), min_input_size=(3, 256, 256)),
|
||||
input_size=(3, 256, 256), pool_size=(8, 8), min_input_size=(3, 256, 256), crop_pct=0.94),
|
||||
|
||||
'lambda_resnet26t': _cfg(
|
||||
url='',
|
||||
|
@ -118,12 +118,12 @@ class BottleneckAttn(nn.Module):
|
||||
x = x.reshape(B, -1, self.dim_head, H * W).transpose(-1, -2)
|
||||
q, k, v = torch.split(x, self.num_heads, dim=1)
|
||||
|
||||
attn_logits = (q @ k.transpose(-1, -2)) * self.scale
|
||||
attn_logits = attn_logits + self.pos_embed(q) # B, num_heads, H * W, H * W
|
||||
attn = (q @ k.transpose(-1, -2)) * self.scale
|
||||
attn = attn + self.pos_embed(q) # B, num_heads, H * W, H * W
|
||||
attn = attn.softmax(dim=-1)
|
||||
|
||||
attn_out = attn_logits.softmax(dim=-1)
|
||||
attn_out = (attn_out @ v).transpose(-1, -2).reshape(B, self.dim_out, H, W) # B, dim_out, H, W
|
||||
attn_out = self.pool(attn_out)
|
||||
return attn_out
|
||||
out = (attn @ v).transpose(-1, -2).reshape(B, self.dim_out, H, W) # B, dim_out, H, W
|
||||
out = self.pool(out)
|
||||
return out
|
||||
|
||||
|
||||
|
@ -106,22 +106,23 @@ class HaloAttn(nn.Module):
|
||||
assert dim_out % num_heads == 0
|
||||
self.stride = stride
|
||||
self.num_heads = num_heads
|
||||
self.dim_head = dim_head or dim // num_heads
|
||||
self.dim_qk = num_heads * self.dim_head
|
||||
self.dim_v = dim_out
|
||||
self.dim_head_qk = dim_head or dim_out // num_heads
|
||||
self.dim_head_v = dim_out // self.num_heads
|
||||
self.dim_out_qk = num_heads * self.dim_head_qk
|
||||
self.dim_out_v = num_heads * self.dim_head_v
|
||||
self.block_size = block_size
|
||||
self.halo_size = halo_size
|
||||
self.win_size = block_size + halo_size * 2 # neighbourhood window size
|
||||
self.scale = self.dim_head ** -0.5
|
||||
self.scale = self.dim_head_qk ** -0.5
|
||||
|
||||
# FIXME not clear if this stride behaviour is what the paper intended
|
||||
# Also, the paper mentions using a 3D conv for dealing with the blocking/gather, and leaving
|
||||
# data in unfolded block form. I haven't wrapped my head around how that'd look.
|
||||
self.q = nn.Conv2d(dim, self.dim_qk, 1, stride=self.stride, bias=qkv_bias)
|
||||
self.kv = nn.Conv2d(dim, self.dim_qk + self.dim_v, 1, bias=qkv_bias)
|
||||
self.q = nn.Conv2d(dim, self.dim_out_qk, 1, stride=self.stride, bias=qkv_bias)
|
||||
self.kv = nn.Conv2d(dim, self.dim_out_qk + self.dim_out_v, 1, bias=qkv_bias)
|
||||
|
||||
self.pos_embed = PosEmbedRel(
|
||||
block_size=block_size // self.stride, win_size=self.win_size, dim_head=self.dim_head, scale=self.scale)
|
||||
block_size=block_size // self.stride, win_size=self.win_size, dim_head=self.dim_head_qk, scale=self.scale)
|
||||
|
||||
self.reset_parameters()
|
||||
|
||||
@ -143,37 +144,42 @@ class HaloAttn(nn.Module):
|
||||
|
||||
q = self.q(x)
|
||||
# unfold
|
||||
q = q.reshape(-1, self.dim_head, num_h_blocks, bs_stride, num_w_blocks, bs_stride).permute(0, 1, 3, 5, 2, 4)
|
||||
q = q.reshape(-1, self.dim_head_qk, num_h_blocks, bs_stride, num_w_blocks, bs_stride).permute(0, 1, 3, 5, 2, 4)
|
||||
# B, num_heads * dim_head * block_size ** 2, num_blocks
|
||||
q = q.reshape(B * self.num_heads, self.dim_head, -1, num_blocks).transpose(1, 3)
|
||||
q = q.reshape(B * self.num_heads, self.dim_head_qk, -1, num_blocks).transpose(1, 3)
|
||||
# B * num_heads, num_blocks, block_size ** 2, dim_head
|
||||
|
||||
kv = self.kv(x)
|
||||
# generate overlapping windows for kv
|
||||
kv = F.pad(kv, [self.halo_size, self.halo_size, self.halo_size, self.halo_size])
|
||||
kv = kv.unfold(2, self.win_size, self.block_size).unfold(3, self.win_size, self.block_size).reshape(
|
||||
B * self.num_heads, self.dim_head + (self.dim_v // self.num_heads), num_blocks, -1).permute(0, 2, 3, 1)
|
||||
# NOTE these two alternatives are equivalent, but above is the best balance of performance and clarity
|
||||
# if self.stride_tricks:
|
||||
# kv = F.pad(kv, [self.halo_size, self.halo_size, self.halo_size, self.halo_size]).contiguous()
|
||||
# kv = kv.as_strided((
|
||||
# B, self.dim_qk + self.dim_v, self.win_size, self.win_size, num_h_blocks, num_w_blocks),
|
||||
# stride=(kv.stride(0), kv.stride(1), kv.shape[-1], 1, self.block_size * kv.shape[-1], self.block_size))
|
||||
# else:
|
||||
# kv = F.unfold(kv, kernel_size=self.win_size, stride=self.block_size, padding=self.halo_size)
|
||||
# kv = kv.reshape(
|
||||
# B * self.num_heads, self.dim_head + (self.dim_v // self.num_heads), -1, num_blocks).transpose(1, 3)
|
||||
k, v = torch.split(kv, [self.dim_head, self.dim_v // self.num_heads], dim=-1)
|
||||
# B * num_heads, num_blocks, block_size ** 2, dim_head or dim_v // num_heads
|
||||
B * self.num_heads, self.dim_head_qk + self.dim_head_v, num_blocks, -1).permute(0, 2, 3, 1)
|
||||
k, v = torch.split(kv, [self.dim_head_qk, self.dim_head_v], dim=-1)
|
||||
# B * num_heads, num_blocks, win_size ** 2, dim_head_qk or dim_head_v
|
||||
|
||||
attn_logits = (q @ k.transpose(-1, -2)) * self.scale # FIXME should usual attn scale be applied?
|
||||
attn_logits = attn_logits + self.pos_embed(q) # B * num_heads, block_size ** 2, win_size ** 2
|
||||
|
||||
attn_out = attn_logits.softmax(dim=-1)
|
||||
attn_out = (attn_out @ v).transpose(1, 3) # B * num_heads, dim_v // num_heads, block_size ** 2, num_blocks
|
||||
attn = (q @ k.transpose(-1, -2)) * self.scale
|
||||
attn = attn + self.pos_embed(q) # B * num_heads, num_blocks, block_size ** 2, win_size ** 2
|
||||
attn = attn.softmax(dim=-1)
|
||||
|
||||
out = (attn @ v).transpose(1, 3) # B * num_heads, dim_head_v, block_size ** 2, num_blocks
|
||||
# fold
|
||||
attn_out = attn_out.reshape(-1, bs_stride, bs_stride, num_h_blocks, num_w_blocks)
|
||||
attn_out = attn_out.permute(0, 3, 1, 4, 2).contiguous().view(B, self.dim_v, H // self.stride, W // self.stride)
|
||||
out = out.reshape(-1, bs_stride, bs_stride, num_h_blocks, num_w_blocks)
|
||||
out = out.permute(0, 3, 1, 4, 2).contiguous().view(B, self.dim_out_v, H // self.stride, W // self.stride)
|
||||
# B, dim_out, H // stride, W // stride
|
||||
return attn_out
|
||||
return out
|
||||
|
||||
|
||||
""" Two alternatives for overlapping windows.
|
||||
|
||||
`.unfold().unfold()` is same speed as stride tricks with similar clarity as F.unfold()
|
||||
|
||||
if self.stride_tricks:
|
||||
kv = F.pad(kv, [self.halo_size, self.halo_size, self.halo_size, self.halo_size]).contiguous()
|
||||
kv = kv.as_strided((
|
||||
B, self.dim_out_qk + self.dim_out_v, self.win_size, self.win_size, num_h_blocks, num_w_blocks),
|
||||
stride=(kv.stride(0), kv.stride(1), kv.shape[-1], 1, self.block_size * kv.shape[-1], self.block_size))
|
||||
else:
|
||||
kv = F.unfold(kv, kernel_size=self.win_size, stride=self.block_size, padding=self.halo_size)
|
||||
kv = kv.reshape(
|
||||
B * self.num_heads, self.dim_head_qk + self.dim_head_v, -1, num_blocks).transpose(1, 3)
|
||||
"""
|
||||
|
Loading…
x
Reference in New Issue
Block a user