mirror of
https://github.com/huggingface/pytorch-image-models.git
synced 2025-06-03 15:01:08 +08:00
Merge branch 'main' into tnt
This commit is contained in:
commit
69b1fbcdc1
@ -566,7 +566,7 @@ Model validation results can be found in the [results tables](results/README.md)
|
|||||||
|
|
||||||
The official documentation can be found at https://huggingface.co/docs/hub/timm. Documentation contributions are welcome.
|
The official documentation can be found at https://huggingface.co/docs/hub/timm. Documentation contributions are welcome.
|
||||||
|
|
||||||
[Getting Started with PyTorch Image Models (timm): A Practitioner’s Guide](https://towardsdatascience.com/getting-started-with-pytorch-image-models-timm-a-practitioners-guide-4e77b4bf9055) by [Chris Hughes](https://github.com/Chris-hughes10) is an extensive blog post covering many aspects of `timm` in detail.
|
[Getting Started with PyTorch Image Models (timm): A Practitioner’s Guide](https://towardsdatascience.com/getting-started-with-pytorch-image-models-timm-a-practitioners-guide-4e77b4bf9055-2/) by [Chris Hughes](https://github.com/Chris-hughes10) is an extensive blog post covering many aspects of `timm` in detail.
|
||||||
|
|
||||||
[timmdocs](http://timm.fast.ai/) is an alternate set of documentation for `timm`. A big thanks to [Aman Arora](https://github.com/amaarora) for his efforts creating timmdocs.
|
[timmdocs](http://timm.fast.ai/) is an alternate set of documentation for `timm`. A big thanks to [Aman Arora](https://github.com/amaarora) for his efforts creating timmdocs.
|
||||||
|
|
||||||
|
@ -54,6 +54,9 @@ FEAT_INTER_FILTERS = [
|
|||||||
'beit', 'mvitv2', 'eva', 'cait', 'xcit', 'volo', 'twins', 'deit', 'swin_transformer', 'swin_transformer_v2',
|
'beit', 'mvitv2', 'eva', 'cait', 'xcit', 'volo', 'twins', 'deit', 'swin_transformer', 'swin_transformer_v2',
|
||||||
'swin_transformer_v2_cr', 'maxxvit', 'efficientnet', 'mobilenetv3', 'levit', 'efficientformer', 'resnet',
|
'swin_transformer_v2_cr', 'maxxvit', 'efficientnet', 'mobilenetv3', 'levit', 'efficientformer', 'resnet',
|
||||||
'regnet', 'byobnet', 'byoanet', 'mlp_mixer', 'hiera', 'fastvit', 'hieradet_sam2', 'aimv2*', 'tnt',
|
'regnet', 'byobnet', 'byoanet', 'mlp_mixer', 'hiera', 'fastvit', 'hieradet_sam2', 'aimv2*', 'tnt',
|
||||||
|
'tiny_vit', 'vovnet', 'tresnet', 'rexnet', 'resnetv2', 'repghost', 'repvit', 'pvt_v2', 'nextvit', 'nest',
|
||||||
|
'mambaout', 'inception_next', 'inception_v4', 'hgnet', 'gcvit', 'focalnet', 'efficientformer_v2', 'edgenext',
|
||||||
|
'davit', 'rdnet', 'convnext', 'pit'
|
||||||
]
|
]
|
||||||
|
|
||||||
# transformer / hybrid models don't support full set of spatial / feature APIs and/or have spatial output.
|
# transformer / hybrid models don't support full set of spatial / feature APIs and/or have spatial output.
|
||||||
@ -508,8 +511,9 @@ def test_model_forward_intermediates(model_name, batch_size):
|
|||||||
spatial_axis = get_spatial_dim(output_fmt)
|
spatial_axis = get_spatial_dim(output_fmt)
|
||||||
import math
|
import math
|
||||||
|
|
||||||
|
inpt = torch.randn((batch_size, *input_size))
|
||||||
output, intermediates = model.forward_intermediates(
|
output, intermediates = model.forward_intermediates(
|
||||||
torch.randn((batch_size, *input_size)),
|
inpt,
|
||||||
output_fmt=output_fmt,
|
output_fmt=output_fmt,
|
||||||
)
|
)
|
||||||
assert len(expected_channels) == len(intermediates)
|
assert len(expected_channels) == len(intermediates)
|
||||||
@ -521,6 +525,9 @@ def test_model_forward_intermediates(model_name, batch_size):
|
|||||||
assert o.shape[0] == batch_size
|
assert o.shape[0] == batch_size
|
||||||
assert not torch.isnan(o).any()
|
assert not torch.isnan(o).any()
|
||||||
|
|
||||||
|
output2 = model.forward_features(inpt)
|
||||||
|
assert torch.allclose(output, output2)
|
||||||
|
|
||||||
|
|
||||||
def _create_fx_model(model, train=False):
|
def _create_fx_model(model, train=False):
|
||||||
# This block of code does a bit of juggling to handle any case where there are multiple outputs in train mode
|
# This block of code does a bit of juggling to handle any case where there are multiple outputs in train mode
|
||||||
|
@ -144,6 +144,7 @@ def create_dataset(
|
|||||||
use_train = split in _TRAIN_SYNONYM
|
use_train = split in _TRAIN_SYNONYM
|
||||||
ds = QMNIST(train=use_train, **torch_kwargs)
|
ds = QMNIST(train=use_train, **torch_kwargs)
|
||||||
elif name == 'imagenet':
|
elif name == 'imagenet':
|
||||||
|
torch_kwargs.pop('download')
|
||||||
assert has_imagenet, 'Please update to a newer PyTorch and torchvision for ImageNet dataset.'
|
assert has_imagenet, 'Please update to a newer PyTorch and torchvision for ImageNet dataset.'
|
||||||
if split in _EVAL_SYNONYM:
|
if split in _EVAL_SYNONYM:
|
||||||
split = 'val'
|
split = 'val'
|
||||||
|
@ -452,29 +452,29 @@ class ConvNeXt(nn.Module):
|
|||||||
"""
|
"""
|
||||||
assert output_fmt in ('NCHW',), 'Output shape must be NCHW.'
|
assert output_fmt in ('NCHW',), 'Output shape must be NCHW.'
|
||||||
intermediates = []
|
intermediates = []
|
||||||
take_indices, max_index = feature_take_indices(len(self.stages) + 1, indices)
|
take_indices, max_index = feature_take_indices(len(self.stages), indices)
|
||||||
|
|
||||||
# forward pass
|
# forward pass
|
||||||
feat_idx = 0 # stem is index 0
|
|
||||||
x = self.stem(x)
|
x = self.stem(x)
|
||||||
if feat_idx in take_indices:
|
|
||||||
intermediates.append(x)
|
|
||||||
|
|
||||||
|
last_idx = len(self.stages) - 1
|
||||||
if torch.jit.is_scripting() or not stop_early: # can't slice blocks in torchscript
|
if torch.jit.is_scripting() or not stop_early: # can't slice blocks in torchscript
|
||||||
stages = self.stages
|
stages = self.stages
|
||||||
else:
|
else:
|
||||||
stages = self.stages[:max_index]
|
stages = self.stages[:max_index + 1]
|
||||||
for stage in stages:
|
for feat_idx, stage in enumerate(stages):
|
||||||
feat_idx += 1
|
|
||||||
x = stage(x)
|
x = stage(x)
|
||||||
if feat_idx in take_indices:
|
if feat_idx in take_indices:
|
||||||
# NOTE not bothering to apply norm_pre when norm=True as almost no models have it enabled
|
if norm and feat_idx == last_idx:
|
||||||
intermediates.append(x)
|
intermediates.append(self.norm_pre(x))
|
||||||
|
else:
|
||||||
|
intermediates.append(x)
|
||||||
|
|
||||||
if intermediates_only:
|
if intermediates_only:
|
||||||
return intermediates
|
return intermediates
|
||||||
|
|
||||||
x = self.norm_pre(x)
|
if feat_idx == last_idx:
|
||||||
|
x = self.norm_pre(x)
|
||||||
|
|
||||||
return x, intermediates
|
return x, intermediates
|
||||||
|
|
||||||
@ -486,8 +486,8 @@ class ConvNeXt(nn.Module):
|
|||||||
):
|
):
|
||||||
""" Prune layers not required for specified intermediates.
|
""" Prune layers not required for specified intermediates.
|
||||||
"""
|
"""
|
||||||
take_indices, max_index = feature_take_indices(len(self.stages) + 1, indices)
|
take_indices, max_index = feature_take_indices(len(self.stages), indices)
|
||||||
self.stages = self.stages[:max_index] # truncate blocks w/ stem as idx 0
|
self.stages = self.stages[:max_index + 1] # truncate blocks w/ stem as idx 0
|
||||||
if prune_norm:
|
if prune_norm:
|
||||||
self.norm_pre = nn.Identity()
|
self.norm_pre = nn.Identity()
|
||||||
if prune_head:
|
if prune_head:
|
||||||
|
@ -12,7 +12,7 @@ DaViT model defs and weights adapted from https://github.com/dingmyu/davit, orig
|
|||||||
# All rights reserved.
|
# All rights reserved.
|
||||||
# This source code is licensed under the MIT license
|
# This source code is licensed under the MIT license
|
||||||
from functools import partial
|
from functools import partial
|
||||||
from typing import Optional, Tuple
|
from typing import List, Optional, Tuple, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
@ -23,6 +23,7 @@ from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
|
|||||||
from timm.layers import DropPath, to_2tuple, trunc_normal_, Mlp, LayerNorm2d, get_norm_layer, use_fused_attn
|
from timm.layers import DropPath, to_2tuple, trunc_normal_, Mlp, LayerNorm2d, get_norm_layer, use_fused_attn
|
||||||
from timm.layers import NormMlpClassifierHead, ClassifierHead
|
from timm.layers import NormMlpClassifierHead, ClassifierHead
|
||||||
from ._builder import build_model_with_cfg
|
from ._builder import build_model_with_cfg
|
||||||
|
from ._features import feature_take_indices
|
||||||
from ._features_fx import register_notrace_function
|
from ._features_fx import register_notrace_function
|
||||||
from ._manipulate import checkpoint_seq
|
from ._manipulate import checkpoint_seq
|
||||||
from ._registry import generate_default_cfgs, register_model
|
from ._registry import generate_default_cfgs, register_model
|
||||||
@ -636,6 +637,72 @@ class DaVit(nn.Module):
|
|||||||
self.num_classes = num_classes
|
self.num_classes = num_classes
|
||||||
self.head.reset(num_classes, global_pool)
|
self.head.reset(num_classes, global_pool)
|
||||||
|
|
||||||
|
def forward_intermediates(
|
||||||
|
self,
|
||||||
|
x: torch.Tensor,
|
||||||
|
indices: Optional[Union[int, List[int]]] = None,
|
||||||
|
norm: bool = False,
|
||||||
|
stop_early: bool = False,
|
||||||
|
output_fmt: str = 'NCHW',
|
||||||
|
intermediates_only: bool = False,
|
||||||
|
) -> Union[List[torch.Tensor], Tuple[torch.Tensor, List[torch.Tensor]]]:
|
||||||
|
""" Forward features that returns intermediates.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
x: Input image tensor
|
||||||
|
indices: Take last n blocks if int, all if None, select matching indices if sequence
|
||||||
|
norm: Apply norm layer to compatible intermediates
|
||||||
|
stop_early: Stop iterating over blocks when last desired intermediate hit
|
||||||
|
output_fmt: Shape of intermediate feature outputs
|
||||||
|
intermediates_only: Only return intermediate features
|
||||||
|
Returns:
|
||||||
|
|
||||||
|
"""
|
||||||
|
assert output_fmt in ('NCHW',), 'Output shape must be NCHW.'
|
||||||
|
intermediates = []
|
||||||
|
take_indices, max_index = feature_take_indices(len(self.stages), indices)
|
||||||
|
|
||||||
|
# forward pass
|
||||||
|
x = self.stem(x)
|
||||||
|
last_idx = len(self.stages) - 1
|
||||||
|
if torch.jit.is_scripting() or not stop_early: # can't slice blocks in torchscript
|
||||||
|
stages = self.stages
|
||||||
|
else:
|
||||||
|
stages = self.stages[:max_index + 1]
|
||||||
|
|
||||||
|
for feat_idx, stage in enumerate(stages):
|
||||||
|
x = stage(x)
|
||||||
|
if feat_idx in take_indices:
|
||||||
|
if norm and feat_idx == last_idx:
|
||||||
|
x_inter = self.norm_pre(x) # applying final norm to last intermediate
|
||||||
|
else:
|
||||||
|
x_inter = x
|
||||||
|
intermediates.append(x_inter)
|
||||||
|
|
||||||
|
if intermediates_only:
|
||||||
|
return intermediates
|
||||||
|
|
||||||
|
if feat_idx == last_idx:
|
||||||
|
x = self.norm_pre(x)
|
||||||
|
|
||||||
|
return x, intermediates
|
||||||
|
|
||||||
|
def prune_intermediate_layers(
|
||||||
|
self,
|
||||||
|
indices: Union[int, List[int]] = 1,
|
||||||
|
prune_norm: bool = False,
|
||||||
|
prune_head: bool = True,
|
||||||
|
):
|
||||||
|
""" Prune layers not required for specified intermediates.
|
||||||
|
"""
|
||||||
|
take_indices, max_index = feature_take_indices(len(self.stages), indices)
|
||||||
|
self.stages = self.stages[:max_index + 1] # truncate blocks w/ stem as idx 0
|
||||||
|
if prune_norm:
|
||||||
|
self.norm_pre = nn.Identity()
|
||||||
|
if prune_head:
|
||||||
|
self.reset_classifier(0, '')
|
||||||
|
return take_indices
|
||||||
|
|
||||||
def forward_features(self, x):
|
def forward_features(self, x):
|
||||||
x = self.stem(x)
|
x = self.stem(x)
|
||||||
if self.grad_checkpointing and not torch.jit.is_scripting():
|
if self.grad_checkpointing and not torch.jit.is_scripting():
|
||||||
|
@ -9,7 +9,7 @@ Modifications and additions for timm by / Copyright 2022, Ross Wightman
|
|||||||
"""
|
"""
|
||||||
import math
|
import math
|
||||||
from functools import partial
|
from functools import partial
|
||||||
from typing import Optional, Tuple
|
from typing import List, Optional, Tuple, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
@ -19,6 +19,7 @@ from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
|
|||||||
from timm.layers import trunc_normal_tf_, DropPath, LayerNorm2d, Mlp, create_conv2d, \
|
from timm.layers import trunc_normal_tf_, DropPath, LayerNorm2d, Mlp, create_conv2d, \
|
||||||
NormMlpClassifierHead, ClassifierHead
|
NormMlpClassifierHead, ClassifierHead
|
||||||
from ._builder import build_model_with_cfg
|
from ._builder import build_model_with_cfg
|
||||||
|
from ._features import feature_take_indices
|
||||||
from ._features_fx import register_notrace_module
|
from ._features_fx import register_notrace_module
|
||||||
from ._manipulate import named_apply, checkpoint_seq
|
from ._manipulate import named_apply, checkpoint_seq
|
||||||
from ._registry import register_model, generate_default_cfgs
|
from ._registry import register_model, generate_default_cfgs
|
||||||
@ -418,6 +419,72 @@ class EdgeNeXt(nn.Module):
|
|||||||
self.num_classes = num_classes
|
self.num_classes = num_classes
|
||||||
self.head.reset(num_classes, global_pool)
|
self.head.reset(num_classes, global_pool)
|
||||||
|
|
||||||
|
def forward_intermediates(
|
||||||
|
self,
|
||||||
|
x: torch.Tensor,
|
||||||
|
indices: Optional[Union[int, List[int]]] = None,
|
||||||
|
norm: bool = False,
|
||||||
|
stop_early: bool = False,
|
||||||
|
output_fmt: str = 'NCHW',
|
||||||
|
intermediates_only: bool = False,
|
||||||
|
) -> Union[List[torch.Tensor], Tuple[torch.Tensor, List[torch.Tensor]]]:
|
||||||
|
""" Forward features that returns intermediates.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
x: Input image tensor
|
||||||
|
indices: Take last n blocks if int, all if None, select matching indices if sequence
|
||||||
|
norm: Apply norm layer to compatible intermediates
|
||||||
|
stop_early: Stop iterating over blocks when last desired intermediate hit
|
||||||
|
output_fmt: Shape of intermediate feature outputs
|
||||||
|
intermediates_only: Only return intermediate features
|
||||||
|
Returns:
|
||||||
|
|
||||||
|
"""
|
||||||
|
assert output_fmt in ('NCHW',), 'Output shape must be NCHW.'
|
||||||
|
intermediates = []
|
||||||
|
take_indices, max_index = feature_take_indices(len(self.stages), indices)
|
||||||
|
|
||||||
|
# forward pass
|
||||||
|
x = self.stem(x)
|
||||||
|
last_idx = len(self.stages) - 1
|
||||||
|
if torch.jit.is_scripting() or not stop_early: # can't slice blocks in torchscript
|
||||||
|
stages = self.stages
|
||||||
|
else:
|
||||||
|
stages = self.stages[:max_index + 1]
|
||||||
|
|
||||||
|
for feat_idx, stage in enumerate(stages):
|
||||||
|
x = stage(x)
|
||||||
|
if feat_idx in take_indices:
|
||||||
|
if norm and feat_idx == last_idx:
|
||||||
|
x_inter = self.norm_pre(x) # applying final norm to last intermediate
|
||||||
|
else:
|
||||||
|
x_inter = x
|
||||||
|
intermediates.append(x_inter)
|
||||||
|
|
||||||
|
if intermediates_only:
|
||||||
|
return intermediates
|
||||||
|
|
||||||
|
if feat_idx == last_idx:
|
||||||
|
x = self.norm_pre(x)
|
||||||
|
|
||||||
|
return x, intermediates
|
||||||
|
|
||||||
|
def prune_intermediate_layers(
|
||||||
|
self,
|
||||||
|
indices: Union[int, List[int]] = 1,
|
||||||
|
prune_norm: bool = False,
|
||||||
|
prune_head: bool = True,
|
||||||
|
):
|
||||||
|
""" Prune layers not required for specified intermediates.
|
||||||
|
"""
|
||||||
|
take_indices, max_index = feature_take_indices(len(self.stages), indices)
|
||||||
|
self.stages = self.stages[:max_index + 1] # truncate blocks w/ stem as idx 0
|
||||||
|
if prune_norm:
|
||||||
|
self.norm_pre = nn.Identity()
|
||||||
|
if prune_head:
|
||||||
|
self.reset_classifier(0, '')
|
||||||
|
return take_indices
|
||||||
|
|
||||||
def forward_features(self, x):
|
def forward_features(self, x):
|
||||||
x = self.stem(x)
|
x = self.stem(x)
|
||||||
x = self.stages(x)
|
x = self.stages(x)
|
||||||
|
@ -16,7 +16,7 @@ Modifications and timm support by / Copyright 2023, Ross Wightman
|
|||||||
"""
|
"""
|
||||||
import math
|
import math
|
||||||
from functools import partial
|
from functools import partial
|
||||||
from typing import Dict, Optional
|
from typing import Dict, List, Optional, Tuple, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
@ -25,6 +25,7 @@ from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
|
|||||||
from timm.layers import create_conv2d, create_norm_layer, get_act_layer, get_norm_layer, ConvNormAct
|
from timm.layers import create_conv2d, create_norm_layer, get_act_layer, get_norm_layer, ConvNormAct
|
||||||
from timm.layers import DropPath, trunc_normal_, to_2tuple, to_ntuple, ndgrid
|
from timm.layers import DropPath, trunc_normal_, to_2tuple, to_ntuple, ndgrid
|
||||||
from ._builder import build_model_with_cfg
|
from ._builder import build_model_with_cfg
|
||||||
|
from ._features import feature_take_indices
|
||||||
from ._manipulate import checkpoint_seq
|
from ._manipulate import checkpoint_seq
|
||||||
from ._registry import generate_default_cfgs, register_model
|
from ._registry import generate_default_cfgs, register_model
|
||||||
|
|
||||||
@ -625,6 +626,73 @@ class EfficientFormerV2(nn.Module):
|
|||||||
def set_distilled_training(self, enable=True):
|
def set_distilled_training(self, enable=True):
|
||||||
self.distilled_training = enable
|
self.distilled_training = enable
|
||||||
|
|
||||||
|
def forward_intermediates(
|
||||||
|
self,
|
||||||
|
x: torch.Tensor,
|
||||||
|
indices: Optional[Union[int, List[int]]] = None,
|
||||||
|
norm: bool = False,
|
||||||
|
stop_early: bool = False,
|
||||||
|
output_fmt: str = 'NCHW',
|
||||||
|
intermediates_only: bool = False,
|
||||||
|
) -> Union[List[torch.Tensor], Tuple[torch.Tensor, List[torch.Tensor]]]:
|
||||||
|
""" Forward features that returns intermediates.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
x: Input image tensor
|
||||||
|
indices: Take last n blocks if int, all if None, select matching indices if sequence
|
||||||
|
norm: Apply norm layer to compatible intermediates
|
||||||
|
stop_early: Stop iterating over blocks when last desired intermediate hit
|
||||||
|
output_fmt: Shape of intermediate feature outputs
|
||||||
|
intermediates_only: Only return intermediate features
|
||||||
|
Returns:
|
||||||
|
|
||||||
|
"""
|
||||||
|
assert output_fmt in ('NCHW',), 'Output shape must be NCHW.'
|
||||||
|
intermediates = []
|
||||||
|
take_indices, max_index = feature_take_indices(len(self.stages), indices)
|
||||||
|
|
||||||
|
# forward pass
|
||||||
|
x = self.stem(x)
|
||||||
|
|
||||||
|
last_idx = len(self.stages) - 1
|
||||||
|
if torch.jit.is_scripting() or not stop_early: # can't slice blocks in torchscript
|
||||||
|
stages = self.stages
|
||||||
|
else:
|
||||||
|
stages = self.stages[:max_index + 1]
|
||||||
|
|
||||||
|
for feat_idx, stage in enumerate(stages):
|
||||||
|
x = stage(x)
|
||||||
|
if feat_idx in take_indices:
|
||||||
|
if feat_idx == last_idx:
|
||||||
|
x_inter = self.norm(x) if norm else x
|
||||||
|
intermediates.append(x_inter)
|
||||||
|
else:
|
||||||
|
intermediates.append(x)
|
||||||
|
|
||||||
|
if intermediates_only:
|
||||||
|
return intermediates
|
||||||
|
|
||||||
|
if feat_idx == last_idx:
|
||||||
|
x = self.norm(x)
|
||||||
|
|
||||||
|
return x, intermediates
|
||||||
|
|
||||||
|
def prune_intermediate_layers(
|
||||||
|
self,
|
||||||
|
indices: Union[int, List[int]] = 1,
|
||||||
|
prune_norm: bool = False,
|
||||||
|
prune_head: bool = True,
|
||||||
|
):
|
||||||
|
""" Prune layers not required for specified intermediates.
|
||||||
|
"""
|
||||||
|
take_indices, max_index = feature_take_indices(len(self.stages), indices)
|
||||||
|
self.stages = self.stages[:max_index + 1] # truncate blocks w/ stem as idx 0
|
||||||
|
if prune_norm:
|
||||||
|
self.norm = nn.Identity()
|
||||||
|
if prune_head:
|
||||||
|
self.reset_classifier(0, '')
|
||||||
|
return take_indices
|
||||||
|
|
||||||
def forward_features(self, x):
|
def forward_features(self, x):
|
||||||
x = self.stem(x)
|
x = self.stem(x)
|
||||||
x = self.stages(x)
|
x = self.stages(x)
|
||||||
|
@ -7,7 +7,7 @@ Adapted from official impl at https://github.com/mit-han-lab/efficientvit
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
__all__ = ['EfficientVit', 'EfficientVitLarge']
|
__all__ = ['EfficientVit', 'EfficientVitLarge']
|
||||||
from typing import List, Optional
|
from typing import List, Optional, Tuple, Union
|
||||||
from functools import partial
|
from functools import partial
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
@ -17,6 +17,7 @@ import torch.nn.functional as F
|
|||||||
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
|
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
|
||||||
from timm.layers import SelectAdaptivePool2d, create_conv2d, GELUTanh
|
from timm.layers import SelectAdaptivePool2d, create_conv2d, GELUTanh
|
||||||
from ._builder import build_model_with_cfg
|
from ._builder import build_model_with_cfg
|
||||||
|
from ._features import feature_take_indices
|
||||||
from ._features_fx import register_notrace_module
|
from ._features_fx import register_notrace_module
|
||||||
from ._manipulate import checkpoint_seq
|
from ._manipulate import checkpoint_seq
|
||||||
from ._registry import register_model, generate_default_cfgs
|
from ._registry import register_model, generate_default_cfgs
|
||||||
@ -754,6 +755,63 @@ class EfficientVit(nn.Module):
|
|||||||
self.num_classes = num_classes
|
self.num_classes = num_classes
|
||||||
self.head.reset(num_classes, global_pool)
|
self.head.reset(num_classes, global_pool)
|
||||||
|
|
||||||
|
def forward_intermediates(
|
||||||
|
self,
|
||||||
|
x: torch.Tensor,
|
||||||
|
indices: Optional[Union[int, List[int]]] = None,
|
||||||
|
norm: bool = False,
|
||||||
|
stop_early: bool = False,
|
||||||
|
output_fmt: str = 'NCHW',
|
||||||
|
intermediates_only: bool = False,
|
||||||
|
) -> Union[List[torch.Tensor], Tuple[torch.Tensor, List[torch.Tensor]]]:
|
||||||
|
""" Forward features that returns intermediates.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
x: Input image tensor
|
||||||
|
indices: Take last n blocks if int, all if None, select matching indices if sequence
|
||||||
|
norm: Apply norm layer to compatible intermediates
|
||||||
|
stop_early: Stop iterating over blocks when last desired intermediate hit
|
||||||
|
output_fmt: Shape of intermediate feature outputs
|
||||||
|
intermediates_only: Only return intermediate features
|
||||||
|
Returns:
|
||||||
|
|
||||||
|
"""
|
||||||
|
assert output_fmt in ('NCHW',), 'Output shape must be NCHW.'
|
||||||
|
intermediates = []
|
||||||
|
take_indices, max_index = feature_take_indices(len(self.stages), indices)
|
||||||
|
|
||||||
|
# forward pass
|
||||||
|
x = self.stem(x)
|
||||||
|
|
||||||
|
if torch.jit.is_scripting() or not stop_early: # can't slice blocks in torchscript
|
||||||
|
stages = self.stages
|
||||||
|
else:
|
||||||
|
stages = self.stages[:max_index + 1]
|
||||||
|
|
||||||
|
for feat_idx, stage in enumerate(stages):
|
||||||
|
x = stage(x)
|
||||||
|
if feat_idx in take_indices:
|
||||||
|
intermediates.append(x)
|
||||||
|
|
||||||
|
if intermediates_only:
|
||||||
|
return intermediates
|
||||||
|
|
||||||
|
return x, intermediates
|
||||||
|
|
||||||
|
def prune_intermediate_layers(
|
||||||
|
self,
|
||||||
|
indices: Union[int, List[int]] = 1,
|
||||||
|
prune_norm: bool = False,
|
||||||
|
prune_head: bool = True,
|
||||||
|
):
|
||||||
|
""" Prune layers not required for specified intermediates.
|
||||||
|
"""
|
||||||
|
take_indices, max_index = feature_take_indices(len(self.stages), indices)
|
||||||
|
self.stages = self.stages[:max_index + 1] # truncate blocks w/ stem as idx 0
|
||||||
|
if prune_head:
|
||||||
|
self.reset_classifier(0, '')
|
||||||
|
return take_indices
|
||||||
|
|
||||||
def forward_features(self, x):
|
def forward_features(self, x):
|
||||||
x = self.stem(x)
|
x = self.stem(x)
|
||||||
if self.grad_checkpointing and not torch.jit.is_scripting():
|
if self.grad_checkpointing and not torch.jit.is_scripting():
|
||||||
@ -851,6 +909,63 @@ class EfficientVitLarge(nn.Module):
|
|||||||
self.num_classes = num_classes
|
self.num_classes = num_classes
|
||||||
self.head.reset(num_classes, global_pool)
|
self.head.reset(num_classes, global_pool)
|
||||||
|
|
||||||
|
def forward_intermediates(
|
||||||
|
self,
|
||||||
|
x: torch.Tensor,
|
||||||
|
indices: Optional[Union[int, List[int]]] = None,
|
||||||
|
norm: bool = False,
|
||||||
|
stop_early: bool = False,
|
||||||
|
output_fmt: str = 'NCHW',
|
||||||
|
intermediates_only: bool = False,
|
||||||
|
) -> Union[List[torch.Tensor], Tuple[torch.Tensor, List[torch.Tensor]]]:
|
||||||
|
""" Forward features that returns intermediates.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
x: Input image tensor
|
||||||
|
indices: Take last n blocks if int, all if None, select matching indices if sequence
|
||||||
|
norm: Apply norm layer to compatible intermediates
|
||||||
|
stop_early: Stop iterating over blocks when last desired intermediate hit
|
||||||
|
output_fmt: Shape of intermediate feature outputs
|
||||||
|
intermediates_only: Only return intermediate features
|
||||||
|
Returns:
|
||||||
|
|
||||||
|
"""
|
||||||
|
assert output_fmt in ('NCHW',), 'Output shape must be NCHW.'
|
||||||
|
intermediates = []
|
||||||
|
take_indices, max_index = feature_take_indices(len(self.stages), indices)
|
||||||
|
|
||||||
|
# forward pass
|
||||||
|
x = self.stem(x)
|
||||||
|
|
||||||
|
if torch.jit.is_scripting() or not stop_early: # can't slice blocks in torchscript
|
||||||
|
stages = self.stages
|
||||||
|
else:
|
||||||
|
stages = self.stages[:max_index + 1]
|
||||||
|
|
||||||
|
for feat_idx, stage in enumerate(stages):
|
||||||
|
x = stage(x)
|
||||||
|
if feat_idx in take_indices:
|
||||||
|
intermediates.append(x)
|
||||||
|
|
||||||
|
if intermediates_only:
|
||||||
|
return intermediates
|
||||||
|
|
||||||
|
return x, intermediates
|
||||||
|
|
||||||
|
def prune_intermediate_layers(
|
||||||
|
self,
|
||||||
|
indices: Union[int, List[int]] = 1,
|
||||||
|
prune_norm: bool = False,
|
||||||
|
prune_head: bool = True,
|
||||||
|
):
|
||||||
|
""" Prune layers not required for specified intermediates.
|
||||||
|
"""
|
||||||
|
take_indices, max_index = feature_take_indices(len(self.stages), indices)
|
||||||
|
self.stages = self.stages[:max_index + 1] # truncate blocks w/ stem as idx 0
|
||||||
|
if prune_head:
|
||||||
|
self.reset_classifier(0, '')
|
||||||
|
return take_indices
|
||||||
|
|
||||||
def forward_features(self, x):
|
def forward_features(self, x):
|
||||||
x = self.stem(x)
|
x = self.stem(x)
|
||||||
if self.grad_checkpointing and not torch.jit.is_scripting():
|
if self.grad_checkpointing and not torch.jit.is_scripting():
|
||||||
|
@ -9,7 +9,7 @@ Adapted from official impl at https://github.com/microsoft/Cream/tree/main/Effic
|
|||||||
__all__ = ['EfficientVitMsra']
|
__all__ = ['EfficientVitMsra']
|
||||||
import itertools
|
import itertools
|
||||||
from collections import OrderedDict
|
from collections import OrderedDict
|
||||||
from typing import Dict, Optional
|
from typing import Dict, List, Optional, Tuple, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
@ -17,6 +17,7 @@ import torch.nn as nn
|
|||||||
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
|
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
|
||||||
from timm.layers import SqueezeExcite, SelectAdaptivePool2d, trunc_normal_, _assert
|
from timm.layers import SqueezeExcite, SelectAdaptivePool2d, trunc_normal_, _assert
|
||||||
from ._builder import build_model_with_cfg
|
from ._builder import build_model_with_cfg
|
||||||
|
from ._features import feature_take_indices
|
||||||
from ._manipulate import checkpoint_seq
|
from ._manipulate import checkpoint_seq
|
||||||
from ._registry import register_model, generate_default_cfgs
|
from ._registry import register_model, generate_default_cfgs
|
||||||
|
|
||||||
@ -475,6 +476,63 @@ class EfficientVitMsra(nn.Module):
|
|||||||
self.head = NormLinear(
|
self.head = NormLinear(
|
||||||
self.num_features, num_classes, drop=self.drop_rate) if num_classes > 0 else torch.nn.Identity()
|
self.num_features, num_classes, drop=self.drop_rate) if num_classes > 0 else torch.nn.Identity()
|
||||||
|
|
||||||
|
def forward_intermediates(
|
||||||
|
self,
|
||||||
|
x: torch.Tensor,
|
||||||
|
indices: Optional[Union[int, List[int]]] = None,
|
||||||
|
norm: bool = False,
|
||||||
|
stop_early: bool = False,
|
||||||
|
output_fmt: str = 'NCHW',
|
||||||
|
intermediates_only: bool = False,
|
||||||
|
) -> Union[List[torch.Tensor], Tuple[torch.Tensor, List[torch.Tensor]]]:
|
||||||
|
""" Forward features that returns intermediates.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
x: Input image tensor
|
||||||
|
indices: Take last n blocks if int, all if None, select matching indices if sequence
|
||||||
|
norm: Apply norm layer to compatible intermediates
|
||||||
|
stop_early: Stop iterating over blocks when last desired intermediate hit
|
||||||
|
output_fmt: Shape of intermediate feature outputs
|
||||||
|
intermediates_only: Only return intermediate features
|
||||||
|
Returns:
|
||||||
|
|
||||||
|
"""
|
||||||
|
assert output_fmt in ('NCHW',), 'Output shape must be NCHW.'
|
||||||
|
intermediates = []
|
||||||
|
take_indices, max_index = feature_take_indices(len(self.stages), indices)
|
||||||
|
|
||||||
|
# forward pass
|
||||||
|
x = self.patch_embed(x)
|
||||||
|
|
||||||
|
if torch.jit.is_scripting() or not stop_early: # can't slice blocks in torchscript
|
||||||
|
stages = self.stages
|
||||||
|
else:
|
||||||
|
stages = self.stages[:max_index + 1]
|
||||||
|
|
||||||
|
for feat_idx, stage in enumerate(stages):
|
||||||
|
x = stage(x)
|
||||||
|
if feat_idx in take_indices:
|
||||||
|
intermediates.append(x)
|
||||||
|
|
||||||
|
if intermediates_only:
|
||||||
|
return intermediates
|
||||||
|
|
||||||
|
return x, intermediates
|
||||||
|
|
||||||
|
def prune_intermediate_layers(
|
||||||
|
self,
|
||||||
|
indices: Union[int, List[int]] = 1,
|
||||||
|
prune_norm: bool = False,
|
||||||
|
prune_head: bool = True,
|
||||||
|
):
|
||||||
|
""" Prune layers not required for specified intermediates.
|
||||||
|
"""
|
||||||
|
take_indices, max_index = feature_take_indices(len(self.stages), indices)
|
||||||
|
self.stages = self.stages[:max_index + 1] # truncate blocks w/ stem as idx 0
|
||||||
|
if prune_head:
|
||||||
|
self.reset_classifier(0, '')
|
||||||
|
return take_indices
|
||||||
|
|
||||||
def forward_features(self, x):
|
def forward_features(self, x):
|
||||||
x = self.patch_embed(x)
|
x = self.patch_embed(x)
|
||||||
if self.grad_checkpointing and not torch.jit.is_scripting():
|
if self.grad_checkpointing and not torch.jit.is_scripting():
|
||||||
|
@ -18,7 +18,7 @@ This impl is/has:
|
|||||||
# Written by Jianwei Yang (jianwyan@microsoft.com)
|
# Written by Jianwei Yang (jianwyan@microsoft.com)
|
||||||
# --------------------------------------------------------
|
# --------------------------------------------------------
|
||||||
from functools import partial
|
from functools import partial
|
||||||
from typing import Callable, Optional, Tuple
|
from typing import Callable, List, Optional, Tuple, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
@ -26,6 +26,7 @@ import torch.nn as nn
|
|||||||
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
|
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
|
||||||
from timm.layers import Mlp, DropPath, LayerNorm2d, trunc_normal_, ClassifierHead, NormMlpClassifierHead
|
from timm.layers import Mlp, DropPath, LayerNorm2d, trunc_normal_, ClassifierHead, NormMlpClassifierHead
|
||||||
from ._builder import build_model_with_cfg
|
from ._builder import build_model_with_cfg
|
||||||
|
from ._features import feature_take_indices
|
||||||
from ._manipulate import named_apply, checkpoint
|
from ._manipulate import named_apply, checkpoint
|
||||||
from ._registry import generate_default_cfgs, register_model
|
from ._registry import generate_default_cfgs, register_model
|
||||||
|
|
||||||
@ -458,6 +459,72 @@ class FocalNet(nn.Module):
|
|||||||
self.num_classes = num_classes
|
self.num_classes = num_classes
|
||||||
self.head.reset(num_classes, pool_type=global_pool)
|
self.head.reset(num_classes, pool_type=global_pool)
|
||||||
|
|
||||||
|
def forward_intermediates(
|
||||||
|
self,
|
||||||
|
x: torch.Tensor,
|
||||||
|
indices: Optional[Union[int, List[int]]] = None,
|
||||||
|
norm: bool = False,
|
||||||
|
stop_early: bool = False,
|
||||||
|
output_fmt: str = 'NCHW',
|
||||||
|
intermediates_only: bool = False,
|
||||||
|
) -> Union[List[torch.Tensor], Tuple[torch.Tensor, List[torch.Tensor]]]:
|
||||||
|
""" Forward features that returns intermediates.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
x: Input image tensor
|
||||||
|
indices: Take last n blocks if int, all if None, select matching indices if sequence
|
||||||
|
norm: Apply norm layer to compatible intermediates
|
||||||
|
stop_early: Stop iterating over blocks when last desired intermediate hit
|
||||||
|
output_fmt: Shape of intermediate feature outputs
|
||||||
|
intermediates_only: Only return intermediate features
|
||||||
|
Returns:
|
||||||
|
|
||||||
|
"""
|
||||||
|
assert output_fmt in ('NCHW',), 'Output shape must be NCHW.'
|
||||||
|
intermediates = []
|
||||||
|
take_indices, max_index = feature_take_indices(len(self.layers), indices)
|
||||||
|
|
||||||
|
# forward pass
|
||||||
|
x = self.stem(x)
|
||||||
|
if torch.jit.is_scripting() or not stop_early: # can't slice blocks in torchscript
|
||||||
|
stages = self.layers
|
||||||
|
else:
|
||||||
|
stages = self.layers[:max_index + 1]
|
||||||
|
|
||||||
|
last_idx = len(self.layers) - 1
|
||||||
|
for feat_idx, stage in enumerate(stages):
|
||||||
|
x = stage(x)
|
||||||
|
if feat_idx in take_indices:
|
||||||
|
if norm and feat_idx == last_idx:
|
||||||
|
x_inter = self.norm(x) # applying final norm to last intermediate
|
||||||
|
else:
|
||||||
|
x_inter = x
|
||||||
|
intermediates.append(x_inter)
|
||||||
|
|
||||||
|
if intermediates_only:
|
||||||
|
return intermediates
|
||||||
|
|
||||||
|
if feat_idx == last_idx:
|
||||||
|
x = self.norm(x)
|
||||||
|
|
||||||
|
return x, intermediates
|
||||||
|
|
||||||
|
def prune_intermediate_layers(
|
||||||
|
self,
|
||||||
|
indices: Union[int, List[int]] = 1,
|
||||||
|
prune_norm: bool = False,
|
||||||
|
prune_head: bool = True,
|
||||||
|
):
|
||||||
|
""" Prune layers not required for specified intermediates.
|
||||||
|
"""
|
||||||
|
take_indices, max_index = feature_take_indices(len(self.layers), indices)
|
||||||
|
self.layers = self.layers[:max_index + 1] # truncate blocks w/ stem as idx 0
|
||||||
|
if prune_norm:
|
||||||
|
self.norm = nn.Identity()
|
||||||
|
if prune_head:
|
||||||
|
self.reset_classifier(0, '')
|
||||||
|
return take_indices
|
||||||
|
|
||||||
def forward_features(self, x):
|
def forward_features(self, x):
|
||||||
x = self.stem(x)
|
x = self.stem(x)
|
||||||
x = self.layers(x)
|
x = self.layers(x)
|
||||||
|
@ -30,6 +30,7 @@ from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
|
|||||||
from timm.layers import DropPath, to_2tuple, to_ntuple, Mlp, ClassifierHead, LayerNorm2d, \
|
from timm.layers import DropPath, to_2tuple, to_ntuple, Mlp, ClassifierHead, LayerNorm2d, \
|
||||||
get_attn, get_act_layer, get_norm_layer, RelPosBias, _assert
|
get_attn, get_act_layer, get_norm_layer, RelPosBias, _assert
|
||||||
from ._builder import build_model_with_cfg
|
from ._builder import build_model_with_cfg
|
||||||
|
from ._features import feature_take_indices
|
||||||
from ._features_fx import register_notrace_function
|
from ._features_fx import register_notrace_function
|
||||||
from ._manipulate import named_apply, checkpoint
|
from ._manipulate import named_apply, checkpoint
|
||||||
from ._registry import register_model, generate_default_cfgs
|
from ._registry import register_model, generate_default_cfgs
|
||||||
@ -397,7 +398,7 @@ class GlobalContextVit(nn.Module):
|
|||||||
act_layer = get_act_layer(act_layer)
|
act_layer = get_act_layer(act_layer)
|
||||||
norm_layer = partial(get_norm_layer(norm_layer), eps=norm_eps)
|
norm_layer = partial(get_norm_layer(norm_layer), eps=norm_eps)
|
||||||
norm_layer_cl = partial(get_norm_layer(norm_layer_cl), eps=norm_eps)
|
norm_layer_cl = partial(get_norm_layer(norm_layer_cl), eps=norm_eps)
|
||||||
|
self.feature_info = []
|
||||||
img_size = to_2tuple(img_size)
|
img_size = to_2tuple(img_size)
|
||||||
feat_size = tuple(d // 4 for d in img_size) # stem reduction by 4
|
feat_size = tuple(d // 4 for d in img_size) # stem reduction by 4
|
||||||
self.global_pool = global_pool
|
self.global_pool = global_pool
|
||||||
@ -441,6 +442,7 @@ class GlobalContextVit(nn.Module):
|
|||||||
norm_layer=norm_layer,
|
norm_layer=norm_layer,
|
||||||
norm_layer_cl=norm_layer_cl,
|
norm_layer_cl=norm_layer_cl,
|
||||||
))
|
))
|
||||||
|
self.feature_info += [dict(num_chs=stages[-1].dim, reduction=2**(i+2), module=f'stages.{i}')]
|
||||||
self.stages = nn.Sequential(*stages)
|
self.stages = nn.Sequential(*stages)
|
||||||
|
|
||||||
# Classifier head
|
# Classifier head
|
||||||
@ -494,6 +496,62 @@ class GlobalContextVit(nn.Module):
|
|||||||
global_pool = self.head.global_pool.pool_type
|
global_pool = self.head.global_pool.pool_type
|
||||||
self.head = ClassifierHead(self.num_features, num_classes, pool_type=global_pool, drop_rate=self.drop_rate)
|
self.head = ClassifierHead(self.num_features, num_classes, pool_type=global_pool, drop_rate=self.drop_rate)
|
||||||
|
|
||||||
|
def forward_intermediates(
|
||||||
|
self,
|
||||||
|
x: torch.Tensor,
|
||||||
|
indices: Optional[Union[int, List[int]]] = None,
|
||||||
|
norm: bool = False,
|
||||||
|
stop_early: bool = False,
|
||||||
|
output_fmt: str = 'NCHW',
|
||||||
|
intermediates_only: bool = False,
|
||||||
|
) -> Union[List[torch.Tensor], Tuple[torch.Tensor, List[torch.Tensor]]]:
|
||||||
|
""" Forward features that returns intermediates.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
x: Input image tensor
|
||||||
|
indices: Take last n blocks if int, all if None, select matching indices if sequence
|
||||||
|
norm: Apply norm layer to compatible intermediates
|
||||||
|
stop_early: Stop iterating over blocks when last desired intermediate hit
|
||||||
|
output_fmt: Shape of intermediate feature outputs
|
||||||
|
intermediates_only: Only return intermediate features
|
||||||
|
Returns:
|
||||||
|
|
||||||
|
"""
|
||||||
|
assert output_fmt in ('NCHW',), 'Output shape must be NCHW.'
|
||||||
|
intermediates = []
|
||||||
|
take_indices, max_index = feature_take_indices(len(self.stages), indices)
|
||||||
|
|
||||||
|
# forward pass
|
||||||
|
x = self.stem(x)
|
||||||
|
if torch.jit.is_scripting() or not stop_early: # can't slice blocks in torchscript
|
||||||
|
stages = self.stages
|
||||||
|
else:
|
||||||
|
stages = self.stages[:max_index + 1]
|
||||||
|
|
||||||
|
for feat_idx, stage in enumerate(stages):
|
||||||
|
x = stage(x)
|
||||||
|
if feat_idx in take_indices:
|
||||||
|
intermediates.append(x)
|
||||||
|
|
||||||
|
if intermediates_only:
|
||||||
|
return intermediates
|
||||||
|
|
||||||
|
return x, intermediates
|
||||||
|
|
||||||
|
def prune_intermediate_layers(
|
||||||
|
self,
|
||||||
|
indices: Union[int, List[int]] = 1,
|
||||||
|
prune_norm: bool = False,
|
||||||
|
prune_head: bool = True,
|
||||||
|
):
|
||||||
|
""" Prune layers not required for specified intermediates.
|
||||||
|
"""
|
||||||
|
take_indices, max_index = feature_take_indices(len(self.stages), indices)
|
||||||
|
self.stages = self.stages[:max_index + 1] # truncate blocks w/ stem as idx 0
|
||||||
|
if prune_head:
|
||||||
|
self.reset_classifier(0, '')
|
||||||
|
return take_indices
|
||||||
|
|
||||||
def forward_features(self, x: torch.Tensor) -> torch.Tensor:
|
def forward_features(self, x: torch.Tensor) -> torch.Tensor:
|
||||||
x = self.stem(x)
|
x = self.stem(x)
|
||||||
x = self.stages(x)
|
x = self.stages(x)
|
||||||
@ -509,9 +567,11 @@ class GlobalContextVit(nn.Module):
|
|||||||
|
|
||||||
|
|
||||||
def _create_gcvit(variant, pretrained=False, **kwargs):
|
def _create_gcvit(variant, pretrained=False, **kwargs):
|
||||||
if kwargs.get('features_only', None):
|
model = build_model_with_cfg(
|
||||||
raise RuntimeError('features_only not implemented for Vision Transformer models.')
|
GlobalContextVit, variant, pretrained,
|
||||||
model = build_model_with_cfg(GlobalContextVit, variant, pretrained, **kwargs)
|
feature_cfg=dict(out_indices=(0, 1, 2, 3), flatten_sequential=True),
|
||||||
|
**kwargs
|
||||||
|
)
|
||||||
return model
|
return model
|
||||||
|
|
||||||
|
|
||||||
|
@ -6,7 +6,7 @@ The Paddle Implement of PP-HGNet (https://github.com/PaddlePaddle/PaddleClas/blo
|
|||||||
PP-HGNet: https://github.com/PaddlePaddle/PaddleClas/blob/release/2.5.1/ppcls/arch/backbone/legendary_models/pp_hgnet.py
|
PP-HGNet: https://github.com/PaddlePaddle/PaddleClas/blob/release/2.5.1/ppcls/arch/backbone/legendary_models/pp_hgnet.py
|
||||||
PP-HGNetv2: https://github.com/PaddlePaddle/PaddleClas/blob/release/2.5.1/ppcls/arch/backbone/legendary_models/pp_hgnet_v2.py
|
PP-HGNetv2: https://github.com/PaddlePaddle/PaddleClas/blob/release/2.5.1/ppcls/arch/backbone/legendary_models/pp_hgnet_v2.py
|
||||||
"""
|
"""
|
||||||
from typing import Dict, Optional
|
from typing import Dict, List, Optional, Tuple, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
@ -15,6 +15,7 @@ import torch.nn.functional as F
|
|||||||
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
|
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
|
||||||
from timm.layers import SelectAdaptivePool2d, DropPath, create_conv2d
|
from timm.layers import SelectAdaptivePool2d, DropPath, create_conv2d
|
||||||
from ._builder import build_model_with_cfg
|
from ._builder import build_model_with_cfg
|
||||||
|
from ._features import feature_take_indices
|
||||||
from ._registry import register_model, generate_default_cfgs
|
from ._registry import register_model, generate_default_cfgs
|
||||||
from ._manipulate import checkpoint_seq
|
from ._manipulate import checkpoint_seq
|
||||||
|
|
||||||
@ -508,6 +509,62 @@ class HighPerfGpuNet(nn.Module):
|
|||||||
self.num_classes = num_classes
|
self.num_classes = num_classes
|
||||||
self.head.reset(num_classes, global_pool)
|
self.head.reset(num_classes, global_pool)
|
||||||
|
|
||||||
|
def forward_intermediates(
|
||||||
|
self,
|
||||||
|
x: torch.Tensor,
|
||||||
|
indices: Optional[Union[int, List[int]]] = None,
|
||||||
|
norm: bool = False,
|
||||||
|
stop_early: bool = False,
|
||||||
|
output_fmt: str = 'NCHW',
|
||||||
|
intermediates_only: bool = False,
|
||||||
|
) -> Union[List[torch.Tensor], Tuple[torch.Tensor, List[torch.Tensor]]]:
|
||||||
|
""" Forward features that returns intermediates.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
x: Input image tensor
|
||||||
|
indices: Take last n blocks if int, all if None, select matching indices if sequence
|
||||||
|
norm: Apply norm layer to compatible intermediates
|
||||||
|
stop_early: Stop iterating over blocks when last desired intermediate hit
|
||||||
|
output_fmt: Shape of intermediate feature outputs
|
||||||
|
intermediates_only: Only return intermediate features
|
||||||
|
Returns:
|
||||||
|
|
||||||
|
"""
|
||||||
|
assert output_fmt in ('NCHW',), 'Output shape must be NCHW.'
|
||||||
|
intermediates = []
|
||||||
|
take_indices, max_index = feature_take_indices(len(self.stages), indices)
|
||||||
|
|
||||||
|
# forward pass
|
||||||
|
x = self.stem(x)
|
||||||
|
if torch.jit.is_scripting() or not stop_early: # can't slice blocks in torchscript
|
||||||
|
stages = self.stages
|
||||||
|
else:
|
||||||
|
stages = self.stages[:max_index + 1]
|
||||||
|
|
||||||
|
for feat_idx, stage in enumerate(stages):
|
||||||
|
x = stage(x)
|
||||||
|
if feat_idx in take_indices:
|
||||||
|
intermediates.append(x)
|
||||||
|
|
||||||
|
if intermediates_only:
|
||||||
|
return intermediates
|
||||||
|
|
||||||
|
return x, intermediates
|
||||||
|
|
||||||
|
def prune_intermediate_layers(
|
||||||
|
self,
|
||||||
|
indices: Union[int, List[int]] = 1,
|
||||||
|
prune_norm: bool = False,
|
||||||
|
prune_head: bool = True,
|
||||||
|
):
|
||||||
|
""" Prune layers not required for specified intermediates.
|
||||||
|
"""
|
||||||
|
take_indices, max_index = feature_take_indices(len(self.stages), indices)
|
||||||
|
self.stages = self.stages[:max_index + 1] # truncate blocks w/ stem as idx 0
|
||||||
|
if prune_head:
|
||||||
|
self.reset_classifier(0, 'avg')
|
||||||
|
return take_indices
|
||||||
|
|
||||||
def forward_features(self, x):
|
def forward_features(self, x):
|
||||||
x = self.stem(x)
|
x = self.stem(x)
|
||||||
return self.stages(x)
|
return self.stages(x)
|
||||||
|
@ -4,7 +4,7 @@ Original implementation & weights from: https://github.com/sail-sg/inceptionnext
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
from functools import partial
|
from functools import partial
|
||||||
from typing import Optional
|
from typing import List, Optional, Tuple, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
@ -12,6 +12,7 @@ import torch.nn as nn
|
|||||||
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
|
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
|
||||||
from timm.layers import trunc_normal_, DropPath, to_2tuple, get_padding, SelectAdaptivePool2d
|
from timm.layers import trunc_normal_, DropPath, to_2tuple, get_padding, SelectAdaptivePool2d
|
||||||
from ._builder import build_model_with_cfg
|
from ._builder import build_model_with_cfg
|
||||||
|
from ._features import feature_take_indices
|
||||||
from ._manipulate import checkpoint_seq
|
from ._manipulate import checkpoint_seq
|
||||||
from ._registry import register_model, generate_default_cfgs
|
from ._registry import register_model, generate_default_cfgs
|
||||||
|
|
||||||
@ -349,6 +350,62 @@ class MetaNeXt(nn.Module):
|
|||||||
def no_weight_decay(self):
|
def no_weight_decay(self):
|
||||||
return set()
|
return set()
|
||||||
|
|
||||||
|
def forward_intermediates(
|
||||||
|
self,
|
||||||
|
x: torch.Tensor,
|
||||||
|
indices: Optional[Union[int, List[int]]] = None,
|
||||||
|
norm: bool = False,
|
||||||
|
stop_early: bool = False,
|
||||||
|
output_fmt: str = 'NCHW',
|
||||||
|
intermediates_only: bool = False,
|
||||||
|
) -> Union[List[torch.Tensor], Tuple[torch.Tensor, List[torch.Tensor]]]:
|
||||||
|
""" Forward features that returns intermediates.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
x: Input image tensor
|
||||||
|
indices: Take last n blocks if int, all if None, select matching indices if sequence
|
||||||
|
norm: Apply norm layer to compatible intermediates
|
||||||
|
stop_early: Stop iterating over blocks when last desired intermediate hit
|
||||||
|
output_fmt: Shape of intermediate feature outputs
|
||||||
|
intermediates_only: Only return intermediate features
|
||||||
|
Returns:
|
||||||
|
|
||||||
|
"""
|
||||||
|
assert output_fmt in ('NCHW',), 'Output shape must be NCHW.'
|
||||||
|
intermediates = []
|
||||||
|
take_indices, max_index = feature_take_indices(len(self.stages), indices)
|
||||||
|
|
||||||
|
# forward pass
|
||||||
|
x = self.stem(x)
|
||||||
|
if torch.jit.is_scripting() or not stop_early: # can't slice blocks in torchscript
|
||||||
|
stages = self.stages
|
||||||
|
else:
|
||||||
|
stages = self.stages[:max_index + 1]
|
||||||
|
|
||||||
|
for feat_idx, stage in enumerate(stages):
|
||||||
|
x = stage(x)
|
||||||
|
if feat_idx in take_indices:
|
||||||
|
intermediates.append(x)
|
||||||
|
|
||||||
|
if intermediates_only:
|
||||||
|
return intermediates
|
||||||
|
|
||||||
|
return x, intermediates
|
||||||
|
|
||||||
|
def prune_intermediate_layers(
|
||||||
|
self,
|
||||||
|
indices: Union[int, List[int]] = 1,
|
||||||
|
prune_norm: bool = False,
|
||||||
|
prune_head: bool = True,
|
||||||
|
):
|
||||||
|
""" Prune layers not required for specified intermediates.
|
||||||
|
"""
|
||||||
|
take_indices, max_index = feature_take_indices(len(self.stages), indices)
|
||||||
|
self.stages = self.stages[:max_index + 1] # truncate blocks w/ stem as idx 0
|
||||||
|
if prune_head:
|
||||||
|
self.reset_classifier(0, 'avg')
|
||||||
|
return take_indices
|
||||||
|
|
||||||
def forward_features(self, x):
|
def forward_features(self, x):
|
||||||
x = self.stem(x)
|
x = self.stem(x)
|
||||||
x = self.stages(x)
|
x = self.stages(x)
|
||||||
|
@ -3,6 +3,7 @@ Sourced from https://github.com/Cadene/tensorflow-model-zoo.torch (MIT License)
|
|||||||
based upon Google's Tensorflow implementation and pretrained weights (Apache 2.0 License)
|
based upon Google's Tensorflow implementation and pretrained weights (Apache 2.0 License)
|
||||||
"""
|
"""
|
||||||
from functools import partial
|
from functools import partial
|
||||||
|
from typing import List, Optional, Tuple, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
@ -10,6 +11,7 @@ import torch.nn as nn
|
|||||||
from timm.data import IMAGENET_INCEPTION_MEAN, IMAGENET_INCEPTION_STD
|
from timm.data import IMAGENET_INCEPTION_MEAN, IMAGENET_INCEPTION_STD
|
||||||
from timm.layers import create_classifier, ConvNormAct
|
from timm.layers import create_classifier, ConvNormAct
|
||||||
from ._builder import build_model_with_cfg
|
from ._builder import build_model_with_cfg
|
||||||
|
from ._features import feature_take_indices
|
||||||
from ._registry import register_model, generate_default_cfgs
|
from ._registry import register_model, generate_default_cfgs
|
||||||
|
|
||||||
__all__ = ['InceptionV4']
|
__all__ = ['InceptionV4']
|
||||||
@ -285,6 +287,66 @@ class InceptionV4(nn.Module):
|
|||||||
self.global_pool, self.last_linear = create_classifier(
|
self.global_pool, self.last_linear = create_classifier(
|
||||||
self.num_features, self.num_classes, pool_type=global_pool)
|
self.num_features, self.num_classes, pool_type=global_pool)
|
||||||
|
|
||||||
|
def forward_intermediates(
|
||||||
|
self,
|
||||||
|
x: torch.Tensor,
|
||||||
|
indices: Optional[Union[int, List[int]]] = None,
|
||||||
|
norm: bool = False,
|
||||||
|
stop_early: bool = False,
|
||||||
|
output_fmt: str = 'NCHW',
|
||||||
|
intermediates_only: bool = False,
|
||||||
|
) -> Union[List[torch.Tensor], Tuple[torch.Tensor, List[torch.Tensor]]]:
|
||||||
|
""" Forward features that returns intermediates.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
x: Input image tensor
|
||||||
|
indices: Take last n blocks if int, all if None, select matching indices if sequence
|
||||||
|
norm: Apply norm layer to compatible intermediates
|
||||||
|
stop_early: Stop iterating over blocks when last desired intermediate hit
|
||||||
|
output_fmt: Shape of intermediate feature outputs
|
||||||
|
intermediates_only: Only return intermediate features
|
||||||
|
Returns:
|
||||||
|
|
||||||
|
"""
|
||||||
|
assert output_fmt in ('NCHW',), 'Output shape must be NCHW.'
|
||||||
|
intermediates = []
|
||||||
|
stage_ends = [int(info['module'].split('.')[-1]) for info in self.feature_info]
|
||||||
|
take_indices, max_index = feature_take_indices(len(stage_ends), indices)
|
||||||
|
take_indices = [stage_ends[i] for i in take_indices]
|
||||||
|
max_index = stage_ends[max_index]
|
||||||
|
|
||||||
|
# forward pass
|
||||||
|
if torch.jit.is_scripting() or not stop_early: # can't slice blocks in torchscript
|
||||||
|
stages = self.features
|
||||||
|
else:
|
||||||
|
stages = self.features[:max_index + 1]
|
||||||
|
|
||||||
|
for feat_idx, stage in enumerate(stages):
|
||||||
|
x = stage(x)
|
||||||
|
if feat_idx in take_indices:
|
||||||
|
intermediates.append(x)
|
||||||
|
|
||||||
|
if intermediates_only:
|
||||||
|
return intermediates
|
||||||
|
|
||||||
|
return x, intermediates
|
||||||
|
|
||||||
|
def prune_intermediate_layers(
|
||||||
|
self,
|
||||||
|
indices: Union[int, List[int]] = 1,
|
||||||
|
prune_norm: bool = False,
|
||||||
|
prune_head: bool = True,
|
||||||
|
):
|
||||||
|
""" Prune layers not required for specified intermediates.
|
||||||
|
"""
|
||||||
|
stage_ends = [int(info['module'].split('.')[-1]) for info in self.feature_info]
|
||||||
|
take_indices, max_index = feature_take_indices(len(stage_ends), indices)
|
||||||
|
max_index = stage_ends[max_index]
|
||||||
|
self.features = self.features[:max_index + 1] # truncate blocks w/ stem as idx 0
|
||||||
|
if prune_head:
|
||||||
|
self.reset_classifier(0, '')
|
||||||
|
return take_indices
|
||||||
|
|
||||||
def forward_features(self, x):
|
def forward_features(self, x):
|
||||||
return self.features(x)
|
return self.features(x)
|
||||||
|
|
||||||
|
@ -6,7 +6,7 @@ MetaFormer (https://github.com/sail-sg/metaformer),
|
|||||||
InceptionNeXt (https://github.com/sail-sg/inceptionnext)
|
InceptionNeXt (https://github.com/sail-sg/inceptionnext)
|
||||||
"""
|
"""
|
||||||
from collections import OrderedDict
|
from collections import OrderedDict
|
||||||
from typing import Optional
|
from typing import List, Optional, Tuple, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from torch import nn
|
from torch import nn
|
||||||
@ -14,6 +14,7 @@ from torch import nn
|
|||||||
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
|
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
|
||||||
from timm.layers import trunc_normal_, DropPath, LayerNorm, LayerScale, ClNormMlpClassifierHead, get_act_layer
|
from timm.layers import trunc_normal_, DropPath, LayerNorm, LayerScale, ClNormMlpClassifierHead, get_act_layer
|
||||||
from ._builder import build_model_with_cfg
|
from ._builder import build_model_with_cfg
|
||||||
|
from ._features import feature_take_indices
|
||||||
from ._manipulate import checkpoint_seq
|
from ._manipulate import checkpoint_seq
|
||||||
from ._registry import register_model, generate_default_cfgs
|
from ._registry import register_model, generate_default_cfgs
|
||||||
|
|
||||||
@ -417,6 +418,67 @@ class MambaOut(nn.Module):
|
|||||||
self.num_classes = num_classes
|
self.num_classes = num_classes
|
||||||
self.head.reset(num_classes, global_pool)
|
self.head.reset(num_classes, global_pool)
|
||||||
|
|
||||||
|
def forward_intermediates(
|
||||||
|
self,
|
||||||
|
x: torch.Tensor,
|
||||||
|
indices: Optional[Union[int, List[int]]] = None,
|
||||||
|
norm: bool = False,
|
||||||
|
stop_early: bool = False,
|
||||||
|
output_fmt: str = 'NCHW',
|
||||||
|
intermediates_only: bool = False,
|
||||||
|
) -> Union[List[torch.Tensor], Tuple[torch.Tensor, List[torch.Tensor]]]:
|
||||||
|
""" Forward features that returns intermediates.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
x: Input image tensor
|
||||||
|
indices: Take last n blocks if int, all if None, select matching indices if sequence
|
||||||
|
norm: Apply norm layer to compatible intermediates
|
||||||
|
stop_early: Stop iterating over blocks when last desired intermediate hit
|
||||||
|
output_fmt: Shape of intermediate feature outputs
|
||||||
|
intermediates_only: Only return intermediate features
|
||||||
|
Returns:
|
||||||
|
|
||||||
|
"""
|
||||||
|
assert output_fmt in ('NCHW', 'NHWC'), 'Output format must be one of NCHW or NHWC.'
|
||||||
|
channel_first = output_fmt == 'NCHW'
|
||||||
|
intermediates = []
|
||||||
|
take_indices, max_index = feature_take_indices(len(self.stages), indices)
|
||||||
|
|
||||||
|
# forward pass
|
||||||
|
x = self.stem(x)
|
||||||
|
if torch.jit.is_scripting() or not stop_early: # can't slice blocks in torchscript
|
||||||
|
stages = self.stages
|
||||||
|
else:
|
||||||
|
stages = self.stages[:max_index + 1]
|
||||||
|
|
||||||
|
for feat_idx, stage in enumerate(stages):
|
||||||
|
x = stage(x)
|
||||||
|
if feat_idx in take_indices:
|
||||||
|
intermediates.append(x)
|
||||||
|
|
||||||
|
if channel_first:
|
||||||
|
# reshape to BCHW output format
|
||||||
|
intermediates = [y.permute(0, 3, 1, 2).contiguous() for y in intermediates]
|
||||||
|
|
||||||
|
if intermediates_only:
|
||||||
|
return intermediates
|
||||||
|
|
||||||
|
return x, intermediates
|
||||||
|
|
||||||
|
def prune_intermediate_layers(
|
||||||
|
self,
|
||||||
|
indices: Union[int, List[int]] = 1,
|
||||||
|
prune_norm: bool = False,
|
||||||
|
prune_head: bool = True,
|
||||||
|
):
|
||||||
|
""" Prune layers not required for specified intermediates.
|
||||||
|
"""
|
||||||
|
take_indices, max_index = feature_take_indices(len(self.stages), indices)
|
||||||
|
self.stages = self.stages[:max_index + 1] # truncate blocks w/ stem as idx 0
|
||||||
|
if prune_head:
|
||||||
|
self.reset_classifier(0, '')
|
||||||
|
return take_indices
|
||||||
|
|
||||||
def forward_features(self, x):
|
def forward_features(self, x):
|
||||||
x = self.stem(x)
|
x = self.stem(x)
|
||||||
x = self.stages(x)
|
x = self.stages(x)
|
||||||
|
@ -1302,7 +1302,8 @@ class MaxxVit(nn.Module):
|
|||||||
if intermediates_only:
|
if intermediates_only:
|
||||||
return intermediates
|
return intermediates
|
||||||
|
|
||||||
x = self.norm(x)
|
if feat_idx == last_idx:
|
||||||
|
x = self.norm(x)
|
||||||
|
|
||||||
return x, intermediates
|
return x, intermediates
|
||||||
|
|
||||||
|
@ -28,7 +28,7 @@ Adapted from https://github.com/sail-sg/metaformer, original copyright below
|
|||||||
|
|
||||||
from collections import OrderedDict
|
from collections import OrderedDict
|
||||||
from functools import partial
|
from functools import partial
|
||||||
from typing import Optional
|
from typing import List, Optional, Tuple, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
@ -40,6 +40,7 @@ from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
|
|||||||
from timm.layers import trunc_normal_, DropPath, SelectAdaptivePool2d, GroupNorm1, LayerNorm, LayerNorm2d, Mlp, \
|
from timm.layers import trunc_normal_, DropPath, SelectAdaptivePool2d, GroupNorm1, LayerNorm, LayerNorm2d, Mlp, \
|
||||||
use_fused_attn
|
use_fused_attn
|
||||||
from ._builder import build_model_with_cfg
|
from ._builder import build_model_with_cfg
|
||||||
|
from ._features import feature_take_indices
|
||||||
from ._manipulate import checkpoint_seq
|
from ._manipulate import checkpoint_seq
|
||||||
from ._registry import generate_default_cfgs, register_model
|
from ._registry import generate_default_cfgs, register_model
|
||||||
|
|
||||||
@ -597,6 +598,62 @@ class MetaFormer(nn.Module):
|
|||||||
final = nn.Identity()
|
final = nn.Identity()
|
||||||
self.head.fc = final
|
self.head.fc = final
|
||||||
|
|
||||||
|
def forward_intermediates(
|
||||||
|
self,
|
||||||
|
x: torch.Tensor,
|
||||||
|
indices: Optional[Union[int, List[int]]] = None,
|
||||||
|
norm: bool = False,
|
||||||
|
stop_early: bool = False,
|
||||||
|
output_fmt: str = 'NCHW',
|
||||||
|
intermediates_only: bool = False,
|
||||||
|
) -> Union[List[torch.Tensor], Tuple[torch.Tensor, List[torch.Tensor]]]:
|
||||||
|
""" Forward features that returns intermediates.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
x: Input image tensor
|
||||||
|
indices: Take last n blocks if int, all if None, select matching indices if sequence
|
||||||
|
norm: Apply norm layer to compatible intermediates
|
||||||
|
stop_early: Stop iterating over blocks when last desired intermediate hit
|
||||||
|
output_fmt: Shape of intermediate feature outputs
|
||||||
|
intermediates_only: Only return intermediate features
|
||||||
|
Returns:
|
||||||
|
|
||||||
|
"""
|
||||||
|
assert output_fmt in ('NCHW',), 'Output shape must be NCHW.'
|
||||||
|
intermediates = []
|
||||||
|
take_indices, max_index = feature_take_indices(len(self.stages), indices)
|
||||||
|
|
||||||
|
# forward pass
|
||||||
|
x = self.stem(x)
|
||||||
|
if torch.jit.is_scripting() or not stop_early: # can't slice blocks in torchscript
|
||||||
|
stages = self.stages
|
||||||
|
else:
|
||||||
|
stages = self.stages[:max_index + 1]
|
||||||
|
|
||||||
|
for feat_idx, stage in enumerate(stages):
|
||||||
|
x = stage(x)
|
||||||
|
if feat_idx in take_indices:
|
||||||
|
intermediates.append(x)
|
||||||
|
|
||||||
|
if intermediates_only:
|
||||||
|
return intermediates
|
||||||
|
|
||||||
|
return x, intermediates
|
||||||
|
|
||||||
|
def prune_intermediate_layers(
|
||||||
|
self,
|
||||||
|
indices: Union[int, List[int]] = 1,
|
||||||
|
prune_norm: bool = False,
|
||||||
|
prune_head: bool = True,
|
||||||
|
):
|
||||||
|
""" Prune layers not required for specified intermediates.
|
||||||
|
"""
|
||||||
|
take_indices, max_index = feature_take_indices(len(self.stages), indices)
|
||||||
|
self.stages = self.stages[:max_index + 1] # truncate blocks w/ stem as idx 0
|
||||||
|
if prune_head:
|
||||||
|
self.reset_classifier(0, '')
|
||||||
|
return take_indices
|
||||||
|
|
||||||
def forward_head(self, x: Tensor, pre_logits: bool = False):
|
def forward_head(self, x: Tensor, pre_logits: bool = False):
|
||||||
# NOTE nn.Sequential in head broken down since can't call head[:-1](x) in torchscript :(
|
# NOTE nn.Sequential in head broken down since can't call head[:-1](x) in torchscript :(
|
||||||
x = self.head.global_pool(x)
|
x = self.head.global_pool(x)
|
||||||
|
@ -870,10 +870,11 @@ class MultiScaleVit(nn.Module):
|
|||||||
if self.pos_embed is not None:
|
if self.pos_embed is not None:
|
||||||
x = x + self.pos_embed
|
x = x + self.pos_embed
|
||||||
|
|
||||||
for i, stage in enumerate(self.stages):
|
last_idx = len(self.stages) - 1
|
||||||
|
for feat_idx, stage in enumerate(self.stages):
|
||||||
x, feat_size = stage(x, feat_size)
|
x, feat_size = stage(x, feat_size)
|
||||||
if i in take_indices:
|
if feat_idx in take_indices:
|
||||||
if norm and i == (len(self.stages) - 1):
|
if norm and feat_idx == last_idx:
|
||||||
x_inter = self.norm(x) # applying final norm last intermediate
|
x_inter = self.norm(x) # applying final norm last intermediate
|
||||||
else:
|
else:
|
||||||
x_inter = x
|
x_inter = x
|
||||||
@ -887,7 +888,8 @@ class MultiScaleVit(nn.Module):
|
|||||||
if intermediates_only:
|
if intermediates_only:
|
||||||
return intermediates
|
return intermediates
|
||||||
|
|
||||||
x = self.norm(x)
|
if feat_idx == last_idx:
|
||||||
|
x = self.norm(x)
|
||||||
|
|
||||||
return x, intermediates
|
return x, intermediates
|
||||||
|
|
||||||
|
@ -19,6 +19,7 @@ import collections.abc
|
|||||||
import logging
|
import logging
|
||||||
import math
|
import math
|
||||||
from functools import partial
|
from functools import partial
|
||||||
|
from typing import List, Optional, Tuple, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
@ -28,6 +29,7 @@ from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
|
|||||||
from timm.layers import PatchEmbed, Mlp, DropPath, create_classifier, trunc_normal_, _assert
|
from timm.layers import PatchEmbed, Mlp, DropPath, create_classifier, trunc_normal_, _assert
|
||||||
from timm.layers import create_conv2d, create_pool2d, to_ntuple, use_fused_attn, LayerNorm
|
from timm.layers import create_conv2d, create_pool2d, to_ntuple, use_fused_attn, LayerNorm
|
||||||
from ._builder import build_model_with_cfg
|
from ._builder import build_model_with_cfg
|
||||||
|
from ._features import feature_take_indices
|
||||||
from ._features_fx import register_notrace_function
|
from ._features_fx import register_notrace_function
|
||||||
from ._manipulate import checkpoint_seq, named_apply
|
from ._manipulate import checkpoint_seq, named_apply
|
||||||
from ._registry import register_model, generate_default_cfgs, register_model_deprecations
|
from ._registry import register_model, generate_default_cfgs, register_model_deprecations
|
||||||
@ -420,6 +422,73 @@ class Nest(nn.Module):
|
|||||||
self.global_pool, self.head = create_classifier(
|
self.global_pool, self.head = create_classifier(
|
||||||
self.num_features, self.num_classes, pool_type=global_pool)
|
self.num_features, self.num_classes, pool_type=global_pool)
|
||||||
|
|
||||||
|
def forward_intermediates(
|
||||||
|
self,
|
||||||
|
x: torch.Tensor,
|
||||||
|
indices: Optional[Union[int, List[int]]] = None,
|
||||||
|
norm: bool = False,
|
||||||
|
stop_early: bool = False,
|
||||||
|
output_fmt: str = 'NCHW',
|
||||||
|
intermediates_only: bool = False,
|
||||||
|
) -> Union[List[torch.Tensor], Tuple[torch.Tensor, List[torch.Tensor]]]:
|
||||||
|
""" Forward features that returns intermediates.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
x: Input image tensor
|
||||||
|
indices: Take last n blocks if int, all if None, select matching indices if sequence
|
||||||
|
norm: Apply norm layer to compatible intermediates
|
||||||
|
stop_early: Stop iterating over blocks when last desired intermediate hit
|
||||||
|
output_fmt: Shape of intermediate feature outputs
|
||||||
|
intermediates_only: Only return intermediate features
|
||||||
|
Returns:
|
||||||
|
|
||||||
|
"""
|
||||||
|
assert output_fmt in ('NCHW',), 'Output shape must be NCHW.'
|
||||||
|
intermediates = []
|
||||||
|
take_indices, max_index = feature_take_indices(len(self.levels), indices)
|
||||||
|
|
||||||
|
# forward pass
|
||||||
|
x = self.patch_embed(x)
|
||||||
|
last_idx = len(self.num_blocks) - 1
|
||||||
|
if torch.jit.is_scripting() or not stop_early: # can't slice blocks in torchscript
|
||||||
|
stages = self.levels
|
||||||
|
else:
|
||||||
|
stages = self.levels[:max_index + 1]
|
||||||
|
|
||||||
|
for feat_idx, stage in enumerate(stages):
|
||||||
|
x = stage(x)
|
||||||
|
if feat_idx in take_indices:
|
||||||
|
if norm and feat_idx == last_idx:
|
||||||
|
x_inter = self.norm(x.permute(0, 2, 3, 1)).permute(0, 3, 1, 2)
|
||||||
|
intermediates.append(x_inter)
|
||||||
|
else:
|
||||||
|
intermediates.append(x)
|
||||||
|
|
||||||
|
if intermediates_only:
|
||||||
|
return intermediates
|
||||||
|
|
||||||
|
if feat_idx == last_idx:
|
||||||
|
# Layer norm done over channel dim only (to NHWC and back)
|
||||||
|
x = self.norm(x.permute(0, 2, 3, 1)).permute(0, 3, 1, 2)
|
||||||
|
|
||||||
|
return x, intermediates
|
||||||
|
|
||||||
|
def prune_intermediate_layers(
|
||||||
|
self,
|
||||||
|
indices: Union[int, List[int]] = 1,
|
||||||
|
prune_norm: bool = False,
|
||||||
|
prune_head: bool = True,
|
||||||
|
):
|
||||||
|
""" Prune layers not required for specified intermediates.
|
||||||
|
"""
|
||||||
|
take_indices, max_index = feature_take_indices(len(self.levels), indices)
|
||||||
|
self.levels = self.levels[:max_index + 1] # truncate blocks w/ stem as idx 0
|
||||||
|
if prune_norm:
|
||||||
|
self.norm = nn.Identity()
|
||||||
|
if prune_head:
|
||||||
|
self.reset_classifier(0, '')
|
||||||
|
return take_indices
|
||||||
|
|
||||||
def forward_features(self, x):
|
def forward_features(self, x):
|
||||||
x = self.patch_embed(x)
|
x = self.patch_embed(x)
|
||||||
x = self.levels(x)
|
x = self.levels(x)
|
||||||
|
@ -6,7 +6,7 @@ Next-ViT model defs and weights adapted from https://github.com/bytedance/Next-V
|
|||||||
"""
|
"""
|
||||||
# Copyright (c) ByteDance Inc. All rights reserved.
|
# Copyright (c) ByteDance Inc. All rights reserved.
|
||||||
from functools import partial
|
from functools import partial
|
||||||
from typing import Optional
|
from typing import List, Optional, Tuple, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
@ -16,6 +16,7 @@ from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
|
|||||||
from timm.layers import DropPath, trunc_normal_, ConvMlp, get_norm_layer, get_act_layer, use_fused_attn
|
from timm.layers import DropPath, trunc_normal_, ConvMlp, get_norm_layer, get_act_layer, use_fused_attn
|
||||||
from timm.layers import ClassifierHead
|
from timm.layers import ClassifierHead
|
||||||
from ._builder import build_model_with_cfg
|
from ._builder import build_model_with_cfg
|
||||||
|
from ._features import feature_take_indices
|
||||||
from ._features_fx import register_notrace_function
|
from ._features_fx import register_notrace_function
|
||||||
from ._manipulate import checkpoint_seq
|
from ._manipulate import checkpoint_seq
|
||||||
from ._registry import generate_default_cfgs, register_model
|
from ._registry import generate_default_cfgs, register_model
|
||||||
@ -560,6 +561,72 @@ class NextViT(nn.Module):
|
|||||||
self.num_classes = num_classes
|
self.num_classes = num_classes
|
||||||
self.head.reset(num_classes, pool_type=global_pool)
|
self.head.reset(num_classes, pool_type=global_pool)
|
||||||
|
|
||||||
|
def forward_intermediates(
|
||||||
|
self,
|
||||||
|
x: torch.Tensor,
|
||||||
|
indices: Optional[Union[int, List[int]]] = None,
|
||||||
|
norm: bool = False,
|
||||||
|
stop_early: bool = False,
|
||||||
|
output_fmt: str = 'NCHW',
|
||||||
|
intermediates_only: bool = False,
|
||||||
|
) -> Union[List[torch.Tensor], Tuple[torch.Tensor, List[torch.Tensor]]]:
|
||||||
|
""" Forward features that returns intermediates.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
x: Input image tensor
|
||||||
|
indices: Take last n blocks if int, all if None, select matching indices if sequence
|
||||||
|
norm: Apply norm layer to compatible intermediates
|
||||||
|
stop_early: Stop iterating over blocks when last desired intermediate hit
|
||||||
|
output_fmt: Shape of intermediate feature outputs
|
||||||
|
intermediates_only: Only return intermediate features
|
||||||
|
Returns:
|
||||||
|
|
||||||
|
"""
|
||||||
|
assert output_fmt in ('NCHW',), 'Output shape must be NCHW.'
|
||||||
|
intermediates = []
|
||||||
|
take_indices, max_index = feature_take_indices(len(self.stages), indices)
|
||||||
|
|
||||||
|
# forward pass
|
||||||
|
x = self.stem(x)
|
||||||
|
last_idx = len(self.stages) - 1
|
||||||
|
if torch.jit.is_scripting() or not stop_early: # can't slice blocks in torchscript
|
||||||
|
stages = self.stages
|
||||||
|
else:
|
||||||
|
stages = self.stages[:max_index + 1]
|
||||||
|
|
||||||
|
for feat_idx, stage in enumerate(stages):
|
||||||
|
x = stage(x)
|
||||||
|
if feat_idx in take_indices:
|
||||||
|
if feat_idx == last_idx:
|
||||||
|
x_inter = self.norm(x) if norm else x
|
||||||
|
intermediates.append(x_inter)
|
||||||
|
else:
|
||||||
|
intermediates.append(x)
|
||||||
|
|
||||||
|
if intermediates_only:
|
||||||
|
return intermediates
|
||||||
|
|
||||||
|
if feat_idx == last_idx:
|
||||||
|
x = self.norm(x)
|
||||||
|
|
||||||
|
return x, intermediates
|
||||||
|
|
||||||
|
def prune_intermediate_layers(
|
||||||
|
self,
|
||||||
|
indices: Union[int, List[int]] = 1,
|
||||||
|
prune_norm: bool = False,
|
||||||
|
prune_head: bool = True,
|
||||||
|
):
|
||||||
|
""" Prune layers not required for specified intermediates.
|
||||||
|
"""
|
||||||
|
take_indices, max_index = feature_take_indices(len(self.stages), indices)
|
||||||
|
self.stages = self.stages[:max_index + 1] # truncate blocks w/ stem as idx 0
|
||||||
|
if prune_norm:
|
||||||
|
self.norm = nn.Identity()
|
||||||
|
if prune_head:
|
||||||
|
self.reset_classifier(0, '')
|
||||||
|
return take_indices
|
||||||
|
|
||||||
def forward_features(self, x):
|
def forward_features(self, x):
|
||||||
x = self.stem(x)
|
x = self.stem(x)
|
||||||
if self.grad_checkpointing and not torch.jit.is_scripting():
|
if self.grad_checkpointing and not torch.jit.is_scripting():
|
||||||
|
@ -14,7 +14,7 @@ Modifications for timm by / Copyright 2020 Ross Wightman
|
|||||||
import math
|
import math
|
||||||
import re
|
import re
|
||||||
from functools import partial
|
from functools import partial
|
||||||
from typing import Optional, Sequence, Tuple
|
from typing import List, Optional, Sequence, Tuple, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from torch import nn
|
from torch import nn
|
||||||
@ -22,6 +22,7 @@ from torch import nn
|
|||||||
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
|
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
|
||||||
from timm.layers import trunc_normal_, to_2tuple
|
from timm.layers import trunc_normal_, to_2tuple
|
||||||
from ._builder import build_model_with_cfg
|
from ._builder import build_model_with_cfg
|
||||||
|
from ._features import feature_take_indices
|
||||||
from ._registry import register_model, generate_default_cfgs
|
from ._registry import register_model, generate_default_cfgs
|
||||||
from .vision_transformer import Block
|
from .vision_transformer import Block
|
||||||
|
|
||||||
@ -254,6 +255,71 @@ class PoolingVisionTransformer(nn.Module):
|
|||||||
if self.head_dist is not None:
|
if self.head_dist is not None:
|
||||||
self.head_dist = nn.Linear(self.embed_dim, self.num_classes) if num_classes > 0 else nn.Identity()
|
self.head_dist = nn.Linear(self.embed_dim, self.num_classes) if num_classes > 0 else nn.Identity()
|
||||||
|
|
||||||
|
def forward_intermediates(
|
||||||
|
self,
|
||||||
|
x: torch.Tensor,
|
||||||
|
indices: Optional[Union[int, List[int]]] = None,
|
||||||
|
norm: bool = False,
|
||||||
|
stop_early: bool = False,
|
||||||
|
output_fmt: str = 'NCHW',
|
||||||
|
intermediates_only: bool = False,
|
||||||
|
) -> Union[List[torch.Tensor], Tuple[torch.Tensor, List[torch.Tensor]]]:
|
||||||
|
""" Forward features that returns intermediates.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
x: Input image tensor
|
||||||
|
indices: Take last n blocks if int, all if None, select matching indices if sequence
|
||||||
|
norm: Apply norm layer to compatible intermediates
|
||||||
|
stop_early: Stop iterating over blocks when last desired intermediate hit
|
||||||
|
output_fmt: Shape of intermediate feature outputs
|
||||||
|
intermediates_only: Only return intermediate features
|
||||||
|
Returns:
|
||||||
|
|
||||||
|
"""
|
||||||
|
assert output_fmt in ('NCHW',), 'Output shape must be NCHW.'
|
||||||
|
intermediates = []
|
||||||
|
take_indices, max_index = feature_take_indices(len(self.transformers), indices)
|
||||||
|
|
||||||
|
# forward pass
|
||||||
|
x = self.patch_embed(x)
|
||||||
|
x = self.pos_drop(x + self.pos_embed)
|
||||||
|
cls_tokens = self.cls_token.expand(x.shape[0], -1, -1)
|
||||||
|
|
||||||
|
last_idx = len(self.transformers) - 1
|
||||||
|
if torch.jit.is_scripting() or not stop_early: # can't slice blocks in torchscript
|
||||||
|
stages = self.transformers
|
||||||
|
else:
|
||||||
|
stages = self.transformers[:max_index + 1]
|
||||||
|
|
||||||
|
for feat_idx, stage in enumerate(stages):
|
||||||
|
x, cls_tokens = stage((x, cls_tokens))
|
||||||
|
if feat_idx in take_indices:
|
||||||
|
intermediates.append(x)
|
||||||
|
|
||||||
|
if intermediates_only:
|
||||||
|
return intermediates
|
||||||
|
|
||||||
|
if feat_idx == last_idx:
|
||||||
|
cls_tokens = self.norm(cls_tokens)
|
||||||
|
|
||||||
|
return cls_tokens, intermediates
|
||||||
|
|
||||||
|
def prune_intermediate_layers(
|
||||||
|
self,
|
||||||
|
indices: Union[int, List[int]] = 1,
|
||||||
|
prune_norm: bool = False,
|
||||||
|
prune_head: bool = True,
|
||||||
|
):
|
||||||
|
""" Prune layers not required for specified intermediates.
|
||||||
|
"""
|
||||||
|
take_indices, max_index = feature_take_indices(len(self.transformers), indices)
|
||||||
|
self.transformers = self.transformers[:max_index + 1] # truncate blocks w/ stem as idx 0
|
||||||
|
if prune_norm:
|
||||||
|
self.norm = nn.Identity()
|
||||||
|
if prune_head:
|
||||||
|
self.reset_classifier(0, '')
|
||||||
|
return take_indices
|
||||||
|
|
||||||
def forward_features(self, x):
|
def forward_features(self, x):
|
||||||
x = self.patch_embed(x)
|
x = self.patch_embed(x)
|
||||||
x = self.pos_drop(x + self.pos_embed)
|
x = self.pos_drop(x + self.pos_embed)
|
||||||
@ -314,7 +380,7 @@ def _create_pit(variant, pretrained=False, **kwargs):
|
|||||||
variant,
|
variant,
|
||||||
pretrained,
|
pretrained,
|
||||||
pretrained_filter_fn=checkpoint_filter_fn,
|
pretrained_filter_fn=checkpoint_filter_fn,
|
||||||
feature_cfg=dict(feature_cls='hook', no_rewrite=True, out_indices=out_indices),
|
feature_cfg=dict(feature_cls='hook', out_indices=out_indices),
|
||||||
**kwargs,
|
**kwargs,
|
||||||
)
|
)
|
||||||
return model
|
return model
|
||||||
|
@ -16,7 +16,7 @@ Modifications and timm support by / Copyright 2022, Ross Wightman
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
import math
|
import math
|
||||||
from typing import Callable, List, Optional, Union
|
from typing import Callable, List, Optional, Tuple, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
@ -25,6 +25,7 @@ import torch.nn.functional as F
|
|||||||
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
|
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
|
||||||
from timm.layers import DropPath, to_2tuple, to_ntuple, trunc_normal_, LayerNorm, use_fused_attn
|
from timm.layers import DropPath, to_2tuple, to_ntuple, trunc_normal_, LayerNorm, use_fused_attn
|
||||||
from ._builder import build_model_with_cfg
|
from ._builder import build_model_with_cfg
|
||||||
|
from ._features import feature_take_indices
|
||||||
from ._manipulate import checkpoint
|
from ._manipulate import checkpoint
|
||||||
from ._registry import register_model, generate_default_cfgs
|
from ._registry import register_model, generate_default_cfgs
|
||||||
|
|
||||||
@ -386,6 +387,62 @@ class PyramidVisionTransformerV2(nn.Module):
|
|||||||
self.global_pool = global_pool
|
self.global_pool = global_pool
|
||||||
self.head = nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity()
|
self.head = nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity()
|
||||||
|
|
||||||
|
def forward_intermediates(
|
||||||
|
self,
|
||||||
|
x: torch.Tensor,
|
||||||
|
indices: Optional[Union[int, List[int]]] = None,
|
||||||
|
norm: bool = False,
|
||||||
|
stop_early: bool = False,
|
||||||
|
output_fmt: str = 'NCHW',
|
||||||
|
intermediates_only: bool = False,
|
||||||
|
) -> Union[List[torch.Tensor], Tuple[torch.Tensor, List[torch.Tensor]]]:
|
||||||
|
""" Forward features that returns intermediates.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
x: Input image tensor
|
||||||
|
indices: Take last n blocks if int, all if None, select matching indices if sequence
|
||||||
|
norm: Apply norm layer to compatible intermediates
|
||||||
|
stop_early: Stop iterating over blocks when last desired intermediate hit
|
||||||
|
output_fmt: Shape of intermediate feature outputs
|
||||||
|
intermediates_only: Only return intermediate features
|
||||||
|
Returns:
|
||||||
|
|
||||||
|
"""
|
||||||
|
assert output_fmt in ('NCHW',), 'Output shape must be NCHW.'
|
||||||
|
intermediates = []
|
||||||
|
take_indices, max_index = feature_take_indices(len(self.stages), indices)
|
||||||
|
|
||||||
|
# forward pass
|
||||||
|
x = self.patch_embed(x)
|
||||||
|
if torch.jit.is_scripting() or not stop_early: # can't slice blocks in torchscript
|
||||||
|
stages = self.stages
|
||||||
|
else:
|
||||||
|
stages = self.stages[:max_index + 1]
|
||||||
|
|
||||||
|
for feat_idx, stage in enumerate(stages):
|
||||||
|
x = stage(x)
|
||||||
|
if feat_idx in take_indices:
|
||||||
|
intermediates.append(x)
|
||||||
|
|
||||||
|
if intermediates_only:
|
||||||
|
return intermediates
|
||||||
|
|
||||||
|
return x, intermediates
|
||||||
|
|
||||||
|
def prune_intermediate_layers(
|
||||||
|
self,
|
||||||
|
indices: Union[int, List[int]] = 1,
|
||||||
|
prune_norm: bool = False,
|
||||||
|
prune_head: bool = True,
|
||||||
|
):
|
||||||
|
""" Prune layers not required for specified intermediates.
|
||||||
|
"""
|
||||||
|
take_indices, max_index = feature_take_indices(len(self.stages), indices)
|
||||||
|
self.stages = self.stages[:max_index + 1] # truncate blocks w/ stem as idx 0
|
||||||
|
if prune_head:
|
||||||
|
self.reset_classifier(0, '')
|
||||||
|
return take_indices
|
||||||
|
|
||||||
def forward_features(self, x):
|
def forward_features(self, x):
|
||||||
x = self.patch_embed(x)
|
x = self.patch_embed(x)
|
||||||
x = self.stages(x)
|
x = self.stages(x)
|
||||||
|
@ -302,29 +302,33 @@ class RDNet(nn.Module):
|
|||||||
"""
|
"""
|
||||||
assert output_fmt in ('NCHW',), 'Output shape must be NCHW.'
|
assert output_fmt in ('NCHW',), 'Output shape must be NCHW.'
|
||||||
intermediates = []
|
intermediates = []
|
||||||
take_indices, max_index = feature_take_indices(len(self.dense_stages) + 1, indices)
|
stage_ends = [int(info['module'].split('.')[-1]) for info in self.feature_info]
|
||||||
|
take_indices, max_index = feature_take_indices(len(stage_ends), indices)
|
||||||
|
take_indices = [stage_ends[i] for i in take_indices]
|
||||||
|
max_index = stage_ends[max_index]
|
||||||
|
|
||||||
# forward pass
|
# forward pass
|
||||||
feat_idx = 0 # stem is index 0
|
|
||||||
x = self.stem(x)
|
x = self.stem(x)
|
||||||
if feat_idx in take_indices:
|
|
||||||
intermediates.append(x)
|
|
||||||
|
|
||||||
|
last_idx = len(self.dense_stages) - 1
|
||||||
if torch.jit.is_scripting() or not stop_early: # can't slice blocks in torchscript
|
if torch.jit.is_scripting() or not stop_early: # can't slice blocks in torchscript
|
||||||
dense_stages = self.dense_stages
|
dense_stages = self.dense_stages
|
||||||
else:
|
else:
|
||||||
dense_stages = self.dense_stages[:max_index]
|
dense_stages = self.dense_stages[:max_index + 1]
|
||||||
for stage in dense_stages:
|
for feat_idx, stage in enumerate(dense_stages):
|
||||||
feat_idx += 1
|
|
||||||
x = stage(x)
|
x = stage(x)
|
||||||
if feat_idx in take_indices:
|
if feat_idx in take_indices:
|
||||||
# NOTE not bothering to apply norm_pre when norm=True as almost no models have it enabled
|
if norm and feat_idx == last_idx:
|
||||||
intermediates.append(x)
|
x_inter = self.norm_pre(x) # applying final norm to last intermediate
|
||||||
|
else:
|
||||||
|
x_inter = x
|
||||||
|
intermediates.append(x_inter)
|
||||||
|
|
||||||
if intermediates_only:
|
if intermediates_only:
|
||||||
return intermediates
|
return intermediates
|
||||||
|
|
||||||
x = self.norm_pre(x)
|
if feat_idx == last_idx:
|
||||||
|
x = self.norm_pre(x)
|
||||||
|
|
||||||
return x, intermediates
|
return x, intermediates
|
||||||
|
|
||||||
@ -336,8 +340,10 @@ class RDNet(nn.Module):
|
|||||||
):
|
):
|
||||||
""" Prune layers not required for specified intermediates.
|
""" Prune layers not required for specified intermediates.
|
||||||
"""
|
"""
|
||||||
take_indices, max_index = feature_take_indices(len(self.dense_stages) + 1, indices)
|
stage_ends = [int(info['module'].split('.')[-1]) for info in self.feature_info]
|
||||||
self.dense_stages = self.dense_stages[:max_index] # truncate blocks w/ stem as idx 0
|
take_indices, max_index = feature_take_indices(len(stage_ends), indices)
|
||||||
|
max_index = stage_ends[max_index]
|
||||||
|
self.dense_stages = self.dense_stages[:max_index + 1] # truncate blocks w/ stem as idx 0
|
||||||
if prune_norm:
|
if prune_norm:
|
||||||
self.norm_pre = nn.Identity()
|
self.norm_pre = nn.Identity()
|
||||||
if prune_head:
|
if prune_head:
|
||||||
@ -355,6 +361,7 @@ class RDNet(nn.Module):
|
|||||||
def forward_features(self, x):
|
def forward_features(self, x):
|
||||||
x = self.stem(x)
|
x = self.stem(x)
|
||||||
x = self.dense_stages(x)
|
x = self.dense_stages(x)
|
||||||
|
x = self.norm_pre(x)
|
||||||
return x
|
return x
|
||||||
|
|
||||||
def forward_head(self, x, pre_logits: bool = False):
|
def forward_head(self, x, pre_logits: bool = False):
|
||||||
|
@ -6,7 +6,7 @@ Original implementation: https://github.com/ChengpengChen/RepGhost
|
|||||||
"""
|
"""
|
||||||
import copy
|
import copy
|
||||||
from functools import partial
|
from functools import partial
|
||||||
from typing import Optional
|
from typing import List, Optional, Tuple, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
@ -16,6 +16,7 @@ from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
|
|||||||
from timm.layers import SelectAdaptivePool2d, Linear, make_divisible
|
from timm.layers import SelectAdaptivePool2d, Linear, make_divisible
|
||||||
from ._builder import build_model_with_cfg
|
from ._builder import build_model_with_cfg
|
||||||
from ._efficientnet_blocks import SqueezeExcite, ConvBnAct
|
from ._efficientnet_blocks import SqueezeExcite, ConvBnAct
|
||||||
|
from ._features import feature_take_indices
|
||||||
from ._manipulate import checkpoint_seq
|
from ._manipulate import checkpoint_seq
|
||||||
from ._registry import register_model, generate_default_cfgs
|
from ._registry import register_model, generate_default_cfgs
|
||||||
|
|
||||||
@ -294,6 +295,72 @@ class RepGhostNet(nn.Module):
|
|||||||
self.flatten = nn.Flatten(1) if global_pool else nn.Identity() # don't flatten if pooling disabled
|
self.flatten = nn.Flatten(1) if global_pool else nn.Identity() # don't flatten if pooling disabled
|
||||||
self.classifier = Linear(self.head_hidden_size, num_classes) if num_classes > 0 else nn.Identity()
|
self.classifier = Linear(self.head_hidden_size, num_classes) if num_classes > 0 else nn.Identity()
|
||||||
|
|
||||||
|
def forward_intermediates(
|
||||||
|
self,
|
||||||
|
x: torch.Tensor,
|
||||||
|
indices: Optional[Union[int, List[int]]] = None,
|
||||||
|
norm: bool = False,
|
||||||
|
stop_early: bool = False,
|
||||||
|
output_fmt: str = 'NCHW',
|
||||||
|
intermediates_only: bool = False,
|
||||||
|
) -> Union[List[torch.Tensor], Tuple[torch.Tensor, List[torch.Tensor]]]:
|
||||||
|
""" Forward features that returns intermediates.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
x: Input image tensor
|
||||||
|
indices: Take last n blocks if int, all if None, select matching indices if sequence
|
||||||
|
norm: Apply norm layer to compatible intermediates
|
||||||
|
stop_early: Stop iterating over blocks when last desired intermediate hit
|
||||||
|
output_fmt: Shape of intermediate feature outputs
|
||||||
|
intermediates_only: Only return intermediate features
|
||||||
|
Returns:
|
||||||
|
|
||||||
|
"""
|
||||||
|
assert output_fmt in ('NCHW',), 'Output shape must be NCHW.'
|
||||||
|
intermediates = []
|
||||||
|
stage_ends = [-1] + [int(info['module'].split('.')[-1]) for info in self.feature_info[1:]]
|
||||||
|
take_indices, max_index = feature_take_indices(len(stage_ends), indices)
|
||||||
|
take_indices = [stage_ends[i]+1 for i in take_indices]
|
||||||
|
max_index = stage_ends[max_index]
|
||||||
|
|
||||||
|
# forward pass
|
||||||
|
feat_idx = 0
|
||||||
|
x = self.conv_stem(x)
|
||||||
|
if feat_idx in take_indices:
|
||||||
|
intermediates.append(x)
|
||||||
|
x = self.bn1(x)
|
||||||
|
x = self.act1(x)
|
||||||
|
if torch.jit.is_scripting() or not stop_early: # can't slice blocks in torchscript
|
||||||
|
stages = self.blocks
|
||||||
|
else:
|
||||||
|
stages = self.blocks[:max_index + 1]
|
||||||
|
|
||||||
|
for feat_idx, stage in enumerate(stages, start=1):
|
||||||
|
x = stage(x)
|
||||||
|
if feat_idx in take_indices:
|
||||||
|
intermediates.append(x)
|
||||||
|
|
||||||
|
if intermediates_only:
|
||||||
|
return intermediates
|
||||||
|
|
||||||
|
return x, intermediates
|
||||||
|
|
||||||
|
def prune_intermediate_layers(
|
||||||
|
self,
|
||||||
|
indices: Union[int, List[int]] = 1,
|
||||||
|
prune_norm: bool = False,
|
||||||
|
prune_head: bool = True,
|
||||||
|
):
|
||||||
|
""" Prune layers not required for specified intermediates.
|
||||||
|
"""
|
||||||
|
stage_ends = [-1] + [int(info['module'].split('.')[-1]) for info in self.feature_info[1:]]
|
||||||
|
take_indices, max_index = feature_take_indices(len(stage_ends), indices)
|
||||||
|
max_index = stage_ends[max_index]
|
||||||
|
self.blocks = self.blocks[:max_index + 1] # truncate blocks w/ stem as idx 0
|
||||||
|
if prune_head:
|
||||||
|
self.reset_classifier(0, '')
|
||||||
|
return take_indices
|
||||||
|
|
||||||
def forward_features(self, x):
|
def forward_features(self, x):
|
||||||
x = self.conv_stem(x)
|
x = self.conv_stem(x)
|
||||||
x = self.bn1(x)
|
x = self.bn1(x)
|
||||||
|
@ -14,9 +14,7 @@ Paper: `RepViT: Revisiting Mobile CNN From ViT Perspective`
|
|||||||
|
|
||||||
Adapted from official impl at https://github.com/jameslahm/RepViT
|
Adapted from official impl at https://github.com/jameslahm/RepViT
|
||||||
"""
|
"""
|
||||||
|
from typing import List, Optional, Tuple, Union
|
||||||
__all__ = ['RepVit']
|
|
||||||
from typing import Optional
|
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
@ -24,9 +22,12 @@ import torch.nn as nn
|
|||||||
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
|
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
|
||||||
from timm.layers import SqueezeExcite, trunc_normal_, to_ntuple, to_2tuple
|
from timm.layers import SqueezeExcite, trunc_normal_, to_ntuple, to_2tuple
|
||||||
from ._builder import build_model_with_cfg
|
from ._builder import build_model_with_cfg
|
||||||
|
from ._features import feature_take_indices
|
||||||
from ._manipulate import checkpoint_seq
|
from ._manipulate import checkpoint_seq
|
||||||
from ._registry import register_model, generate_default_cfgs
|
from ._registry import register_model, generate_default_cfgs
|
||||||
|
|
||||||
|
__all__ = ['RepVit']
|
||||||
|
|
||||||
|
|
||||||
class ConvNorm(nn.Sequential):
|
class ConvNorm(nn.Sequential):
|
||||||
def __init__(self, in_dim, out_dim, ks=1, stride=1, pad=0, dilation=1, groups=1, bn_weight_init=1):
|
def __init__(self, in_dim, out_dim, ks=1, stride=1, pad=0, dilation=1, groups=1, bn_weight_init=1):
|
||||||
@ -333,6 +334,62 @@ class RepVit(nn.Module):
|
|||||||
def set_distilled_training(self, enable=True):
|
def set_distilled_training(self, enable=True):
|
||||||
self.head.distilled_training = enable
|
self.head.distilled_training = enable
|
||||||
|
|
||||||
|
def forward_intermediates(
|
||||||
|
self,
|
||||||
|
x: torch.Tensor,
|
||||||
|
indices: Optional[Union[int, List[int]]] = None,
|
||||||
|
norm: bool = False,
|
||||||
|
stop_early: bool = False,
|
||||||
|
output_fmt: str = 'NCHW',
|
||||||
|
intermediates_only: bool = False,
|
||||||
|
) -> Union[List[torch.Tensor], Tuple[torch.Tensor, List[torch.Tensor]]]:
|
||||||
|
""" Forward features that returns intermediates.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
x: Input image tensor
|
||||||
|
indices: Take last n blocks if int, all if None, select matching indices if sequence
|
||||||
|
norm: Apply norm layer to compatible intermediates
|
||||||
|
stop_early: Stop iterating over blocks when last desired intermediate hit
|
||||||
|
output_fmt: Shape of intermediate feature outputs
|
||||||
|
intermediates_only: Only return intermediate features
|
||||||
|
Returns:
|
||||||
|
|
||||||
|
"""
|
||||||
|
assert output_fmt in ('NCHW',), 'Output shape must be NCHW.'
|
||||||
|
intermediates = []
|
||||||
|
take_indices, max_index = feature_take_indices(len(self.stages), indices)
|
||||||
|
|
||||||
|
# forward pass
|
||||||
|
x = self.stem(x)
|
||||||
|
if torch.jit.is_scripting() or not stop_early: # can't slice blocks in torchscript
|
||||||
|
stages = self.stages
|
||||||
|
else:
|
||||||
|
stages = self.stages[:max_index + 1]
|
||||||
|
|
||||||
|
for feat_idx, stage in enumerate(stages):
|
||||||
|
x = stage(x)
|
||||||
|
if feat_idx in take_indices:
|
||||||
|
intermediates.append(x)
|
||||||
|
|
||||||
|
if intermediates_only:
|
||||||
|
return intermediates
|
||||||
|
|
||||||
|
return x, intermediates
|
||||||
|
|
||||||
|
def prune_intermediate_layers(
|
||||||
|
self,
|
||||||
|
indices: Union[int, List[int]] = 1,
|
||||||
|
prune_norm: bool = False,
|
||||||
|
prune_head: bool = True,
|
||||||
|
):
|
||||||
|
""" Prune layers not required for specified intermediates.
|
||||||
|
"""
|
||||||
|
take_indices, max_index = feature_take_indices(len(self.stages), indices)
|
||||||
|
self.stages = self.stages[:max_index + 1] # truncate blocks w/ stem as idx 0
|
||||||
|
if prune_head:
|
||||||
|
self.reset_classifier(0, '')
|
||||||
|
return take_indices
|
||||||
|
|
||||||
def forward_features(self, x):
|
def forward_features(self, x):
|
||||||
x = self.stem(x)
|
x = self.stem(x)
|
||||||
if self.grad_checkpointing and not torch.jit.is_scripting():
|
if self.grad_checkpointing and not torch.jit.is_scripting():
|
||||||
|
@ -31,7 +31,7 @@ Original copyright of Google code below, modifications by Ross Wightman, Copyrig
|
|||||||
|
|
||||||
from collections import OrderedDict # pylint: disable=g-importing-member
|
from collections import OrderedDict # pylint: disable=g-importing-member
|
||||||
from functools import partial
|
from functools import partial
|
||||||
from typing import Optional
|
from typing import List, Optional, Tuple, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
@ -40,6 +40,7 @@ from timm.data import IMAGENET_INCEPTION_MEAN, IMAGENET_INCEPTION_STD
|
|||||||
from timm.layers import GroupNormAct, BatchNormAct2d, EvoNorm2dS0, FilterResponseNormTlu2d, ClassifierHead, \
|
from timm.layers import GroupNormAct, BatchNormAct2d, EvoNorm2dS0, FilterResponseNormTlu2d, ClassifierHead, \
|
||||||
DropPath, AvgPool2dSame, create_pool2d, StdConv2d, create_conv2d, get_act_layer, get_norm_act_layer, make_divisible
|
DropPath, AvgPool2dSame, create_pool2d, StdConv2d, create_conv2d, get_act_layer, get_norm_act_layer, make_divisible
|
||||||
from ._builder import build_model_with_cfg
|
from ._builder import build_model_with_cfg
|
||||||
|
from ._features import feature_take_indices
|
||||||
from ._manipulate import checkpoint_seq, named_apply, adapt_input_conv
|
from ._manipulate import checkpoint_seq, named_apply, adapt_input_conv
|
||||||
from ._registry import generate_default_cfgs, register_model, register_model_deprecations
|
from ._registry import generate_default_cfgs, register_model, register_model_deprecations
|
||||||
|
|
||||||
@ -543,6 +544,79 @@ class ResNetV2(nn.Module):
|
|||||||
self.num_classes = num_classes
|
self.num_classes = num_classes
|
||||||
self.head.reset(num_classes, global_pool)
|
self.head.reset(num_classes, global_pool)
|
||||||
|
|
||||||
|
def forward_intermediates(
|
||||||
|
self,
|
||||||
|
x: torch.Tensor,
|
||||||
|
indices: Optional[Union[int, List[int]]] = None,
|
||||||
|
norm: bool = False,
|
||||||
|
stop_early: bool = False,
|
||||||
|
output_fmt: str = 'NCHW',
|
||||||
|
intermediates_only: bool = False,
|
||||||
|
) -> Union[List[torch.Tensor], Tuple[torch.Tensor, List[torch.Tensor]]]:
|
||||||
|
""" Forward features that returns intermediates.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
x: Input image tensor
|
||||||
|
indices: Take last n blocks if int, all if None, select matching indices if sequence
|
||||||
|
norm: Apply norm layer to compatible intermediates
|
||||||
|
stop_early: Stop iterating over blocks when last desired intermediate hit
|
||||||
|
output_fmt: Shape of intermediate feature outputs
|
||||||
|
intermediates_only: Only return intermediate features
|
||||||
|
Returns:
|
||||||
|
|
||||||
|
"""
|
||||||
|
assert output_fmt in ('NCHW',), 'Output shape must be NCHW.'
|
||||||
|
intermediates = []
|
||||||
|
take_indices, max_index = feature_take_indices(5, indices)
|
||||||
|
|
||||||
|
# forward pass
|
||||||
|
feat_idx = 0
|
||||||
|
H, W = x.shape[-2:]
|
||||||
|
for stem in self.stem:
|
||||||
|
x = stem(x)
|
||||||
|
if x.shape[-2:] == (H //2, W //2):
|
||||||
|
x_down = x
|
||||||
|
if feat_idx in take_indices:
|
||||||
|
intermediates.append(x_down)
|
||||||
|
last_idx = len(self.stages)
|
||||||
|
if torch.jit.is_scripting() or not stop_early: # can't slice blocks in torchscript
|
||||||
|
stages = self.stages
|
||||||
|
else:
|
||||||
|
stages = self.stages[:max_index]
|
||||||
|
|
||||||
|
for feat_idx, stage in enumerate(stages, start=1):
|
||||||
|
x = stage(x)
|
||||||
|
if feat_idx in take_indices:
|
||||||
|
if feat_idx == last_idx:
|
||||||
|
x_inter = self.norm(x) if norm else x
|
||||||
|
intermediates.append(x_inter)
|
||||||
|
else:
|
||||||
|
intermediates.append(x)
|
||||||
|
|
||||||
|
if intermediates_only:
|
||||||
|
return intermediates
|
||||||
|
|
||||||
|
if feat_idx == last_idx:
|
||||||
|
x = self.norm(x)
|
||||||
|
|
||||||
|
return x, intermediates
|
||||||
|
|
||||||
|
def prune_intermediate_layers(
|
||||||
|
self,
|
||||||
|
indices: Union[int, List[int]] = 1,
|
||||||
|
prune_norm: bool = False,
|
||||||
|
prune_head: bool = True,
|
||||||
|
):
|
||||||
|
""" Prune layers not required for specified intermediates.
|
||||||
|
"""
|
||||||
|
take_indices, max_index = feature_take_indices(5, indices)
|
||||||
|
self.stages = self.stages[:max_index] # truncate blocks w/ stem as idx 0
|
||||||
|
if prune_norm:
|
||||||
|
self.norm = nn.Identity()
|
||||||
|
if prune_head:
|
||||||
|
self.reset_classifier(0, '')
|
||||||
|
return take_indices
|
||||||
|
|
||||||
def forward_features(self, x):
|
def forward_features(self, x):
|
||||||
x = self.stem(x)
|
x = self.stem(x)
|
||||||
if self.grad_checkpointing and not torch.jit.is_scripting():
|
if self.grad_checkpointing and not torch.jit.is_scripting():
|
||||||
|
@ -12,7 +12,7 @@ Copyright 2020 Ross Wightman
|
|||||||
|
|
||||||
from functools import partial
|
from functools import partial
|
||||||
from math import ceil
|
from math import ceil
|
||||||
from typing import Optional
|
from typing import List, Optional, Tuple, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
@ -21,6 +21,7 @@ from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
|
|||||||
from timm.layers import ClassifierHead, create_act_layer, ConvNormAct, DropPath, make_divisible, SEModule
|
from timm.layers import ClassifierHead, create_act_layer, ConvNormAct, DropPath, make_divisible, SEModule
|
||||||
from ._builder import build_model_with_cfg
|
from ._builder import build_model_with_cfg
|
||||||
from ._efficientnet_builder import efficientnet_init_weights
|
from ._efficientnet_builder import efficientnet_init_weights
|
||||||
|
from ._features import feature_take_indices
|
||||||
from ._manipulate import checkpoint_seq
|
from ._manipulate import checkpoint_seq
|
||||||
from ._registry import generate_default_cfgs, register_model
|
from ._registry import generate_default_cfgs, register_model
|
||||||
|
|
||||||
@ -234,6 +235,67 @@ class RexNet(nn.Module):
|
|||||||
self.num_classes = num_classes
|
self.num_classes = num_classes
|
||||||
self.head.reset(num_classes, global_pool)
|
self.head.reset(num_classes, global_pool)
|
||||||
|
|
||||||
|
def forward_intermediates(
|
||||||
|
self,
|
||||||
|
x: torch.Tensor,
|
||||||
|
indices: Optional[Union[int, List[int]]] = None,
|
||||||
|
norm: bool = False,
|
||||||
|
stop_early: bool = False,
|
||||||
|
output_fmt: str = 'NCHW',
|
||||||
|
intermediates_only: bool = False,
|
||||||
|
) -> Union[List[torch.Tensor], Tuple[torch.Tensor, List[torch.Tensor]]]:
|
||||||
|
""" Forward features that returns intermediates.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
x: Input image tensor
|
||||||
|
indices: Take last n blocks if int, all if None, select matching indices if sequence
|
||||||
|
norm: Apply norm layer to compatible intermediates
|
||||||
|
stop_early: Stop iterating over blocks when last desired intermediate hit
|
||||||
|
output_fmt: Shape of intermediate feature outputs
|
||||||
|
intermediates_only: Only return intermediate features
|
||||||
|
Returns:
|
||||||
|
|
||||||
|
"""
|
||||||
|
assert output_fmt in ('NCHW',), 'Output shape must be NCHW.'
|
||||||
|
intermediates = []
|
||||||
|
stage_ends = [int(info['module'].split('.')[-1]) for info in self.feature_info]
|
||||||
|
take_indices, max_index = feature_take_indices(len(stage_ends), indices)
|
||||||
|
take_indices = [stage_ends[i] for i in take_indices]
|
||||||
|
max_index = stage_ends[max_index]
|
||||||
|
|
||||||
|
# forward pass
|
||||||
|
x = self.stem(x)
|
||||||
|
if torch.jit.is_scripting() or not stop_early: # can't slice blocks in torchscript
|
||||||
|
stages = self.features
|
||||||
|
else:
|
||||||
|
stages = self.features[:max_index + 1]
|
||||||
|
|
||||||
|
for feat_idx, stage in enumerate(stages):
|
||||||
|
x = stage(x)
|
||||||
|
if feat_idx in take_indices:
|
||||||
|
intermediates.append(x)
|
||||||
|
|
||||||
|
if intermediates_only:
|
||||||
|
return intermediates
|
||||||
|
|
||||||
|
return x, intermediates
|
||||||
|
|
||||||
|
def prune_intermediate_layers(
|
||||||
|
self,
|
||||||
|
indices: Union[int, List[int]] = 1,
|
||||||
|
prune_norm: bool = False,
|
||||||
|
prune_head: bool = True,
|
||||||
|
):
|
||||||
|
""" Prune layers not required for specified intermediates.
|
||||||
|
"""
|
||||||
|
stage_ends = [int(info['module'].split('.')[-1]) for info in self.feature_info]
|
||||||
|
take_indices, max_index = feature_take_indices(len(stage_ends), indices)
|
||||||
|
max_index = stage_ends[max_index]
|
||||||
|
self.features = self.features[:max_index + 1] # truncate blocks w/ stem as idx 0
|
||||||
|
if prune_head:
|
||||||
|
self.reset_classifier(0, '')
|
||||||
|
return take_indices
|
||||||
|
|
||||||
def forward_features(self, x):
|
def forward_features(self, x):
|
||||||
x = self.stem(x)
|
x = self.stem(x)
|
||||||
if self.grad_checkpointing and not torch.jit.is_scripting():
|
if self.grad_checkpointing and not torch.jit.is_scripting():
|
||||||
|
@ -10,7 +10,7 @@ __all__ = ['TinyVit']
|
|||||||
|
|
||||||
import itertools
|
import itertools
|
||||||
from functools import partial
|
from functools import partial
|
||||||
from typing import Dict, Optional
|
from typing import Dict, List, Optional, Tuple, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
@ -20,6 +20,7 @@ from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
|
|||||||
from timm.layers import LayerNorm2d, NormMlpClassifierHead, DropPath,\
|
from timm.layers import LayerNorm2d, NormMlpClassifierHead, DropPath,\
|
||||||
trunc_normal_, resize_rel_pos_bias_table_levit, use_fused_attn
|
trunc_normal_, resize_rel_pos_bias_table_levit, use_fused_attn
|
||||||
from ._builder import build_model_with_cfg
|
from ._builder import build_model_with_cfg
|
||||||
|
from ._features import feature_take_indices
|
||||||
from ._features_fx import register_notrace_module
|
from ._features_fx import register_notrace_module
|
||||||
from ._manipulate import checkpoint_seq
|
from ._manipulate import checkpoint_seq
|
||||||
from ._registry import register_model, generate_default_cfgs
|
from ._registry import register_model, generate_default_cfgs
|
||||||
@ -536,6 +537,62 @@ class TinyVit(nn.Module):
|
|||||||
self.num_classes = num_classes
|
self.num_classes = num_classes
|
||||||
self.head.reset(num_classes, pool_type=global_pool)
|
self.head.reset(num_classes, pool_type=global_pool)
|
||||||
|
|
||||||
|
def forward_intermediates(
|
||||||
|
self,
|
||||||
|
x: torch.Tensor,
|
||||||
|
indices: Optional[Union[int, List[int]]] = None,
|
||||||
|
norm: bool = False,
|
||||||
|
stop_early: bool = False,
|
||||||
|
output_fmt: str = 'NCHW',
|
||||||
|
intermediates_only: bool = False,
|
||||||
|
) -> Union[List[torch.Tensor], Tuple[torch.Tensor, List[torch.Tensor]]]:
|
||||||
|
""" Forward features that returns intermediates.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
x: Input image tensor
|
||||||
|
indices: Take last n blocks if int, all if None, select matching indices if sequence
|
||||||
|
norm: Apply norm layer to compatible intermediates
|
||||||
|
stop_early: Stop iterating over blocks when last desired intermediate hit
|
||||||
|
output_fmt: Shape of intermediate feature outputs
|
||||||
|
intermediates_only: Only return intermediate features
|
||||||
|
Returns:
|
||||||
|
|
||||||
|
"""
|
||||||
|
assert output_fmt in ('NCHW',), 'Output shape must be NCHW.'
|
||||||
|
intermediates = []
|
||||||
|
take_indices, max_index = feature_take_indices(len(self.stages), indices)
|
||||||
|
|
||||||
|
# forward pass
|
||||||
|
x = self.patch_embed(x)
|
||||||
|
if torch.jit.is_scripting() or not stop_early: # can't slice blocks in torchscript
|
||||||
|
stages = self.stages
|
||||||
|
else:
|
||||||
|
stages = self.stages[:max_index + 1]
|
||||||
|
|
||||||
|
for feat_idx, stage in enumerate(stages):
|
||||||
|
x = stage(x)
|
||||||
|
if feat_idx in take_indices:
|
||||||
|
intermediates.append(x)
|
||||||
|
|
||||||
|
if intermediates_only:
|
||||||
|
return intermediates
|
||||||
|
|
||||||
|
return x, intermediates
|
||||||
|
|
||||||
|
def prune_intermediate_layers(
|
||||||
|
self,
|
||||||
|
indices: Union[int, List[int]] = 1,
|
||||||
|
prune_norm: bool = False,
|
||||||
|
prune_head: bool = True,
|
||||||
|
):
|
||||||
|
""" Prune layers not required for specified intermediates.
|
||||||
|
"""
|
||||||
|
take_indices, max_index = feature_take_indices(len(self.stages), indices)
|
||||||
|
self.stages = self.stages[:max_index + 1] # truncate blocks w/ stem as idx 0
|
||||||
|
if prune_head:
|
||||||
|
self.reset_classifier(0, '')
|
||||||
|
return take_indices
|
||||||
|
|
||||||
def forward_features(self, x):
|
def forward_features(self, x):
|
||||||
x = self.patch_embed(x)
|
x = self.patch_embed(x)
|
||||||
if self.grad_checkpointing and not torch.jit.is_scripting():
|
if self.grad_checkpointing and not torch.jit.is_scripting():
|
||||||
|
@ -7,13 +7,14 @@ Original model: https://github.com/mrT23/TResNet
|
|||||||
"""
|
"""
|
||||||
from collections import OrderedDict
|
from collections import OrderedDict
|
||||||
from functools import partial
|
from functools import partial
|
||||||
from typing import Optional
|
from typing import List, Optional, Tuple, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
|
|
||||||
from timm.layers import SpaceToDepth, BlurPool2d, ClassifierHead, SEModule, ConvNormAct, DropPath
|
from timm.layers import SpaceToDepth, BlurPool2d, ClassifierHead, SEModule, ConvNormAct, DropPath
|
||||||
from ._builder import build_model_with_cfg
|
from ._builder import build_model_with_cfg
|
||||||
|
from ._features import feature_take_indices
|
||||||
from ._manipulate import checkpoint_seq
|
from ._manipulate import checkpoint_seq
|
||||||
from ._registry import register_model, generate_default_cfgs, register_model_deprecations
|
from ._registry import register_model, generate_default_cfgs, register_model_deprecations
|
||||||
|
|
||||||
@ -228,6 +229,65 @@ class TResNet(nn.Module):
|
|||||||
self.num_classes = num_classes
|
self.num_classes = num_classes
|
||||||
self.head.reset(num_classes, pool_type=global_pool)
|
self.head.reset(num_classes, pool_type=global_pool)
|
||||||
|
|
||||||
|
def forward_intermediates(
|
||||||
|
self,
|
||||||
|
x: torch.Tensor,
|
||||||
|
indices: Optional[Union[int, List[int]]] = None,
|
||||||
|
norm: bool = False,
|
||||||
|
stop_early: bool = False,
|
||||||
|
output_fmt: str = 'NCHW',
|
||||||
|
intermediates_only: bool = False,
|
||||||
|
) -> Union[List[torch.Tensor], Tuple[torch.Tensor, List[torch.Tensor]]]:
|
||||||
|
""" Forward features that returns intermediates.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
x: Input image tensor
|
||||||
|
indices: Take last n blocks if int, all if None, select matching indices if sequence
|
||||||
|
norm: Apply norm layer to compatible intermediates
|
||||||
|
stop_early: Stop iterating over blocks when last desired intermediate hit
|
||||||
|
output_fmt: Shape of intermediate feature outputs
|
||||||
|
intermediates_only: Only return intermediate features
|
||||||
|
Returns:
|
||||||
|
|
||||||
|
"""
|
||||||
|
assert output_fmt in ('NCHW',), 'Output shape must be NCHW.'
|
||||||
|
intermediates = []
|
||||||
|
stage_ends = [1, 2, 3, 4, 5]
|
||||||
|
take_indices, max_index = feature_take_indices(len(stage_ends), indices)
|
||||||
|
take_indices = [stage_ends[i] for i in take_indices]
|
||||||
|
max_index = stage_ends[max_index]
|
||||||
|
# forward pass
|
||||||
|
if torch.jit.is_scripting() or not stop_early: # can't slice blocks in torchscript
|
||||||
|
stages = self.body
|
||||||
|
else:
|
||||||
|
stages = self.body[:max_index + 1]
|
||||||
|
|
||||||
|
for feat_idx, stage in enumerate(stages):
|
||||||
|
x = stage(x)
|
||||||
|
if feat_idx in take_indices:
|
||||||
|
intermediates.append(x)
|
||||||
|
|
||||||
|
if intermediates_only:
|
||||||
|
return intermediates
|
||||||
|
|
||||||
|
return x, intermediates
|
||||||
|
|
||||||
|
def prune_intermediate_layers(
|
||||||
|
self,
|
||||||
|
indices: Union[int, List[int]] = 1,
|
||||||
|
prune_norm: bool = False,
|
||||||
|
prune_head: bool = True,
|
||||||
|
):
|
||||||
|
""" Prune layers not required for specified intermediates.
|
||||||
|
"""
|
||||||
|
stage_ends = [1, 2, 3, 4, 5]
|
||||||
|
take_indices, max_index = feature_take_indices(len(stage_ends), indices)
|
||||||
|
max_index = stage_ends[max_index]
|
||||||
|
self.body = self.body[:max_index + 1] # truncate blocks w/ stem as idx 0
|
||||||
|
if prune_head:
|
||||||
|
self.reset_classifier(0, '')
|
||||||
|
return take_indices
|
||||||
|
|
||||||
def forward_features(self, x):
|
def forward_features(self, x):
|
||||||
if self.grad_checkpointing and not torch.jit.is_scripting():
|
if self.grad_checkpointing and not torch.jit.is_scripting():
|
||||||
x = self.body.s2d(x)
|
x = self.body.s2d(x)
|
||||||
|
@ -11,7 +11,7 @@ for some reference, rewrote most of the code.
|
|||||||
Hacked together by / Copyright 2020 Ross Wightman
|
Hacked together by / Copyright 2020 Ross Wightman
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from typing import List, Optional
|
from typing import List, Optional, Tuple, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
@ -20,6 +20,7 @@ from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
|
|||||||
from timm.layers import ConvNormAct, SeparableConvNormAct, BatchNormAct2d, ClassifierHead, DropPath, \
|
from timm.layers import ConvNormAct, SeparableConvNormAct, BatchNormAct2d, ClassifierHead, DropPath, \
|
||||||
create_attn, create_norm_act_layer
|
create_attn, create_norm_act_layer
|
||||||
from ._builder import build_model_with_cfg
|
from ._builder import build_model_with_cfg
|
||||||
|
from ._features import feature_take_indices
|
||||||
from ._manipulate import checkpoint_seq
|
from ._manipulate import checkpoint_seq
|
||||||
from ._registry import register_model, generate_default_cfgs
|
from ._registry import register_model, generate_default_cfgs
|
||||||
|
|
||||||
@ -264,6 +265,67 @@ class VovNet(nn.Module):
|
|||||||
self.num_classes = num_classes
|
self.num_classes = num_classes
|
||||||
self.head.reset(num_classes, global_pool)
|
self.head.reset(num_classes, global_pool)
|
||||||
|
|
||||||
|
def forward_intermediates(
|
||||||
|
self,
|
||||||
|
x: torch.Tensor,
|
||||||
|
indices: Optional[Union[int, List[int]]] = None,
|
||||||
|
norm: bool = False,
|
||||||
|
stop_early: bool = False,
|
||||||
|
output_fmt: str = 'NCHW',
|
||||||
|
intermediates_only: bool = False,
|
||||||
|
) -> Union[List[torch.Tensor], Tuple[torch.Tensor, List[torch.Tensor]]]:
|
||||||
|
""" Forward features that returns intermediates.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
x: Input image tensor
|
||||||
|
indices: Take last n blocks if int, all if None, select matching indices if sequence
|
||||||
|
norm: Apply norm layer to compatible intermediates
|
||||||
|
stop_early: Stop iterating over blocks when last desired intermediate hit
|
||||||
|
output_fmt: Shape of intermediate feature outputs
|
||||||
|
intermediates_only: Only return intermediate features
|
||||||
|
Returns:
|
||||||
|
|
||||||
|
"""
|
||||||
|
assert output_fmt in ('NCHW',), 'Output shape must be NCHW.'
|
||||||
|
intermediates = []
|
||||||
|
take_indices, max_index = feature_take_indices(5, indices)
|
||||||
|
|
||||||
|
# forward pass
|
||||||
|
feat_idx = 0
|
||||||
|
x = self.stem[:-1](x)
|
||||||
|
if feat_idx in take_indices:
|
||||||
|
intermediates.append(x)
|
||||||
|
|
||||||
|
x = self.stem[-1](x)
|
||||||
|
if torch.jit.is_scripting() or not stop_early: # can't slice blocks in torchscript
|
||||||
|
stages = self.stages
|
||||||
|
else:
|
||||||
|
stages = self.stages[:max_index]
|
||||||
|
|
||||||
|
for feat_idx, stage in enumerate(stages, start=1):
|
||||||
|
x = stage(x)
|
||||||
|
if feat_idx in take_indices:
|
||||||
|
intermediates.append(x)
|
||||||
|
|
||||||
|
if intermediates_only:
|
||||||
|
return intermediates
|
||||||
|
|
||||||
|
return x, intermediates
|
||||||
|
|
||||||
|
def prune_intermediate_layers(
|
||||||
|
self,
|
||||||
|
indices: Union[int, List[int]] = 1,
|
||||||
|
prune_norm: bool = False,
|
||||||
|
prune_head: bool = True,
|
||||||
|
):
|
||||||
|
""" Prune layers not required for specified intermediates.
|
||||||
|
"""
|
||||||
|
take_indices, max_index = feature_take_indices(5, indices)
|
||||||
|
self.stages = self.stages[:max_index] # truncate blocks w/ stem as idx 0
|
||||||
|
if prune_head:
|
||||||
|
self.reset_classifier(0, '')
|
||||||
|
return take_indices
|
||||||
|
|
||||||
def forward_features(self, x):
|
def forward_features(self, x):
|
||||||
x = self.stem(x)
|
x = self.stem(x)
|
||||||
return self.stages(x)
|
return self.stages(x)
|
||||||
|
@ -494,7 +494,8 @@ class Xcit(nn.Module):
|
|||||||
# NOTE not supporting return of class tokens
|
# NOTE not supporting return of class tokens
|
||||||
x = torch.cat((self.cls_token.expand(B, -1, -1), x), dim=1)
|
x = torch.cat((self.cls_token.expand(B, -1, -1), x), dim=1)
|
||||||
for blk in self.cls_attn_blocks:
|
for blk in self.cls_attn_blocks:
|
||||||
x = blk(x)
|
x = blk(x)
|
||||||
|
|
||||||
x = self.norm(x)
|
x = self.norm(x)
|
||||||
|
|
||||||
return x, intermediates
|
return x, intermediates
|
||||||
|
Loading…
x
Reference in New Issue
Block a user