mirror of
https://github.com/huggingface/pytorch-image-models.git
synced 2025-06-03 15:01:08 +08:00
Remove key_padding masking, sequence isolation is enough.
This commit is contained in:
parent
f93083e2b2
commit
2734bb76ce
@ -230,7 +230,8 @@ class Attention(nn.Module):
|
|||||||
else:
|
else:
|
||||||
q = q * self.scale
|
q = q * self.scale
|
||||||
attn = q @ k.transpose(-2, -1)
|
attn = q @ k.transpose(-2, -1)
|
||||||
attn += attn_mask
|
if attn_mask is not None:
|
||||||
|
attn += attn_mask
|
||||||
attn = attn.softmax(dim=-1)
|
attn = attn.softmax(dim=-1)
|
||||||
attn = self.attn_drop(attn)
|
attn = self.attn_drop(attn)
|
||||||
x = attn @ v
|
x = attn @ v
|
||||||
@ -292,7 +293,7 @@ class Block(nn.Module):
|
|||||||
self.ls2 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity()
|
self.ls2 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity()
|
||||||
self.drop_path2 = DropPath(drop_path) if drop_path > 0. else nn.Identity()
|
self.drop_path2 = DropPath(drop_path) if drop_path > 0. else nn.Identity()
|
||||||
|
|
||||||
def forward(self, x, attn_mask: Optional[torch.Tensor]):
|
def forward(self, x, attn_mask: Optional[torch.Tensor] = None):
|
||||||
x = x + self.drop_path1(self.ls1(self.attn(self.norm1(x), attn_mask=attn_mask)))
|
x = x + self.drop_path1(self.ls1(self.attn(self.norm1(x), attn_mask=attn_mask)))
|
||||||
x = x + self.drop_path2(self.ls2(self.mlp(self.norm2(x))))
|
x = x + self.drop_path2(self.ls2(self.mlp(self.norm2(x))))
|
||||||
return x
|
return x
|
||||||
@ -720,8 +721,11 @@ class VisionTransformerPacked(nn.Module):
|
|||||||
|
|
||||||
if attn_mask is None:
|
if attn_mask is None:
|
||||||
attn_mask = seq_ids.unsqueeze(2) == seq_ids.unsqueeze(1)
|
attn_mask = seq_ids.unsqueeze(2) == seq_ids.unsqueeze(1)
|
||||||
key_padding_mask = (seq_ids != 0).unsqueeze(1)
|
# NOTE: not applying key padding mask as padding tokens are already isolated to
|
||||||
attn_mask = attn_mask & key_padding_mask
|
# themselves via the above mask (padding has seq_id == 0). Doing an additional
|
||||||
|
# key padding mask results in fully masked rows which causes numerical issues.
|
||||||
|
# key_padding_mask = (seq_ids != 0).unsqueeze(1)
|
||||||
|
# attn_mask = attn_mask & key_padding_mask
|
||||||
attn_mask = attn_mask.unsqueeze(1)
|
attn_mask = attn_mask.unsqueeze(1)
|
||||||
|
|
||||||
if attn_mask.dtype == torch.bool:
|
if attn_mask.dtype == torch.bool:
|
||||||
@ -729,11 +733,12 @@ class VisionTransformerPacked(nn.Module):
|
|||||||
min_val = torch.finfo(dtype).min
|
min_val = torch.finfo(dtype).min
|
||||||
attn_mask = torch.zeros_like(attn_mask, dtype=dtype).masked_fill_(~attn_mask, min_val)
|
attn_mask = torch.zeros_like(attn_mask, dtype=dtype).masked_fill_(~attn_mask, min_val)
|
||||||
|
|
||||||
# if self.grad_checkpointing and not torch.jit.is_scripting():
|
|
||||||
# tokens = checkpoint_seq(self.blocks, tokens)
|
|
||||||
# else:
|
|
||||||
for b in self.blocks:
|
for b in self.blocks:
|
||||||
tokens = b(tokens, attn_mask=attn_mask)
|
if self.grad_checkpointing and not torch.jit.is_scripting():
|
||||||
|
tokens = torch.utils.checkpoint.checkpoint(
|
||||||
|
b, tokens, use_reentrant=False, attn_mask=attn_mask)
|
||||||
|
else:
|
||||||
|
tokens = b(tokens, attn_mask=attn_mask)
|
||||||
tokens = self.norm(tokens)
|
tokens = self.norm(tokens)
|
||||||
|
|
||||||
device = tokens.device
|
device = tokens.device
|
||||||
@ -743,7 +748,7 @@ class VisionTransformerPacked(nn.Module):
|
|||||||
seq_lens = seq_lens.reshape(-1)
|
seq_lens = seq_lens.reshape(-1)
|
||||||
valid_rows = seq_lens > 0
|
valid_rows = seq_lens > 0
|
||||||
if self.attn_pool is not None:
|
if self.attn_pool is not None:
|
||||||
unpack_mask = unpack_mask & key_padding_mask
|
# unpack_mask = unpack_mask & key_padding_mask
|
||||||
unpack_mask = unpack_mask.unsqueeze(1)
|
unpack_mask = unpack_mask.unsqueeze(1)
|
||||||
unpack_mask = torch.zeros_like(unpack_mask, dtype=tokens.dtype).masked_fill_(
|
unpack_mask = torch.zeros_like(unpack_mask, dtype=tokens.dtype).masked_fill_(
|
||||||
~unpack_mask, torch.finfo(tokens.dtype).min)
|
~unpack_mask, torch.finfo(tokens.dtype).min)
|
||||||
@ -767,6 +772,7 @@ class VisionTransformerPacked(nn.Module):
|
|||||||
if isinstance(x, (list, tuple)):
|
if isinstance(x, (list, tuple)):
|
||||||
x = torch.stack([t.mean(dim=0) for t in x], 0)
|
x = torch.stack([t.mean(dim=0) for t in x], 0)
|
||||||
else:
|
else:
|
||||||
|
# x = x.sum(dim=1) / seq_lens.reshape(-1, 1)
|
||||||
x = x.mean(dim=1)
|
x = x.mean(dim=1)
|
||||||
x = self.fc_norm(x)
|
x = self.fc_norm(x)
|
||||||
x = self.head_drop(x)
|
x = self.head_drop(x)
|
||||||
@ -801,6 +807,7 @@ def _cfg(url='', **kwargs):
|
|||||||
|
|
||||||
|
|
||||||
default_cfgs = generate_default_cfgs({
|
default_cfgs = generate_default_cfgs({
|
||||||
|
'navit_medium_patch16_384': _cfg(),
|
||||||
'navit_base_patch32_224': _cfg(),
|
'navit_base_patch32_224': _cfg(),
|
||||||
'navit_base_patch32_384': _cfg(),
|
'navit_base_patch32_384': _cfg(),
|
||||||
'navit_base_patch16_224': _cfg(),
|
'navit_base_patch16_224': _cfg(),
|
||||||
@ -821,6 +828,16 @@ def _create_vision_transformer_packed(variant, pretrained=False, **kwargs):
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@register_model
|
||||||
|
def navit_medium_patch16_384(pretrained=False, **kwargs) -> VisionTransformerPacked:
|
||||||
|
model_args = dict(
|
||||||
|
img_size=384, patch_size=16, embed_dim=512, depth=12, num_heads=8,
|
||||||
|
fc_norm=False, init_values=1e-5, qkv_bias=False)
|
||||||
|
model = _create_vision_transformer_packed(
|
||||||
|
'navit_medium_patch16_384', pretrained=pretrained, **dict(model_args, **kwargs))
|
||||||
|
return model
|
||||||
|
|
||||||
|
|
||||||
@register_model
|
@register_model
|
||||||
def navit_base_patch32_224(pretrained=False, **kwargs) -> VisionTransformerPacked:
|
def navit_base_patch32_224(pretrained=False, **kwargs) -> VisionTransformerPacked:
|
||||||
model_args = dict(patch_size=32, embed_dim=768, depth=12, num_heads=12)
|
model_args = dict(patch_size=32, embed_dim=768, depth=12, num_heads=12)
|
||||||
|
Loading…
x
Reference in New Issue
Block a user