Fix rotary embed version of attn pool. Bit of cleanup/naming

This commit is contained in:
Ross Wightman 2024-06-11 23:49:17 -07:00
parent cdc7bcea69
commit 57adc1acc8
2 changed files with 17 additions and 14 deletions

View File

@ -42,7 +42,7 @@ class RotAttentionPool2d(nn.Module):
qkv_bias: bool = True, qkv_bias: bool = True,
qkv_separate: bool = False, qkv_separate: bool = False,
pool_type: str = 'token', pool_type: str = 'token',
avg_token: bool = True, class_token: bool = False,
drop_rate: float = 0., drop_rate: float = 0.,
): ):
super().__init__() super().__init__()
@ -63,6 +63,11 @@ class RotAttentionPool2d(nn.Module):
self.scale = self.head_dim ** -0.5 self.scale = self.head_dim ** -0.5
self.fused_attn = use_fused_attn() self.fused_attn = use_fused_attn()
if class_token:
self.cls_token = nn.Parameter(torch.zeros(1, embed_dim))
else:
self.cls_token = None
if qkv_separate: if qkv_separate:
self.q = nn.Linear(in_features, embed_dim, bias=qkv_bias) self.q = nn.Linear(in_features, embed_dim, bias=qkv_bias)
self.k = nn.Linear(in_features, embed_dim, bias=qkv_bias) self.k = nn.Linear(in_features, embed_dim, bias=qkv_bias)
@ -109,7 +114,10 @@ class RotAttentionPool2d(nn.Module):
B, _, H, W = x.shape B, _, H, W = x.shape
N = H * W N = H * W
x = x.flatten(2).transpose(1, 2) x = x.flatten(2).transpose(1, 2)
x = torch.cat([x.mean(1, keepdim=True), x], dim=1) if self.cls_token is None:
x = torch.cat([x.mean(1, keepdim=True), x], dim=1)
else:
x = torch.cat([self.cls_token.expand(x.shape[0], -1, -1), x], dim=1)
if self.qkv is None: if self.qkv is None:
q = self.q(x).reshape(B, N + 1, self.num_heads, self.head_dim).transpose(1, 2) q = self.q(x).reshape(B, N + 1, self.num_heads, self.head_dim).transpose(1, 2)
k = self.k(x).reshape(B, N + 1, self.num_heads, self.head_dim).transpose(1, 2) k = self.k(x).reshape(B, N + 1, self.num_heads, self.head_dim).transpose(1, 2)
@ -130,7 +138,6 @@ class RotAttentionPool2d(nn.Module):
attn = attn.softmax(dim=-1) attn = attn.softmax(dim=-1)
x = attn @ v x = attn @ v
x = x.transpose(1, 2).reshape(B, N + 1, -1) x = x.transpose(1, 2).reshape(B, N + 1, -1)
x = x[:, 0]
x = self.drop(x) x = self.drop(x)
if pre_logits: if pre_logits:
x = self._pool(x, H, W) x = self._pool(x, H, W)
@ -162,7 +169,7 @@ class AttentionPool2d(nn.Module):
qkv_bias: bool = True, qkv_bias: bool = True,
qkv_separate: bool = False, qkv_separate: bool = False,
pool_type: str = 'token', pool_type: str = 'token',
learned_token: bool = False, class_token: bool = False,
drop_rate: float = 0., drop_rate: float = 0.,
): ):
super().__init__() super().__init__()
@ -184,10 +191,10 @@ class AttentionPool2d(nn.Module):
self.scale = self.head_dim ** -0.5 self.scale = self.head_dim ** -0.5
self.fused_attn = use_fused_attn() self.fused_attn = use_fused_attn()
if learned_token: if class_token:
self.token = nn.Parameter(torch.zeros(1, embed_dim)) self.cls_token = nn.Parameter(torch.zeros(1, embed_dim))
else: else:
self.token = None self.cls_token = None
if qkv_separate: if qkv_separate:
self.q = nn.Linear(in_features, embed_dim, bias=qkv_bias) self.q = nn.Linear(in_features, embed_dim, bias=qkv_bias)
@ -239,10 +246,10 @@ class AttentionPool2d(nn.Module):
B, _, H, W = x.shape B, _, H, W = x.shape
N = H * W N = H * W
x = x.flatten(2).transpose(1, 2) x = x.flatten(2).transpose(1, 2)
if self.token is not None: if self.cls_token is None:
x = torch.cat([self.token.expand(x.shape[0], -1, -1), x], dim=1)
else:
x = torch.cat([x.mean(1, keepdim=True), x], dim=1) x = torch.cat([x.mean(1, keepdim=True), x], dim=1)
else:
x = torch.cat([self.cls_token.expand(x.shape[0], -1, -1), x], dim=1)
pos_embed = resample_abs_pos_embed(self.pos_embed.unsqueeze(0), (H, W), num_prefix_tokens=1) pos_embed = resample_abs_pos_embed(self.pos_embed.unsqueeze(0), (H, W), num_prefix_tokens=1)
x = x + pos_embed x = x + pos_embed

View File

@ -1945,7 +1945,6 @@ model_cfgs = dict(
downsample='avg', downsample='avg',
aa_layer='avg', aa_layer='avg',
head_type='attn_abs', head_type='attn_abs',
#head_hidden_size=512,
), ),
resnet50x4_clip=ByoModelCfg( resnet50x4_clip=ByoModelCfg(
@ -1962,7 +1961,6 @@ model_cfgs = dict(
downsample='avg', downsample='avg',
aa_layer='avg', aa_layer='avg',
head_type='attn_abs', head_type='attn_abs',
#head_hidden_size=640,
), ),
resnet50x16_clip=ByoModelCfg( resnet50x16_clip=ByoModelCfg(
@ -1979,7 +1977,6 @@ model_cfgs = dict(
downsample='avg', downsample='avg',
aa_layer='avg', aa_layer='avg',
head_type='attn_abs', head_type='attn_abs',
#head_hidden_size=768,
), ),
resnet50x64_clip=ByoModelCfg( resnet50x64_clip=ByoModelCfg(
@ -1996,7 +1993,6 @@ model_cfgs = dict(
downsample='avg', downsample='avg',
aa_layer='avg', aa_layer='avg',
head_type='attn_abs', head_type='attn_abs',
#head_hidden_size=1024,
), ),
resnet50_mlp=ByoModelCfg( resnet50_mlp=ByoModelCfg(