support tiny_vit

This commit is contained in:
Ryan 2025-05-05 01:35:02 +08:00 committed by Ross Wightman
parent 8befebd93c
commit 5e8cc616d4
2 changed files with 58 additions and 3 deletions

View File

@ -10,7 +10,7 @@ __all__ = ['TinyVit']
import itertools import itertools
from functools import partial from functools import partial
from typing import Dict, Optional from typing import Dict, List, Optional, Tuple, Union
import torch import torch
import torch.nn as nn import torch.nn as nn
@ -20,6 +20,7 @@ from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
from timm.layers import LayerNorm2d, NormMlpClassifierHead, DropPath,\ from timm.layers import LayerNorm2d, NormMlpClassifierHead, DropPath,\
trunc_normal_, resize_rel_pos_bias_table_levit, use_fused_attn trunc_normal_, resize_rel_pos_bias_table_levit, use_fused_attn
from ._builder import build_model_with_cfg from ._builder import build_model_with_cfg
from ._features import feature_take_indices
from ._features_fx import register_notrace_module from ._features_fx import register_notrace_module
from ._manipulate import checkpoint_seq from ._manipulate import checkpoint_seq
from ._registry import register_model, generate_default_cfgs from ._registry import register_model, generate_default_cfgs
@ -536,6 +537,62 @@ class TinyVit(nn.Module):
self.num_classes = num_classes self.num_classes = num_classes
self.head.reset(num_classes, pool_type=global_pool) self.head.reset(num_classes, pool_type=global_pool)
def forward_intermediates(
self,
x: torch.Tensor,
indices: Optional[Union[int, List[int]]] = None,
norm: bool = False,
stop_early: bool = False,
output_fmt: str = 'NCHW',
intermediates_only: bool = False,
) -> Union[List[torch.Tensor], Tuple[torch.Tensor, List[torch.Tensor]]]:
""" Forward features that returns intermediates.
Args:
x: Input image tensor
indices: Take last n blocks if int, all if None, select matching indices if sequence
norm: Apply norm layer to compatible intermediates
stop_early: Stop iterating over blocks when last desired intermediate hit
output_fmt: Shape of intermediate feature outputs
intermediates_only: Only return intermediate features
Returns:
"""
assert output_fmt in ('NCHW',), 'Output shape must be NCHW.'
intermediates = []
take_indices, max_index = feature_take_indices(len(self.stages), indices)
# forward pass
x = self.patch_embed(x)
if torch.jit.is_scripting() or not stop_early: # can't slice blocks in torchscript
stages = self.stages
else:
stages = self.stages[:max_index + 1]
for feat_idx, stage in enumerate(stages):
x = stage(x)
if feat_idx in take_indices:
intermediates.append(x)
if intermediates_only:
return intermediates
return x, intermediates
def prune_intermediate_layers(
self,
indices: Union[int, List[int]] = 1,
prune_norm: bool = False,
prune_head: bool = True,
):
""" Prune layers not required for specified intermediates.
"""
take_indices, max_index = feature_take_indices(len(self.stages), indices)
self.stages = self.stages[:max_index + 1] # truncate blocks w/ stem as idx 0
if prune_head:
self.reset_classifier(0, '')
return take_indices
def forward_features(self, x): def forward_features(self, x):
x = self.patch_embed(x) x = self.patch_embed(x)
if self.grad_checkpointing and not torch.jit.is_scripting(): if self.grad_checkpointing and not torch.jit.is_scripting():

View File

@ -253,7 +253,6 @@ class TResNet(nn.Module):
assert output_fmt in ('NCHW',), 'Output shape must be NCHW.' assert output_fmt in ('NCHW',), 'Output shape must be NCHW.'
intermediates = [] intermediates = []
take_indices, max_index = feature_take_indices(len(self.body) - 1, indices) take_indices, max_index = feature_take_indices(len(self.body) - 1, indices)
print(take_indices, max_index)
# forward pass # forward pass
x = self.body[0](x) # s2d x = self.body[0](x) # s2d
@ -261,7 +260,6 @@ class TResNet(nn.Module):
stages = [self.body[1], self.body[2], self.body[3], self.body[4], self.body[5]] stages = [self.body[1], self.body[2], self.body[3], self.body[4], self.body[5]]
else: else:
stages = self.body[1:max_index + 2] stages = self.body[1:max_index + 2]
print(len(stages))
for feat_idx, stage in enumerate(stages): for feat_idx, stage in enumerate(stages):
x = stage(x) x = stage(x)