support mambaout, metaformer, nest, nextvit, pvt_v2

This commit is contained in:
Ryan 2025-05-05 05:30:57 +08:00 committed by Ross Wightman
parent 2d3155908c
commit 6b6beffa6b
5 changed files with 305 additions and 4 deletions

View File

@ -6,7 +6,7 @@ MetaFormer (https://github.com/sail-sg/metaformer),
InceptionNeXt (https://github.com/sail-sg/inceptionnext) InceptionNeXt (https://github.com/sail-sg/inceptionnext)
""" """
from collections import OrderedDict from collections import OrderedDict
from typing import Optional from typing import List, Optional, Tuple, Union
import torch import torch
from torch import nn from torch import nn
@ -14,6 +14,7 @@ from torch import nn
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
from timm.layers import trunc_normal_, DropPath, LayerNorm, LayerScale, ClNormMlpClassifierHead, get_act_layer from timm.layers import trunc_normal_, DropPath, LayerNorm, LayerScale, ClNormMlpClassifierHead, get_act_layer
from ._builder import build_model_with_cfg from ._builder import build_model_with_cfg
from ._features import feature_take_indices
from ._manipulate import checkpoint_seq from ._manipulate import checkpoint_seq
from ._registry import register_model, generate_default_cfgs from ._registry import register_model, generate_default_cfgs
@ -417,6 +418,68 @@ class MambaOut(nn.Module):
self.num_classes = num_classes self.num_classes = num_classes
self.head.reset(num_classes, global_pool) self.head.reset(num_classes, global_pool)
def forward_intermediates(
self,
x: torch.Tensor,
indices: Optional[Union[int, List[int]]] = None,
norm: bool = False,
stop_early: bool = False,
output_fmt: str = 'NCHW',
intermediates_only: bool = False,
) -> Union[List[torch.Tensor], Tuple[torch.Tensor, List[torch.Tensor]]]:
""" Forward features that returns intermediates.
Args:
x: Input image tensor
indices: Take last n blocks if int, all if None, select matching indices if sequence
norm: Apply norm layer to compatible intermediates
stop_early: Stop iterating over blocks when last desired intermediate hit
output_fmt: Shape of intermediate feature outputs
intermediates_only: Only return intermediate features
Returns:
"""
assert output_fmt in ('NCHW', 'NHWC'), 'Output format must be one of NCHW or NHWC.'
channel_first = output_fmt == 'NCHW'
intermediates = []
take_indices, max_index = feature_take_indices(len(self.stages), indices)
# forward pass
x = self.stem(x)
if torch.jit.is_scripting() or not stop_early: # can't slice blocks in torchscript
stages = self.stages
else:
stages = self.stages[:max_index + 1]
for feat_idx, stage in enumerate(stages):
x = stage(x)
if feat_idx in take_indices:
intermediates.append(x)
if channel_first:
# reshape to BCHW output format
intermediates = [y.permute(0, 3, 1, 2).contiguous() for y in intermediates]
if intermediates_only:
return intermediates
return x, intermediates
def prune_intermediate_layers(
self,
indices: Union[int, List[int]] = 1,
prune_norm: bool = False,
prune_head: bool = True,
):
""" Prune layers not required for specified intermediates.
"""
take_indices, max_index = feature_take_indices(len(self.stages), indices)
self.stages = self.stages[:max_index + 1] # truncate blocks w/ stem as idx 0
if prune_head:
self.reset_classifier(0, '')
return take_indices
def forward_features(self, x): def forward_features(self, x):
x = self.stem(x) x = self.stem(x)
x = self.stages(x) x = self.stages(x)

View File

@ -28,7 +28,7 @@ Adapted from https://github.com/sail-sg/metaformer, original copyright below
from collections import OrderedDict from collections import OrderedDict
from functools import partial from functools import partial
from typing import Optional from typing import List, Optional, Tuple, Union
import torch import torch
import torch.nn as nn import torch.nn as nn
@ -40,6 +40,7 @@ from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
from timm.layers import trunc_normal_, DropPath, SelectAdaptivePool2d, GroupNorm1, LayerNorm, LayerNorm2d, Mlp, \ from timm.layers import trunc_normal_, DropPath, SelectAdaptivePool2d, GroupNorm1, LayerNorm, LayerNorm2d, Mlp, \
use_fused_attn use_fused_attn
from ._builder import build_model_with_cfg from ._builder import build_model_with_cfg
from ._features import feature_take_indices
from ._manipulate import checkpoint_seq from ._manipulate import checkpoint_seq
from ._registry import generate_default_cfgs, register_model from ._registry import generate_default_cfgs, register_model
@ -597,6 +598,62 @@ class MetaFormer(nn.Module):
final = nn.Identity() final = nn.Identity()
self.head.fc = final self.head.fc = final
def forward_intermediates(
self,
x: torch.Tensor,
indices: Optional[Union[int, List[int]]] = None,
norm: bool = False,
stop_early: bool = False,
output_fmt: str = 'NCHW',
intermediates_only: bool = False,
) -> Union[List[torch.Tensor], Tuple[torch.Tensor, List[torch.Tensor]]]:
""" Forward features that returns intermediates.
Args:
x: Input image tensor
indices: Take last n blocks if int, all if None, select matching indices if sequence
norm: Apply norm layer to compatible intermediates
stop_early: Stop iterating over blocks when last desired intermediate hit
output_fmt: Shape of intermediate feature outputs
intermediates_only: Only return intermediate features
Returns:
"""
assert output_fmt in ('NCHW',), 'Output shape must be NCHW.'
intermediates = []
take_indices, max_index = feature_take_indices(len(self.stages), indices)
# forward pass
x = self.stem(x)
if torch.jit.is_scripting() or not stop_early: # can't slice blocks in torchscript
stages = self.stages
else:
stages = self.stages[:max_index + 1]
for feat_idx, stage in enumerate(stages):
x = stage(x)
if feat_idx in take_indices:
intermediates.append(x)
if intermediates_only:
return intermediates
return x, intermediates
def prune_intermediate_layers(
self,
indices: Union[int, List[int]] = 1,
prune_norm: bool = False,
prune_head: bool = True,
):
""" Prune layers not required for specified intermediates.
"""
take_indices, max_index = feature_take_indices(len(self.stages), indices)
self.stages = self.stages[:max_index + 1] # truncate blocks w/ stem as idx 0
if prune_head:
self.reset_classifier(0, '')
return take_indices
def forward_head(self, x: Tensor, pre_logits: bool = False): def forward_head(self, x: Tensor, pre_logits: bool = False):
# NOTE nn.Sequential in head broken down since can't call head[:-1](x) in torchscript :( # NOTE nn.Sequential in head broken down since can't call head[:-1](x) in torchscript :(
x = self.head.global_pool(x) x = self.head.global_pool(x)

View File

@ -19,6 +19,7 @@ import collections.abc
import logging import logging
import math import math
from functools import partial from functools import partial
from typing import List, Optional, Tuple, Union
import torch import torch
import torch.nn.functional as F import torch.nn.functional as F
@ -28,6 +29,7 @@ from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
from timm.layers import PatchEmbed, Mlp, DropPath, create_classifier, trunc_normal_, _assert from timm.layers import PatchEmbed, Mlp, DropPath, create_classifier, trunc_normal_, _assert
from timm.layers import create_conv2d, create_pool2d, to_ntuple, use_fused_attn, LayerNorm from timm.layers import create_conv2d, create_pool2d, to_ntuple, use_fused_attn, LayerNorm
from ._builder import build_model_with_cfg from ._builder import build_model_with_cfg
from ._features import feature_take_indices
from ._features_fx import register_notrace_function from ._features_fx import register_notrace_function
from ._manipulate import checkpoint_seq, named_apply from ._manipulate import checkpoint_seq, named_apply
from ._registry import register_model, generate_default_cfgs, register_model_deprecations from ._registry import register_model, generate_default_cfgs, register_model_deprecations
@ -420,6 +422,67 @@ class Nest(nn.Module):
self.global_pool, self.head = create_classifier( self.global_pool, self.head = create_classifier(
self.num_features, self.num_classes, pool_type=global_pool) self.num_features, self.num_classes, pool_type=global_pool)
def forward_intermediates(
self,
x: torch.Tensor,
indices: Optional[Union[int, List[int]]] = None,
norm: bool = False,
stop_early: bool = False,
output_fmt: str = 'NCHW',
intermediates_only: bool = False,
) -> Union[List[torch.Tensor], Tuple[torch.Tensor, List[torch.Tensor]]]:
""" Forward features that returns intermediates.
Args:
x: Input image tensor
indices: Take last n blocks if int, all if None, select matching indices if sequence
norm: Apply norm layer to compatible intermediates
stop_early: Stop iterating over blocks when last desired intermediate hit
output_fmt: Shape of intermediate feature outputs
intermediates_only: Only return intermediate features
Returns:
"""
assert output_fmt in ('NCHW',), 'Output shape must be NCHW.'
intermediates = []
take_indices, max_index = feature_take_indices(len(self.levels), indices)
# forward pass
x = self.patch_embed(x)
if torch.jit.is_scripting() or not stop_early: # can't slice blocks in torchscript
stages = self.levels
else:
stages = self.levels[:max_index + 1]
for feat_idx, stage in enumerate(stages):
x = stage(x)
if feat_idx in take_indices:
intermediates.append(x)
if intermediates_only:
return intermediates
# Layer norm done over channel dim only (to NHWC and back)
x = self.norm(x.permute(0, 2, 3, 1)).permute(0, 3, 1, 2)
return x, intermediates
def prune_intermediate_layers(
self,
indices: Union[int, List[int]] = 1,
prune_norm: bool = False,
prune_head: bool = True,
):
""" Prune layers not required for specified intermediates.
"""
take_indices, max_index = feature_take_indices(len(self.levels), indices)
self.levels = self.levels[:max_index + 1] # truncate blocks w/ stem as idx 0
if prune_norm:
self.norm = nn.Identity()
if prune_head:
self.reset_classifier(0, '')
return take_indices
def forward_features(self, x): def forward_features(self, x):
x = self.patch_embed(x) x = self.patch_embed(x)
x = self.levels(x) x = self.levels(x)

View File

@ -6,7 +6,7 @@ Next-ViT model defs and weights adapted from https://github.com/bytedance/Next-V
""" """
# Copyright (c) ByteDance Inc. All rights reserved. # Copyright (c) ByteDance Inc. All rights reserved.
from functools import partial from functools import partial
from typing import Optional from typing import List, Optional, Tuple, Union
import torch import torch
import torch.nn.functional as F import torch.nn.functional as F
@ -16,6 +16,7 @@ from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
from timm.layers import DropPath, trunc_normal_, ConvMlp, get_norm_layer, get_act_layer, use_fused_attn from timm.layers import DropPath, trunc_normal_, ConvMlp, get_norm_layer, get_act_layer, use_fused_attn
from timm.layers import ClassifierHead from timm.layers import ClassifierHead
from ._builder import build_model_with_cfg from ._builder import build_model_with_cfg
from ._features import feature_take_indices
from ._features_fx import register_notrace_function from ._features_fx import register_notrace_function
from ._manipulate import checkpoint_seq from ._manipulate import checkpoint_seq
from ._registry import generate_default_cfgs, register_model from ._registry import generate_default_cfgs, register_model
@ -560,6 +561,66 @@ class NextViT(nn.Module):
self.num_classes = num_classes self.num_classes = num_classes
self.head.reset(num_classes, pool_type=global_pool) self.head.reset(num_classes, pool_type=global_pool)
def forward_intermediates(
self,
x: torch.Tensor,
indices: Optional[Union[int, List[int]]] = None,
norm: bool = False,
stop_early: bool = False,
output_fmt: str = 'NCHW',
intermediates_only: bool = False,
) -> Union[List[torch.Tensor], Tuple[torch.Tensor, List[torch.Tensor]]]:
""" Forward features that returns intermediates.
Args:
x: Input image tensor
indices: Take last n blocks if int, all if None, select matching indices if sequence
norm: Apply norm layer to compatible intermediates
stop_early: Stop iterating over blocks when last desired intermediate hit
output_fmt: Shape of intermediate feature outputs
intermediates_only: Only return intermediate features
Returns:
"""
assert output_fmt in ('NCHW',), 'Output shape must be NCHW.'
intermediates = []
take_indices, max_index = feature_take_indices(len(self.stages), indices)
# forward pass
x = self.stem(x)
if torch.jit.is_scripting() or not stop_early: # can't slice blocks in torchscript
stages = self.stages
else:
stages = self.stages[:max_index + 1]
for feat_idx, stage in enumerate(stages):
x = stage(x)
if feat_idx in take_indices:
intermediates.append(x)
if intermediates_only:
return intermediates
x = self.norm(x)
return x, intermediates
def prune_intermediate_layers(
self,
indices: Union[int, List[int]] = 1,
prune_norm: bool = False,
prune_head: bool = True,
):
""" Prune layers not required for specified intermediates.
"""
take_indices, max_index = feature_take_indices(len(self.stages), indices)
self.stages = self.stages[:max_index + 1] # truncate blocks w/ stem as idx 0
if prune_norm:
self.norm = nn.Identity()
if prune_head:
self.reset_classifier(0, '')
return take_indices
def forward_features(self, x): def forward_features(self, x):
x = self.stem(x) x = self.stem(x)
if self.grad_checkpointing and not torch.jit.is_scripting(): if self.grad_checkpointing and not torch.jit.is_scripting():

View File

@ -16,7 +16,7 @@ Modifications and timm support by / Copyright 2022, Ross Wightman
""" """
import math import math
from typing import Callable, List, Optional, Union from typing import Callable, List, Optional, Tuple, Union
import torch import torch
import torch.nn as nn import torch.nn as nn
@ -25,6 +25,7 @@ import torch.nn.functional as F
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
from timm.layers import DropPath, to_2tuple, to_ntuple, trunc_normal_, LayerNorm, use_fused_attn from timm.layers import DropPath, to_2tuple, to_ntuple, trunc_normal_, LayerNorm, use_fused_attn
from ._builder import build_model_with_cfg from ._builder import build_model_with_cfg
from ._features import feature_take_indices
from ._manipulate import checkpoint from ._manipulate import checkpoint
from ._registry import register_model, generate_default_cfgs from ._registry import register_model, generate_default_cfgs
@ -386,6 +387,62 @@ class PyramidVisionTransformerV2(nn.Module):
self.global_pool = global_pool self.global_pool = global_pool
self.head = nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity() self.head = nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity()
def forward_intermediates(
self,
x: torch.Tensor,
indices: Optional[Union[int, List[int]]] = None,
norm: bool = False,
stop_early: bool = False,
output_fmt: str = 'NCHW',
intermediates_only: bool = False,
) -> Union[List[torch.Tensor], Tuple[torch.Tensor, List[torch.Tensor]]]:
""" Forward features that returns intermediates.
Args:
x: Input image tensor
indices: Take last n blocks if int, all if None, select matching indices if sequence
norm: Apply norm layer to compatible intermediates
stop_early: Stop iterating over blocks when last desired intermediate hit
output_fmt: Shape of intermediate feature outputs
intermediates_only: Only return intermediate features
Returns:
"""
assert output_fmt in ('NCHW',), 'Output shape must be NCHW.'
intermediates = []
take_indices, max_index = feature_take_indices(len(self.stages), indices)
# forward pass
x = self.patch_embed(x)
if torch.jit.is_scripting() or not stop_early: # can't slice blocks in torchscript
stages = self.stages
else:
stages = self.stages[:max_index + 1]
for feat_idx, stage in enumerate(stages):
x = stage(x)
if feat_idx in take_indices:
intermediates.append(x)
if intermediates_only:
return intermediates
return x, intermediates
def prune_intermediate_layers(
self,
indices: Union[int, List[int]] = 1,
prune_norm: bool = False,
prune_head: bool = True,
):
""" Prune layers not required for specified intermediates.
"""
take_indices, max_index = feature_take_indices(len(self.stages), indices)
self.stages = self.stages[:max_index + 1] # truncate blocks w/ stem as idx 0
if prune_head:
self.reset_classifier(0, '')
return take_indices
def forward_features(self, x): def forward_features(self, x):
x = self.patch_embed(x) x = self.patch_embed(x)
x = self.stages(x) x = self.stages(x)