From d2240745d3614b085f3cba43293575993d7c0848 Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Mon, 22 Jul 2024 13:33:30 -0700 Subject: [PATCH 1/2] Fix issue where feature out_indices out of order after wrapping with FeatureGetterNet due to use of set() --- timm/models/_features.py | 71 ++++++++++++------------ timm/models/beit.py | 4 +- timm/models/byobnet.py | 4 +- timm/models/cait.py | 4 +- timm/models/convnext.py | 4 +- timm/models/efficientformer.py | 4 +- timm/models/efficientnet.py | 4 +- timm/models/eva.py | 4 +- timm/models/fastvit.py | 4 +- timm/models/hiera.py | 4 +- timm/models/levit.py | 4 +- timm/models/maxxvit.py | 4 +- timm/models/mlp_mixer.py | 4 +- timm/models/mobilenetv3.py | 4 +- timm/models/mvitv2.py | 4 +- timm/models/regnet.py | 4 +- timm/models/resnet.py | 4 +- timm/models/swin_transformer.py | 4 +- timm/models/swin_transformer_v2.py | 4 +- timm/models/swin_transformer_v2_cr.py | 4 +- timm/models/twins.py | 4 +- timm/models/vision_transformer.py | 4 +- timm/models/vision_transformer_relpos.py | 4 +- timm/models/vision_transformer_sam.py | 4 +- timm/models/volo.py | 4 +- timm/models/xcit.py | 4 +- 26 files changed, 85 insertions(+), 86 deletions(-) diff --git a/timm/models/_features.py b/timm/models/_features.py index 12f0ab37..cdd7216a 100644 --- a/timm/models/_features.py +++ b/timm/models/_features.py @@ -11,7 +11,7 @@ Hacked together by / Copyright 2020 Ross Wightman from collections import OrderedDict, defaultdict from copy import deepcopy from functools import partial -from typing import Dict, List, Optional, Sequence, Set, Tuple, Union +from typing import Dict, List, Optional, Sequence, Tuple, Union import torch import torch.nn as nn @@ -26,44 +26,44 @@ __all__ = [ ] -def _take_indices( - num_blocks: int, - n: Optional[Union[int, List[int], Tuple[int]]], -) -> Tuple[Set[int], int]: - if isinstance(n, int): - assert n >= 0 - take_indices = {x for x in range(num_blocks - n, num_blocks)} - else: - take_indices = {num_blocks + idx if idx < 0 else idx for idx in n} - return take_indices, max(take_indices) - - -def _take_indices_jit( - num_blocks: int, - n: Union[int, List[int], Tuple[int]], -) -> Tuple[List[int], int]: - if isinstance(n, int): - assert n >= 0 - take_indices = [num_blocks - n + i for i in range(n)] - elif isinstance(n, tuple): - # splitting this up is silly, but needed for torchscript type resolution of n - take_indices = [num_blocks + idx if idx < 0 else idx for idx in n] - else: - take_indices = [num_blocks + idx if idx < 0 else idx for idx in n] - return take_indices, max(take_indices) - - def feature_take_indices( - num_blocks: int, - indices: Optional[Union[int, List[int], Tuple[int]]] = None, + num_features: int, + indices: Optional[Union[int, List[int]]] = None, + as_set: bool = False, ) -> Tuple[List[int], int]: + """ Determine the absolute feature indices to 'take' from. + + Note: This function can be called in forwar() so must be torchscript compatible, + which requires some incomplete typing and workaround hacks. + + Args: + num_features: total number of features to select from + indices: indices to select, + None -> select all + int -> select last n + list/tuple of int -> return specified (-ve indices specify from end) + as_set: return as a set + + Returns: + List (or set) of absolute (from beginning) indices, Maximum index + """ if indices is None: - indices = num_blocks # all blocks if None - if torch.jit.is_scripting(): - return _take_indices_jit(num_blocks, indices) + indices = num_features # all features if None + + if isinstance(indices, int): + assert indices >= 0 + # convert int -> last n indices + take_indices = [num_features - indices + i for i in range(indices)] + elif isinstance(indices, tuple): + # duplicating this is silly, but needed for torchscript type resolution of n + take_indices = [num_features + idx if idx < 0 else idx for idx in indices] else: - # NOTE non-jit returns Set[int] instead of List[int] but torchscript can't handle that anno - return _take_indices(num_blocks, indices) + take_indices = [num_features + idx if idx < 0 else idx for idx in indices] + + if not torch.jit.is_scripting() and as_set: + return set(take_indices), max(take_indices) + + return take_indices, max(take_indices) def _out_indices_as_tuple(x: Union[int, Tuple[int, ...]]) -> Tuple[int, ...]: @@ -464,7 +464,6 @@ class FeatureGetterNet(nn.ModuleDict): out_indices, prune_norm=not norm, ) - out_indices = list(out_indices) self.feature_info = _get_feature_info(model, out_indices) self.model = model self.out_indices = out_indices diff --git a/timm/models/beit.py b/timm/models/beit.py index 57007cd7..5e9d118a 100644 --- a/timm/models/beit.py +++ b/timm/models/beit.py @@ -404,7 +404,7 @@ class Beit(nn.Module): def forward_intermediates( self, x: torch.Tensor, - indices: Optional[Union[int, List[int], Tuple[int]]] = None, + indices: Optional[Union[int, List[int]]] = None, return_prefix_tokens: bool = False, norm: bool = False, stop_early: bool = False, @@ -470,7 +470,7 @@ class Beit(nn.Module): def prune_intermediate_layers( self, - indices: Union[int, List[int], Tuple[int]] = 1, + indices: Union[int, List[int]] = 1, prune_norm: bool = False, prune_head: bool = True, ): diff --git a/timm/models/byobnet.py b/timm/models/byobnet.py index 32ecfa44..03fcb082 100644 --- a/timm/models/byobnet.py +++ b/timm/models/byobnet.py @@ -1343,7 +1343,7 @@ class ByobNet(nn.Module): def forward_intermediates( self, x: torch.Tensor, - indices: Optional[Union[int, List[int], Tuple[int]]] = None, + indices: Optional[Union[int, List[int]]] = None, norm: bool = False, stop_early: bool = False, output_fmt: str = 'NCHW', @@ -1401,7 +1401,7 @@ class ByobNet(nn.Module): def prune_intermediate_layers( self, - indices: Union[int, List[int], Tuple[int]] = 1, + indices: Union[int, List[int]] = 1, prune_norm: bool = False, prune_head: bool = True, ): diff --git a/timm/models/cait.py b/timm/models/cait.py index 78a7adc9..28e14ec7 100644 --- a/timm/models/cait.py +++ b/timm/models/cait.py @@ -341,7 +341,7 @@ class Cait(nn.Module): def forward_intermediates( self, x: torch.Tensor, - indices: Optional[Union[int, List[int], Tuple[int]]] = None, + indices: Optional[Union[int, List[int]]] = None, norm: bool = False, stop_early: bool = False, output_fmt: str = 'NCHW', @@ -398,7 +398,7 @@ class Cait(nn.Module): def prune_intermediate_layers( self, - indices: Union[int, List[int], Tuple[int]] = 1, + indices: Union[int, List[int]] = 1, prune_norm: bool = False, prune_head: bool = True, ): diff --git a/timm/models/convnext.py b/timm/models/convnext.py index a09653cf..b3f3350e 100644 --- a/timm/models/convnext.py +++ b/timm/models/convnext.py @@ -412,7 +412,7 @@ class ConvNeXt(nn.Module): def forward_intermediates( self, x: torch.Tensor, - indices: Optional[Union[int, List[int], Tuple[int]]] = None, + indices: Optional[Union[int, List[int]]] = None, norm: bool = False, stop_early: bool = False, output_fmt: str = 'NCHW', @@ -460,7 +460,7 @@ class ConvNeXt(nn.Module): def prune_intermediate_layers( self, - indices: Union[int, List[int], Tuple[int]] = 1, + indices: Union[int, List[int]] = 1, prune_norm: bool = False, prune_head: bool = True, ): diff --git a/timm/models/efficientformer.py b/timm/models/efficientformer.py index 513f7e44..a759dc41 100644 --- a/timm/models/efficientformer.py +++ b/timm/models/efficientformer.py @@ -463,7 +463,7 @@ class EfficientFormer(nn.Module): def forward_intermediates( self, x: torch.Tensor, - indices: Optional[Union[int, List[int], Tuple[int]]] = None, + indices: Optional[Union[int, List[int]]] = None, norm: bool = False, stop_early: bool = False, output_fmt: str = 'NCHW', @@ -516,7 +516,7 @@ class EfficientFormer(nn.Module): def prune_intermediate_layers( self, - indices: Union[int, List[int], Tuple[int]] = 1, + indices: Union[int, List[int]] = 1, prune_norm: bool = False, prune_head: bool = True, ): diff --git a/timm/models/efficientnet.py b/timm/models/efficientnet.py index 6ed9407d..6ca1e447 100644 --- a/timm/models/efficientnet.py +++ b/timm/models/efficientnet.py @@ -165,7 +165,7 @@ class EfficientNet(nn.Module): def forward_intermediates( self, x: torch.Tensor, - indices: Optional[Union[int, List[int], Tuple[int]]] = None, + indices: Optional[Union[int, List[int]]] = None, norm: bool = False, stop_early: bool = False, output_fmt: str = 'NCHW', @@ -221,7 +221,7 @@ class EfficientNet(nn.Module): def prune_intermediate_layers( self, - indices: Union[int, List[int], Tuple[int]] = 1, + indices: Union[int, List[int]] = 1, prune_norm: bool = False, prune_head: bool = True, extra_blocks: bool = False, diff --git a/timm/models/eva.py b/timm/models/eva.py index f31fd08f..c8a430c4 100644 --- a/timm/models/eva.py +++ b/timm/models/eva.py @@ -589,7 +589,7 @@ class Eva(nn.Module): def forward_intermediates( self, x: torch.Tensor, - indices: Optional[Union[int, List[int], Tuple[int]]] = None, + indices: Optional[Union[int, List[int]]] = None, return_prefix_tokens: bool = False, norm: bool = False, stop_early: bool = False, @@ -646,7 +646,7 @@ class Eva(nn.Module): def prune_intermediate_layers( self, - indices: Union[int, List[int], Tuple[int]] = 1, + indices: Union[int, List[int]] = 1, prune_norm: bool = False, prune_head: bool = True, ): diff --git a/timm/models/fastvit.py b/timm/models/fastvit.py index bac23962..2221ed82 100644 --- a/timm/models/fastvit.py +++ b/timm/models/fastvit.py @@ -1251,7 +1251,7 @@ class FastVit(nn.Module): def forward_intermediates( self, x: torch.Tensor, - indices: Optional[Union[int, List[int], Tuple[int]]] = None, + indices: Optional[Union[int, List[int]]] = None, norm: bool = False, stop_early: bool = False, output_fmt: str = 'NCHW', @@ -1296,7 +1296,7 @@ class FastVit(nn.Module): def prune_intermediate_layers( self, - indices: Union[int, List[int], Tuple[int]] = 1, + indices: Union[int, List[int]] = 1, prune_norm: bool = False, prune_head: bool = True, ): diff --git a/timm/models/hiera.py b/timm/models/hiera.py index e06d3545..2d57ad2c 100644 --- a/timm/models/hiera.py +++ b/timm/models/hiera.py @@ -669,7 +669,7 @@ class Hiera(nn.Module): self, x: torch.Tensor, mask: Optional[torch.Tensor] = None, - indices: Optional[Union[int, List[int], Tuple[int]]] = None, + indices: Optional[Union[int, List[int]]] = None, norm: bool = False, stop_early: bool = True, output_fmt: str = 'NCHW', @@ -722,7 +722,7 @@ class Hiera(nn.Module): def prune_intermediate_layers( self, - indices: Union[int, List[int], Tuple[int]] = 1, + indices: Union[int, List[int]] = 1, prune_norm: bool = False, prune_head: bool = True, ): diff --git a/timm/models/levit.py b/timm/models/levit.py index 4e43006a..16186cae 100644 --- a/timm/models/levit.py +++ b/timm/models/levit.py @@ -638,7 +638,7 @@ class Levit(nn.Module): def forward_intermediates( self, x: torch.Tensor, - indices: Optional[Union[int, List[int], Tuple[int]]] = None, + indices: Optional[Union[int, List[int]]] = None, norm: bool = False, stop_early: bool = False, output_fmt: str = 'NCHW', @@ -687,7 +687,7 @@ class Levit(nn.Module): def prune_intermediate_layers( self, - indices: Union[int, List[int], Tuple[int]] = 1, + indices: Union[int, List[int]] = 1, prune_norm: bool = False, prune_head: bool = True, ): diff --git a/timm/models/maxxvit.py b/timm/models/maxxvit.py index 9c418510..e4375b34 100644 --- a/timm/models/maxxvit.py +++ b/timm/models/maxxvit.py @@ -1256,7 +1256,7 @@ class MaxxVit(nn.Module): def forward_intermediates( self, x: torch.Tensor, - indices: Optional[Union[int, List[int], Tuple[int]]] = None, + indices: Optional[Union[int, List[int]]] = None, norm: bool = False, stop_early: bool = False, output_fmt: str = 'NCHW', @@ -1308,7 +1308,7 @@ class MaxxVit(nn.Module): def prune_intermediate_layers( self, - indices: Union[int, List[int], Tuple[int]] = 1, + indices: Union[int, List[int]] = 1, prune_norm: bool = False, prune_head: bool = True, ): diff --git a/timm/models/mlp_mixer.py b/timm/models/mlp_mixer.py index 087a924c..25cde6a6 100644 --- a/timm/models/mlp_mixer.py +++ b/timm/models/mlp_mixer.py @@ -265,7 +265,7 @@ class MlpMixer(nn.Module): def forward_intermediates( self, x: torch.Tensor, - indices: Optional[Union[int, List[int], Tuple[int]]] = None, + indices: Optional[Union[int, List[int]]] = None, norm: bool = False, stop_early: bool = False, output_fmt: str = 'NCHW', @@ -318,7 +318,7 @@ class MlpMixer(nn.Module): def prune_intermediate_layers( self, - indices: Union[int, List[int], Tuple[int]] = 1, + indices: Union[int, List[int]] = 1, prune_norm: bool = False, prune_head: bool = True, ): diff --git a/timm/models/mobilenetv3.py b/timm/models/mobilenetv3.py index 8aa1acb0..4c55ca88 100644 --- a/timm/models/mobilenetv3.py +++ b/timm/models/mobilenetv3.py @@ -170,7 +170,7 @@ class MobileNetV3(nn.Module): def forward_intermediates( self, x: torch.Tensor, - indices: Optional[Union[int, List[int], Tuple[int]]] = None, + indices: Optional[Union[int, List[int]]] = None, norm: bool = False, stop_early: bool = False, output_fmt: str = 'NCHW', @@ -225,7 +225,7 @@ class MobileNetV3(nn.Module): def prune_intermediate_layers( self, - indices: Union[int, List[int], Tuple[int]] = 1, + indices: Union[int, List[int]] = 1, prune_norm: bool = False, prune_head: bool = True, extra_blocks: bool = False, diff --git a/timm/models/mvitv2.py b/timm/models/mvitv2.py index 7735b631..167ebb9e 100644 --- a/timm/models/mvitv2.py +++ b/timm/models/mvitv2.py @@ -837,7 +837,7 @@ class MultiScaleVit(nn.Module): def forward_intermediates( self, x: torch.Tensor, - indices: Optional[Union[int, List[int], Tuple[int]]] = None, + indices: Optional[Union[int, List[int]]] = None, norm: bool = False, stop_early: bool = False, output_fmt: str = 'NCHW', @@ -893,7 +893,7 @@ class MultiScaleVit(nn.Module): def prune_intermediate_layers( self, - indices: Union[int, List[int], Tuple[int]] = 1, + indices: Union[int, List[int]] = 1, prune_norm: bool = False, prune_head: bool = True, ): diff --git a/timm/models/regnet.py b/timm/models/regnet.py index 374ecaa0..1e741bec 100644 --- a/timm/models/regnet.py +++ b/timm/models/regnet.py @@ -520,7 +520,7 @@ class RegNet(nn.Module): def forward_intermediates( self, x: torch.Tensor, - indices: Optional[Union[int, List[int], Tuple[int]]] = None, + indices: Optional[Union[int, List[int]]] = None, norm: bool = False, stop_early: bool = False, output_fmt: str = 'NCHW', @@ -567,7 +567,7 @@ class RegNet(nn.Module): def prune_intermediate_layers( self, - indices: Union[int, List[int], Tuple[int]] = 1, + indices: Union[int, List[int]] = 1, prune_norm: bool = False, prune_head: bool = True, ): diff --git a/timm/models/resnet.py b/timm/models/resnet.py index 6893bc5b..3e2326a5 100644 --- a/timm/models/resnet.py +++ b/timm/models/resnet.py @@ -548,7 +548,7 @@ class ResNet(nn.Module): def forward_intermediates( self, x: torch.Tensor, - indices: Optional[Union[int, List[int], Tuple[int]]] = None, + indices: Optional[Union[int, List[int]]] = None, norm: bool = False, stop_early: bool = False, output_fmt: str = 'NCHW', @@ -595,7 +595,7 @@ class ResNet(nn.Module): def prune_intermediate_layers( self, - indices: Union[int, List[int], Tuple[int]] = 1, + indices: Union[int, List[int]] = 1, prune_norm: bool = False, prune_head: bool = True, ): diff --git a/timm/models/swin_transformer.py b/timm/models/swin_transformer.py index a5800937..25f87b8f 100644 --- a/timm/models/swin_transformer.py +++ b/timm/models/swin_transformer.py @@ -611,7 +611,7 @@ class SwinTransformer(nn.Module): def forward_intermediates( self, x: torch.Tensor, - indices: Optional[Union[int, List[int], Tuple[int]]] = None, + indices: Optional[Union[int, List[int]]] = None, norm: bool = False, stop_early: bool = False, output_fmt: str = 'NCHW', @@ -660,7 +660,7 @@ class SwinTransformer(nn.Module): def prune_intermediate_layers( self, - indices: Union[int, List[int], Tuple[int]] = 1, + indices: Union[int, List[int]] = 1, prune_norm: bool = False, prune_head: bool = True, ): diff --git a/timm/models/swin_transformer_v2.py b/timm/models/swin_transformer_v2.py index 7bf91032..cc8f5a75 100644 --- a/timm/models/swin_transformer_v2.py +++ b/timm/models/swin_transformer_v2.py @@ -612,7 +612,7 @@ class SwinTransformerV2(nn.Module): def forward_intermediates( self, x: torch.Tensor, - indices: Optional[Union[int, List[int], Tuple[int]]] = None, + indices: Optional[Union[int, List[int]]] = None, norm: bool = False, stop_early: bool = False, output_fmt: str = 'NCHW', @@ -661,7 +661,7 @@ class SwinTransformerV2(nn.Module): def prune_intermediate_layers( self, - indices: Union[int, List[int], Tuple[int]] = 1, + indices: Union[int, List[int]] = 1, prune_norm: bool = False, prune_head: bool = True, ): diff --git a/timm/models/swin_transformer_v2_cr.py b/timm/models/swin_transformer_v2_cr.py index d5fcbadc..5d455ae2 100644 --- a/timm/models/swin_transformer_v2_cr.py +++ b/timm/models/swin_transformer_v2_cr.py @@ -722,7 +722,7 @@ class SwinTransformerV2Cr(nn.Module): def forward_intermediates( self, x: torch.Tensor, - indices: Optional[Union[int, List[int], Tuple[int]]] = None, + indices: Optional[Union[int, List[int]]] = None, norm: bool = False, stop_early: bool = False, output_fmt: str = 'NCHW', @@ -763,7 +763,7 @@ class SwinTransformerV2Cr(nn.Module): def prune_intermediate_layers( self, - indices: Union[int, List[int], Tuple[int]] = 1, + indices: Union[int, List[int]] = 1, prune_norm: bool = False, prune_head: bool = True, ): diff --git a/timm/models/twins.py b/timm/models/twins.py index 1aea273d..62029b94 100644 --- a/timm/models/twins.py +++ b/timm/models/twins.py @@ -407,7 +407,7 @@ class Twins(nn.Module): def forward_intermediates( self, x: torch.Tensor, - indices: Optional[Union[int, List[int], Tuple[int]]] = None, + indices: Optional[Union[int, List[int]]] = None, norm: bool = False, stop_early: bool = False, output_fmt: str = 'NCHW', @@ -461,7 +461,7 @@ class Twins(nn.Module): def prune_intermediate_layers( self, - indices: Union[int, List[int], Tuple[int]] = 1, + indices: Union[int, List[int]] = 1, prune_norm: bool = False, prune_head: bool = True, ): diff --git a/timm/models/vision_transformer.py b/timm/models/vision_transformer.py index fb95aebc..08b8cdb4 100644 --- a/timm/models/vision_transformer.py +++ b/timm/models/vision_transformer.py @@ -673,7 +673,7 @@ class VisionTransformer(nn.Module): def forward_intermediates( self, x: torch.Tensor, - indices: Optional[Union[int, List[int], Tuple[int]]] = None, + indices: Optional[Union[int, List[int]]] = None, return_prefix_tokens: bool = False, norm: bool = False, stop_early: bool = False, @@ -737,7 +737,7 @@ class VisionTransformer(nn.Module): def prune_intermediate_layers( self, - indices: Union[int, List[int], Tuple[int]] = 1, + indices: Union[int, List[int]] = 1, prune_norm: bool = False, prune_head: bool = True, ): diff --git a/timm/models/vision_transformer_relpos.py b/timm/models/vision_transformer_relpos.py index ed66068e..23419597 100644 --- a/timm/models/vision_transformer_relpos.py +++ b/timm/models/vision_transformer_relpos.py @@ -391,7 +391,7 @@ class VisionTransformerRelPos(nn.Module): def forward_intermediates( self, x: torch.Tensor, - indices: Optional[Union[int, List[int], Tuple[int]]] = None, + indices: Optional[Union[int, List[int]]] = None, return_prefix_tokens: bool = False, norm: bool = False, stop_early: bool = False, @@ -455,7 +455,7 @@ class VisionTransformerRelPos(nn.Module): def prune_intermediate_layers( self, - indices: Union[int, List[int], Tuple[int]] = 1, + indices: Union[int, List[int]] = 1, prune_norm: bool = False, prune_head: bool = True, ): diff --git a/timm/models/vision_transformer_sam.py b/timm/models/vision_transformer_sam.py index aeabc770..b3b74aa3 100644 --- a/timm/models/vision_transformer_sam.py +++ b/timm/models/vision_transformer_sam.py @@ -542,7 +542,7 @@ class VisionTransformerSAM(nn.Module): def forward_intermediates( self, x: torch.Tensor, - indices: Optional[Union[int, List[int], Tuple[int]]] = None, + indices: Optional[Union[int, List[int]]] = None, norm: bool = False, stop_early: bool = False, output_fmt: str = 'NCHW', @@ -597,7 +597,7 @@ class VisionTransformerSAM(nn.Module): def prune_intermediate_layers( self, - indices: Optional[Union[int, List[int], Tuple[int]]] = None, + indices: Optional[Union[int, List[int]]] = None, prune_norm: bool = False, prune_head: bool = True, ): diff --git a/timm/models/volo.py b/timm/models/volo.py index e1e4e0db..0d273180 100644 --- a/timm/models/volo.py +++ b/timm/models/volo.py @@ -711,7 +711,7 @@ class VOLO(nn.Module): def forward_intermediates( self, x: torch.Tensor, - indices: Optional[Union[int, List[int], Tuple[int]]] = None, + indices: Optional[Union[int, List[int]]] = None, norm: bool = False, stop_early: bool = False, output_fmt: str = 'NCHW', @@ -772,7 +772,7 @@ class VOLO(nn.Module): def prune_intermediate_layers( self, - indices: Union[int, List[int], Tuple[int]] = 1, + indices: Union[int, List[int]] = 1, prune_norm: bool = False, prune_head: bool = True, ): diff --git a/timm/models/xcit.py b/timm/models/xcit.py index a1e0f0c3..1e902ac2 100644 --- a/timm/models/xcit.py +++ b/timm/models/xcit.py @@ -442,7 +442,7 @@ class Xcit(nn.Module): def forward_intermediates( self, x: torch.Tensor, - indices: Optional[Union[int, List[int], Tuple[int]]] = None, + indices: Optional[Union[int, List[int]]] = None, norm: bool = False, stop_early: bool = False, output_fmt: str = 'NCHW', @@ -502,7 +502,7 @@ class Xcit(nn.Module): def prune_intermediate_layers( self, - indices: Union[int, List[int], Tuple[int]] = 1, + indices: Union[int, List[int]] = 1, prune_norm: bool = False, prune_head: bool = True, ): From 8efdc38213ccaa0593c05b6631123edd9f32073c Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Tue, 23 Jul 2024 08:19:09 -0700 Subject: [PATCH 2/2] Fix #2242 add checks for out indices with intermediate getter mode --- timm/models/_features.py | 13 +++++++------ 1 file changed, 7 insertions(+), 6 deletions(-) diff --git a/timm/models/_features.py b/timm/models/_features.py index cdd7216a..14d174f5 100644 --- a/timm/models/_features.py +++ b/timm/models/_features.py @@ -17,7 +17,7 @@ import torch import torch.nn as nn from torch.utils.checkpoint import checkpoint -from timm.layers import Format +from timm.layers import Format, _assert __all__ = [ @@ -51,14 +51,15 @@ def feature_take_indices( indices = num_features # all features if None if isinstance(indices, int): - assert indices >= 0 # convert int -> last n indices + _assert(0 < indices <= num_features, f'last-n ({indices}) is out of range (1 to {num_features})') take_indices = [num_features - indices + i for i in range(indices)] - elif isinstance(indices, tuple): - # duplicating this is silly, but needed for torchscript type resolution of n - take_indices = [num_features + idx if idx < 0 else idx for idx in indices] else: - take_indices = [num_features + idx if idx < 0 else idx for idx in indices] + take_indices: List[int] = [] + for i in indices: + idx = num_features + i if i < 0 else i + _assert(0 <= idx < num_features, f'feature index {idx} is out of range (0 to {num_features - 1})') + take_indices.append(idx) if not torch.jit.is_scripting() and as_set: return set(take_indices), max(take_indices)