mirror of
https://github.com/huggingface/pytorch-image-models.git
synced 2025-06-03 15:01:08 +08:00
Fix #2242 add checks for out indices with intermediate getter mode
This commit is contained in:
parent
d2240745d3
commit
8efdc38213
@ -17,7 +17,7 @@ import torch
|
|||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
from torch.utils.checkpoint import checkpoint
|
from torch.utils.checkpoint import checkpoint
|
||||||
|
|
||||||
from timm.layers import Format
|
from timm.layers import Format, _assert
|
||||||
|
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
@ -51,14 +51,15 @@ def feature_take_indices(
|
|||||||
indices = num_features # all features if None
|
indices = num_features # all features if None
|
||||||
|
|
||||||
if isinstance(indices, int):
|
if isinstance(indices, int):
|
||||||
assert indices >= 0
|
|
||||||
# convert int -> last n indices
|
# convert int -> last n indices
|
||||||
|
_assert(0 < indices <= num_features, f'last-n ({indices}) is out of range (1 to {num_features})')
|
||||||
take_indices = [num_features - indices + i for i in range(indices)]
|
take_indices = [num_features - indices + i for i in range(indices)]
|
||||||
elif isinstance(indices, tuple):
|
|
||||||
# duplicating this is silly, but needed for torchscript type resolution of n
|
|
||||||
take_indices = [num_features + idx if idx < 0 else idx for idx in indices]
|
|
||||||
else:
|
else:
|
||||||
take_indices = [num_features + idx if idx < 0 else idx for idx in indices]
|
take_indices: List[int] = []
|
||||||
|
for i in indices:
|
||||||
|
idx = num_features + i if i < 0 else i
|
||||||
|
_assert(0 <= idx < num_features, f'feature index {idx} is out of range (0 to {num_features - 1})')
|
||||||
|
take_indices.append(idx)
|
||||||
|
|
||||||
if not torch.jit.is_scripting() and as_set:
|
if not torch.jit.is_scripting() and as_set:
|
||||||
return set(take_indices), max(take_indices)
|
return set(take_indices), max(take_indices)
|
||||||
|
Loading…
x
Reference in New Issue
Block a user