Bit more Hiera fiddling

This commit is contained in:
Ross Wightman 2024-04-21 09:36:57 -07:00
parent 8a54d2a930
commit d88bed6535

View File

@ -47,7 +47,7 @@ def conv_nd(n: int) -> Type[nn.Module]:
return [nn.Identity, nn.Conv1d, nn.Conv2d, nn.Conv3d][n]
def get_resized_mask(target_size: torch.Size, mask: torch.Tensor) -> torch.Tensor:
def get_resized_mask(target_size: List[int], mask: torch.Tensor) -> torch.Tensor:
# target_size: [(T), (H), W]
# (spatial) mask: [B, C, (t), (h), w]
if mask is None:
@ -59,23 +59,6 @@ def get_resized_mask(target_size: torch.Size, mask: torch.Tensor) -> torch.Tenso
return mask
def do_masked_conv(
x: torch.Tensor,
conv: nn.Module,
mask: Optional[torch.Tensor] = None,
) -> torch.Tensor:
"""Zero-out the masked regions of the input before conv.
Prevents leakage of masked regions when using overlapping kernels.
"""
if conv is None:
return x
if mask is None:
return conv(x)
mask = get_resized_mask(target_size=x.shape[2:], mask=mask)
return conv(x * mask.bool())
def undo_windowing(
x: torch.Tensor,
shape: List[int],
@ -145,7 +128,6 @@ class Unroll(nn.Module):
Output: Patch embeddings [B, N, C] permuted such that [B, 4, N//4, C].max(1) etc. performs MaxPoolNd
"""
B, _, C = x.shape
cur_size = self.size
x = x.view(*([B] + cur_size + [C]))
@ -332,6 +314,7 @@ class HieraBlock(nn.Module):
act_layer: nn.Module = nn.GELU,
q_stride: int = 1,
window_size: int = 0,
use_expand_proj: bool = True,
use_mask_unit_attn: bool = False,
):
super().__init__()
@ -341,8 +324,14 @@ class HieraBlock(nn.Module):
self.norm1 = norm_layer(dim)
if dim != dim_out:
self.proj = nn.Linear(dim, dim_out)
self.do_expand = True
if use_expand_proj:
self.proj = nn.Linear(dim, dim_out)
else:
assert dim_out == dim * 2
self.proj = None
else:
self.do_expand = False
self.proj = None
self.attn = MaskUnitAttention(
dim,
@ -362,9 +351,17 @@ class HieraBlock(nn.Module):
def forward(self, x: torch.Tensor) -> torch.Tensor:
# Attention + Q Pooling
x_norm = self.norm1(x)
if self.proj is not None:
x = self.proj(x_norm)
x = x.view(x.shape[0], self.attn.q_stride, -1, x.shape[-1]).amax(dim=1) # max-pool
if self.do_expand:
if self.proj is not None:
x = self.proj(x_norm)
x = x.view(x.shape[0], self.attn.q_stride, -1, x.shape[-1]).amax(dim=1) # max-pool
else:
x = torch.cat([
x.view(x.shape[0], self.attn.q_stride, -1, x.shape[-1]).amax(dim=1), # max-pool
x.view(x.shape[0], self.attn.q_stride, -1, x.shape[-1]).mean(dim=1), # avg-pool
],
dim=-1,
)
x = x + self.drop_path1(self.attn(x_norm))
# MLP
@ -419,7 +416,11 @@ class PatchEmbed(nn.Module):
x: torch.Tensor,
mask: Optional[torch.Tensor] = None,
) -> torch.Tensor:
x = do_masked_conv(x, self.proj, mask)
if mask is not None:
mask = get_resized_mask(target_size=x.shape[2:], mask=mask)
x = self.proj(x * mask.to(torch.bool))
else:
x = self.proj(x)
if self.reshape:
x = x.reshape(x.shape[0], x.shape[1], -1).transpose(2, 1)
return x
@ -570,10 +571,10 @@ class Hiera(nn.Module):
@torch.jit.ignore
def no_weight_decay(self):
if self.sep_pos_embed:
return ["pos_embed_spatial", "pos_embed_temporal"]
else:
if self.pos_embed is not None:
return ["pos_embed"]
else:
return ["pos_embed_spatial", "pos_embed_temporal"]
def get_random_mask(self, x: torch.Tensor, mask_ratio: float) -> torch.Tensor:
"""