update some model

This commit is contained in:
Ryan 2025-05-08 00:57:18 +08:00 committed by Ross Wightman
parent 99c25fa5c0
commit 2e9b2a76fb
7 changed files with 104 additions and 29 deletions

View File

@ -452,29 +452,29 @@ class ConvNeXt(nn.Module):
"""
assert output_fmt in ('NCHW',), 'Output shape must be NCHW.'
intermediates = []
take_indices, max_index = feature_take_indices(len(self.stages) + 1, indices)
take_indices, max_index = feature_take_indices(len(self.stages), indices)
# forward pass
feat_idx = 0 # stem is index 0
x = self.stem(x)
if feat_idx in take_indices:
intermediates.append(x)
last_idx = len(self.stages) - 1
if torch.jit.is_scripting() or not stop_early: # can't slice blocks in torchscript
stages = self.stages
else:
stages = self.stages[:max_index]
for stage in stages:
feat_idx += 1
stages = self.stages[:max_index + 1]
for feat_idx, stage in enumerate(stages):
x = stage(x)
if feat_idx in take_indices:
# NOTE not bothering to apply norm_pre when norm=True as almost no models have it enabled
intermediates.append(x)
if norm and feat_idx == last_idx:
intermediates.append(self.norm_pre(x))
else:
intermediates.append(x)
if intermediates_only:
return intermediates
x = self.norm_pre(x)
if feat_idx == last_idx:
x = self.norm_pre(x)
return x, intermediates

View File

@ -491,7 +491,7 @@ class FocalNet(nn.Module):
else:
stages = self.layers[:max_index + 1]
last_idx = len(self.layers)
last_idx = len(self.layers) - 1
for feat_idx, stage in enumerate(stages):
x = stage(x)
if feat_idx in take_indices:

View File

@ -870,10 +870,11 @@ class MultiScaleVit(nn.Module):
if self.pos_embed is not None:
x = x + self.pos_embed
for i, stage in enumerate(self.stages):
last_idx = len(self.stages) - 1
for feat_idx, stage in enumerate(self.stages):
x, feat_size = stage(x, feat_size)
if i in take_indices:
if norm and i == (len(self.stages) - 1):
if feat_idx in take_indices:
if norm and feat_idx == last_idx:
x_inter = self.norm(x) # applying final norm last intermediate
else:
x_inter = x
@ -887,7 +888,8 @@ class MultiScaleVit(nn.Module):
if intermediates_only:
return intermediates
x = self.norm(x)
if feat_idx == last_idx:
x = self.norm(x)
return x, intermediates

View File

@ -14,7 +14,7 @@ Modifications for timm by / Copyright 2020 Ross Wightman
import math
import re
from functools import partial
from typing import Optional, Sequence, Tuple
from typing import List, Optional, Sequence, Tuple, Union
import torch
from torch import nn
@ -22,6 +22,7 @@ from torch import nn
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
from timm.layers import trunc_normal_, to_2tuple
from ._builder import build_model_with_cfg
from ._features import feature_take_indices
from ._registry import register_model, generate_default_cfgs
from .vision_transformer import Block
@ -254,6 +255,71 @@ class PoolingVisionTransformer(nn.Module):
if self.head_dist is not None:
self.head_dist = nn.Linear(self.embed_dim, self.num_classes) if num_classes > 0 else nn.Identity()
def forward_intermediates(
self,
x: torch.Tensor,
indices: Optional[Union[int, List[int]]] = None,
norm: bool = False,
stop_early: bool = False,
output_fmt: str = 'NCHW',
intermediates_only: bool = False,
) -> Union[List[torch.Tensor], Tuple[torch.Tensor, List[torch.Tensor]]]:
""" Forward features that returns intermediates.
Args:
x: Input image tensor
indices: Take last n blocks if int, all if None, select matching indices if sequence
norm: Apply norm layer to compatible intermediates
stop_early: Stop iterating over blocks when last desired intermediate hit
output_fmt: Shape of intermediate feature outputs
intermediates_only: Only return intermediate features
Returns:
"""
assert output_fmt in ('NCHW',), 'Output shape must be NCHW.'
intermediates = []
take_indices, max_index = feature_take_indices(len(self.transformers), indices)
# forward pass
x = self.patch_embed(x)
x = self.pos_drop(x + self.pos_embed)
cls_tokens = self.cls_token.expand(x.shape[0], -1, -1)
last_idx = len(self.transformers) - 1
if torch.jit.is_scripting() or not stop_early: # can't slice blocks in torchscript
stages = self.transformers
else:
stages = self.transformers[:max_index + 1]
for feat_idx, stage in enumerate(stages):
x, cls_tokens = stage((x, cls_tokens))
if feat_idx in take_indices:
intermediates.append(x)
if intermediates_only:
return intermediates
if feat_idx == last_idx:
cls_tokens = self.norm(cls_tokens)
return cls_tokens, intermediates
def prune_intermediate_layers(
self,
indices: Union[int, List[int]] = 1,
prune_norm: bool = False,
prune_head: bool = True,
):
""" Prune layers not required for specified intermediates.
"""
take_indices, max_index = feature_take_indices(len(self.transformers), indices)
self.stages = self.stages[:max_index + 1] # truncate blocks w/ stem as idx 0
if prune_norm:
self.norm = nn.Identity()
if prune_head:
self.reset_classifier(0, '')
return take_indices
def forward_features(self, x):
x = self.patch_embed(x)
x = self.pos_drop(x + self.pos_embed)

View File

@ -302,20 +302,20 @@ class RDNet(nn.Module):
"""
assert output_fmt in ('NCHW',), 'Output shape must be NCHW.'
intermediates = []
take_indices, max_index = feature_take_indices(len(self.dense_stages) + 1, indices)
stage_ends = [int(info['module'].split('.')[-1]) for info in self.feature_info]
take_indices, max_index = feature_take_indices(len(stage_ends), indices)
take_indices = [stage_ends[i] for i in take_indices]
max_index = stage_ends[max_index]
# forward pass
feat_idx = 0 # stem is index 0
x = self.stem(x)
if feat_idx in take_indices:
intermediates.append(x)
last_idx = len(self.dense_stages)
last_idx = len(self.dense_stages) - 1
if torch.jit.is_scripting() or not stop_early: # can't slice blocks in torchscript
dense_stages = self.dense_stages
else:
dense_stages = self.dense_stages[:max_index]
for stage in dense_stages:
feat_idx += 1
dense_stages = self.dense_stages[:max_index + 1]
for feat_idx, stage in enumerate(dense_stages):
x = stage(x)
if feat_idx in take_indices:
if norm and feat_idx == last_idx:
@ -340,8 +340,10 @@ class RDNet(nn.Module):
):
""" Prune layers not required for specified intermediates.
"""
take_indices, max_index = feature_take_indices(len(self.dense_stages) + 1, indices)
self.dense_stages = self.dense_stages[:max_index] # truncate blocks w/ stem as idx 0
stage_ends = [int(info['module'].split('.')[-1]) for info in self.feature_info]
take_indices, max_index = feature_take_indices(len(stage_ends), indices)
max_index = stage_ends[max_index]
self.dense_stages = self.dense_stages[:max_index + 1] # truncate blocks w/ stem as idx 0
if prune_norm:
self.norm_pre = nn.Identity()
if prune_head:

View File

@ -571,9 +571,13 @@ class ResNetV2(nn.Module):
# forward pass
feat_idx = 0
x = self.stem(x)
H, W = x.shape[-2:]
for stem in self.stem:
x = stem(x)
if x.shape[-2:] == (H //2, W //2):
x_down = x
if feat_idx in take_indices:
intermediates.append(x)
intermediates.append(x_down)
last_idx = len(self.stages)
if torch.jit.is_scripting() or not stop_early: # can't slice blocks in torchscript
stages = self.stages

View File

@ -494,7 +494,8 @@ class Xcit(nn.Module):
# NOTE not supporting return of class tokens
x = torch.cat((self.cls_token.expand(B, -1, -1), x), dim=1)
for blk in self.cls_attn_blocks:
x = blk(x)
x = blk(x)
x = self.norm(x)
return x, intermediates