Add forward_intermediates API to Hiera for features_only=True support

This commit is contained in:
Ross Wightman 2024-04-21 11:30:41 -07:00
parent d88bed6535
commit ef147fd2fb
2 changed files with 58 additions and 4 deletions

View File

@ -49,7 +49,7 @@ if hasattr(torch._C, '_jit_set_profiling_executor'):
# models with forward_intermediates() and support for FeatureGetterNet features_only wrapper
FEAT_INTER_FILTERS = [
'vit_*', 'twins_*', 'deit*', 'beit*', 'mvitv2*', 'eva*', 'samvit_*', 'flexivit*'
'vit_*', 'twins_*', 'deit*', 'beit*', 'mvitv2*', 'eva*', 'samvit_*', 'flexivit*', 'hiera_*'
]
# transformer models don't support many of the spatial / feature based model functionalities
@ -57,7 +57,7 @@ NON_STD_FILTERS = [
'vit_*', 'tnt_*', 'pit_*', 'coat_*', 'cait_*', '*mixer_*', 'gmlp_*', 'resmlp_*', 'twins_*',
'convit_*', 'levit*', 'visformer*', 'deit*', 'xcit_*', 'crossvit_*', 'beit*',
'poolformer_*', 'volo_*', 'sequencer2d_*', 'mvitv2*', 'gcvit*', 'efficientformer*',
'eva_*', 'flexivit*', 'eva02*', 'samvit_*', 'efficientvit_m*', 'tiny_vit_*'
'eva_*', 'flexivit*', 'eva02*', 'samvit_*', 'efficientvit_m*', 'tiny_vit_*', 'hiera_*'
]
NUM_NON_STD = len(NON_STD_FILTERS)

View File

@ -37,6 +37,7 @@ from timm.layers import DropPath, Mlp, use_fused_attn
from ._registry import generate_default_cfgs, register_model
from ._builder import build_model_with_cfg
from ._features import feature_take_indices
def conv_nd(n: int) -> Type[nn.Module]:
@ -517,7 +518,7 @@ class Hiera(nn.Module):
# Transformer blocks
cur_stage = 0
self.blocks = nn.ModuleList()
self.feature_info = []
for i in range(depth):
dim_out = embed_dim
# Mask unit or global attention.
@ -543,8 +544,10 @@ class Hiera(nn.Module):
window_size=flat_mu_size,
use_mask_unit_attn=use_mask_unit_attn,
)
embed_dim = dim_out
if i in self.stage_ends:
self.feature_info += [
dict(num_chs=dim_out, reduction=2**(cur_stage+2), module=f'blocks.{self.stage_ends[cur_stage]}')]
self.blocks.append(block)
self.norm = norm_layer(embed_dim)
@ -616,6 +619,57 @@ class Hiera(nn.Module):
x = x + pos_embed
return x
def forward_intermediates(
self,
x: torch.Tensor,
indices: Optional[Union[int, List[int], Tuple[int]]] = None,
norm: bool = False,
stop_early: bool = True,
output_fmt: str = 'NCHW',
intermediates_only: bool = False,
) -> Union[List[torch.Tensor], Tuple[torch.Tensor, List[torch.Tensor]]]:
""" Forward features that returns intermediates.
Args:
x: Input image tensor
indices: Take last n blocks if int, all if None, select matching indices if sequence
norm: Apply norm layer to all intermediates
stop_early: Stop iterating over blocks when last desired intermediate hit
output_fmt: Shape of intermediate feature outputs
intermediates_only: Only return intermediate features
Returns:
"""
assert not norm, 'normalization of features not supported'
assert output_fmt in ('NCHW',), 'Output format must be one of NCHW.'
take_indices, max_index = feature_take_indices(len(self.stage_ends), indices)
# FIXME using existing return_intermediates support in model, doesn't have early stopping.
x, intermediates = self.forward_features(x, return_intermediates=True)
intermediates = [y.permute(0, 3, 1, 2) for i, y in enumerate(intermediates) if i in take_indices]
if intermediates_only:
return intermediates
return x, intermediates
def prune_intermediate_layers(
self,
n: Union[int, List[int], Tuple[int]] = 1,
prune_norm: bool = False,
prune_head: bool = True,
):
""" Prune layers not required for specified intermediates.
"""
take_indices, max_index = feature_take_indices(len(self.stage_ends), n)
max_index = self.stage_ends[max_index]
self.blocks = self.blocks[:max_index + 1] # truncate blocks
if prune_head:
# norm part of head for this model, equivalent to fc_norm in other vit.
self.norm = nn.Identity()
self.head = nn.Identity()
return take_indices
def forward_features(
self,
x: torch.Tensor,