Fix #2242 add checks for out indices with intermediate getter mode
parent
d2240745d3
commit
8efdc38213
|
@ -17,7 +17,7 @@ import torch
|
|||
import torch.nn as nn
|
||||
from torch.utils.checkpoint import checkpoint
|
||||
|
||||
from timm.layers import Format
|
||||
from timm.layers import Format, _assert
|
||||
|
||||
|
||||
__all__ = [
|
||||
|
@ -51,14 +51,15 @@ def feature_take_indices(
|
|||
indices = num_features # all features if None
|
||||
|
||||
if isinstance(indices, int):
|
||||
assert indices >= 0
|
||||
# convert int -> last n indices
|
||||
_assert(0 < indices <= num_features, f'last-n ({indices}) is out of range (1 to {num_features})')
|
||||
take_indices = [num_features - indices + i for i in range(indices)]
|
||||
elif isinstance(indices, tuple):
|
||||
# duplicating this is silly, but needed for torchscript type resolution of n
|
||||
take_indices = [num_features + idx if idx < 0 else idx for idx in indices]
|
||||
else:
|
||||
take_indices = [num_features + idx if idx < 0 else idx for idx in indices]
|
||||
take_indices: List[int] = []
|
||||
for i in indices:
|
||||
idx = num_features + i if i < 0 else i
|
||||
_assert(0 <= idx < num_features, f'feature index {idx} is out of range (0 to {num_features - 1})')
|
||||
take_indices.append(idx)
|
||||
|
||||
if not torch.jit.is_scripting() and as_set:
|
||||
return set(take_indices), max(take_indices)
|
||||
|
|
Loading…
Reference in New Issue