Merge pull request #2239 from huggingface/fix_out_indices_order
Fix issue where feature out_indices out of order after wrapping with FeatureGetterNetpull/2243/head
commit
f3c11dc3a5
|
@ -11,13 +11,13 @@ Hacked together by / Copyright 2020 Ross Wightman
|
|||
from collections import OrderedDict, defaultdict
|
||||
from copy import deepcopy
|
||||
from functools import partial
|
||||
from typing import Dict, List, Optional, Sequence, Set, Tuple, Union
|
||||
from typing import Dict, List, Optional, Sequence, Tuple, Union
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from torch.utils.checkpoint import checkpoint
|
||||
|
||||
from timm.layers import Format
|
||||
from timm.layers import Format, _assert
|
||||
|
||||
|
||||
__all__ = [
|
||||
|
@ -26,44 +26,45 @@ __all__ = [
|
|||
]
|
||||
|
||||
|
||||
def _take_indices(
|
||||
num_blocks: int,
|
||||
n: Optional[Union[int, List[int], Tuple[int]]],
|
||||
) -> Tuple[Set[int], int]:
|
||||
if isinstance(n, int):
|
||||
assert n >= 0
|
||||
take_indices = {x for x in range(num_blocks - n, num_blocks)}
|
||||
else:
|
||||
take_indices = {num_blocks + idx if idx < 0 else idx for idx in n}
|
||||
return take_indices, max(take_indices)
|
||||
|
||||
|
||||
def _take_indices_jit(
|
||||
num_blocks: int,
|
||||
n: Union[int, List[int], Tuple[int]],
|
||||
) -> Tuple[List[int], int]:
|
||||
if isinstance(n, int):
|
||||
assert n >= 0
|
||||
take_indices = [num_blocks - n + i for i in range(n)]
|
||||
elif isinstance(n, tuple):
|
||||
# splitting this up is silly, but needed for torchscript type resolution of n
|
||||
take_indices = [num_blocks + idx if idx < 0 else idx for idx in n]
|
||||
else:
|
||||
take_indices = [num_blocks + idx if idx < 0 else idx for idx in n]
|
||||
return take_indices, max(take_indices)
|
||||
|
||||
|
||||
def feature_take_indices(
|
||||
num_blocks: int,
|
||||
indices: Optional[Union[int, List[int], Tuple[int]]] = None,
|
||||
num_features: int,
|
||||
indices: Optional[Union[int, List[int]]] = None,
|
||||
as_set: bool = False,
|
||||
) -> Tuple[List[int], int]:
|
||||
""" Determine the absolute feature indices to 'take' from.
|
||||
|
||||
Note: This function can be called in forwar() so must be torchscript compatible,
|
||||
which requires some incomplete typing and workaround hacks.
|
||||
|
||||
Args:
|
||||
num_features: total number of features to select from
|
||||
indices: indices to select,
|
||||
None -> select all
|
||||
int -> select last n
|
||||
list/tuple of int -> return specified (-ve indices specify from end)
|
||||
as_set: return as a set
|
||||
|
||||
Returns:
|
||||
List (or set) of absolute (from beginning) indices, Maximum index
|
||||
"""
|
||||
if indices is None:
|
||||
indices = num_blocks # all blocks if None
|
||||
if torch.jit.is_scripting():
|
||||
return _take_indices_jit(num_blocks, indices)
|
||||
indices = num_features # all features if None
|
||||
|
||||
if isinstance(indices, int):
|
||||
# convert int -> last n indices
|
||||
_assert(0 < indices <= num_features, f'last-n ({indices}) is out of range (1 to {num_features})')
|
||||
take_indices = [num_features - indices + i for i in range(indices)]
|
||||
else:
|
||||
# NOTE non-jit returns Set[int] instead of List[int] but torchscript can't handle that anno
|
||||
return _take_indices(num_blocks, indices)
|
||||
take_indices: List[int] = []
|
||||
for i in indices:
|
||||
idx = num_features + i if i < 0 else i
|
||||
_assert(0 <= idx < num_features, f'feature index {idx} is out of range (0 to {num_features - 1})')
|
||||
take_indices.append(idx)
|
||||
|
||||
if not torch.jit.is_scripting() and as_set:
|
||||
return set(take_indices), max(take_indices)
|
||||
|
||||
return take_indices, max(take_indices)
|
||||
|
||||
|
||||
def _out_indices_as_tuple(x: Union[int, Tuple[int, ...]]) -> Tuple[int, ...]:
|
||||
|
@ -464,7 +465,6 @@ class FeatureGetterNet(nn.ModuleDict):
|
|||
out_indices,
|
||||
prune_norm=not norm,
|
||||
)
|
||||
out_indices = list(out_indices)
|
||||
self.feature_info = _get_feature_info(model, out_indices)
|
||||
self.model = model
|
||||
self.out_indices = out_indices
|
||||
|
|
|
@ -404,7 +404,7 @@ class Beit(nn.Module):
|
|||
def forward_intermediates(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
indices: Optional[Union[int, List[int], Tuple[int]]] = None,
|
||||
indices: Optional[Union[int, List[int]]] = None,
|
||||
return_prefix_tokens: bool = False,
|
||||
norm: bool = False,
|
||||
stop_early: bool = False,
|
||||
|
@ -470,7 +470,7 @@ class Beit(nn.Module):
|
|||
|
||||
def prune_intermediate_layers(
|
||||
self,
|
||||
indices: Union[int, List[int], Tuple[int]] = 1,
|
||||
indices: Union[int, List[int]] = 1,
|
||||
prune_norm: bool = False,
|
||||
prune_head: bool = True,
|
||||
):
|
||||
|
|
|
@ -1343,7 +1343,7 @@ class ByobNet(nn.Module):
|
|||
def forward_intermediates(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
indices: Optional[Union[int, List[int], Tuple[int]]] = None,
|
||||
indices: Optional[Union[int, List[int]]] = None,
|
||||
norm: bool = False,
|
||||
stop_early: bool = False,
|
||||
output_fmt: str = 'NCHW',
|
||||
|
@ -1401,7 +1401,7 @@ class ByobNet(nn.Module):
|
|||
|
||||
def prune_intermediate_layers(
|
||||
self,
|
||||
indices: Union[int, List[int], Tuple[int]] = 1,
|
||||
indices: Union[int, List[int]] = 1,
|
||||
prune_norm: bool = False,
|
||||
prune_head: bool = True,
|
||||
):
|
||||
|
|
|
@ -341,7 +341,7 @@ class Cait(nn.Module):
|
|||
def forward_intermediates(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
indices: Optional[Union[int, List[int], Tuple[int]]] = None,
|
||||
indices: Optional[Union[int, List[int]]] = None,
|
||||
norm: bool = False,
|
||||
stop_early: bool = False,
|
||||
output_fmt: str = 'NCHW',
|
||||
|
@ -398,7 +398,7 @@ class Cait(nn.Module):
|
|||
|
||||
def prune_intermediate_layers(
|
||||
self,
|
||||
indices: Union[int, List[int], Tuple[int]] = 1,
|
||||
indices: Union[int, List[int]] = 1,
|
||||
prune_norm: bool = False,
|
||||
prune_head: bool = True,
|
||||
):
|
||||
|
|
|
@ -412,7 +412,7 @@ class ConvNeXt(nn.Module):
|
|||
def forward_intermediates(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
indices: Optional[Union[int, List[int], Tuple[int]]] = None,
|
||||
indices: Optional[Union[int, List[int]]] = None,
|
||||
norm: bool = False,
|
||||
stop_early: bool = False,
|
||||
output_fmt: str = 'NCHW',
|
||||
|
@ -460,7 +460,7 @@ class ConvNeXt(nn.Module):
|
|||
|
||||
def prune_intermediate_layers(
|
||||
self,
|
||||
indices: Union[int, List[int], Tuple[int]] = 1,
|
||||
indices: Union[int, List[int]] = 1,
|
||||
prune_norm: bool = False,
|
||||
prune_head: bool = True,
|
||||
):
|
||||
|
|
|
@ -463,7 +463,7 @@ class EfficientFormer(nn.Module):
|
|||
def forward_intermediates(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
indices: Optional[Union[int, List[int], Tuple[int]]] = None,
|
||||
indices: Optional[Union[int, List[int]]] = None,
|
||||
norm: bool = False,
|
||||
stop_early: bool = False,
|
||||
output_fmt: str = 'NCHW',
|
||||
|
@ -516,7 +516,7 @@ class EfficientFormer(nn.Module):
|
|||
|
||||
def prune_intermediate_layers(
|
||||
self,
|
||||
indices: Union[int, List[int], Tuple[int]] = 1,
|
||||
indices: Union[int, List[int]] = 1,
|
||||
prune_norm: bool = False,
|
||||
prune_head: bool = True,
|
||||
):
|
||||
|
|
|
@ -165,7 +165,7 @@ class EfficientNet(nn.Module):
|
|||
def forward_intermediates(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
indices: Optional[Union[int, List[int], Tuple[int]]] = None,
|
||||
indices: Optional[Union[int, List[int]]] = None,
|
||||
norm: bool = False,
|
||||
stop_early: bool = False,
|
||||
output_fmt: str = 'NCHW',
|
||||
|
@ -221,7 +221,7 @@ class EfficientNet(nn.Module):
|
|||
|
||||
def prune_intermediate_layers(
|
||||
self,
|
||||
indices: Union[int, List[int], Tuple[int]] = 1,
|
||||
indices: Union[int, List[int]] = 1,
|
||||
prune_norm: bool = False,
|
||||
prune_head: bool = True,
|
||||
extra_blocks: bool = False,
|
||||
|
|
|
@ -589,7 +589,7 @@ class Eva(nn.Module):
|
|||
def forward_intermediates(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
indices: Optional[Union[int, List[int], Tuple[int]]] = None,
|
||||
indices: Optional[Union[int, List[int]]] = None,
|
||||
return_prefix_tokens: bool = False,
|
||||
norm: bool = False,
|
||||
stop_early: bool = False,
|
||||
|
@ -646,7 +646,7 @@ class Eva(nn.Module):
|
|||
|
||||
def prune_intermediate_layers(
|
||||
self,
|
||||
indices: Union[int, List[int], Tuple[int]] = 1,
|
||||
indices: Union[int, List[int]] = 1,
|
||||
prune_norm: bool = False,
|
||||
prune_head: bool = True,
|
||||
):
|
||||
|
|
|
@ -1251,7 +1251,7 @@ class FastVit(nn.Module):
|
|||
def forward_intermediates(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
indices: Optional[Union[int, List[int], Tuple[int]]] = None,
|
||||
indices: Optional[Union[int, List[int]]] = None,
|
||||
norm: bool = False,
|
||||
stop_early: bool = False,
|
||||
output_fmt: str = 'NCHW',
|
||||
|
@ -1296,7 +1296,7 @@ class FastVit(nn.Module):
|
|||
|
||||
def prune_intermediate_layers(
|
||||
self,
|
||||
indices: Union[int, List[int], Tuple[int]] = 1,
|
||||
indices: Union[int, List[int]] = 1,
|
||||
prune_norm: bool = False,
|
||||
prune_head: bool = True,
|
||||
):
|
||||
|
|
|
@ -669,7 +669,7 @@ class Hiera(nn.Module):
|
|||
self,
|
||||
x: torch.Tensor,
|
||||
mask: Optional[torch.Tensor] = None,
|
||||
indices: Optional[Union[int, List[int], Tuple[int]]] = None,
|
||||
indices: Optional[Union[int, List[int]]] = None,
|
||||
norm: bool = False,
|
||||
stop_early: bool = True,
|
||||
output_fmt: str = 'NCHW',
|
||||
|
@ -722,7 +722,7 @@ class Hiera(nn.Module):
|
|||
|
||||
def prune_intermediate_layers(
|
||||
self,
|
||||
indices: Union[int, List[int], Tuple[int]] = 1,
|
||||
indices: Union[int, List[int]] = 1,
|
||||
prune_norm: bool = False,
|
||||
prune_head: bool = True,
|
||||
):
|
||||
|
|
|
@ -638,7 +638,7 @@ class Levit(nn.Module):
|
|||
def forward_intermediates(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
indices: Optional[Union[int, List[int], Tuple[int]]] = None,
|
||||
indices: Optional[Union[int, List[int]]] = None,
|
||||
norm: bool = False,
|
||||
stop_early: bool = False,
|
||||
output_fmt: str = 'NCHW',
|
||||
|
@ -687,7 +687,7 @@ class Levit(nn.Module):
|
|||
|
||||
def prune_intermediate_layers(
|
||||
self,
|
||||
indices: Union[int, List[int], Tuple[int]] = 1,
|
||||
indices: Union[int, List[int]] = 1,
|
||||
prune_norm: bool = False,
|
||||
prune_head: bool = True,
|
||||
):
|
||||
|
|
|
@ -1256,7 +1256,7 @@ class MaxxVit(nn.Module):
|
|||
def forward_intermediates(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
indices: Optional[Union[int, List[int], Tuple[int]]] = None,
|
||||
indices: Optional[Union[int, List[int]]] = None,
|
||||
norm: bool = False,
|
||||
stop_early: bool = False,
|
||||
output_fmt: str = 'NCHW',
|
||||
|
@ -1308,7 +1308,7 @@ class MaxxVit(nn.Module):
|
|||
|
||||
def prune_intermediate_layers(
|
||||
self,
|
||||
indices: Union[int, List[int], Tuple[int]] = 1,
|
||||
indices: Union[int, List[int]] = 1,
|
||||
prune_norm: bool = False,
|
||||
prune_head: bool = True,
|
||||
):
|
||||
|
|
|
@ -265,7 +265,7 @@ class MlpMixer(nn.Module):
|
|||
def forward_intermediates(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
indices: Optional[Union[int, List[int], Tuple[int]]] = None,
|
||||
indices: Optional[Union[int, List[int]]] = None,
|
||||
norm: bool = False,
|
||||
stop_early: bool = False,
|
||||
output_fmt: str = 'NCHW',
|
||||
|
@ -318,7 +318,7 @@ class MlpMixer(nn.Module):
|
|||
|
||||
def prune_intermediate_layers(
|
||||
self,
|
||||
indices: Union[int, List[int], Tuple[int]] = 1,
|
||||
indices: Union[int, List[int]] = 1,
|
||||
prune_norm: bool = False,
|
||||
prune_head: bool = True,
|
||||
):
|
||||
|
|
|
@ -170,7 +170,7 @@ class MobileNetV3(nn.Module):
|
|||
def forward_intermediates(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
indices: Optional[Union[int, List[int], Tuple[int]]] = None,
|
||||
indices: Optional[Union[int, List[int]]] = None,
|
||||
norm: bool = False,
|
||||
stop_early: bool = False,
|
||||
output_fmt: str = 'NCHW',
|
||||
|
@ -225,7 +225,7 @@ class MobileNetV3(nn.Module):
|
|||
|
||||
def prune_intermediate_layers(
|
||||
self,
|
||||
indices: Union[int, List[int], Tuple[int]] = 1,
|
||||
indices: Union[int, List[int]] = 1,
|
||||
prune_norm: bool = False,
|
||||
prune_head: bool = True,
|
||||
extra_blocks: bool = False,
|
||||
|
|
|
@ -837,7 +837,7 @@ class MultiScaleVit(nn.Module):
|
|||
def forward_intermediates(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
indices: Optional[Union[int, List[int], Tuple[int]]] = None,
|
||||
indices: Optional[Union[int, List[int]]] = None,
|
||||
norm: bool = False,
|
||||
stop_early: bool = False,
|
||||
output_fmt: str = 'NCHW',
|
||||
|
@ -893,7 +893,7 @@ class MultiScaleVit(nn.Module):
|
|||
|
||||
def prune_intermediate_layers(
|
||||
self,
|
||||
indices: Union[int, List[int], Tuple[int]] = 1,
|
||||
indices: Union[int, List[int]] = 1,
|
||||
prune_norm: bool = False,
|
||||
prune_head: bool = True,
|
||||
):
|
||||
|
|
|
@ -520,7 +520,7 @@ class RegNet(nn.Module):
|
|||
def forward_intermediates(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
indices: Optional[Union[int, List[int], Tuple[int]]] = None,
|
||||
indices: Optional[Union[int, List[int]]] = None,
|
||||
norm: bool = False,
|
||||
stop_early: bool = False,
|
||||
output_fmt: str = 'NCHW',
|
||||
|
@ -567,7 +567,7 @@ class RegNet(nn.Module):
|
|||
|
||||
def prune_intermediate_layers(
|
||||
self,
|
||||
indices: Union[int, List[int], Tuple[int]] = 1,
|
||||
indices: Union[int, List[int]] = 1,
|
||||
prune_norm: bool = False,
|
||||
prune_head: bool = True,
|
||||
):
|
||||
|
|
|
@ -548,7 +548,7 @@ class ResNet(nn.Module):
|
|||
def forward_intermediates(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
indices: Optional[Union[int, List[int], Tuple[int]]] = None,
|
||||
indices: Optional[Union[int, List[int]]] = None,
|
||||
norm: bool = False,
|
||||
stop_early: bool = False,
|
||||
output_fmt: str = 'NCHW',
|
||||
|
@ -595,7 +595,7 @@ class ResNet(nn.Module):
|
|||
|
||||
def prune_intermediate_layers(
|
||||
self,
|
||||
indices: Union[int, List[int], Tuple[int]] = 1,
|
||||
indices: Union[int, List[int]] = 1,
|
||||
prune_norm: bool = False,
|
||||
prune_head: bool = True,
|
||||
):
|
||||
|
|
|
@ -611,7 +611,7 @@ class SwinTransformer(nn.Module):
|
|||
def forward_intermediates(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
indices: Optional[Union[int, List[int], Tuple[int]]] = None,
|
||||
indices: Optional[Union[int, List[int]]] = None,
|
||||
norm: bool = False,
|
||||
stop_early: bool = False,
|
||||
output_fmt: str = 'NCHW',
|
||||
|
@ -660,7 +660,7 @@ class SwinTransformer(nn.Module):
|
|||
|
||||
def prune_intermediate_layers(
|
||||
self,
|
||||
indices: Union[int, List[int], Tuple[int]] = 1,
|
||||
indices: Union[int, List[int]] = 1,
|
||||
prune_norm: bool = False,
|
||||
prune_head: bool = True,
|
||||
):
|
||||
|
|
|
@ -612,7 +612,7 @@ class SwinTransformerV2(nn.Module):
|
|||
def forward_intermediates(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
indices: Optional[Union[int, List[int], Tuple[int]]] = None,
|
||||
indices: Optional[Union[int, List[int]]] = None,
|
||||
norm: bool = False,
|
||||
stop_early: bool = False,
|
||||
output_fmt: str = 'NCHW',
|
||||
|
@ -661,7 +661,7 @@ class SwinTransformerV2(nn.Module):
|
|||
|
||||
def prune_intermediate_layers(
|
||||
self,
|
||||
indices: Union[int, List[int], Tuple[int]] = 1,
|
||||
indices: Union[int, List[int]] = 1,
|
||||
prune_norm: bool = False,
|
||||
prune_head: bool = True,
|
||||
):
|
||||
|
|
|
@ -722,7 +722,7 @@ class SwinTransformerV2Cr(nn.Module):
|
|||
def forward_intermediates(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
indices: Optional[Union[int, List[int], Tuple[int]]] = None,
|
||||
indices: Optional[Union[int, List[int]]] = None,
|
||||
norm: bool = False,
|
||||
stop_early: bool = False,
|
||||
output_fmt: str = 'NCHW',
|
||||
|
@ -763,7 +763,7 @@ class SwinTransformerV2Cr(nn.Module):
|
|||
|
||||
def prune_intermediate_layers(
|
||||
self,
|
||||
indices: Union[int, List[int], Tuple[int]] = 1,
|
||||
indices: Union[int, List[int]] = 1,
|
||||
prune_norm: bool = False,
|
||||
prune_head: bool = True,
|
||||
):
|
||||
|
|
|
@ -407,7 +407,7 @@ class Twins(nn.Module):
|
|||
def forward_intermediates(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
indices: Optional[Union[int, List[int], Tuple[int]]] = None,
|
||||
indices: Optional[Union[int, List[int]]] = None,
|
||||
norm: bool = False,
|
||||
stop_early: bool = False,
|
||||
output_fmt: str = 'NCHW',
|
||||
|
@ -461,7 +461,7 @@ class Twins(nn.Module):
|
|||
|
||||
def prune_intermediate_layers(
|
||||
self,
|
||||
indices: Union[int, List[int], Tuple[int]] = 1,
|
||||
indices: Union[int, List[int]] = 1,
|
||||
prune_norm: bool = False,
|
||||
prune_head: bool = True,
|
||||
):
|
||||
|
|
|
@ -673,7 +673,7 @@ class VisionTransformer(nn.Module):
|
|||
def forward_intermediates(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
indices: Optional[Union[int, List[int], Tuple[int]]] = None,
|
||||
indices: Optional[Union[int, List[int]]] = None,
|
||||
return_prefix_tokens: bool = False,
|
||||
norm: bool = False,
|
||||
stop_early: bool = False,
|
||||
|
@ -737,7 +737,7 @@ class VisionTransformer(nn.Module):
|
|||
|
||||
def prune_intermediate_layers(
|
||||
self,
|
||||
indices: Union[int, List[int], Tuple[int]] = 1,
|
||||
indices: Union[int, List[int]] = 1,
|
||||
prune_norm: bool = False,
|
||||
prune_head: bool = True,
|
||||
):
|
||||
|
|
|
@ -391,7 +391,7 @@ class VisionTransformerRelPos(nn.Module):
|
|||
def forward_intermediates(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
indices: Optional[Union[int, List[int], Tuple[int]]] = None,
|
||||
indices: Optional[Union[int, List[int]]] = None,
|
||||
return_prefix_tokens: bool = False,
|
||||
norm: bool = False,
|
||||
stop_early: bool = False,
|
||||
|
@ -455,7 +455,7 @@ class VisionTransformerRelPos(nn.Module):
|
|||
|
||||
def prune_intermediate_layers(
|
||||
self,
|
||||
indices: Union[int, List[int], Tuple[int]] = 1,
|
||||
indices: Union[int, List[int]] = 1,
|
||||
prune_norm: bool = False,
|
||||
prune_head: bool = True,
|
||||
):
|
||||
|
|
|
@ -542,7 +542,7 @@ class VisionTransformerSAM(nn.Module):
|
|||
def forward_intermediates(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
indices: Optional[Union[int, List[int], Tuple[int]]] = None,
|
||||
indices: Optional[Union[int, List[int]]] = None,
|
||||
norm: bool = False,
|
||||
stop_early: bool = False,
|
||||
output_fmt: str = 'NCHW',
|
||||
|
@ -597,7 +597,7 @@ class VisionTransformerSAM(nn.Module):
|
|||
|
||||
def prune_intermediate_layers(
|
||||
self,
|
||||
indices: Optional[Union[int, List[int], Tuple[int]]] = None,
|
||||
indices: Optional[Union[int, List[int]]] = None,
|
||||
prune_norm: bool = False,
|
||||
prune_head: bool = True,
|
||||
):
|
||||
|
|
|
@ -711,7 +711,7 @@ class VOLO(nn.Module):
|
|||
def forward_intermediates(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
indices: Optional[Union[int, List[int], Tuple[int]]] = None,
|
||||
indices: Optional[Union[int, List[int]]] = None,
|
||||
norm: bool = False,
|
||||
stop_early: bool = False,
|
||||
output_fmt: str = 'NCHW',
|
||||
|
@ -772,7 +772,7 @@ class VOLO(nn.Module):
|
|||
|
||||
def prune_intermediate_layers(
|
||||
self,
|
||||
indices: Union[int, List[int], Tuple[int]] = 1,
|
||||
indices: Union[int, List[int]] = 1,
|
||||
prune_norm: bool = False,
|
||||
prune_head: bool = True,
|
||||
):
|
||||
|
|
|
@ -442,7 +442,7 @@ class Xcit(nn.Module):
|
|||
def forward_intermediates(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
indices: Optional[Union[int, List[int], Tuple[int]]] = None,
|
||||
indices: Optional[Union[int, List[int]]] = None,
|
||||
norm: bool = False,
|
||||
stop_early: bool = False,
|
||||
output_fmt: str = 'NCHW',
|
||||
|
@ -502,7 +502,7 @@ class Xcit(nn.Module):
|
|||
|
||||
def prune_intermediate_layers(
|
||||
self,
|
||||
indices: Union[int, List[int], Tuple[int]] = 1,
|
||||
indices: Union[int, List[int]] = 1,
|
||||
prune_norm: bool = False,
|
||||
prune_head: bool = True,
|
||||
):
|
||||
|
|
Loading…
Reference in New Issue