mirror of
https://github.com/huggingface/pytorch-image-models.git
synced 2025-06-03 15:01:08 +08:00
Merge branch 'main' into tnt
This commit is contained in:
commit
69b1fbcdc1
@ -566,7 +566,7 @@ Model validation results can be found in the [results tables](results/README.md)
|
||||
|
||||
The official documentation can be found at https://huggingface.co/docs/hub/timm. Documentation contributions are welcome.
|
||||
|
||||
[Getting Started with PyTorch Image Models (timm): A Practitioner’s Guide](https://towardsdatascience.com/getting-started-with-pytorch-image-models-timm-a-practitioners-guide-4e77b4bf9055) by [Chris Hughes](https://github.com/Chris-hughes10) is an extensive blog post covering many aspects of `timm` in detail.
|
||||
[Getting Started with PyTorch Image Models (timm): A Practitioner’s Guide](https://towardsdatascience.com/getting-started-with-pytorch-image-models-timm-a-practitioners-guide-4e77b4bf9055-2/) by [Chris Hughes](https://github.com/Chris-hughes10) is an extensive blog post covering many aspects of `timm` in detail.
|
||||
|
||||
[timmdocs](http://timm.fast.ai/) is an alternate set of documentation for `timm`. A big thanks to [Aman Arora](https://github.com/amaarora) for his efforts creating timmdocs.
|
||||
|
||||
|
@ -54,6 +54,9 @@ FEAT_INTER_FILTERS = [
|
||||
'beit', 'mvitv2', 'eva', 'cait', 'xcit', 'volo', 'twins', 'deit', 'swin_transformer', 'swin_transformer_v2',
|
||||
'swin_transformer_v2_cr', 'maxxvit', 'efficientnet', 'mobilenetv3', 'levit', 'efficientformer', 'resnet',
|
||||
'regnet', 'byobnet', 'byoanet', 'mlp_mixer', 'hiera', 'fastvit', 'hieradet_sam2', 'aimv2*', 'tnt',
|
||||
'tiny_vit', 'vovnet', 'tresnet', 'rexnet', 'resnetv2', 'repghost', 'repvit', 'pvt_v2', 'nextvit', 'nest',
|
||||
'mambaout', 'inception_next', 'inception_v4', 'hgnet', 'gcvit', 'focalnet', 'efficientformer_v2', 'edgenext',
|
||||
'davit', 'rdnet', 'convnext', 'pit'
|
||||
]
|
||||
|
||||
# transformer / hybrid models don't support full set of spatial / feature APIs and/or have spatial output.
|
||||
@ -508,8 +511,9 @@ def test_model_forward_intermediates(model_name, batch_size):
|
||||
spatial_axis = get_spatial_dim(output_fmt)
|
||||
import math
|
||||
|
||||
inpt = torch.randn((batch_size, *input_size))
|
||||
output, intermediates = model.forward_intermediates(
|
||||
torch.randn((batch_size, *input_size)),
|
||||
inpt,
|
||||
output_fmt=output_fmt,
|
||||
)
|
||||
assert len(expected_channels) == len(intermediates)
|
||||
@ -521,6 +525,9 @@ def test_model_forward_intermediates(model_name, batch_size):
|
||||
assert o.shape[0] == batch_size
|
||||
assert not torch.isnan(o).any()
|
||||
|
||||
output2 = model.forward_features(inpt)
|
||||
assert torch.allclose(output, output2)
|
||||
|
||||
|
||||
def _create_fx_model(model, train=False):
|
||||
# This block of code does a bit of juggling to handle any case where there are multiple outputs in train mode
|
||||
|
@ -144,6 +144,7 @@ def create_dataset(
|
||||
use_train = split in _TRAIN_SYNONYM
|
||||
ds = QMNIST(train=use_train, **torch_kwargs)
|
||||
elif name == 'imagenet':
|
||||
torch_kwargs.pop('download')
|
||||
assert has_imagenet, 'Please update to a newer PyTorch and torchvision for ImageNet dataset.'
|
||||
if split in _EVAL_SYNONYM:
|
||||
split = 'val'
|
||||
|
@ -452,29 +452,29 @@ class ConvNeXt(nn.Module):
|
||||
"""
|
||||
assert output_fmt in ('NCHW',), 'Output shape must be NCHW.'
|
||||
intermediates = []
|
||||
take_indices, max_index = feature_take_indices(len(self.stages) + 1, indices)
|
||||
take_indices, max_index = feature_take_indices(len(self.stages), indices)
|
||||
|
||||
# forward pass
|
||||
feat_idx = 0 # stem is index 0
|
||||
x = self.stem(x)
|
||||
if feat_idx in take_indices:
|
||||
intermediates.append(x)
|
||||
|
||||
last_idx = len(self.stages) - 1
|
||||
if torch.jit.is_scripting() or not stop_early: # can't slice blocks in torchscript
|
||||
stages = self.stages
|
||||
else:
|
||||
stages = self.stages[:max_index]
|
||||
for stage in stages:
|
||||
feat_idx += 1
|
||||
stages = self.stages[:max_index + 1]
|
||||
for feat_idx, stage in enumerate(stages):
|
||||
x = stage(x)
|
||||
if feat_idx in take_indices:
|
||||
# NOTE not bothering to apply norm_pre when norm=True as almost no models have it enabled
|
||||
intermediates.append(x)
|
||||
if norm and feat_idx == last_idx:
|
||||
intermediates.append(self.norm_pre(x))
|
||||
else:
|
||||
intermediates.append(x)
|
||||
|
||||
if intermediates_only:
|
||||
return intermediates
|
||||
|
||||
x = self.norm_pre(x)
|
||||
if feat_idx == last_idx:
|
||||
x = self.norm_pre(x)
|
||||
|
||||
return x, intermediates
|
||||
|
||||
@ -486,8 +486,8 @@ class ConvNeXt(nn.Module):
|
||||
):
|
||||
""" Prune layers not required for specified intermediates.
|
||||
"""
|
||||
take_indices, max_index = feature_take_indices(len(self.stages) + 1, indices)
|
||||
self.stages = self.stages[:max_index] # truncate blocks w/ stem as idx 0
|
||||
take_indices, max_index = feature_take_indices(len(self.stages), indices)
|
||||
self.stages = self.stages[:max_index + 1] # truncate blocks w/ stem as idx 0
|
||||
if prune_norm:
|
||||
self.norm_pre = nn.Identity()
|
||||
if prune_head:
|
||||
|
@ -12,7 +12,7 @@ DaViT model defs and weights adapted from https://github.com/dingmyu/davit, orig
|
||||
# All rights reserved.
|
||||
# This source code is licensed under the MIT license
|
||||
from functools import partial
|
||||
from typing import Optional, Tuple
|
||||
from typing import List, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
@ -23,6 +23,7 @@ from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
|
||||
from timm.layers import DropPath, to_2tuple, trunc_normal_, Mlp, LayerNorm2d, get_norm_layer, use_fused_attn
|
||||
from timm.layers import NormMlpClassifierHead, ClassifierHead
|
||||
from ._builder import build_model_with_cfg
|
||||
from ._features import feature_take_indices
|
||||
from ._features_fx import register_notrace_function
|
||||
from ._manipulate import checkpoint_seq
|
||||
from ._registry import generate_default_cfgs, register_model
|
||||
@ -636,6 +637,72 @@ class DaVit(nn.Module):
|
||||
self.num_classes = num_classes
|
||||
self.head.reset(num_classes, global_pool)
|
||||
|
||||
def forward_intermediates(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
indices: Optional[Union[int, List[int]]] = None,
|
||||
norm: bool = False,
|
||||
stop_early: bool = False,
|
||||
output_fmt: str = 'NCHW',
|
||||
intermediates_only: bool = False,
|
||||
) -> Union[List[torch.Tensor], Tuple[torch.Tensor, List[torch.Tensor]]]:
|
||||
""" Forward features that returns intermediates.
|
||||
|
||||
Args:
|
||||
x: Input image tensor
|
||||
indices: Take last n blocks if int, all if None, select matching indices if sequence
|
||||
norm: Apply norm layer to compatible intermediates
|
||||
stop_early: Stop iterating over blocks when last desired intermediate hit
|
||||
output_fmt: Shape of intermediate feature outputs
|
||||
intermediates_only: Only return intermediate features
|
||||
Returns:
|
||||
|
||||
"""
|
||||
assert output_fmt in ('NCHW',), 'Output shape must be NCHW.'
|
||||
intermediates = []
|
||||
take_indices, max_index = feature_take_indices(len(self.stages), indices)
|
||||
|
||||
# forward pass
|
||||
x = self.stem(x)
|
||||
last_idx = len(self.stages) - 1
|
||||
if torch.jit.is_scripting() or not stop_early: # can't slice blocks in torchscript
|
||||
stages = self.stages
|
||||
else:
|
||||
stages = self.stages[:max_index + 1]
|
||||
|
||||
for feat_idx, stage in enumerate(stages):
|
||||
x = stage(x)
|
||||
if feat_idx in take_indices:
|
||||
if norm and feat_idx == last_idx:
|
||||
x_inter = self.norm_pre(x) # applying final norm to last intermediate
|
||||
else:
|
||||
x_inter = x
|
||||
intermediates.append(x_inter)
|
||||
|
||||
if intermediates_only:
|
||||
return intermediates
|
||||
|
||||
if feat_idx == last_idx:
|
||||
x = self.norm_pre(x)
|
||||
|
||||
return x, intermediates
|
||||
|
||||
def prune_intermediate_layers(
|
||||
self,
|
||||
indices: Union[int, List[int]] = 1,
|
||||
prune_norm: bool = False,
|
||||
prune_head: bool = True,
|
||||
):
|
||||
""" Prune layers not required for specified intermediates.
|
||||
"""
|
||||
take_indices, max_index = feature_take_indices(len(self.stages), indices)
|
||||
self.stages = self.stages[:max_index + 1] # truncate blocks w/ stem as idx 0
|
||||
if prune_norm:
|
||||
self.norm_pre = nn.Identity()
|
||||
if prune_head:
|
||||
self.reset_classifier(0, '')
|
||||
return take_indices
|
||||
|
||||
def forward_features(self, x):
|
||||
x = self.stem(x)
|
||||
if self.grad_checkpointing and not torch.jit.is_scripting():
|
||||
|
@ -9,7 +9,7 @@ Modifications and additions for timm by / Copyright 2022, Ross Wightman
|
||||
"""
|
||||
import math
|
||||
from functools import partial
|
||||
from typing import Optional, Tuple
|
||||
from typing import List, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
@ -19,6 +19,7 @@ from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
|
||||
from timm.layers import trunc_normal_tf_, DropPath, LayerNorm2d, Mlp, create_conv2d, \
|
||||
NormMlpClassifierHead, ClassifierHead
|
||||
from ._builder import build_model_with_cfg
|
||||
from ._features import feature_take_indices
|
||||
from ._features_fx import register_notrace_module
|
||||
from ._manipulate import named_apply, checkpoint_seq
|
||||
from ._registry import register_model, generate_default_cfgs
|
||||
@ -418,6 +419,72 @@ class EdgeNeXt(nn.Module):
|
||||
self.num_classes = num_classes
|
||||
self.head.reset(num_classes, global_pool)
|
||||
|
||||
def forward_intermediates(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
indices: Optional[Union[int, List[int]]] = None,
|
||||
norm: bool = False,
|
||||
stop_early: bool = False,
|
||||
output_fmt: str = 'NCHW',
|
||||
intermediates_only: bool = False,
|
||||
) -> Union[List[torch.Tensor], Tuple[torch.Tensor, List[torch.Tensor]]]:
|
||||
""" Forward features that returns intermediates.
|
||||
|
||||
Args:
|
||||
x: Input image tensor
|
||||
indices: Take last n blocks if int, all if None, select matching indices if sequence
|
||||
norm: Apply norm layer to compatible intermediates
|
||||
stop_early: Stop iterating over blocks when last desired intermediate hit
|
||||
output_fmt: Shape of intermediate feature outputs
|
||||
intermediates_only: Only return intermediate features
|
||||
Returns:
|
||||
|
||||
"""
|
||||
assert output_fmt in ('NCHW',), 'Output shape must be NCHW.'
|
||||
intermediates = []
|
||||
take_indices, max_index = feature_take_indices(len(self.stages), indices)
|
||||
|
||||
# forward pass
|
||||
x = self.stem(x)
|
||||
last_idx = len(self.stages) - 1
|
||||
if torch.jit.is_scripting() or not stop_early: # can't slice blocks in torchscript
|
||||
stages = self.stages
|
||||
else:
|
||||
stages = self.stages[:max_index + 1]
|
||||
|
||||
for feat_idx, stage in enumerate(stages):
|
||||
x = stage(x)
|
||||
if feat_idx in take_indices:
|
||||
if norm and feat_idx == last_idx:
|
||||
x_inter = self.norm_pre(x) # applying final norm to last intermediate
|
||||
else:
|
||||
x_inter = x
|
||||
intermediates.append(x_inter)
|
||||
|
||||
if intermediates_only:
|
||||
return intermediates
|
||||
|
||||
if feat_idx == last_idx:
|
||||
x = self.norm_pre(x)
|
||||
|
||||
return x, intermediates
|
||||
|
||||
def prune_intermediate_layers(
|
||||
self,
|
||||
indices: Union[int, List[int]] = 1,
|
||||
prune_norm: bool = False,
|
||||
prune_head: bool = True,
|
||||
):
|
||||
""" Prune layers not required for specified intermediates.
|
||||
"""
|
||||
take_indices, max_index = feature_take_indices(len(self.stages), indices)
|
||||
self.stages = self.stages[:max_index + 1] # truncate blocks w/ stem as idx 0
|
||||
if prune_norm:
|
||||
self.norm_pre = nn.Identity()
|
||||
if prune_head:
|
||||
self.reset_classifier(0, '')
|
||||
return take_indices
|
||||
|
||||
def forward_features(self, x):
|
||||
x = self.stem(x)
|
||||
x = self.stages(x)
|
||||
|
@ -16,7 +16,7 @@ Modifications and timm support by / Copyright 2023, Ross Wightman
|
||||
"""
|
||||
import math
|
||||
from functools import partial
|
||||
from typing import Dict, Optional
|
||||
from typing import Dict, List, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
@ -25,6 +25,7 @@ from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
|
||||
from timm.layers import create_conv2d, create_norm_layer, get_act_layer, get_norm_layer, ConvNormAct
|
||||
from timm.layers import DropPath, trunc_normal_, to_2tuple, to_ntuple, ndgrid
|
||||
from ._builder import build_model_with_cfg
|
||||
from ._features import feature_take_indices
|
||||
from ._manipulate import checkpoint_seq
|
||||
from ._registry import generate_default_cfgs, register_model
|
||||
|
||||
@ -625,6 +626,73 @@ class EfficientFormerV2(nn.Module):
|
||||
def set_distilled_training(self, enable=True):
|
||||
self.distilled_training = enable
|
||||
|
||||
def forward_intermediates(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
indices: Optional[Union[int, List[int]]] = None,
|
||||
norm: bool = False,
|
||||
stop_early: bool = False,
|
||||
output_fmt: str = 'NCHW',
|
||||
intermediates_only: bool = False,
|
||||
) -> Union[List[torch.Tensor], Tuple[torch.Tensor, List[torch.Tensor]]]:
|
||||
""" Forward features that returns intermediates.
|
||||
|
||||
Args:
|
||||
x: Input image tensor
|
||||
indices: Take last n blocks if int, all if None, select matching indices if sequence
|
||||
norm: Apply norm layer to compatible intermediates
|
||||
stop_early: Stop iterating over blocks when last desired intermediate hit
|
||||
output_fmt: Shape of intermediate feature outputs
|
||||
intermediates_only: Only return intermediate features
|
||||
Returns:
|
||||
|
||||
"""
|
||||
assert output_fmt in ('NCHW',), 'Output shape must be NCHW.'
|
||||
intermediates = []
|
||||
take_indices, max_index = feature_take_indices(len(self.stages), indices)
|
||||
|
||||
# forward pass
|
||||
x = self.stem(x)
|
||||
|
||||
last_idx = len(self.stages) - 1
|
||||
if torch.jit.is_scripting() or not stop_early: # can't slice blocks in torchscript
|
||||
stages = self.stages
|
||||
else:
|
||||
stages = self.stages[:max_index + 1]
|
||||
|
||||
for feat_idx, stage in enumerate(stages):
|
||||
x = stage(x)
|
||||
if feat_idx in take_indices:
|
||||
if feat_idx == last_idx:
|
||||
x_inter = self.norm(x) if norm else x
|
||||
intermediates.append(x_inter)
|
||||
else:
|
||||
intermediates.append(x)
|
||||
|
||||
if intermediates_only:
|
||||
return intermediates
|
||||
|
||||
if feat_idx == last_idx:
|
||||
x = self.norm(x)
|
||||
|
||||
return x, intermediates
|
||||
|
||||
def prune_intermediate_layers(
|
||||
self,
|
||||
indices: Union[int, List[int]] = 1,
|
||||
prune_norm: bool = False,
|
||||
prune_head: bool = True,
|
||||
):
|
||||
""" Prune layers not required for specified intermediates.
|
||||
"""
|
||||
take_indices, max_index = feature_take_indices(len(self.stages), indices)
|
||||
self.stages = self.stages[:max_index + 1] # truncate blocks w/ stem as idx 0
|
||||
if prune_norm:
|
||||
self.norm = nn.Identity()
|
||||
if prune_head:
|
||||
self.reset_classifier(0, '')
|
||||
return take_indices
|
||||
|
||||
def forward_features(self, x):
|
||||
x = self.stem(x)
|
||||
x = self.stages(x)
|
||||
|
@ -7,7 +7,7 @@ Adapted from official impl at https://github.com/mit-han-lab/efficientvit
|
||||
"""
|
||||
|
||||
__all__ = ['EfficientVit', 'EfficientVitLarge']
|
||||
from typing import List, Optional
|
||||
from typing import List, Optional, Tuple, Union
|
||||
from functools import partial
|
||||
|
||||
import torch
|
||||
@ -17,6 +17,7 @@ import torch.nn.functional as F
|
||||
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
|
||||
from timm.layers import SelectAdaptivePool2d, create_conv2d, GELUTanh
|
||||
from ._builder import build_model_with_cfg
|
||||
from ._features import feature_take_indices
|
||||
from ._features_fx import register_notrace_module
|
||||
from ._manipulate import checkpoint_seq
|
||||
from ._registry import register_model, generate_default_cfgs
|
||||
@ -754,6 +755,63 @@ class EfficientVit(nn.Module):
|
||||
self.num_classes = num_classes
|
||||
self.head.reset(num_classes, global_pool)
|
||||
|
||||
def forward_intermediates(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
indices: Optional[Union[int, List[int]]] = None,
|
||||
norm: bool = False,
|
||||
stop_early: bool = False,
|
||||
output_fmt: str = 'NCHW',
|
||||
intermediates_only: bool = False,
|
||||
) -> Union[List[torch.Tensor], Tuple[torch.Tensor, List[torch.Tensor]]]:
|
||||
""" Forward features that returns intermediates.
|
||||
|
||||
Args:
|
||||
x: Input image tensor
|
||||
indices: Take last n blocks if int, all if None, select matching indices if sequence
|
||||
norm: Apply norm layer to compatible intermediates
|
||||
stop_early: Stop iterating over blocks when last desired intermediate hit
|
||||
output_fmt: Shape of intermediate feature outputs
|
||||
intermediates_only: Only return intermediate features
|
||||
Returns:
|
||||
|
||||
"""
|
||||
assert output_fmt in ('NCHW',), 'Output shape must be NCHW.'
|
||||
intermediates = []
|
||||
take_indices, max_index = feature_take_indices(len(self.stages), indices)
|
||||
|
||||
# forward pass
|
||||
x = self.stem(x)
|
||||
|
||||
if torch.jit.is_scripting() or not stop_early: # can't slice blocks in torchscript
|
||||
stages = self.stages
|
||||
else:
|
||||
stages = self.stages[:max_index + 1]
|
||||
|
||||
for feat_idx, stage in enumerate(stages):
|
||||
x = stage(x)
|
||||
if feat_idx in take_indices:
|
||||
intermediates.append(x)
|
||||
|
||||
if intermediates_only:
|
||||
return intermediates
|
||||
|
||||
return x, intermediates
|
||||
|
||||
def prune_intermediate_layers(
|
||||
self,
|
||||
indices: Union[int, List[int]] = 1,
|
||||
prune_norm: bool = False,
|
||||
prune_head: bool = True,
|
||||
):
|
||||
""" Prune layers not required for specified intermediates.
|
||||
"""
|
||||
take_indices, max_index = feature_take_indices(len(self.stages), indices)
|
||||
self.stages = self.stages[:max_index + 1] # truncate blocks w/ stem as idx 0
|
||||
if prune_head:
|
||||
self.reset_classifier(0, '')
|
||||
return take_indices
|
||||
|
||||
def forward_features(self, x):
|
||||
x = self.stem(x)
|
||||
if self.grad_checkpointing and not torch.jit.is_scripting():
|
||||
@ -851,6 +909,63 @@ class EfficientVitLarge(nn.Module):
|
||||
self.num_classes = num_classes
|
||||
self.head.reset(num_classes, global_pool)
|
||||
|
||||
def forward_intermediates(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
indices: Optional[Union[int, List[int]]] = None,
|
||||
norm: bool = False,
|
||||
stop_early: bool = False,
|
||||
output_fmt: str = 'NCHW',
|
||||
intermediates_only: bool = False,
|
||||
) -> Union[List[torch.Tensor], Tuple[torch.Tensor, List[torch.Tensor]]]:
|
||||
""" Forward features that returns intermediates.
|
||||
|
||||
Args:
|
||||
x: Input image tensor
|
||||
indices: Take last n blocks if int, all if None, select matching indices if sequence
|
||||
norm: Apply norm layer to compatible intermediates
|
||||
stop_early: Stop iterating over blocks when last desired intermediate hit
|
||||
output_fmt: Shape of intermediate feature outputs
|
||||
intermediates_only: Only return intermediate features
|
||||
Returns:
|
||||
|
||||
"""
|
||||
assert output_fmt in ('NCHW',), 'Output shape must be NCHW.'
|
||||
intermediates = []
|
||||
take_indices, max_index = feature_take_indices(len(self.stages), indices)
|
||||
|
||||
# forward pass
|
||||
x = self.stem(x)
|
||||
|
||||
if torch.jit.is_scripting() or not stop_early: # can't slice blocks in torchscript
|
||||
stages = self.stages
|
||||
else:
|
||||
stages = self.stages[:max_index + 1]
|
||||
|
||||
for feat_idx, stage in enumerate(stages):
|
||||
x = stage(x)
|
||||
if feat_idx in take_indices:
|
||||
intermediates.append(x)
|
||||
|
||||
if intermediates_only:
|
||||
return intermediates
|
||||
|
||||
return x, intermediates
|
||||
|
||||
def prune_intermediate_layers(
|
||||
self,
|
||||
indices: Union[int, List[int]] = 1,
|
||||
prune_norm: bool = False,
|
||||
prune_head: bool = True,
|
||||
):
|
||||
""" Prune layers not required for specified intermediates.
|
||||
"""
|
||||
take_indices, max_index = feature_take_indices(len(self.stages), indices)
|
||||
self.stages = self.stages[:max_index + 1] # truncate blocks w/ stem as idx 0
|
||||
if prune_head:
|
||||
self.reset_classifier(0, '')
|
||||
return take_indices
|
||||
|
||||
def forward_features(self, x):
|
||||
x = self.stem(x)
|
||||
if self.grad_checkpointing and not torch.jit.is_scripting():
|
||||
|
@ -9,7 +9,7 @@ Adapted from official impl at https://github.com/microsoft/Cream/tree/main/Effic
|
||||
__all__ = ['EfficientVitMsra']
|
||||
import itertools
|
||||
from collections import OrderedDict
|
||||
from typing import Dict, Optional
|
||||
from typing import Dict, List, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
@ -17,6 +17,7 @@ import torch.nn as nn
|
||||
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
|
||||
from timm.layers import SqueezeExcite, SelectAdaptivePool2d, trunc_normal_, _assert
|
||||
from ._builder import build_model_with_cfg
|
||||
from ._features import feature_take_indices
|
||||
from ._manipulate import checkpoint_seq
|
||||
from ._registry import register_model, generate_default_cfgs
|
||||
|
||||
@ -475,6 +476,63 @@ class EfficientVitMsra(nn.Module):
|
||||
self.head = NormLinear(
|
||||
self.num_features, num_classes, drop=self.drop_rate) if num_classes > 0 else torch.nn.Identity()
|
||||
|
||||
def forward_intermediates(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
indices: Optional[Union[int, List[int]]] = None,
|
||||
norm: bool = False,
|
||||
stop_early: bool = False,
|
||||
output_fmt: str = 'NCHW',
|
||||
intermediates_only: bool = False,
|
||||
) -> Union[List[torch.Tensor], Tuple[torch.Tensor, List[torch.Tensor]]]:
|
||||
""" Forward features that returns intermediates.
|
||||
|
||||
Args:
|
||||
x: Input image tensor
|
||||
indices: Take last n blocks if int, all if None, select matching indices if sequence
|
||||
norm: Apply norm layer to compatible intermediates
|
||||
stop_early: Stop iterating over blocks when last desired intermediate hit
|
||||
output_fmt: Shape of intermediate feature outputs
|
||||
intermediates_only: Only return intermediate features
|
||||
Returns:
|
||||
|
||||
"""
|
||||
assert output_fmt in ('NCHW',), 'Output shape must be NCHW.'
|
||||
intermediates = []
|
||||
take_indices, max_index = feature_take_indices(len(self.stages), indices)
|
||||
|
||||
# forward pass
|
||||
x = self.patch_embed(x)
|
||||
|
||||
if torch.jit.is_scripting() or not stop_early: # can't slice blocks in torchscript
|
||||
stages = self.stages
|
||||
else:
|
||||
stages = self.stages[:max_index + 1]
|
||||
|
||||
for feat_idx, stage in enumerate(stages):
|
||||
x = stage(x)
|
||||
if feat_idx in take_indices:
|
||||
intermediates.append(x)
|
||||
|
||||
if intermediates_only:
|
||||
return intermediates
|
||||
|
||||
return x, intermediates
|
||||
|
||||
def prune_intermediate_layers(
|
||||
self,
|
||||
indices: Union[int, List[int]] = 1,
|
||||
prune_norm: bool = False,
|
||||
prune_head: bool = True,
|
||||
):
|
||||
""" Prune layers not required for specified intermediates.
|
||||
"""
|
||||
take_indices, max_index = feature_take_indices(len(self.stages), indices)
|
||||
self.stages = self.stages[:max_index + 1] # truncate blocks w/ stem as idx 0
|
||||
if prune_head:
|
||||
self.reset_classifier(0, '')
|
||||
return take_indices
|
||||
|
||||
def forward_features(self, x):
|
||||
x = self.patch_embed(x)
|
||||
if self.grad_checkpointing and not torch.jit.is_scripting():
|
||||
|
@ -18,7 +18,7 @@ This impl is/has:
|
||||
# Written by Jianwei Yang (jianwyan@microsoft.com)
|
||||
# --------------------------------------------------------
|
||||
from functools import partial
|
||||
from typing import Callable, Optional, Tuple
|
||||
from typing import Callable, List, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
@ -26,6 +26,7 @@ import torch.nn as nn
|
||||
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
|
||||
from timm.layers import Mlp, DropPath, LayerNorm2d, trunc_normal_, ClassifierHead, NormMlpClassifierHead
|
||||
from ._builder import build_model_with_cfg
|
||||
from ._features import feature_take_indices
|
||||
from ._manipulate import named_apply, checkpoint
|
||||
from ._registry import generate_default_cfgs, register_model
|
||||
|
||||
@ -458,6 +459,72 @@ class FocalNet(nn.Module):
|
||||
self.num_classes = num_classes
|
||||
self.head.reset(num_classes, pool_type=global_pool)
|
||||
|
||||
def forward_intermediates(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
indices: Optional[Union[int, List[int]]] = None,
|
||||
norm: bool = False,
|
||||
stop_early: bool = False,
|
||||
output_fmt: str = 'NCHW',
|
||||
intermediates_only: bool = False,
|
||||
) -> Union[List[torch.Tensor], Tuple[torch.Tensor, List[torch.Tensor]]]:
|
||||
""" Forward features that returns intermediates.
|
||||
|
||||
Args:
|
||||
x: Input image tensor
|
||||
indices: Take last n blocks if int, all if None, select matching indices if sequence
|
||||
norm: Apply norm layer to compatible intermediates
|
||||
stop_early: Stop iterating over blocks when last desired intermediate hit
|
||||
output_fmt: Shape of intermediate feature outputs
|
||||
intermediates_only: Only return intermediate features
|
||||
Returns:
|
||||
|
||||
"""
|
||||
assert output_fmt in ('NCHW',), 'Output shape must be NCHW.'
|
||||
intermediates = []
|
||||
take_indices, max_index = feature_take_indices(len(self.layers), indices)
|
||||
|
||||
# forward pass
|
||||
x = self.stem(x)
|
||||
if torch.jit.is_scripting() or not stop_early: # can't slice blocks in torchscript
|
||||
stages = self.layers
|
||||
else:
|
||||
stages = self.layers[:max_index + 1]
|
||||
|
||||
last_idx = len(self.layers) - 1
|
||||
for feat_idx, stage in enumerate(stages):
|
||||
x = stage(x)
|
||||
if feat_idx in take_indices:
|
||||
if norm and feat_idx == last_idx:
|
||||
x_inter = self.norm(x) # applying final norm to last intermediate
|
||||
else:
|
||||
x_inter = x
|
||||
intermediates.append(x_inter)
|
||||
|
||||
if intermediates_only:
|
||||
return intermediates
|
||||
|
||||
if feat_idx == last_idx:
|
||||
x = self.norm(x)
|
||||
|
||||
return x, intermediates
|
||||
|
||||
def prune_intermediate_layers(
|
||||
self,
|
||||
indices: Union[int, List[int]] = 1,
|
||||
prune_norm: bool = False,
|
||||
prune_head: bool = True,
|
||||
):
|
||||
""" Prune layers not required for specified intermediates.
|
||||
"""
|
||||
take_indices, max_index = feature_take_indices(len(self.layers), indices)
|
||||
self.layers = self.layers[:max_index + 1] # truncate blocks w/ stem as idx 0
|
||||
if prune_norm:
|
||||
self.norm = nn.Identity()
|
||||
if prune_head:
|
||||
self.reset_classifier(0, '')
|
||||
return take_indices
|
||||
|
||||
def forward_features(self, x):
|
||||
x = self.stem(x)
|
||||
x = self.layers(x)
|
||||
|
@ -30,6 +30,7 @@ from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
|
||||
from timm.layers import DropPath, to_2tuple, to_ntuple, Mlp, ClassifierHead, LayerNorm2d, \
|
||||
get_attn, get_act_layer, get_norm_layer, RelPosBias, _assert
|
||||
from ._builder import build_model_with_cfg
|
||||
from ._features import feature_take_indices
|
||||
from ._features_fx import register_notrace_function
|
||||
from ._manipulate import named_apply, checkpoint
|
||||
from ._registry import register_model, generate_default_cfgs
|
||||
@ -397,7 +398,7 @@ class GlobalContextVit(nn.Module):
|
||||
act_layer = get_act_layer(act_layer)
|
||||
norm_layer = partial(get_norm_layer(norm_layer), eps=norm_eps)
|
||||
norm_layer_cl = partial(get_norm_layer(norm_layer_cl), eps=norm_eps)
|
||||
|
||||
self.feature_info = []
|
||||
img_size = to_2tuple(img_size)
|
||||
feat_size = tuple(d // 4 for d in img_size) # stem reduction by 4
|
||||
self.global_pool = global_pool
|
||||
@ -441,6 +442,7 @@ class GlobalContextVit(nn.Module):
|
||||
norm_layer=norm_layer,
|
||||
norm_layer_cl=norm_layer_cl,
|
||||
))
|
||||
self.feature_info += [dict(num_chs=stages[-1].dim, reduction=2**(i+2), module=f'stages.{i}')]
|
||||
self.stages = nn.Sequential(*stages)
|
||||
|
||||
# Classifier head
|
||||
@ -494,6 +496,62 @@ class GlobalContextVit(nn.Module):
|
||||
global_pool = self.head.global_pool.pool_type
|
||||
self.head = ClassifierHead(self.num_features, num_classes, pool_type=global_pool, drop_rate=self.drop_rate)
|
||||
|
||||
def forward_intermediates(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
indices: Optional[Union[int, List[int]]] = None,
|
||||
norm: bool = False,
|
||||
stop_early: bool = False,
|
||||
output_fmt: str = 'NCHW',
|
||||
intermediates_only: bool = False,
|
||||
) -> Union[List[torch.Tensor], Tuple[torch.Tensor, List[torch.Tensor]]]:
|
||||
""" Forward features that returns intermediates.
|
||||
|
||||
Args:
|
||||
x: Input image tensor
|
||||
indices: Take last n blocks if int, all if None, select matching indices if sequence
|
||||
norm: Apply norm layer to compatible intermediates
|
||||
stop_early: Stop iterating over blocks when last desired intermediate hit
|
||||
output_fmt: Shape of intermediate feature outputs
|
||||
intermediates_only: Only return intermediate features
|
||||
Returns:
|
||||
|
||||
"""
|
||||
assert output_fmt in ('NCHW',), 'Output shape must be NCHW.'
|
||||
intermediates = []
|
||||
take_indices, max_index = feature_take_indices(len(self.stages), indices)
|
||||
|
||||
# forward pass
|
||||
x = self.stem(x)
|
||||
if torch.jit.is_scripting() or not stop_early: # can't slice blocks in torchscript
|
||||
stages = self.stages
|
||||
else:
|
||||
stages = self.stages[:max_index + 1]
|
||||
|
||||
for feat_idx, stage in enumerate(stages):
|
||||
x = stage(x)
|
||||
if feat_idx in take_indices:
|
||||
intermediates.append(x)
|
||||
|
||||
if intermediates_only:
|
||||
return intermediates
|
||||
|
||||
return x, intermediates
|
||||
|
||||
def prune_intermediate_layers(
|
||||
self,
|
||||
indices: Union[int, List[int]] = 1,
|
||||
prune_norm: bool = False,
|
||||
prune_head: bool = True,
|
||||
):
|
||||
""" Prune layers not required for specified intermediates.
|
||||
"""
|
||||
take_indices, max_index = feature_take_indices(len(self.stages), indices)
|
||||
self.stages = self.stages[:max_index + 1] # truncate blocks w/ stem as idx 0
|
||||
if prune_head:
|
||||
self.reset_classifier(0, '')
|
||||
return take_indices
|
||||
|
||||
def forward_features(self, x: torch.Tensor) -> torch.Tensor:
|
||||
x = self.stem(x)
|
||||
x = self.stages(x)
|
||||
@ -509,9 +567,11 @@ class GlobalContextVit(nn.Module):
|
||||
|
||||
|
||||
def _create_gcvit(variant, pretrained=False, **kwargs):
|
||||
if kwargs.get('features_only', None):
|
||||
raise RuntimeError('features_only not implemented for Vision Transformer models.')
|
||||
model = build_model_with_cfg(GlobalContextVit, variant, pretrained, **kwargs)
|
||||
model = build_model_with_cfg(
|
||||
GlobalContextVit, variant, pretrained,
|
||||
feature_cfg=dict(out_indices=(0, 1, 2, 3), flatten_sequential=True),
|
||||
**kwargs
|
||||
)
|
||||
return model
|
||||
|
||||
|
||||
|
@ -6,7 +6,7 @@ The Paddle Implement of PP-HGNet (https://github.com/PaddlePaddle/PaddleClas/blo
|
||||
PP-HGNet: https://github.com/PaddlePaddle/PaddleClas/blob/release/2.5.1/ppcls/arch/backbone/legendary_models/pp_hgnet.py
|
||||
PP-HGNetv2: https://github.com/PaddlePaddle/PaddleClas/blob/release/2.5.1/ppcls/arch/backbone/legendary_models/pp_hgnet_v2.py
|
||||
"""
|
||||
from typing import Dict, Optional
|
||||
from typing import Dict, List, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
@ -15,6 +15,7 @@ import torch.nn.functional as F
|
||||
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
|
||||
from timm.layers import SelectAdaptivePool2d, DropPath, create_conv2d
|
||||
from ._builder import build_model_with_cfg
|
||||
from ._features import feature_take_indices
|
||||
from ._registry import register_model, generate_default_cfgs
|
||||
from ._manipulate import checkpoint_seq
|
||||
|
||||
@ -508,6 +509,62 @@ class HighPerfGpuNet(nn.Module):
|
||||
self.num_classes = num_classes
|
||||
self.head.reset(num_classes, global_pool)
|
||||
|
||||
def forward_intermediates(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
indices: Optional[Union[int, List[int]]] = None,
|
||||
norm: bool = False,
|
||||
stop_early: bool = False,
|
||||
output_fmt: str = 'NCHW',
|
||||
intermediates_only: bool = False,
|
||||
) -> Union[List[torch.Tensor], Tuple[torch.Tensor, List[torch.Tensor]]]:
|
||||
""" Forward features that returns intermediates.
|
||||
|
||||
Args:
|
||||
x: Input image tensor
|
||||
indices: Take last n blocks if int, all if None, select matching indices if sequence
|
||||
norm: Apply norm layer to compatible intermediates
|
||||
stop_early: Stop iterating over blocks when last desired intermediate hit
|
||||
output_fmt: Shape of intermediate feature outputs
|
||||
intermediates_only: Only return intermediate features
|
||||
Returns:
|
||||
|
||||
"""
|
||||
assert output_fmt in ('NCHW',), 'Output shape must be NCHW.'
|
||||
intermediates = []
|
||||
take_indices, max_index = feature_take_indices(len(self.stages), indices)
|
||||
|
||||
# forward pass
|
||||
x = self.stem(x)
|
||||
if torch.jit.is_scripting() or not stop_early: # can't slice blocks in torchscript
|
||||
stages = self.stages
|
||||
else:
|
||||
stages = self.stages[:max_index + 1]
|
||||
|
||||
for feat_idx, stage in enumerate(stages):
|
||||
x = stage(x)
|
||||
if feat_idx in take_indices:
|
||||
intermediates.append(x)
|
||||
|
||||
if intermediates_only:
|
||||
return intermediates
|
||||
|
||||
return x, intermediates
|
||||
|
||||
def prune_intermediate_layers(
|
||||
self,
|
||||
indices: Union[int, List[int]] = 1,
|
||||
prune_norm: bool = False,
|
||||
prune_head: bool = True,
|
||||
):
|
||||
""" Prune layers not required for specified intermediates.
|
||||
"""
|
||||
take_indices, max_index = feature_take_indices(len(self.stages), indices)
|
||||
self.stages = self.stages[:max_index + 1] # truncate blocks w/ stem as idx 0
|
||||
if prune_head:
|
||||
self.reset_classifier(0, 'avg')
|
||||
return take_indices
|
||||
|
||||
def forward_features(self, x):
|
||||
x = self.stem(x)
|
||||
return self.stages(x)
|
||||
|
@ -4,7 +4,7 @@ Original implementation & weights from: https://github.com/sail-sg/inceptionnext
|
||||
"""
|
||||
|
||||
from functools import partial
|
||||
from typing import Optional
|
||||
from typing import List, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
@ -12,6 +12,7 @@ import torch.nn as nn
|
||||
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
|
||||
from timm.layers import trunc_normal_, DropPath, to_2tuple, get_padding, SelectAdaptivePool2d
|
||||
from ._builder import build_model_with_cfg
|
||||
from ._features import feature_take_indices
|
||||
from ._manipulate import checkpoint_seq
|
||||
from ._registry import register_model, generate_default_cfgs
|
||||
|
||||
@ -349,6 +350,62 @@ class MetaNeXt(nn.Module):
|
||||
def no_weight_decay(self):
|
||||
return set()
|
||||
|
||||
def forward_intermediates(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
indices: Optional[Union[int, List[int]]] = None,
|
||||
norm: bool = False,
|
||||
stop_early: bool = False,
|
||||
output_fmt: str = 'NCHW',
|
||||
intermediates_only: bool = False,
|
||||
) -> Union[List[torch.Tensor], Tuple[torch.Tensor, List[torch.Tensor]]]:
|
||||
""" Forward features that returns intermediates.
|
||||
|
||||
Args:
|
||||
x: Input image tensor
|
||||
indices: Take last n blocks if int, all if None, select matching indices if sequence
|
||||
norm: Apply norm layer to compatible intermediates
|
||||
stop_early: Stop iterating over blocks when last desired intermediate hit
|
||||
output_fmt: Shape of intermediate feature outputs
|
||||
intermediates_only: Only return intermediate features
|
||||
Returns:
|
||||
|
||||
"""
|
||||
assert output_fmt in ('NCHW',), 'Output shape must be NCHW.'
|
||||
intermediates = []
|
||||
take_indices, max_index = feature_take_indices(len(self.stages), indices)
|
||||
|
||||
# forward pass
|
||||
x = self.stem(x)
|
||||
if torch.jit.is_scripting() or not stop_early: # can't slice blocks in torchscript
|
||||
stages = self.stages
|
||||
else:
|
||||
stages = self.stages[:max_index + 1]
|
||||
|
||||
for feat_idx, stage in enumerate(stages):
|
||||
x = stage(x)
|
||||
if feat_idx in take_indices:
|
||||
intermediates.append(x)
|
||||
|
||||
if intermediates_only:
|
||||
return intermediates
|
||||
|
||||
return x, intermediates
|
||||
|
||||
def prune_intermediate_layers(
|
||||
self,
|
||||
indices: Union[int, List[int]] = 1,
|
||||
prune_norm: bool = False,
|
||||
prune_head: bool = True,
|
||||
):
|
||||
""" Prune layers not required for specified intermediates.
|
||||
"""
|
||||
take_indices, max_index = feature_take_indices(len(self.stages), indices)
|
||||
self.stages = self.stages[:max_index + 1] # truncate blocks w/ stem as idx 0
|
||||
if prune_head:
|
||||
self.reset_classifier(0, 'avg')
|
||||
return take_indices
|
||||
|
||||
def forward_features(self, x):
|
||||
x = self.stem(x)
|
||||
x = self.stages(x)
|
||||
|
@ -3,6 +3,7 @@ Sourced from https://github.com/Cadene/tensorflow-model-zoo.torch (MIT License)
|
||||
based upon Google's Tensorflow implementation and pretrained weights (Apache 2.0 License)
|
||||
"""
|
||||
from functools import partial
|
||||
from typing import List, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
@ -10,6 +11,7 @@ import torch.nn as nn
|
||||
from timm.data import IMAGENET_INCEPTION_MEAN, IMAGENET_INCEPTION_STD
|
||||
from timm.layers import create_classifier, ConvNormAct
|
||||
from ._builder import build_model_with_cfg
|
||||
from ._features import feature_take_indices
|
||||
from ._registry import register_model, generate_default_cfgs
|
||||
|
||||
__all__ = ['InceptionV4']
|
||||
@ -285,6 +287,66 @@ class InceptionV4(nn.Module):
|
||||
self.global_pool, self.last_linear = create_classifier(
|
||||
self.num_features, self.num_classes, pool_type=global_pool)
|
||||
|
||||
def forward_intermediates(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
indices: Optional[Union[int, List[int]]] = None,
|
||||
norm: bool = False,
|
||||
stop_early: bool = False,
|
||||
output_fmt: str = 'NCHW',
|
||||
intermediates_only: bool = False,
|
||||
) -> Union[List[torch.Tensor], Tuple[torch.Tensor, List[torch.Tensor]]]:
|
||||
""" Forward features that returns intermediates.
|
||||
|
||||
Args:
|
||||
x: Input image tensor
|
||||
indices: Take last n blocks if int, all if None, select matching indices if sequence
|
||||
norm: Apply norm layer to compatible intermediates
|
||||
stop_early: Stop iterating over blocks when last desired intermediate hit
|
||||
output_fmt: Shape of intermediate feature outputs
|
||||
intermediates_only: Only return intermediate features
|
||||
Returns:
|
||||
|
||||
"""
|
||||
assert output_fmt in ('NCHW',), 'Output shape must be NCHW.'
|
||||
intermediates = []
|
||||
stage_ends = [int(info['module'].split('.')[-1]) for info in self.feature_info]
|
||||
take_indices, max_index = feature_take_indices(len(stage_ends), indices)
|
||||
take_indices = [stage_ends[i] for i in take_indices]
|
||||
max_index = stage_ends[max_index]
|
||||
|
||||
# forward pass
|
||||
if torch.jit.is_scripting() or not stop_early: # can't slice blocks in torchscript
|
||||
stages = self.features
|
||||
else:
|
||||
stages = self.features[:max_index + 1]
|
||||
|
||||
for feat_idx, stage in enumerate(stages):
|
||||
x = stage(x)
|
||||
if feat_idx in take_indices:
|
||||
intermediates.append(x)
|
||||
|
||||
if intermediates_only:
|
||||
return intermediates
|
||||
|
||||
return x, intermediates
|
||||
|
||||
def prune_intermediate_layers(
|
||||
self,
|
||||
indices: Union[int, List[int]] = 1,
|
||||
prune_norm: bool = False,
|
||||
prune_head: bool = True,
|
||||
):
|
||||
""" Prune layers not required for specified intermediates.
|
||||
"""
|
||||
stage_ends = [int(info['module'].split('.')[-1]) for info in self.feature_info]
|
||||
take_indices, max_index = feature_take_indices(len(stage_ends), indices)
|
||||
max_index = stage_ends[max_index]
|
||||
self.features = self.features[:max_index + 1] # truncate blocks w/ stem as idx 0
|
||||
if prune_head:
|
||||
self.reset_classifier(0, '')
|
||||
return take_indices
|
||||
|
||||
def forward_features(self, x):
|
||||
return self.features(x)
|
||||
|
||||
|
@ -6,7 +6,7 @@ MetaFormer (https://github.com/sail-sg/metaformer),
|
||||
InceptionNeXt (https://github.com/sail-sg/inceptionnext)
|
||||
"""
|
||||
from collections import OrderedDict
|
||||
from typing import Optional
|
||||
from typing import List, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
@ -14,6 +14,7 @@ from torch import nn
|
||||
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
|
||||
from timm.layers import trunc_normal_, DropPath, LayerNorm, LayerScale, ClNormMlpClassifierHead, get_act_layer
|
||||
from ._builder import build_model_with_cfg
|
||||
from ._features import feature_take_indices
|
||||
from ._manipulate import checkpoint_seq
|
||||
from ._registry import register_model, generate_default_cfgs
|
||||
|
||||
@ -417,6 +418,67 @@ class MambaOut(nn.Module):
|
||||
self.num_classes = num_classes
|
||||
self.head.reset(num_classes, global_pool)
|
||||
|
||||
def forward_intermediates(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
indices: Optional[Union[int, List[int]]] = None,
|
||||
norm: bool = False,
|
||||
stop_early: bool = False,
|
||||
output_fmt: str = 'NCHW',
|
||||
intermediates_only: bool = False,
|
||||
) -> Union[List[torch.Tensor], Tuple[torch.Tensor, List[torch.Tensor]]]:
|
||||
""" Forward features that returns intermediates.
|
||||
|
||||
Args:
|
||||
x: Input image tensor
|
||||
indices: Take last n blocks if int, all if None, select matching indices if sequence
|
||||
norm: Apply norm layer to compatible intermediates
|
||||
stop_early: Stop iterating over blocks when last desired intermediate hit
|
||||
output_fmt: Shape of intermediate feature outputs
|
||||
intermediates_only: Only return intermediate features
|
||||
Returns:
|
||||
|
||||
"""
|
||||
assert output_fmt in ('NCHW', 'NHWC'), 'Output format must be one of NCHW or NHWC.'
|
||||
channel_first = output_fmt == 'NCHW'
|
||||
intermediates = []
|
||||
take_indices, max_index = feature_take_indices(len(self.stages), indices)
|
||||
|
||||
# forward pass
|
||||
x = self.stem(x)
|
||||
if torch.jit.is_scripting() or not stop_early: # can't slice blocks in torchscript
|
||||
stages = self.stages
|
||||
else:
|
||||
stages = self.stages[:max_index + 1]
|
||||
|
||||
for feat_idx, stage in enumerate(stages):
|
||||
x = stage(x)
|
||||
if feat_idx in take_indices:
|
||||
intermediates.append(x)
|
||||
|
||||
if channel_first:
|
||||
# reshape to BCHW output format
|
||||
intermediates = [y.permute(0, 3, 1, 2).contiguous() for y in intermediates]
|
||||
|
||||
if intermediates_only:
|
||||
return intermediates
|
||||
|
||||
return x, intermediates
|
||||
|
||||
def prune_intermediate_layers(
|
||||
self,
|
||||
indices: Union[int, List[int]] = 1,
|
||||
prune_norm: bool = False,
|
||||
prune_head: bool = True,
|
||||
):
|
||||
""" Prune layers not required for specified intermediates.
|
||||
"""
|
||||
take_indices, max_index = feature_take_indices(len(self.stages), indices)
|
||||
self.stages = self.stages[:max_index + 1] # truncate blocks w/ stem as idx 0
|
||||
if prune_head:
|
||||
self.reset_classifier(0, '')
|
||||
return take_indices
|
||||
|
||||
def forward_features(self, x):
|
||||
x = self.stem(x)
|
||||
x = self.stages(x)
|
||||
|
@ -1302,7 +1302,8 @@ class MaxxVit(nn.Module):
|
||||
if intermediates_only:
|
||||
return intermediates
|
||||
|
||||
x = self.norm(x)
|
||||
if feat_idx == last_idx:
|
||||
x = self.norm(x)
|
||||
|
||||
return x, intermediates
|
||||
|
||||
|
@ -28,7 +28,7 @@ Adapted from https://github.com/sail-sg/metaformer, original copyright below
|
||||
|
||||
from collections import OrderedDict
|
||||
from functools import partial
|
||||
from typing import Optional
|
||||
from typing import List, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
@ -40,6 +40,7 @@ from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
|
||||
from timm.layers import trunc_normal_, DropPath, SelectAdaptivePool2d, GroupNorm1, LayerNorm, LayerNorm2d, Mlp, \
|
||||
use_fused_attn
|
||||
from ._builder import build_model_with_cfg
|
||||
from ._features import feature_take_indices
|
||||
from ._manipulate import checkpoint_seq
|
||||
from ._registry import generate_default_cfgs, register_model
|
||||
|
||||
@ -597,6 +598,62 @@ class MetaFormer(nn.Module):
|
||||
final = nn.Identity()
|
||||
self.head.fc = final
|
||||
|
||||
def forward_intermediates(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
indices: Optional[Union[int, List[int]]] = None,
|
||||
norm: bool = False,
|
||||
stop_early: bool = False,
|
||||
output_fmt: str = 'NCHW',
|
||||
intermediates_only: bool = False,
|
||||
) -> Union[List[torch.Tensor], Tuple[torch.Tensor, List[torch.Tensor]]]:
|
||||
""" Forward features that returns intermediates.
|
||||
|
||||
Args:
|
||||
x: Input image tensor
|
||||
indices: Take last n blocks if int, all if None, select matching indices if sequence
|
||||
norm: Apply norm layer to compatible intermediates
|
||||
stop_early: Stop iterating over blocks when last desired intermediate hit
|
||||
output_fmt: Shape of intermediate feature outputs
|
||||
intermediates_only: Only return intermediate features
|
||||
Returns:
|
||||
|
||||
"""
|
||||
assert output_fmt in ('NCHW',), 'Output shape must be NCHW.'
|
||||
intermediates = []
|
||||
take_indices, max_index = feature_take_indices(len(self.stages), indices)
|
||||
|
||||
# forward pass
|
||||
x = self.stem(x)
|
||||
if torch.jit.is_scripting() or not stop_early: # can't slice blocks in torchscript
|
||||
stages = self.stages
|
||||
else:
|
||||
stages = self.stages[:max_index + 1]
|
||||
|
||||
for feat_idx, stage in enumerate(stages):
|
||||
x = stage(x)
|
||||
if feat_idx in take_indices:
|
||||
intermediates.append(x)
|
||||
|
||||
if intermediates_only:
|
||||
return intermediates
|
||||
|
||||
return x, intermediates
|
||||
|
||||
def prune_intermediate_layers(
|
||||
self,
|
||||
indices: Union[int, List[int]] = 1,
|
||||
prune_norm: bool = False,
|
||||
prune_head: bool = True,
|
||||
):
|
||||
""" Prune layers not required for specified intermediates.
|
||||
"""
|
||||
take_indices, max_index = feature_take_indices(len(self.stages), indices)
|
||||
self.stages = self.stages[:max_index + 1] # truncate blocks w/ stem as idx 0
|
||||
if prune_head:
|
||||
self.reset_classifier(0, '')
|
||||
return take_indices
|
||||
|
||||
def forward_head(self, x: Tensor, pre_logits: bool = False):
|
||||
# NOTE nn.Sequential in head broken down since can't call head[:-1](x) in torchscript :(
|
||||
x = self.head.global_pool(x)
|
||||
|
@ -870,10 +870,11 @@ class MultiScaleVit(nn.Module):
|
||||
if self.pos_embed is not None:
|
||||
x = x + self.pos_embed
|
||||
|
||||
for i, stage in enumerate(self.stages):
|
||||
last_idx = len(self.stages) - 1
|
||||
for feat_idx, stage in enumerate(self.stages):
|
||||
x, feat_size = stage(x, feat_size)
|
||||
if i in take_indices:
|
||||
if norm and i == (len(self.stages) - 1):
|
||||
if feat_idx in take_indices:
|
||||
if norm and feat_idx == last_idx:
|
||||
x_inter = self.norm(x) # applying final norm last intermediate
|
||||
else:
|
||||
x_inter = x
|
||||
@ -887,7 +888,8 @@ class MultiScaleVit(nn.Module):
|
||||
if intermediates_only:
|
||||
return intermediates
|
||||
|
||||
x = self.norm(x)
|
||||
if feat_idx == last_idx:
|
||||
x = self.norm(x)
|
||||
|
||||
return x, intermediates
|
||||
|
||||
|
@ -19,6 +19,7 @@ import collections.abc
|
||||
import logging
|
||||
import math
|
||||
from functools import partial
|
||||
from typing import List, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
@ -28,6 +29,7 @@ from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
|
||||
from timm.layers import PatchEmbed, Mlp, DropPath, create_classifier, trunc_normal_, _assert
|
||||
from timm.layers import create_conv2d, create_pool2d, to_ntuple, use_fused_attn, LayerNorm
|
||||
from ._builder import build_model_with_cfg
|
||||
from ._features import feature_take_indices
|
||||
from ._features_fx import register_notrace_function
|
||||
from ._manipulate import checkpoint_seq, named_apply
|
||||
from ._registry import register_model, generate_default_cfgs, register_model_deprecations
|
||||
@ -420,6 +422,73 @@ class Nest(nn.Module):
|
||||
self.global_pool, self.head = create_classifier(
|
||||
self.num_features, self.num_classes, pool_type=global_pool)
|
||||
|
||||
def forward_intermediates(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
indices: Optional[Union[int, List[int]]] = None,
|
||||
norm: bool = False,
|
||||
stop_early: bool = False,
|
||||
output_fmt: str = 'NCHW',
|
||||
intermediates_only: bool = False,
|
||||
) -> Union[List[torch.Tensor], Tuple[torch.Tensor, List[torch.Tensor]]]:
|
||||
""" Forward features that returns intermediates.
|
||||
|
||||
Args:
|
||||
x: Input image tensor
|
||||
indices: Take last n blocks if int, all if None, select matching indices if sequence
|
||||
norm: Apply norm layer to compatible intermediates
|
||||
stop_early: Stop iterating over blocks when last desired intermediate hit
|
||||
output_fmt: Shape of intermediate feature outputs
|
||||
intermediates_only: Only return intermediate features
|
||||
Returns:
|
||||
|
||||
"""
|
||||
assert output_fmt in ('NCHW',), 'Output shape must be NCHW.'
|
||||
intermediates = []
|
||||
take_indices, max_index = feature_take_indices(len(self.levels), indices)
|
||||
|
||||
# forward pass
|
||||
x = self.patch_embed(x)
|
||||
last_idx = len(self.num_blocks) - 1
|
||||
if torch.jit.is_scripting() or not stop_early: # can't slice blocks in torchscript
|
||||
stages = self.levels
|
||||
else:
|
||||
stages = self.levels[:max_index + 1]
|
||||
|
||||
for feat_idx, stage in enumerate(stages):
|
||||
x = stage(x)
|
||||
if feat_idx in take_indices:
|
||||
if norm and feat_idx == last_idx:
|
||||
x_inter = self.norm(x.permute(0, 2, 3, 1)).permute(0, 3, 1, 2)
|
||||
intermediates.append(x_inter)
|
||||
else:
|
||||
intermediates.append(x)
|
||||
|
||||
if intermediates_only:
|
||||
return intermediates
|
||||
|
||||
if feat_idx == last_idx:
|
||||
# Layer norm done over channel dim only (to NHWC and back)
|
||||
x = self.norm(x.permute(0, 2, 3, 1)).permute(0, 3, 1, 2)
|
||||
|
||||
return x, intermediates
|
||||
|
||||
def prune_intermediate_layers(
|
||||
self,
|
||||
indices: Union[int, List[int]] = 1,
|
||||
prune_norm: bool = False,
|
||||
prune_head: bool = True,
|
||||
):
|
||||
""" Prune layers not required for specified intermediates.
|
||||
"""
|
||||
take_indices, max_index = feature_take_indices(len(self.levels), indices)
|
||||
self.levels = self.levels[:max_index + 1] # truncate blocks w/ stem as idx 0
|
||||
if prune_norm:
|
||||
self.norm = nn.Identity()
|
||||
if prune_head:
|
||||
self.reset_classifier(0, '')
|
||||
return take_indices
|
||||
|
||||
def forward_features(self, x):
|
||||
x = self.patch_embed(x)
|
||||
x = self.levels(x)
|
||||
|
@ -6,7 +6,7 @@ Next-ViT model defs and weights adapted from https://github.com/bytedance/Next-V
|
||||
"""
|
||||
# Copyright (c) ByteDance Inc. All rights reserved.
|
||||
from functools import partial
|
||||
from typing import Optional
|
||||
from typing import List, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
@ -16,6 +16,7 @@ from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
|
||||
from timm.layers import DropPath, trunc_normal_, ConvMlp, get_norm_layer, get_act_layer, use_fused_attn
|
||||
from timm.layers import ClassifierHead
|
||||
from ._builder import build_model_with_cfg
|
||||
from ._features import feature_take_indices
|
||||
from ._features_fx import register_notrace_function
|
||||
from ._manipulate import checkpoint_seq
|
||||
from ._registry import generate_default_cfgs, register_model
|
||||
@ -560,6 +561,72 @@ class NextViT(nn.Module):
|
||||
self.num_classes = num_classes
|
||||
self.head.reset(num_classes, pool_type=global_pool)
|
||||
|
||||
def forward_intermediates(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
indices: Optional[Union[int, List[int]]] = None,
|
||||
norm: bool = False,
|
||||
stop_early: bool = False,
|
||||
output_fmt: str = 'NCHW',
|
||||
intermediates_only: bool = False,
|
||||
) -> Union[List[torch.Tensor], Tuple[torch.Tensor, List[torch.Tensor]]]:
|
||||
""" Forward features that returns intermediates.
|
||||
|
||||
Args:
|
||||
x: Input image tensor
|
||||
indices: Take last n blocks if int, all if None, select matching indices if sequence
|
||||
norm: Apply norm layer to compatible intermediates
|
||||
stop_early: Stop iterating over blocks when last desired intermediate hit
|
||||
output_fmt: Shape of intermediate feature outputs
|
||||
intermediates_only: Only return intermediate features
|
||||
Returns:
|
||||
|
||||
"""
|
||||
assert output_fmt in ('NCHW',), 'Output shape must be NCHW.'
|
||||
intermediates = []
|
||||
take_indices, max_index = feature_take_indices(len(self.stages), indices)
|
||||
|
||||
# forward pass
|
||||
x = self.stem(x)
|
||||
last_idx = len(self.stages) - 1
|
||||
if torch.jit.is_scripting() or not stop_early: # can't slice blocks in torchscript
|
||||
stages = self.stages
|
||||
else:
|
||||
stages = self.stages[:max_index + 1]
|
||||
|
||||
for feat_idx, stage in enumerate(stages):
|
||||
x = stage(x)
|
||||
if feat_idx in take_indices:
|
||||
if feat_idx == last_idx:
|
||||
x_inter = self.norm(x) if norm else x
|
||||
intermediates.append(x_inter)
|
||||
else:
|
||||
intermediates.append(x)
|
||||
|
||||
if intermediates_only:
|
||||
return intermediates
|
||||
|
||||
if feat_idx == last_idx:
|
||||
x = self.norm(x)
|
||||
|
||||
return x, intermediates
|
||||
|
||||
def prune_intermediate_layers(
|
||||
self,
|
||||
indices: Union[int, List[int]] = 1,
|
||||
prune_norm: bool = False,
|
||||
prune_head: bool = True,
|
||||
):
|
||||
""" Prune layers not required for specified intermediates.
|
||||
"""
|
||||
take_indices, max_index = feature_take_indices(len(self.stages), indices)
|
||||
self.stages = self.stages[:max_index + 1] # truncate blocks w/ stem as idx 0
|
||||
if prune_norm:
|
||||
self.norm = nn.Identity()
|
||||
if prune_head:
|
||||
self.reset_classifier(0, '')
|
||||
return take_indices
|
||||
|
||||
def forward_features(self, x):
|
||||
x = self.stem(x)
|
||||
if self.grad_checkpointing and not torch.jit.is_scripting():
|
||||
|
@ -14,7 +14,7 @@ Modifications for timm by / Copyright 2020 Ross Wightman
|
||||
import math
|
||||
import re
|
||||
from functools import partial
|
||||
from typing import Optional, Sequence, Tuple
|
||||
from typing import List, Optional, Sequence, Tuple, Union
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
@ -22,6 +22,7 @@ from torch import nn
|
||||
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
|
||||
from timm.layers import trunc_normal_, to_2tuple
|
||||
from ._builder import build_model_with_cfg
|
||||
from ._features import feature_take_indices
|
||||
from ._registry import register_model, generate_default_cfgs
|
||||
from .vision_transformer import Block
|
||||
|
||||
@ -254,6 +255,71 @@ class PoolingVisionTransformer(nn.Module):
|
||||
if self.head_dist is not None:
|
||||
self.head_dist = nn.Linear(self.embed_dim, self.num_classes) if num_classes > 0 else nn.Identity()
|
||||
|
||||
def forward_intermediates(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
indices: Optional[Union[int, List[int]]] = None,
|
||||
norm: bool = False,
|
||||
stop_early: bool = False,
|
||||
output_fmt: str = 'NCHW',
|
||||
intermediates_only: bool = False,
|
||||
) -> Union[List[torch.Tensor], Tuple[torch.Tensor, List[torch.Tensor]]]:
|
||||
""" Forward features that returns intermediates.
|
||||
|
||||
Args:
|
||||
x: Input image tensor
|
||||
indices: Take last n blocks if int, all if None, select matching indices if sequence
|
||||
norm: Apply norm layer to compatible intermediates
|
||||
stop_early: Stop iterating over blocks when last desired intermediate hit
|
||||
output_fmt: Shape of intermediate feature outputs
|
||||
intermediates_only: Only return intermediate features
|
||||
Returns:
|
||||
|
||||
"""
|
||||
assert output_fmt in ('NCHW',), 'Output shape must be NCHW.'
|
||||
intermediates = []
|
||||
take_indices, max_index = feature_take_indices(len(self.transformers), indices)
|
||||
|
||||
# forward pass
|
||||
x = self.patch_embed(x)
|
||||
x = self.pos_drop(x + self.pos_embed)
|
||||
cls_tokens = self.cls_token.expand(x.shape[0], -1, -1)
|
||||
|
||||
last_idx = len(self.transformers) - 1
|
||||
if torch.jit.is_scripting() or not stop_early: # can't slice blocks in torchscript
|
||||
stages = self.transformers
|
||||
else:
|
||||
stages = self.transformers[:max_index + 1]
|
||||
|
||||
for feat_idx, stage in enumerate(stages):
|
||||
x, cls_tokens = stage((x, cls_tokens))
|
||||
if feat_idx in take_indices:
|
||||
intermediates.append(x)
|
||||
|
||||
if intermediates_only:
|
||||
return intermediates
|
||||
|
||||
if feat_idx == last_idx:
|
||||
cls_tokens = self.norm(cls_tokens)
|
||||
|
||||
return cls_tokens, intermediates
|
||||
|
||||
def prune_intermediate_layers(
|
||||
self,
|
||||
indices: Union[int, List[int]] = 1,
|
||||
prune_norm: bool = False,
|
||||
prune_head: bool = True,
|
||||
):
|
||||
""" Prune layers not required for specified intermediates.
|
||||
"""
|
||||
take_indices, max_index = feature_take_indices(len(self.transformers), indices)
|
||||
self.transformers = self.transformers[:max_index + 1] # truncate blocks w/ stem as idx 0
|
||||
if prune_norm:
|
||||
self.norm = nn.Identity()
|
||||
if prune_head:
|
||||
self.reset_classifier(0, '')
|
||||
return take_indices
|
||||
|
||||
def forward_features(self, x):
|
||||
x = self.patch_embed(x)
|
||||
x = self.pos_drop(x + self.pos_embed)
|
||||
@ -314,7 +380,7 @@ def _create_pit(variant, pretrained=False, **kwargs):
|
||||
variant,
|
||||
pretrained,
|
||||
pretrained_filter_fn=checkpoint_filter_fn,
|
||||
feature_cfg=dict(feature_cls='hook', no_rewrite=True, out_indices=out_indices),
|
||||
feature_cfg=dict(feature_cls='hook', out_indices=out_indices),
|
||||
**kwargs,
|
||||
)
|
||||
return model
|
||||
|
@ -16,7 +16,7 @@ Modifications and timm support by / Copyright 2022, Ross Wightman
|
||||
"""
|
||||
|
||||
import math
|
||||
from typing import Callable, List, Optional, Union
|
||||
from typing import Callable, List, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
@ -25,6 +25,7 @@ import torch.nn.functional as F
|
||||
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
|
||||
from timm.layers import DropPath, to_2tuple, to_ntuple, trunc_normal_, LayerNorm, use_fused_attn
|
||||
from ._builder import build_model_with_cfg
|
||||
from ._features import feature_take_indices
|
||||
from ._manipulate import checkpoint
|
||||
from ._registry import register_model, generate_default_cfgs
|
||||
|
||||
@ -386,6 +387,62 @@ class PyramidVisionTransformerV2(nn.Module):
|
||||
self.global_pool = global_pool
|
||||
self.head = nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity()
|
||||
|
||||
def forward_intermediates(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
indices: Optional[Union[int, List[int]]] = None,
|
||||
norm: bool = False,
|
||||
stop_early: bool = False,
|
||||
output_fmt: str = 'NCHW',
|
||||
intermediates_only: bool = False,
|
||||
) -> Union[List[torch.Tensor], Tuple[torch.Tensor, List[torch.Tensor]]]:
|
||||
""" Forward features that returns intermediates.
|
||||
|
||||
Args:
|
||||
x: Input image tensor
|
||||
indices: Take last n blocks if int, all if None, select matching indices if sequence
|
||||
norm: Apply norm layer to compatible intermediates
|
||||
stop_early: Stop iterating over blocks when last desired intermediate hit
|
||||
output_fmt: Shape of intermediate feature outputs
|
||||
intermediates_only: Only return intermediate features
|
||||
Returns:
|
||||
|
||||
"""
|
||||
assert output_fmt in ('NCHW',), 'Output shape must be NCHW.'
|
||||
intermediates = []
|
||||
take_indices, max_index = feature_take_indices(len(self.stages), indices)
|
||||
|
||||
# forward pass
|
||||
x = self.patch_embed(x)
|
||||
if torch.jit.is_scripting() or not stop_early: # can't slice blocks in torchscript
|
||||
stages = self.stages
|
||||
else:
|
||||
stages = self.stages[:max_index + 1]
|
||||
|
||||
for feat_idx, stage in enumerate(stages):
|
||||
x = stage(x)
|
||||
if feat_idx in take_indices:
|
||||
intermediates.append(x)
|
||||
|
||||
if intermediates_only:
|
||||
return intermediates
|
||||
|
||||
return x, intermediates
|
||||
|
||||
def prune_intermediate_layers(
|
||||
self,
|
||||
indices: Union[int, List[int]] = 1,
|
||||
prune_norm: bool = False,
|
||||
prune_head: bool = True,
|
||||
):
|
||||
""" Prune layers not required for specified intermediates.
|
||||
"""
|
||||
take_indices, max_index = feature_take_indices(len(self.stages), indices)
|
||||
self.stages = self.stages[:max_index + 1] # truncate blocks w/ stem as idx 0
|
||||
if prune_head:
|
||||
self.reset_classifier(0, '')
|
||||
return take_indices
|
||||
|
||||
def forward_features(self, x):
|
||||
x = self.patch_embed(x)
|
||||
x = self.stages(x)
|
||||
|
@ -302,29 +302,33 @@ class RDNet(nn.Module):
|
||||
"""
|
||||
assert output_fmt in ('NCHW',), 'Output shape must be NCHW.'
|
||||
intermediates = []
|
||||
take_indices, max_index = feature_take_indices(len(self.dense_stages) + 1, indices)
|
||||
stage_ends = [int(info['module'].split('.')[-1]) for info in self.feature_info]
|
||||
take_indices, max_index = feature_take_indices(len(stage_ends), indices)
|
||||
take_indices = [stage_ends[i] for i in take_indices]
|
||||
max_index = stage_ends[max_index]
|
||||
|
||||
# forward pass
|
||||
feat_idx = 0 # stem is index 0
|
||||
x = self.stem(x)
|
||||
if feat_idx in take_indices:
|
||||
intermediates.append(x)
|
||||
|
||||
last_idx = len(self.dense_stages) - 1
|
||||
if torch.jit.is_scripting() or not stop_early: # can't slice blocks in torchscript
|
||||
dense_stages = self.dense_stages
|
||||
else:
|
||||
dense_stages = self.dense_stages[:max_index]
|
||||
for stage in dense_stages:
|
||||
feat_idx += 1
|
||||
dense_stages = self.dense_stages[:max_index + 1]
|
||||
for feat_idx, stage in enumerate(dense_stages):
|
||||
x = stage(x)
|
||||
if feat_idx in take_indices:
|
||||
# NOTE not bothering to apply norm_pre when norm=True as almost no models have it enabled
|
||||
intermediates.append(x)
|
||||
if norm and feat_idx == last_idx:
|
||||
x_inter = self.norm_pre(x) # applying final norm to last intermediate
|
||||
else:
|
||||
x_inter = x
|
||||
intermediates.append(x_inter)
|
||||
|
||||
if intermediates_only:
|
||||
return intermediates
|
||||
|
||||
x = self.norm_pre(x)
|
||||
if feat_idx == last_idx:
|
||||
x = self.norm_pre(x)
|
||||
|
||||
return x, intermediates
|
||||
|
||||
@ -336,8 +340,10 @@ class RDNet(nn.Module):
|
||||
):
|
||||
""" Prune layers not required for specified intermediates.
|
||||
"""
|
||||
take_indices, max_index = feature_take_indices(len(self.dense_stages) + 1, indices)
|
||||
self.dense_stages = self.dense_stages[:max_index] # truncate blocks w/ stem as idx 0
|
||||
stage_ends = [int(info['module'].split('.')[-1]) for info in self.feature_info]
|
||||
take_indices, max_index = feature_take_indices(len(stage_ends), indices)
|
||||
max_index = stage_ends[max_index]
|
||||
self.dense_stages = self.dense_stages[:max_index + 1] # truncate blocks w/ stem as idx 0
|
||||
if prune_norm:
|
||||
self.norm_pre = nn.Identity()
|
||||
if prune_head:
|
||||
@ -355,6 +361,7 @@ class RDNet(nn.Module):
|
||||
def forward_features(self, x):
|
||||
x = self.stem(x)
|
||||
x = self.dense_stages(x)
|
||||
x = self.norm_pre(x)
|
||||
return x
|
||||
|
||||
def forward_head(self, x, pre_logits: bool = False):
|
||||
|
@ -6,7 +6,7 @@ Original implementation: https://github.com/ChengpengChen/RepGhost
|
||||
"""
|
||||
import copy
|
||||
from functools import partial
|
||||
from typing import Optional
|
||||
from typing import List, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
@ -16,6 +16,7 @@ from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
|
||||
from timm.layers import SelectAdaptivePool2d, Linear, make_divisible
|
||||
from ._builder import build_model_with_cfg
|
||||
from ._efficientnet_blocks import SqueezeExcite, ConvBnAct
|
||||
from ._features import feature_take_indices
|
||||
from ._manipulate import checkpoint_seq
|
||||
from ._registry import register_model, generate_default_cfgs
|
||||
|
||||
@ -294,6 +295,72 @@ class RepGhostNet(nn.Module):
|
||||
self.flatten = nn.Flatten(1) if global_pool else nn.Identity() # don't flatten if pooling disabled
|
||||
self.classifier = Linear(self.head_hidden_size, num_classes) if num_classes > 0 else nn.Identity()
|
||||
|
||||
def forward_intermediates(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
indices: Optional[Union[int, List[int]]] = None,
|
||||
norm: bool = False,
|
||||
stop_early: bool = False,
|
||||
output_fmt: str = 'NCHW',
|
||||
intermediates_only: bool = False,
|
||||
) -> Union[List[torch.Tensor], Tuple[torch.Tensor, List[torch.Tensor]]]:
|
||||
""" Forward features that returns intermediates.
|
||||
|
||||
Args:
|
||||
x: Input image tensor
|
||||
indices: Take last n blocks if int, all if None, select matching indices if sequence
|
||||
norm: Apply norm layer to compatible intermediates
|
||||
stop_early: Stop iterating over blocks when last desired intermediate hit
|
||||
output_fmt: Shape of intermediate feature outputs
|
||||
intermediates_only: Only return intermediate features
|
||||
Returns:
|
||||
|
||||
"""
|
||||
assert output_fmt in ('NCHW',), 'Output shape must be NCHW.'
|
||||
intermediates = []
|
||||
stage_ends = [-1] + [int(info['module'].split('.')[-1]) for info in self.feature_info[1:]]
|
||||
take_indices, max_index = feature_take_indices(len(stage_ends), indices)
|
||||
take_indices = [stage_ends[i]+1 for i in take_indices]
|
||||
max_index = stage_ends[max_index]
|
||||
|
||||
# forward pass
|
||||
feat_idx = 0
|
||||
x = self.conv_stem(x)
|
||||
if feat_idx in take_indices:
|
||||
intermediates.append(x)
|
||||
x = self.bn1(x)
|
||||
x = self.act1(x)
|
||||
if torch.jit.is_scripting() or not stop_early: # can't slice blocks in torchscript
|
||||
stages = self.blocks
|
||||
else:
|
||||
stages = self.blocks[:max_index + 1]
|
||||
|
||||
for feat_idx, stage in enumerate(stages, start=1):
|
||||
x = stage(x)
|
||||
if feat_idx in take_indices:
|
||||
intermediates.append(x)
|
||||
|
||||
if intermediates_only:
|
||||
return intermediates
|
||||
|
||||
return x, intermediates
|
||||
|
||||
def prune_intermediate_layers(
|
||||
self,
|
||||
indices: Union[int, List[int]] = 1,
|
||||
prune_norm: bool = False,
|
||||
prune_head: bool = True,
|
||||
):
|
||||
""" Prune layers not required for specified intermediates.
|
||||
"""
|
||||
stage_ends = [-1] + [int(info['module'].split('.')[-1]) for info in self.feature_info[1:]]
|
||||
take_indices, max_index = feature_take_indices(len(stage_ends), indices)
|
||||
max_index = stage_ends[max_index]
|
||||
self.blocks = self.blocks[:max_index + 1] # truncate blocks w/ stem as idx 0
|
||||
if prune_head:
|
||||
self.reset_classifier(0, '')
|
||||
return take_indices
|
||||
|
||||
def forward_features(self, x):
|
||||
x = self.conv_stem(x)
|
||||
x = self.bn1(x)
|
||||
|
@ -14,9 +14,7 @@ Paper: `RepViT: Revisiting Mobile CNN From ViT Perspective`
|
||||
|
||||
Adapted from official impl at https://github.com/jameslahm/RepViT
|
||||
"""
|
||||
|
||||
__all__ = ['RepVit']
|
||||
from typing import Optional
|
||||
from typing import List, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
@ -24,9 +22,12 @@ import torch.nn as nn
|
||||
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
|
||||
from timm.layers import SqueezeExcite, trunc_normal_, to_ntuple, to_2tuple
|
||||
from ._builder import build_model_with_cfg
|
||||
from ._features import feature_take_indices
|
||||
from ._manipulate import checkpoint_seq
|
||||
from ._registry import register_model, generate_default_cfgs
|
||||
|
||||
__all__ = ['RepVit']
|
||||
|
||||
|
||||
class ConvNorm(nn.Sequential):
|
||||
def __init__(self, in_dim, out_dim, ks=1, stride=1, pad=0, dilation=1, groups=1, bn_weight_init=1):
|
||||
@ -333,6 +334,62 @@ class RepVit(nn.Module):
|
||||
def set_distilled_training(self, enable=True):
|
||||
self.head.distilled_training = enable
|
||||
|
||||
def forward_intermediates(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
indices: Optional[Union[int, List[int]]] = None,
|
||||
norm: bool = False,
|
||||
stop_early: bool = False,
|
||||
output_fmt: str = 'NCHW',
|
||||
intermediates_only: bool = False,
|
||||
) -> Union[List[torch.Tensor], Tuple[torch.Tensor, List[torch.Tensor]]]:
|
||||
""" Forward features that returns intermediates.
|
||||
|
||||
Args:
|
||||
x: Input image tensor
|
||||
indices: Take last n blocks if int, all if None, select matching indices if sequence
|
||||
norm: Apply norm layer to compatible intermediates
|
||||
stop_early: Stop iterating over blocks when last desired intermediate hit
|
||||
output_fmt: Shape of intermediate feature outputs
|
||||
intermediates_only: Only return intermediate features
|
||||
Returns:
|
||||
|
||||
"""
|
||||
assert output_fmt in ('NCHW',), 'Output shape must be NCHW.'
|
||||
intermediates = []
|
||||
take_indices, max_index = feature_take_indices(len(self.stages), indices)
|
||||
|
||||
# forward pass
|
||||
x = self.stem(x)
|
||||
if torch.jit.is_scripting() or not stop_early: # can't slice blocks in torchscript
|
||||
stages = self.stages
|
||||
else:
|
||||
stages = self.stages[:max_index + 1]
|
||||
|
||||
for feat_idx, stage in enumerate(stages):
|
||||
x = stage(x)
|
||||
if feat_idx in take_indices:
|
||||
intermediates.append(x)
|
||||
|
||||
if intermediates_only:
|
||||
return intermediates
|
||||
|
||||
return x, intermediates
|
||||
|
||||
def prune_intermediate_layers(
|
||||
self,
|
||||
indices: Union[int, List[int]] = 1,
|
||||
prune_norm: bool = False,
|
||||
prune_head: bool = True,
|
||||
):
|
||||
""" Prune layers not required for specified intermediates.
|
||||
"""
|
||||
take_indices, max_index = feature_take_indices(len(self.stages), indices)
|
||||
self.stages = self.stages[:max_index + 1] # truncate blocks w/ stem as idx 0
|
||||
if prune_head:
|
||||
self.reset_classifier(0, '')
|
||||
return take_indices
|
||||
|
||||
def forward_features(self, x):
|
||||
x = self.stem(x)
|
||||
if self.grad_checkpointing and not torch.jit.is_scripting():
|
||||
|
@ -31,7 +31,7 @@ Original copyright of Google code below, modifications by Ross Wightman, Copyrig
|
||||
|
||||
from collections import OrderedDict # pylint: disable=g-importing-member
|
||||
from functools import partial
|
||||
from typing import Optional
|
||||
from typing import List, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
@ -40,6 +40,7 @@ from timm.data import IMAGENET_INCEPTION_MEAN, IMAGENET_INCEPTION_STD
|
||||
from timm.layers import GroupNormAct, BatchNormAct2d, EvoNorm2dS0, FilterResponseNormTlu2d, ClassifierHead, \
|
||||
DropPath, AvgPool2dSame, create_pool2d, StdConv2d, create_conv2d, get_act_layer, get_norm_act_layer, make_divisible
|
||||
from ._builder import build_model_with_cfg
|
||||
from ._features import feature_take_indices
|
||||
from ._manipulate import checkpoint_seq, named_apply, adapt_input_conv
|
||||
from ._registry import generate_default_cfgs, register_model, register_model_deprecations
|
||||
|
||||
@ -543,6 +544,79 @@ class ResNetV2(nn.Module):
|
||||
self.num_classes = num_classes
|
||||
self.head.reset(num_classes, global_pool)
|
||||
|
||||
def forward_intermediates(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
indices: Optional[Union[int, List[int]]] = None,
|
||||
norm: bool = False,
|
||||
stop_early: bool = False,
|
||||
output_fmt: str = 'NCHW',
|
||||
intermediates_only: bool = False,
|
||||
) -> Union[List[torch.Tensor], Tuple[torch.Tensor, List[torch.Tensor]]]:
|
||||
""" Forward features that returns intermediates.
|
||||
|
||||
Args:
|
||||
x: Input image tensor
|
||||
indices: Take last n blocks if int, all if None, select matching indices if sequence
|
||||
norm: Apply norm layer to compatible intermediates
|
||||
stop_early: Stop iterating over blocks when last desired intermediate hit
|
||||
output_fmt: Shape of intermediate feature outputs
|
||||
intermediates_only: Only return intermediate features
|
||||
Returns:
|
||||
|
||||
"""
|
||||
assert output_fmt in ('NCHW',), 'Output shape must be NCHW.'
|
||||
intermediates = []
|
||||
take_indices, max_index = feature_take_indices(5, indices)
|
||||
|
||||
# forward pass
|
||||
feat_idx = 0
|
||||
H, W = x.shape[-2:]
|
||||
for stem in self.stem:
|
||||
x = stem(x)
|
||||
if x.shape[-2:] == (H //2, W //2):
|
||||
x_down = x
|
||||
if feat_idx in take_indices:
|
||||
intermediates.append(x_down)
|
||||
last_idx = len(self.stages)
|
||||
if torch.jit.is_scripting() or not stop_early: # can't slice blocks in torchscript
|
||||
stages = self.stages
|
||||
else:
|
||||
stages = self.stages[:max_index]
|
||||
|
||||
for feat_idx, stage in enumerate(stages, start=1):
|
||||
x = stage(x)
|
||||
if feat_idx in take_indices:
|
||||
if feat_idx == last_idx:
|
||||
x_inter = self.norm(x) if norm else x
|
||||
intermediates.append(x_inter)
|
||||
else:
|
||||
intermediates.append(x)
|
||||
|
||||
if intermediates_only:
|
||||
return intermediates
|
||||
|
||||
if feat_idx == last_idx:
|
||||
x = self.norm(x)
|
||||
|
||||
return x, intermediates
|
||||
|
||||
def prune_intermediate_layers(
|
||||
self,
|
||||
indices: Union[int, List[int]] = 1,
|
||||
prune_norm: bool = False,
|
||||
prune_head: bool = True,
|
||||
):
|
||||
""" Prune layers not required for specified intermediates.
|
||||
"""
|
||||
take_indices, max_index = feature_take_indices(5, indices)
|
||||
self.stages = self.stages[:max_index] # truncate blocks w/ stem as idx 0
|
||||
if prune_norm:
|
||||
self.norm = nn.Identity()
|
||||
if prune_head:
|
||||
self.reset_classifier(0, '')
|
||||
return take_indices
|
||||
|
||||
def forward_features(self, x):
|
||||
x = self.stem(x)
|
||||
if self.grad_checkpointing and not torch.jit.is_scripting():
|
||||
|
@ -12,7 +12,7 @@ Copyright 2020 Ross Wightman
|
||||
|
||||
from functools import partial
|
||||
from math import ceil
|
||||
from typing import Optional
|
||||
from typing import List, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
@ -21,6 +21,7 @@ from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
|
||||
from timm.layers import ClassifierHead, create_act_layer, ConvNormAct, DropPath, make_divisible, SEModule
|
||||
from ._builder import build_model_with_cfg
|
||||
from ._efficientnet_builder import efficientnet_init_weights
|
||||
from ._features import feature_take_indices
|
||||
from ._manipulate import checkpoint_seq
|
||||
from ._registry import generate_default_cfgs, register_model
|
||||
|
||||
@ -234,6 +235,67 @@ class RexNet(nn.Module):
|
||||
self.num_classes = num_classes
|
||||
self.head.reset(num_classes, global_pool)
|
||||
|
||||
def forward_intermediates(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
indices: Optional[Union[int, List[int]]] = None,
|
||||
norm: bool = False,
|
||||
stop_early: bool = False,
|
||||
output_fmt: str = 'NCHW',
|
||||
intermediates_only: bool = False,
|
||||
) -> Union[List[torch.Tensor], Tuple[torch.Tensor, List[torch.Tensor]]]:
|
||||
""" Forward features that returns intermediates.
|
||||
|
||||
Args:
|
||||
x: Input image tensor
|
||||
indices: Take last n blocks if int, all if None, select matching indices if sequence
|
||||
norm: Apply norm layer to compatible intermediates
|
||||
stop_early: Stop iterating over blocks when last desired intermediate hit
|
||||
output_fmt: Shape of intermediate feature outputs
|
||||
intermediates_only: Only return intermediate features
|
||||
Returns:
|
||||
|
||||
"""
|
||||
assert output_fmt in ('NCHW',), 'Output shape must be NCHW.'
|
||||
intermediates = []
|
||||
stage_ends = [int(info['module'].split('.')[-1]) for info in self.feature_info]
|
||||
take_indices, max_index = feature_take_indices(len(stage_ends), indices)
|
||||
take_indices = [stage_ends[i] for i in take_indices]
|
||||
max_index = stage_ends[max_index]
|
||||
|
||||
# forward pass
|
||||
x = self.stem(x)
|
||||
if torch.jit.is_scripting() or not stop_early: # can't slice blocks in torchscript
|
||||
stages = self.features
|
||||
else:
|
||||
stages = self.features[:max_index + 1]
|
||||
|
||||
for feat_idx, stage in enumerate(stages):
|
||||
x = stage(x)
|
||||
if feat_idx in take_indices:
|
||||
intermediates.append(x)
|
||||
|
||||
if intermediates_only:
|
||||
return intermediates
|
||||
|
||||
return x, intermediates
|
||||
|
||||
def prune_intermediate_layers(
|
||||
self,
|
||||
indices: Union[int, List[int]] = 1,
|
||||
prune_norm: bool = False,
|
||||
prune_head: bool = True,
|
||||
):
|
||||
""" Prune layers not required for specified intermediates.
|
||||
"""
|
||||
stage_ends = [int(info['module'].split('.')[-1]) for info in self.feature_info]
|
||||
take_indices, max_index = feature_take_indices(len(stage_ends), indices)
|
||||
max_index = stage_ends[max_index]
|
||||
self.features = self.features[:max_index + 1] # truncate blocks w/ stem as idx 0
|
||||
if prune_head:
|
||||
self.reset_classifier(0, '')
|
||||
return take_indices
|
||||
|
||||
def forward_features(self, x):
|
||||
x = self.stem(x)
|
||||
if self.grad_checkpointing and not torch.jit.is_scripting():
|
||||
|
@ -10,7 +10,7 @@ __all__ = ['TinyVit']
|
||||
|
||||
import itertools
|
||||
from functools import partial
|
||||
from typing import Dict, Optional
|
||||
from typing import Dict, List, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
@ -20,6 +20,7 @@ from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
|
||||
from timm.layers import LayerNorm2d, NormMlpClassifierHead, DropPath,\
|
||||
trunc_normal_, resize_rel_pos_bias_table_levit, use_fused_attn
|
||||
from ._builder import build_model_with_cfg
|
||||
from ._features import feature_take_indices
|
||||
from ._features_fx import register_notrace_module
|
||||
from ._manipulate import checkpoint_seq
|
||||
from ._registry import register_model, generate_default_cfgs
|
||||
@ -536,6 +537,62 @@ class TinyVit(nn.Module):
|
||||
self.num_classes = num_classes
|
||||
self.head.reset(num_classes, pool_type=global_pool)
|
||||
|
||||
def forward_intermediates(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
indices: Optional[Union[int, List[int]]] = None,
|
||||
norm: bool = False,
|
||||
stop_early: bool = False,
|
||||
output_fmt: str = 'NCHW',
|
||||
intermediates_only: bool = False,
|
||||
) -> Union[List[torch.Tensor], Tuple[torch.Tensor, List[torch.Tensor]]]:
|
||||
""" Forward features that returns intermediates.
|
||||
|
||||
Args:
|
||||
x: Input image tensor
|
||||
indices: Take last n blocks if int, all if None, select matching indices if sequence
|
||||
norm: Apply norm layer to compatible intermediates
|
||||
stop_early: Stop iterating over blocks when last desired intermediate hit
|
||||
output_fmt: Shape of intermediate feature outputs
|
||||
intermediates_only: Only return intermediate features
|
||||
Returns:
|
||||
|
||||
"""
|
||||
assert output_fmt in ('NCHW',), 'Output shape must be NCHW.'
|
||||
intermediates = []
|
||||
take_indices, max_index = feature_take_indices(len(self.stages), indices)
|
||||
|
||||
# forward pass
|
||||
x = self.patch_embed(x)
|
||||
if torch.jit.is_scripting() or not stop_early: # can't slice blocks in torchscript
|
||||
stages = self.stages
|
||||
else:
|
||||
stages = self.stages[:max_index + 1]
|
||||
|
||||
for feat_idx, stage in enumerate(stages):
|
||||
x = stage(x)
|
||||
if feat_idx in take_indices:
|
||||
intermediates.append(x)
|
||||
|
||||
if intermediates_only:
|
||||
return intermediates
|
||||
|
||||
return x, intermediates
|
||||
|
||||
def prune_intermediate_layers(
|
||||
self,
|
||||
indices: Union[int, List[int]] = 1,
|
||||
prune_norm: bool = False,
|
||||
prune_head: bool = True,
|
||||
):
|
||||
""" Prune layers not required for specified intermediates.
|
||||
"""
|
||||
take_indices, max_index = feature_take_indices(len(self.stages), indices)
|
||||
self.stages = self.stages[:max_index + 1] # truncate blocks w/ stem as idx 0
|
||||
if prune_head:
|
||||
self.reset_classifier(0, '')
|
||||
return take_indices
|
||||
|
||||
def forward_features(self, x):
|
||||
x = self.patch_embed(x)
|
||||
if self.grad_checkpointing and not torch.jit.is_scripting():
|
||||
|
@ -7,13 +7,14 @@ Original model: https://github.com/mrT23/TResNet
|
||||
"""
|
||||
from collections import OrderedDict
|
||||
from functools import partial
|
||||
from typing import Optional
|
||||
from typing import List, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from timm.layers import SpaceToDepth, BlurPool2d, ClassifierHead, SEModule, ConvNormAct, DropPath
|
||||
from ._builder import build_model_with_cfg
|
||||
from ._features import feature_take_indices
|
||||
from ._manipulate import checkpoint_seq
|
||||
from ._registry import register_model, generate_default_cfgs, register_model_deprecations
|
||||
|
||||
@ -228,6 +229,65 @@ class TResNet(nn.Module):
|
||||
self.num_classes = num_classes
|
||||
self.head.reset(num_classes, pool_type=global_pool)
|
||||
|
||||
def forward_intermediates(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
indices: Optional[Union[int, List[int]]] = None,
|
||||
norm: bool = False,
|
||||
stop_early: bool = False,
|
||||
output_fmt: str = 'NCHW',
|
||||
intermediates_only: bool = False,
|
||||
) -> Union[List[torch.Tensor], Tuple[torch.Tensor, List[torch.Tensor]]]:
|
||||
""" Forward features that returns intermediates.
|
||||
|
||||
Args:
|
||||
x: Input image tensor
|
||||
indices: Take last n blocks if int, all if None, select matching indices if sequence
|
||||
norm: Apply norm layer to compatible intermediates
|
||||
stop_early: Stop iterating over blocks when last desired intermediate hit
|
||||
output_fmt: Shape of intermediate feature outputs
|
||||
intermediates_only: Only return intermediate features
|
||||
Returns:
|
||||
|
||||
"""
|
||||
assert output_fmt in ('NCHW',), 'Output shape must be NCHW.'
|
||||
intermediates = []
|
||||
stage_ends = [1, 2, 3, 4, 5]
|
||||
take_indices, max_index = feature_take_indices(len(stage_ends), indices)
|
||||
take_indices = [stage_ends[i] for i in take_indices]
|
||||
max_index = stage_ends[max_index]
|
||||
# forward pass
|
||||
if torch.jit.is_scripting() or not stop_early: # can't slice blocks in torchscript
|
||||
stages = self.body
|
||||
else:
|
||||
stages = self.body[:max_index + 1]
|
||||
|
||||
for feat_idx, stage in enumerate(stages):
|
||||
x = stage(x)
|
||||
if feat_idx in take_indices:
|
||||
intermediates.append(x)
|
||||
|
||||
if intermediates_only:
|
||||
return intermediates
|
||||
|
||||
return x, intermediates
|
||||
|
||||
def prune_intermediate_layers(
|
||||
self,
|
||||
indices: Union[int, List[int]] = 1,
|
||||
prune_norm: bool = False,
|
||||
prune_head: bool = True,
|
||||
):
|
||||
""" Prune layers not required for specified intermediates.
|
||||
"""
|
||||
stage_ends = [1, 2, 3, 4, 5]
|
||||
take_indices, max_index = feature_take_indices(len(stage_ends), indices)
|
||||
max_index = stage_ends[max_index]
|
||||
self.body = self.body[:max_index + 1] # truncate blocks w/ stem as idx 0
|
||||
if prune_head:
|
||||
self.reset_classifier(0, '')
|
||||
return take_indices
|
||||
|
||||
def forward_features(self, x):
|
||||
if self.grad_checkpointing and not torch.jit.is_scripting():
|
||||
x = self.body.s2d(x)
|
||||
|
@ -11,7 +11,7 @@ for some reference, rewrote most of the code.
|
||||
Hacked together by / Copyright 2020 Ross Wightman
|
||||
"""
|
||||
|
||||
from typing import List, Optional
|
||||
from typing import List, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
@ -20,6 +20,7 @@ from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
|
||||
from timm.layers import ConvNormAct, SeparableConvNormAct, BatchNormAct2d, ClassifierHead, DropPath, \
|
||||
create_attn, create_norm_act_layer
|
||||
from ._builder import build_model_with_cfg
|
||||
from ._features import feature_take_indices
|
||||
from ._manipulate import checkpoint_seq
|
||||
from ._registry import register_model, generate_default_cfgs
|
||||
|
||||
@ -264,6 +265,67 @@ class VovNet(nn.Module):
|
||||
self.num_classes = num_classes
|
||||
self.head.reset(num_classes, global_pool)
|
||||
|
||||
def forward_intermediates(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
indices: Optional[Union[int, List[int]]] = None,
|
||||
norm: bool = False,
|
||||
stop_early: bool = False,
|
||||
output_fmt: str = 'NCHW',
|
||||
intermediates_only: bool = False,
|
||||
) -> Union[List[torch.Tensor], Tuple[torch.Tensor, List[torch.Tensor]]]:
|
||||
""" Forward features that returns intermediates.
|
||||
|
||||
Args:
|
||||
x: Input image tensor
|
||||
indices: Take last n blocks if int, all if None, select matching indices if sequence
|
||||
norm: Apply norm layer to compatible intermediates
|
||||
stop_early: Stop iterating over blocks when last desired intermediate hit
|
||||
output_fmt: Shape of intermediate feature outputs
|
||||
intermediates_only: Only return intermediate features
|
||||
Returns:
|
||||
|
||||
"""
|
||||
assert output_fmt in ('NCHW',), 'Output shape must be NCHW.'
|
||||
intermediates = []
|
||||
take_indices, max_index = feature_take_indices(5, indices)
|
||||
|
||||
# forward pass
|
||||
feat_idx = 0
|
||||
x = self.stem[:-1](x)
|
||||
if feat_idx in take_indices:
|
||||
intermediates.append(x)
|
||||
|
||||
x = self.stem[-1](x)
|
||||
if torch.jit.is_scripting() or not stop_early: # can't slice blocks in torchscript
|
||||
stages = self.stages
|
||||
else:
|
||||
stages = self.stages[:max_index]
|
||||
|
||||
for feat_idx, stage in enumerate(stages, start=1):
|
||||
x = stage(x)
|
||||
if feat_idx in take_indices:
|
||||
intermediates.append(x)
|
||||
|
||||
if intermediates_only:
|
||||
return intermediates
|
||||
|
||||
return x, intermediates
|
||||
|
||||
def prune_intermediate_layers(
|
||||
self,
|
||||
indices: Union[int, List[int]] = 1,
|
||||
prune_norm: bool = False,
|
||||
prune_head: bool = True,
|
||||
):
|
||||
""" Prune layers not required for specified intermediates.
|
||||
"""
|
||||
take_indices, max_index = feature_take_indices(5, indices)
|
||||
self.stages = self.stages[:max_index] # truncate blocks w/ stem as idx 0
|
||||
if prune_head:
|
||||
self.reset_classifier(0, '')
|
||||
return take_indices
|
||||
|
||||
def forward_features(self, x):
|
||||
x = self.stem(x)
|
||||
return self.stages(x)
|
||||
|
@ -494,7 +494,8 @@ class Xcit(nn.Module):
|
||||
# NOTE not supporting return of class tokens
|
||||
x = torch.cat((self.cls_token.expand(B, -1, -1), x), dim=1)
|
||||
for blk in self.cls_attn_blocks:
|
||||
x = blk(x)
|
||||
x = blk(x)
|
||||
|
||||
x = self.norm(x)
|
||||
|
||||
return x, intermediates
|
||||
|
Loading…
x
Reference in New Issue
Block a user