mirror of
https://github.com/huggingface/pytorch-image-models.git
synced 2025-06-03 15:01:08 +08:00
Fix rotary embed version of attn pool. Bit of cleanup/naming
This commit is contained in:
parent
cdc7bcea69
commit
57adc1acc8
@ -42,7 +42,7 @@ class RotAttentionPool2d(nn.Module):
|
|||||||
qkv_bias: bool = True,
|
qkv_bias: bool = True,
|
||||||
qkv_separate: bool = False,
|
qkv_separate: bool = False,
|
||||||
pool_type: str = 'token',
|
pool_type: str = 'token',
|
||||||
avg_token: bool = True,
|
class_token: bool = False,
|
||||||
drop_rate: float = 0.,
|
drop_rate: float = 0.,
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
@ -63,6 +63,11 @@ class RotAttentionPool2d(nn.Module):
|
|||||||
self.scale = self.head_dim ** -0.5
|
self.scale = self.head_dim ** -0.5
|
||||||
self.fused_attn = use_fused_attn()
|
self.fused_attn = use_fused_attn()
|
||||||
|
|
||||||
|
if class_token:
|
||||||
|
self.cls_token = nn.Parameter(torch.zeros(1, embed_dim))
|
||||||
|
else:
|
||||||
|
self.cls_token = None
|
||||||
|
|
||||||
if qkv_separate:
|
if qkv_separate:
|
||||||
self.q = nn.Linear(in_features, embed_dim, bias=qkv_bias)
|
self.q = nn.Linear(in_features, embed_dim, bias=qkv_bias)
|
||||||
self.k = nn.Linear(in_features, embed_dim, bias=qkv_bias)
|
self.k = nn.Linear(in_features, embed_dim, bias=qkv_bias)
|
||||||
@ -109,7 +114,10 @@ class RotAttentionPool2d(nn.Module):
|
|||||||
B, _, H, W = x.shape
|
B, _, H, W = x.shape
|
||||||
N = H * W
|
N = H * W
|
||||||
x = x.flatten(2).transpose(1, 2)
|
x = x.flatten(2).transpose(1, 2)
|
||||||
x = torch.cat([x.mean(1, keepdim=True), x], dim=1)
|
if self.cls_token is None:
|
||||||
|
x = torch.cat([x.mean(1, keepdim=True), x], dim=1)
|
||||||
|
else:
|
||||||
|
x = torch.cat([self.cls_token.expand(x.shape[0], -1, -1), x], dim=1)
|
||||||
if self.qkv is None:
|
if self.qkv is None:
|
||||||
q = self.q(x).reshape(B, N + 1, self.num_heads, self.head_dim).transpose(1, 2)
|
q = self.q(x).reshape(B, N + 1, self.num_heads, self.head_dim).transpose(1, 2)
|
||||||
k = self.k(x).reshape(B, N + 1, self.num_heads, self.head_dim).transpose(1, 2)
|
k = self.k(x).reshape(B, N + 1, self.num_heads, self.head_dim).transpose(1, 2)
|
||||||
@ -130,7 +138,6 @@ class RotAttentionPool2d(nn.Module):
|
|||||||
attn = attn.softmax(dim=-1)
|
attn = attn.softmax(dim=-1)
|
||||||
x = attn @ v
|
x = attn @ v
|
||||||
x = x.transpose(1, 2).reshape(B, N + 1, -1)
|
x = x.transpose(1, 2).reshape(B, N + 1, -1)
|
||||||
x = x[:, 0]
|
|
||||||
x = self.drop(x)
|
x = self.drop(x)
|
||||||
if pre_logits:
|
if pre_logits:
|
||||||
x = self._pool(x, H, W)
|
x = self._pool(x, H, W)
|
||||||
@ -162,7 +169,7 @@ class AttentionPool2d(nn.Module):
|
|||||||
qkv_bias: bool = True,
|
qkv_bias: bool = True,
|
||||||
qkv_separate: bool = False,
|
qkv_separate: bool = False,
|
||||||
pool_type: str = 'token',
|
pool_type: str = 'token',
|
||||||
learned_token: bool = False,
|
class_token: bool = False,
|
||||||
drop_rate: float = 0.,
|
drop_rate: float = 0.,
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
@ -184,10 +191,10 @@ class AttentionPool2d(nn.Module):
|
|||||||
self.scale = self.head_dim ** -0.5
|
self.scale = self.head_dim ** -0.5
|
||||||
self.fused_attn = use_fused_attn()
|
self.fused_attn = use_fused_attn()
|
||||||
|
|
||||||
if learned_token:
|
if class_token:
|
||||||
self.token = nn.Parameter(torch.zeros(1, embed_dim))
|
self.cls_token = nn.Parameter(torch.zeros(1, embed_dim))
|
||||||
else:
|
else:
|
||||||
self.token = None
|
self.cls_token = None
|
||||||
|
|
||||||
if qkv_separate:
|
if qkv_separate:
|
||||||
self.q = nn.Linear(in_features, embed_dim, bias=qkv_bias)
|
self.q = nn.Linear(in_features, embed_dim, bias=qkv_bias)
|
||||||
@ -239,10 +246,10 @@ class AttentionPool2d(nn.Module):
|
|||||||
B, _, H, W = x.shape
|
B, _, H, W = x.shape
|
||||||
N = H * W
|
N = H * W
|
||||||
x = x.flatten(2).transpose(1, 2)
|
x = x.flatten(2).transpose(1, 2)
|
||||||
if self.token is not None:
|
if self.cls_token is None:
|
||||||
x = torch.cat([self.token.expand(x.shape[0], -1, -1), x], dim=1)
|
|
||||||
else:
|
|
||||||
x = torch.cat([x.mean(1, keepdim=True), x], dim=1)
|
x = torch.cat([x.mean(1, keepdim=True), x], dim=1)
|
||||||
|
else:
|
||||||
|
x = torch.cat([self.cls_token.expand(x.shape[0], -1, -1), x], dim=1)
|
||||||
pos_embed = resample_abs_pos_embed(self.pos_embed.unsqueeze(0), (H, W), num_prefix_tokens=1)
|
pos_embed = resample_abs_pos_embed(self.pos_embed.unsqueeze(0), (H, W), num_prefix_tokens=1)
|
||||||
x = x + pos_embed
|
x = x + pos_embed
|
||||||
|
|
||||||
|
@ -1945,7 +1945,6 @@ model_cfgs = dict(
|
|||||||
downsample='avg',
|
downsample='avg',
|
||||||
aa_layer='avg',
|
aa_layer='avg',
|
||||||
head_type='attn_abs',
|
head_type='attn_abs',
|
||||||
#head_hidden_size=512,
|
|
||||||
),
|
),
|
||||||
|
|
||||||
resnet50x4_clip=ByoModelCfg(
|
resnet50x4_clip=ByoModelCfg(
|
||||||
@ -1962,7 +1961,6 @@ model_cfgs = dict(
|
|||||||
downsample='avg',
|
downsample='avg',
|
||||||
aa_layer='avg',
|
aa_layer='avg',
|
||||||
head_type='attn_abs',
|
head_type='attn_abs',
|
||||||
#head_hidden_size=640,
|
|
||||||
),
|
),
|
||||||
|
|
||||||
resnet50x16_clip=ByoModelCfg(
|
resnet50x16_clip=ByoModelCfg(
|
||||||
@ -1979,7 +1977,6 @@ model_cfgs = dict(
|
|||||||
downsample='avg',
|
downsample='avg',
|
||||||
aa_layer='avg',
|
aa_layer='avg',
|
||||||
head_type='attn_abs',
|
head_type='attn_abs',
|
||||||
#head_hidden_size=768,
|
|
||||||
),
|
),
|
||||||
|
|
||||||
resnet50x64_clip=ByoModelCfg(
|
resnet50x64_clip=ByoModelCfg(
|
||||||
@ -1996,7 +1993,6 @@ model_cfgs = dict(
|
|||||||
downsample='avg',
|
downsample='avg',
|
||||||
aa_layer='avg',
|
aa_layer='avg',
|
||||||
head_type='attn_abs',
|
head_type='attn_abs',
|
||||||
#head_hidden_size=1024,
|
|
||||||
),
|
),
|
||||||
|
|
||||||
resnet50_mlp=ByoModelCfg(
|
resnet50_mlp=ByoModelCfg(
|
||||||
|
Loading…
x
Reference in New Issue
Block a user