Merge branch 'main' into tnt

This commit is contained in:
Ryan 2025-05-11 22:45:48 +08:00
commit 69b1fbcdc1
31 changed files with 1573 additions and 59 deletions

View File

@ -566,7 +566,7 @@ Model validation results can be found in the [results tables](results/README.md)
The official documentation can be found at https://huggingface.co/docs/hub/timm. Documentation contributions are welcome. The official documentation can be found at https://huggingface.co/docs/hub/timm. Documentation contributions are welcome.
[Getting Started with PyTorch Image Models (timm): A Practitioners Guide](https://towardsdatascience.com/getting-started-with-pytorch-image-models-timm-a-practitioners-guide-4e77b4bf9055) by [Chris Hughes](https://github.com/Chris-hughes10) is an extensive blog post covering many aspects of `timm` in detail. [Getting Started with PyTorch Image Models (timm): A Practitioners Guide](https://towardsdatascience.com/getting-started-with-pytorch-image-models-timm-a-practitioners-guide-4e77b4bf9055-2/) by [Chris Hughes](https://github.com/Chris-hughes10) is an extensive blog post covering many aspects of `timm` in detail.
[timmdocs](http://timm.fast.ai/) is an alternate set of documentation for `timm`. A big thanks to [Aman Arora](https://github.com/amaarora) for his efforts creating timmdocs. [timmdocs](http://timm.fast.ai/) is an alternate set of documentation for `timm`. A big thanks to [Aman Arora](https://github.com/amaarora) for his efforts creating timmdocs.

View File

@ -54,6 +54,9 @@ FEAT_INTER_FILTERS = [
'beit', 'mvitv2', 'eva', 'cait', 'xcit', 'volo', 'twins', 'deit', 'swin_transformer', 'swin_transformer_v2', 'beit', 'mvitv2', 'eva', 'cait', 'xcit', 'volo', 'twins', 'deit', 'swin_transformer', 'swin_transformer_v2',
'swin_transformer_v2_cr', 'maxxvit', 'efficientnet', 'mobilenetv3', 'levit', 'efficientformer', 'resnet', 'swin_transformer_v2_cr', 'maxxvit', 'efficientnet', 'mobilenetv3', 'levit', 'efficientformer', 'resnet',
'regnet', 'byobnet', 'byoanet', 'mlp_mixer', 'hiera', 'fastvit', 'hieradet_sam2', 'aimv2*', 'tnt', 'regnet', 'byobnet', 'byoanet', 'mlp_mixer', 'hiera', 'fastvit', 'hieradet_sam2', 'aimv2*', 'tnt',
'tiny_vit', 'vovnet', 'tresnet', 'rexnet', 'resnetv2', 'repghost', 'repvit', 'pvt_v2', 'nextvit', 'nest',
'mambaout', 'inception_next', 'inception_v4', 'hgnet', 'gcvit', 'focalnet', 'efficientformer_v2', 'edgenext',
'davit', 'rdnet', 'convnext', 'pit'
] ]
# transformer / hybrid models don't support full set of spatial / feature APIs and/or have spatial output. # transformer / hybrid models don't support full set of spatial / feature APIs and/or have spatial output.
@ -508,8 +511,9 @@ def test_model_forward_intermediates(model_name, batch_size):
spatial_axis = get_spatial_dim(output_fmt) spatial_axis = get_spatial_dim(output_fmt)
import math import math
inpt = torch.randn((batch_size, *input_size))
output, intermediates = model.forward_intermediates( output, intermediates = model.forward_intermediates(
torch.randn((batch_size, *input_size)), inpt,
output_fmt=output_fmt, output_fmt=output_fmt,
) )
assert len(expected_channels) == len(intermediates) assert len(expected_channels) == len(intermediates)
@ -521,6 +525,9 @@ def test_model_forward_intermediates(model_name, batch_size):
assert o.shape[0] == batch_size assert o.shape[0] == batch_size
assert not torch.isnan(o).any() assert not torch.isnan(o).any()
output2 = model.forward_features(inpt)
assert torch.allclose(output, output2)
def _create_fx_model(model, train=False): def _create_fx_model(model, train=False):
# This block of code does a bit of juggling to handle any case where there are multiple outputs in train mode # This block of code does a bit of juggling to handle any case where there are multiple outputs in train mode

View File

@ -144,6 +144,7 @@ def create_dataset(
use_train = split in _TRAIN_SYNONYM use_train = split in _TRAIN_SYNONYM
ds = QMNIST(train=use_train, **torch_kwargs) ds = QMNIST(train=use_train, **torch_kwargs)
elif name == 'imagenet': elif name == 'imagenet':
torch_kwargs.pop('download')
assert has_imagenet, 'Please update to a newer PyTorch and torchvision for ImageNet dataset.' assert has_imagenet, 'Please update to a newer PyTorch and torchvision for ImageNet dataset.'
if split in _EVAL_SYNONYM: if split in _EVAL_SYNONYM:
split = 'val' split = 'val'

View File

@ -452,29 +452,29 @@ class ConvNeXt(nn.Module):
""" """
assert output_fmt in ('NCHW',), 'Output shape must be NCHW.' assert output_fmt in ('NCHW',), 'Output shape must be NCHW.'
intermediates = [] intermediates = []
take_indices, max_index = feature_take_indices(len(self.stages) + 1, indices) take_indices, max_index = feature_take_indices(len(self.stages), indices)
# forward pass # forward pass
feat_idx = 0 # stem is index 0
x = self.stem(x) x = self.stem(x)
if feat_idx in take_indices:
intermediates.append(x)
last_idx = len(self.stages) - 1
if torch.jit.is_scripting() or not stop_early: # can't slice blocks in torchscript if torch.jit.is_scripting() or not stop_early: # can't slice blocks in torchscript
stages = self.stages stages = self.stages
else: else:
stages = self.stages[:max_index] stages = self.stages[:max_index + 1]
for stage in stages: for feat_idx, stage in enumerate(stages):
feat_idx += 1
x = stage(x) x = stage(x)
if feat_idx in take_indices: if feat_idx in take_indices:
# NOTE not bothering to apply norm_pre when norm=True as almost no models have it enabled if norm and feat_idx == last_idx:
intermediates.append(x) intermediates.append(self.norm_pre(x))
else:
intermediates.append(x)
if intermediates_only: if intermediates_only:
return intermediates return intermediates
x = self.norm_pre(x) if feat_idx == last_idx:
x = self.norm_pre(x)
return x, intermediates return x, intermediates
@ -486,8 +486,8 @@ class ConvNeXt(nn.Module):
): ):
""" Prune layers not required for specified intermediates. """ Prune layers not required for specified intermediates.
""" """
take_indices, max_index = feature_take_indices(len(self.stages) + 1, indices) take_indices, max_index = feature_take_indices(len(self.stages), indices)
self.stages = self.stages[:max_index] # truncate blocks w/ stem as idx 0 self.stages = self.stages[:max_index + 1] # truncate blocks w/ stem as idx 0
if prune_norm: if prune_norm:
self.norm_pre = nn.Identity() self.norm_pre = nn.Identity()
if prune_head: if prune_head:

View File

@ -12,7 +12,7 @@ DaViT model defs and weights adapted from https://github.com/dingmyu/davit, orig
# All rights reserved. # All rights reserved.
# This source code is licensed under the MIT license # This source code is licensed under the MIT license
from functools import partial from functools import partial
from typing import Optional, Tuple from typing import List, Optional, Tuple, Union
import torch import torch
import torch.nn as nn import torch.nn as nn
@ -23,6 +23,7 @@ from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
from timm.layers import DropPath, to_2tuple, trunc_normal_, Mlp, LayerNorm2d, get_norm_layer, use_fused_attn from timm.layers import DropPath, to_2tuple, trunc_normal_, Mlp, LayerNorm2d, get_norm_layer, use_fused_attn
from timm.layers import NormMlpClassifierHead, ClassifierHead from timm.layers import NormMlpClassifierHead, ClassifierHead
from ._builder import build_model_with_cfg from ._builder import build_model_with_cfg
from ._features import feature_take_indices
from ._features_fx import register_notrace_function from ._features_fx import register_notrace_function
from ._manipulate import checkpoint_seq from ._manipulate import checkpoint_seq
from ._registry import generate_default_cfgs, register_model from ._registry import generate_default_cfgs, register_model
@ -636,6 +637,72 @@ class DaVit(nn.Module):
self.num_classes = num_classes self.num_classes = num_classes
self.head.reset(num_classes, global_pool) self.head.reset(num_classes, global_pool)
def forward_intermediates(
self,
x: torch.Tensor,
indices: Optional[Union[int, List[int]]] = None,
norm: bool = False,
stop_early: bool = False,
output_fmt: str = 'NCHW',
intermediates_only: bool = False,
) -> Union[List[torch.Tensor], Tuple[torch.Tensor, List[torch.Tensor]]]:
""" Forward features that returns intermediates.
Args:
x: Input image tensor
indices: Take last n blocks if int, all if None, select matching indices if sequence
norm: Apply norm layer to compatible intermediates
stop_early: Stop iterating over blocks when last desired intermediate hit
output_fmt: Shape of intermediate feature outputs
intermediates_only: Only return intermediate features
Returns:
"""
assert output_fmt in ('NCHW',), 'Output shape must be NCHW.'
intermediates = []
take_indices, max_index = feature_take_indices(len(self.stages), indices)
# forward pass
x = self.stem(x)
last_idx = len(self.stages) - 1
if torch.jit.is_scripting() or not stop_early: # can't slice blocks in torchscript
stages = self.stages
else:
stages = self.stages[:max_index + 1]
for feat_idx, stage in enumerate(stages):
x = stage(x)
if feat_idx in take_indices:
if norm and feat_idx == last_idx:
x_inter = self.norm_pre(x) # applying final norm to last intermediate
else:
x_inter = x
intermediates.append(x_inter)
if intermediates_only:
return intermediates
if feat_idx == last_idx:
x = self.norm_pre(x)
return x, intermediates
def prune_intermediate_layers(
self,
indices: Union[int, List[int]] = 1,
prune_norm: bool = False,
prune_head: bool = True,
):
""" Prune layers not required for specified intermediates.
"""
take_indices, max_index = feature_take_indices(len(self.stages), indices)
self.stages = self.stages[:max_index + 1] # truncate blocks w/ stem as idx 0
if prune_norm:
self.norm_pre = nn.Identity()
if prune_head:
self.reset_classifier(0, '')
return take_indices
def forward_features(self, x): def forward_features(self, x):
x = self.stem(x) x = self.stem(x)
if self.grad_checkpointing and not torch.jit.is_scripting(): if self.grad_checkpointing and not torch.jit.is_scripting():

View File

@ -9,7 +9,7 @@ Modifications and additions for timm by / Copyright 2022, Ross Wightman
""" """
import math import math
from functools import partial from functools import partial
from typing import Optional, Tuple from typing import List, Optional, Tuple, Union
import torch import torch
import torch.nn.functional as F import torch.nn.functional as F
@ -19,6 +19,7 @@ from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
from timm.layers import trunc_normal_tf_, DropPath, LayerNorm2d, Mlp, create_conv2d, \ from timm.layers import trunc_normal_tf_, DropPath, LayerNorm2d, Mlp, create_conv2d, \
NormMlpClassifierHead, ClassifierHead NormMlpClassifierHead, ClassifierHead
from ._builder import build_model_with_cfg from ._builder import build_model_with_cfg
from ._features import feature_take_indices
from ._features_fx import register_notrace_module from ._features_fx import register_notrace_module
from ._manipulate import named_apply, checkpoint_seq from ._manipulate import named_apply, checkpoint_seq
from ._registry import register_model, generate_default_cfgs from ._registry import register_model, generate_default_cfgs
@ -418,6 +419,72 @@ class EdgeNeXt(nn.Module):
self.num_classes = num_classes self.num_classes = num_classes
self.head.reset(num_classes, global_pool) self.head.reset(num_classes, global_pool)
def forward_intermediates(
self,
x: torch.Tensor,
indices: Optional[Union[int, List[int]]] = None,
norm: bool = False,
stop_early: bool = False,
output_fmt: str = 'NCHW',
intermediates_only: bool = False,
) -> Union[List[torch.Tensor], Tuple[torch.Tensor, List[torch.Tensor]]]:
""" Forward features that returns intermediates.
Args:
x: Input image tensor
indices: Take last n blocks if int, all if None, select matching indices if sequence
norm: Apply norm layer to compatible intermediates
stop_early: Stop iterating over blocks when last desired intermediate hit
output_fmt: Shape of intermediate feature outputs
intermediates_only: Only return intermediate features
Returns:
"""
assert output_fmt in ('NCHW',), 'Output shape must be NCHW.'
intermediates = []
take_indices, max_index = feature_take_indices(len(self.stages), indices)
# forward pass
x = self.stem(x)
last_idx = len(self.stages) - 1
if torch.jit.is_scripting() or not stop_early: # can't slice blocks in torchscript
stages = self.stages
else:
stages = self.stages[:max_index + 1]
for feat_idx, stage in enumerate(stages):
x = stage(x)
if feat_idx in take_indices:
if norm and feat_idx == last_idx:
x_inter = self.norm_pre(x) # applying final norm to last intermediate
else:
x_inter = x
intermediates.append(x_inter)
if intermediates_only:
return intermediates
if feat_idx == last_idx:
x = self.norm_pre(x)
return x, intermediates
def prune_intermediate_layers(
self,
indices: Union[int, List[int]] = 1,
prune_norm: bool = False,
prune_head: bool = True,
):
""" Prune layers not required for specified intermediates.
"""
take_indices, max_index = feature_take_indices(len(self.stages), indices)
self.stages = self.stages[:max_index + 1] # truncate blocks w/ stem as idx 0
if prune_norm:
self.norm_pre = nn.Identity()
if prune_head:
self.reset_classifier(0, '')
return take_indices
def forward_features(self, x): def forward_features(self, x):
x = self.stem(x) x = self.stem(x)
x = self.stages(x) x = self.stages(x)

View File

@ -16,7 +16,7 @@ Modifications and timm support by / Copyright 2023, Ross Wightman
""" """
import math import math
from functools import partial from functools import partial
from typing import Dict, Optional from typing import Dict, List, Optional, Tuple, Union
import torch import torch
import torch.nn as nn import torch.nn as nn
@ -25,6 +25,7 @@ from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
from timm.layers import create_conv2d, create_norm_layer, get_act_layer, get_norm_layer, ConvNormAct from timm.layers import create_conv2d, create_norm_layer, get_act_layer, get_norm_layer, ConvNormAct
from timm.layers import DropPath, trunc_normal_, to_2tuple, to_ntuple, ndgrid from timm.layers import DropPath, trunc_normal_, to_2tuple, to_ntuple, ndgrid
from ._builder import build_model_with_cfg from ._builder import build_model_with_cfg
from ._features import feature_take_indices
from ._manipulate import checkpoint_seq from ._manipulate import checkpoint_seq
from ._registry import generate_default_cfgs, register_model from ._registry import generate_default_cfgs, register_model
@ -625,6 +626,73 @@ class EfficientFormerV2(nn.Module):
def set_distilled_training(self, enable=True): def set_distilled_training(self, enable=True):
self.distilled_training = enable self.distilled_training = enable
def forward_intermediates(
self,
x: torch.Tensor,
indices: Optional[Union[int, List[int]]] = None,
norm: bool = False,
stop_early: bool = False,
output_fmt: str = 'NCHW',
intermediates_only: bool = False,
) -> Union[List[torch.Tensor], Tuple[torch.Tensor, List[torch.Tensor]]]:
""" Forward features that returns intermediates.
Args:
x: Input image tensor
indices: Take last n blocks if int, all if None, select matching indices if sequence
norm: Apply norm layer to compatible intermediates
stop_early: Stop iterating over blocks when last desired intermediate hit
output_fmt: Shape of intermediate feature outputs
intermediates_only: Only return intermediate features
Returns:
"""
assert output_fmt in ('NCHW',), 'Output shape must be NCHW.'
intermediates = []
take_indices, max_index = feature_take_indices(len(self.stages), indices)
# forward pass
x = self.stem(x)
last_idx = len(self.stages) - 1
if torch.jit.is_scripting() or not stop_early: # can't slice blocks in torchscript
stages = self.stages
else:
stages = self.stages[:max_index + 1]
for feat_idx, stage in enumerate(stages):
x = stage(x)
if feat_idx in take_indices:
if feat_idx == last_idx:
x_inter = self.norm(x) if norm else x
intermediates.append(x_inter)
else:
intermediates.append(x)
if intermediates_only:
return intermediates
if feat_idx == last_idx:
x = self.norm(x)
return x, intermediates
def prune_intermediate_layers(
self,
indices: Union[int, List[int]] = 1,
prune_norm: bool = False,
prune_head: bool = True,
):
""" Prune layers not required for specified intermediates.
"""
take_indices, max_index = feature_take_indices(len(self.stages), indices)
self.stages = self.stages[:max_index + 1] # truncate blocks w/ stem as idx 0
if prune_norm:
self.norm = nn.Identity()
if prune_head:
self.reset_classifier(0, '')
return take_indices
def forward_features(self, x): def forward_features(self, x):
x = self.stem(x) x = self.stem(x)
x = self.stages(x) x = self.stages(x)

View File

@ -7,7 +7,7 @@ Adapted from official impl at https://github.com/mit-han-lab/efficientvit
""" """
__all__ = ['EfficientVit', 'EfficientVitLarge'] __all__ = ['EfficientVit', 'EfficientVitLarge']
from typing import List, Optional from typing import List, Optional, Tuple, Union
from functools import partial from functools import partial
import torch import torch
@ -17,6 +17,7 @@ import torch.nn.functional as F
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
from timm.layers import SelectAdaptivePool2d, create_conv2d, GELUTanh from timm.layers import SelectAdaptivePool2d, create_conv2d, GELUTanh
from ._builder import build_model_with_cfg from ._builder import build_model_with_cfg
from ._features import feature_take_indices
from ._features_fx import register_notrace_module from ._features_fx import register_notrace_module
from ._manipulate import checkpoint_seq from ._manipulate import checkpoint_seq
from ._registry import register_model, generate_default_cfgs from ._registry import register_model, generate_default_cfgs
@ -754,6 +755,63 @@ class EfficientVit(nn.Module):
self.num_classes = num_classes self.num_classes = num_classes
self.head.reset(num_classes, global_pool) self.head.reset(num_classes, global_pool)
def forward_intermediates(
self,
x: torch.Tensor,
indices: Optional[Union[int, List[int]]] = None,
norm: bool = False,
stop_early: bool = False,
output_fmt: str = 'NCHW',
intermediates_only: bool = False,
) -> Union[List[torch.Tensor], Tuple[torch.Tensor, List[torch.Tensor]]]:
""" Forward features that returns intermediates.
Args:
x: Input image tensor
indices: Take last n blocks if int, all if None, select matching indices if sequence
norm: Apply norm layer to compatible intermediates
stop_early: Stop iterating over blocks when last desired intermediate hit
output_fmt: Shape of intermediate feature outputs
intermediates_only: Only return intermediate features
Returns:
"""
assert output_fmt in ('NCHW',), 'Output shape must be NCHW.'
intermediates = []
take_indices, max_index = feature_take_indices(len(self.stages), indices)
# forward pass
x = self.stem(x)
if torch.jit.is_scripting() or not stop_early: # can't slice blocks in torchscript
stages = self.stages
else:
stages = self.stages[:max_index + 1]
for feat_idx, stage in enumerate(stages):
x = stage(x)
if feat_idx in take_indices:
intermediates.append(x)
if intermediates_only:
return intermediates
return x, intermediates
def prune_intermediate_layers(
self,
indices: Union[int, List[int]] = 1,
prune_norm: bool = False,
prune_head: bool = True,
):
""" Prune layers not required for specified intermediates.
"""
take_indices, max_index = feature_take_indices(len(self.stages), indices)
self.stages = self.stages[:max_index + 1] # truncate blocks w/ stem as idx 0
if prune_head:
self.reset_classifier(0, '')
return take_indices
def forward_features(self, x): def forward_features(self, x):
x = self.stem(x) x = self.stem(x)
if self.grad_checkpointing and not torch.jit.is_scripting(): if self.grad_checkpointing and not torch.jit.is_scripting():
@ -851,6 +909,63 @@ class EfficientVitLarge(nn.Module):
self.num_classes = num_classes self.num_classes = num_classes
self.head.reset(num_classes, global_pool) self.head.reset(num_classes, global_pool)
def forward_intermediates(
self,
x: torch.Tensor,
indices: Optional[Union[int, List[int]]] = None,
norm: bool = False,
stop_early: bool = False,
output_fmt: str = 'NCHW',
intermediates_only: bool = False,
) -> Union[List[torch.Tensor], Tuple[torch.Tensor, List[torch.Tensor]]]:
""" Forward features that returns intermediates.
Args:
x: Input image tensor
indices: Take last n blocks if int, all if None, select matching indices if sequence
norm: Apply norm layer to compatible intermediates
stop_early: Stop iterating over blocks when last desired intermediate hit
output_fmt: Shape of intermediate feature outputs
intermediates_only: Only return intermediate features
Returns:
"""
assert output_fmt in ('NCHW',), 'Output shape must be NCHW.'
intermediates = []
take_indices, max_index = feature_take_indices(len(self.stages), indices)
# forward pass
x = self.stem(x)
if torch.jit.is_scripting() or not stop_early: # can't slice blocks in torchscript
stages = self.stages
else:
stages = self.stages[:max_index + 1]
for feat_idx, stage in enumerate(stages):
x = stage(x)
if feat_idx in take_indices:
intermediates.append(x)
if intermediates_only:
return intermediates
return x, intermediates
def prune_intermediate_layers(
self,
indices: Union[int, List[int]] = 1,
prune_norm: bool = False,
prune_head: bool = True,
):
""" Prune layers not required for specified intermediates.
"""
take_indices, max_index = feature_take_indices(len(self.stages), indices)
self.stages = self.stages[:max_index + 1] # truncate blocks w/ stem as idx 0
if prune_head:
self.reset_classifier(0, '')
return take_indices
def forward_features(self, x): def forward_features(self, x):
x = self.stem(x) x = self.stem(x)
if self.grad_checkpointing and not torch.jit.is_scripting(): if self.grad_checkpointing and not torch.jit.is_scripting():

View File

@ -9,7 +9,7 @@ Adapted from official impl at https://github.com/microsoft/Cream/tree/main/Effic
__all__ = ['EfficientVitMsra'] __all__ = ['EfficientVitMsra']
import itertools import itertools
from collections import OrderedDict from collections import OrderedDict
from typing import Dict, Optional from typing import Dict, List, Optional, Tuple, Union
import torch import torch
import torch.nn as nn import torch.nn as nn
@ -17,6 +17,7 @@ import torch.nn as nn
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
from timm.layers import SqueezeExcite, SelectAdaptivePool2d, trunc_normal_, _assert from timm.layers import SqueezeExcite, SelectAdaptivePool2d, trunc_normal_, _assert
from ._builder import build_model_with_cfg from ._builder import build_model_with_cfg
from ._features import feature_take_indices
from ._manipulate import checkpoint_seq from ._manipulate import checkpoint_seq
from ._registry import register_model, generate_default_cfgs from ._registry import register_model, generate_default_cfgs
@ -475,6 +476,63 @@ class EfficientVitMsra(nn.Module):
self.head = NormLinear( self.head = NormLinear(
self.num_features, num_classes, drop=self.drop_rate) if num_classes > 0 else torch.nn.Identity() self.num_features, num_classes, drop=self.drop_rate) if num_classes > 0 else torch.nn.Identity()
def forward_intermediates(
self,
x: torch.Tensor,
indices: Optional[Union[int, List[int]]] = None,
norm: bool = False,
stop_early: bool = False,
output_fmt: str = 'NCHW',
intermediates_only: bool = False,
) -> Union[List[torch.Tensor], Tuple[torch.Tensor, List[torch.Tensor]]]:
""" Forward features that returns intermediates.
Args:
x: Input image tensor
indices: Take last n blocks if int, all if None, select matching indices if sequence
norm: Apply norm layer to compatible intermediates
stop_early: Stop iterating over blocks when last desired intermediate hit
output_fmt: Shape of intermediate feature outputs
intermediates_only: Only return intermediate features
Returns:
"""
assert output_fmt in ('NCHW',), 'Output shape must be NCHW.'
intermediates = []
take_indices, max_index = feature_take_indices(len(self.stages), indices)
# forward pass
x = self.patch_embed(x)
if torch.jit.is_scripting() or not stop_early: # can't slice blocks in torchscript
stages = self.stages
else:
stages = self.stages[:max_index + 1]
for feat_idx, stage in enumerate(stages):
x = stage(x)
if feat_idx in take_indices:
intermediates.append(x)
if intermediates_only:
return intermediates
return x, intermediates
def prune_intermediate_layers(
self,
indices: Union[int, List[int]] = 1,
prune_norm: bool = False,
prune_head: bool = True,
):
""" Prune layers not required for specified intermediates.
"""
take_indices, max_index = feature_take_indices(len(self.stages), indices)
self.stages = self.stages[:max_index + 1] # truncate blocks w/ stem as idx 0
if prune_head:
self.reset_classifier(0, '')
return take_indices
def forward_features(self, x): def forward_features(self, x):
x = self.patch_embed(x) x = self.patch_embed(x)
if self.grad_checkpointing and not torch.jit.is_scripting(): if self.grad_checkpointing and not torch.jit.is_scripting():

View File

@ -18,7 +18,7 @@ This impl is/has:
# Written by Jianwei Yang (jianwyan@microsoft.com) # Written by Jianwei Yang (jianwyan@microsoft.com)
# -------------------------------------------------------- # --------------------------------------------------------
from functools import partial from functools import partial
from typing import Callable, Optional, Tuple from typing import Callable, List, Optional, Tuple, Union
import torch import torch
import torch.nn as nn import torch.nn as nn
@ -26,6 +26,7 @@ import torch.nn as nn
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
from timm.layers import Mlp, DropPath, LayerNorm2d, trunc_normal_, ClassifierHead, NormMlpClassifierHead from timm.layers import Mlp, DropPath, LayerNorm2d, trunc_normal_, ClassifierHead, NormMlpClassifierHead
from ._builder import build_model_with_cfg from ._builder import build_model_with_cfg
from ._features import feature_take_indices
from ._manipulate import named_apply, checkpoint from ._manipulate import named_apply, checkpoint
from ._registry import generate_default_cfgs, register_model from ._registry import generate_default_cfgs, register_model
@ -458,6 +459,72 @@ class FocalNet(nn.Module):
self.num_classes = num_classes self.num_classes = num_classes
self.head.reset(num_classes, pool_type=global_pool) self.head.reset(num_classes, pool_type=global_pool)
def forward_intermediates(
self,
x: torch.Tensor,
indices: Optional[Union[int, List[int]]] = None,
norm: bool = False,
stop_early: bool = False,
output_fmt: str = 'NCHW',
intermediates_only: bool = False,
) -> Union[List[torch.Tensor], Tuple[torch.Tensor, List[torch.Tensor]]]:
""" Forward features that returns intermediates.
Args:
x: Input image tensor
indices: Take last n blocks if int, all if None, select matching indices if sequence
norm: Apply norm layer to compatible intermediates
stop_early: Stop iterating over blocks when last desired intermediate hit
output_fmt: Shape of intermediate feature outputs
intermediates_only: Only return intermediate features
Returns:
"""
assert output_fmt in ('NCHW',), 'Output shape must be NCHW.'
intermediates = []
take_indices, max_index = feature_take_indices(len(self.layers), indices)
# forward pass
x = self.stem(x)
if torch.jit.is_scripting() or not stop_early: # can't slice blocks in torchscript
stages = self.layers
else:
stages = self.layers[:max_index + 1]
last_idx = len(self.layers) - 1
for feat_idx, stage in enumerate(stages):
x = stage(x)
if feat_idx in take_indices:
if norm and feat_idx == last_idx:
x_inter = self.norm(x) # applying final norm to last intermediate
else:
x_inter = x
intermediates.append(x_inter)
if intermediates_only:
return intermediates
if feat_idx == last_idx:
x = self.norm(x)
return x, intermediates
def prune_intermediate_layers(
self,
indices: Union[int, List[int]] = 1,
prune_norm: bool = False,
prune_head: bool = True,
):
""" Prune layers not required for specified intermediates.
"""
take_indices, max_index = feature_take_indices(len(self.layers), indices)
self.layers = self.layers[:max_index + 1] # truncate blocks w/ stem as idx 0
if prune_norm:
self.norm = nn.Identity()
if prune_head:
self.reset_classifier(0, '')
return take_indices
def forward_features(self, x): def forward_features(self, x):
x = self.stem(x) x = self.stem(x)
x = self.layers(x) x = self.layers(x)

View File

@ -30,6 +30,7 @@ from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
from timm.layers import DropPath, to_2tuple, to_ntuple, Mlp, ClassifierHead, LayerNorm2d, \ from timm.layers import DropPath, to_2tuple, to_ntuple, Mlp, ClassifierHead, LayerNorm2d, \
get_attn, get_act_layer, get_norm_layer, RelPosBias, _assert get_attn, get_act_layer, get_norm_layer, RelPosBias, _assert
from ._builder import build_model_with_cfg from ._builder import build_model_with_cfg
from ._features import feature_take_indices
from ._features_fx import register_notrace_function from ._features_fx import register_notrace_function
from ._manipulate import named_apply, checkpoint from ._manipulate import named_apply, checkpoint
from ._registry import register_model, generate_default_cfgs from ._registry import register_model, generate_default_cfgs
@ -397,7 +398,7 @@ class GlobalContextVit(nn.Module):
act_layer = get_act_layer(act_layer) act_layer = get_act_layer(act_layer)
norm_layer = partial(get_norm_layer(norm_layer), eps=norm_eps) norm_layer = partial(get_norm_layer(norm_layer), eps=norm_eps)
norm_layer_cl = partial(get_norm_layer(norm_layer_cl), eps=norm_eps) norm_layer_cl = partial(get_norm_layer(norm_layer_cl), eps=norm_eps)
self.feature_info = []
img_size = to_2tuple(img_size) img_size = to_2tuple(img_size)
feat_size = tuple(d // 4 for d in img_size) # stem reduction by 4 feat_size = tuple(d // 4 for d in img_size) # stem reduction by 4
self.global_pool = global_pool self.global_pool = global_pool
@ -441,6 +442,7 @@ class GlobalContextVit(nn.Module):
norm_layer=norm_layer, norm_layer=norm_layer,
norm_layer_cl=norm_layer_cl, norm_layer_cl=norm_layer_cl,
)) ))
self.feature_info += [dict(num_chs=stages[-1].dim, reduction=2**(i+2), module=f'stages.{i}')]
self.stages = nn.Sequential(*stages) self.stages = nn.Sequential(*stages)
# Classifier head # Classifier head
@ -494,6 +496,62 @@ class GlobalContextVit(nn.Module):
global_pool = self.head.global_pool.pool_type global_pool = self.head.global_pool.pool_type
self.head = ClassifierHead(self.num_features, num_classes, pool_type=global_pool, drop_rate=self.drop_rate) self.head = ClassifierHead(self.num_features, num_classes, pool_type=global_pool, drop_rate=self.drop_rate)
def forward_intermediates(
self,
x: torch.Tensor,
indices: Optional[Union[int, List[int]]] = None,
norm: bool = False,
stop_early: bool = False,
output_fmt: str = 'NCHW',
intermediates_only: bool = False,
) -> Union[List[torch.Tensor], Tuple[torch.Tensor, List[torch.Tensor]]]:
""" Forward features that returns intermediates.
Args:
x: Input image tensor
indices: Take last n blocks if int, all if None, select matching indices if sequence
norm: Apply norm layer to compatible intermediates
stop_early: Stop iterating over blocks when last desired intermediate hit
output_fmt: Shape of intermediate feature outputs
intermediates_only: Only return intermediate features
Returns:
"""
assert output_fmt in ('NCHW',), 'Output shape must be NCHW.'
intermediates = []
take_indices, max_index = feature_take_indices(len(self.stages), indices)
# forward pass
x = self.stem(x)
if torch.jit.is_scripting() or not stop_early: # can't slice blocks in torchscript
stages = self.stages
else:
stages = self.stages[:max_index + 1]
for feat_idx, stage in enumerate(stages):
x = stage(x)
if feat_idx in take_indices:
intermediates.append(x)
if intermediates_only:
return intermediates
return x, intermediates
def prune_intermediate_layers(
self,
indices: Union[int, List[int]] = 1,
prune_norm: bool = False,
prune_head: bool = True,
):
""" Prune layers not required for specified intermediates.
"""
take_indices, max_index = feature_take_indices(len(self.stages), indices)
self.stages = self.stages[:max_index + 1] # truncate blocks w/ stem as idx 0
if prune_head:
self.reset_classifier(0, '')
return take_indices
def forward_features(self, x: torch.Tensor) -> torch.Tensor: def forward_features(self, x: torch.Tensor) -> torch.Tensor:
x = self.stem(x) x = self.stem(x)
x = self.stages(x) x = self.stages(x)
@ -509,9 +567,11 @@ class GlobalContextVit(nn.Module):
def _create_gcvit(variant, pretrained=False, **kwargs): def _create_gcvit(variant, pretrained=False, **kwargs):
if kwargs.get('features_only', None): model = build_model_with_cfg(
raise RuntimeError('features_only not implemented for Vision Transformer models.') GlobalContextVit, variant, pretrained,
model = build_model_with_cfg(GlobalContextVit, variant, pretrained, **kwargs) feature_cfg=dict(out_indices=(0, 1, 2, 3), flatten_sequential=True),
**kwargs
)
return model return model

View File

@ -6,7 +6,7 @@ The Paddle Implement of PP-HGNet (https://github.com/PaddlePaddle/PaddleClas/blo
PP-HGNet: https://github.com/PaddlePaddle/PaddleClas/blob/release/2.5.1/ppcls/arch/backbone/legendary_models/pp_hgnet.py PP-HGNet: https://github.com/PaddlePaddle/PaddleClas/blob/release/2.5.1/ppcls/arch/backbone/legendary_models/pp_hgnet.py
PP-HGNetv2: https://github.com/PaddlePaddle/PaddleClas/blob/release/2.5.1/ppcls/arch/backbone/legendary_models/pp_hgnet_v2.py PP-HGNetv2: https://github.com/PaddlePaddle/PaddleClas/blob/release/2.5.1/ppcls/arch/backbone/legendary_models/pp_hgnet_v2.py
""" """
from typing import Dict, Optional from typing import Dict, List, Optional, Tuple, Union
import torch import torch
import torch.nn as nn import torch.nn as nn
@ -15,6 +15,7 @@ import torch.nn.functional as F
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
from timm.layers import SelectAdaptivePool2d, DropPath, create_conv2d from timm.layers import SelectAdaptivePool2d, DropPath, create_conv2d
from ._builder import build_model_with_cfg from ._builder import build_model_with_cfg
from ._features import feature_take_indices
from ._registry import register_model, generate_default_cfgs from ._registry import register_model, generate_default_cfgs
from ._manipulate import checkpoint_seq from ._manipulate import checkpoint_seq
@ -508,6 +509,62 @@ class HighPerfGpuNet(nn.Module):
self.num_classes = num_classes self.num_classes = num_classes
self.head.reset(num_classes, global_pool) self.head.reset(num_classes, global_pool)
def forward_intermediates(
self,
x: torch.Tensor,
indices: Optional[Union[int, List[int]]] = None,
norm: bool = False,
stop_early: bool = False,
output_fmt: str = 'NCHW',
intermediates_only: bool = False,
) -> Union[List[torch.Tensor], Tuple[torch.Tensor, List[torch.Tensor]]]:
""" Forward features that returns intermediates.
Args:
x: Input image tensor
indices: Take last n blocks if int, all if None, select matching indices if sequence
norm: Apply norm layer to compatible intermediates
stop_early: Stop iterating over blocks when last desired intermediate hit
output_fmt: Shape of intermediate feature outputs
intermediates_only: Only return intermediate features
Returns:
"""
assert output_fmt in ('NCHW',), 'Output shape must be NCHW.'
intermediates = []
take_indices, max_index = feature_take_indices(len(self.stages), indices)
# forward pass
x = self.stem(x)
if torch.jit.is_scripting() or not stop_early: # can't slice blocks in torchscript
stages = self.stages
else:
stages = self.stages[:max_index + 1]
for feat_idx, stage in enumerate(stages):
x = stage(x)
if feat_idx in take_indices:
intermediates.append(x)
if intermediates_only:
return intermediates
return x, intermediates
def prune_intermediate_layers(
self,
indices: Union[int, List[int]] = 1,
prune_norm: bool = False,
prune_head: bool = True,
):
""" Prune layers not required for specified intermediates.
"""
take_indices, max_index = feature_take_indices(len(self.stages), indices)
self.stages = self.stages[:max_index + 1] # truncate blocks w/ stem as idx 0
if prune_head:
self.reset_classifier(0, 'avg')
return take_indices
def forward_features(self, x): def forward_features(self, x):
x = self.stem(x) x = self.stem(x)
return self.stages(x) return self.stages(x)

View File

@ -4,7 +4,7 @@ Original implementation & weights from: https://github.com/sail-sg/inceptionnext
""" """
from functools import partial from functools import partial
from typing import Optional from typing import List, Optional, Tuple, Union
import torch import torch
import torch.nn as nn import torch.nn as nn
@ -12,6 +12,7 @@ import torch.nn as nn
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
from timm.layers import trunc_normal_, DropPath, to_2tuple, get_padding, SelectAdaptivePool2d from timm.layers import trunc_normal_, DropPath, to_2tuple, get_padding, SelectAdaptivePool2d
from ._builder import build_model_with_cfg from ._builder import build_model_with_cfg
from ._features import feature_take_indices
from ._manipulate import checkpoint_seq from ._manipulate import checkpoint_seq
from ._registry import register_model, generate_default_cfgs from ._registry import register_model, generate_default_cfgs
@ -349,6 +350,62 @@ class MetaNeXt(nn.Module):
def no_weight_decay(self): def no_weight_decay(self):
return set() return set()
def forward_intermediates(
self,
x: torch.Tensor,
indices: Optional[Union[int, List[int]]] = None,
norm: bool = False,
stop_early: bool = False,
output_fmt: str = 'NCHW',
intermediates_only: bool = False,
) -> Union[List[torch.Tensor], Tuple[torch.Tensor, List[torch.Tensor]]]:
""" Forward features that returns intermediates.
Args:
x: Input image tensor
indices: Take last n blocks if int, all if None, select matching indices if sequence
norm: Apply norm layer to compatible intermediates
stop_early: Stop iterating over blocks when last desired intermediate hit
output_fmt: Shape of intermediate feature outputs
intermediates_only: Only return intermediate features
Returns:
"""
assert output_fmt in ('NCHW',), 'Output shape must be NCHW.'
intermediates = []
take_indices, max_index = feature_take_indices(len(self.stages), indices)
# forward pass
x = self.stem(x)
if torch.jit.is_scripting() or not stop_early: # can't slice blocks in torchscript
stages = self.stages
else:
stages = self.stages[:max_index + 1]
for feat_idx, stage in enumerate(stages):
x = stage(x)
if feat_idx in take_indices:
intermediates.append(x)
if intermediates_only:
return intermediates
return x, intermediates
def prune_intermediate_layers(
self,
indices: Union[int, List[int]] = 1,
prune_norm: bool = False,
prune_head: bool = True,
):
""" Prune layers not required for specified intermediates.
"""
take_indices, max_index = feature_take_indices(len(self.stages), indices)
self.stages = self.stages[:max_index + 1] # truncate blocks w/ stem as idx 0
if prune_head:
self.reset_classifier(0, 'avg')
return take_indices
def forward_features(self, x): def forward_features(self, x):
x = self.stem(x) x = self.stem(x)
x = self.stages(x) x = self.stages(x)

View File

@ -3,6 +3,7 @@ Sourced from https://github.com/Cadene/tensorflow-model-zoo.torch (MIT License)
based upon Google's Tensorflow implementation and pretrained weights (Apache 2.0 License) based upon Google's Tensorflow implementation and pretrained weights (Apache 2.0 License)
""" """
from functools import partial from functools import partial
from typing import List, Optional, Tuple, Union
import torch import torch
import torch.nn as nn import torch.nn as nn
@ -10,6 +11,7 @@ import torch.nn as nn
from timm.data import IMAGENET_INCEPTION_MEAN, IMAGENET_INCEPTION_STD from timm.data import IMAGENET_INCEPTION_MEAN, IMAGENET_INCEPTION_STD
from timm.layers import create_classifier, ConvNormAct from timm.layers import create_classifier, ConvNormAct
from ._builder import build_model_with_cfg from ._builder import build_model_with_cfg
from ._features import feature_take_indices
from ._registry import register_model, generate_default_cfgs from ._registry import register_model, generate_default_cfgs
__all__ = ['InceptionV4'] __all__ = ['InceptionV4']
@ -285,6 +287,66 @@ class InceptionV4(nn.Module):
self.global_pool, self.last_linear = create_classifier( self.global_pool, self.last_linear = create_classifier(
self.num_features, self.num_classes, pool_type=global_pool) self.num_features, self.num_classes, pool_type=global_pool)
def forward_intermediates(
self,
x: torch.Tensor,
indices: Optional[Union[int, List[int]]] = None,
norm: bool = False,
stop_early: bool = False,
output_fmt: str = 'NCHW',
intermediates_only: bool = False,
) -> Union[List[torch.Tensor], Tuple[torch.Tensor, List[torch.Tensor]]]:
""" Forward features that returns intermediates.
Args:
x: Input image tensor
indices: Take last n blocks if int, all if None, select matching indices if sequence
norm: Apply norm layer to compatible intermediates
stop_early: Stop iterating over blocks when last desired intermediate hit
output_fmt: Shape of intermediate feature outputs
intermediates_only: Only return intermediate features
Returns:
"""
assert output_fmt in ('NCHW',), 'Output shape must be NCHW.'
intermediates = []
stage_ends = [int(info['module'].split('.')[-1]) for info in self.feature_info]
take_indices, max_index = feature_take_indices(len(stage_ends), indices)
take_indices = [stage_ends[i] for i in take_indices]
max_index = stage_ends[max_index]
# forward pass
if torch.jit.is_scripting() or not stop_early: # can't slice blocks in torchscript
stages = self.features
else:
stages = self.features[:max_index + 1]
for feat_idx, stage in enumerate(stages):
x = stage(x)
if feat_idx in take_indices:
intermediates.append(x)
if intermediates_only:
return intermediates
return x, intermediates
def prune_intermediate_layers(
self,
indices: Union[int, List[int]] = 1,
prune_norm: bool = False,
prune_head: bool = True,
):
""" Prune layers not required for specified intermediates.
"""
stage_ends = [int(info['module'].split('.')[-1]) for info in self.feature_info]
take_indices, max_index = feature_take_indices(len(stage_ends), indices)
max_index = stage_ends[max_index]
self.features = self.features[:max_index + 1] # truncate blocks w/ stem as idx 0
if prune_head:
self.reset_classifier(0, '')
return take_indices
def forward_features(self, x): def forward_features(self, x):
return self.features(x) return self.features(x)

View File

@ -6,7 +6,7 @@ MetaFormer (https://github.com/sail-sg/metaformer),
InceptionNeXt (https://github.com/sail-sg/inceptionnext) InceptionNeXt (https://github.com/sail-sg/inceptionnext)
""" """
from collections import OrderedDict from collections import OrderedDict
from typing import Optional from typing import List, Optional, Tuple, Union
import torch import torch
from torch import nn from torch import nn
@ -14,6 +14,7 @@ from torch import nn
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
from timm.layers import trunc_normal_, DropPath, LayerNorm, LayerScale, ClNormMlpClassifierHead, get_act_layer from timm.layers import trunc_normal_, DropPath, LayerNorm, LayerScale, ClNormMlpClassifierHead, get_act_layer
from ._builder import build_model_with_cfg from ._builder import build_model_with_cfg
from ._features import feature_take_indices
from ._manipulate import checkpoint_seq from ._manipulate import checkpoint_seq
from ._registry import register_model, generate_default_cfgs from ._registry import register_model, generate_default_cfgs
@ -417,6 +418,67 @@ class MambaOut(nn.Module):
self.num_classes = num_classes self.num_classes = num_classes
self.head.reset(num_classes, global_pool) self.head.reset(num_classes, global_pool)
def forward_intermediates(
self,
x: torch.Tensor,
indices: Optional[Union[int, List[int]]] = None,
norm: bool = False,
stop_early: bool = False,
output_fmt: str = 'NCHW',
intermediates_only: bool = False,
) -> Union[List[torch.Tensor], Tuple[torch.Tensor, List[torch.Tensor]]]:
""" Forward features that returns intermediates.
Args:
x: Input image tensor
indices: Take last n blocks if int, all if None, select matching indices if sequence
norm: Apply norm layer to compatible intermediates
stop_early: Stop iterating over blocks when last desired intermediate hit
output_fmt: Shape of intermediate feature outputs
intermediates_only: Only return intermediate features
Returns:
"""
assert output_fmt in ('NCHW', 'NHWC'), 'Output format must be one of NCHW or NHWC.'
channel_first = output_fmt == 'NCHW'
intermediates = []
take_indices, max_index = feature_take_indices(len(self.stages), indices)
# forward pass
x = self.stem(x)
if torch.jit.is_scripting() or not stop_early: # can't slice blocks in torchscript
stages = self.stages
else:
stages = self.stages[:max_index + 1]
for feat_idx, stage in enumerate(stages):
x = stage(x)
if feat_idx in take_indices:
intermediates.append(x)
if channel_first:
# reshape to BCHW output format
intermediates = [y.permute(0, 3, 1, 2).contiguous() for y in intermediates]
if intermediates_only:
return intermediates
return x, intermediates
def prune_intermediate_layers(
self,
indices: Union[int, List[int]] = 1,
prune_norm: bool = False,
prune_head: bool = True,
):
""" Prune layers not required for specified intermediates.
"""
take_indices, max_index = feature_take_indices(len(self.stages), indices)
self.stages = self.stages[:max_index + 1] # truncate blocks w/ stem as idx 0
if prune_head:
self.reset_classifier(0, '')
return take_indices
def forward_features(self, x): def forward_features(self, x):
x = self.stem(x) x = self.stem(x)
x = self.stages(x) x = self.stages(x)

View File

@ -1302,7 +1302,8 @@ class MaxxVit(nn.Module):
if intermediates_only: if intermediates_only:
return intermediates return intermediates
x = self.norm(x) if feat_idx == last_idx:
x = self.norm(x)
return x, intermediates return x, intermediates

View File

@ -28,7 +28,7 @@ Adapted from https://github.com/sail-sg/metaformer, original copyright below
from collections import OrderedDict from collections import OrderedDict
from functools import partial from functools import partial
from typing import Optional from typing import List, Optional, Tuple, Union
import torch import torch
import torch.nn as nn import torch.nn as nn
@ -40,6 +40,7 @@ from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
from timm.layers import trunc_normal_, DropPath, SelectAdaptivePool2d, GroupNorm1, LayerNorm, LayerNorm2d, Mlp, \ from timm.layers import trunc_normal_, DropPath, SelectAdaptivePool2d, GroupNorm1, LayerNorm, LayerNorm2d, Mlp, \
use_fused_attn use_fused_attn
from ._builder import build_model_with_cfg from ._builder import build_model_with_cfg
from ._features import feature_take_indices
from ._manipulate import checkpoint_seq from ._manipulate import checkpoint_seq
from ._registry import generate_default_cfgs, register_model from ._registry import generate_default_cfgs, register_model
@ -597,6 +598,62 @@ class MetaFormer(nn.Module):
final = nn.Identity() final = nn.Identity()
self.head.fc = final self.head.fc = final
def forward_intermediates(
self,
x: torch.Tensor,
indices: Optional[Union[int, List[int]]] = None,
norm: bool = False,
stop_early: bool = False,
output_fmt: str = 'NCHW',
intermediates_only: bool = False,
) -> Union[List[torch.Tensor], Tuple[torch.Tensor, List[torch.Tensor]]]:
""" Forward features that returns intermediates.
Args:
x: Input image tensor
indices: Take last n blocks if int, all if None, select matching indices if sequence
norm: Apply norm layer to compatible intermediates
stop_early: Stop iterating over blocks when last desired intermediate hit
output_fmt: Shape of intermediate feature outputs
intermediates_only: Only return intermediate features
Returns:
"""
assert output_fmt in ('NCHW',), 'Output shape must be NCHW.'
intermediates = []
take_indices, max_index = feature_take_indices(len(self.stages), indices)
# forward pass
x = self.stem(x)
if torch.jit.is_scripting() or not stop_early: # can't slice blocks in torchscript
stages = self.stages
else:
stages = self.stages[:max_index + 1]
for feat_idx, stage in enumerate(stages):
x = stage(x)
if feat_idx in take_indices:
intermediates.append(x)
if intermediates_only:
return intermediates
return x, intermediates
def prune_intermediate_layers(
self,
indices: Union[int, List[int]] = 1,
prune_norm: bool = False,
prune_head: bool = True,
):
""" Prune layers not required for specified intermediates.
"""
take_indices, max_index = feature_take_indices(len(self.stages), indices)
self.stages = self.stages[:max_index + 1] # truncate blocks w/ stem as idx 0
if prune_head:
self.reset_classifier(0, '')
return take_indices
def forward_head(self, x: Tensor, pre_logits: bool = False): def forward_head(self, x: Tensor, pre_logits: bool = False):
# NOTE nn.Sequential in head broken down since can't call head[:-1](x) in torchscript :( # NOTE nn.Sequential in head broken down since can't call head[:-1](x) in torchscript :(
x = self.head.global_pool(x) x = self.head.global_pool(x)

View File

@ -870,10 +870,11 @@ class MultiScaleVit(nn.Module):
if self.pos_embed is not None: if self.pos_embed is not None:
x = x + self.pos_embed x = x + self.pos_embed
for i, stage in enumerate(self.stages): last_idx = len(self.stages) - 1
for feat_idx, stage in enumerate(self.stages):
x, feat_size = stage(x, feat_size) x, feat_size = stage(x, feat_size)
if i in take_indices: if feat_idx in take_indices:
if norm and i == (len(self.stages) - 1): if norm and feat_idx == last_idx:
x_inter = self.norm(x) # applying final norm last intermediate x_inter = self.norm(x) # applying final norm last intermediate
else: else:
x_inter = x x_inter = x
@ -887,7 +888,8 @@ class MultiScaleVit(nn.Module):
if intermediates_only: if intermediates_only:
return intermediates return intermediates
x = self.norm(x) if feat_idx == last_idx:
x = self.norm(x)
return x, intermediates return x, intermediates

View File

@ -19,6 +19,7 @@ import collections.abc
import logging import logging
import math import math
from functools import partial from functools import partial
from typing import List, Optional, Tuple, Union
import torch import torch
import torch.nn.functional as F import torch.nn.functional as F
@ -28,6 +29,7 @@ from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
from timm.layers import PatchEmbed, Mlp, DropPath, create_classifier, trunc_normal_, _assert from timm.layers import PatchEmbed, Mlp, DropPath, create_classifier, trunc_normal_, _assert
from timm.layers import create_conv2d, create_pool2d, to_ntuple, use_fused_attn, LayerNorm from timm.layers import create_conv2d, create_pool2d, to_ntuple, use_fused_attn, LayerNorm
from ._builder import build_model_with_cfg from ._builder import build_model_with_cfg
from ._features import feature_take_indices
from ._features_fx import register_notrace_function from ._features_fx import register_notrace_function
from ._manipulate import checkpoint_seq, named_apply from ._manipulate import checkpoint_seq, named_apply
from ._registry import register_model, generate_default_cfgs, register_model_deprecations from ._registry import register_model, generate_default_cfgs, register_model_deprecations
@ -420,6 +422,73 @@ class Nest(nn.Module):
self.global_pool, self.head = create_classifier( self.global_pool, self.head = create_classifier(
self.num_features, self.num_classes, pool_type=global_pool) self.num_features, self.num_classes, pool_type=global_pool)
def forward_intermediates(
self,
x: torch.Tensor,
indices: Optional[Union[int, List[int]]] = None,
norm: bool = False,
stop_early: bool = False,
output_fmt: str = 'NCHW',
intermediates_only: bool = False,
) -> Union[List[torch.Tensor], Tuple[torch.Tensor, List[torch.Tensor]]]:
""" Forward features that returns intermediates.
Args:
x: Input image tensor
indices: Take last n blocks if int, all if None, select matching indices if sequence
norm: Apply norm layer to compatible intermediates
stop_early: Stop iterating over blocks when last desired intermediate hit
output_fmt: Shape of intermediate feature outputs
intermediates_only: Only return intermediate features
Returns:
"""
assert output_fmt in ('NCHW',), 'Output shape must be NCHW.'
intermediates = []
take_indices, max_index = feature_take_indices(len(self.levels), indices)
# forward pass
x = self.patch_embed(x)
last_idx = len(self.num_blocks) - 1
if torch.jit.is_scripting() or not stop_early: # can't slice blocks in torchscript
stages = self.levels
else:
stages = self.levels[:max_index + 1]
for feat_idx, stage in enumerate(stages):
x = stage(x)
if feat_idx in take_indices:
if norm and feat_idx == last_idx:
x_inter = self.norm(x.permute(0, 2, 3, 1)).permute(0, 3, 1, 2)
intermediates.append(x_inter)
else:
intermediates.append(x)
if intermediates_only:
return intermediates
if feat_idx == last_idx:
# Layer norm done over channel dim only (to NHWC and back)
x = self.norm(x.permute(0, 2, 3, 1)).permute(0, 3, 1, 2)
return x, intermediates
def prune_intermediate_layers(
self,
indices: Union[int, List[int]] = 1,
prune_norm: bool = False,
prune_head: bool = True,
):
""" Prune layers not required for specified intermediates.
"""
take_indices, max_index = feature_take_indices(len(self.levels), indices)
self.levels = self.levels[:max_index + 1] # truncate blocks w/ stem as idx 0
if prune_norm:
self.norm = nn.Identity()
if prune_head:
self.reset_classifier(0, '')
return take_indices
def forward_features(self, x): def forward_features(self, x):
x = self.patch_embed(x) x = self.patch_embed(x)
x = self.levels(x) x = self.levels(x)

View File

@ -6,7 +6,7 @@ Next-ViT model defs and weights adapted from https://github.com/bytedance/Next-V
""" """
# Copyright (c) ByteDance Inc. All rights reserved. # Copyright (c) ByteDance Inc. All rights reserved.
from functools import partial from functools import partial
from typing import Optional from typing import List, Optional, Tuple, Union
import torch import torch
import torch.nn.functional as F import torch.nn.functional as F
@ -16,6 +16,7 @@ from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
from timm.layers import DropPath, trunc_normal_, ConvMlp, get_norm_layer, get_act_layer, use_fused_attn from timm.layers import DropPath, trunc_normal_, ConvMlp, get_norm_layer, get_act_layer, use_fused_attn
from timm.layers import ClassifierHead from timm.layers import ClassifierHead
from ._builder import build_model_with_cfg from ._builder import build_model_with_cfg
from ._features import feature_take_indices
from ._features_fx import register_notrace_function from ._features_fx import register_notrace_function
from ._manipulate import checkpoint_seq from ._manipulate import checkpoint_seq
from ._registry import generate_default_cfgs, register_model from ._registry import generate_default_cfgs, register_model
@ -560,6 +561,72 @@ class NextViT(nn.Module):
self.num_classes = num_classes self.num_classes = num_classes
self.head.reset(num_classes, pool_type=global_pool) self.head.reset(num_classes, pool_type=global_pool)
def forward_intermediates(
self,
x: torch.Tensor,
indices: Optional[Union[int, List[int]]] = None,
norm: bool = False,
stop_early: bool = False,
output_fmt: str = 'NCHW',
intermediates_only: bool = False,
) -> Union[List[torch.Tensor], Tuple[torch.Tensor, List[torch.Tensor]]]:
""" Forward features that returns intermediates.
Args:
x: Input image tensor
indices: Take last n blocks if int, all if None, select matching indices if sequence
norm: Apply norm layer to compatible intermediates
stop_early: Stop iterating over blocks when last desired intermediate hit
output_fmt: Shape of intermediate feature outputs
intermediates_only: Only return intermediate features
Returns:
"""
assert output_fmt in ('NCHW',), 'Output shape must be NCHW.'
intermediates = []
take_indices, max_index = feature_take_indices(len(self.stages), indices)
# forward pass
x = self.stem(x)
last_idx = len(self.stages) - 1
if torch.jit.is_scripting() or not stop_early: # can't slice blocks in torchscript
stages = self.stages
else:
stages = self.stages[:max_index + 1]
for feat_idx, stage in enumerate(stages):
x = stage(x)
if feat_idx in take_indices:
if feat_idx == last_idx:
x_inter = self.norm(x) if norm else x
intermediates.append(x_inter)
else:
intermediates.append(x)
if intermediates_only:
return intermediates
if feat_idx == last_idx:
x = self.norm(x)
return x, intermediates
def prune_intermediate_layers(
self,
indices: Union[int, List[int]] = 1,
prune_norm: bool = False,
prune_head: bool = True,
):
""" Prune layers not required for specified intermediates.
"""
take_indices, max_index = feature_take_indices(len(self.stages), indices)
self.stages = self.stages[:max_index + 1] # truncate blocks w/ stem as idx 0
if prune_norm:
self.norm = nn.Identity()
if prune_head:
self.reset_classifier(0, '')
return take_indices
def forward_features(self, x): def forward_features(self, x):
x = self.stem(x) x = self.stem(x)
if self.grad_checkpointing and not torch.jit.is_scripting(): if self.grad_checkpointing and not torch.jit.is_scripting():

View File

@ -14,7 +14,7 @@ Modifications for timm by / Copyright 2020 Ross Wightman
import math import math
import re import re
from functools import partial from functools import partial
from typing import Optional, Sequence, Tuple from typing import List, Optional, Sequence, Tuple, Union
import torch import torch
from torch import nn from torch import nn
@ -22,6 +22,7 @@ from torch import nn
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
from timm.layers import trunc_normal_, to_2tuple from timm.layers import trunc_normal_, to_2tuple
from ._builder import build_model_with_cfg from ._builder import build_model_with_cfg
from ._features import feature_take_indices
from ._registry import register_model, generate_default_cfgs from ._registry import register_model, generate_default_cfgs
from .vision_transformer import Block from .vision_transformer import Block
@ -254,6 +255,71 @@ class PoolingVisionTransformer(nn.Module):
if self.head_dist is not None: if self.head_dist is not None:
self.head_dist = nn.Linear(self.embed_dim, self.num_classes) if num_classes > 0 else nn.Identity() self.head_dist = nn.Linear(self.embed_dim, self.num_classes) if num_classes > 0 else nn.Identity()
def forward_intermediates(
self,
x: torch.Tensor,
indices: Optional[Union[int, List[int]]] = None,
norm: bool = False,
stop_early: bool = False,
output_fmt: str = 'NCHW',
intermediates_only: bool = False,
) -> Union[List[torch.Tensor], Tuple[torch.Tensor, List[torch.Tensor]]]:
""" Forward features that returns intermediates.
Args:
x: Input image tensor
indices: Take last n blocks if int, all if None, select matching indices if sequence
norm: Apply norm layer to compatible intermediates
stop_early: Stop iterating over blocks when last desired intermediate hit
output_fmt: Shape of intermediate feature outputs
intermediates_only: Only return intermediate features
Returns:
"""
assert output_fmt in ('NCHW',), 'Output shape must be NCHW.'
intermediates = []
take_indices, max_index = feature_take_indices(len(self.transformers), indices)
# forward pass
x = self.patch_embed(x)
x = self.pos_drop(x + self.pos_embed)
cls_tokens = self.cls_token.expand(x.shape[0], -1, -1)
last_idx = len(self.transformers) - 1
if torch.jit.is_scripting() or not stop_early: # can't slice blocks in torchscript
stages = self.transformers
else:
stages = self.transformers[:max_index + 1]
for feat_idx, stage in enumerate(stages):
x, cls_tokens = stage((x, cls_tokens))
if feat_idx in take_indices:
intermediates.append(x)
if intermediates_only:
return intermediates
if feat_idx == last_idx:
cls_tokens = self.norm(cls_tokens)
return cls_tokens, intermediates
def prune_intermediate_layers(
self,
indices: Union[int, List[int]] = 1,
prune_norm: bool = False,
prune_head: bool = True,
):
""" Prune layers not required for specified intermediates.
"""
take_indices, max_index = feature_take_indices(len(self.transformers), indices)
self.transformers = self.transformers[:max_index + 1] # truncate blocks w/ stem as idx 0
if prune_norm:
self.norm = nn.Identity()
if prune_head:
self.reset_classifier(0, '')
return take_indices
def forward_features(self, x): def forward_features(self, x):
x = self.patch_embed(x) x = self.patch_embed(x)
x = self.pos_drop(x + self.pos_embed) x = self.pos_drop(x + self.pos_embed)
@ -314,7 +380,7 @@ def _create_pit(variant, pretrained=False, **kwargs):
variant, variant,
pretrained, pretrained,
pretrained_filter_fn=checkpoint_filter_fn, pretrained_filter_fn=checkpoint_filter_fn,
feature_cfg=dict(feature_cls='hook', no_rewrite=True, out_indices=out_indices), feature_cfg=dict(feature_cls='hook', out_indices=out_indices),
**kwargs, **kwargs,
) )
return model return model

View File

@ -16,7 +16,7 @@ Modifications and timm support by / Copyright 2022, Ross Wightman
""" """
import math import math
from typing import Callable, List, Optional, Union from typing import Callable, List, Optional, Tuple, Union
import torch import torch
import torch.nn as nn import torch.nn as nn
@ -25,6 +25,7 @@ import torch.nn.functional as F
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
from timm.layers import DropPath, to_2tuple, to_ntuple, trunc_normal_, LayerNorm, use_fused_attn from timm.layers import DropPath, to_2tuple, to_ntuple, trunc_normal_, LayerNorm, use_fused_attn
from ._builder import build_model_with_cfg from ._builder import build_model_with_cfg
from ._features import feature_take_indices
from ._manipulate import checkpoint from ._manipulate import checkpoint
from ._registry import register_model, generate_default_cfgs from ._registry import register_model, generate_default_cfgs
@ -386,6 +387,62 @@ class PyramidVisionTransformerV2(nn.Module):
self.global_pool = global_pool self.global_pool = global_pool
self.head = nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity() self.head = nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity()
def forward_intermediates(
self,
x: torch.Tensor,
indices: Optional[Union[int, List[int]]] = None,
norm: bool = False,
stop_early: bool = False,
output_fmt: str = 'NCHW',
intermediates_only: bool = False,
) -> Union[List[torch.Tensor], Tuple[torch.Tensor, List[torch.Tensor]]]:
""" Forward features that returns intermediates.
Args:
x: Input image tensor
indices: Take last n blocks if int, all if None, select matching indices if sequence
norm: Apply norm layer to compatible intermediates
stop_early: Stop iterating over blocks when last desired intermediate hit
output_fmt: Shape of intermediate feature outputs
intermediates_only: Only return intermediate features
Returns:
"""
assert output_fmt in ('NCHW',), 'Output shape must be NCHW.'
intermediates = []
take_indices, max_index = feature_take_indices(len(self.stages), indices)
# forward pass
x = self.patch_embed(x)
if torch.jit.is_scripting() or not stop_early: # can't slice blocks in torchscript
stages = self.stages
else:
stages = self.stages[:max_index + 1]
for feat_idx, stage in enumerate(stages):
x = stage(x)
if feat_idx in take_indices:
intermediates.append(x)
if intermediates_only:
return intermediates
return x, intermediates
def prune_intermediate_layers(
self,
indices: Union[int, List[int]] = 1,
prune_norm: bool = False,
prune_head: bool = True,
):
""" Prune layers not required for specified intermediates.
"""
take_indices, max_index = feature_take_indices(len(self.stages), indices)
self.stages = self.stages[:max_index + 1] # truncate blocks w/ stem as idx 0
if prune_head:
self.reset_classifier(0, '')
return take_indices
def forward_features(self, x): def forward_features(self, x):
x = self.patch_embed(x) x = self.patch_embed(x)
x = self.stages(x) x = self.stages(x)

View File

@ -302,29 +302,33 @@ class RDNet(nn.Module):
""" """
assert output_fmt in ('NCHW',), 'Output shape must be NCHW.' assert output_fmt in ('NCHW',), 'Output shape must be NCHW.'
intermediates = [] intermediates = []
take_indices, max_index = feature_take_indices(len(self.dense_stages) + 1, indices) stage_ends = [int(info['module'].split('.')[-1]) for info in self.feature_info]
take_indices, max_index = feature_take_indices(len(stage_ends), indices)
take_indices = [stage_ends[i] for i in take_indices]
max_index = stage_ends[max_index]
# forward pass # forward pass
feat_idx = 0 # stem is index 0
x = self.stem(x) x = self.stem(x)
if feat_idx in take_indices:
intermediates.append(x)
last_idx = len(self.dense_stages) - 1
if torch.jit.is_scripting() or not stop_early: # can't slice blocks in torchscript if torch.jit.is_scripting() or not stop_early: # can't slice blocks in torchscript
dense_stages = self.dense_stages dense_stages = self.dense_stages
else: else:
dense_stages = self.dense_stages[:max_index] dense_stages = self.dense_stages[:max_index + 1]
for stage in dense_stages: for feat_idx, stage in enumerate(dense_stages):
feat_idx += 1
x = stage(x) x = stage(x)
if feat_idx in take_indices: if feat_idx in take_indices:
# NOTE not bothering to apply norm_pre when norm=True as almost no models have it enabled if norm and feat_idx == last_idx:
intermediates.append(x) x_inter = self.norm_pre(x) # applying final norm to last intermediate
else:
x_inter = x
intermediates.append(x_inter)
if intermediates_only: if intermediates_only:
return intermediates return intermediates
x = self.norm_pre(x) if feat_idx == last_idx:
x = self.norm_pre(x)
return x, intermediates return x, intermediates
@ -336,8 +340,10 @@ class RDNet(nn.Module):
): ):
""" Prune layers not required for specified intermediates. """ Prune layers not required for specified intermediates.
""" """
take_indices, max_index = feature_take_indices(len(self.dense_stages) + 1, indices) stage_ends = [int(info['module'].split('.')[-1]) for info in self.feature_info]
self.dense_stages = self.dense_stages[:max_index] # truncate blocks w/ stem as idx 0 take_indices, max_index = feature_take_indices(len(stage_ends), indices)
max_index = stage_ends[max_index]
self.dense_stages = self.dense_stages[:max_index + 1] # truncate blocks w/ stem as idx 0
if prune_norm: if prune_norm:
self.norm_pre = nn.Identity() self.norm_pre = nn.Identity()
if prune_head: if prune_head:
@ -355,6 +361,7 @@ class RDNet(nn.Module):
def forward_features(self, x): def forward_features(self, x):
x = self.stem(x) x = self.stem(x)
x = self.dense_stages(x) x = self.dense_stages(x)
x = self.norm_pre(x)
return x return x
def forward_head(self, x, pre_logits: bool = False): def forward_head(self, x, pre_logits: bool = False):

View File

@ -6,7 +6,7 @@ Original implementation: https://github.com/ChengpengChen/RepGhost
""" """
import copy import copy
from functools import partial from functools import partial
from typing import Optional from typing import List, Optional, Tuple, Union
import torch import torch
import torch.nn as nn import torch.nn as nn
@ -16,6 +16,7 @@ from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
from timm.layers import SelectAdaptivePool2d, Linear, make_divisible from timm.layers import SelectAdaptivePool2d, Linear, make_divisible
from ._builder import build_model_with_cfg from ._builder import build_model_with_cfg
from ._efficientnet_blocks import SqueezeExcite, ConvBnAct from ._efficientnet_blocks import SqueezeExcite, ConvBnAct
from ._features import feature_take_indices
from ._manipulate import checkpoint_seq from ._manipulate import checkpoint_seq
from ._registry import register_model, generate_default_cfgs from ._registry import register_model, generate_default_cfgs
@ -294,6 +295,72 @@ class RepGhostNet(nn.Module):
self.flatten = nn.Flatten(1) if global_pool else nn.Identity() # don't flatten if pooling disabled self.flatten = nn.Flatten(1) if global_pool else nn.Identity() # don't flatten if pooling disabled
self.classifier = Linear(self.head_hidden_size, num_classes) if num_classes > 0 else nn.Identity() self.classifier = Linear(self.head_hidden_size, num_classes) if num_classes > 0 else nn.Identity()
def forward_intermediates(
self,
x: torch.Tensor,
indices: Optional[Union[int, List[int]]] = None,
norm: bool = False,
stop_early: bool = False,
output_fmt: str = 'NCHW',
intermediates_only: bool = False,
) -> Union[List[torch.Tensor], Tuple[torch.Tensor, List[torch.Tensor]]]:
""" Forward features that returns intermediates.
Args:
x: Input image tensor
indices: Take last n blocks if int, all if None, select matching indices if sequence
norm: Apply norm layer to compatible intermediates
stop_early: Stop iterating over blocks when last desired intermediate hit
output_fmt: Shape of intermediate feature outputs
intermediates_only: Only return intermediate features
Returns:
"""
assert output_fmt in ('NCHW',), 'Output shape must be NCHW.'
intermediates = []
stage_ends = [-1] + [int(info['module'].split('.')[-1]) for info in self.feature_info[1:]]
take_indices, max_index = feature_take_indices(len(stage_ends), indices)
take_indices = [stage_ends[i]+1 for i in take_indices]
max_index = stage_ends[max_index]
# forward pass
feat_idx = 0
x = self.conv_stem(x)
if feat_idx in take_indices:
intermediates.append(x)
x = self.bn1(x)
x = self.act1(x)
if torch.jit.is_scripting() or not stop_early: # can't slice blocks in torchscript
stages = self.blocks
else:
stages = self.blocks[:max_index + 1]
for feat_idx, stage in enumerate(stages, start=1):
x = stage(x)
if feat_idx in take_indices:
intermediates.append(x)
if intermediates_only:
return intermediates
return x, intermediates
def prune_intermediate_layers(
self,
indices: Union[int, List[int]] = 1,
prune_norm: bool = False,
prune_head: bool = True,
):
""" Prune layers not required for specified intermediates.
"""
stage_ends = [-1] + [int(info['module'].split('.')[-1]) for info in self.feature_info[1:]]
take_indices, max_index = feature_take_indices(len(stage_ends), indices)
max_index = stage_ends[max_index]
self.blocks = self.blocks[:max_index + 1] # truncate blocks w/ stem as idx 0
if prune_head:
self.reset_classifier(0, '')
return take_indices
def forward_features(self, x): def forward_features(self, x):
x = self.conv_stem(x) x = self.conv_stem(x)
x = self.bn1(x) x = self.bn1(x)

View File

@ -14,9 +14,7 @@ Paper: `RepViT: Revisiting Mobile CNN From ViT Perspective`
Adapted from official impl at https://github.com/jameslahm/RepViT Adapted from official impl at https://github.com/jameslahm/RepViT
""" """
from typing import List, Optional, Tuple, Union
__all__ = ['RepVit']
from typing import Optional
import torch import torch
import torch.nn as nn import torch.nn as nn
@ -24,9 +22,12 @@ import torch.nn as nn
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
from timm.layers import SqueezeExcite, trunc_normal_, to_ntuple, to_2tuple from timm.layers import SqueezeExcite, trunc_normal_, to_ntuple, to_2tuple
from ._builder import build_model_with_cfg from ._builder import build_model_with_cfg
from ._features import feature_take_indices
from ._manipulate import checkpoint_seq from ._manipulate import checkpoint_seq
from ._registry import register_model, generate_default_cfgs from ._registry import register_model, generate_default_cfgs
__all__ = ['RepVit']
class ConvNorm(nn.Sequential): class ConvNorm(nn.Sequential):
def __init__(self, in_dim, out_dim, ks=1, stride=1, pad=0, dilation=1, groups=1, bn_weight_init=1): def __init__(self, in_dim, out_dim, ks=1, stride=1, pad=0, dilation=1, groups=1, bn_weight_init=1):
@ -333,6 +334,62 @@ class RepVit(nn.Module):
def set_distilled_training(self, enable=True): def set_distilled_training(self, enable=True):
self.head.distilled_training = enable self.head.distilled_training = enable
def forward_intermediates(
self,
x: torch.Tensor,
indices: Optional[Union[int, List[int]]] = None,
norm: bool = False,
stop_early: bool = False,
output_fmt: str = 'NCHW',
intermediates_only: bool = False,
) -> Union[List[torch.Tensor], Tuple[torch.Tensor, List[torch.Tensor]]]:
""" Forward features that returns intermediates.
Args:
x: Input image tensor
indices: Take last n blocks if int, all if None, select matching indices if sequence
norm: Apply norm layer to compatible intermediates
stop_early: Stop iterating over blocks when last desired intermediate hit
output_fmt: Shape of intermediate feature outputs
intermediates_only: Only return intermediate features
Returns:
"""
assert output_fmt in ('NCHW',), 'Output shape must be NCHW.'
intermediates = []
take_indices, max_index = feature_take_indices(len(self.stages), indices)
# forward pass
x = self.stem(x)
if torch.jit.is_scripting() or not stop_early: # can't slice blocks in torchscript
stages = self.stages
else:
stages = self.stages[:max_index + 1]
for feat_idx, stage in enumerate(stages):
x = stage(x)
if feat_idx in take_indices:
intermediates.append(x)
if intermediates_only:
return intermediates
return x, intermediates
def prune_intermediate_layers(
self,
indices: Union[int, List[int]] = 1,
prune_norm: bool = False,
prune_head: bool = True,
):
""" Prune layers not required for specified intermediates.
"""
take_indices, max_index = feature_take_indices(len(self.stages), indices)
self.stages = self.stages[:max_index + 1] # truncate blocks w/ stem as idx 0
if prune_head:
self.reset_classifier(0, '')
return take_indices
def forward_features(self, x): def forward_features(self, x):
x = self.stem(x) x = self.stem(x)
if self.grad_checkpointing and not torch.jit.is_scripting(): if self.grad_checkpointing and not torch.jit.is_scripting():

View File

@ -31,7 +31,7 @@ Original copyright of Google code below, modifications by Ross Wightman, Copyrig
from collections import OrderedDict # pylint: disable=g-importing-member from collections import OrderedDict # pylint: disable=g-importing-member
from functools import partial from functools import partial
from typing import Optional from typing import List, Optional, Tuple, Union
import torch import torch
import torch.nn as nn import torch.nn as nn
@ -40,6 +40,7 @@ from timm.data import IMAGENET_INCEPTION_MEAN, IMAGENET_INCEPTION_STD
from timm.layers import GroupNormAct, BatchNormAct2d, EvoNorm2dS0, FilterResponseNormTlu2d, ClassifierHead, \ from timm.layers import GroupNormAct, BatchNormAct2d, EvoNorm2dS0, FilterResponseNormTlu2d, ClassifierHead, \
DropPath, AvgPool2dSame, create_pool2d, StdConv2d, create_conv2d, get_act_layer, get_norm_act_layer, make_divisible DropPath, AvgPool2dSame, create_pool2d, StdConv2d, create_conv2d, get_act_layer, get_norm_act_layer, make_divisible
from ._builder import build_model_with_cfg from ._builder import build_model_with_cfg
from ._features import feature_take_indices
from ._manipulate import checkpoint_seq, named_apply, adapt_input_conv from ._manipulate import checkpoint_seq, named_apply, adapt_input_conv
from ._registry import generate_default_cfgs, register_model, register_model_deprecations from ._registry import generate_default_cfgs, register_model, register_model_deprecations
@ -543,6 +544,79 @@ class ResNetV2(nn.Module):
self.num_classes = num_classes self.num_classes = num_classes
self.head.reset(num_classes, global_pool) self.head.reset(num_classes, global_pool)
def forward_intermediates(
self,
x: torch.Tensor,
indices: Optional[Union[int, List[int]]] = None,
norm: bool = False,
stop_early: bool = False,
output_fmt: str = 'NCHW',
intermediates_only: bool = False,
) -> Union[List[torch.Tensor], Tuple[torch.Tensor, List[torch.Tensor]]]:
""" Forward features that returns intermediates.
Args:
x: Input image tensor
indices: Take last n blocks if int, all if None, select matching indices if sequence
norm: Apply norm layer to compatible intermediates
stop_early: Stop iterating over blocks when last desired intermediate hit
output_fmt: Shape of intermediate feature outputs
intermediates_only: Only return intermediate features
Returns:
"""
assert output_fmt in ('NCHW',), 'Output shape must be NCHW.'
intermediates = []
take_indices, max_index = feature_take_indices(5, indices)
# forward pass
feat_idx = 0
H, W = x.shape[-2:]
for stem in self.stem:
x = stem(x)
if x.shape[-2:] == (H //2, W //2):
x_down = x
if feat_idx in take_indices:
intermediates.append(x_down)
last_idx = len(self.stages)
if torch.jit.is_scripting() or not stop_early: # can't slice blocks in torchscript
stages = self.stages
else:
stages = self.stages[:max_index]
for feat_idx, stage in enumerate(stages, start=1):
x = stage(x)
if feat_idx in take_indices:
if feat_idx == last_idx:
x_inter = self.norm(x) if norm else x
intermediates.append(x_inter)
else:
intermediates.append(x)
if intermediates_only:
return intermediates
if feat_idx == last_idx:
x = self.norm(x)
return x, intermediates
def prune_intermediate_layers(
self,
indices: Union[int, List[int]] = 1,
prune_norm: bool = False,
prune_head: bool = True,
):
""" Prune layers not required for specified intermediates.
"""
take_indices, max_index = feature_take_indices(5, indices)
self.stages = self.stages[:max_index] # truncate blocks w/ stem as idx 0
if prune_norm:
self.norm = nn.Identity()
if prune_head:
self.reset_classifier(0, '')
return take_indices
def forward_features(self, x): def forward_features(self, x):
x = self.stem(x) x = self.stem(x)
if self.grad_checkpointing and not torch.jit.is_scripting(): if self.grad_checkpointing and not torch.jit.is_scripting():

View File

@ -12,7 +12,7 @@ Copyright 2020 Ross Wightman
from functools import partial from functools import partial
from math import ceil from math import ceil
from typing import Optional from typing import List, Optional, Tuple, Union
import torch import torch
import torch.nn as nn import torch.nn as nn
@ -21,6 +21,7 @@ from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
from timm.layers import ClassifierHead, create_act_layer, ConvNormAct, DropPath, make_divisible, SEModule from timm.layers import ClassifierHead, create_act_layer, ConvNormAct, DropPath, make_divisible, SEModule
from ._builder import build_model_with_cfg from ._builder import build_model_with_cfg
from ._efficientnet_builder import efficientnet_init_weights from ._efficientnet_builder import efficientnet_init_weights
from ._features import feature_take_indices
from ._manipulate import checkpoint_seq from ._manipulate import checkpoint_seq
from ._registry import generate_default_cfgs, register_model from ._registry import generate_default_cfgs, register_model
@ -234,6 +235,67 @@ class RexNet(nn.Module):
self.num_classes = num_classes self.num_classes = num_classes
self.head.reset(num_classes, global_pool) self.head.reset(num_classes, global_pool)
def forward_intermediates(
self,
x: torch.Tensor,
indices: Optional[Union[int, List[int]]] = None,
norm: bool = False,
stop_early: bool = False,
output_fmt: str = 'NCHW',
intermediates_only: bool = False,
) -> Union[List[torch.Tensor], Tuple[torch.Tensor, List[torch.Tensor]]]:
""" Forward features that returns intermediates.
Args:
x: Input image tensor
indices: Take last n blocks if int, all if None, select matching indices if sequence
norm: Apply norm layer to compatible intermediates
stop_early: Stop iterating over blocks when last desired intermediate hit
output_fmt: Shape of intermediate feature outputs
intermediates_only: Only return intermediate features
Returns:
"""
assert output_fmt in ('NCHW',), 'Output shape must be NCHW.'
intermediates = []
stage_ends = [int(info['module'].split('.')[-1]) for info in self.feature_info]
take_indices, max_index = feature_take_indices(len(stage_ends), indices)
take_indices = [stage_ends[i] for i in take_indices]
max_index = stage_ends[max_index]
# forward pass
x = self.stem(x)
if torch.jit.is_scripting() or not stop_early: # can't slice blocks in torchscript
stages = self.features
else:
stages = self.features[:max_index + 1]
for feat_idx, stage in enumerate(stages):
x = stage(x)
if feat_idx in take_indices:
intermediates.append(x)
if intermediates_only:
return intermediates
return x, intermediates
def prune_intermediate_layers(
self,
indices: Union[int, List[int]] = 1,
prune_norm: bool = False,
prune_head: bool = True,
):
""" Prune layers not required for specified intermediates.
"""
stage_ends = [int(info['module'].split('.')[-1]) for info in self.feature_info]
take_indices, max_index = feature_take_indices(len(stage_ends), indices)
max_index = stage_ends[max_index]
self.features = self.features[:max_index + 1] # truncate blocks w/ stem as idx 0
if prune_head:
self.reset_classifier(0, '')
return take_indices
def forward_features(self, x): def forward_features(self, x):
x = self.stem(x) x = self.stem(x)
if self.grad_checkpointing and not torch.jit.is_scripting(): if self.grad_checkpointing and not torch.jit.is_scripting():

View File

@ -10,7 +10,7 @@ __all__ = ['TinyVit']
import itertools import itertools
from functools import partial from functools import partial
from typing import Dict, Optional from typing import Dict, List, Optional, Tuple, Union
import torch import torch
import torch.nn as nn import torch.nn as nn
@ -20,6 +20,7 @@ from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
from timm.layers import LayerNorm2d, NormMlpClassifierHead, DropPath,\ from timm.layers import LayerNorm2d, NormMlpClassifierHead, DropPath,\
trunc_normal_, resize_rel_pos_bias_table_levit, use_fused_attn trunc_normal_, resize_rel_pos_bias_table_levit, use_fused_attn
from ._builder import build_model_with_cfg from ._builder import build_model_with_cfg
from ._features import feature_take_indices
from ._features_fx import register_notrace_module from ._features_fx import register_notrace_module
from ._manipulate import checkpoint_seq from ._manipulate import checkpoint_seq
from ._registry import register_model, generate_default_cfgs from ._registry import register_model, generate_default_cfgs
@ -536,6 +537,62 @@ class TinyVit(nn.Module):
self.num_classes = num_classes self.num_classes = num_classes
self.head.reset(num_classes, pool_type=global_pool) self.head.reset(num_classes, pool_type=global_pool)
def forward_intermediates(
self,
x: torch.Tensor,
indices: Optional[Union[int, List[int]]] = None,
norm: bool = False,
stop_early: bool = False,
output_fmt: str = 'NCHW',
intermediates_only: bool = False,
) -> Union[List[torch.Tensor], Tuple[torch.Tensor, List[torch.Tensor]]]:
""" Forward features that returns intermediates.
Args:
x: Input image tensor
indices: Take last n blocks if int, all if None, select matching indices if sequence
norm: Apply norm layer to compatible intermediates
stop_early: Stop iterating over blocks when last desired intermediate hit
output_fmt: Shape of intermediate feature outputs
intermediates_only: Only return intermediate features
Returns:
"""
assert output_fmt in ('NCHW',), 'Output shape must be NCHW.'
intermediates = []
take_indices, max_index = feature_take_indices(len(self.stages), indices)
# forward pass
x = self.patch_embed(x)
if torch.jit.is_scripting() or not stop_early: # can't slice blocks in torchscript
stages = self.stages
else:
stages = self.stages[:max_index + 1]
for feat_idx, stage in enumerate(stages):
x = stage(x)
if feat_idx in take_indices:
intermediates.append(x)
if intermediates_only:
return intermediates
return x, intermediates
def prune_intermediate_layers(
self,
indices: Union[int, List[int]] = 1,
prune_norm: bool = False,
prune_head: bool = True,
):
""" Prune layers not required for specified intermediates.
"""
take_indices, max_index = feature_take_indices(len(self.stages), indices)
self.stages = self.stages[:max_index + 1] # truncate blocks w/ stem as idx 0
if prune_head:
self.reset_classifier(0, '')
return take_indices
def forward_features(self, x): def forward_features(self, x):
x = self.patch_embed(x) x = self.patch_embed(x)
if self.grad_checkpointing and not torch.jit.is_scripting(): if self.grad_checkpointing and not torch.jit.is_scripting():

View File

@ -7,13 +7,14 @@ Original model: https://github.com/mrT23/TResNet
""" """
from collections import OrderedDict from collections import OrderedDict
from functools import partial from functools import partial
from typing import Optional from typing import List, Optional, Tuple, Union
import torch import torch
import torch.nn as nn import torch.nn as nn
from timm.layers import SpaceToDepth, BlurPool2d, ClassifierHead, SEModule, ConvNormAct, DropPath from timm.layers import SpaceToDepth, BlurPool2d, ClassifierHead, SEModule, ConvNormAct, DropPath
from ._builder import build_model_with_cfg from ._builder import build_model_with_cfg
from ._features import feature_take_indices
from ._manipulate import checkpoint_seq from ._manipulate import checkpoint_seq
from ._registry import register_model, generate_default_cfgs, register_model_deprecations from ._registry import register_model, generate_default_cfgs, register_model_deprecations
@ -228,6 +229,65 @@ class TResNet(nn.Module):
self.num_classes = num_classes self.num_classes = num_classes
self.head.reset(num_classes, pool_type=global_pool) self.head.reset(num_classes, pool_type=global_pool)
def forward_intermediates(
self,
x: torch.Tensor,
indices: Optional[Union[int, List[int]]] = None,
norm: bool = False,
stop_early: bool = False,
output_fmt: str = 'NCHW',
intermediates_only: bool = False,
) -> Union[List[torch.Tensor], Tuple[torch.Tensor, List[torch.Tensor]]]:
""" Forward features that returns intermediates.
Args:
x: Input image tensor
indices: Take last n blocks if int, all if None, select matching indices if sequence
norm: Apply norm layer to compatible intermediates
stop_early: Stop iterating over blocks when last desired intermediate hit
output_fmt: Shape of intermediate feature outputs
intermediates_only: Only return intermediate features
Returns:
"""
assert output_fmt in ('NCHW',), 'Output shape must be NCHW.'
intermediates = []
stage_ends = [1, 2, 3, 4, 5]
take_indices, max_index = feature_take_indices(len(stage_ends), indices)
take_indices = [stage_ends[i] for i in take_indices]
max_index = stage_ends[max_index]
# forward pass
if torch.jit.is_scripting() or not stop_early: # can't slice blocks in torchscript
stages = self.body
else:
stages = self.body[:max_index + 1]
for feat_idx, stage in enumerate(stages):
x = stage(x)
if feat_idx in take_indices:
intermediates.append(x)
if intermediates_only:
return intermediates
return x, intermediates
def prune_intermediate_layers(
self,
indices: Union[int, List[int]] = 1,
prune_norm: bool = False,
prune_head: bool = True,
):
""" Prune layers not required for specified intermediates.
"""
stage_ends = [1, 2, 3, 4, 5]
take_indices, max_index = feature_take_indices(len(stage_ends), indices)
max_index = stage_ends[max_index]
self.body = self.body[:max_index + 1] # truncate blocks w/ stem as idx 0
if prune_head:
self.reset_classifier(0, '')
return take_indices
def forward_features(self, x): def forward_features(self, x):
if self.grad_checkpointing and not torch.jit.is_scripting(): if self.grad_checkpointing and not torch.jit.is_scripting():
x = self.body.s2d(x) x = self.body.s2d(x)

View File

@ -11,7 +11,7 @@ for some reference, rewrote most of the code.
Hacked together by / Copyright 2020 Ross Wightman Hacked together by / Copyright 2020 Ross Wightman
""" """
from typing import List, Optional from typing import List, Optional, Tuple, Union
import torch import torch
import torch.nn as nn import torch.nn as nn
@ -20,6 +20,7 @@ from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
from timm.layers import ConvNormAct, SeparableConvNormAct, BatchNormAct2d, ClassifierHead, DropPath, \ from timm.layers import ConvNormAct, SeparableConvNormAct, BatchNormAct2d, ClassifierHead, DropPath, \
create_attn, create_norm_act_layer create_attn, create_norm_act_layer
from ._builder import build_model_with_cfg from ._builder import build_model_with_cfg
from ._features import feature_take_indices
from ._manipulate import checkpoint_seq from ._manipulate import checkpoint_seq
from ._registry import register_model, generate_default_cfgs from ._registry import register_model, generate_default_cfgs
@ -264,6 +265,67 @@ class VovNet(nn.Module):
self.num_classes = num_classes self.num_classes = num_classes
self.head.reset(num_classes, global_pool) self.head.reset(num_classes, global_pool)
def forward_intermediates(
self,
x: torch.Tensor,
indices: Optional[Union[int, List[int]]] = None,
norm: bool = False,
stop_early: bool = False,
output_fmt: str = 'NCHW',
intermediates_only: bool = False,
) -> Union[List[torch.Tensor], Tuple[torch.Tensor, List[torch.Tensor]]]:
""" Forward features that returns intermediates.
Args:
x: Input image tensor
indices: Take last n blocks if int, all if None, select matching indices if sequence
norm: Apply norm layer to compatible intermediates
stop_early: Stop iterating over blocks when last desired intermediate hit
output_fmt: Shape of intermediate feature outputs
intermediates_only: Only return intermediate features
Returns:
"""
assert output_fmt in ('NCHW',), 'Output shape must be NCHW.'
intermediates = []
take_indices, max_index = feature_take_indices(5, indices)
# forward pass
feat_idx = 0
x = self.stem[:-1](x)
if feat_idx in take_indices:
intermediates.append(x)
x = self.stem[-1](x)
if torch.jit.is_scripting() or not stop_early: # can't slice blocks in torchscript
stages = self.stages
else:
stages = self.stages[:max_index]
for feat_idx, stage in enumerate(stages, start=1):
x = stage(x)
if feat_idx in take_indices:
intermediates.append(x)
if intermediates_only:
return intermediates
return x, intermediates
def prune_intermediate_layers(
self,
indices: Union[int, List[int]] = 1,
prune_norm: bool = False,
prune_head: bool = True,
):
""" Prune layers not required for specified intermediates.
"""
take_indices, max_index = feature_take_indices(5, indices)
self.stages = self.stages[:max_index] # truncate blocks w/ stem as idx 0
if prune_head:
self.reset_classifier(0, '')
return take_indices
def forward_features(self, x): def forward_features(self, x):
x = self.stem(x) x = self.stem(x)
return self.stages(x) return self.stages(x)

View File

@ -494,7 +494,8 @@ class Xcit(nn.Module):
# NOTE not supporting return of class tokens # NOTE not supporting return of class tokens
x = torch.cat((self.cls_token.expand(B, -1, -1), x), dim=1) x = torch.cat((self.cls_token.expand(B, -1, -1), x), dim=1)
for blk in self.cls_attn_blocks: for blk in self.cls_attn_blocks:
x = blk(x) x = blk(x)
x = self.norm(x) x = self.norm(x)
return x, intermediates return x, intermediates