Fix #2242 add checks for out indices with intermediate getter mode

fix_out_indices_order
Ross Wightman 2024-07-23 08:19:09 -07:00
parent d2240745d3
commit 8efdc38213
1 changed files with 7 additions and 6 deletions

View File

@ -17,7 +17,7 @@ import torch
import torch.nn as nn
from torch.utils.checkpoint import checkpoint
from timm.layers import Format
from timm.layers import Format, _assert
__all__ = [
@ -51,14 +51,15 @@ def feature_take_indices(
indices = num_features # all features if None
if isinstance(indices, int):
assert indices >= 0
# convert int -> last n indices
_assert(0 < indices <= num_features, f'last-n ({indices}) is out of range (1 to {num_features})')
take_indices = [num_features - indices + i for i in range(indices)]
elif isinstance(indices, tuple):
# duplicating this is silly, but needed for torchscript type resolution of n
take_indices = [num_features + idx if idx < 0 else idx for idx in indices]
else:
take_indices = [num_features + idx if idx < 0 else idx for idx in indices]
take_indices: List[int] = []
for i in indices:
idx = num_features + i if i < 0 else i
_assert(0 <= idx < num_features, f'feature index {idx} is out of range (0 to {num_features - 1})')
take_indices.append(idx)
if not torch.jit.is_scripting() and as_set:
return set(take_indices), max(take_indices)