mirror of
https://github.com/huggingface/pytorch-image-models.git
synced 2025-06-03 15:01:08 +08:00
Remove key_padding masking, sequence isolation is enough.
This commit is contained in:
parent
f93083e2b2
commit
2734bb76ce
@ -230,7 +230,8 @@ class Attention(nn.Module):
|
||||
else:
|
||||
q = q * self.scale
|
||||
attn = q @ k.transpose(-2, -1)
|
||||
attn += attn_mask
|
||||
if attn_mask is not None:
|
||||
attn += attn_mask
|
||||
attn = attn.softmax(dim=-1)
|
||||
attn = self.attn_drop(attn)
|
||||
x = attn @ v
|
||||
@ -292,7 +293,7 @@ class Block(nn.Module):
|
||||
self.ls2 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity()
|
||||
self.drop_path2 = DropPath(drop_path) if drop_path > 0. else nn.Identity()
|
||||
|
||||
def forward(self, x, attn_mask: Optional[torch.Tensor]):
|
||||
def forward(self, x, attn_mask: Optional[torch.Tensor] = None):
|
||||
x = x + self.drop_path1(self.ls1(self.attn(self.norm1(x), attn_mask=attn_mask)))
|
||||
x = x + self.drop_path2(self.ls2(self.mlp(self.norm2(x))))
|
||||
return x
|
||||
@ -720,8 +721,11 @@ class VisionTransformerPacked(nn.Module):
|
||||
|
||||
if attn_mask is None:
|
||||
attn_mask = seq_ids.unsqueeze(2) == seq_ids.unsqueeze(1)
|
||||
key_padding_mask = (seq_ids != 0).unsqueeze(1)
|
||||
attn_mask = attn_mask & key_padding_mask
|
||||
# NOTE: not applying key padding mask as padding tokens are already isolated to
|
||||
# themselves via the above mask (padding has seq_id == 0). Doing an additional
|
||||
# key padding mask results in fully masked rows which causes numerical issues.
|
||||
# key_padding_mask = (seq_ids != 0).unsqueeze(1)
|
||||
# attn_mask = attn_mask & key_padding_mask
|
||||
attn_mask = attn_mask.unsqueeze(1)
|
||||
|
||||
if attn_mask.dtype == torch.bool:
|
||||
@ -729,11 +733,12 @@ class VisionTransformerPacked(nn.Module):
|
||||
min_val = torch.finfo(dtype).min
|
||||
attn_mask = torch.zeros_like(attn_mask, dtype=dtype).masked_fill_(~attn_mask, min_val)
|
||||
|
||||
# if self.grad_checkpointing and not torch.jit.is_scripting():
|
||||
# tokens = checkpoint_seq(self.blocks, tokens)
|
||||
# else:
|
||||
for b in self.blocks:
|
||||
tokens = b(tokens, attn_mask=attn_mask)
|
||||
if self.grad_checkpointing and not torch.jit.is_scripting():
|
||||
tokens = torch.utils.checkpoint.checkpoint(
|
||||
b, tokens, use_reentrant=False, attn_mask=attn_mask)
|
||||
else:
|
||||
tokens = b(tokens, attn_mask=attn_mask)
|
||||
tokens = self.norm(tokens)
|
||||
|
||||
device = tokens.device
|
||||
@ -743,7 +748,7 @@ class VisionTransformerPacked(nn.Module):
|
||||
seq_lens = seq_lens.reshape(-1)
|
||||
valid_rows = seq_lens > 0
|
||||
if self.attn_pool is not None:
|
||||
unpack_mask = unpack_mask & key_padding_mask
|
||||
# unpack_mask = unpack_mask & key_padding_mask
|
||||
unpack_mask = unpack_mask.unsqueeze(1)
|
||||
unpack_mask = torch.zeros_like(unpack_mask, dtype=tokens.dtype).masked_fill_(
|
||||
~unpack_mask, torch.finfo(tokens.dtype).min)
|
||||
@ -767,6 +772,7 @@ class VisionTransformerPacked(nn.Module):
|
||||
if isinstance(x, (list, tuple)):
|
||||
x = torch.stack([t.mean(dim=0) for t in x], 0)
|
||||
else:
|
||||
# x = x.sum(dim=1) / seq_lens.reshape(-1, 1)
|
||||
x = x.mean(dim=1)
|
||||
x = self.fc_norm(x)
|
||||
x = self.head_drop(x)
|
||||
@ -801,6 +807,7 @@ def _cfg(url='', **kwargs):
|
||||
|
||||
|
||||
default_cfgs = generate_default_cfgs({
|
||||
'navit_medium_patch16_384': _cfg(),
|
||||
'navit_base_patch32_224': _cfg(),
|
||||
'navit_base_patch32_384': _cfg(),
|
||||
'navit_base_patch16_224': _cfg(),
|
||||
@ -821,6 +828,16 @@ def _create_vision_transformer_packed(variant, pretrained=False, **kwargs):
|
||||
)
|
||||
|
||||
|
||||
@register_model
|
||||
def navit_medium_patch16_384(pretrained=False, **kwargs) -> VisionTransformerPacked:
|
||||
model_args = dict(
|
||||
img_size=384, patch_size=16, embed_dim=512, depth=12, num_heads=8,
|
||||
fc_norm=False, init_values=1e-5, qkv_bias=False)
|
||||
model = _create_vision_transformer_packed(
|
||||
'navit_medium_patch16_384', pretrained=pretrained, **dict(model_args, **kwargs))
|
||||
return model
|
||||
|
||||
|
||||
@register_model
|
||||
def navit_base_patch32_224(pretrained=False, **kwargs) -> VisionTransformerPacked:
|
||||
model_args = dict(patch_size=32, embed_dim=768, depth=12, num_heads=12)
|
||||
|
Loading…
x
Reference in New Issue
Block a user