support efficientvit, edgenext, davit

This commit is contained in:
Ryan 2025-05-06 00:24:57 +08:00 committed by Ross Wightman
parent 9aedecbb5f
commit 12def0d118
4 changed files with 311 additions and 4 deletions

View File

@ -12,7 +12,7 @@ DaViT model defs and weights adapted from https://github.com/dingmyu/davit, orig
# All rights reserved.
# This source code is licensed under the MIT license
from functools import partial
from typing import Optional, Tuple
from typing import List, Optional, Tuple, Union
import torch
import torch.nn as nn
@ -23,6 +23,7 @@ from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
from timm.layers import DropPath, to_2tuple, trunc_normal_, Mlp, LayerNorm2d, get_norm_layer, use_fused_attn
from timm.layers import NormMlpClassifierHead, ClassifierHead
from ._builder import build_model_with_cfg
from ._features import feature_take_indices
from ._features_fx import register_notrace_function
from ._manipulate import checkpoint_seq
from ._registry import generate_default_cfgs, register_model
@ -636,6 +637,72 @@ class DaVit(nn.Module):
self.num_classes = num_classes
self.head.reset(num_classes, global_pool)
def forward_intermediates(
self,
x: torch.Tensor,
indices: Optional[Union[int, List[int]]] = None,
norm: bool = False,
stop_early: bool = False,
output_fmt: str = 'NCHW',
intermediates_only: bool = False,
) -> Union[List[torch.Tensor], Tuple[torch.Tensor, List[torch.Tensor]]]:
""" Forward features that returns intermediates.
Args:
x: Input image tensor
indices: Take last n blocks if int, all if None, select matching indices if sequence
norm: Apply norm layer to compatible intermediates
stop_early: Stop iterating over blocks when last desired intermediate hit
output_fmt: Shape of intermediate feature outputs
intermediates_only: Only return intermediate features
Returns:
"""
assert output_fmt in ('NCHW',), 'Output shape must be NCHW.'
intermediates = []
take_indices, max_index = feature_take_indices(len(self.stages), indices)
# forward pass
x = self.stem(x)
last_idx = len(self.stages) - 1
if torch.jit.is_scripting() or not stop_early: # can't slice blocks in torchscript
stages = self.stages
else:
stages = self.stages[:max_index + 1]
for feat_idx, stage in enumerate(stages):
x = stage(x)
if feat_idx in take_indices:
if norm and feat_idx == last_idx:
x_inter = self.norm_pre(x) # applying final norm to last intermediate
else:
x_inter = x
intermediates.append(x_inter)
if intermediates_only:
return intermediates
if feat_idx == last_idx:
x = self.norm_pre(x)
return x, intermediates
def prune_intermediate_layers(
self,
indices: Union[int, List[int]] = 1,
prune_norm: bool = False,
prune_head: bool = True,
):
""" Prune layers not required for specified intermediates.
"""
take_indices, max_index = feature_take_indices(len(self.stages), indices)
self.stages = self.stages[:max_index + 1] # truncate blocks w/ stem as idx 0
if prune_norm:
self.norm_pre = nn.Identity()
if prune_head:
self.reset_classifier(0, '')
return take_indices
def forward_features(self, x):
x = self.stem(x)
if self.grad_checkpointing and not torch.jit.is_scripting():

View File

@ -9,7 +9,7 @@ Modifications and additions for timm by / Copyright 2022, Ross Wightman
"""
import math
from functools import partial
from typing import Optional, Tuple
from typing import List, Optional, Tuple, Union
import torch
import torch.nn.functional as F
@ -19,6 +19,7 @@ from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
from timm.layers import trunc_normal_tf_, DropPath, LayerNorm2d, Mlp, create_conv2d, \
NormMlpClassifierHead, ClassifierHead
from ._builder import build_model_with_cfg
from ._features import feature_take_indices
from ._features_fx import register_notrace_module
from ._manipulate import named_apply, checkpoint_seq
from ._registry import register_model, generate_default_cfgs
@ -418,6 +419,72 @@ class EdgeNeXt(nn.Module):
self.num_classes = num_classes
self.head.reset(num_classes, global_pool)
def forward_intermediates(
self,
x: torch.Tensor,
indices: Optional[Union[int, List[int]]] = None,
norm: bool = False,
stop_early: bool = False,
output_fmt: str = 'NCHW',
intermediates_only: bool = False,
) -> Union[List[torch.Tensor], Tuple[torch.Tensor, List[torch.Tensor]]]:
""" Forward features that returns intermediates.
Args:
x: Input image tensor
indices: Take last n blocks if int, all if None, select matching indices if sequence
norm: Apply norm layer to compatible intermediates
stop_early: Stop iterating over blocks when last desired intermediate hit
output_fmt: Shape of intermediate feature outputs
intermediates_only: Only return intermediate features
Returns:
"""
assert output_fmt in ('NCHW',), 'Output shape must be NCHW.'
intermediates = []
take_indices, max_index = feature_take_indices(len(self.stages), indices)
# forward pass
x = self.stem(x)
last_idx = len(self.stages) - 1
if torch.jit.is_scripting() or not stop_early: # can't slice blocks in torchscript
stages = self.stages
else:
stages = self.stages[:max_index + 1]
for feat_idx, stage in enumerate(stages):
x = stage(x)
if feat_idx in take_indices:
if norm and feat_idx == last_idx:
x_inter = self.norm_pre(x) # applying final norm to last intermediate
else:
x_inter = x
intermediates.append(x_inter)
if intermediates_only:
return intermediates
if feat_idx == last_idx:
x = self.norm_pre(x)
return x, intermediates
def prune_intermediate_layers(
self,
indices: Union[int, List[int]] = 1,
prune_norm: bool = False,
prune_head: bool = True,
):
""" Prune layers not required for specified intermediates.
"""
take_indices, max_index = feature_take_indices(len(self.stages), indices)
self.stages = self.stages[:max_index + 1] # truncate blocks w/ stem as idx 0
if prune_norm:
self.norm_pre = nn.Identity()
if prune_head:
self.reset_classifier(0, '')
return take_indices
def forward_features(self, x):
x = self.stem(x)
x = self.stages(x)

View File

@ -7,7 +7,7 @@ Adapted from official impl at https://github.com/mit-han-lab/efficientvit
"""
__all__ = ['EfficientVit', 'EfficientVitLarge']
from typing import List, Optional
from typing import List, Optional, Tuple, Union
from functools import partial
import torch
@ -17,6 +17,7 @@ import torch.nn.functional as F
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
from timm.layers import SelectAdaptivePool2d, create_conv2d, GELUTanh
from ._builder import build_model_with_cfg
from ._features import feature_take_indices
from ._features_fx import register_notrace_module
from ._manipulate import checkpoint_seq
from ._registry import register_model, generate_default_cfgs
@ -754,6 +755,63 @@ class EfficientVit(nn.Module):
self.num_classes = num_classes
self.head.reset(num_classes, global_pool)
def forward_intermediates(
self,
x: torch.Tensor,
indices: Optional[Union[int, List[int]]] = None,
norm: bool = False,
stop_early: bool = False,
output_fmt: str = 'NCHW',
intermediates_only: bool = False,
) -> Union[List[torch.Tensor], Tuple[torch.Tensor, List[torch.Tensor]]]:
""" Forward features that returns intermediates.
Args:
x: Input image tensor
indices: Take last n blocks if int, all if None, select matching indices if sequence
norm: Apply norm layer to compatible intermediates
stop_early: Stop iterating over blocks when last desired intermediate hit
output_fmt: Shape of intermediate feature outputs
intermediates_only: Only return intermediate features
Returns:
"""
assert output_fmt in ('NCHW',), 'Output shape must be NCHW.'
intermediates = []
take_indices, max_index = feature_take_indices(len(self.stages), indices)
# forward pass
x = self.stem(x)
if torch.jit.is_scripting() or not stop_early: # can't slice blocks in torchscript
stages = self.stages
else:
stages = self.stages[:max_index + 1]
for feat_idx, stage in enumerate(stages):
x = stage(x)
if feat_idx in take_indices:
intermediates.append(x)
if intermediates_only:
return intermediates
return x, intermediates
def prune_intermediate_layers(
self,
indices: Union[int, List[int]] = 1,
prune_norm: bool = False,
prune_head: bool = True,
):
""" Prune layers not required for specified intermediates.
"""
take_indices, max_index = feature_take_indices(len(self.stages), indices)
self.stages = self.stages[:max_index + 1] # truncate blocks w/ stem as idx 0
if prune_head:
self.reset_classifier(0, '')
return take_indices
def forward_features(self, x):
x = self.stem(x)
if self.grad_checkpointing and not torch.jit.is_scripting():
@ -851,6 +909,63 @@ class EfficientVitLarge(nn.Module):
self.num_classes = num_classes
self.head.reset(num_classes, global_pool)
def forward_intermediates(
self,
x: torch.Tensor,
indices: Optional[Union[int, List[int]]] = None,
norm: bool = False,
stop_early: bool = False,
output_fmt: str = 'NCHW',
intermediates_only: bool = False,
) -> Union[List[torch.Tensor], Tuple[torch.Tensor, List[torch.Tensor]]]:
""" Forward features that returns intermediates.
Args:
x: Input image tensor
indices: Take last n blocks if int, all if None, select matching indices if sequence
norm: Apply norm layer to compatible intermediates
stop_early: Stop iterating over blocks when last desired intermediate hit
output_fmt: Shape of intermediate feature outputs
intermediates_only: Only return intermediate features
Returns:
"""
assert output_fmt in ('NCHW',), 'Output shape must be NCHW.'
intermediates = []
take_indices, max_index = feature_take_indices(len(self.stages), indices)
# forward pass
x = self.stem(x)
if torch.jit.is_scripting() or not stop_early: # can't slice blocks in torchscript
stages = self.stages
else:
stages = self.stages[:max_index + 1]
for feat_idx, stage in enumerate(stages):
x = stage(x)
if feat_idx in take_indices:
intermediates.append(x)
if intermediates_only:
return intermediates
return x, intermediates
def prune_intermediate_layers(
self,
indices: Union[int, List[int]] = 1,
prune_norm: bool = False,
prune_head: bool = True,
):
""" Prune layers not required for specified intermediates.
"""
take_indices, max_index = feature_take_indices(len(self.stages), indices)
self.stages = self.stages[:max_index + 1] # truncate blocks w/ stem as idx 0
if prune_head:
self.reset_classifier(0, '')
return take_indices
def forward_features(self, x):
x = self.stem(x)
if self.grad_checkpointing and not torch.jit.is_scripting():

View File

@ -9,7 +9,7 @@ Adapted from official impl at https://github.com/microsoft/Cream/tree/main/Effic
__all__ = ['EfficientVitMsra']
import itertools
from collections import OrderedDict
from typing import Dict, Optional
from typing import Dict, List, Optional, Tuple, Union
import torch
import torch.nn as nn
@ -17,6 +17,7 @@ import torch.nn as nn
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
from timm.layers import SqueezeExcite, SelectAdaptivePool2d, trunc_normal_, _assert
from ._builder import build_model_with_cfg
from ._features import feature_take_indices
from ._manipulate import checkpoint_seq
from ._registry import register_model, generate_default_cfgs
@ -475,6 +476,63 @@ class EfficientVitMsra(nn.Module):
self.head = NormLinear(
self.num_features, num_classes, drop=self.drop_rate) if num_classes > 0 else torch.nn.Identity()
def forward_intermediates(
self,
x: torch.Tensor,
indices: Optional[Union[int, List[int]]] = None,
norm: bool = False,
stop_early: bool = False,
output_fmt: str = 'NCHW',
intermediates_only: bool = False,
) -> Union[List[torch.Tensor], Tuple[torch.Tensor, List[torch.Tensor]]]:
""" Forward features that returns intermediates.
Args:
x: Input image tensor
indices: Take last n blocks if int, all if None, select matching indices if sequence
norm: Apply norm layer to compatible intermediates
stop_early: Stop iterating over blocks when last desired intermediate hit
output_fmt: Shape of intermediate feature outputs
intermediates_only: Only return intermediate features
Returns:
"""
assert output_fmt in ('NCHW',), 'Output shape must be NCHW.'
intermediates = []
take_indices, max_index = feature_take_indices(len(self.stages), indices)
# forward pass
x = self.patch_embed(x)
if torch.jit.is_scripting() or not stop_early: # can't slice blocks in torchscript
stages = self.stages
else:
stages = self.stages[:max_index + 1]
for feat_idx, stage in enumerate(stages):
x = stage(x)
if feat_idx in take_indices:
intermediates.append(x)
if intermediates_only:
return intermediates
return x, intermediates
def prune_intermediate_layers(
self,
indices: Union[int, List[int]] = 1,
prune_norm: bool = False,
prune_head: bool = True,
):
""" Prune layers not required for specified intermediates.
"""
take_indices, max_index = feature_take_indices(len(self.stages), indices)
self.stages = self.stages[:max_index + 1] # truncate blocks w/ stem as idx 0
if prune_head:
self.reset_classifier(0, '')
return take_indices
def forward_features(self, x):
x = self.patch_embed(x)
if self.grad_checkpointing and not torch.jit.is_scripting():