diff --git a/timm/models/hiera.py b/timm/models/hiera.py index 0de306e7..4063a93e 100644 --- a/timm/models/hiera.py +++ b/timm/models/hiera.py @@ -32,12 +32,13 @@ import torch.nn.functional as F from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD -from timm.layers import DropPath, Mlp, use_fused_attn +from timm.layers import DropPath, Mlp, use_fused_attn, _assert from ._registry import generate_default_cfgs, register_model from ._builder import build_model_with_cfg from ._features import feature_take_indices +from ._features_fx import register_notrace_function def conv_nd(n: int) -> Type[nn.Module]: @@ -48,13 +49,14 @@ def conv_nd(n: int) -> Type[nn.Module]: return [nn.Identity, nn.Conv1d, nn.Conv2d, nn.Conv3d][n] +@register_notrace_function def get_resized_mask(target_size: List[int], mask: torch.Tensor) -> torch.Tensor: # target_size: [(T), (H), W] # (spatial) mask: [B, C, (t), (h), w] if mask is None: return mask - assert len(mask.shape[2:]) == len(target_size) + _assert(len(mask.shape[2:]) == len(target_size), "mask spatial shape and target_size must match.") if mask.shape[2:] != target_size: return F.interpolate(mask.float(), size=target_size) return mask @@ -622,6 +624,7 @@ class Hiera(nn.Module): def forward_intermediates( self, x: torch.Tensor, + mask: Optional[torch.Tensor] = None, indices: Optional[Union[int, List[int], Tuple[int]]] = None, norm: bool = False, stop_early: bool = True, @@ -643,10 +646,31 @@ class Hiera(nn.Module): assert not norm, 'normalization of features not supported' assert output_fmt in ('NCHW',), 'Output format must be one of NCHW.' take_indices, max_index = feature_take_indices(len(self.stage_ends), indices) + take_indices = [self.stage_ends[i] for i in take_indices] + max_index = self.stage_ends[max_index] + + if mask is not None: + patch_mask = mask.view(x.shape[0], 1, *self.mask_spatial_shape) # B, C, *mask_spatial_shape + else: + patch_mask = None + x = self.patch_embed(x, mask=patch_mask) + x = self._pos_embed(x) + x = self.unroll(x) + + # Discard masked tokens + if mask is not None: + x = x[mask[..., None].tile(1, self.mu_size, x.shape[2])].view(x.shape[0], -1, x.shape[-1]) + + intermediates = [] + if torch.jit.is_scripting() or not stop_early: # can't slice blocks in torchscript + blocks = self.blocks + else: + blocks = self.blocks[:max_index + 1] + for i, blk in enumerate(blocks): + x = blk(x) + if i in take_indices: + intermediates.append(self.reroll(x, i, mask=mask).permute(0, 3, 1, 2)) - # FIXME using existing return_intermediates support in model, doesn't have early stopping. - x, intermediates = self.forward_features(x, return_intermediates=True) - intermediates = [y.permute(0, 3, 1, 2) for i, y in enumerate(intermediates) if i in take_indices] if intermediates_only: return intermediates @@ -673,18 +697,18 @@ class Hiera(nn.Module): def forward_features( self, x: torch.Tensor, - mask: torch.Tensor = None, + mask: Optional[torch.Tensor] = None, return_intermediates: bool = False, ) -> torch.Tensor: """ mask should be a boolean tensor of shape [B, #MUt*#MUy*#MUx] where #MU are the number of mask units in that dim. Note: 1 in mask is *keep*, 0 is *remove*; mask.sum(dim=-1) should be the same across the batch. """ - x = self.patch_embed( - x, - mask=mask.view(x.shape[0], 1, *self.mask_spatial_shape) # B, C, *mask_spatial_shape - if mask is not None else None, - ) + if mask is not None: + patch_mask = mask.view(x.shape[0], 1, *self.mask_spatial_shape) # B, C, *mask_spatial_shape + else: + patch_mask = None + x = self.patch_embed(x, mask=patch_mask) x = self._pos_embed(x) x = self.unroll(x) @@ -718,19 +742,12 @@ class Hiera(nn.Module): def forward( self, x: torch.Tensor, - mask: torch.Tensor = None, - return_intermediates: bool = False, + mask: Optional[torch.Tensor] = None, ) -> torch.Tensor: - if return_intermediates: - x, intermediates = self.forward_features(x, mask=mask, return_intermediates=return_intermediates) - if mask is not None: - x = self.forward_head(x) - return x, intermediates - else: - x = self.forward_features(x, mask=mask) - if mask is None: - x = self.forward_head(x) - return x + x = self.forward_features(x, mask=mask) + if mask is None: + x = self.forward_head(x) + return x def _cfg(url='', **kwargs):