diff --git a/tests/test_models.py b/tests/test_models.py index 21f37a76..fee2ecff 100644 --- a/tests/test_models.py +++ b/tests/test_models.py @@ -49,7 +49,7 @@ if hasattr(torch._C, '_jit_set_profiling_executor'): # models with forward_intermediates() and support for FeatureGetterNet features_only wrapper FEAT_INTER_FILTERS = [ - 'vit_*', 'twins_*', 'deit*', 'beit*', 'mvitv2*', 'eva*', 'samvit_*', 'flexivit*' + 'vit_*', 'twins_*', 'deit*', 'beit*', 'mvitv2*', 'eva*', 'samvit_*', 'flexivit*', 'hiera_*' ] # transformer models don't support many of the spatial / feature based model functionalities @@ -57,7 +57,7 @@ NON_STD_FILTERS = [ 'vit_*', 'tnt_*', 'pit_*', 'coat_*', 'cait_*', '*mixer_*', 'gmlp_*', 'resmlp_*', 'twins_*', 'convit_*', 'levit*', 'visformer*', 'deit*', 'xcit_*', 'crossvit_*', 'beit*', 'poolformer_*', 'volo_*', 'sequencer2d_*', 'mvitv2*', 'gcvit*', 'efficientformer*', - 'eva_*', 'flexivit*', 'eva02*', 'samvit_*', 'efficientvit_m*', 'tiny_vit_*' + 'eva_*', 'flexivit*', 'eva02*', 'samvit_*', 'efficientvit_m*', 'tiny_vit_*', 'hiera_*' ] NUM_NON_STD = len(NON_STD_FILTERS) diff --git a/timm/models/hiera.py b/timm/models/hiera.py index 95b3cc7e..0de306e7 100644 --- a/timm/models/hiera.py +++ b/timm/models/hiera.py @@ -37,6 +37,7 @@ from timm.layers import DropPath, Mlp, use_fused_attn from ._registry import generate_default_cfgs, register_model from ._builder import build_model_with_cfg +from ._features import feature_take_indices def conv_nd(n: int) -> Type[nn.Module]: @@ -517,7 +518,7 @@ class Hiera(nn.Module): # Transformer blocks cur_stage = 0 self.blocks = nn.ModuleList() - + self.feature_info = [] for i in range(depth): dim_out = embed_dim # Mask unit or global attention. @@ -543,8 +544,10 @@ class Hiera(nn.Module): window_size=flat_mu_size, use_mask_unit_attn=use_mask_unit_attn, ) - embed_dim = dim_out + if i in self.stage_ends: + self.feature_info += [ + dict(num_chs=dim_out, reduction=2**(cur_stage+2), module=f'blocks.{self.stage_ends[cur_stage]}')] self.blocks.append(block) self.norm = norm_layer(embed_dim) @@ -616,6 +619,57 @@ class Hiera(nn.Module): x = x + pos_embed return x + def forward_intermediates( + self, + x: torch.Tensor, + indices: Optional[Union[int, List[int], Tuple[int]]] = None, + norm: bool = False, + stop_early: bool = True, + output_fmt: str = 'NCHW', + intermediates_only: bool = False, + ) -> Union[List[torch.Tensor], Tuple[torch.Tensor, List[torch.Tensor]]]: + """ Forward features that returns intermediates. + + Args: + x: Input image tensor + indices: Take last n blocks if int, all if None, select matching indices if sequence + norm: Apply norm layer to all intermediates + stop_early: Stop iterating over blocks when last desired intermediate hit + output_fmt: Shape of intermediate feature outputs + intermediates_only: Only return intermediate features + Returns: + + """ + assert not norm, 'normalization of features not supported' + assert output_fmt in ('NCHW',), 'Output format must be one of NCHW.' + take_indices, max_index = feature_take_indices(len(self.stage_ends), indices) + + # FIXME using existing return_intermediates support in model, doesn't have early stopping. + x, intermediates = self.forward_features(x, return_intermediates=True) + intermediates = [y.permute(0, 3, 1, 2) for i, y in enumerate(intermediates) if i in take_indices] + if intermediates_only: + return intermediates + + return x, intermediates + + def prune_intermediate_layers( + self, + n: Union[int, List[int], Tuple[int]] = 1, + prune_norm: bool = False, + prune_head: bool = True, + ): + """ Prune layers not required for specified intermediates. + """ + take_indices, max_index = feature_take_indices(len(self.stage_ends), n) + max_index = self.stage_ends[max_index] + self.blocks = self.blocks[:max_index + 1] # truncate blocks + if prune_head: + # norm part of head for this model, equivalent to fc_norm in other vit. + self.norm = nn.Identity() + self.head = nn.Identity() + return take_indices + + def forward_features( self, x: torch.Tensor,