mirror of
https://github.com/huggingface/pytorch-image-models.git
synced 2025-06-03 15:01:08 +08:00
Bit more Hiera fiddling
This commit is contained in:
parent
8a54d2a930
commit
d88bed6535
@ -47,7 +47,7 @@ def conv_nd(n: int) -> Type[nn.Module]:
|
||||
return [nn.Identity, nn.Conv1d, nn.Conv2d, nn.Conv3d][n]
|
||||
|
||||
|
||||
def get_resized_mask(target_size: torch.Size, mask: torch.Tensor) -> torch.Tensor:
|
||||
def get_resized_mask(target_size: List[int], mask: torch.Tensor) -> torch.Tensor:
|
||||
# target_size: [(T), (H), W]
|
||||
# (spatial) mask: [B, C, (t), (h), w]
|
||||
if mask is None:
|
||||
@ -59,23 +59,6 @@ def get_resized_mask(target_size: torch.Size, mask: torch.Tensor) -> torch.Tenso
|
||||
return mask
|
||||
|
||||
|
||||
def do_masked_conv(
|
||||
x: torch.Tensor,
|
||||
conv: nn.Module,
|
||||
mask: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
"""Zero-out the masked regions of the input before conv.
|
||||
Prevents leakage of masked regions when using overlapping kernels.
|
||||
"""
|
||||
if conv is None:
|
||||
return x
|
||||
if mask is None:
|
||||
return conv(x)
|
||||
|
||||
mask = get_resized_mask(target_size=x.shape[2:], mask=mask)
|
||||
return conv(x * mask.bool())
|
||||
|
||||
|
||||
def undo_windowing(
|
||||
x: torch.Tensor,
|
||||
shape: List[int],
|
||||
@ -145,7 +128,6 @@ class Unroll(nn.Module):
|
||||
Output: Patch embeddings [B, N, C] permuted such that [B, 4, N//4, C].max(1) etc. performs MaxPoolNd
|
||||
"""
|
||||
B, _, C = x.shape
|
||||
|
||||
cur_size = self.size
|
||||
x = x.view(*([B] + cur_size + [C]))
|
||||
|
||||
@ -332,6 +314,7 @@ class HieraBlock(nn.Module):
|
||||
act_layer: nn.Module = nn.GELU,
|
||||
q_stride: int = 1,
|
||||
window_size: int = 0,
|
||||
use_expand_proj: bool = True,
|
||||
use_mask_unit_attn: bool = False,
|
||||
):
|
||||
super().__init__()
|
||||
@ -341,8 +324,14 @@ class HieraBlock(nn.Module):
|
||||
|
||||
self.norm1 = norm_layer(dim)
|
||||
if dim != dim_out:
|
||||
self.proj = nn.Linear(dim, dim_out)
|
||||
self.do_expand = True
|
||||
if use_expand_proj:
|
||||
self.proj = nn.Linear(dim, dim_out)
|
||||
else:
|
||||
assert dim_out == dim * 2
|
||||
self.proj = None
|
||||
else:
|
||||
self.do_expand = False
|
||||
self.proj = None
|
||||
self.attn = MaskUnitAttention(
|
||||
dim,
|
||||
@ -362,9 +351,17 @@ class HieraBlock(nn.Module):
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
# Attention + Q Pooling
|
||||
x_norm = self.norm1(x)
|
||||
if self.proj is not None:
|
||||
x = self.proj(x_norm)
|
||||
x = x.view(x.shape[0], self.attn.q_stride, -1, x.shape[-1]).amax(dim=1) # max-pool
|
||||
if self.do_expand:
|
||||
if self.proj is not None:
|
||||
x = self.proj(x_norm)
|
||||
x = x.view(x.shape[0], self.attn.q_stride, -1, x.shape[-1]).amax(dim=1) # max-pool
|
||||
else:
|
||||
x = torch.cat([
|
||||
x.view(x.shape[0], self.attn.q_stride, -1, x.shape[-1]).amax(dim=1), # max-pool
|
||||
x.view(x.shape[0], self.attn.q_stride, -1, x.shape[-1]).mean(dim=1), # avg-pool
|
||||
],
|
||||
dim=-1,
|
||||
)
|
||||
x = x + self.drop_path1(self.attn(x_norm))
|
||||
|
||||
# MLP
|
||||
@ -419,7 +416,11 @@ class PatchEmbed(nn.Module):
|
||||
x: torch.Tensor,
|
||||
mask: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
x = do_masked_conv(x, self.proj, mask)
|
||||
if mask is not None:
|
||||
mask = get_resized_mask(target_size=x.shape[2:], mask=mask)
|
||||
x = self.proj(x * mask.to(torch.bool))
|
||||
else:
|
||||
x = self.proj(x)
|
||||
if self.reshape:
|
||||
x = x.reshape(x.shape[0], x.shape[1], -1).transpose(2, 1)
|
||||
return x
|
||||
@ -570,10 +571,10 @@ class Hiera(nn.Module):
|
||||
|
||||
@torch.jit.ignore
|
||||
def no_weight_decay(self):
|
||||
if self.sep_pos_embed:
|
||||
return ["pos_embed_spatial", "pos_embed_temporal"]
|
||||
else:
|
||||
if self.pos_embed is not None:
|
||||
return ["pos_embed"]
|
||||
else:
|
||||
return ["pos_embed_spatial", "pos_embed_temporal"]
|
||||
|
||||
def get_random_mask(self, x: torch.Tensor, mask_ratio: float) -> torch.Tensor:
|
||||
"""
|
||||
|
Loading…
x
Reference in New Issue
Block a user