From 2e9b2a76fb1a797511bb1fefaeaef8a224012e1a Mon Sep 17 00:00:00 2001 From: Ryan <23580140+brianhou0208@users.noreply.github.com> Date: Thu, 8 May 2025 00:57:18 +0800 Subject: [PATCH] update some model --- timm/models/convnext.py | 20 ++++++------ timm/models/focalnet.py | 2 +- timm/models/mvitv2.py | 10 +++--- timm/models/pit.py | 68 ++++++++++++++++++++++++++++++++++++++++- timm/models/rdnet.py | 22 +++++++------ timm/models/resnetv2.py | 8 +++-- timm/models/xcit.py | 3 +- 7 files changed, 104 insertions(+), 29 deletions(-) diff --git a/timm/models/convnext.py b/timm/models/convnext.py index 47f2bf87..2f445118 100644 --- a/timm/models/convnext.py +++ b/timm/models/convnext.py @@ -452,29 +452,29 @@ class ConvNeXt(nn.Module): """ assert output_fmt in ('NCHW',), 'Output shape must be NCHW.' intermediates = [] - take_indices, max_index = feature_take_indices(len(self.stages) + 1, indices) + take_indices, max_index = feature_take_indices(len(self.stages), indices) # forward pass - feat_idx = 0 # stem is index 0 x = self.stem(x) - if feat_idx in take_indices: - intermediates.append(x) + last_idx = len(self.stages) - 1 if torch.jit.is_scripting() or not stop_early: # can't slice blocks in torchscript stages = self.stages else: - stages = self.stages[:max_index] - for stage in stages: - feat_idx += 1 + stages = self.stages[:max_index + 1] + for feat_idx, stage in enumerate(stages): x = stage(x) if feat_idx in take_indices: - # NOTE not bothering to apply norm_pre when norm=True as almost no models have it enabled - intermediates.append(x) + if norm and feat_idx == last_idx: + intermediates.append(self.norm_pre(x)) + else: + intermediates.append(x) if intermediates_only: return intermediates - x = self.norm_pre(x) + if feat_idx == last_idx: + x = self.norm_pre(x) return x, intermediates diff --git a/timm/models/focalnet.py b/timm/models/focalnet.py index 51ab4d08..ec7cd1cf 100644 --- a/timm/models/focalnet.py +++ b/timm/models/focalnet.py @@ -491,7 +491,7 @@ class FocalNet(nn.Module): else: stages = self.layers[:max_index + 1] - last_idx = len(self.layers) + last_idx = len(self.layers) - 1 for feat_idx, stage in enumerate(stages): x = stage(x) if feat_idx in take_indices: diff --git a/timm/models/mvitv2.py b/timm/models/mvitv2.py index f790fd0d..c048a072 100644 --- a/timm/models/mvitv2.py +++ b/timm/models/mvitv2.py @@ -870,10 +870,11 @@ class MultiScaleVit(nn.Module): if self.pos_embed is not None: x = x + self.pos_embed - for i, stage in enumerate(self.stages): + last_idx = len(self.stages) - 1 + for feat_idx, stage in enumerate(self.stages): x, feat_size = stage(x, feat_size) - if i in take_indices: - if norm and i == (len(self.stages) - 1): + if feat_idx in take_indices: + if norm and feat_idx == last_idx: x_inter = self.norm(x) # applying final norm last intermediate else: x_inter = x @@ -887,7 +888,8 @@ class MultiScaleVit(nn.Module): if intermediates_only: return intermediates - x = self.norm(x) + if feat_idx == last_idx: + x = self.norm(x) return x, intermediates diff --git a/timm/models/pit.py b/timm/models/pit.py index 3a1090b8..109cfaf8 100644 --- a/timm/models/pit.py +++ b/timm/models/pit.py @@ -14,7 +14,7 @@ Modifications for timm by / Copyright 2020 Ross Wightman import math import re from functools import partial -from typing import Optional, Sequence, Tuple +from typing import List, Optional, Sequence, Tuple, Union import torch from torch import nn @@ -22,6 +22,7 @@ from torch import nn from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD from timm.layers import trunc_normal_, to_2tuple from ._builder import build_model_with_cfg +from ._features import feature_take_indices from ._registry import register_model, generate_default_cfgs from .vision_transformer import Block @@ -254,6 +255,71 @@ class PoolingVisionTransformer(nn.Module): if self.head_dist is not None: self.head_dist = nn.Linear(self.embed_dim, self.num_classes) if num_classes > 0 else nn.Identity() + def forward_intermediates( + self, + x: torch.Tensor, + indices: Optional[Union[int, List[int]]] = None, + norm: bool = False, + stop_early: bool = False, + output_fmt: str = 'NCHW', + intermediates_only: bool = False, + ) -> Union[List[torch.Tensor], Tuple[torch.Tensor, List[torch.Tensor]]]: + """ Forward features that returns intermediates. + + Args: + x: Input image tensor + indices: Take last n blocks if int, all if None, select matching indices if sequence + norm: Apply norm layer to compatible intermediates + stop_early: Stop iterating over blocks when last desired intermediate hit + output_fmt: Shape of intermediate feature outputs + intermediates_only: Only return intermediate features + Returns: + + """ + assert output_fmt in ('NCHW',), 'Output shape must be NCHW.' + intermediates = [] + take_indices, max_index = feature_take_indices(len(self.transformers), indices) + + # forward pass + x = self.patch_embed(x) + x = self.pos_drop(x + self.pos_embed) + cls_tokens = self.cls_token.expand(x.shape[0], -1, -1) + + last_idx = len(self.transformers) - 1 + if torch.jit.is_scripting() or not stop_early: # can't slice blocks in torchscript + stages = self.transformers + else: + stages = self.transformers[:max_index + 1] + + for feat_idx, stage in enumerate(stages): + x, cls_tokens = stage((x, cls_tokens)) + if feat_idx in take_indices: + intermediates.append(x) + + if intermediates_only: + return intermediates + + if feat_idx == last_idx: + cls_tokens = self.norm(cls_tokens) + + return cls_tokens, intermediates + + def prune_intermediate_layers( + self, + indices: Union[int, List[int]] = 1, + prune_norm: bool = False, + prune_head: bool = True, + ): + """ Prune layers not required for specified intermediates. + """ + take_indices, max_index = feature_take_indices(len(self.transformers), indices) + self.stages = self.stages[:max_index + 1] # truncate blocks w/ stem as idx 0 + if prune_norm: + self.norm = nn.Identity() + if prune_head: + self.reset_classifier(0, '') + return take_indices + def forward_features(self, x): x = self.patch_embed(x) x = self.pos_drop(x + self.pos_embed) diff --git a/timm/models/rdnet.py b/timm/models/rdnet.py index 246030af..a3a205ff 100644 --- a/timm/models/rdnet.py +++ b/timm/models/rdnet.py @@ -302,20 +302,20 @@ class RDNet(nn.Module): """ assert output_fmt in ('NCHW',), 'Output shape must be NCHW.' intermediates = [] - take_indices, max_index = feature_take_indices(len(self.dense_stages) + 1, indices) + stage_ends = [int(info['module'].split('.')[-1]) for info in self.feature_info] + take_indices, max_index = feature_take_indices(len(stage_ends), indices) + take_indices = [stage_ends[i] for i in take_indices] + max_index = stage_ends[max_index] # forward pass - feat_idx = 0 # stem is index 0 x = self.stem(x) - if feat_idx in take_indices: - intermediates.append(x) - last_idx = len(self.dense_stages) + + last_idx = len(self.dense_stages) - 1 if torch.jit.is_scripting() or not stop_early: # can't slice blocks in torchscript dense_stages = self.dense_stages else: - dense_stages = self.dense_stages[:max_index] - for stage in dense_stages: - feat_idx += 1 + dense_stages = self.dense_stages[:max_index + 1] + for feat_idx, stage in enumerate(dense_stages): x = stage(x) if feat_idx in take_indices: if norm and feat_idx == last_idx: @@ -340,8 +340,10 @@ class RDNet(nn.Module): ): """ Prune layers not required for specified intermediates. """ - take_indices, max_index = feature_take_indices(len(self.dense_stages) + 1, indices) - self.dense_stages = self.dense_stages[:max_index] # truncate blocks w/ stem as idx 0 + stage_ends = [int(info['module'].split('.')[-1]) for info in self.feature_info] + take_indices, max_index = feature_take_indices(len(stage_ends), indices) + max_index = stage_ends[max_index] + self.dense_stages = self.dense_stages[:max_index + 1] # truncate blocks w/ stem as idx 0 if prune_norm: self.norm_pre = nn.Identity() if prune_head: diff --git a/timm/models/resnetv2.py b/timm/models/resnetv2.py index 1cc3b864..5cc164ae 100644 --- a/timm/models/resnetv2.py +++ b/timm/models/resnetv2.py @@ -571,9 +571,13 @@ class ResNetV2(nn.Module): # forward pass feat_idx = 0 - x = self.stem(x) + H, W = x.shape[-2:] + for stem in self.stem: + x = stem(x) + if x.shape[-2:] == (H //2, W //2): + x_down = x if feat_idx in take_indices: - intermediates.append(x) + intermediates.append(x_down) last_idx = len(self.stages) if torch.jit.is_scripting() or not stop_early: # can't slice blocks in torchscript stages = self.stages diff --git a/timm/models/xcit.py b/timm/models/xcit.py index e6cf87b7..250749f1 100644 --- a/timm/models/xcit.py +++ b/timm/models/xcit.py @@ -494,7 +494,8 @@ class Xcit(nn.Module): # NOTE not supporting return of class tokens x = torch.cat((self.cls_token.expand(B, -1, -1), x), dim=1) for blk in self.cls_attn_blocks: - x = blk(x) + x = blk(x) + x = self.norm(x) return x, intermediates