From f8be741f0f478d83920b212ea9cff10ece941c70 Mon Sep 17 00:00:00 2001 From: Ryan <23580140+brianhou0208@users.noreply.github.com> Date: Mon, 5 May 2025 04:21:12 +0800 Subject: [PATCH] support rexnet, resnetv2, repvit and repghostnet --- timm/models/repghost.py | 69 ++++++++++++++++++++++++++++++++++++++++- timm/models/repvit.py | 63 +++++++++++++++++++++++++++++++++++-- timm/models/resnetv2.py | 67 ++++++++++++++++++++++++++++++++++++++- timm/models/rexnet.py | 64 +++++++++++++++++++++++++++++++++++++- 4 files changed, 257 insertions(+), 6 deletions(-) diff --git a/timm/models/repghost.py b/timm/models/repghost.py index 4b802d79..77fc35d5 100644 --- a/timm/models/repghost.py +++ b/timm/models/repghost.py @@ -6,7 +6,7 @@ Original implementation: https://github.com/ChengpengChen/RepGhost """ import copy from functools import partial -from typing import Optional +from typing import List, Optional, Tuple, Union import torch import torch.nn as nn @@ -16,6 +16,7 @@ from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD from timm.layers import SelectAdaptivePool2d, Linear, make_divisible from ._builder import build_model_with_cfg from ._efficientnet_blocks import SqueezeExcite, ConvBnAct +from ._features import feature_take_indices from ._manipulate import checkpoint_seq from ._registry import register_model, generate_default_cfgs @@ -294,6 +295,72 @@ class RepGhostNet(nn.Module): self.flatten = nn.Flatten(1) if global_pool else nn.Identity() # don't flatten if pooling disabled self.classifier = Linear(self.head_hidden_size, num_classes) if num_classes > 0 else nn.Identity() + def forward_intermediates( + self, + x: torch.Tensor, + indices: Optional[Union[int, List[int]]] = None, + norm: bool = False, + stop_early: bool = False, + output_fmt: str = 'NCHW', + intermediates_only: bool = False, + ) -> Union[List[torch.Tensor], Tuple[torch.Tensor, List[torch.Tensor]]]: + """ Forward features that returns intermediates. + + Args: + x: Input image tensor + indices: Take last n blocks if int, all if None, select matching indices if sequence + norm: Apply norm layer to compatible intermediates + stop_early: Stop iterating over blocks when last desired intermediate hit + output_fmt: Shape of intermediate feature outputs + intermediates_only: Only return intermediate features + Returns: + + """ + assert output_fmt in ('NCHW',), 'Output shape must be NCHW.' + intermediates = [] + stage_ends = [-1] + [int(info['module'].split('.')[-1]) for info in self.feature_info[1:]] + take_indices, max_index = feature_take_indices(len(stage_ends), indices) + take_indices = [stage_ends[i]+1 for i in take_indices] + max_index = stage_ends[max_index] + + # forward pass + feat_idx = 0 + x = self.conv_stem(x) + if feat_idx in take_indices: + intermediates.append(x) + x = self.bn1(x) + x = self.act1(x) + if torch.jit.is_scripting() or not stop_early: # can't slice blocks in torchscript + stages = self.blocks + else: + stages = self.blocks[:max_index + 1] + + for feat_idx, stage in enumerate(stages, start=1): + x = stage(x) + if feat_idx in take_indices: + intermediates.append(x) + + if intermediates_only: + return intermediates + + return x, intermediates + + def prune_intermediate_layers( + self, + indices: Union[int, List[int]] = 1, + prune_norm: bool = False, + prune_head: bool = True, + ): + """ Prune layers not required for specified intermediates. + """ + stage_ends = [-1] + [int(info['module'].split('.')[-1]) for info in self.feature_info[1:]] + take_indices, max_index = feature_take_indices(len(stage_ends), indices) + max_index = stage_ends[max_index] + self.blocks = self.blocks[:max_index + 1] # truncate blocks w/ stem as idx 0 + if prune_head: + self.reset_classifier(0, '') + return take_indices + def forward_features(self, x): x = self.conv_stem(x) x = self.bn1(x) diff --git a/timm/models/repvit.py b/timm/models/repvit.py index 7dcb2cd9..ddcfed55 100644 --- a/timm/models/repvit.py +++ b/timm/models/repvit.py @@ -14,9 +14,7 @@ Paper: `RepViT: Revisiting Mobile CNN From ViT Perspective` Adapted from official impl at https://github.com/jameslahm/RepViT """ - -__all__ = ['RepVit'] -from typing import Optional +from typing import List, Optional, Tuple, Union import torch import torch.nn as nn @@ -24,9 +22,12 @@ import torch.nn as nn from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD from timm.layers import SqueezeExcite, trunc_normal_, to_ntuple, to_2tuple from ._builder import build_model_with_cfg +from ._features import feature_take_indices from ._manipulate import checkpoint_seq from ._registry import register_model, generate_default_cfgs +__all__ = ['RepVit'] + class ConvNorm(nn.Sequential): def __init__(self, in_dim, out_dim, ks=1, stride=1, pad=0, dilation=1, groups=1, bn_weight_init=1): @@ -333,6 +334,62 @@ class RepVit(nn.Module): def set_distilled_training(self, enable=True): self.head.distilled_training = enable + def forward_intermediates( + self, + x: torch.Tensor, + indices: Optional[Union[int, List[int]]] = None, + norm: bool = False, + stop_early: bool = False, + output_fmt: str = 'NCHW', + intermediates_only: bool = False, + ) -> Union[List[torch.Tensor], Tuple[torch.Tensor, List[torch.Tensor]]]: + """ Forward features that returns intermediates. + + Args: + x: Input image tensor + indices: Take last n blocks if int, all if None, select matching indices if sequence + norm: Apply norm layer to compatible intermediates + stop_early: Stop iterating over blocks when last desired intermediate hit + output_fmt: Shape of intermediate feature outputs + intermediates_only: Only return intermediate features + Returns: + + """ + assert output_fmt in ('NCHW',), 'Output shape must be NCHW.' + intermediates = [] + take_indices, max_index = feature_take_indices(len(self.stages), indices) + + # forward pass + x = self.stem(x) + if torch.jit.is_scripting() or not stop_early: # can't slice blocks in torchscript + stages = self.stages + else: + stages = self.stages[:max_index + 1] + + for feat_idx, stage in enumerate(stages): + x = stage(x) + if feat_idx in take_indices: + intermediates.append(x) + + if intermediates_only: + return intermediates + + return x, intermediates + + def prune_intermediate_layers( + self, + indices: Union[int, List[int]] = 1, + prune_norm: bool = False, + prune_head: bool = True, + ): + """ Prune layers not required for specified intermediates. + """ + take_indices, max_index = feature_take_indices(len(self.stages), indices) + self.stages = self.stages[:max_index + 1] # truncate blocks w/ stem as idx 0 + if prune_head: + self.reset_classifier(0, '') + return take_indices + def forward_features(self, x): x = self.stem(x) if self.grad_checkpointing and not torch.jit.is_scripting(): diff --git a/timm/models/resnetv2.py b/timm/models/resnetv2.py index d7d2905b..1bac794c 100644 --- a/timm/models/resnetv2.py +++ b/timm/models/resnetv2.py @@ -31,7 +31,7 @@ Original copyright of Google code below, modifications by Ross Wightman, Copyrig from collections import OrderedDict # pylint: disable=g-importing-member from functools import partial -from typing import Optional +from typing import List, Optional, Tuple, Union import torch import torch.nn as nn @@ -40,6 +40,7 @@ from timm.data import IMAGENET_INCEPTION_MEAN, IMAGENET_INCEPTION_STD from timm.layers import GroupNormAct, BatchNormAct2d, EvoNorm2dS0, FilterResponseNormTlu2d, ClassifierHead, \ DropPath, AvgPool2dSame, create_pool2d, StdConv2d, create_conv2d, get_act_layer, get_norm_act_layer, make_divisible from ._builder import build_model_with_cfg +from ._features import feature_take_indices from ._manipulate import checkpoint_seq, named_apply, adapt_input_conv from ._registry import generate_default_cfgs, register_model, register_model_deprecations @@ -543,6 +544,70 @@ class ResNetV2(nn.Module): self.num_classes = num_classes self.head.reset(num_classes, global_pool) + def forward_intermediates( + self, + x: torch.Tensor, + indices: Optional[Union[int, List[int]]] = None, + norm: bool = False, + stop_early: bool = False, + output_fmt: str = 'NCHW', + intermediates_only: bool = False, + ) -> Union[List[torch.Tensor], Tuple[torch.Tensor, List[torch.Tensor]]]: + """ Forward features that returns intermediates. + + Args: + x: Input image tensor + indices: Take last n blocks if int, all if None, select matching indices if sequence + norm: Apply norm layer to compatible intermediates + stop_early: Stop iterating over blocks when last desired intermediate hit + output_fmt: Shape of intermediate feature outputs + intermediates_only: Only return intermediate features + Returns: + + """ + assert output_fmt in ('NCHW',), 'Output shape must be NCHW.' + intermediates = [] + take_indices, max_index = feature_take_indices(5, indices) + + # forward pass + feat_idx = 0 + x = self.stem(x) + if feat_idx in take_indices: + intermediates.append(x) + + if torch.jit.is_scripting() or not stop_early: # can't slice blocks in torchscript + stages = self.stages + else: + stages = self.stages[:max_index] + + for feat_idx, stage in enumerate(stages, start=1): + x = stage(x) + if feat_idx in take_indices: + intermediates.append(x) + + if intermediates_only: + return intermediates + + x = self.norm(x) + + return x, intermediates + + def prune_intermediate_layers( + self, + indices: Union[int, List[int]] = 1, + prune_norm: bool = False, + prune_head: bool = True, + ): + """ Prune layers not required for specified intermediates. + """ + take_indices, max_index = feature_take_indices(5, indices) + self.stages = self.stages[:max_index] # truncate blocks w/ stem as idx 0 + if prune_norm: + self.norm = nn.Identity() + if prune_head: + self.reset_classifier(0, '') + return take_indices + def forward_features(self, x): x = self.stem(x) if self.grad_checkpointing and not torch.jit.is_scripting(): diff --git a/timm/models/rexnet.py b/timm/models/rexnet.py index 9971728c..dd3cb4f3 100644 --- a/timm/models/rexnet.py +++ b/timm/models/rexnet.py @@ -12,7 +12,7 @@ Copyright 2020 Ross Wightman from functools import partial from math import ceil -from typing import Optional +from typing import List, Optional, Tuple, Union import torch import torch.nn as nn @@ -21,6 +21,7 @@ from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD from timm.layers import ClassifierHead, create_act_layer, ConvNormAct, DropPath, make_divisible, SEModule from ._builder import build_model_with_cfg from ._efficientnet_builder import efficientnet_init_weights +from ._features import feature_take_indices from ._manipulate import checkpoint_seq from ._registry import generate_default_cfgs, register_model @@ -234,6 +235,67 @@ class RexNet(nn.Module): self.num_classes = num_classes self.head.reset(num_classes, global_pool) + def forward_intermediates( + self, + x: torch.Tensor, + indices: Optional[Union[int, List[int]]] = None, + norm: bool = False, + stop_early: bool = False, + output_fmt: str = 'NCHW', + intermediates_only: bool = False, + ) -> Union[List[torch.Tensor], Tuple[torch.Tensor, List[torch.Tensor]]]: + """ Forward features that returns intermediates. + + Args: + x: Input image tensor + indices: Take last n blocks if int, all if None, select matching indices if sequence + norm: Apply norm layer to compatible intermediates + stop_early: Stop iterating over blocks when last desired intermediate hit + output_fmt: Shape of intermediate feature outputs + intermediates_only: Only return intermediate features + Returns: + + """ + assert output_fmt in ('NCHW',), 'Output shape must be NCHW.' + intermediates = [] + stage_ends = [int(info['module'].split('.')[-1]) for info in self.feature_info] + take_indices, max_index = feature_take_indices(len(stage_ends), indices) + take_indices = [stage_ends[i] for i in take_indices] + max_index = stage_ends[max_index] + + # forward pass + x = self.stem(x) + if torch.jit.is_scripting() or not stop_early: # can't slice blocks in torchscript + stages = self.features + else: + stages = self.features[:max_index + 1] + + for feat_idx, stage in enumerate(stages): + x = stage(x) + if feat_idx in take_indices: + intermediates.append(x) + + if intermediates_only: + return intermediates + + return x, intermediates + + def prune_intermediate_layers( + self, + indices: Union[int, List[int]] = 1, + prune_norm: bool = False, + prune_head: bool = True, + ): + """ Prune layers not required for specified intermediates. + """ + stage_ends = [int(info['module'].split('.')[-1]) for info in self.feature_info] + take_indices, max_index = feature_take_indices(len(stage_ends), indices) + max_index = stage_ends[max_index] + self.features = self.features[:max_index + 1] # truncate blocks w/ stem as idx 0 + if prune_head: + self.reset_classifier(0, '') + return take_indices + def forward_features(self, x): x = self.stem(x) if self.grad_checkpointing and not torch.jit.is_scripting():