mirror of
https://github.com/huggingface/pytorch-image-models.git
synced 2025-06-03 15:01:08 +08:00
Some halo and bottleneck attn code cleanup, add halonet50ts weights, use optimal crop ratios
This commit is contained in:
parent
d9abfa48df
commit
007bc39323
@ -3,7 +3,7 @@
|
|||||||
A flexible network w/ dataclass based config for stacking NN blocks including
|
A flexible network w/ dataclass based config for stacking NN blocks including
|
||||||
self-attention (or similar) layers.
|
self-attention (or similar) layers.
|
||||||
|
|
||||||
Currently used to implement experimential variants of:
|
Currently used to implement experimental variants of:
|
||||||
* Bottleneck Transformers
|
* Bottleneck Transformers
|
||||||
* Lambda ResNets
|
* Lambda ResNets
|
||||||
* HaloNets
|
* HaloNets
|
||||||
@ -46,15 +46,16 @@ default_cfgs = {
|
|||||||
'halonet_h1': _cfg(url='', input_size=(3, 256, 256), pool_size=(8, 8), min_input_size=(3, 256, 256)),
|
'halonet_h1': _cfg(url='', input_size=(3, 256, 256), pool_size=(8, 8), min_input_size=(3, 256, 256)),
|
||||||
'halonet26t': _cfg(
|
'halonet26t': _cfg(
|
||||||
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-attn-weights/halonet26t_256-9b4bf0b3.pth',
|
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-attn-weights/halonet26t_256-9b4bf0b3.pth',
|
||||||
input_size=(3, 256, 256), pool_size=(8, 8), min_input_size=(3, 256, 256)),
|
input_size=(3, 256, 256), pool_size=(8, 8), min_input_size=(3, 256, 256), crop_pct=0.94),
|
||||||
'sehalonet33ts': _cfg(
|
'sehalonet33ts': _cfg(
|
||||||
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-attn-weights/sehalonet33ts_256-87e053f9.pth',
|
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-attn-weights/sehalonet33ts_256-87e053f9.pth',
|
||||||
input_size=(3, 256, 256), pool_size=(8, 8), min_input_size=(3, 256, 256), crop_pct=0.94),
|
input_size=(3, 256, 256), pool_size=(8, 8), min_input_size=(3, 256, 256), crop_pct=0.94),
|
||||||
'halonet50ts': _cfg(
|
'halonet50ts': _cfg(
|
||||||
url='', input_size=(3, 256, 256), pool_size=(8, 8), min_input_size=(3, 256, 256)),
|
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-attn-weights/halonet50ts_256_ra3-f07eab9f.pth',
|
||||||
|
input_size=(3, 256, 256), pool_size=(8, 8), min_input_size=(3, 256, 256), crop_pct=0.94),
|
||||||
'eca_halonext26ts': _cfg(
|
'eca_halonext26ts': _cfg(
|
||||||
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-attn-weights/eca_halonext26ts_256-1e55880b.pth',
|
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-attn-weights/eca_halonext26ts_256-1e55880b.pth',
|
||||||
input_size=(3, 256, 256), pool_size=(8, 8), min_input_size=(3, 256, 256)),
|
input_size=(3, 256, 256), pool_size=(8, 8), min_input_size=(3, 256, 256), crop_pct=0.94),
|
||||||
|
|
||||||
'lambda_resnet26t': _cfg(
|
'lambda_resnet26t': _cfg(
|
||||||
url='',
|
url='',
|
||||||
|
@ -118,12 +118,12 @@ class BottleneckAttn(nn.Module):
|
|||||||
x = x.reshape(B, -1, self.dim_head, H * W).transpose(-1, -2)
|
x = x.reshape(B, -1, self.dim_head, H * W).transpose(-1, -2)
|
||||||
q, k, v = torch.split(x, self.num_heads, dim=1)
|
q, k, v = torch.split(x, self.num_heads, dim=1)
|
||||||
|
|
||||||
attn_logits = (q @ k.transpose(-1, -2)) * self.scale
|
attn = (q @ k.transpose(-1, -2)) * self.scale
|
||||||
attn_logits = attn_logits + self.pos_embed(q) # B, num_heads, H * W, H * W
|
attn = attn + self.pos_embed(q) # B, num_heads, H * W, H * W
|
||||||
|
attn = attn.softmax(dim=-1)
|
||||||
|
|
||||||
attn_out = attn_logits.softmax(dim=-1)
|
out = (attn @ v).transpose(-1, -2).reshape(B, self.dim_out, H, W) # B, dim_out, H, W
|
||||||
attn_out = (attn_out @ v).transpose(-1, -2).reshape(B, self.dim_out, H, W) # B, dim_out, H, W
|
out = self.pool(out)
|
||||||
attn_out = self.pool(attn_out)
|
return out
|
||||||
return attn_out
|
|
||||||
|
|
||||||
|
|
||||||
|
@ -106,22 +106,23 @@ class HaloAttn(nn.Module):
|
|||||||
assert dim_out % num_heads == 0
|
assert dim_out % num_heads == 0
|
||||||
self.stride = stride
|
self.stride = stride
|
||||||
self.num_heads = num_heads
|
self.num_heads = num_heads
|
||||||
self.dim_head = dim_head or dim // num_heads
|
self.dim_head_qk = dim_head or dim_out // num_heads
|
||||||
self.dim_qk = num_heads * self.dim_head
|
self.dim_head_v = dim_out // self.num_heads
|
||||||
self.dim_v = dim_out
|
self.dim_out_qk = num_heads * self.dim_head_qk
|
||||||
|
self.dim_out_v = num_heads * self.dim_head_v
|
||||||
self.block_size = block_size
|
self.block_size = block_size
|
||||||
self.halo_size = halo_size
|
self.halo_size = halo_size
|
||||||
self.win_size = block_size + halo_size * 2 # neighbourhood window size
|
self.win_size = block_size + halo_size * 2 # neighbourhood window size
|
||||||
self.scale = self.dim_head ** -0.5
|
self.scale = self.dim_head_qk ** -0.5
|
||||||
|
|
||||||
# FIXME not clear if this stride behaviour is what the paper intended
|
# FIXME not clear if this stride behaviour is what the paper intended
|
||||||
# Also, the paper mentions using a 3D conv for dealing with the blocking/gather, and leaving
|
# Also, the paper mentions using a 3D conv for dealing with the blocking/gather, and leaving
|
||||||
# data in unfolded block form. I haven't wrapped my head around how that'd look.
|
# data in unfolded block form. I haven't wrapped my head around how that'd look.
|
||||||
self.q = nn.Conv2d(dim, self.dim_qk, 1, stride=self.stride, bias=qkv_bias)
|
self.q = nn.Conv2d(dim, self.dim_out_qk, 1, stride=self.stride, bias=qkv_bias)
|
||||||
self.kv = nn.Conv2d(dim, self.dim_qk + self.dim_v, 1, bias=qkv_bias)
|
self.kv = nn.Conv2d(dim, self.dim_out_qk + self.dim_out_v, 1, bias=qkv_bias)
|
||||||
|
|
||||||
self.pos_embed = PosEmbedRel(
|
self.pos_embed = PosEmbedRel(
|
||||||
block_size=block_size // self.stride, win_size=self.win_size, dim_head=self.dim_head, scale=self.scale)
|
block_size=block_size // self.stride, win_size=self.win_size, dim_head=self.dim_head_qk, scale=self.scale)
|
||||||
|
|
||||||
self.reset_parameters()
|
self.reset_parameters()
|
||||||
|
|
||||||
@ -143,37 +144,42 @@ class HaloAttn(nn.Module):
|
|||||||
|
|
||||||
q = self.q(x)
|
q = self.q(x)
|
||||||
# unfold
|
# unfold
|
||||||
q = q.reshape(-1, self.dim_head, num_h_blocks, bs_stride, num_w_blocks, bs_stride).permute(0, 1, 3, 5, 2, 4)
|
q = q.reshape(-1, self.dim_head_qk, num_h_blocks, bs_stride, num_w_blocks, bs_stride).permute(0, 1, 3, 5, 2, 4)
|
||||||
# B, num_heads * dim_head * block_size ** 2, num_blocks
|
# B, num_heads * dim_head * block_size ** 2, num_blocks
|
||||||
q = q.reshape(B * self.num_heads, self.dim_head, -1, num_blocks).transpose(1, 3)
|
q = q.reshape(B * self.num_heads, self.dim_head_qk, -1, num_blocks).transpose(1, 3)
|
||||||
# B * num_heads, num_blocks, block_size ** 2, dim_head
|
# B * num_heads, num_blocks, block_size ** 2, dim_head
|
||||||
|
|
||||||
kv = self.kv(x)
|
kv = self.kv(x)
|
||||||
# generate overlapping windows for kv
|
# generate overlapping windows for kv
|
||||||
kv = F.pad(kv, [self.halo_size, self.halo_size, self.halo_size, self.halo_size])
|
kv = F.pad(kv, [self.halo_size, self.halo_size, self.halo_size, self.halo_size])
|
||||||
kv = kv.unfold(2, self.win_size, self.block_size).unfold(3, self.win_size, self.block_size).reshape(
|
kv = kv.unfold(2, self.win_size, self.block_size).unfold(3, self.win_size, self.block_size).reshape(
|
||||||
B * self.num_heads, self.dim_head + (self.dim_v // self.num_heads), num_blocks, -1).permute(0, 2, 3, 1)
|
B * self.num_heads, self.dim_head_qk + self.dim_head_v, num_blocks, -1).permute(0, 2, 3, 1)
|
||||||
# NOTE these two alternatives are equivalent, but above is the best balance of performance and clarity
|
k, v = torch.split(kv, [self.dim_head_qk, self.dim_head_v], dim=-1)
|
||||||
# if self.stride_tricks:
|
# B * num_heads, num_blocks, win_size ** 2, dim_head_qk or dim_head_v
|
||||||
# kv = F.pad(kv, [self.halo_size, self.halo_size, self.halo_size, self.halo_size]).contiguous()
|
|
||||||
# kv = kv.as_strided((
|
|
||||||
# B, self.dim_qk + self.dim_v, self.win_size, self.win_size, num_h_blocks, num_w_blocks),
|
|
||||||
# stride=(kv.stride(0), kv.stride(1), kv.shape[-1], 1, self.block_size * kv.shape[-1], self.block_size))
|
|
||||||
# else:
|
|
||||||
# kv = F.unfold(kv, kernel_size=self.win_size, stride=self.block_size, padding=self.halo_size)
|
|
||||||
# kv = kv.reshape(
|
|
||||||
# B * self.num_heads, self.dim_head + (self.dim_v // self.num_heads), -1, num_blocks).transpose(1, 3)
|
|
||||||
k, v = torch.split(kv, [self.dim_head, self.dim_v // self.num_heads], dim=-1)
|
|
||||||
# B * num_heads, num_blocks, block_size ** 2, dim_head or dim_v // num_heads
|
|
||||||
|
|
||||||
attn_logits = (q @ k.transpose(-1, -2)) * self.scale # FIXME should usual attn scale be applied?
|
attn = (q @ k.transpose(-1, -2)) * self.scale
|
||||||
attn_logits = attn_logits + self.pos_embed(q) # B * num_heads, block_size ** 2, win_size ** 2
|
attn = attn + self.pos_embed(q) # B * num_heads, num_blocks, block_size ** 2, win_size ** 2
|
||||||
|
attn = attn.softmax(dim=-1)
|
||||||
attn_out = attn_logits.softmax(dim=-1)
|
|
||||||
attn_out = (attn_out @ v).transpose(1, 3) # B * num_heads, dim_v // num_heads, block_size ** 2, num_blocks
|
|
||||||
|
|
||||||
|
out = (attn @ v).transpose(1, 3) # B * num_heads, dim_head_v, block_size ** 2, num_blocks
|
||||||
# fold
|
# fold
|
||||||
attn_out = attn_out.reshape(-1, bs_stride, bs_stride, num_h_blocks, num_w_blocks)
|
out = out.reshape(-1, bs_stride, bs_stride, num_h_blocks, num_w_blocks)
|
||||||
attn_out = attn_out.permute(0, 3, 1, 4, 2).contiguous().view(B, self.dim_v, H // self.stride, W // self.stride)
|
out = out.permute(0, 3, 1, 4, 2).contiguous().view(B, self.dim_out_v, H // self.stride, W // self.stride)
|
||||||
# B, dim_out, H // stride, W // stride
|
# B, dim_out, H // stride, W // stride
|
||||||
return attn_out
|
return out
|
||||||
|
|
||||||
|
|
||||||
|
""" Two alternatives for overlapping windows.
|
||||||
|
|
||||||
|
`.unfold().unfold()` is same speed as stride tricks with similar clarity as F.unfold()
|
||||||
|
|
||||||
|
if self.stride_tricks:
|
||||||
|
kv = F.pad(kv, [self.halo_size, self.halo_size, self.halo_size, self.halo_size]).contiguous()
|
||||||
|
kv = kv.as_strided((
|
||||||
|
B, self.dim_out_qk + self.dim_out_v, self.win_size, self.win_size, num_h_blocks, num_w_blocks),
|
||||||
|
stride=(kv.stride(0), kv.stride(1), kv.shape[-1], 1, self.block_size * kv.shape[-1], self.block_size))
|
||||||
|
else:
|
||||||
|
kv = F.unfold(kv, kernel_size=self.win_size, stride=self.block_size, padding=self.halo_size)
|
||||||
|
kv = kv.reshape(
|
||||||
|
B * self.num_heads, self.dim_head_qk + self.dim_head_v, -1, num_blocks).transpose(1, 3)
|
||||||
|
"""
|
||||||
|
Loading…
x
Reference in New Issue
Block a user