More forward_intermediates() & features_only work

* forward_intermediates() added to beit, deit, eva, mvitv2, twins, vit, vit_sam
* add features_only to forward intermediates to allow just intermediate features
* fix #2060
* fix #1374
* fix #657
This commit is contained in:
Ross Wightman 2024-04-09 21:29:16 -07:00
parent 5fdc0b4e93
commit 679daef76a
8 changed files with 512 additions and 101 deletions

View File

@ -11,7 +11,7 @@ Hacked together by / Copyright 2020 Ross Wightman
from collections import OrderedDict, defaultdict
from copy import deepcopy
from functools import partial
from typing import Dict, List, Optional, Sequence, Tuple, Union
from typing import Dict, List, Optional, Sequence, Set, Tuple, Union
import torch
import torch.nn as nn
@ -20,7 +20,39 @@ from torch.utils.checkpoint import checkpoint
from timm.layers import Format
__all__ = ['FeatureInfo', 'FeatureHooks', 'FeatureDictNet', 'FeatureListNet', 'FeatureHookNet']
__all__ = [
'FeatureInfo', 'FeatureHooks', 'FeatureDictNet', 'FeatureListNet', 'FeatureHookNet', 'FeatureGetterNet',
'feature_take_indices'
]
def _take_indices(n: Union[int, List[int], Tuple[int]], num_blocks: int) -> Tuple[Set[int], int]:
if isinstance(n, int):
assert n >= 0
take_indices = {x for x in range(num_blocks - n, num_blocks)}
else:
take_indices = {num_blocks + idx if idx < 0 else idx for idx in n}
return take_indices, max(take_indices)
def _take_indices_jit(n: Union[int, List[int], Tuple[int]], num_blocks: int) -> Tuple[List[int], int]:
if isinstance(n, int):
assert n >= 0
take_indices = [num_blocks - n + i for i in range(n)]
elif isinstance(n, tuple):
# splitting this up is silly, but needed for torchscript type resolution of n
take_indices = [num_blocks + idx if idx < 0 else idx for idx in n]
else:
take_indices = [num_blocks + idx if idx < 0 else idx for idx in n]
return take_indices, max(take_indices)
def feature_take_indices(n: Union[int, List[int], Tuple[int]], num_blocks: int) -> Tuple[List[int], int]:
if torch.jit.is_scripting():
return _take_indices_jit(n, num_blocks)
else:
# NOTE non-jit returns Set[int] instead of List[int] but torchscript can't handle that anno
return _take_indices(n, num_blocks)
def _out_indices_as_tuple(x: Union[int, Tuple[int, ...]]) -> Tuple[int, ...]:
@ -397,29 +429,38 @@ class FeatureGetterNet(nn.ModuleDict):
out_map: Optional[Sequence[Union[int, str]]] = None,
return_dict: bool = False,
output_fmt: str = 'NCHW',
norm: bool = False,
prune: bool = True,
):
"""
Args:
model: Model to wrap.
out_indices: Indices of features to extract.
out_map: Remap feature names for dict output (WIP, not supported).
return_dict: Return features as dictionary instead of list (WIP, not supported).
norm: Apply final model norm to all output features (if possible).
"""
super().__init__()
self.model = model
if prune and hasattr(model, 'prune_intermediate_layers'):
model.prune_intermediate_layers(
out_indices,
prune_norm=not norm,
)
self.feature_info = _get_feature_info(model, out_indices)
self.model = model
self.out_indices = out_indices
self.out_map = out_map
self.return_dict = return_dict
self.output_fmt = output_fmt
self.norm = norm
def forward(self, *args, **kwargs):
"""
def get_intermediate_layers(
self,
x: torch.Tensor,
n: Union[int, Sequence] = 1,
reshape: bool = False,
return_prefix_tokens: bool = False,
norm: bool = False,
"""
out = self.model.get_intermediate_layers(
*args,
def forward(self, x):
features = self.model.forward_intermediates(
x,
n=self.out_indices,
reshape=True,
**kwargs,
norm=self.norm,
output_fmt=self.output_fmt,
features_only=True,
)
return out
return features

View File

@ -39,7 +39,7 @@ Modifications by / Copyright 2021 Ross Wightman, original copyrights below
# --------------------------------------------------------'
import math
from typing import Callable, Optional, Tuple, Union
from typing import Callable, List, Optional, Tuple, Union
import torch
import torch.nn as nn
@ -52,8 +52,8 @@ from timm.layers import resample_patch_embed, resample_abs_pos_embed, resize_rel
from ._builder import build_model_with_cfg
from ._features import feature_take_indices
from ._registry import generate_default_cfgs, register_model
from .vision_transformer import checkpoint_filter_fn
__all__ = ['Beit']
@ -333,6 +333,8 @@ class Beit(nn.Module):
window_size=self.patch_embed.grid_size if use_rel_pos_bias else None,
)
for i in range(depth)])
self.feature_info = [
dict(module=f'blocks.{i}', num_chs=embed_dim, reduction=patch_size) for i in range(depth)]
use_fc_norm = self.global_pool == 'avg'
self.norm = nn.Identity() if use_fc_norm else norm_layer(embed_dim)
@ -398,6 +400,93 @@ class Beit(nn.Module):
self.global_pool = global_pool
self.head = nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity()
def forward_intermediates(
self,
x: torch.Tensor,
n: Optional[Union[int, List[int], Tuple[int]]] = None,
return_prefix_tokens: bool = False,
norm: bool = False,
stop_early: bool = True,
output_fmt: str = 'NCHW',
features_only: bool = False,
) -> Union[List[torch.Tensor], Tuple[torch.Tensor, List[torch.Tensor]]]:
""" Forward features that returns intermediates.
Args:
x: Input image tensor
n: Take last n blocks if n is an int, if in is a sequence, select by matching indices
return_prefix_tokens: Return both prefix and spatial intermediate tokens
norm: Apply norm layer to all intermediates
stop_early: Stop iterating over blocks when last desired intermediate hit
output_fmt: Shape of intermediate feature outputs
features_only: Only return intermediate features
Returns:
"""
assert output_fmt in ('NCHW', 'NLC'), 'Output format for ViT features must be one of NCHW or NLC.'
reshape = output_fmt == 'NCHW'
intermediates = []
num_blocks = len(self.blocks)
if n is None:
n = num_blocks
take_indices, max_index = feature_take_indices(n, num_blocks)
# forward pass
B, _, height, width = x.shape
x = self.patch_embed(x)
x = torch.cat((self.cls_token.expand(x.shape[0], -1, -1), x), dim=1)
if self.pos_embed is not None:
x = x + self.pos_embed
x = self.pos_drop(x)
rel_pos_bias = self.rel_pos_bias() if self.rel_pos_bias is not None else None
if torch.jit.is_scripting() or not stop_early: # can't slice blocks in torchscript
blocks = self.blocks
else:
blocks = self.blocks[:max_index + 1]
for i, blk in enumerate(blocks):
x = blk(x, shared_rel_pos_bias=rel_pos_bias)
if i in take_indices:
# normalize intermediates with final norm layer if enabled
intermediates.append(self.norm(x) if norm else x)
# process intermediates
if self.num_prefix_tokens:
# split prefix (e.g. class, distill) and spatial feature tokens
prefix_tokens = [y[:, 0:self.num_prefix_tokens] for y in intermediates]
intermediates = [y[:, self.num_prefix_tokens:] for y in intermediates]
if reshape:
# reshape == True => BCHW output format
patch_size = self.patch_embed.patch_size
H = int(math.ceil(height / patch_size[0]))
W = int(math.ceil(width / patch_size[1]))
intermediates = [y.reshape(B, H, W, -1).permute(0, 3, 1, 2).contiguous() for y in intermediates]
if not torch.jit.is_scripting() and return_prefix_tokens:
# return_prefix not support in torchscript due to poor type handling
intermediates = list(zip(intermediates, prefix_tokens))
if features_only:
return intermediates
x = self.norm(x)
return x, intermediates
def prune_intermediate_layers(
self,
n: Union[int, List[int], Tuple[int]] = 1,
prune_norm: bool = False,
prune_head: bool = True,
):
""" Prune layers not required for specified intermediates.
"""
take_indices, max_index = feature_take_indices(n, len(self.blocks))
self.blocks = self.blocks[:max_index + 1] # truncate blocks
if prune_norm:
self.norm = nn.Identity()
if prune_head:
self.fc_norm = nn.Identity()
self.head = nn.Identity()
def forward_features(self, x):
x = self.patch_embed(x)
x = torch.cat((self.cls_token.expand(x.shape[0], -1, -1), x), dim=1)
@ -547,14 +636,13 @@ def _beit_checkpoint_filter_fn(state_dict, model, interpolation='bicubic', antia
def _create_beit(variant, pretrained=False, **kwargs):
if kwargs.get('features_only', None):
raise RuntimeError('features_only not implemented for BEiT models.')
out_indices = kwargs.pop('out_indices', 3)
model = build_model_with_cfg(
Beit, variant, pretrained,
# FIXME an updated filter fn needed to interpolate rel pos emb if fine tuning to diff model sizes
pretrained_filter_fn=_beit_checkpoint_filter_fn,
**kwargs)
feature_cfg=dict(out_indices=out_indices, feature_cls='getter'),
**kwargs,
)
return model

View File

@ -119,14 +119,14 @@ class VisionTransformerDistilled(VisionTransformer):
def _create_deit(variant, pretrained=False, distilled=False, **kwargs):
if kwargs.get('features_only', None):
raise RuntimeError('features_only not implemented for Vision Transformer models.')
out_indices = kwargs.pop('out_indices', 3)
model_cls = VisionTransformerDistilled if distilled else VisionTransformer
model = build_model_with_cfg(
model_cls,
variant,
pretrained,
pretrained_filter_fn=partial(checkpoint_filter_fn, adapt_layer_scale=True),
feature_cfg=dict(out_indices=out_indices, feature_cls='getter'),
**kwargs,
)
return model

View File

@ -24,9 +24,8 @@ Modifications by / Copyright 2023 Ross Wightman, original copyrights below
"""
# EVA models Copyright (c) 2022 BAAI-Vision
# EVA02 models Copyright (c) 2023 BAAI-Vision
import math
from typing import Callable, Optional, Tuple, Union
from typing import Callable, List, Optional, Tuple, Union
import torch
import torch.nn as nn
@ -39,6 +38,7 @@ from timm.layers import PatchEmbed, Mlp, GluMlp, SwiGLU, LayerNorm, DropPath, Pa
to_2tuple, use_fused_attn
from ._builder import build_model_with_cfg
from ._features import feature_take_indices
from ._registry import generate_default_cfgs, register_model
__all__ = ['Eva']
@ -469,6 +469,8 @@ class Eva(nn.Module):
init_values=init_values,
)
for i in range(depth)])
self.feature_info = [
dict(module=f'blocks.{i}', num_chs=embed_dim, reduction=patch_size) for i in range(depth)]
use_fc_norm = self.global_pool == 'avg'
self.norm = nn.Identity() if use_fc_norm else norm_layer(embed_dim)
@ -559,6 +561,85 @@ class Eva(nn.Module):
rot_pos_embed = apply_keep_indices_nlc(x, rot_pos_embed, keep_indices)
return x, rot_pos_embed
def forward_intermediates(
self,
x: torch.Tensor,
n: Optional[Union[int, List[int], Tuple[int]]] = None,
return_prefix_tokens: bool = False,
norm: bool = False,
stop_early: bool = True,
output_fmt: str = 'NCHW',
features_only: bool = False,
) -> Union[List[torch.Tensor], Tuple[torch.Tensor, List[torch.Tensor]]]:
""" Forward features that returns intermediates.
Args:
x: Input image tensor
n: Take last n blocks if n is an int, if in is a sequence, select by matching indices
return_prefix_tokens: Return both prefix and spatial intermediate tokens
norm: Apply norm layer to all intermediates
stop_early: Stop iterating over blocks when last desired intermediate hit
output_fmt: Shape of intermediate feature outputs
features_only: Only return intermediate features
"""
assert output_fmt in ('NCHW', 'NLC'), 'Output format for EVA-ViT features must be one of NCHW or NLC.'
reshape = output_fmt == 'NCHW'
intermediates = []
num_blocks = len(self.blocks)
if n is None:
n = num_blocks
take_indices, max_index = feature_take_indices(n, num_blocks)
# forward pass
B, _, height, width = x.shape
x = self.patch_embed(x)
x, rot_pos_embed = self._pos_embed(x)
if torch.jit.is_scripting() or not stop_early: # can't slice blocks in torchscript
blocks = self.blocks
else:
blocks = self.blocks[:max_index + 1]
for i, blk in enumerate(blocks):
x = blk(x, rope=rot_pos_embed)
if i in take_indices:
intermediates.append(self.norm(x) if norm else x)
# process intermediates
if self.num_prefix_tokens:
# split prefix (e.g. class, distill) and spatial feature tokens
prefix_tokens = [y[:, 0:self.num_prefix_tokens] for y in intermediates]
intermediates = [y[:, self.num_prefix_tokens:] for y in intermediates]
if reshape:
# reshape == True => BCHW output format
patch_size = self.patch_embed.patch_size
H = int(math.ceil(height / patch_size[0]))
W = int(math.ceil(width / patch_size[1]))
intermediates = [y.reshape(B, H, W, -1).permute(0, 3, 1, 2).contiguous() for y in intermediates]
if not torch.jit.is_scripting() and return_prefix_tokens:
# return_prefix not support in torchscript due to poor type handling
intermediates = list(zip(intermediates, prefix_tokens))
if features_only:
return intermediates
x = self.norm(x)
return x, intermediates
def prune_intermediate_layers(
self,
n: Union[int, List[int], Tuple[int]] = 1,
prune_norm: bool = False,
prune_head: bool = True,
):
""" Prune layers not required for specified intermediates.
"""
take_indices, max_index = feature_take_indices(n, len(self.blocks))
self.blocks = self.blocks[:max_index + 1] # truncate blocks
if prune_norm:
self.norm = nn.Identity()
if prune_head:
self.fc_norm = nn.Identity()
self.head = nn.Identity()
def forward_features(self, x):
x = self.patch_embed(x)
x, rot_pos_embed = self._pos_embed(x)
@ -663,13 +744,13 @@ def checkpoint_filter_fn(
def _create_eva(variant, pretrained=False, **kwargs):
if kwargs.get('features_only', None):
raise RuntimeError('features_only not implemented for Eva models.')
out_indices = kwargs.pop('out_indices', 3)
model = build_model_with_cfg(
Eva, variant, pretrained,
pretrained_filter_fn=checkpoint_filter_fn,
**kwargs)
feature_cfg=dict(out_indices=out_indices, feature_cls='getter'),
**kwargs,
)
return model

View File

@ -26,6 +26,7 @@ from torch import nn
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
from timm.layers import Mlp, DropPath, trunc_normal_tf_, get_norm_layer, to_2tuple
from ._builder import build_model_with_cfg
from ._features import feature_take_indices
from ._features_fx import register_notrace_function
from ._registry import register_model, register_model_deprecations, generate_default_cfgs
@ -747,8 +748,10 @@ class MultiScaleVit(nn.Module):
num_stages = len(cfg.embed_dim)
feat_size = patch_dims
curr_stride = max(cfg.patch_stride)
dpr = [x.tolist() for x in torch.linspace(0, drop_path_rate, sum(cfg.depths)).split(cfg.depths)]
self.stages = nn.ModuleList()
self.feature_info = []
for i in range(num_stages):
if cfg.expand_attn:
dim_out = cfg.embed_dim[i]
@ -775,6 +778,8 @@ class MultiScaleVit(nn.Module):
norm_layer=norm_layer,
drop_path=dpr[i],
)
curr_stride *= max(cfg.stride_q[i])
self.feature_info += [dict(module=f'block.{i}', num_chs=dim_out, reduction=curr_stride)]
embed_dim = dim_out
feat_size = stage.feat_size
self.stages.append(stage)
@ -829,6 +834,51 @@ class MultiScaleVit(nn.Module):
('fc', nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity())
]))
def forward_intermediates(
self,
x: torch.Tensor,
n: Union[int, List[int], Tuple[int]] = None,
norm: bool = False,
stop_early: bool = True,
output_fmt: str = 'NCHW',
features_only: bool = False,
) -> Union[List[torch.Tensor], Tuple[torch.Tensor, List[torch.Tensor]]]:
assert output_fmt in ('NCHW', 'NLC'), 'Output shape for MViT-V2 must be NCHW or NLC.'
reshape = output_fmt == 'NCHW'
intermediates = []
num_stages = len(self.stages) # block list is two-tiered, first tier == stage
if n is None:
n = num_stages
take_indices, max_index = feature_take_indices(n, num_stages)
# FIXME slice block/pos_block if < max
# forward pass
x, feat_size = self.patch_embed(x)
B = x.shape[0]
if self.cls_token is not None:
cls_tokens = self.cls_token.expand(B, -1, -1)
x = torch.cat((cls_tokens, x), dim=1)
if self.pos_embed is not None:
x = x + self.pos_embed
for i, stage in enumerate(self.stages):
x, feat_size = stage(x, feat_size)
if i in take_indices:
if norm and i == (len(self.stages) - 1):
x_inter = self.norm(x) # applying final norm last intermediate
else:
x_inter = x
if reshape:
x_inter = x_inter.reshape(B, feat_size[0], feat_size[1], -1).permute(0, 3, 1, 2)
intermediates.append(x_inter)
if features_only:
return intermediates
x = self.norm(x)
return x, intermediates
def forward_features(self, x):
x, feat_size = self.patch_embed(x)
B, N, C = x.shape
@ -862,6 +912,18 @@ class MultiScaleVit(nn.Module):
def checkpoint_filter_fn(state_dict, model):
if 'stages.0.blocks.0.norm1.weight' in state_dict:
# native checkpoint, look for rel_pos interpolations
for k in state_dict.keys():
if 'rel_pos' in k:
rel_pos = state_dict[k]
dest_rel_pos_shape = model.state_dict()[k].shape
if rel_pos.shape[0] != dest_rel_pos_shape[0]:
rel_pos_resized = torch.nn.functional.interpolate(
rel_pos.reshape(1, rel_pos.shape[0], -1).permute(0, 2, 1),
size=dest_rel_pos_shape[0],
mode="linear",
)
state_dict[k] = rel_pos_resized.reshape(-1, dest_rel_pos_shape[0]).permute(1, 0)
return state_dict
import re
@ -892,16 +954,6 @@ def checkpoint_filter_fn(state_dict, model):
k = k.replace('head.projection', 'head.fc')
out_dict[k] = v
# for k, v in state_dict.items():
# if model.pos_embed is not None and k == 'pos_embed' and v.shape[1] != model.pos_embed.shape[1]:
# # To resize pos embedding when using model at different size from pretrained weights
# v = resize_pos_embed(
# v,
# model.pos_embed,
# 0 if getattr(model, 'no_embed_class') else getattr(model, 'num_prefix_tokens', 1),
# model.patch_embed.grid_size
# )
return out_dict
@ -948,16 +1000,14 @@ model_cfgs = dict(
def _create_mvitv2(variant, cfg_variant=None, pretrained=False, **kwargs):
if kwargs.get('features_only', None):
raise RuntimeError('features_only not implemented for Multiscale Vision Transformer models.')
out_indices = kwargs.pop('out_indices', 4)
return build_model_with_cfg(
MultiScaleVit,
variant,
pretrained,
model_cfg=model_cfgs[variant] if not cfg_variant else model_cfgs[cfg_variant],
pretrained_filter_fn=checkpoint_filter_fn,
feature_cfg=dict(flatten_sequential=True),
feature_cfg=dict(out_indices=out_indices, feature_cls='getter'),
**kwargs,
)

View File

@ -13,7 +13,7 @@ Code/weights from https://github.com/Meituan-AutoML/Twins, original copyright/li
# --------------------------------------------------------
import math
from functools import partial
from typing import Tuple
from typing import List, Optional, Tuple, Union
import torch
import torch.nn as nn
@ -22,6 +22,7 @@ import torch.nn.functional as F
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
from timm.layers import Mlp, DropPath, to_2tuple, trunc_normal_, use_fused_attn
from ._builder import build_model_with_cfg
from ._features import feature_take_indices
from ._features_fx import register_notrace_module
from ._registry import register_model, generate_default_cfgs
from .vision_transformer import Attention
@ -324,6 +325,7 @@ class Twins(nn.Module):
patch_size = 2
self.blocks = nn.ModuleList()
self.feature_info = []
dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))] # stochastic depth decay rule
cur = 0
for k in range(len(depths)):
@ -339,6 +341,7 @@ class Twins(nn.Module):
ws=1 if wss is None or i % 2 == 1 else wss[k]) for i in range(depths[k])],
)
self.blocks.append(_block)
self.feature_info += [dict(module=f'block.{k}', num_chs=embed_dims[k], reduction=2**(2+k))]
cur += depths[k]
self.pos_block = nn.ModuleList([PosConv(embed_dim, embed_dim) for embed_dim in embed_dims])
@ -401,6 +404,53 @@ class Twins(nn.Module):
if m.bias is not None:
m.bias.data.zero_()
def forward_intermediates(
self,
x: torch.Tensor,
n: Union[int, List[int], Tuple[int]] = None,
norm: bool = False,
stop_early: bool = True,
output_fmt: str = 'NCHW',
features_only: bool = False,
) -> Union[List[torch.Tensor], Tuple[torch.Tensor, List[torch.Tensor]]]:
assert output_fmt == 'NCHW', 'Output shape for Twins must be NCHW.'
intermediates = []
num_stages = len(self.blocks) # block list is two-tiered, first tier == stage
if n is None:
n = num_stages
take_indices, max_index = feature_take_indices(n, num_stages)
# FIXME slice block/pos_block if < max
# forward pass
B, _, height, width = x.shape
for i, (embed, drop, blocks, pos_blk) in enumerate(zip(
self.patch_embeds, self.pos_drops, self.blocks, self.pos_block)
):
x, size = embed(x)
x = drop(x)
for j, blk in enumerate(blocks):
x = blk(x, size)
if j == 0:
x = pos_blk(x, size) # PEG here
if i < len(self.depths) - 1:
x = x.reshape(B, *size, -1).permute(0, 3, 1, 2).contiguous()
if i in take_indices:
intermediates.append(x)
else:
if i in take_indices:
# only last feature can be normed
x_feat = self.norm(x) if norm else x
intermediates.append(x_feat.reshape(B, *size, -1).permute(0, 3, 1, 2).contiguous())
if features_only:
return intermediates
x = self.norm(x)
return x, intermediates
def forward_features(self, x):
B = x.shape[0]
for i, (embed, drop, blocks, pos_blk) in enumerate(
@ -429,10 +479,12 @@ class Twins(nn.Module):
def _create_twins(variant, pretrained=False, **kwargs):
if kwargs.get('features_only', None):
raise RuntimeError('features_only not implemented for Vision Transformer models.')
model = build_model_with_cfg(Twins, variant, pretrained, **kwargs)
out_indices = kwargs.pop('out_indices', 4)
model = build_model_with_cfg(
Twins, variant, pretrained,
feature_cfg=dict(out_indices=out_indices, feature_cls='getter'),
**kwargs,
)
return model

View File

@ -45,6 +45,7 @@ from timm.layers import PatchEmbed, Mlp, DropPath, AttentionPoolLatent, RmsNorm,
trunc_normal_, lecun_normal_, resample_patch_embed, resample_abs_pos_embed, use_fused_attn, \
get_act_layer, get_norm_layer, LayerType
from ._builder import build_model_with_cfg
from ._features import feature_take_indices
from ._manipulate import named_apply, checkpoint_seq, adapt_input_conv
from ._registry import generate_default_cfgs, register_model, register_model_deprecations
@ -473,7 +474,6 @@ class VisionTransformer(nn.Module):
self.no_embed_class = no_embed_class # don't embed prefix positions (includes reg)
self.dynamic_img_size = dynamic_img_size
self.grad_checkpointing = False
self.feature_info = []
embed_args = {}
if dynamic_img_size:
@ -631,58 +631,111 @@ class VisionTransformer(nn.Module):
return self.pos_drop(x)
def _intermediate_layers(
def forward_intermediates(
self,
x: torch.Tensor,
n: Union[int, Sequence] = 1,
) -> List[torch.Tensor]:
outputs, num_blocks = [], len(self.blocks)
take_indices = set(range(num_blocks - n, num_blocks) if isinstance(n, int) else n)
last_index_to_take = max(take_indices)
n: Optional[Union[int, List[int], Tuple[int]]] = None,
return_prefix_tokens: bool = False,
norm: bool = False,
stop_early: bool = True,
output_fmt: str = 'NCHW',
features_only: bool = False,
) -> Union[List[torch.Tensor], Tuple[torch.Tensor, List[torch.Tensor]]]:
""" Forward features that returns intermediates.
Args:
x: Input image tensor
n: Take last n blocks if n is an int, if in is a sequence, select by matching indices
return_prefix_tokens: Return both prefix and spatial intermediate tokens
norm: Apply norm layer to all intermediates
stop_early: Stop iterating over blocks when last desired intermediate hit
output_fmt: Shape of intermediate feature outputs
features_only: Only return intermediate features
Returns:
"""
assert output_fmt in ('NCHW', 'NLC'), 'Output format for ViT features must be one of NCHW or NLC.'
reshape = output_fmt == 'NCHW'
intermediates = []
num_blocks = len(self.blocks)
if n is None:
n = num_blocks
take_indices, max_index = feature_take_indices(n, num_blocks)
# forward pass
B, _, height, width = x.shape
x = self.patch_embed(x)
x = self._pos_embed(x)
x = self.patch_drop(x)
x = self.norm_pre(x)
for i, blk in enumerate(self.blocks[: last_index_to_take + 1]):
if torch.jit.is_scripting() or not stop_early: # can't slice blocks in torchscript
blocks = self.blocks
else:
blocks = self.blocks[:max_index + 1]
for i, blk in enumerate(blocks):
x = blk(x)
if i in take_indices:
outputs.append(x)
# normalize intermediates with final norm layer if enabled
intermediates.append(self.norm(x) if norm else x)
return outputs
# process intermediates
if self.num_prefix_tokens:
# split prefix (e.g. class, distill) and spatial feature tokens
prefix_tokens = [y[:, 0:self.num_prefix_tokens] for y in intermediates]
intermediates = [y[:, self.num_prefix_tokens:] for y in intermediates]
if reshape:
# reshape == True => BCHW output format
patch_size = self.patch_embed.patch_size
H = int(math.ceil(height / patch_size[0]))
W = int(math.ceil(width / patch_size[1]))
intermediates = [y.reshape(B, H, W, -1).permute(0, 3, 1, 2).contiguous() for y in intermediates]
if not torch.jit.is_scripting() and return_prefix_tokens:
# return_prefix not support in torchscript due to poor type handling
intermediates = list(zip(intermediates, prefix_tokens))
if features_only:
return intermediates
x = self.norm(x)
return x, intermediates
def prune_intermediate_layers(
self,
n: Union[int, List[int], Tuple[int]] = 1,
prune_norm: bool = False,
prune_head: bool = True,
):
""" Prune layers not required for specified intermediates.
"""
take_indices, max_index = feature_take_indices(n, len(self.blocks))
self.blocks = self.blocks[:max_index + 1] # truncate blocks
if prune_norm:
self.norm = nn.Identity()
if prune_head:
if self.attn_pool is not None:
self.attn_pool = None
self.fc_norm = nn.Identity()
self.head = nn.Identity()
def get_intermediate_layers(
self,
x: torch.Tensor,
n: Union[int, Sequence] = 1,
n: Union[int, List[int], Tuple[int]] = 1,
reshape: bool = False,
return_prefix_tokens: bool = False,
norm: bool = False,
) -> Tuple[Union[torch.Tensor, Tuple[torch.Tensor]]]:
""" Intermediate layer accessor (NOTE: This is a WIP experiment).
Inspired by DINO / DINOv2 interface
) -> List[torch.Tensor]:
""" Intermediate layer accessor inspired by DINO / DINOv2 interface.
NOTE: This API is for backwards compat, favour using forward_intermediates() directly.
"""
# take last n blocks if n is an int, if in is a sequence, select by matching indices
outputs = self._intermediate_layers(x, n)
if norm:
outputs = [self.norm(out) for out in outputs]
prefix_tokens = [out[:, 0:self.num_prefix_tokens] for out in outputs]
outputs = [out[:, self.num_prefix_tokens:] for out in outputs]
if reshape:
patch_size = self.patch_embed.patch_size
batch, _, height, width = x.size()
outputs = [
out.reshape(batch, int(math.ceil(height / patch_size[0])), int(math.ceil(width / patch_size[1])), -1)
.permute(0, 3, 1, 2)
.contiguous()
for out in outputs
]
if return_prefix_tokens:
return tuple(zip(outputs, prefix_tokens))
return tuple(outputs)
return self.forward_intermediates(
x, n,
return_prefix_tokens=return_prefix_tokens,
norm=norm,
output_fmt='NCHW' if reshape else 'NLC',
features_only=True,
)
def forward_features(self, x: torch.Tensor) -> torch.Tensor:
x = self.patch_embed(x)
@ -2485,7 +2538,7 @@ def vit_huge_patch14_xp_224(pretrained: bool = False, **kwargs) -> VisionTransfo
def vit_small_patch14_dinov2(pretrained: bool = False, **kwargs) -> VisionTransformer:
""" ViT-S/14 for DINOv2
"""
model_args = dict(patch_size=14, embed_dim=384, depth=12, num_heads=6, init_values=1e-5, img_size=518)
model_args = dict(patch_size=14, embed_dim=384, depth=12, num_heads=6, init_values=1e-5)
model = _create_vision_transformer(
'vit_small_patch14_dinov2', pretrained=pretrained, **dict(model_args, **kwargs))
return model
@ -2495,7 +2548,7 @@ def vit_small_patch14_dinov2(pretrained: bool = False, **kwargs) -> VisionTransf
def vit_base_patch14_dinov2(pretrained: bool = False, **kwargs) -> VisionTransformer:
""" ViT-B/14 for DINOv2
"""
model_args = dict(patch_size=14, embed_dim=768, depth=12, num_heads=12, init_values=1e-5, img_size=518)
model_args = dict(patch_size=14, embed_dim=768, depth=12, num_heads=12, init_values=1e-5)
model = _create_vision_transformer(
'vit_base_patch14_dinov2', pretrained=pretrained, **dict(model_args, **kwargs))
return model
@ -2505,7 +2558,7 @@ def vit_base_patch14_dinov2(pretrained: bool = False, **kwargs) -> VisionTransfo
def vit_large_patch14_dinov2(pretrained: bool = False, **kwargs) -> VisionTransformer:
""" ViT-L/14 for DINOv2
"""
model_args = dict(patch_size=14, embed_dim=1024, depth=24, num_heads=16, init_values=1e-5, img_size=518)
model_args = dict(patch_size=14, embed_dim=1024, depth=24, num_heads=16, init_values=1e-5)
model = _create_vision_transformer(
'vit_large_patch14_dinov2', pretrained=pretrained, **dict(model_args, **kwargs))
return model
@ -2521,7 +2574,7 @@ def vit_giant_patch14_dinov2(pretrained: bool = False, **kwargs) -> VisionTransf
# With SwiGLUPacked, we need to set hidden_features = 2 * 4096 = 8192
model_args = dict(
patch_size=14, embed_dim=1536, depth=40, num_heads=24, init_values=1e-5,
mlp_ratio=2.66667 * 2, mlp_layer=SwiGLUPacked, img_size=518, act_layer=nn.SiLU
mlp_ratio=2.66667 * 2, mlp_layer=SwiGLUPacked, act_layer=nn.SiLU
)
model = _create_vision_transformer(
'vit_giant_patch14_dinov2', pretrained=pretrained, **dict(model_args, **kwargs))

View File

@ -11,21 +11,22 @@ A PyTorch implement of Vision Transformers as described in:
"""
import logging
from functools import partial
from typing import Callable, Optional, Tuple
from typing import Callable, List, Optional, Tuple, Union
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.utils.checkpoint
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD, IMAGENET_INCEPTION_MEAN, IMAGENET_INCEPTION_STD
from timm.layers import PatchEmbed, Mlp, DropPath, PatchDropout, LayerNorm2d, ClassifierHead, NormMlpClassifierHead, \
Format, resample_abs_pos_embed_nhwc, RotaryEmbeddingCat, apply_rot_embed_cat, to_2tuple, use_fused_attn
from torch.jit import Final
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD, IMAGENET_INCEPTION_MEAN, IMAGENET_INCEPTION_STD
from timm.layers import PatchEmbed, Mlp, DropPath, PatchDropout, LayerNorm2d, ClassifierHead, NormMlpClassifierHead,\
Format, resample_abs_pos_embed_nhwc, RotaryEmbeddingCat, apply_rot_embed_cat, to_2tuple, use_fused_attn
from ._builder import build_model_with_cfg
from ._features import feature_take_indices
from ._features_fx import register_notrace_function
from ._manipulate import checkpoint_seq
from ._registry import generate_default_cfgs, register_model
from ._features_fx import register_notrace_function
# model_registry will add each entrypoint fn to this
__all__ = ['VisionTransformerSAM']
@ -343,8 +344,7 @@ class VisionTransformerSAM(nn.Module):
attn_drop_rate: float = 0.,
drop_path_rate: float = 0.,
weight_init: str = '',
embed_layer: Callable = partial(
PatchEmbed, output_fmt=Format.NHWC, strict_img_size=False),
embed_layer: Callable = partial(PatchEmbed, output_fmt=Format.NHWC, strict_img_size=False),
norm_layer: Optional[Callable] = nn.LayerNorm,
act_layer: Optional[Callable] = nn.GELU,
block_fn: Callable = Block,
@ -469,6 +469,8 @@ class VisionTransformerSAM(nn.Module):
rope=self.rope_window if i not in global_attn_indexes else self.rope_global,
)
for i in range(depth)])
self.feature_info = [
dict(module=f'blocks.{i}', num_chs=embed_dim, reduction=patch_size) for i in range(depth)]
if neck_chans:
self.neck = nn.Sequential(
@ -536,6 +538,52 @@ class VisionTransformerSAM(nn.Module):
def reset_classifier(self, num_classes=0, global_pool=None):
self.head.reset(num_classes, global_pool)
def forward_intermediates(
self,
x: torch.Tensor,
n: Union[int, List[int], Tuple[int]] = None,
norm: bool = False,
stop_early: bool = True,
output_fmt: str = 'NCHW',
features_only: bool = False,
) -> Union[List[torch.Tensor], Tuple[torch.Tensor, List[torch.Tensor]]]:
assert output_fmt == 'NCHW', 'Output shape for ViT-SAM must be NCHW.'
intermediates = []
num_blocks = len(self.blocks)
if n is None:
n = num_blocks
take_indices, max_index = feature_take_indices(n, num_blocks)
# forward pass, collect intermediates
x = self.patch_embed(x)
if self.pos_embed is not None:
# dynamically resize abs pos embedding if needed
x = x + resample_abs_pos_embed_nhwc(self.pos_embed, x.shape[1:3])
x = self.pos_drop(x)
x = self.patch_drop(x)
x = self.norm_pre(x)
if torch.jit.is_scripting() or not stop_early: # can't slice blocks in torchscript
blocks = self.blocks
else:
blocks = self.blocks[:max_index + 1]
for i, blk in enumerate(blocks):
x = blk(x)
if i in take_indices:
# make output BCHW
if norm:
# norm is intertwined with neck convs so apply both, changes the dim
# FIXME only apply to final? Need experiments
intermediates.append(self.neck(x.permute(0, 3, 1, 2)))
else:
intermediates.append(x.permute(0, 3, 1, 2))
if features_only:
return intermediates
x = self.neck(x.permute(0, 3, 1, 2))
return x, intermediates
def forward_features(self, x):
x = self.patch_embed(x)
if self.pos_embed is not None:
@ -618,15 +666,13 @@ default_cfgs = generate_default_cfgs({
def _create_vision_transformer(variant, pretrained=False, **kwargs):
if kwargs.get('features_only', None):
raise RuntimeError(
'features_only not implemented for Vision Transformer models.')
out_indices = kwargs.pop('out_indices', 3)
return build_model_with_cfg(
VisionTransformerSAM,
variant,
pretrained,
pretrained_filter_fn=checkpoint_filter_fn,
feature_cfg=dict(out_indices=out_indices, feature_cls='getter'),
**kwargs,
)