diff --git a/timm/models/hiera.py b/timm/models/hiera.py index 40cea195..95b3cc7e 100644 --- a/timm/models/hiera.py +++ b/timm/models/hiera.py @@ -47,7 +47,7 @@ def conv_nd(n: int) -> Type[nn.Module]: return [nn.Identity, nn.Conv1d, nn.Conv2d, nn.Conv3d][n] -def get_resized_mask(target_size: torch.Size, mask: torch.Tensor) -> torch.Tensor: +def get_resized_mask(target_size: List[int], mask: torch.Tensor) -> torch.Tensor: # target_size: [(T), (H), W] # (spatial) mask: [B, C, (t), (h), w] if mask is None: @@ -59,23 +59,6 @@ def get_resized_mask(target_size: torch.Size, mask: torch.Tensor) -> torch.Tenso return mask -def do_masked_conv( - x: torch.Tensor, - conv: nn.Module, - mask: Optional[torch.Tensor] = None, -) -> torch.Tensor: - """Zero-out the masked regions of the input before conv. - Prevents leakage of masked regions when using overlapping kernels. - """ - if conv is None: - return x - if mask is None: - return conv(x) - - mask = get_resized_mask(target_size=x.shape[2:], mask=mask) - return conv(x * mask.bool()) - - def undo_windowing( x: torch.Tensor, shape: List[int], @@ -145,7 +128,6 @@ class Unroll(nn.Module): Output: Patch embeddings [B, N, C] permuted such that [B, 4, N//4, C].max(1) etc. performs MaxPoolNd """ B, _, C = x.shape - cur_size = self.size x = x.view(*([B] + cur_size + [C])) @@ -332,6 +314,7 @@ class HieraBlock(nn.Module): act_layer: nn.Module = nn.GELU, q_stride: int = 1, window_size: int = 0, + use_expand_proj: bool = True, use_mask_unit_attn: bool = False, ): super().__init__() @@ -341,8 +324,14 @@ class HieraBlock(nn.Module): self.norm1 = norm_layer(dim) if dim != dim_out: - self.proj = nn.Linear(dim, dim_out) + self.do_expand = True + if use_expand_proj: + self.proj = nn.Linear(dim, dim_out) + else: + assert dim_out == dim * 2 + self.proj = None else: + self.do_expand = False self.proj = None self.attn = MaskUnitAttention( dim, @@ -362,9 +351,17 @@ class HieraBlock(nn.Module): def forward(self, x: torch.Tensor) -> torch.Tensor: # Attention + Q Pooling x_norm = self.norm1(x) - if self.proj is not None: - x = self.proj(x_norm) - x = x.view(x.shape[0], self.attn.q_stride, -1, x.shape[-1]).amax(dim=1) # max-pool + if self.do_expand: + if self.proj is not None: + x = self.proj(x_norm) + x = x.view(x.shape[0], self.attn.q_stride, -1, x.shape[-1]).amax(dim=1) # max-pool + else: + x = torch.cat([ + x.view(x.shape[0], self.attn.q_stride, -1, x.shape[-1]).amax(dim=1), # max-pool + x.view(x.shape[0], self.attn.q_stride, -1, x.shape[-1]).mean(dim=1), # avg-pool + ], + dim=-1, + ) x = x + self.drop_path1(self.attn(x_norm)) # MLP @@ -419,7 +416,11 @@ class PatchEmbed(nn.Module): x: torch.Tensor, mask: Optional[torch.Tensor] = None, ) -> torch.Tensor: - x = do_masked_conv(x, self.proj, mask) + if mask is not None: + mask = get_resized_mask(target_size=x.shape[2:], mask=mask) + x = self.proj(x * mask.to(torch.bool)) + else: + x = self.proj(x) if self.reshape: x = x.reshape(x.shape[0], x.shape[1], -1).transpose(2, 1) return x @@ -570,10 +571,10 @@ class Hiera(nn.Module): @torch.jit.ignore def no_weight_decay(self): - if self.sep_pos_embed: - return ["pos_embed_spatial", "pos_embed_temporal"] - else: + if self.pos_embed is not None: return ["pos_embed"] + else: + return ["pos_embed_spatial", "pos_embed_temporal"] def get_random_mask(self, x: torch.Tensor, mask_ratio: float) -> torch.Tensor: """