Merge branch 'main' into tnt

This commit is contained in:
Ryan 2025-05-11 22:45:48 +08:00
commit 69b1fbcdc1
31 changed files with 1573 additions and 59 deletions

View File

@ -566,7 +566,7 @@ Model validation results can be found in the [results tables](results/README.md)
The official documentation can be found at https://huggingface.co/docs/hub/timm. Documentation contributions are welcome.
[Getting Started with PyTorch Image Models (timm): A Practitioners Guide](https://towardsdatascience.com/getting-started-with-pytorch-image-models-timm-a-practitioners-guide-4e77b4bf9055) by [Chris Hughes](https://github.com/Chris-hughes10) is an extensive blog post covering many aspects of `timm` in detail.
[Getting Started with PyTorch Image Models (timm): A Practitioners Guide](https://towardsdatascience.com/getting-started-with-pytorch-image-models-timm-a-practitioners-guide-4e77b4bf9055-2/) by [Chris Hughes](https://github.com/Chris-hughes10) is an extensive blog post covering many aspects of `timm` in detail.
[timmdocs](http://timm.fast.ai/) is an alternate set of documentation for `timm`. A big thanks to [Aman Arora](https://github.com/amaarora) for his efforts creating timmdocs.

View File

@ -54,6 +54,9 @@ FEAT_INTER_FILTERS = [
'beit', 'mvitv2', 'eva', 'cait', 'xcit', 'volo', 'twins', 'deit', 'swin_transformer', 'swin_transformer_v2',
'swin_transformer_v2_cr', 'maxxvit', 'efficientnet', 'mobilenetv3', 'levit', 'efficientformer', 'resnet',
'regnet', 'byobnet', 'byoanet', 'mlp_mixer', 'hiera', 'fastvit', 'hieradet_sam2', 'aimv2*', 'tnt',
'tiny_vit', 'vovnet', 'tresnet', 'rexnet', 'resnetv2', 'repghost', 'repvit', 'pvt_v2', 'nextvit', 'nest',
'mambaout', 'inception_next', 'inception_v4', 'hgnet', 'gcvit', 'focalnet', 'efficientformer_v2', 'edgenext',
'davit', 'rdnet', 'convnext', 'pit'
]
# transformer / hybrid models don't support full set of spatial / feature APIs and/or have spatial output.
@ -508,8 +511,9 @@ def test_model_forward_intermediates(model_name, batch_size):
spatial_axis = get_spatial_dim(output_fmt)
import math
inpt = torch.randn((batch_size, *input_size))
output, intermediates = model.forward_intermediates(
torch.randn((batch_size, *input_size)),
inpt,
output_fmt=output_fmt,
)
assert len(expected_channels) == len(intermediates)
@ -521,6 +525,9 @@ def test_model_forward_intermediates(model_name, batch_size):
assert o.shape[0] == batch_size
assert not torch.isnan(o).any()
output2 = model.forward_features(inpt)
assert torch.allclose(output, output2)
def _create_fx_model(model, train=False):
# This block of code does a bit of juggling to handle any case where there are multiple outputs in train mode

View File

@ -144,6 +144,7 @@ def create_dataset(
use_train = split in _TRAIN_SYNONYM
ds = QMNIST(train=use_train, **torch_kwargs)
elif name == 'imagenet':
torch_kwargs.pop('download')
assert has_imagenet, 'Please update to a newer PyTorch and torchvision for ImageNet dataset.'
if split in _EVAL_SYNONYM:
split = 'val'

View File

@ -452,29 +452,29 @@ class ConvNeXt(nn.Module):
"""
assert output_fmt in ('NCHW',), 'Output shape must be NCHW.'
intermediates = []
take_indices, max_index = feature_take_indices(len(self.stages) + 1, indices)
take_indices, max_index = feature_take_indices(len(self.stages), indices)
# forward pass
feat_idx = 0 # stem is index 0
x = self.stem(x)
if feat_idx in take_indices:
intermediates.append(x)
last_idx = len(self.stages) - 1
if torch.jit.is_scripting() or not stop_early: # can't slice blocks in torchscript
stages = self.stages
else:
stages = self.stages[:max_index]
for stage in stages:
feat_idx += 1
stages = self.stages[:max_index + 1]
for feat_idx, stage in enumerate(stages):
x = stage(x)
if feat_idx in take_indices:
# NOTE not bothering to apply norm_pre when norm=True as almost no models have it enabled
intermediates.append(x)
if norm and feat_idx == last_idx:
intermediates.append(self.norm_pre(x))
else:
intermediates.append(x)
if intermediates_only:
return intermediates
x = self.norm_pre(x)
if feat_idx == last_idx:
x = self.norm_pre(x)
return x, intermediates
@ -486,8 +486,8 @@ class ConvNeXt(nn.Module):
):
""" Prune layers not required for specified intermediates.
"""
take_indices, max_index = feature_take_indices(len(self.stages) + 1, indices)
self.stages = self.stages[:max_index] # truncate blocks w/ stem as idx 0
take_indices, max_index = feature_take_indices(len(self.stages), indices)
self.stages = self.stages[:max_index + 1] # truncate blocks w/ stem as idx 0
if prune_norm:
self.norm_pre = nn.Identity()
if prune_head:

View File

@ -12,7 +12,7 @@ DaViT model defs and weights adapted from https://github.com/dingmyu/davit, orig
# All rights reserved.
# This source code is licensed under the MIT license
from functools import partial
from typing import Optional, Tuple
from typing import List, Optional, Tuple, Union
import torch
import torch.nn as nn
@ -23,6 +23,7 @@ from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
from timm.layers import DropPath, to_2tuple, trunc_normal_, Mlp, LayerNorm2d, get_norm_layer, use_fused_attn
from timm.layers import NormMlpClassifierHead, ClassifierHead
from ._builder import build_model_with_cfg
from ._features import feature_take_indices
from ._features_fx import register_notrace_function
from ._manipulate import checkpoint_seq
from ._registry import generate_default_cfgs, register_model
@ -636,6 +637,72 @@ class DaVit(nn.Module):
self.num_classes = num_classes
self.head.reset(num_classes, global_pool)
def forward_intermediates(
self,
x: torch.Tensor,
indices: Optional[Union[int, List[int]]] = None,
norm: bool = False,
stop_early: bool = False,
output_fmt: str = 'NCHW',
intermediates_only: bool = False,
) -> Union[List[torch.Tensor], Tuple[torch.Tensor, List[torch.Tensor]]]:
""" Forward features that returns intermediates.
Args:
x: Input image tensor
indices: Take last n blocks if int, all if None, select matching indices if sequence
norm: Apply norm layer to compatible intermediates
stop_early: Stop iterating over blocks when last desired intermediate hit
output_fmt: Shape of intermediate feature outputs
intermediates_only: Only return intermediate features
Returns:
"""
assert output_fmt in ('NCHW',), 'Output shape must be NCHW.'
intermediates = []
take_indices, max_index = feature_take_indices(len(self.stages), indices)
# forward pass
x = self.stem(x)
last_idx = len(self.stages) - 1
if torch.jit.is_scripting() or not stop_early: # can't slice blocks in torchscript
stages = self.stages
else:
stages = self.stages[:max_index + 1]
for feat_idx, stage in enumerate(stages):
x = stage(x)
if feat_idx in take_indices:
if norm and feat_idx == last_idx:
x_inter = self.norm_pre(x) # applying final norm to last intermediate
else:
x_inter = x
intermediates.append(x_inter)
if intermediates_only:
return intermediates
if feat_idx == last_idx:
x = self.norm_pre(x)
return x, intermediates
def prune_intermediate_layers(
self,
indices: Union[int, List[int]] = 1,
prune_norm: bool = False,
prune_head: bool = True,
):
""" Prune layers not required for specified intermediates.
"""
take_indices, max_index = feature_take_indices(len(self.stages), indices)
self.stages = self.stages[:max_index + 1] # truncate blocks w/ stem as idx 0
if prune_norm:
self.norm_pre = nn.Identity()
if prune_head:
self.reset_classifier(0, '')
return take_indices
def forward_features(self, x):
x = self.stem(x)
if self.grad_checkpointing and not torch.jit.is_scripting():

View File

@ -9,7 +9,7 @@ Modifications and additions for timm by / Copyright 2022, Ross Wightman
"""
import math
from functools import partial
from typing import Optional, Tuple
from typing import List, Optional, Tuple, Union
import torch
import torch.nn.functional as F
@ -19,6 +19,7 @@ from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
from timm.layers import trunc_normal_tf_, DropPath, LayerNorm2d, Mlp, create_conv2d, \
NormMlpClassifierHead, ClassifierHead
from ._builder import build_model_with_cfg
from ._features import feature_take_indices
from ._features_fx import register_notrace_module
from ._manipulate import named_apply, checkpoint_seq
from ._registry import register_model, generate_default_cfgs
@ -418,6 +419,72 @@ class EdgeNeXt(nn.Module):
self.num_classes = num_classes
self.head.reset(num_classes, global_pool)
def forward_intermediates(
self,
x: torch.Tensor,
indices: Optional[Union[int, List[int]]] = None,
norm: bool = False,
stop_early: bool = False,
output_fmt: str = 'NCHW',
intermediates_only: bool = False,
) -> Union[List[torch.Tensor], Tuple[torch.Tensor, List[torch.Tensor]]]:
""" Forward features that returns intermediates.
Args:
x: Input image tensor
indices: Take last n blocks if int, all if None, select matching indices if sequence
norm: Apply norm layer to compatible intermediates
stop_early: Stop iterating over blocks when last desired intermediate hit
output_fmt: Shape of intermediate feature outputs
intermediates_only: Only return intermediate features
Returns:
"""
assert output_fmt in ('NCHW',), 'Output shape must be NCHW.'
intermediates = []
take_indices, max_index = feature_take_indices(len(self.stages), indices)
# forward pass
x = self.stem(x)
last_idx = len(self.stages) - 1
if torch.jit.is_scripting() or not stop_early: # can't slice blocks in torchscript
stages = self.stages
else:
stages = self.stages[:max_index + 1]
for feat_idx, stage in enumerate(stages):
x = stage(x)
if feat_idx in take_indices:
if norm and feat_idx == last_idx:
x_inter = self.norm_pre(x) # applying final norm to last intermediate
else:
x_inter = x
intermediates.append(x_inter)
if intermediates_only:
return intermediates
if feat_idx == last_idx:
x = self.norm_pre(x)
return x, intermediates
def prune_intermediate_layers(
self,
indices: Union[int, List[int]] = 1,
prune_norm: bool = False,
prune_head: bool = True,
):
""" Prune layers not required for specified intermediates.
"""
take_indices, max_index = feature_take_indices(len(self.stages), indices)
self.stages = self.stages[:max_index + 1] # truncate blocks w/ stem as idx 0
if prune_norm:
self.norm_pre = nn.Identity()
if prune_head:
self.reset_classifier(0, '')
return take_indices
def forward_features(self, x):
x = self.stem(x)
x = self.stages(x)

View File

@ -16,7 +16,7 @@ Modifications and timm support by / Copyright 2023, Ross Wightman
"""
import math
from functools import partial
from typing import Dict, Optional
from typing import Dict, List, Optional, Tuple, Union
import torch
import torch.nn as nn
@ -25,6 +25,7 @@ from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
from timm.layers import create_conv2d, create_norm_layer, get_act_layer, get_norm_layer, ConvNormAct
from timm.layers import DropPath, trunc_normal_, to_2tuple, to_ntuple, ndgrid
from ._builder import build_model_with_cfg
from ._features import feature_take_indices
from ._manipulate import checkpoint_seq
from ._registry import generate_default_cfgs, register_model
@ -625,6 +626,73 @@ class EfficientFormerV2(nn.Module):
def set_distilled_training(self, enable=True):
self.distilled_training = enable
def forward_intermediates(
self,
x: torch.Tensor,
indices: Optional[Union[int, List[int]]] = None,
norm: bool = False,
stop_early: bool = False,
output_fmt: str = 'NCHW',
intermediates_only: bool = False,
) -> Union[List[torch.Tensor], Tuple[torch.Tensor, List[torch.Tensor]]]:
""" Forward features that returns intermediates.
Args:
x: Input image tensor
indices: Take last n blocks if int, all if None, select matching indices if sequence
norm: Apply norm layer to compatible intermediates
stop_early: Stop iterating over blocks when last desired intermediate hit
output_fmt: Shape of intermediate feature outputs
intermediates_only: Only return intermediate features
Returns:
"""
assert output_fmt in ('NCHW',), 'Output shape must be NCHW.'
intermediates = []
take_indices, max_index = feature_take_indices(len(self.stages), indices)
# forward pass
x = self.stem(x)
last_idx = len(self.stages) - 1
if torch.jit.is_scripting() or not stop_early: # can't slice blocks in torchscript
stages = self.stages
else:
stages = self.stages[:max_index + 1]
for feat_idx, stage in enumerate(stages):
x = stage(x)
if feat_idx in take_indices:
if feat_idx == last_idx:
x_inter = self.norm(x) if norm else x
intermediates.append(x_inter)
else:
intermediates.append(x)
if intermediates_only:
return intermediates
if feat_idx == last_idx:
x = self.norm(x)
return x, intermediates
def prune_intermediate_layers(
self,
indices: Union[int, List[int]] = 1,
prune_norm: bool = False,
prune_head: bool = True,
):
""" Prune layers not required for specified intermediates.
"""
take_indices, max_index = feature_take_indices(len(self.stages), indices)
self.stages = self.stages[:max_index + 1] # truncate blocks w/ stem as idx 0
if prune_norm:
self.norm = nn.Identity()
if prune_head:
self.reset_classifier(0, '')
return take_indices
def forward_features(self, x):
x = self.stem(x)
x = self.stages(x)

View File

@ -7,7 +7,7 @@ Adapted from official impl at https://github.com/mit-han-lab/efficientvit
"""
__all__ = ['EfficientVit', 'EfficientVitLarge']
from typing import List, Optional
from typing import List, Optional, Tuple, Union
from functools import partial
import torch
@ -17,6 +17,7 @@ import torch.nn.functional as F
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
from timm.layers import SelectAdaptivePool2d, create_conv2d, GELUTanh
from ._builder import build_model_with_cfg
from ._features import feature_take_indices
from ._features_fx import register_notrace_module
from ._manipulate import checkpoint_seq
from ._registry import register_model, generate_default_cfgs
@ -754,6 +755,63 @@ class EfficientVit(nn.Module):
self.num_classes = num_classes
self.head.reset(num_classes, global_pool)
def forward_intermediates(
self,
x: torch.Tensor,
indices: Optional[Union[int, List[int]]] = None,
norm: bool = False,
stop_early: bool = False,
output_fmt: str = 'NCHW',
intermediates_only: bool = False,
) -> Union[List[torch.Tensor], Tuple[torch.Tensor, List[torch.Tensor]]]:
""" Forward features that returns intermediates.
Args:
x: Input image tensor
indices: Take last n blocks if int, all if None, select matching indices if sequence
norm: Apply norm layer to compatible intermediates
stop_early: Stop iterating over blocks when last desired intermediate hit
output_fmt: Shape of intermediate feature outputs
intermediates_only: Only return intermediate features
Returns:
"""
assert output_fmt in ('NCHW',), 'Output shape must be NCHW.'
intermediates = []
take_indices, max_index = feature_take_indices(len(self.stages), indices)
# forward pass
x = self.stem(x)
if torch.jit.is_scripting() or not stop_early: # can't slice blocks in torchscript
stages = self.stages
else:
stages = self.stages[:max_index + 1]
for feat_idx, stage in enumerate(stages):
x = stage(x)
if feat_idx in take_indices:
intermediates.append(x)
if intermediates_only:
return intermediates
return x, intermediates
def prune_intermediate_layers(
self,
indices: Union[int, List[int]] = 1,
prune_norm: bool = False,
prune_head: bool = True,
):
""" Prune layers not required for specified intermediates.
"""
take_indices, max_index = feature_take_indices(len(self.stages), indices)
self.stages = self.stages[:max_index + 1] # truncate blocks w/ stem as idx 0
if prune_head:
self.reset_classifier(0, '')
return take_indices
def forward_features(self, x):
x = self.stem(x)
if self.grad_checkpointing and not torch.jit.is_scripting():
@ -851,6 +909,63 @@ class EfficientVitLarge(nn.Module):
self.num_classes = num_classes
self.head.reset(num_classes, global_pool)
def forward_intermediates(
self,
x: torch.Tensor,
indices: Optional[Union[int, List[int]]] = None,
norm: bool = False,
stop_early: bool = False,
output_fmt: str = 'NCHW',
intermediates_only: bool = False,
) -> Union[List[torch.Tensor], Tuple[torch.Tensor, List[torch.Tensor]]]:
""" Forward features that returns intermediates.
Args:
x: Input image tensor
indices: Take last n blocks if int, all if None, select matching indices if sequence
norm: Apply norm layer to compatible intermediates
stop_early: Stop iterating over blocks when last desired intermediate hit
output_fmt: Shape of intermediate feature outputs
intermediates_only: Only return intermediate features
Returns:
"""
assert output_fmt in ('NCHW',), 'Output shape must be NCHW.'
intermediates = []
take_indices, max_index = feature_take_indices(len(self.stages), indices)
# forward pass
x = self.stem(x)
if torch.jit.is_scripting() or not stop_early: # can't slice blocks in torchscript
stages = self.stages
else:
stages = self.stages[:max_index + 1]
for feat_idx, stage in enumerate(stages):
x = stage(x)
if feat_idx in take_indices:
intermediates.append(x)
if intermediates_only:
return intermediates
return x, intermediates
def prune_intermediate_layers(
self,
indices: Union[int, List[int]] = 1,
prune_norm: bool = False,
prune_head: bool = True,
):
""" Prune layers not required for specified intermediates.
"""
take_indices, max_index = feature_take_indices(len(self.stages), indices)
self.stages = self.stages[:max_index + 1] # truncate blocks w/ stem as idx 0
if prune_head:
self.reset_classifier(0, '')
return take_indices
def forward_features(self, x):
x = self.stem(x)
if self.grad_checkpointing and not torch.jit.is_scripting():

View File

@ -9,7 +9,7 @@ Adapted from official impl at https://github.com/microsoft/Cream/tree/main/Effic
__all__ = ['EfficientVitMsra']
import itertools
from collections import OrderedDict
from typing import Dict, Optional
from typing import Dict, List, Optional, Tuple, Union
import torch
import torch.nn as nn
@ -17,6 +17,7 @@ import torch.nn as nn
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
from timm.layers import SqueezeExcite, SelectAdaptivePool2d, trunc_normal_, _assert
from ._builder import build_model_with_cfg
from ._features import feature_take_indices
from ._manipulate import checkpoint_seq
from ._registry import register_model, generate_default_cfgs
@ -475,6 +476,63 @@ class EfficientVitMsra(nn.Module):
self.head = NormLinear(
self.num_features, num_classes, drop=self.drop_rate) if num_classes > 0 else torch.nn.Identity()
def forward_intermediates(
self,
x: torch.Tensor,
indices: Optional[Union[int, List[int]]] = None,
norm: bool = False,
stop_early: bool = False,
output_fmt: str = 'NCHW',
intermediates_only: bool = False,
) -> Union[List[torch.Tensor], Tuple[torch.Tensor, List[torch.Tensor]]]:
""" Forward features that returns intermediates.
Args:
x: Input image tensor
indices: Take last n blocks if int, all if None, select matching indices if sequence
norm: Apply norm layer to compatible intermediates
stop_early: Stop iterating over blocks when last desired intermediate hit
output_fmt: Shape of intermediate feature outputs
intermediates_only: Only return intermediate features
Returns:
"""
assert output_fmt in ('NCHW',), 'Output shape must be NCHW.'
intermediates = []
take_indices, max_index = feature_take_indices(len(self.stages), indices)
# forward pass
x = self.patch_embed(x)
if torch.jit.is_scripting() or not stop_early: # can't slice blocks in torchscript
stages = self.stages
else:
stages = self.stages[:max_index + 1]
for feat_idx, stage in enumerate(stages):
x = stage(x)
if feat_idx in take_indices:
intermediates.append(x)
if intermediates_only:
return intermediates
return x, intermediates
def prune_intermediate_layers(
self,
indices: Union[int, List[int]] = 1,
prune_norm: bool = False,
prune_head: bool = True,
):
""" Prune layers not required for specified intermediates.
"""
take_indices, max_index = feature_take_indices(len(self.stages), indices)
self.stages = self.stages[:max_index + 1] # truncate blocks w/ stem as idx 0
if prune_head:
self.reset_classifier(0, '')
return take_indices
def forward_features(self, x):
x = self.patch_embed(x)
if self.grad_checkpointing and not torch.jit.is_scripting():

View File

@ -18,7 +18,7 @@ This impl is/has:
# Written by Jianwei Yang (jianwyan@microsoft.com)
# --------------------------------------------------------
from functools import partial
from typing import Callable, Optional, Tuple
from typing import Callable, List, Optional, Tuple, Union
import torch
import torch.nn as nn
@ -26,6 +26,7 @@ import torch.nn as nn
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
from timm.layers import Mlp, DropPath, LayerNorm2d, trunc_normal_, ClassifierHead, NormMlpClassifierHead
from ._builder import build_model_with_cfg
from ._features import feature_take_indices
from ._manipulate import named_apply, checkpoint
from ._registry import generate_default_cfgs, register_model
@ -458,6 +459,72 @@ class FocalNet(nn.Module):
self.num_classes = num_classes
self.head.reset(num_classes, pool_type=global_pool)
def forward_intermediates(
self,
x: torch.Tensor,
indices: Optional[Union[int, List[int]]] = None,
norm: bool = False,
stop_early: bool = False,
output_fmt: str = 'NCHW',
intermediates_only: bool = False,
) -> Union[List[torch.Tensor], Tuple[torch.Tensor, List[torch.Tensor]]]:
""" Forward features that returns intermediates.
Args:
x: Input image tensor
indices: Take last n blocks if int, all if None, select matching indices if sequence
norm: Apply norm layer to compatible intermediates
stop_early: Stop iterating over blocks when last desired intermediate hit
output_fmt: Shape of intermediate feature outputs
intermediates_only: Only return intermediate features
Returns:
"""
assert output_fmt in ('NCHW',), 'Output shape must be NCHW.'
intermediates = []
take_indices, max_index = feature_take_indices(len(self.layers), indices)
# forward pass
x = self.stem(x)
if torch.jit.is_scripting() or not stop_early: # can't slice blocks in torchscript
stages = self.layers
else:
stages = self.layers[:max_index + 1]
last_idx = len(self.layers) - 1
for feat_idx, stage in enumerate(stages):
x = stage(x)
if feat_idx in take_indices:
if norm and feat_idx == last_idx:
x_inter = self.norm(x) # applying final norm to last intermediate
else:
x_inter = x
intermediates.append(x_inter)
if intermediates_only:
return intermediates
if feat_idx == last_idx:
x = self.norm(x)
return x, intermediates
def prune_intermediate_layers(
self,
indices: Union[int, List[int]] = 1,
prune_norm: bool = False,
prune_head: bool = True,
):
""" Prune layers not required for specified intermediates.
"""
take_indices, max_index = feature_take_indices(len(self.layers), indices)
self.layers = self.layers[:max_index + 1] # truncate blocks w/ stem as idx 0
if prune_norm:
self.norm = nn.Identity()
if prune_head:
self.reset_classifier(0, '')
return take_indices
def forward_features(self, x):
x = self.stem(x)
x = self.layers(x)

View File

@ -30,6 +30,7 @@ from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
from timm.layers import DropPath, to_2tuple, to_ntuple, Mlp, ClassifierHead, LayerNorm2d, \
get_attn, get_act_layer, get_norm_layer, RelPosBias, _assert
from ._builder import build_model_with_cfg
from ._features import feature_take_indices
from ._features_fx import register_notrace_function
from ._manipulate import named_apply, checkpoint
from ._registry import register_model, generate_default_cfgs
@ -397,7 +398,7 @@ class GlobalContextVit(nn.Module):
act_layer = get_act_layer(act_layer)
norm_layer = partial(get_norm_layer(norm_layer), eps=norm_eps)
norm_layer_cl = partial(get_norm_layer(norm_layer_cl), eps=norm_eps)
self.feature_info = []
img_size = to_2tuple(img_size)
feat_size = tuple(d // 4 for d in img_size) # stem reduction by 4
self.global_pool = global_pool
@ -441,6 +442,7 @@ class GlobalContextVit(nn.Module):
norm_layer=norm_layer,
norm_layer_cl=norm_layer_cl,
))
self.feature_info += [dict(num_chs=stages[-1].dim, reduction=2**(i+2), module=f'stages.{i}')]
self.stages = nn.Sequential(*stages)
# Classifier head
@ -494,6 +496,62 @@ class GlobalContextVit(nn.Module):
global_pool = self.head.global_pool.pool_type
self.head = ClassifierHead(self.num_features, num_classes, pool_type=global_pool, drop_rate=self.drop_rate)
def forward_intermediates(
self,
x: torch.Tensor,
indices: Optional[Union[int, List[int]]] = None,
norm: bool = False,
stop_early: bool = False,
output_fmt: str = 'NCHW',
intermediates_only: bool = False,
) -> Union[List[torch.Tensor], Tuple[torch.Tensor, List[torch.Tensor]]]:
""" Forward features that returns intermediates.
Args:
x: Input image tensor
indices: Take last n blocks if int, all if None, select matching indices if sequence
norm: Apply norm layer to compatible intermediates
stop_early: Stop iterating over blocks when last desired intermediate hit
output_fmt: Shape of intermediate feature outputs
intermediates_only: Only return intermediate features
Returns:
"""
assert output_fmt in ('NCHW',), 'Output shape must be NCHW.'
intermediates = []
take_indices, max_index = feature_take_indices(len(self.stages), indices)
# forward pass
x = self.stem(x)
if torch.jit.is_scripting() or not stop_early: # can't slice blocks in torchscript
stages = self.stages
else:
stages = self.stages[:max_index + 1]
for feat_idx, stage in enumerate(stages):
x = stage(x)
if feat_idx in take_indices:
intermediates.append(x)
if intermediates_only:
return intermediates
return x, intermediates
def prune_intermediate_layers(
self,
indices: Union[int, List[int]] = 1,
prune_norm: bool = False,
prune_head: bool = True,
):
""" Prune layers not required for specified intermediates.
"""
take_indices, max_index = feature_take_indices(len(self.stages), indices)
self.stages = self.stages[:max_index + 1] # truncate blocks w/ stem as idx 0
if prune_head:
self.reset_classifier(0, '')
return take_indices
def forward_features(self, x: torch.Tensor) -> torch.Tensor:
x = self.stem(x)
x = self.stages(x)
@ -509,9 +567,11 @@ class GlobalContextVit(nn.Module):
def _create_gcvit(variant, pretrained=False, **kwargs):
if kwargs.get('features_only', None):
raise RuntimeError('features_only not implemented for Vision Transformer models.')
model = build_model_with_cfg(GlobalContextVit, variant, pretrained, **kwargs)
model = build_model_with_cfg(
GlobalContextVit, variant, pretrained,
feature_cfg=dict(out_indices=(0, 1, 2, 3), flatten_sequential=True),
**kwargs
)
return model

View File

@ -6,7 +6,7 @@ The Paddle Implement of PP-HGNet (https://github.com/PaddlePaddle/PaddleClas/blo
PP-HGNet: https://github.com/PaddlePaddle/PaddleClas/blob/release/2.5.1/ppcls/arch/backbone/legendary_models/pp_hgnet.py
PP-HGNetv2: https://github.com/PaddlePaddle/PaddleClas/blob/release/2.5.1/ppcls/arch/backbone/legendary_models/pp_hgnet_v2.py
"""
from typing import Dict, Optional
from typing import Dict, List, Optional, Tuple, Union
import torch
import torch.nn as nn
@ -15,6 +15,7 @@ import torch.nn.functional as F
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
from timm.layers import SelectAdaptivePool2d, DropPath, create_conv2d
from ._builder import build_model_with_cfg
from ._features import feature_take_indices
from ._registry import register_model, generate_default_cfgs
from ._manipulate import checkpoint_seq
@ -508,6 +509,62 @@ class HighPerfGpuNet(nn.Module):
self.num_classes = num_classes
self.head.reset(num_classes, global_pool)
def forward_intermediates(
self,
x: torch.Tensor,
indices: Optional[Union[int, List[int]]] = None,
norm: bool = False,
stop_early: bool = False,
output_fmt: str = 'NCHW',
intermediates_only: bool = False,
) -> Union[List[torch.Tensor], Tuple[torch.Tensor, List[torch.Tensor]]]:
""" Forward features that returns intermediates.
Args:
x: Input image tensor
indices: Take last n blocks if int, all if None, select matching indices if sequence
norm: Apply norm layer to compatible intermediates
stop_early: Stop iterating over blocks when last desired intermediate hit
output_fmt: Shape of intermediate feature outputs
intermediates_only: Only return intermediate features
Returns:
"""
assert output_fmt in ('NCHW',), 'Output shape must be NCHW.'
intermediates = []
take_indices, max_index = feature_take_indices(len(self.stages), indices)
# forward pass
x = self.stem(x)
if torch.jit.is_scripting() or not stop_early: # can't slice blocks in torchscript
stages = self.stages
else:
stages = self.stages[:max_index + 1]
for feat_idx, stage in enumerate(stages):
x = stage(x)
if feat_idx in take_indices:
intermediates.append(x)
if intermediates_only:
return intermediates
return x, intermediates
def prune_intermediate_layers(
self,
indices: Union[int, List[int]] = 1,
prune_norm: bool = False,
prune_head: bool = True,
):
""" Prune layers not required for specified intermediates.
"""
take_indices, max_index = feature_take_indices(len(self.stages), indices)
self.stages = self.stages[:max_index + 1] # truncate blocks w/ stem as idx 0
if prune_head:
self.reset_classifier(0, 'avg')
return take_indices
def forward_features(self, x):
x = self.stem(x)
return self.stages(x)

View File

@ -4,7 +4,7 @@ Original implementation & weights from: https://github.com/sail-sg/inceptionnext
"""
from functools import partial
from typing import Optional
from typing import List, Optional, Tuple, Union
import torch
import torch.nn as nn
@ -12,6 +12,7 @@ import torch.nn as nn
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
from timm.layers import trunc_normal_, DropPath, to_2tuple, get_padding, SelectAdaptivePool2d
from ._builder import build_model_with_cfg
from ._features import feature_take_indices
from ._manipulate import checkpoint_seq
from ._registry import register_model, generate_default_cfgs
@ -349,6 +350,62 @@ class MetaNeXt(nn.Module):
def no_weight_decay(self):
return set()
def forward_intermediates(
self,
x: torch.Tensor,
indices: Optional[Union[int, List[int]]] = None,
norm: bool = False,
stop_early: bool = False,
output_fmt: str = 'NCHW',
intermediates_only: bool = False,
) -> Union[List[torch.Tensor], Tuple[torch.Tensor, List[torch.Tensor]]]:
""" Forward features that returns intermediates.
Args:
x: Input image tensor
indices: Take last n blocks if int, all if None, select matching indices if sequence
norm: Apply norm layer to compatible intermediates
stop_early: Stop iterating over blocks when last desired intermediate hit
output_fmt: Shape of intermediate feature outputs
intermediates_only: Only return intermediate features
Returns:
"""
assert output_fmt in ('NCHW',), 'Output shape must be NCHW.'
intermediates = []
take_indices, max_index = feature_take_indices(len(self.stages), indices)
# forward pass
x = self.stem(x)
if torch.jit.is_scripting() or not stop_early: # can't slice blocks in torchscript
stages = self.stages
else:
stages = self.stages[:max_index + 1]
for feat_idx, stage in enumerate(stages):
x = stage(x)
if feat_idx in take_indices:
intermediates.append(x)
if intermediates_only:
return intermediates
return x, intermediates
def prune_intermediate_layers(
self,
indices: Union[int, List[int]] = 1,
prune_norm: bool = False,
prune_head: bool = True,
):
""" Prune layers not required for specified intermediates.
"""
take_indices, max_index = feature_take_indices(len(self.stages), indices)
self.stages = self.stages[:max_index + 1] # truncate blocks w/ stem as idx 0
if prune_head:
self.reset_classifier(0, 'avg')
return take_indices
def forward_features(self, x):
x = self.stem(x)
x = self.stages(x)

View File

@ -3,6 +3,7 @@ Sourced from https://github.com/Cadene/tensorflow-model-zoo.torch (MIT License)
based upon Google's Tensorflow implementation and pretrained weights (Apache 2.0 License)
"""
from functools import partial
from typing import List, Optional, Tuple, Union
import torch
import torch.nn as nn
@ -10,6 +11,7 @@ import torch.nn as nn
from timm.data import IMAGENET_INCEPTION_MEAN, IMAGENET_INCEPTION_STD
from timm.layers import create_classifier, ConvNormAct
from ._builder import build_model_with_cfg
from ._features import feature_take_indices
from ._registry import register_model, generate_default_cfgs
__all__ = ['InceptionV4']
@ -285,6 +287,66 @@ class InceptionV4(nn.Module):
self.global_pool, self.last_linear = create_classifier(
self.num_features, self.num_classes, pool_type=global_pool)
def forward_intermediates(
self,
x: torch.Tensor,
indices: Optional[Union[int, List[int]]] = None,
norm: bool = False,
stop_early: bool = False,
output_fmt: str = 'NCHW',
intermediates_only: bool = False,
) -> Union[List[torch.Tensor], Tuple[torch.Tensor, List[torch.Tensor]]]:
""" Forward features that returns intermediates.
Args:
x: Input image tensor
indices: Take last n blocks if int, all if None, select matching indices if sequence
norm: Apply norm layer to compatible intermediates
stop_early: Stop iterating over blocks when last desired intermediate hit
output_fmt: Shape of intermediate feature outputs
intermediates_only: Only return intermediate features
Returns:
"""
assert output_fmt in ('NCHW',), 'Output shape must be NCHW.'
intermediates = []
stage_ends = [int(info['module'].split('.')[-1]) for info in self.feature_info]
take_indices, max_index = feature_take_indices(len(stage_ends), indices)
take_indices = [stage_ends[i] for i in take_indices]
max_index = stage_ends[max_index]
# forward pass
if torch.jit.is_scripting() or not stop_early: # can't slice blocks in torchscript
stages = self.features
else:
stages = self.features[:max_index + 1]
for feat_idx, stage in enumerate(stages):
x = stage(x)
if feat_idx in take_indices:
intermediates.append(x)
if intermediates_only:
return intermediates
return x, intermediates
def prune_intermediate_layers(
self,
indices: Union[int, List[int]] = 1,
prune_norm: bool = False,
prune_head: bool = True,
):
""" Prune layers not required for specified intermediates.
"""
stage_ends = [int(info['module'].split('.')[-1]) for info in self.feature_info]
take_indices, max_index = feature_take_indices(len(stage_ends), indices)
max_index = stage_ends[max_index]
self.features = self.features[:max_index + 1] # truncate blocks w/ stem as idx 0
if prune_head:
self.reset_classifier(0, '')
return take_indices
def forward_features(self, x):
return self.features(x)

View File

@ -6,7 +6,7 @@ MetaFormer (https://github.com/sail-sg/metaformer),
InceptionNeXt (https://github.com/sail-sg/inceptionnext)
"""
from collections import OrderedDict
from typing import Optional
from typing import List, Optional, Tuple, Union
import torch
from torch import nn
@ -14,6 +14,7 @@ from torch import nn
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
from timm.layers import trunc_normal_, DropPath, LayerNorm, LayerScale, ClNormMlpClassifierHead, get_act_layer
from ._builder import build_model_with_cfg
from ._features import feature_take_indices
from ._manipulate import checkpoint_seq
from ._registry import register_model, generate_default_cfgs
@ -417,6 +418,67 @@ class MambaOut(nn.Module):
self.num_classes = num_classes
self.head.reset(num_classes, global_pool)
def forward_intermediates(
self,
x: torch.Tensor,
indices: Optional[Union[int, List[int]]] = None,
norm: bool = False,
stop_early: bool = False,
output_fmt: str = 'NCHW',
intermediates_only: bool = False,
) -> Union[List[torch.Tensor], Tuple[torch.Tensor, List[torch.Tensor]]]:
""" Forward features that returns intermediates.
Args:
x: Input image tensor
indices: Take last n blocks if int, all if None, select matching indices if sequence
norm: Apply norm layer to compatible intermediates
stop_early: Stop iterating over blocks when last desired intermediate hit
output_fmt: Shape of intermediate feature outputs
intermediates_only: Only return intermediate features
Returns:
"""
assert output_fmt in ('NCHW', 'NHWC'), 'Output format must be one of NCHW or NHWC.'
channel_first = output_fmt == 'NCHW'
intermediates = []
take_indices, max_index = feature_take_indices(len(self.stages), indices)
# forward pass
x = self.stem(x)
if torch.jit.is_scripting() or not stop_early: # can't slice blocks in torchscript
stages = self.stages
else:
stages = self.stages[:max_index + 1]
for feat_idx, stage in enumerate(stages):
x = stage(x)
if feat_idx in take_indices:
intermediates.append(x)
if channel_first:
# reshape to BCHW output format
intermediates = [y.permute(0, 3, 1, 2).contiguous() for y in intermediates]
if intermediates_only:
return intermediates
return x, intermediates
def prune_intermediate_layers(
self,
indices: Union[int, List[int]] = 1,
prune_norm: bool = False,
prune_head: bool = True,
):
""" Prune layers not required for specified intermediates.
"""
take_indices, max_index = feature_take_indices(len(self.stages), indices)
self.stages = self.stages[:max_index + 1] # truncate blocks w/ stem as idx 0
if prune_head:
self.reset_classifier(0, '')
return take_indices
def forward_features(self, x):
x = self.stem(x)
x = self.stages(x)

View File

@ -1302,7 +1302,8 @@ class MaxxVit(nn.Module):
if intermediates_only:
return intermediates
x = self.norm(x)
if feat_idx == last_idx:
x = self.norm(x)
return x, intermediates

View File

@ -28,7 +28,7 @@ Adapted from https://github.com/sail-sg/metaformer, original copyright below
from collections import OrderedDict
from functools import partial
from typing import Optional
from typing import List, Optional, Tuple, Union
import torch
import torch.nn as nn
@ -40,6 +40,7 @@ from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
from timm.layers import trunc_normal_, DropPath, SelectAdaptivePool2d, GroupNorm1, LayerNorm, LayerNorm2d, Mlp, \
use_fused_attn
from ._builder import build_model_with_cfg
from ._features import feature_take_indices
from ._manipulate import checkpoint_seq
from ._registry import generate_default_cfgs, register_model
@ -597,6 +598,62 @@ class MetaFormer(nn.Module):
final = nn.Identity()
self.head.fc = final
def forward_intermediates(
self,
x: torch.Tensor,
indices: Optional[Union[int, List[int]]] = None,
norm: bool = False,
stop_early: bool = False,
output_fmt: str = 'NCHW',
intermediates_only: bool = False,
) -> Union[List[torch.Tensor], Tuple[torch.Tensor, List[torch.Tensor]]]:
""" Forward features that returns intermediates.
Args:
x: Input image tensor
indices: Take last n blocks if int, all if None, select matching indices if sequence
norm: Apply norm layer to compatible intermediates
stop_early: Stop iterating over blocks when last desired intermediate hit
output_fmt: Shape of intermediate feature outputs
intermediates_only: Only return intermediate features
Returns:
"""
assert output_fmt in ('NCHW',), 'Output shape must be NCHW.'
intermediates = []
take_indices, max_index = feature_take_indices(len(self.stages), indices)
# forward pass
x = self.stem(x)
if torch.jit.is_scripting() or not stop_early: # can't slice blocks in torchscript
stages = self.stages
else:
stages = self.stages[:max_index + 1]
for feat_idx, stage in enumerate(stages):
x = stage(x)
if feat_idx in take_indices:
intermediates.append(x)
if intermediates_only:
return intermediates
return x, intermediates
def prune_intermediate_layers(
self,
indices: Union[int, List[int]] = 1,
prune_norm: bool = False,
prune_head: bool = True,
):
""" Prune layers not required for specified intermediates.
"""
take_indices, max_index = feature_take_indices(len(self.stages), indices)
self.stages = self.stages[:max_index + 1] # truncate blocks w/ stem as idx 0
if prune_head:
self.reset_classifier(0, '')
return take_indices
def forward_head(self, x: Tensor, pre_logits: bool = False):
# NOTE nn.Sequential in head broken down since can't call head[:-1](x) in torchscript :(
x = self.head.global_pool(x)

View File

@ -870,10 +870,11 @@ class MultiScaleVit(nn.Module):
if self.pos_embed is not None:
x = x + self.pos_embed
for i, stage in enumerate(self.stages):
last_idx = len(self.stages) - 1
for feat_idx, stage in enumerate(self.stages):
x, feat_size = stage(x, feat_size)
if i in take_indices:
if norm and i == (len(self.stages) - 1):
if feat_idx in take_indices:
if norm and feat_idx == last_idx:
x_inter = self.norm(x) # applying final norm last intermediate
else:
x_inter = x
@ -887,7 +888,8 @@ class MultiScaleVit(nn.Module):
if intermediates_only:
return intermediates
x = self.norm(x)
if feat_idx == last_idx:
x = self.norm(x)
return x, intermediates

View File

@ -19,6 +19,7 @@ import collections.abc
import logging
import math
from functools import partial
from typing import List, Optional, Tuple, Union
import torch
import torch.nn.functional as F
@ -28,6 +29,7 @@ from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
from timm.layers import PatchEmbed, Mlp, DropPath, create_classifier, trunc_normal_, _assert
from timm.layers import create_conv2d, create_pool2d, to_ntuple, use_fused_attn, LayerNorm
from ._builder import build_model_with_cfg
from ._features import feature_take_indices
from ._features_fx import register_notrace_function
from ._manipulate import checkpoint_seq, named_apply
from ._registry import register_model, generate_default_cfgs, register_model_deprecations
@ -420,6 +422,73 @@ class Nest(nn.Module):
self.global_pool, self.head = create_classifier(
self.num_features, self.num_classes, pool_type=global_pool)
def forward_intermediates(
self,
x: torch.Tensor,
indices: Optional[Union[int, List[int]]] = None,
norm: bool = False,
stop_early: bool = False,
output_fmt: str = 'NCHW',
intermediates_only: bool = False,
) -> Union[List[torch.Tensor], Tuple[torch.Tensor, List[torch.Tensor]]]:
""" Forward features that returns intermediates.
Args:
x: Input image tensor
indices: Take last n blocks if int, all if None, select matching indices if sequence
norm: Apply norm layer to compatible intermediates
stop_early: Stop iterating over blocks when last desired intermediate hit
output_fmt: Shape of intermediate feature outputs
intermediates_only: Only return intermediate features
Returns:
"""
assert output_fmt in ('NCHW',), 'Output shape must be NCHW.'
intermediates = []
take_indices, max_index = feature_take_indices(len(self.levels), indices)
# forward pass
x = self.patch_embed(x)
last_idx = len(self.num_blocks) - 1
if torch.jit.is_scripting() or not stop_early: # can't slice blocks in torchscript
stages = self.levels
else:
stages = self.levels[:max_index + 1]
for feat_idx, stage in enumerate(stages):
x = stage(x)
if feat_idx in take_indices:
if norm and feat_idx == last_idx:
x_inter = self.norm(x.permute(0, 2, 3, 1)).permute(0, 3, 1, 2)
intermediates.append(x_inter)
else:
intermediates.append(x)
if intermediates_only:
return intermediates
if feat_idx == last_idx:
# Layer norm done over channel dim only (to NHWC and back)
x = self.norm(x.permute(0, 2, 3, 1)).permute(0, 3, 1, 2)
return x, intermediates
def prune_intermediate_layers(
self,
indices: Union[int, List[int]] = 1,
prune_norm: bool = False,
prune_head: bool = True,
):
""" Prune layers not required for specified intermediates.
"""
take_indices, max_index = feature_take_indices(len(self.levels), indices)
self.levels = self.levels[:max_index + 1] # truncate blocks w/ stem as idx 0
if prune_norm:
self.norm = nn.Identity()
if prune_head:
self.reset_classifier(0, '')
return take_indices
def forward_features(self, x):
x = self.patch_embed(x)
x = self.levels(x)

View File

@ -6,7 +6,7 @@ Next-ViT model defs and weights adapted from https://github.com/bytedance/Next-V
"""
# Copyright (c) ByteDance Inc. All rights reserved.
from functools import partial
from typing import Optional
from typing import List, Optional, Tuple, Union
import torch
import torch.nn.functional as F
@ -16,6 +16,7 @@ from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
from timm.layers import DropPath, trunc_normal_, ConvMlp, get_norm_layer, get_act_layer, use_fused_attn
from timm.layers import ClassifierHead
from ._builder import build_model_with_cfg
from ._features import feature_take_indices
from ._features_fx import register_notrace_function
from ._manipulate import checkpoint_seq
from ._registry import generate_default_cfgs, register_model
@ -560,6 +561,72 @@ class NextViT(nn.Module):
self.num_classes = num_classes
self.head.reset(num_classes, pool_type=global_pool)
def forward_intermediates(
self,
x: torch.Tensor,
indices: Optional[Union[int, List[int]]] = None,
norm: bool = False,
stop_early: bool = False,
output_fmt: str = 'NCHW',
intermediates_only: bool = False,
) -> Union[List[torch.Tensor], Tuple[torch.Tensor, List[torch.Tensor]]]:
""" Forward features that returns intermediates.
Args:
x: Input image tensor
indices: Take last n blocks if int, all if None, select matching indices if sequence
norm: Apply norm layer to compatible intermediates
stop_early: Stop iterating over blocks when last desired intermediate hit
output_fmt: Shape of intermediate feature outputs
intermediates_only: Only return intermediate features
Returns:
"""
assert output_fmt in ('NCHW',), 'Output shape must be NCHW.'
intermediates = []
take_indices, max_index = feature_take_indices(len(self.stages), indices)
# forward pass
x = self.stem(x)
last_idx = len(self.stages) - 1
if torch.jit.is_scripting() or not stop_early: # can't slice blocks in torchscript
stages = self.stages
else:
stages = self.stages[:max_index + 1]
for feat_idx, stage in enumerate(stages):
x = stage(x)
if feat_idx in take_indices:
if feat_idx == last_idx:
x_inter = self.norm(x) if norm else x
intermediates.append(x_inter)
else:
intermediates.append(x)
if intermediates_only:
return intermediates
if feat_idx == last_idx:
x = self.norm(x)
return x, intermediates
def prune_intermediate_layers(
self,
indices: Union[int, List[int]] = 1,
prune_norm: bool = False,
prune_head: bool = True,
):
""" Prune layers not required for specified intermediates.
"""
take_indices, max_index = feature_take_indices(len(self.stages), indices)
self.stages = self.stages[:max_index + 1] # truncate blocks w/ stem as idx 0
if prune_norm:
self.norm = nn.Identity()
if prune_head:
self.reset_classifier(0, '')
return take_indices
def forward_features(self, x):
x = self.stem(x)
if self.grad_checkpointing and not torch.jit.is_scripting():

View File

@ -14,7 +14,7 @@ Modifications for timm by / Copyright 2020 Ross Wightman
import math
import re
from functools import partial
from typing import Optional, Sequence, Tuple
from typing import List, Optional, Sequence, Tuple, Union
import torch
from torch import nn
@ -22,6 +22,7 @@ from torch import nn
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
from timm.layers import trunc_normal_, to_2tuple
from ._builder import build_model_with_cfg
from ._features import feature_take_indices
from ._registry import register_model, generate_default_cfgs
from .vision_transformer import Block
@ -254,6 +255,71 @@ class PoolingVisionTransformer(nn.Module):
if self.head_dist is not None:
self.head_dist = nn.Linear(self.embed_dim, self.num_classes) if num_classes > 0 else nn.Identity()
def forward_intermediates(
self,
x: torch.Tensor,
indices: Optional[Union[int, List[int]]] = None,
norm: bool = False,
stop_early: bool = False,
output_fmt: str = 'NCHW',
intermediates_only: bool = False,
) -> Union[List[torch.Tensor], Tuple[torch.Tensor, List[torch.Tensor]]]:
""" Forward features that returns intermediates.
Args:
x: Input image tensor
indices: Take last n blocks if int, all if None, select matching indices if sequence
norm: Apply norm layer to compatible intermediates
stop_early: Stop iterating over blocks when last desired intermediate hit
output_fmt: Shape of intermediate feature outputs
intermediates_only: Only return intermediate features
Returns:
"""
assert output_fmt in ('NCHW',), 'Output shape must be NCHW.'
intermediates = []
take_indices, max_index = feature_take_indices(len(self.transformers), indices)
# forward pass
x = self.patch_embed(x)
x = self.pos_drop(x + self.pos_embed)
cls_tokens = self.cls_token.expand(x.shape[0], -1, -1)
last_idx = len(self.transformers) - 1
if torch.jit.is_scripting() or not stop_early: # can't slice blocks in torchscript
stages = self.transformers
else:
stages = self.transformers[:max_index + 1]
for feat_idx, stage in enumerate(stages):
x, cls_tokens = stage((x, cls_tokens))
if feat_idx in take_indices:
intermediates.append(x)
if intermediates_only:
return intermediates
if feat_idx == last_idx:
cls_tokens = self.norm(cls_tokens)
return cls_tokens, intermediates
def prune_intermediate_layers(
self,
indices: Union[int, List[int]] = 1,
prune_norm: bool = False,
prune_head: bool = True,
):
""" Prune layers not required for specified intermediates.
"""
take_indices, max_index = feature_take_indices(len(self.transformers), indices)
self.transformers = self.transformers[:max_index + 1] # truncate blocks w/ stem as idx 0
if prune_norm:
self.norm = nn.Identity()
if prune_head:
self.reset_classifier(0, '')
return take_indices
def forward_features(self, x):
x = self.patch_embed(x)
x = self.pos_drop(x + self.pos_embed)
@ -314,7 +380,7 @@ def _create_pit(variant, pretrained=False, **kwargs):
variant,
pretrained,
pretrained_filter_fn=checkpoint_filter_fn,
feature_cfg=dict(feature_cls='hook', no_rewrite=True, out_indices=out_indices),
feature_cfg=dict(feature_cls='hook', out_indices=out_indices),
**kwargs,
)
return model

View File

@ -16,7 +16,7 @@ Modifications and timm support by / Copyright 2022, Ross Wightman
"""
import math
from typing import Callable, List, Optional, Union
from typing import Callable, List, Optional, Tuple, Union
import torch
import torch.nn as nn
@ -25,6 +25,7 @@ import torch.nn.functional as F
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
from timm.layers import DropPath, to_2tuple, to_ntuple, trunc_normal_, LayerNorm, use_fused_attn
from ._builder import build_model_with_cfg
from ._features import feature_take_indices
from ._manipulate import checkpoint
from ._registry import register_model, generate_default_cfgs
@ -386,6 +387,62 @@ class PyramidVisionTransformerV2(nn.Module):
self.global_pool = global_pool
self.head = nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity()
def forward_intermediates(
self,
x: torch.Tensor,
indices: Optional[Union[int, List[int]]] = None,
norm: bool = False,
stop_early: bool = False,
output_fmt: str = 'NCHW',
intermediates_only: bool = False,
) -> Union[List[torch.Tensor], Tuple[torch.Tensor, List[torch.Tensor]]]:
""" Forward features that returns intermediates.
Args:
x: Input image tensor
indices: Take last n blocks if int, all if None, select matching indices if sequence
norm: Apply norm layer to compatible intermediates
stop_early: Stop iterating over blocks when last desired intermediate hit
output_fmt: Shape of intermediate feature outputs
intermediates_only: Only return intermediate features
Returns:
"""
assert output_fmt in ('NCHW',), 'Output shape must be NCHW.'
intermediates = []
take_indices, max_index = feature_take_indices(len(self.stages), indices)
# forward pass
x = self.patch_embed(x)
if torch.jit.is_scripting() or not stop_early: # can't slice blocks in torchscript
stages = self.stages
else:
stages = self.stages[:max_index + 1]
for feat_idx, stage in enumerate(stages):
x = stage(x)
if feat_idx in take_indices:
intermediates.append(x)
if intermediates_only:
return intermediates
return x, intermediates
def prune_intermediate_layers(
self,
indices: Union[int, List[int]] = 1,
prune_norm: bool = False,
prune_head: bool = True,
):
""" Prune layers not required for specified intermediates.
"""
take_indices, max_index = feature_take_indices(len(self.stages), indices)
self.stages = self.stages[:max_index + 1] # truncate blocks w/ stem as idx 0
if prune_head:
self.reset_classifier(0, '')
return take_indices
def forward_features(self, x):
x = self.patch_embed(x)
x = self.stages(x)

View File

@ -302,29 +302,33 @@ class RDNet(nn.Module):
"""
assert output_fmt in ('NCHW',), 'Output shape must be NCHW.'
intermediates = []
take_indices, max_index = feature_take_indices(len(self.dense_stages) + 1, indices)
stage_ends = [int(info['module'].split('.')[-1]) for info in self.feature_info]
take_indices, max_index = feature_take_indices(len(stage_ends), indices)
take_indices = [stage_ends[i] for i in take_indices]
max_index = stage_ends[max_index]
# forward pass
feat_idx = 0 # stem is index 0
x = self.stem(x)
if feat_idx in take_indices:
intermediates.append(x)
last_idx = len(self.dense_stages) - 1
if torch.jit.is_scripting() or not stop_early: # can't slice blocks in torchscript
dense_stages = self.dense_stages
else:
dense_stages = self.dense_stages[:max_index]
for stage in dense_stages:
feat_idx += 1
dense_stages = self.dense_stages[:max_index + 1]
for feat_idx, stage in enumerate(dense_stages):
x = stage(x)
if feat_idx in take_indices:
# NOTE not bothering to apply norm_pre when norm=True as almost no models have it enabled
intermediates.append(x)
if norm and feat_idx == last_idx:
x_inter = self.norm_pre(x) # applying final norm to last intermediate
else:
x_inter = x
intermediates.append(x_inter)
if intermediates_only:
return intermediates
x = self.norm_pre(x)
if feat_idx == last_idx:
x = self.norm_pre(x)
return x, intermediates
@ -336,8 +340,10 @@ class RDNet(nn.Module):
):
""" Prune layers not required for specified intermediates.
"""
take_indices, max_index = feature_take_indices(len(self.dense_stages) + 1, indices)
self.dense_stages = self.dense_stages[:max_index] # truncate blocks w/ stem as idx 0
stage_ends = [int(info['module'].split('.')[-1]) for info in self.feature_info]
take_indices, max_index = feature_take_indices(len(stage_ends), indices)
max_index = stage_ends[max_index]
self.dense_stages = self.dense_stages[:max_index + 1] # truncate blocks w/ stem as idx 0
if prune_norm:
self.norm_pre = nn.Identity()
if prune_head:
@ -355,6 +361,7 @@ class RDNet(nn.Module):
def forward_features(self, x):
x = self.stem(x)
x = self.dense_stages(x)
x = self.norm_pre(x)
return x
def forward_head(self, x, pre_logits: bool = False):

View File

@ -6,7 +6,7 @@ Original implementation: https://github.com/ChengpengChen/RepGhost
"""
import copy
from functools import partial
from typing import Optional
from typing import List, Optional, Tuple, Union
import torch
import torch.nn as nn
@ -16,6 +16,7 @@ from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
from timm.layers import SelectAdaptivePool2d, Linear, make_divisible
from ._builder import build_model_with_cfg
from ._efficientnet_blocks import SqueezeExcite, ConvBnAct
from ._features import feature_take_indices
from ._manipulate import checkpoint_seq
from ._registry import register_model, generate_default_cfgs
@ -294,6 +295,72 @@ class RepGhostNet(nn.Module):
self.flatten = nn.Flatten(1) if global_pool else nn.Identity() # don't flatten if pooling disabled
self.classifier = Linear(self.head_hidden_size, num_classes) if num_classes > 0 else nn.Identity()
def forward_intermediates(
self,
x: torch.Tensor,
indices: Optional[Union[int, List[int]]] = None,
norm: bool = False,
stop_early: bool = False,
output_fmt: str = 'NCHW',
intermediates_only: bool = False,
) -> Union[List[torch.Tensor], Tuple[torch.Tensor, List[torch.Tensor]]]:
""" Forward features that returns intermediates.
Args:
x: Input image tensor
indices: Take last n blocks if int, all if None, select matching indices if sequence
norm: Apply norm layer to compatible intermediates
stop_early: Stop iterating over blocks when last desired intermediate hit
output_fmt: Shape of intermediate feature outputs
intermediates_only: Only return intermediate features
Returns:
"""
assert output_fmt in ('NCHW',), 'Output shape must be NCHW.'
intermediates = []
stage_ends = [-1] + [int(info['module'].split('.')[-1]) for info in self.feature_info[1:]]
take_indices, max_index = feature_take_indices(len(stage_ends), indices)
take_indices = [stage_ends[i]+1 for i in take_indices]
max_index = stage_ends[max_index]
# forward pass
feat_idx = 0
x = self.conv_stem(x)
if feat_idx in take_indices:
intermediates.append(x)
x = self.bn1(x)
x = self.act1(x)
if torch.jit.is_scripting() or not stop_early: # can't slice blocks in torchscript
stages = self.blocks
else:
stages = self.blocks[:max_index + 1]
for feat_idx, stage in enumerate(stages, start=1):
x = stage(x)
if feat_idx in take_indices:
intermediates.append(x)
if intermediates_only:
return intermediates
return x, intermediates
def prune_intermediate_layers(
self,
indices: Union[int, List[int]] = 1,
prune_norm: bool = False,
prune_head: bool = True,
):
""" Prune layers not required for specified intermediates.
"""
stage_ends = [-1] + [int(info['module'].split('.')[-1]) for info in self.feature_info[1:]]
take_indices, max_index = feature_take_indices(len(stage_ends), indices)
max_index = stage_ends[max_index]
self.blocks = self.blocks[:max_index + 1] # truncate blocks w/ stem as idx 0
if prune_head:
self.reset_classifier(0, '')
return take_indices
def forward_features(self, x):
x = self.conv_stem(x)
x = self.bn1(x)

View File

@ -14,9 +14,7 @@ Paper: `RepViT: Revisiting Mobile CNN From ViT Perspective`
Adapted from official impl at https://github.com/jameslahm/RepViT
"""
__all__ = ['RepVit']
from typing import Optional
from typing import List, Optional, Tuple, Union
import torch
import torch.nn as nn
@ -24,9 +22,12 @@ import torch.nn as nn
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
from timm.layers import SqueezeExcite, trunc_normal_, to_ntuple, to_2tuple
from ._builder import build_model_with_cfg
from ._features import feature_take_indices
from ._manipulate import checkpoint_seq
from ._registry import register_model, generate_default_cfgs
__all__ = ['RepVit']
class ConvNorm(nn.Sequential):
def __init__(self, in_dim, out_dim, ks=1, stride=1, pad=0, dilation=1, groups=1, bn_weight_init=1):
@ -333,6 +334,62 @@ class RepVit(nn.Module):
def set_distilled_training(self, enable=True):
self.head.distilled_training = enable
def forward_intermediates(
self,
x: torch.Tensor,
indices: Optional[Union[int, List[int]]] = None,
norm: bool = False,
stop_early: bool = False,
output_fmt: str = 'NCHW',
intermediates_only: bool = False,
) -> Union[List[torch.Tensor], Tuple[torch.Tensor, List[torch.Tensor]]]:
""" Forward features that returns intermediates.
Args:
x: Input image tensor
indices: Take last n blocks if int, all if None, select matching indices if sequence
norm: Apply norm layer to compatible intermediates
stop_early: Stop iterating over blocks when last desired intermediate hit
output_fmt: Shape of intermediate feature outputs
intermediates_only: Only return intermediate features
Returns:
"""
assert output_fmt in ('NCHW',), 'Output shape must be NCHW.'
intermediates = []
take_indices, max_index = feature_take_indices(len(self.stages), indices)
# forward pass
x = self.stem(x)
if torch.jit.is_scripting() or not stop_early: # can't slice blocks in torchscript
stages = self.stages
else:
stages = self.stages[:max_index + 1]
for feat_idx, stage in enumerate(stages):
x = stage(x)
if feat_idx in take_indices:
intermediates.append(x)
if intermediates_only:
return intermediates
return x, intermediates
def prune_intermediate_layers(
self,
indices: Union[int, List[int]] = 1,
prune_norm: bool = False,
prune_head: bool = True,
):
""" Prune layers not required for specified intermediates.
"""
take_indices, max_index = feature_take_indices(len(self.stages), indices)
self.stages = self.stages[:max_index + 1] # truncate blocks w/ stem as idx 0
if prune_head:
self.reset_classifier(0, '')
return take_indices
def forward_features(self, x):
x = self.stem(x)
if self.grad_checkpointing and not torch.jit.is_scripting():

View File

@ -31,7 +31,7 @@ Original copyright of Google code below, modifications by Ross Wightman, Copyrig
from collections import OrderedDict # pylint: disable=g-importing-member
from functools import partial
from typing import Optional
from typing import List, Optional, Tuple, Union
import torch
import torch.nn as nn
@ -40,6 +40,7 @@ from timm.data import IMAGENET_INCEPTION_MEAN, IMAGENET_INCEPTION_STD
from timm.layers import GroupNormAct, BatchNormAct2d, EvoNorm2dS0, FilterResponseNormTlu2d, ClassifierHead, \
DropPath, AvgPool2dSame, create_pool2d, StdConv2d, create_conv2d, get_act_layer, get_norm_act_layer, make_divisible
from ._builder import build_model_with_cfg
from ._features import feature_take_indices
from ._manipulate import checkpoint_seq, named_apply, adapt_input_conv
from ._registry import generate_default_cfgs, register_model, register_model_deprecations
@ -543,6 +544,79 @@ class ResNetV2(nn.Module):
self.num_classes = num_classes
self.head.reset(num_classes, global_pool)
def forward_intermediates(
self,
x: torch.Tensor,
indices: Optional[Union[int, List[int]]] = None,
norm: bool = False,
stop_early: bool = False,
output_fmt: str = 'NCHW',
intermediates_only: bool = False,
) -> Union[List[torch.Tensor], Tuple[torch.Tensor, List[torch.Tensor]]]:
""" Forward features that returns intermediates.
Args:
x: Input image tensor
indices: Take last n blocks if int, all if None, select matching indices if sequence
norm: Apply norm layer to compatible intermediates
stop_early: Stop iterating over blocks when last desired intermediate hit
output_fmt: Shape of intermediate feature outputs
intermediates_only: Only return intermediate features
Returns:
"""
assert output_fmt in ('NCHW',), 'Output shape must be NCHW.'
intermediates = []
take_indices, max_index = feature_take_indices(5, indices)
# forward pass
feat_idx = 0
H, W = x.shape[-2:]
for stem in self.stem:
x = stem(x)
if x.shape[-2:] == (H //2, W //2):
x_down = x
if feat_idx in take_indices:
intermediates.append(x_down)
last_idx = len(self.stages)
if torch.jit.is_scripting() or not stop_early: # can't slice blocks in torchscript
stages = self.stages
else:
stages = self.stages[:max_index]
for feat_idx, stage in enumerate(stages, start=1):
x = stage(x)
if feat_idx in take_indices:
if feat_idx == last_idx:
x_inter = self.norm(x) if norm else x
intermediates.append(x_inter)
else:
intermediates.append(x)
if intermediates_only:
return intermediates
if feat_idx == last_idx:
x = self.norm(x)
return x, intermediates
def prune_intermediate_layers(
self,
indices: Union[int, List[int]] = 1,
prune_norm: bool = False,
prune_head: bool = True,
):
""" Prune layers not required for specified intermediates.
"""
take_indices, max_index = feature_take_indices(5, indices)
self.stages = self.stages[:max_index] # truncate blocks w/ stem as idx 0
if prune_norm:
self.norm = nn.Identity()
if prune_head:
self.reset_classifier(0, '')
return take_indices
def forward_features(self, x):
x = self.stem(x)
if self.grad_checkpointing and not torch.jit.is_scripting():

View File

@ -12,7 +12,7 @@ Copyright 2020 Ross Wightman
from functools import partial
from math import ceil
from typing import Optional
from typing import List, Optional, Tuple, Union
import torch
import torch.nn as nn
@ -21,6 +21,7 @@ from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
from timm.layers import ClassifierHead, create_act_layer, ConvNormAct, DropPath, make_divisible, SEModule
from ._builder import build_model_with_cfg
from ._efficientnet_builder import efficientnet_init_weights
from ._features import feature_take_indices
from ._manipulate import checkpoint_seq
from ._registry import generate_default_cfgs, register_model
@ -234,6 +235,67 @@ class RexNet(nn.Module):
self.num_classes = num_classes
self.head.reset(num_classes, global_pool)
def forward_intermediates(
self,
x: torch.Tensor,
indices: Optional[Union[int, List[int]]] = None,
norm: bool = False,
stop_early: bool = False,
output_fmt: str = 'NCHW',
intermediates_only: bool = False,
) -> Union[List[torch.Tensor], Tuple[torch.Tensor, List[torch.Tensor]]]:
""" Forward features that returns intermediates.
Args:
x: Input image tensor
indices: Take last n blocks if int, all if None, select matching indices if sequence
norm: Apply norm layer to compatible intermediates
stop_early: Stop iterating over blocks when last desired intermediate hit
output_fmt: Shape of intermediate feature outputs
intermediates_only: Only return intermediate features
Returns:
"""
assert output_fmt in ('NCHW',), 'Output shape must be NCHW.'
intermediates = []
stage_ends = [int(info['module'].split('.')[-1]) for info in self.feature_info]
take_indices, max_index = feature_take_indices(len(stage_ends), indices)
take_indices = [stage_ends[i] for i in take_indices]
max_index = stage_ends[max_index]
# forward pass
x = self.stem(x)
if torch.jit.is_scripting() or not stop_early: # can't slice blocks in torchscript
stages = self.features
else:
stages = self.features[:max_index + 1]
for feat_idx, stage in enumerate(stages):
x = stage(x)
if feat_idx in take_indices:
intermediates.append(x)
if intermediates_only:
return intermediates
return x, intermediates
def prune_intermediate_layers(
self,
indices: Union[int, List[int]] = 1,
prune_norm: bool = False,
prune_head: bool = True,
):
""" Prune layers not required for specified intermediates.
"""
stage_ends = [int(info['module'].split('.')[-1]) for info in self.feature_info]
take_indices, max_index = feature_take_indices(len(stage_ends), indices)
max_index = stage_ends[max_index]
self.features = self.features[:max_index + 1] # truncate blocks w/ stem as idx 0
if prune_head:
self.reset_classifier(0, '')
return take_indices
def forward_features(self, x):
x = self.stem(x)
if self.grad_checkpointing and not torch.jit.is_scripting():

View File

@ -10,7 +10,7 @@ __all__ = ['TinyVit']
import itertools
from functools import partial
from typing import Dict, Optional
from typing import Dict, List, Optional, Tuple, Union
import torch
import torch.nn as nn
@ -20,6 +20,7 @@ from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
from timm.layers import LayerNorm2d, NormMlpClassifierHead, DropPath,\
trunc_normal_, resize_rel_pos_bias_table_levit, use_fused_attn
from ._builder import build_model_with_cfg
from ._features import feature_take_indices
from ._features_fx import register_notrace_module
from ._manipulate import checkpoint_seq
from ._registry import register_model, generate_default_cfgs
@ -536,6 +537,62 @@ class TinyVit(nn.Module):
self.num_classes = num_classes
self.head.reset(num_classes, pool_type=global_pool)
def forward_intermediates(
self,
x: torch.Tensor,
indices: Optional[Union[int, List[int]]] = None,
norm: bool = False,
stop_early: bool = False,
output_fmt: str = 'NCHW',
intermediates_only: bool = False,
) -> Union[List[torch.Tensor], Tuple[torch.Tensor, List[torch.Tensor]]]:
""" Forward features that returns intermediates.
Args:
x: Input image tensor
indices: Take last n blocks if int, all if None, select matching indices if sequence
norm: Apply norm layer to compatible intermediates
stop_early: Stop iterating over blocks when last desired intermediate hit
output_fmt: Shape of intermediate feature outputs
intermediates_only: Only return intermediate features
Returns:
"""
assert output_fmt in ('NCHW',), 'Output shape must be NCHW.'
intermediates = []
take_indices, max_index = feature_take_indices(len(self.stages), indices)
# forward pass
x = self.patch_embed(x)
if torch.jit.is_scripting() or not stop_early: # can't slice blocks in torchscript
stages = self.stages
else:
stages = self.stages[:max_index + 1]
for feat_idx, stage in enumerate(stages):
x = stage(x)
if feat_idx in take_indices:
intermediates.append(x)
if intermediates_only:
return intermediates
return x, intermediates
def prune_intermediate_layers(
self,
indices: Union[int, List[int]] = 1,
prune_norm: bool = False,
prune_head: bool = True,
):
""" Prune layers not required for specified intermediates.
"""
take_indices, max_index = feature_take_indices(len(self.stages), indices)
self.stages = self.stages[:max_index + 1] # truncate blocks w/ stem as idx 0
if prune_head:
self.reset_classifier(0, '')
return take_indices
def forward_features(self, x):
x = self.patch_embed(x)
if self.grad_checkpointing and not torch.jit.is_scripting():

View File

@ -7,13 +7,14 @@ Original model: https://github.com/mrT23/TResNet
"""
from collections import OrderedDict
from functools import partial
from typing import Optional
from typing import List, Optional, Tuple, Union
import torch
import torch.nn as nn
from timm.layers import SpaceToDepth, BlurPool2d, ClassifierHead, SEModule, ConvNormAct, DropPath
from ._builder import build_model_with_cfg
from ._features import feature_take_indices
from ._manipulate import checkpoint_seq
from ._registry import register_model, generate_default_cfgs, register_model_deprecations
@ -228,6 +229,65 @@ class TResNet(nn.Module):
self.num_classes = num_classes
self.head.reset(num_classes, pool_type=global_pool)
def forward_intermediates(
self,
x: torch.Tensor,
indices: Optional[Union[int, List[int]]] = None,
norm: bool = False,
stop_early: bool = False,
output_fmt: str = 'NCHW',
intermediates_only: bool = False,
) -> Union[List[torch.Tensor], Tuple[torch.Tensor, List[torch.Tensor]]]:
""" Forward features that returns intermediates.
Args:
x: Input image tensor
indices: Take last n blocks if int, all if None, select matching indices if sequence
norm: Apply norm layer to compatible intermediates
stop_early: Stop iterating over blocks when last desired intermediate hit
output_fmt: Shape of intermediate feature outputs
intermediates_only: Only return intermediate features
Returns:
"""
assert output_fmt in ('NCHW',), 'Output shape must be NCHW.'
intermediates = []
stage_ends = [1, 2, 3, 4, 5]
take_indices, max_index = feature_take_indices(len(stage_ends), indices)
take_indices = [stage_ends[i] for i in take_indices]
max_index = stage_ends[max_index]
# forward pass
if torch.jit.is_scripting() or not stop_early: # can't slice blocks in torchscript
stages = self.body
else:
stages = self.body[:max_index + 1]
for feat_idx, stage in enumerate(stages):
x = stage(x)
if feat_idx in take_indices:
intermediates.append(x)
if intermediates_only:
return intermediates
return x, intermediates
def prune_intermediate_layers(
self,
indices: Union[int, List[int]] = 1,
prune_norm: bool = False,
prune_head: bool = True,
):
""" Prune layers not required for specified intermediates.
"""
stage_ends = [1, 2, 3, 4, 5]
take_indices, max_index = feature_take_indices(len(stage_ends), indices)
max_index = stage_ends[max_index]
self.body = self.body[:max_index + 1] # truncate blocks w/ stem as idx 0
if prune_head:
self.reset_classifier(0, '')
return take_indices
def forward_features(self, x):
if self.grad_checkpointing and not torch.jit.is_scripting():
x = self.body.s2d(x)

View File

@ -11,7 +11,7 @@ for some reference, rewrote most of the code.
Hacked together by / Copyright 2020 Ross Wightman
"""
from typing import List, Optional
from typing import List, Optional, Tuple, Union
import torch
import torch.nn as nn
@ -20,6 +20,7 @@ from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
from timm.layers import ConvNormAct, SeparableConvNormAct, BatchNormAct2d, ClassifierHead, DropPath, \
create_attn, create_norm_act_layer
from ._builder import build_model_with_cfg
from ._features import feature_take_indices
from ._manipulate import checkpoint_seq
from ._registry import register_model, generate_default_cfgs
@ -264,6 +265,67 @@ class VovNet(nn.Module):
self.num_classes = num_classes
self.head.reset(num_classes, global_pool)
def forward_intermediates(
self,
x: torch.Tensor,
indices: Optional[Union[int, List[int]]] = None,
norm: bool = False,
stop_early: bool = False,
output_fmt: str = 'NCHW',
intermediates_only: bool = False,
) -> Union[List[torch.Tensor], Tuple[torch.Tensor, List[torch.Tensor]]]:
""" Forward features that returns intermediates.
Args:
x: Input image tensor
indices: Take last n blocks if int, all if None, select matching indices if sequence
norm: Apply norm layer to compatible intermediates
stop_early: Stop iterating over blocks when last desired intermediate hit
output_fmt: Shape of intermediate feature outputs
intermediates_only: Only return intermediate features
Returns:
"""
assert output_fmt in ('NCHW',), 'Output shape must be NCHW.'
intermediates = []
take_indices, max_index = feature_take_indices(5, indices)
# forward pass
feat_idx = 0
x = self.stem[:-1](x)
if feat_idx in take_indices:
intermediates.append(x)
x = self.stem[-1](x)
if torch.jit.is_scripting() or not stop_early: # can't slice blocks in torchscript
stages = self.stages
else:
stages = self.stages[:max_index]
for feat_idx, stage in enumerate(stages, start=1):
x = stage(x)
if feat_idx in take_indices:
intermediates.append(x)
if intermediates_only:
return intermediates
return x, intermediates
def prune_intermediate_layers(
self,
indices: Union[int, List[int]] = 1,
prune_norm: bool = False,
prune_head: bool = True,
):
""" Prune layers not required for specified intermediates.
"""
take_indices, max_index = feature_take_indices(5, indices)
self.stages = self.stages[:max_index] # truncate blocks w/ stem as idx 0
if prune_head:
self.reset_classifier(0, '')
return take_indices
def forward_features(self, x):
x = self.stem(x)
return self.stages(x)

View File

@ -494,7 +494,8 @@ class Xcit(nn.Module):
# NOTE not supporting return of class tokens
x = torch.cat((self.cls_token.expand(B, -1, -1), x), dim=1)
for blk in self.cls_attn_blocks:
x = blk(x)
x = blk(x)
x = self.norm(x)
return x, intermediates