Fix issue where feature out_indices out of order after wrapping with FeatureGetterNet due to use of set()

This commit is contained in:
Ross Wightman 2024-07-22 13:33:30 -07:00
parent a1996ec0f4
commit d2240745d3
26 changed files with 85 additions and 86 deletions

View File

@ -11,7 +11,7 @@ Hacked together by / Copyright 2020 Ross Wightman
from collections import OrderedDict, defaultdict
from copy import deepcopy
from functools import partial
from typing import Dict, List, Optional, Sequence, Set, Tuple, Union
from typing import Dict, List, Optional, Sequence, Tuple, Union
import torch
import torch.nn as nn
@ -26,44 +26,44 @@ __all__ = [
]
def _take_indices(
num_blocks: int,
n: Optional[Union[int, List[int], Tuple[int]]],
) -> Tuple[Set[int], int]:
if isinstance(n, int):
assert n >= 0
take_indices = {x for x in range(num_blocks - n, num_blocks)}
else:
take_indices = {num_blocks + idx if idx < 0 else idx for idx in n}
return take_indices, max(take_indices)
def _take_indices_jit(
num_blocks: int,
n: Union[int, List[int], Tuple[int]],
) -> Tuple[List[int], int]:
if isinstance(n, int):
assert n >= 0
take_indices = [num_blocks - n + i for i in range(n)]
elif isinstance(n, tuple):
# splitting this up is silly, but needed for torchscript type resolution of n
take_indices = [num_blocks + idx if idx < 0 else idx for idx in n]
else:
take_indices = [num_blocks + idx if idx < 0 else idx for idx in n]
return take_indices, max(take_indices)
def feature_take_indices(
num_blocks: int,
indices: Optional[Union[int, List[int], Tuple[int]]] = None,
num_features: int,
indices: Optional[Union[int, List[int]]] = None,
as_set: bool = False,
) -> Tuple[List[int], int]:
""" Determine the absolute feature indices to 'take' from.
Note: This function can be called in forwar() so must be torchscript compatible,
which requires some incomplete typing and workaround hacks.
Args:
num_features: total number of features to select from
indices: indices to select,
None -> select all
int -> select last n
list/tuple of int -> return specified (-ve indices specify from end)
as_set: return as a set
Returns:
List (or set) of absolute (from beginning) indices, Maximum index
"""
if indices is None:
indices = num_blocks # all blocks if None
if torch.jit.is_scripting():
return _take_indices_jit(num_blocks, indices)
indices = num_features # all features if None
if isinstance(indices, int):
assert indices >= 0
# convert int -> last n indices
take_indices = [num_features - indices + i for i in range(indices)]
elif isinstance(indices, tuple):
# duplicating this is silly, but needed for torchscript type resolution of n
take_indices = [num_features + idx if idx < 0 else idx for idx in indices]
else:
# NOTE non-jit returns Set[int] instead of List[int] but torchscript can't handle that anno
return _take_indices(num_blocks, indices)
take_indices = [num_features + idx if idx < 0 else idx for idx in indices]
if not torch.jit.is_scripting() and as_set:
return set(take_indices), max(take_indices)
return take_indices, max(take_indices)
def _out_indices_as_tuple(x: Union[int, Tuple[int, ...]]) -> Tuple[int, ...]:
@ -464,7 +464,6 @@ class FeatureGetterNet(nn.ModuleDict):
out_indices,
prune_norm=not norm,
)
out_indices = list(out_indices)
self.feature_info = _get_feature_info(model, out_indices)
self.model = model
self.out_indices = out_indices

View File

@ -404,7 +404,7 @@ class Beit(nn.Module):
def forward_intermediates(
self,
x: torch.Tensor,
indices: Optional[Union[int, List[int], Tuple[int]]] = None,
indices: Optional[Union[int, List[int]]] = None,
return_prefix_tokens: bool = False,
norm: bool = False,
stop_early: bool = False,
@ -470,7 +470,7 @@ class Beit(nn.Module):
def prune_intermediate_layers(
self,
indices: Union[int, List[int], Tuple[int]] = 1,
indices: Union[int, List[int]] = 1,
prune_norm: bool = False,
prune_head: bool = True,
):

View File

@ -1343,7 +1343,7 @@ class ByobNet(nn.Module):
def forward_intermediates(
self,
x: torch.Tensor,
indices: Optional[Union[int, List[int], Tuple[int]]] = None,
indices: Optional[Union[int, List[int]]] = None,
norm: bool = False,
stop_early: bool = False,
output_fmt: str = 'NCHW',
@ -1401,7 +1401,7 @@ class ByobNet(nn.Module):
def prune_intermediate_layers(
self,
indices: Union[int, List[int], Tuple[int]] = 1,
indices: Union[int, List[int]] = 1,
prune_norm: bool = False,
prune_head: bool = True,
):

View File

@ -341,7 +341,7 @@ class Cait(nn.Module):
def forward_intermediates(
self,
x: torch.Tensor,
indices: Optional[Union[int, List[int], Tuple[int]]] = None,
indices: Optional[Union[int, List[int]]] = None,
norm: bool = False,
stop_early: bool = False,
output_fmt: str = 'NCHW',
@ -398,7 +398,7 @@ class Cait(nn.Module):
def prune_intermediate_layers(
self,
indices: Union[int, List[int], Tuple[int]] = 1,
indices: Union[int, List[int]] = 1,
prune_norm: bool = False,
prune_head: bool = True,
):

View File

@ -412,7 +412,7 @@ class ConvNeXt(nn.Module):
def forward_intermediates(
self,
x: torch.Tensor,
indices: Optional[Union[int, List[int], Tuple[int]]] = None,
indices: Optional[Union[int, List[int]]] = None,
norm: bool = False,
stop_early: bool = False,
output_fmt: str = 'NCHW',
@ -460,7 +460,7 @@ class ConvNeXt(nn.Module):
def prune_intermediate_layers(
self,
indices: Union[int, List[int], Tuple[int]] = 1,
indices: Union[int, List[int]] = 1,
prune_norm: bool = False,
prune_head: bool = True,
):

View File

@ -463,7 +463,7 @@ class EfficientFormer(nn.Module):
def forward_intermediates(
self,
x: torch.Tensor,
indices: Optional[Union[int, List[int], Tuple[int]]] = None,
indices: Optional[Union[int, List[int]]] = None,
norm: bool = False,
stop_early: bool = False,
output_fmt: str = 'NCHW',
@ -516,7 +516,7 @@ class EfficientFormer(nn.Module):
def prune_intermediate_layers(
self,
indices: Union[int, List[int], Tuple[int]] = 1,
indices: Union[int, List[int]] = 1,
prune_norm: bool = False,
prune_head: bool = True,
):

View File

@ -165,7 +165,7 @@ class EfficientNet(nn.Module):
def forward_intermediates(
self,
x: torch.Tensor,
indices: Optional[Union[int, List[int], Tuple[int]]] = None,
indices: Optional[Union[int, List[int]]] = None,
norm: bool = False,
stop_early: bool = False,
output_fmt: str = 'NCHW',
@ -221,7 +221,7 @@ class EfficientNet(nn.Module):
def prune_intermediate_layers(
self,
indices: Union[int, List[int], Tuple[int]] = 1,
indices: Union[int, List[int]] = 1,
prune_norm: bool = False,
prune_head: bool = True,
extra_blocks: bool = False,

View File

@ -589,7 +589,7 @@ class Eva(nn.Module):
def forward_intermediates(
self,
x: torch.Tensor,
indices: Optional[Union[int, List[int], Tuple[int]]] = None,
indices: Optional[Union[int, List[int]]] = None,
return_prefix_tokens: bool = False,
norm: bool = False,
stop_early: bool = False,
@ -646,7 +646,7 @@ class Eva(nn.Module):
def prune_intermediate_layers(
self,
indices: Union[int, List[int], Tuple[int]] = 1,
indices: Union[int, List[int]] = 1,
prune_norm: bool = False,
prune_head: bool = True,
):

View File

@ -1251,7 +1251,7 @@ class FastVit(nn.Module):
def forward_intermediates(
self,
x: torch.Tensor,
indices: Optional[Union[int, List[int], Tuple[int]]] = None,
indices: Optional[Union[int, List[int]]] = None,
norm: bool = False,
stop_early: bool = False,
output_fmt: str = 'NCHW',
@ -1296,7 +1296,7 @@ class FastVit(nn.Module):
def prune_intermediate_layers(
self,
indices: Union[int, List[int], Tuple[int]] = 1,
indices: Union[int, List[int]] = 1,
prune_norm: bool = False,
prune_head: bool = True,
):

View File

@ -669,7 +669,7 @@ class Hiera(nn.Module):
self,
x: torch.Tensor,
mask: Optional[torch.Tensor] = None,
indices: Optional[Union[int, List[int], Tuple[int]]] = None,
indices: Optional[Union[int, List[int]]] = None,
norm: bool = False,
stop_early: bool = True,
output_fmt: str = 'NCHW',
@ -722,7 +722,7 @@ class Hiera(nn.Module):
def prune_intermediate_layers(
self,
indices: Union[int, List[int], Tuple[int]] = 1,
indices: Union[int, List[int]] = 1,
prune_norm: bool = False,
prune_head: bool = True,
):

View File

@ -638,7 +638,7 @@ class Levit(nn.Module):
def forward_intermediates(
self,
x: torch.Tensor,
indices: Optional[Union[int, List[int], Tuple[int]]] = None,
indices: Optional[Union[int, List[int]]] = None,
norm: bool = False,
stop_early: bool = False,
output_fmt: str = 'NCHW',
@ -687,7 +687,7 @@ class Levit(nn.Module):
def prune_intermediate_layers(
self,
indices: Union[int, List[int], Tuple[int]] = 1,
indices: Union[int, List[int]] = 1,
prune_norm: bool = False,
prune_head: bool = True,
):

View File

@ -1256,7 +1256,7 @@ class MaxxVit(nn.Module):
def forward_intermediates(
self,
x: torch.Tensor,
indices: Optional[Union[int, List[int], Tuple[int]]] = None,
indices: Optional[Union[int, List[int]]] = None,
norm: bool = False,
stop_early: bool = False,
output_fmt: str = 'NCHW',
@ -1308,7 +1308,7 @@ class MaxxVit(nn.Module):
def prune_intermediate_layers(
self,
indices: Union[int, List[int], Tuple[int]] = 1,
indices: Union[int, List[int]] = 1,
prune_norm: bool = False,
prune_head: bool = True,
):

View File

@ -265,7 +265,7 @@ class MlpMixer(nn.Module):
def forward_intermediates(
self,
x: torch.Tensor,
indices: Optional[Union[int, List[int], Tuple[int]]] = None,
indices: Optional[Union[int, List[int]]] = None,
norm: bool = False,
stop_early: bool = False,
output_fmt: str = 'NCHW',
@ -318,7 +318,7 @@ class MlpMixer(nn.Module):
def prune_intermediate_layers(
self,
indices: Union[int, List[int], Tuple[int]] = 1,
indices: Union[int, List[int]] = 1,
prune_norm: bool = False,
prune_head: bool = True,
):

View File

@ -170,7 +170,7 @@ class MobileNetV3(nn.Module):
def forward_intermediates(
self,
x: torch.Tensor,
indices: Optional[Union[int, List[int], Tuple[int]]] = None,
indices: Optional[Union[int, List[int]]] = None,
norm: bool = False,
stop_early: bool = False,
output_fmt: str = 'NCHW',
@ -225,7 +225,7 @@ class MobileNetV3(nn.Module):
def prune_intermediate_layers(
self,
indices: Union[int, List[int], Tuple[int]] = 1,
indices: Union[int, List[int]] = 1,
prune_norm: bool = False,
prune_head: bool = True,
extra_blocks: bool = False,

View File

@ -837,7 +837,7 @@ class MultiScaleVit(nn.Module):
def forward_intermediates(
self,
x: torch.Tensor,
indices: Optional[Union[int, List[int], Tuple[int]]] = None,
indices: Optional[Union[int, List[int]]] = None,
norm: bool = False,
stop_early: bool = False,
output_fmt: str = 'NCHW',
@ -893,7 +893,7 @@ class MultiScaleVit(nn.Module):
def prune_intermediate_layers(
self,
indices: Union[int, List[int], Tuple[int]] = 1,
indices: Union[int, List[int]] = 1,
prune_norm: bool = False,
prune_head: bool = True,
):

View File

@ -520,7 +520,7 @@ class RegNet(nn.Module):
def forward_intermediates(
self,
x: torch.Tensor,
indices: Optional[Union[int, List[int], Tuple[int]]] = None,
indices: Optional[Union[int, List[int]]] = None,
norm: bool = False,
stop_early: bool = False,
output_fmt: str = 'NCHW',
@ -567,7 +567,7 @@ class RegNet(nn.Module):
def prune_intermediate_layers(
self,
indices: Union[int, List[int], Tuple[int]] = 1,
indices: Union[int, List[int]] = 1,
prune_norm: bool = False,
prune_head: bool = True,
):

View File

@ -548,7 +548,7 @@ class ResNet(nn.Module):
def forward_intermediates(
self,
x: torch.Tensor,
indices: Optional[Union[int, List[int], Tuple[int]]] = None,
indices: Optional[Union[int, List[int]]] = None,
norm: bool = False,
stop_early: bool = False,
output_fmt: str = 'NCHW',
@ -595,7 +595,7 @@ class ResNet(nn.Module):
def prune_intermediate_layers(
self,
indices: Union[int, List[int], Tuple[int]] = 1,
indices: Union[int, List[int]] = 1,
prune_norm: bool = False,
prune_head: bool = True,
):

View File

@ -611,7 +611,7 @@ class SwinTransformer(nn.Module):
def forward_intermediates(
self,
x: torch.Tensor,
indices: Optional[Union[int, List[int], Tuple[int]]] = None,
indices: Optional[Union[int, List[int]]] = None,
norm: bool = False,
stop_early: bool = False,
output_fmt: str = 'NCHW',
@ -660,7 +660,7 @@ class SwinTransformer(nn.Module):
def prune_intermediate_layers(
self,
indices: Union[int, List[int], Tuple[int]] = 1,
indices: Union[int, List[int]] = 1,
prune_norm: bool = False,
prune_head: bool = True,
):

View File

@ -612,7 +612,7 @@ class SwinTransformerV2(nn.Module):
def forward_intermediates(
self,
x: torch.Tensor,
indices: Optional[Union[int, List[int], Tuple[int]]] = None,
indices: Optional[Union[int, List[int]]] = None,
norm: bool = False,
stop_early: bool = False,
output_fmt: str = 'NCHW',
@ -661,7 +661,7 @@ class SwinTransformerV2(nn.Module):
def prune_intermediate_layers(
self,
indices: Union[int, List[int], Tuple[int]] = 1,
indices: Union[int, List[int]] = 1,
prune_norm: bool = False,
prune_head: bool = True,
):

View File

@ -722,7 +722,7 @@ class SwinTransformerV2Cr(nn.Module):
def forward_intermediates(
self,
x: torch.Tensor,
indices: Optional[Union[int, List[int], Tuple[int]]] = None,
indices: Optional[Union[int, List[int]]] = None,
norm: bool = False,
stop_early: bool = False,
output_fmt: str = 'NCHW',
@ -763,7 +763,7 @@ class SwinTransformerV2Cr(nn.Module):
def prune_intermediate_layers(
self,
indices: Union[int, List[int], Tuple[int]] = 1,
indices: Union[int, List[int]] = 1,
prune_norm: bool = False,
prune_head: bool = True,
):

View File

@ -407,7 +407,7 @@ class Twins(nn.Module):
def forward_intermediates(
self,
x: torch.Tensor,
indices: Optional[Union[int, List[int], Tuple[int]]] = None,
indices: Optional[Union[int, List[int]]] = None,
norm: bool = False,
stop_early: bool = False,
output_fmt: str = 'NCHW',
@ -461,7 +461,7 @@ class Twins(nn.Module):
def prune_intermediate_layers(
self,
indices: Union[int, List[int], Tuple[int]] = 1,
indices: Union[int, List[int]] = 1,
prune_norm: bool = False,
prune_head: bool = True,
):

View File

@ -673,7 +673,7 @@ class VisionTransformer(nn.Module):
def forward_intermediates(
self,
x: torch.Tensor,
indices: Optional[Union[int, List[int], Tuple[int]]] = None,
indices: Optional[Union[int, List[int]]] = None,
return_prefix_tokens: bool = False,
norm: bool = False,
stop_early: bool = False,
@ -737,7 +737,7 @@ class VisionTransformer(nn.Module):
def prune_intermediate_layers(
self,
indices: Union[int, List[int], Tuple[int]] = 1,
indices: Union[int, List[int]] = 1,
prune_norm: bool = False,
prune_head: bool = True,
):

View File

@ -391,7 +391,7 @@ class VisionTransformerRelPos(nn.Module):
def forward_intermediates(
self,
x: torch.Tensor,
indices: Optional[Union[int, List[int], Tuple[int]]] = None,
indices: Optional[Union[int, List[int]]] = None,
return_prefix_tokens: bool = False,
norm: bool = False,
stop_early: bool = False,
@ -455,7 +455,7 @@ class VisionTransformerRelPos(nn.Module):
def prune_intermediate_layers(
self,
indices: Union[int, List[int], Tuple[int]] = 1,
indices: Union[int, List[int]] = 1,
prune_norm: bool = False,
prune_head: bool = True,
):

View File

@ -542,7 +542,7 @@ class VisionTransformerSAM(nn.Module):
def forward_intermediates(
self,
x: torch.Tensor,
indices: Optional[Union[int, List[int], Tuple[int]]] = None,
indices: Optional[Union[int, List[int]]] = None,
norm: bool = False,
stop_early: bool = False,
output_fmt: str = 'NCHW',
@ -597,7 +597,7 @@ class VisionTransformerSAM(nn.Module):
def prune_intermediate_layers(
self,
indices: Optional[Union[int, List[int], Tuple[int]]] = None,
indices: Optional[Union[int, List[int]]] = None,
prune_norm: bool = False,
prune_head: bool = True,
):

View File

@ -711,7 +711,7 @@ class VOLO(nn.Module):
def forward_intermediates(
self,
x: torch.Tensor,
indices: Optional[Union[int, List[int], Tuple[int]]] = None,
indices: Optional[Union[int, List[int]]] = None,
norm: bool = False,
stop_early: bool = False,
output_fmt: str = 'NCHW',
@ -772,7 +772,7 @@ class VOLO(nn.Module):
def prune_intermediate_layers(
self,
indices: Union[int, List[int], Tuple[int]] = 1,
indices: Union[int, List[int]] = 1,
prune_norm: bool = False,
prune_head: bool = True,
):

View File

@ -442,7 +442,7 @@ class Xcit(nn.Module):
def forward_intermediates(
self,
x: torch.Tensor,
indices: Optional[Union[int, List[int], Tuple[int]]] = None,
indices: Optional[Union[int, List[int]]] = None,
norm: bool = False,
stop_early: bool = False,
output_fmt: str = 'NCHW',
@ -502,7 +502,7 @@ class Xcit(nn.Module):
def prune_intermediate_layers(
self,
indices: Union[int, List[int], Tuple[int]] = 1,
indices: Union[int, List[int]] = 1,
prune_norm: bool = False,
prune_head: bool = True,
):