mirror of
https://github.com/huggingface/pytorch-image-models.git
synced 2025-06-03 15:01:08 +08:00
Add forward_intermediates API to Hiera for features_only=True support
This commit is contained in:
parent
d88bed6535
commit
ef147fd2fb
@ -49,7 +49,7 @@ if hasattr(torch._C, '_jit_set_profiling_executor'):
|
||||
|
||||
# models with forward_intermediates() and support for FeatureGetterNet features_only wrapper
|
||||
FEAT_INTER_FILTERS = [
|
||||
'vit_*', 'twins_*', 'deit*', 'beit*', 'mvitv2*', 'eva*', 'samvit_*', 'flexivit*'
|
||||
'vit_*', 'twins_*', 'deit*', 'beit*', 'mvitv2*', 'eva*', 'samvit_*', 'flexivit*', 'hiera_*'
|
||||
]
|
||||
|
||||
# transformer models don't support many of the spatial / feature based model functionalities
|
||||
@ -57,7 +57,7 @@ NON_STD_FILTERS = [
|
||||
'vit_*', 'tnt_*', 'pit_*', 'coat_*', 'cait_*', '*mixer_*', 'gmlp_*', 'resmlp_*', 'twins_*',
|
||||
'convit_*', 'levit*', 'visformer*', 'deit*', 'xcit_*', 'crossvit_*', 'beit*',
|
||||
'poolformer_*', 'volo_*', 'sequencer2d_*', 'mvitv2*', 'gcvit*', 'efficientformer*',
|
||||
'eva_*', 'flexivit*', 'eva02*', 'samvit_*', 'efficientvit_m*', 'tiny_vit_*'
|
||||
'eva_*', 'flexivit*', 'eva02*', 'samvit_*', 'efficientvit_m*', 'tiny_vit_*', 'hiera_*'
|
||||
]
|
||||
NUM_NON_STD = len(NON_STD_FILTERS)
|
||||
|
||||
|
@ -37,6 +37,7 @@ from timm.layers import DropPath, Mlp, use_fused_attn
|
||||
|
||||
from ._registry import generate_default_cfgs, register_model
|
||||
from ._builder import build_model_with_cfg
|
||||
from ._features import feature_take_indices
|
||||
|
||||
|
||||
def conv_nd(n: int) -> Type[nn.Module]:
|
||||
@ -517,7 +518,7 @@ class Hiera(nn.Module):
|
||||
# Transformer blocks
|
||||
cur_stage = 0
|
||||
self.blocks = nn.ModuleList()
|
||||
|
||||
self.feature_info = []
|
||||
for i in range(depth):
|
||||
dim_out = embed_dim
|
||||
# Mask unit or global attention.
|
||||
@ -543,8 +544,10 @@ class Hiera(nn.Module):
|
||||
window_size=flat_mu_size,
|
||||
use_mask_unit_attn=use_mask_unit_attn,
|
||||
)
|
||||
|
||||
embed_dim = dim_out
|
||||
if i in self.stage_ends:
|
||||
self.feature_info += [
|
||||
dict(num_chs=dim_out, reduction=2**(cur_stage+2), module=f'blocks.{self.stage_ends[cur_stage]}')]
|
||||
self.blocks.append(block)
|
||||
|
||||
self.norm = norm_layer(embed_dim)
|
||||
@ -616,6 +619,57 @@ class Hiera(nn.Module):
|
||||
x = x + pos_embed
|
||||
return x
|
||||
|
||||
def forward_intermediates(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
indices: Optional[Union[int, List[int], Tuple[int]]] = None,
|
||||
norm: bool = False,
|
||||
stop_early: bool = True,
|
||||
output_fmt: str = 'NCHW',
|
||||
intermediates_only: bool = False,
|
||||
) -> Union[List[torch.Tensor], Tuple[torch.Tensor, List[torch.Tensor]]]:
|
||||
""" Forward features that returns intermediates.
|
||||
|
||||
Args:
|
||||
x: Input image tensor
|
||||
indices: Take last n blocks if int, all if None, select matching indices if sequence
|
||||
norm: Apply norm layer to all intermediates
|
||||
stop_early: Stop iterating over blocks when last desired intermediate hit
|
||||
output_fmt: Shape of intermediate feature outputs
|
||||
intermediates_only: Only return intermediate features
|
||||
Returns:
|
||||
|
||||
"""
|
||||
assert not norm, 'normalization of features not supported'
|
||||
assert output_fmt in ('NCHW',), 'Output format must be one of NCHW.'
|
||||
take_indices, max_index = feature_take_indices(len(self.stage_ends), indices)
|
||||
|
||||
# FIXME using existing return_intermediates support in model, doesn't have early stopping.
|
||||
x, intermediates = self.forward_features(x, return_intermediates=True)
|
||||
intermediates = [y.permute(0, 3, 1, 2) for i, y in enumerate(intermediates) if i in take_indices]
|
||||
if intermediates_only:
|
||||
return intermediates
|
||||
|
||||
return x, intermediates
|
||||
|
||||
def prune_intermediate_layers(
|
||||
self,
|
||||
n: Union[int, List[int], Tuple[int]] = 1,
|
||||
prune_norm: bool = False,
|
||||
prune_head: bool = True,
|
||||
):
|
||||
""" Prune layers not required for specified intermediates.
|
||||
"""
|
||||
take_indices, max_index = feature_take_indices(len(self.stage_ends), n)
|
||||
max_index = self.stage_ends[max_index]
|
||||
self.blocks = self.blocks[:max_index + 1] # truncate blocks
|
||||
if prune_head:
|
||||
# norm part of head for this model, equivalent to fc_norm in other vit.
|
||||
self.norm = nn.Identity()
|
||||
self.head = nn.Identity()
|
||||
return take_indices
|
||||
|
||||
|
||||
def forward_features(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
|
Loading…
x
Reference in New Issue
Block a user