From 12def0d1180aa8d7b20e2c1a42a144145fd7340d Mon Sep 17 00:00:00 2001 From: Ryan <23580140+brianhou0208@users.noreply.github.com> Date: Tue, 6 May 2025 00:24:57 +0800 Subject: [PATCH] support efficientvit, edgenext, davit --- timm/models/davit.py | 69 +++++++++++++++++- timm/models/edgenext.py | 69 +++++++++++++++++- timm/models/efficientvit_mit.py | 117 ++++++++++++++++++++++++++++++- timm/models/efficientvit_msra.py | 60 +++++++++++++++- 4 files changed, 311 insertions(+), 4 deletions(-) diff --git a/timm/models/davit.py b/timm/models/davit.py index 65009888..f538ecca 100644 --- a/timm/models/davit.py +++ b/timm/models/davit.py @@ -12,7 +12,7 @@ DaViT model defs and weights adapted from https://github.com/dingmyu/davit, orig # All rights reserved. # This source code is licensed under the MIT license from functools import partial -from typing import Optional, Tuple +from typing import List, Optional, Tuple, Union import torch import torch.nn as nn @@ -23,6 +23,7 @@ from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD from timm.layers import DropPath, to_2tuple, trunc_normal_, Mlp, LayerNorm2d, get_norm_layer, use_fused_attn from timm.layers import NormMlpClassifierHead, ClassifierHead from ._builder import build_model_with_cfg +from ._features import feature_take_indices from ._features_fx import register_notrace_function from ._manipulate import checkpoint_seq from ._registry import generate_default_cfgs, register_model @@ -636,6 +637,72 @@ class DaVit(nn.Module): self.num_classes = num_classes self.head.reset(num_classes, global_pool) + def forward_intermediates( + self, + x: torch.Tensor, + indices: Optional[Union[int, List[int]]] = None, + norm: bool = False, + stop_early: bool = False, + output_fmt: str = 'NCHW', + intermediates_only: bool = False, + ) -> Union[List[torch.Tensor], Tuple[torch.Tensor, List[torch.Tensor]]]: + """ Forward features that returns intermediates. + + Args: + x: Input image tensor + indices: Take last n blocks if int, all if None, select matching indices if sequence + norm: Apply norm layer to compatible intermediates + stop_early: Stop iterating over blocks when last desired intermediate hit + output_fmt: Shape of intermediate feature outputs + intermediates_only: Only return intermediate features + Returns: + + """ + assert output_fmt in ('NCHW',), 'Output shape must be NCHW.' + intermediates = [] + take_indices, max_index = feature_take_indices(len(self.stages), indices) + + # forward pass + x = self.stem(x) + last_idx = len(self.stages) - 1 + if torch.jit.is_scripting() or not stop_early: # can't slice blocks in torchscript + stages = self.stages + else: + stages = self.stages[:max_index + 1] + + for feat_idx, stage in enumerate(stages): + x = stage(x) + if feat_idx in take_indices: + if norm and feat_idx == last_idx: + x_inter = self.norm_pre(x) # applying final norm to last intermediate + else: + x_inter = x + intermediates.append(x_inter) + + if intermediates_only: + return intermediates + + if feat_idx == last_idx: + x = self.norm_pre(x) + + return x, intermediates + + def prune_intermediate_layers( + self, + indices: Union[int, List[int]] = 1, + prune_norm: bool = False, + prune_head: bool = True, + ): + """ Prune layers not required for specified intermediates. + """ + take_indices, max_index = feature_take_indices(len(self.stages), indices) + self.stages = self.stages[:max_index + 1] # truncate blocks w/ stem as idx 0 + if prune_norm: + self.norm_pre = nn.Identity() + if prune_head: + self.reset_classifier(0, '') + return take_indices + def forward_features(self, x): x = self.stem(x) if self.grad_checkpointing and not torch.jit.is_scripting(): diff --git a/timm/models/edgenext.py b/timm/models/edgenext.py index d768b1dc..e21be971 100644 --- a/timm/models/edgenext.py +++ b/timm/models/edgenext.py @@ -9,7 +9,7 @@ Modifications and additions for timm by / Copyright 2022, Ross Wightman """ import math from functools import partial -from typing import Optional, Tuple +from typing import List, Optional, Tuple, Union import torch import torch.nn.functional as F @@ -19,6 +19,7 @@ from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD from timm.layers import trunc_normal_tf_, DropPath, LayerNorm2d, Mlp, create_conv2d, \ NormMlpClassifierHead, ClassifierHead from ._builder import build_model_with_cfg +from ._features import feature_take_indices from ._features_fx import register_notrace_module from ._manipulate import named_apply, checkpoint_seq from ._registry import register_model, generate_default_cfgs @@ -418,6 +419,72 @@ class EdgeNeXt(nn.Module): self.num_classes = num_classes self.head.reset(num_classes, global_pool) + def forward_intermediates( + self, + x: torch.Tensor, + indices: Optional[Union[int, List[int]]] = None, + norm: bool = False, + stop_early: bool = False, + output_fmt: str = 'NCHW', + intermediates_only: bool = False, + ) -> Union[List[torch.Tensor], Tuple[torch.Tensor, List[torch.Tensor]]]: + """ Forward features that returns intermediates. + + Args: + x: Input image tensor + indices: Take last n blocks if int, all if None, select matching indices if sequence + norm: Apply norm layer to compatible intermediates + stop_early: Stop iterating over blocks when last desired intermediate hit + output_fmt: Shape of intermediate feature outputs + intermediates_only: Only return intermediate features + Returns: + + """ + assert output_fmt in ('NCHW',), 'Output shape must be NCHW.' + intermediates = [] + take_indices, max_index = feature_take_indices(len(self.stages), indices) + + # forward pass + x = self.stem(x) + last_idx = len(self.stages) - 1 + if torch.jit.is_scripting() or not stop_early: # can't slice blocks in torchscript + stages = self.stages + else: + stages = self.stages[:max_index + 1] + + for feat_idx, stage in enumerate(stages): + x = stage(x) + if feat_idx in take_indices: + if norm and feat_idx == last_idx: + x_inter = self.norm_pre(x) # applying final norm to last intermediate + else: + x_inter = x + intermediates.append(x_inter) + + if intermediates_only: + return intermediates + + if feat_idx == last_idx: + x = self.norm_pre(x) + + return x, intermediates + + def prune_intermediate_layers( + self, + indices: Union[int, List[int]] = 1, + prune_norm: bool = False, + prune_head: bool = True, + ): + """ Prune layers not required for specified intermediates. + """ + take_indices, max_index = feature_take_indices(len(self.stages), indices) + self.stages = self.stages[:max_index + 1] # truncate blocks w/ stem as idx 0 + if prune_norm: + self.norm_pre = nn.Identity() + if prune_head: + self.reset_classifier(0, '') + return take_indices + def forward_features(self, x): x = self.stem(x) x = self.stages(x) diff --git a/timm/models/efficientvit_mit.py b/timm/models/efficientvit_mit.py index 34be806b..27872310 100644 --- a/timm/models/efficientvit_mit.py +++ b/timm/models/efficientvit_mit.py @@ -7,7 +7,7 @@ Adapted from official impl at https://github.com/mit-han-lab/efficientvit """ __all__ = ['EfficientVit', 'EfficientVitLarge'] -from typing import List, Optional +from typing import List, Optional, Tuple, Union from functools import partial import torch @@ -17,6 +17,7 @@ import torch.nn.functional as F from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD from timm.layers import SelectAdaptivePool2d, create_conv2d, GELUTanh from ._builder import build_model_with_cfg +from ._features import feature_take_indices from ._features_fx import register_notrace_module from ._manipulate import checkpoint_seq from ._registry import register_model, generate_default_cfgs @@ -754,6 +755,63 @@ class EfficientVit(nn.Module): self.num_classes = num_classes self.head.reset(num_classes, global_pool) + def forward_intermediates( + self, + x: torch.Tensor, + indices: Optional[Union[int, List[int]]] = None, + norm: bool = False, + stop_early: bool = False, + output_fmt: str = 'NCHW', + intermediates_only: bool = False, + ) -> Union[List[torch.Tensor], Tuple[torch.Tensor, List[torch.Tensor]]]: + """ Forward features that returns intermediates. + + Args: + x: Input image tensor + indices: Take last n blocks if int, all if None, select matching indices if sequence + norm: Apply norm layer to compatible intermediates + stop_early: Stop iterating over blocks when last desired intermediate hit + output_fmt: Shape of intermediate feature outputs + intermediates_only: Only return intermediate features + Returns: + + """ + assert output_fmt in ('NCHW',), 'Output shape must be NCHW.' + intermediates = [] + take_indices, max_index = feature_take_indices(len(self.stages), indices) + + # forward pass + x = self.stem(x) + + if torch.jit.is_scripting() or not stop_early: # can't slice blocks in torchscript + stages = self.stages + else: + stages = self.stages[:max_index + 1] + + for feat_idx, stage in enumerate(stages): + x = stage(x) + if feat_idx in take_indices: + intermediates.append(x) + + if intermediates_only: + return intermediates + + return x, intermediates + + def prune_intermediate_layers( + self, + indices: Union[int, List[int]] = 1, + prune_norm: bool = False, + prune_head: bool = True, + ): + """ Prune layers not required for specified intermediates. + """ + take_indices, max_index = feature_take_indices(len(self.stages), indices) + self.stages = self.stages[:max_index + 1] # truncate blocks w/ stem as idx 0 + if prune_head: + self.reset_classifier(0, '') + return take_indices + def forward_features(self, x): x = self.stem(x) if self.grad_checkpointing and not torch.jit.is_scripting(): @@ -851,6 +909,63 @@ class EfficientVitLarge(nn.Module): self.num_classes = num_classes self.head.reset(num_classes, global_pool) + def forward_intermediates( + self, + x: torch.Tensor, + indices: Optional[Union[int, List[int]]] = None, + norm: bool = False, + stop_early: bool = False, + output_fmt: str = 'NCHW', + intermediates_only: bool = False, + ) -> Union[List[torch.Tensor], Tuple[torch.Tensor, List[torch.Tensor]]]: + """ Forward features that returns intermediates. + + Args: + x: Input image tensor + indices: Take last n blocks if int, all if None, select matching indices if sequence + norm: Apply norm layer to compatible intermediates + stop_early: Stop iterating over blocks when last desired intermediate hit + output_fmt: Shape of intermediate feature outputs + intermediates_only: Only return intermediate features + Returns: + + """ + assert output_fmt in ('NCHW',), 'Output shape must be NCHW.' + intermediates = [] + take_indices, max_index = feature_take_indices(len(self.stages), indices) + + # forward pass + x = self.stem(x) + + if torch.jit.is_scripting() or not stop_early: # can't slice blocks in torchscript + stages = self.stages + else: + stages = self.stages[:max_index + 1] + + for feat_idx, stage in enumerate(stages): + x = stage(x) + if feat_idx in take_indices: + intermediates.append(x) + + if intermediates_only: + return intermediates + + return x, intermediates + + def prune_intermediate_layers( + self, + indices: Union[int, List[int]] = 1, + prune_norm: bool = False, + prune_head: bool = True, + ): + """ Prune layers not required for specified intermediates. + """ + take_indices, max_index = feature_take_indices(len(self.stages), indices) + self.stages = self.stages[:max_index + 1] # truncate blocks w/ stem as idx 0 + if prune_head: + self.reset_classifier(0, '') + return take_indices + def forward_features(self, x): x = self.stem(x) if self.grad_checkpointing and not torch.jit.is_scripting(): diff --git a/timm/models/efficientvit_msra.py b/timm/models/efficientvit_msra.py index 7e5c09a4..91caaa5a 100644 --- a/timm/models/efficientvit_msra.py +++ b/timm/models/efficientvit_msra.py @@ -9,7 +9,7 @@ Adapted from official impl at https://github.com/microsoft/Cream/tree/main/Effic __all__ = ['EfficientVitMsra'] import itertools from collections import OrderedDict -from typing import Dict, Optional +from typing import Dict, List, Optional, Tuple, Union import torch import torch.nn as nn @@ -17,6 +17,7 @@ import torch.nn as nn from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD from timm.layers import SqueezeExcite, SelectAdaptivePool2d, trunc_normal_, _assert from ._builder import build_model_with_cfg +from ._features import feature_take_indices from ._manipulate import checkpoint_seq from ._registry import register_model, generate_default_cfgs @@ -475,6 +476,63 @@ class EfficientVitMsra(nn.Module): self.head = NormLinear( self.num_features, num_classes, drop=self.drop_rate) if num_classes > 0 else torch.nn.Identity() + def forward_intermediates( + self, + x: torch.Tensor, + indices: Optional[Union[int, List[int]]] = None, + norm: bool = False, + stop_early: bool = False, + output_fmt: str = 'NCHW', + intermediates_only: bool = False, + ) -> Union[List[torch.Tensor], Tuple[torch.Tensor, List[torch.Tensor]]]: + """ Forward features that returns intermediates. + + Args: + x: Input image tensor + indices: Take last n blocks if int, all if None, select matching indices if sequence + norm: Apply norm layer to compatible intermediates + stop_early: Stop iterating over blocks when last desired intermediate hit + output_fmt: Shape of intermediate feature outputs + intermediates_only: Only return intermediate features + Returns: + + """ + assert output_fmt in ('NCHW',), 'Output shape must be NCHW.' + intermediates = [] + take_indices, max_index = feature_take_indices(len(self.stages), indices) + + # forward pass + x = self.patch_embed(x) + + if torch.jit.is_scripting() or not stop_early: # can't slice blocks in torchscript + stages = self.stages + else: + stages = self.stages[:max_index + 1] + + for feat_idx, stage in enumerate(stages): + x = stage(x) + if feat_idx in take_indices: + intermediates.append(x) + + if intermediates_only: + return intermediates + + return x, intermediates + + def prune_intermediate_layers( + self, + indices: Union[int, List[int]] = 1, + prune_norm: bool = False, + prune_head: bool = True, + ): + """ Prune layers not required for specified intermediates. + """ + take_indices, max_index = feature_take_indices(len(self.stages), indices) + self.stages = self.stages[:max_index + 1] # truncate blocks w/ stem as idx 0 + if prune_head: + self.reset_classifier(0, '') + return take_indices + def forward_features(self, x): x = self.patch_embed(x) if self.grad_checkpointing and not torch.jit.is_scripting():