Merge pull request #2239 from huggingface/fix_out_indices_order

Fix issue where feature out_indices out of order after wrapping with FeatureGetterNet
pull/2243/head
Ross Wightman 2024-07-23 11:02:42 -07:00 committed by GitHub
commit f3c11dc3a5
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
26 changed files with 87 additions and 87 deletions

View File

@ -11,13 +11,13 @@ Hacked together by / Copyright 2020 Ross Wightman
from collections import OrderedDict, defaultdict
from copy import deepcopy
from functools import partial
from typing import Dict, List, Optional, Sequence, Set, Tuple, Union
from typing import Dict, List, Optional, Sequence, Tuple, Union
import torch
import torch.nn as nn
from torch.utils.checkpoint import checkpoint
from timm.layers import Format
from timm.layers import Format, _assert
__all__ = [
@ -26,44 +26,45 @@ __all__ = [
]
def _take_indices(
num_blocks: int,
n: Optional[Union[int, List[int], Tuple[int]]],
) -> Tuple[Set[int], int]:
if isinstance(n, int):
assert n >= 0
take_indices = {x for x in range(num_blocks - n, num_blocks)}
else:
take_indices = {num_blocks + idx if idx < 0 else idx for idx in n}
return take_indices, max(take_indices)
def _take_indices_jit(
num_blocks: int,
n: Union[int, List[int], Tuple[int]],
) -> Tuple[List[int], int]:
if isinstance(n, int):
assert n >= 0
take_indices = [num_blocks - n + i for i in range(n)]
elif isinstance(n, tuple):
# splitting this up is silly, but needed for torchscript type resolution of n
take_indices = [num_blocks + idx if idx < 0 else idx for idx in n]
else:
take_indices = [num_blocks + idx if idx < 0 else idx for idx in n]
return take_indices, max(take_indices)
def feature_take_indices(
num_blocks: int,
indices: Optional[Union[int, List[int], Tuple[int]]] = None,
num_features: int,
indices: Optional[Union[int, List[int]]] = None,
as_set: bool = False,
) -> Tuple[List[int], int]:
""" Determine the absolute feature indices to 'take' from.
Note: This function can be called in forwar() so must be torchscript compatible,
which requires some incomplete typing and workaround hacks.
Args:
num_features: total number of features to select from
indices: indices to select,
None -> select all
int -> select last n
list/tuple of int -> return specified (-ve indices specify from end)
as_set: return as a set
Returns:
List (or set) of absolute (from beginning) indices, Maximum index
"""
if indices is None:
indices = num_blocks # all blocks if None
if torch.jit.is_scripting():
return _take_indices_jit(num_blocks, indices)
indices = num_features # all features if None
if isinstance(indices, int):
# convert int -> last n indices
_assert(0 < indices <= num_features, f'last-n ({indices}) is out of range (1 to {num_features})')
take_indices = [num_features - indices + i for i in range(indices)]
else:
# NOTE non-jit returns Set[int] instead of List[int] but torchscript can't handle that anno
return _take_indices(num_blocks, indices)
take_indices: List[int] = []
for i in indices:
idx = num_features + i if i < 0 else i
_assert(0 <= idx < num_features, f'feature index {idx} is out of range (0 to {num_features - 1})')
take_indices.append(idx)
if not torch.jit.is_scripting() and as_set:
return set(take_indices), max(take_indices)
return take_indices, max(take_indices)
def _out_indices_as_tuple(x: Union[int, Tuple[int, ...]]) -> Tuple[int, ...]:
@ -464,7 +465,6 @@ class FeatureGetterNet(nn.ModuleDict):
out_indices,
prune_norm=not norm,
)
out_indices = list(out_indices)
self.feature_info = _get_feature_info(model, out_indices)
self.model = model
self.out_indices = out_indices

View File

@ -404,7 +404,7 @@ class Beit(nn.Module):
def forward_intermediates(
self,
x: torch.Tensor,
indices: Optional[Union[int, List[int], Tuple[int]]] = None,
indices: Optional[Union[int, List[int]]] = None,
return_prefix_tokens: bool = False,
norm: bool = False,
stop_early: bool = False,
@ -470,7 +470,7 @@ class Beit(nn.Module):
def prune_intermediate_layers(
self,
indices: Union[int, List[int], Tuple[int]] = 1,
indices: Union[int, List[int]] = 1,
prune_norm: bool = False,
prune_head: bool = True,
):

View File

@ -1343,7 +1343,7 @@ class ByobNet(nn.Module):
def forward_intermediates(
self,
x: torch.Tensor,
indices: Optional[Union[int, List[int], Tuple[int]]] = None,
indices: Optional[Union[int, List[int]]] = None,
norm: bool = False,
stop_early: bool = False,
output_fmt: str = 'NCHW',
@ -1401,7 +1401,7 @@ class ByobNet(nn.Module):
def prune_intermediate_layers(
self,
indices: Union[int, List[int], Tuple[int]] = 1,
indices: Union[int, List[int]] = 1,
prune_norm: bool = False,
prune_head: bool = True,
):

View File

@ -341,7 +341,7 @@ class Cait(nn.Module):
def forward_intermediates(
self,
x: torch.Tensor,
indices: Optional[Union[int, List[int], Tuple[int]]] = None,
indices: Optional[Union[int, List[int]]] = None,
norm: bool = False,
stop_early: bool = False,
output_fmt: str = 'NCHW',
@ -398,7 +398,7 @@ class Cait(nn.Module):
def prune_intermediate_layers(
self,
indices: Union[int, List[int], Tuple[int]] = 1,
indices: Union[int, List[int]] = 1,
prune_norm: bool = False,
prune_head: bool = True,
):

View File

@ -412,7 +412,7 @@ class ConvNeXt(nn.Module):
def forward_intermediates(
self,
x: torch.Tensor,
indices: Optional[Union[int, List[int], Tuple[int]]] = None,
indices: Optional[Union[int, List[int]]] = None,
norm: bool = False,
stop_early: bool = False,
output_fmt: str = 'NCHW',
@ -460,7 +460,7 @@ class ConvNeXt(nn.Module):
def prune_intermediate_layers(
self,
indices: Union[int, List[int], Tuple[int]] = 1,
indices: Union[int, List[int]] = 1,
prune_norm: bool = False,
prune_head: bool = True,
):

View File

@ -463,7 +463,7 @@ class EfficientFormer(nn.Module):
def forward_intermediates(
self,
x: torch.Tensor,
indices: Optional[Union[int, List[int], Tuple[int]]] = None,
indices: Optional[Union[int, List[int]]] = None,
norm: bool = False,
stop_early: bool = False,
output_fmt: str = 'NCHW',
@ -516,7 +516,7 @@ class EfficientFormer(nn.Module):
def prune_intermediate_layers(
self,
indices: Union[int, List[int], Tuple[int]] = 1,
indices: Union[int, List[int]] = 1,
prune_norm: bool = False,
prune_head: bool = True,
):

View File

@ -165,7 +165,7 @@ class EfficientNet(nn.Module):
def forward_intermediates(
self,
x: torch.Tensor,
indices: Optional[Union[int, List[int], Tuple[int]]] = None,
indices: Optional[Union[int, List[int]]] = None,
norm: bool = False,
stop_early: bool = False,
output_fmt: str = 'NCHW',
@ -221,7 +221,7 @@ class EfficientNet(nn.Module):
def prune_intermediate_layers(
self,
indices: Union[int, List[int], Tuple[int]] = 1,
indices: Union[int, List[int]] = 1,
prune_norm: bool = False,
prune_head: bool = True,
extra_blocks: bool = False,

View File

@ -589,7 +589,7 @@ class Eva(nn.Module):
def forward_intermediates(
self,
x: torch.Tensor,
indices: Optional[Union[int, List[int], Tuple[int]]] = None,
indices: Optional[Union[int, List[int]]] = None,
return_prefix_tokens: bool = False,
norm: bool = False,
stop_early: bool = False,
@ -646,7 +646,7 @@ class Eva(nn.Module):
def prune_intermediate_layers(
self,
indices: Union[int, List[int], Tuple[int]] = 1,
indices: Union[int, List[int]] = 1,
prune_norm: bool = False,
prune_head: bool = True,
):

View File

@ -1251,7 +1251,7 @@ class FastVit(nn.Module):
def forward_intermediates(
self,
x: torch.Tensor,
indices: Optional[Union[int, List[int], Tuple[int]]] = None,
indices: Optional[Union[int, List[int]]] = None,
norm: bool = False,
stop_early: bool = False,
output_fmt: str = 'NCHW',
@ -1296,7 +1296,7 @@ class FastVit(nn.Module):
def prune_intermediate_layers(
self,
indices: Union[int, List[int], Tuple[int]] = 1,
indices: Union[int, List[int]] = 1,
prune_norm: bool = False,
prune_head: bool = True,
):

View File

@ -669,7 +669,7 @@ class Hiera(nn.Module):
self,
x: torch.Tensor,
mask: Optional[torch.Tensor] = None,
indices: Optional[Union[int, List[int], Tuple[int]]] = None,
indices: Optional[Union[int, List[int]]] = None,
norm: bool = False,
stop_early: bool = True,
output_fmt: str = 'NCHW',
@ -722,7 +722,7 @@ class Hiera(nn.Module):
def prune_intermediate_layers(
self,
indices: Union[int, List[int], Tuple[int]] = 1,
indices: Union[int, List[int]] = 1,
prune_norm: bool = False,
prune_head: bool = True,
):

View File

@ -638,7 +638,7 @@ class Levit(nn.Module):
def forward_intermediates(
self,
x: torch.Tensor,
indices: Optional[Union[int, List[int], Tuple[int]]] = None,
indices: Optional[Union[int, List[int]]] = None,
norm: bool = False,
stop_early: bool = False,
output_fmt: str = 'NCHW',
@ -687,7 +687,7 @@ class Levit(nn.Module):
def prune_intermediate_layers(
self,
indices: Union[int, List[int], Tuple[int]] = 1,
indices: Union[int, List[int]] = 1,
prune_norm: bool = False,
prune_head: bool = True,
):

View File

@ -1256,7 +1256,7 @@ class MaxxVit(nn.Module):
def forward_intermediates(
self,
x: torch.Tensor,
indices: Optional[Union[int, List[int], Tuple[int]]] = None,
indices: Optional[Union[int, List[int]]] = None,
norm: bool = False,
stop_early: bool = False,
output_fmt: str = 'NCHW',
@ -1308,7 +1308,7 @@ class MaxxVit(nn.Module):
def prune_intermediate_layers(
self,
indices: Union[int, List[int], Tuple[int]] = 1,
indices: Union[int, List[int]] = 1,
prune_norm: bool = False,
prune_head: bool = True,
):

View File

@ -265,7 +265,7 @@ class MlpMixer(nn.Module):
def forward_intermediates(
self,
x: torch.Tensor,
indices: Optional[Union[int, List[int], Tuple[int]]] = None,
indices: Optional[Union[int, List[int]]] = None,
norm: bool = False,
stop_early: bool = False,
output_fmt: str = 'NCHW',
@ -318,7 +318,7 @@ class MlpMixer(nn.Module):
def prune_intermediate_layers(
self,
indices: Union[int, List[int], Tuple[int]] = 1,
indices: Union[int, List[int]] = 1,
prune_norm: bool = False,
prune_head: bool = True,
):

View File

@ -170,7 +170,7 @@ class MobileNetV3(nn.Module):
def forward_intermediates(
self,
x: torch.Tensor,
indices: Optional[Union[int, List[int], Tuple[int]]] = None,
indices: Optional[Union[int, List[int]]] = None,
norm: bool = False,
stop_early: bool = False,
output_fmt: str = 'NCHW',
@ -225,7 +225,7 @@ class MobileNetV3(nn.Module):
def prune_intermediate_layers(
self,
indices: Union[int, List[int], Tuple[int]] = 1,
indices: Union[int, List[int]] = 1,
prune_norm: bool = False,
prune_head: bool = True,
extra_blocks: bool = False,

View File

@ -837,7 +837,7 @@ class MultiScaleVit(nn.Module):
def forward_intermediates(
self,
x: torch.Tensor,
indices: Optional[Union[int, List[int], Tuple[int]]] = None,
indices: Optional[Union[int, List[int]]] = None,
norm: bool = False,
stop_early: bool = False,
output_fmt: str = 'NCHW',
@ -893,7 +893,7 @@ class MultiScaleVit(nn.Module):
def prune_intermediate_layers(
self,
indices: Union[int, List[int], Tuple[int]] = 1,
indices: Union[int, List[int]] = 1,
prune_norm: bool = False,
prune_head: bool = True,
):

View File

@ -520,7 +520,7 @@ class RegNet(nn.Module):
def forward_intermediates(
self,
x: torch.Tensor,
indices: Optional[Union[int, List[int], Tuple[int]]] = None,
indices: Optional[Union[int, List[int]]] = None,
norm: bool = False,
stop_early: bool = False,
output_fmt: str = 'NCHW',
@ -567,7 +567,7 @@ class RegNet(nn.Module):
def prune_intermediate_layers(
self,
indices: Union[int, List[int], Tuple[int]] = 1,
indices: Union[int, List[int]] = 1,
prune_norm: bool = False,
prune_head: bool = True,
):

View File

@ -548,7 +548,7 @@ class ResNet(nn.Module):
def forward_intermediates(
self,
x: torch.Tensor,
indices: Optional[Union[int, List[int], Tuple[int]]] = None,
indices: Optional[Union[int, List[int]]] = None,
norm: bool = False,
stop_early: bool = False,
output_fmt: str = 'NCHW',
@ -595,7 +595,7 @@ class ResNet(nn.Module):
def prune_intermediate_layers(
self,
indices: Union[int, List[int], Tuple[int]] = 1,
indices: Union[int, List[int]] = 1,
prune_norm: bool = False,
prune_head: bool = True,
):

View File

@ -611,7 +611,7 @@ class SwinTransformer(nn.Module):
def forward_intermediates(
self,
x: torch.Tensor,
indices: Optional[Union[int, List[int], Tuple[int]]] = None,
indices: Optional[Union[int, List[int]]] = None,
norm: bool = False,
stop_early: bool = False,
output_fmt: str = 'NCHW',
@ -660,7 +660,7 @@ class SwinTransformer(nn.Module):
def prune_intermediate_layers(
self,
indices: Union[int, List[int], Tuple[int]] = 1,
indices: Union[int, List[int]] = 1,
prune_norm: bool = False,
prune_head: bool = True,
):

View File

@ -612,7 +612,7 @@ class SwinTransformerV2(nn.Module):
def forward_intermediates(
self,
x: torch.Tensor,
indices: Optional[Union[int, List[int], Tuple[int]]] = None,
indices: Optional[Union[int, List[int]]] = None,
norm: bool = False,
stop_early: bool = False,
output_fmt: str = 'NCHW',
@ -661,7 +661,7 @@ class SwinTransformerV2(nn.Module):
def prune_intermediate_layers(
self,
indices: Union[int, List[int], Tuple[int]] = 1,
indices: Union[int, List[int]] = 1,
prune_norm: bool = False,
prune_head: bool = True,
):

View File

@ -722,7 +722,7 @@ class SwinTransformerV2Cr(nn.Module):
def forward_intermediates(
self,
x: torch.Tensor,
indices: Optional[Union[int, List[int], Tuple[int]]] = None,
indices: Optional[Union[int, List[int]]] = None,
norm: bool = False,
stop_early: bool = False,
output_fmt: str = 'NCHW',
@ -763,7 +763,7 @@ class SwinTransformerV2Cr(nn.Module):
def prune_intermediate_layers(
self,
indices: Union[int, List[int], Tuple[int]] = 1,
indices: Union[int, List[int]] = 1,
prune_norm: bool = False,
prune_head: bool = True,
):

View File

@ -407,7 +407,7 @@ class Twins(nn.Module):
def forward_intermediates(
self,
x: torch.Tensor,
indices: Optional[Union[int, List[int], Tuple[int]]] = None,
indices: Optional[Union[int, List[int]]] = None,
norm: bool = False,
stop_early: bool = False,
output_fmt: str = 'NCHW',
@ -461,7 +461,7 @@ class Twins(nn.Module):
def prune_intermediate_layers(
self,
indices: Union[int, List[int], Tuple[int]] = 1,
indices: Union[int, List[int]] = 1,
prune_norm: bool = False,
prune_head: bool = True,
):

View File

@ -673,7 +673,7 @@ class VisionTransformer(nn.Module):
def forward_intermediates(
self,
x: torch.Tensor,
indices: Optional[Union[int, List[int], Tuple[int]]] = None,
indices: Optional[Union[int, List[int]]] = None,
return_prefix_tokens: bool = False,
norm: bool = False,
stop_early: bool = False,
@ -737,7 +737,7 @@ class VisionTransformer(nn.Module):
def prune_intermediate_layers(
self,
indices: Union[int, List[int], Tuple[int]] = 1,
indices: Union[int, List[int]] = 1,
prune_norm: bool = False,
prune_head: bool = True,
):

View File

@ -391,7 +391,7 @@ class VisionTransformerRelPos(nn.Module):
def forward_intermediates(
self,
x: torch.Tensor,
indices: Optional[Union[int, List[int], Tuple[int]]] = None,
indices: Optional[Union[int, List[int]]] = None,
return_prefix_tokens: bool = False,
norm: bool = False,
stop_early: bool = False,
@ -455,7 +455,7 @@ class VisionTransformerRelPos(nn.Module):
def prune_intermediate_layers(
self,
indices: Union[int, List[int], Tuple[int]] = 1,
indices: Union[int, List[int]] = 1,
prune_norm: bool = False,
prune_head: bool = True,
):

View File

@ -542,7 +542,7 @@ class VisionTransformerSAM(nn.Module):
def forward_intermediates(
self,
x: torch.Tensor,
indices: Optional[Union[int, List[int], Tuple[int]]] = None,
indices: Optional[Union[int, List[int]]] = None,
norm: bool = False,
stop_early: bool = False,
output_fmt: str = 'NCHW',
@ -597,7 +597,7 @@ class VisionTransformerSAM(nn.Module):
def prune_intermediate_layers(
self,
indices: Optional[Union[int, List[int], Tuple[int]]] = None,
indices: Optional[Union[int, List[int]]] = None,
prune_norm: bool = False,
prune_head: bool = True,
):

View File

@ -711,7 +711,7 @@ class VOLO(nn.Module):
def forward_intermediates(
self,
x: torch.Tensor,
indices: Optional[Union[int, List[int], Tuple[int]]] = None,
indices: Optional[Union[int, List[int]]] = None,
norm: bool = False,
stop_early: bool = False,
output_fmt: str = 'NCHW',
@ -772,7 +772,7 @@ class VOLO(nn.Module):
def prune_intermediate_layers(
self,
indices: Union[int, List[int], Tuple[int]] = 1,
indices: Union[int, List[int]] = 1,
prune_norm: bool = False,
prune_head: bool = True,
):

View File

@ -442,7 +442,7 @@ class Xcit(nn.Module):
def forward_intermediates(
self,
x: torch.Tensor,
indices: Optional[Union[int, List[int], Tuple[int]]] = None,
indices: Optional[Union[int, List[int]]] = None,
norm: bool = False,
stop_early: bool = False,
output_fmt: str = 'NCHW',
@ -502,7 +502,7 @@ class Xcit(nn.Module):
def prune_intermediate_layers(
self,
indices: Union[int, List[int], Tuple[int]] = 1,
indices: Union[int, List[int]] = 1,
prune_norm: bool = False,
prune_head: bool = True,
):