mirror of
https://github.com/huggingface/pytorch-image-models.git
synced 2025-06-03 15:01:08 +08:00
More forward_intermediates() & features_only work
* forward_intermediates() added to beit, deit, eva, mvitv2, twins, vit, vit_sam * add features_only to forward intermediates to allow just intermediate features * fix #2060 * fix #1374 * fix #657
This commit is contained in:
parent
5fdc0b4e93
commit
679daef76a
@ -11,7 +11,7 @@ Hacked together by / Copyright 2020 Ross Wightman
|
||||
from collections import OrderedDict, defaultdict
|
||||
from copy import deepcopy
|
||||
from functools import partial
|
||||
from typing import Dict, List, Optional, Sequence, Tuple, Union
|
||||
from typing import Dict, List, Optional, Sequence, Set, Tuple, Union
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
@ -20,7 +20,39 @@ from torch.utils.checkpoint import checkpoint
|
||||
from timm.layers import Format
|
||||
|
||||
|
||||
__all__ = ['FeatureInfo', 'FeatureHooks', 'FeatureDictNet', 'FeatureListNet', 'FeatureHookNet']
|
||||
__all__ = [
|
||||
'FeatureInfo', 'FeatureHooks', 'FeatureDictNet', 'FeatureListNet', 'FeatureHookNet', 'FeatureGetterNet',
|
||||
'feature_take_indices'
|
||||
]
|
||||
|
||||
|
||||
def _take_indices(n: Union[int, List[int], Tuple[int]], num_blocks: int) -> Tuple[Set[int], int]:
|
||||
if isinstance(n, int):
|
||||
assert n >= 0
|
||||
take_indices = {x for x in range(num_blocks - n, num_blocks)}
|
||||
else:
|
||||
take_indices = {num_blocks + idx if idx < 0 else idx for idx in n}
|
||||
return take_indices, max(take_indices)
|
||||
|
||||
|
||||
def _take_indices_jit(n: Union[int, List[int], Tuple[int]], num_blocks: int) -> Tuple[List[int], int]:
|
||||
if isinstance(n, int):
|
||||
assert n >= 0
|
||||
take_indices = [num_blocks - n + i for i in range(n)]
|
||||
elif isinstance(n, tuple):
|
||||
# splitting this up is silly, but needed for torchscript type resolution of n
|
||||
take_indices = [num_blocks + idx if idx < 0 else idx for idx in n]
|
||||
else:
|
||||
take_indices = [num_blocks + idx if idx < 0 else idx for idx in n]
|
||||
return take_indices, max(take_indices)
|
||||
|
||||
|
||||
def feature_take_indices(n: Union[int, List[int], Tuple[int]], num_blocks: int) -> Tuple[List[int], int]:
|
||||
if torch.jit.is_scripting():
|
||||
return _take_indices_jit(n, num_blocks)
|
||||
else:
|
||||
# NOTE non-jit returns Set[int] instead of List[int] but torchscript can't handle that anno
|
||||
return _take_indices(n, num_blocks)
|
||||
|
||||
|
||||
def _out_indices_as_tuple(x: Union[int, Tuple[int, ...]]) -> Tuple[int, ...]:
|
||||
@ -397,29 +429,38 @@ class FeatureGetterNet(nn.ModuleDict):
|
||||
out_map: Optional[Sequence[Union[int, str]]] = None,
|
||||
return_dict: bool = False,
|
||||
output_fmt: str = 'NCHW',
|
||||
norm: bool = False,
|
||||
prune: bool = True,
|
||||
):
|
||||
"""
|
||||
|
||||
Args:
|
||||
model: Model to wrap.
|
||||
out_indices: Indices of features to extract.
|
||||
out_map: Remap feature names for dict output (WIP, not supported).
|
||||
return_dict: Return features as dictionary instead of list (WIP, not supported).
|
||||
norm: Apply final model norm to all output features (if possible).
|
||||
"""
|
||||
super().__init__()
|
||||
self.model = model
|
||||
if prune and hasattr(model, 'prune_intermediate_layers'):
|
||||
model.prune_intermediate_layers(
|
||||
out_indices,
|
||||
prune_norm=not norm,
|
||||
)
|
||||
self.feature_info = _get_feature_info(model, out_indices)
|
||||
self.model = model
|
||||
self.out_indices = out_indices
|
||||
self.out_map = out_map
|
||||
self.return_dict = return_dict
|
||||
self.output_fmt = output_fmt
|
||||
self.norm = norm
|
||||
|
||||
def forward(self, *args, **kwargs):
|
||||
"""
|
||||
def get_intermediate_layers(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
n: Union[int, Sequence] = 1,
|
||||
reshape: bool = False,
|
||||
return_prefix_tokens: bool = False,
|
||||
norm: bool = False,
|
||||
"""
|
||||
out = self.model.get_intermediate_layers(
|
||||
*args,
|
||||
def forward(self, x):
|
||||
features = self.model.forward_intermediates(
|
||||
x,
|
||||
n=self.out_indices,
|
||||
reshape=True,
|
||||
**kwargs,
|
||||
norm=self.norm,
|
||||
output_fmt=self.output_fmt,
|
||||
features_only=True,
|
||||
)
|
||||
return out
|
||||
return features
|
||||
|
@ -39,7 +39,7 @@ Modifications by / Copyright 2021 Ross Wightman, original copyrights below
|
||||
# --------------------------------------------------------'
|
||||
|
||||
import math
|
||||
from typing import Callable, Optional, Tuple, Union
|
||||
from typing import Callable, List, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
@ -52,8 +52,8 @@ from timm.layers import resample_patch_embed, resample_abs_pos_embed, resize_rel
|
||||
|
||||
|
||||
from ._builder import build_model_with_cfg
|
||||
from ._features import feature_take_indices
|
||||
from ._registry import generate_default_cfgs, register_model
|
||||
from .vision_transformer import checkpoint_filter_fn
|
||||
|
||||
__all__ = ['Beit']
|
||||
|
||||
@ -333,6 +333,8 @@ class Beit(nn.Module):
|
||||
window_size=self.patch_embed.grid_size if use_rel_pos_bias else None,
|
||||
)
|
||||
for i in range(depth)])
|
||||
self.feature_info = [
|
||||
dict(module=f'blocks.{i}', num_chs=embed_dim, reduction=patch_size) for i in range(depth)]
|
||||
|
||||
use_fc_norm = self.global_pool == 'avg'
|
||||
self.norm = nn.Identity() if use_fc_norm else norm_layer(embed_dim)
|
||||
@ -398,6 +400,93 @@ class Beit(nn.Module):
|
||||
self.global_pool = global_pool
|
||||
self.head = nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity()
|
||||
|
||||
def forward_intermediates(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
n: Optional[Union[int, List[int], Tuple[int]]] = None,
|
||||
return_prefix_tokens: bool = False,
|
||||
norm: bool = False,
|
||||
stop_early: bool = True,
|
||||
output_fmt: str = 'NCHW',
|
||||
features_only: bool = False,
|
||||
) -> Union[List[torch.Tensor], Tuple[torch.Tensor, List[torch.Tensor]]]:
|
||||
""" Forward features that returns intermediates.
|
||||
|
||||
Args:
|
||||
x: Input image tensor
|
||||
n: Take last n blocks if n is an int, if in is a sequence, select by matching indices
|
||||
return_prefix_tokens: Return both prefix and spatial intermediate tokens
|
||||
norm: Apply norm layer to all intermediates
|
||||
stop_early: Stop iterating over blocks when last desired intermediate hit
|
||||
output_fmt: Shape of intermediate feature outputs
|
||||
features_only: Only return intermediate features
|
||||
Returns:
|
||||
|
||||
"""
|
||||
assert output_fmt in ('NCHW', 'NLC'), 'Output format for ViT features must be one of NCHW or NLC.'
|
||||
reshape = output_fmt == 'NCHW'
|
||||
intermediates = []
|
||||
num_blocks = len(self.blocks)
|
||||
if n is None:
|
||||
n = num_blocks
|
||||
take_indices, max_index = feature_take_indices(n, num_blocks)
|
||||
|
||||
# forward pass
|
||||
B, _, height, width = x.shape
|
||||
x = self.patch_embed(x)
|
||||
x = torch.cat((self.cls_token.expand(x.shape[0], -1, -1), x), dim=1)
|
||||
if self.pos_embed is not None:
|
||||
x = x + self.pos_embed
|
||||
x = self.pos_drop(x)
|
||||
rel_pos_bias = self.rel_pos_bias() if self.rel_pos_bias is not None else None
|
||||
if torch.jit.is_scripting() or not stop_early: # can't slice blocks in torchscript
|
||||
blocks = self.blocks
|
||||
else:
|
||||
blocks = self.blocks[:max_index + 1]
|
||||
for i, blk in enumerate(blocks):
|
||||
x = blk(x, shared_rel_pos_bias=rel_pos_bias)
|
||||
if i in take_indices:
|
||||
# normalize intermediates with final norm layer if enabled
|
||||
intermediates.append(self.norm(x) if norm else x)
|
||||
|
||||
# process intermediates
|
||||
if self.num_prefix_tokens:
|
||||
# split prefix (e.g. class, distill) and spatial feature tokens
|
||||
prefix_tokens = [y[:, 0:self.num_prefix_tokens] for y in intermediates]
|
||||
intermediates = [y[:, self.num_prefix_tokens:] for y in intermediates]
|
||||
if reshape:
|
||||
# reshape == True => BCHW output format
|
||||
patch_size = self.patch_embed.patch_size
|
||||
H = int(math.ceil(height / patch_size[0]))
|
||||
W = int(math.ceil(width / patch_size[1]))
|
||||
intermediates = [y.reshape(B, H, W, -1).permute(0, 3, 1, 2).contiguous() for y in intermediates]
|
||||
if not torch.jit.is_scripting() and return_prefix_tokens:
|
||||
# return_prefix not support in torchscript due to poor type handling
|
||||
intermediates = list(zip(intermediates, prefix_tokens))
|
||||
|
||||
if features_only:
|
||||
return intermediates
|
||||
|
||||
x = self.norm(x)
|
||||
|
||||
return x, intermediates
|
||||
|
||||
def prune_intermediate_layers(
|
||||
self,
|
||||
n: Union[int, List[int], Tuple[int]] = 1,
|
||||
prune_norm: bool = False,
|
||||
prune_head: bool = True,
|
||||
):
|
||||
""" Prune layers not required for specified intermediates.
|
||||
"""
|
||||
take_indices, max_index = feature_take_indices(n, len(self.blocks))
|
||||
self.blocks = self.blocks[:max_index + 1] # truncate blocks
|
||||
if prune_norm:
|
||||
self.norm = nn.Identity()
|
||||
if prune_head:
|
||||
self.fc_norm = nn.Identity()
|
||||
self.head = nn.Identity()
|
||||
|
||||
def forward_features(self, x):
|
||||
x = self.patch_embed(x)
|
||||
x = torch.cat((self.cls_token.expand(x.shape[0], -1, -1), x), dim=1)
|
||||
@ -547,14 +636,13 @@ def _beit_checkpoint_filter_fn(state_dict, model, interpolation='bicubic', antia
|
||||
|
||||
|
||||
def _create_beit(variant, pretrained=False, **kwargs):
|
||||
if kwargs.get('features_only', None):
|
||||
raise RuntimeError('features_only not implemented for BEiT models.')
|
||||
|
||||
out_indices = kwargs.pop('out_indices', 3)
|
||||
model = build_model_with_cfg(
|
||||
Beit, variant, pretrained,
|
||||
# FIXME an updated filter fn needed to interpolate rel pos emb if fine tuning to diff model sizes
|
||||
pretrained_filter_fn=_beit_checkpoint_filter_fn,
|
||||
**kwargs)
|
||||
feature_cfg=dict(out_indices=out_indices, feature_cls='getter'),
|
||||
**kwargs,
|
||||
)
|
||||
return model
|
||||
|
||||
|
||||
|
@ -119,14 +119,14 @@ class VisionTransformerDistilled(VisionTransformer):
|
||||
|
||||
|
||||
def _create_deit(variant, pretrained=False, distilled=False, **kwargs):
|
||||
if kwargs.get('features_only', None):
|
||||
raise RuntimeError('features_only not implemented for Vision Transformer models.')
|
||||
out_indices = kwargs.pop('out_indices', 3)
|
||||
model_cls = VisionTransformerDistilled if distilled else VisionTransformer
|
||||
model = build_model_with_cfg(
|
||||
model_cls,
|
||||
variant,
|
||||
pretrained,
|
||||
pretrained_filter_fn=partial(checkpoint_filter_fn, adapt_layer_scale=True),
|
||||
feature_cfg=dict(out_indices=out_indices, feature_cls='getter'),
|
||||
**kwargs,
|
||||
)
|
||||
return model
|
||||
|
@ -24,9 +24,8 @@ Modifications by / Copyright 2023 Ross Wightman, original copyrights below
|
||||
"""
|
||||
# EVA models Copyright (c) 2022 BAAI-Vision
|
||||
# EVA02 models Copyright (c) 2023 BAAI-Vision
|
||||
|
||||
import math
|
||||
from typing import Callable, Optional, Tuple, Union
|
||||
from typing import Callable, List, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
@ -39,6 +38,7 @@ from timm.layers import PatchEmbed, Mlp, GluMlp, SwiGLU, LayerNorm, DropPath, Pa
|
||||
to_2tuple, use_fused_attn
|
||||
|
||||
from ._builder import build_model_with_cfg
|
||||
from ._features import feature_take_indices
|
||||
from ._registry import generate_default_cfgs, register_model
|
||||
|
||||
__all__ = ['Eva']
|
||||
@ -469,6 +469,8 @@ class Eva(nn.Module):
|
||||
init_values=init_values,
|
||||
)
|
||||
for i in range(depth)])
|
||||
self.feature_info = [
|
||||
dict(module=f'blocks.{i}', num_chs=embed_dim, reduction=patch_size) for i in range(depth)]
|
||||
|
||||
use_fc_norm = self.global_pool == 'avg'
|
||||
self.norm = nn.Identity() if use_fc_norm else norm_layer(embed_dim)
|
||||
@ -559,6 +561,85 @@ class Eva(nn.Module):
|
||||
rot_pos_embed = apply_keep_indices_nlc(x, rot_pos_embed, keep_indices)
|
||||
return x, rot_pos_embed
|
||||
|
||||
def forward_intermediates(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
n: Optional[Union[int, List[int], Tuple[int]]] = None,
|
||||
return_prefix_tokens: bool = False,
|
||||
norm: bool = False,
|
||||
stop_early: bool = True,
|
||||
output_fmt: str = 'NCHW',
|
||||
features_only: bool = False,
|
||||
) -> Union[List[torch.Tensor], Tuple[torch.Tensor, List[torch.Tensor]]]:
|
||||
""" Forward features that returns intermediates.
|
||||
Args:
|
||||
x: Input image tensor
|
||||
n: Take last n blocks if n is an int, if in is a sequence, select by matching indices
|
||||
return_prefix_tokens: Return both prefix and spatial intermediate tokens
|
||||
norm: Apply norm layer to all intermediates
|
||||
stop_early: Stop iterating over blocks when last desired intermediate hit
|
||||
output_fmt: Shape of intermediate feature outputs
|
||||
features_only: Only return intermediate features
|
||||
"""
|
||||
assert output_fmt in ('NCHW', 'NLC'), 'Output format for EVA-ViT features must be one of NCHW or NLC.'
|
||||
reshape = output_fmt == 'NCHW'
|
||||
intermediates = []
|
||||
num_blocks = len(self.blocks)
|
||||
if n is None:
|
||||
n = num_blocks
|
||||
take_indices, max_index = feature_take_indices(n, num_blocks)
|
||||
|
||||
# forward pass
|
||||
B, _, height, width = x.shape
|
||||
x = self.patch_embed(x)
|
||||
x, rot_pos_embed = self._pos_embed(x)
|
||||
if torch.jit.is_scripting() or not stop_early: # can't slice blocks in torchscript
|
||||
blocks = self.blocks
|
||||
else:
|
||||
blocks = self.blocks[:max_index + 1]
|
||||
for i, blk in enumerate(blocks):
|
||||
x = blk(x, rope=rot_pos_embed)
|
||||
if i in take_indices:
|
||||
intermediates.append(self.norm(x) if norm else x)
|
||||
|
||||
# process intermediates
|
||||
if self.num_prefix_tokens:
|
||||
# split prefix (e.g. class, distill) and spatial feature tokens
|
||||
prefix_tokens = [y[:, 0:self.num_prefix_tokens] for y in intermediates]
|
||||
intermediates = [y[:, self.num_prefix_tokens:] for y in intermediates]
|
||||
if reshape:
|
||||
# reshape == True => BCHW output format
|
||||
patch_size = self.patch_embed.patch_size
|
||||
H = int(math.ceil(height / patch_size[0]))
|
||||
W = int(math.ceil(width / patch_size[1]))
|
||||
intermediates = [y.reshape(B, H, W, -1).permute(0, 3, 1, 2).contiguous() for y in intermediates]
|
||||
if not torch.jit.is_scripting() and return_prefix_tokens:
|
||||
# return_prefix not support in torchscript due to poor type handling
|
||||
intermediates = list(zip(intermediates, prefix_tokens))
|
||||
|
||||
if features_only:
|
||||
return intermediates
|
||||
|
||||
x = self.norm(x)
|
||||
|
||||
return x, intermediates
|
||||
|
||||
def prune_intermediate_layers(
|
||||
self,
|
||||
n: Union[int, List[int], Tuple[int]] = 1,
|
||||
prune_norm: bool = False,
|
||||
prune_head: bool = True,
|
||||
):
|
||||
""" Prune layers not required for specified intermediates.
|
||||
"""
|
||||
take_indices, max_index = feature_take_indices(n, len(self.blocks))
|
||||
self.blocks = self.blocks[:max_index + 1] # truncate blocks
|
||||
if prune_norm:
|
||||
self.norm = nn.Identity()
|
||||
if prune_head:
|
||||
self.fc_norm = nn.Identity()
|
||||
self.head = nn.Identity()
|
||||
|
||||
def forward_features(self, x):
|
||||
x = self.patch_embed(x)
|
||||
x, rot_pos_embed = self._pos_embed(x)
|
||||
@ -663,13 +744,13 @@ def checkpoint_filter_fn(
|
||||
|
||||
|
||||
def _create_eva(variant, pretrained=False, **kwargs):
|
||||
if kwargs.get('features_only', None):
|
||||
raise RuntimeError('features_only not implemented for Eva models.')
|
||||
|
||||
out_indices = kwargs.pop('out_indices', 3)
|
||||
model = build_model_with_cfg(
|
||||
Eva, variant, pretrained,
|
||||
pretrained_filter_fn=checkpoint_filter_fn,
|
||||
**kwargs)
|
||||
feature_cfg=dict(out_indices=out_indices, feature_cls='getter'),
|
||||
**kwargs,
|
||||
)
|
||||
return model
|
||||
|
||||
|
||||
|
@ -26,6 +26,7 @@ from torch import nn
|
||||
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
|
||||
from timm.layers import Mlp, DropPath, trunc_normal_tf_, get_norm_layer, to_2tuple
|
||||
from ._builder import build_model_with_cfg
|
||||
from ._features import feature_take_indices
|
||||
from ._features_fx import register_notrace_function
|
||||
from ._registry import register_model, register_model_deprecations, generate_default_cfgs
|
||||
|
||||
@ -747,8 +748,10 @@ class MultiScaleVit(nn.Module):
|
||||
|
||||
num_stages = len(cfg.embed_dim)
|
||||
feat_size = patch_dims
|
||||
curr_stride = max(cfg.patch_stride)
|
||||
dpr = [x.tolist() for x in torch.linspace(0, drop_path_rate, sum(cfg.depths)).split(cfg.depths)]
|
||||
self.stages = nn.ModuleList()
|
||||
self.feature_info = []
|
||||
for i in range(num_stages):
|
||||
if cfg.expand_attn:
|
||||
dim_out = cfg.embed_dim[i]
|
||||
@ -775,6 +778,8 @@ class MultiScaleVit(nn.Module):
|
||||
norm_layer=norm_layer,
|
||||
drop_path=dpr[i],
|
||||
)
|
||||
curr_stride *= max(cfg.stride_q[i])
|
||||
self.feature_info += [dict(module=f'block.{i}', num_chs=dim_out, reduction=curr_stride)]
|
||||
embed_dim = dim_out
|
||||
feat_size = stage.feat_size
|
||||
self.stages.append(stage)
|
||||
@ -829,6 +834,51 @@ class MultiScaleVit(nn.Module):
|
||||
('fc', nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity())
|
||||
]))
|
||||
|
||||
def forward_intermediates(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
n: Union[int, List[int], Tuple[int]] = None,
|
||||
norm: bool = False,
|
||||
stop_early: bool = True,
|
||||
output_fmt: str = 'NCHW',
|
||||
features_only: bool = False,
|
||||
) -> Union[List[torch.Tensor], Tuple[torch.Tensor, List[torch.Tensor]]]:
|
||||
assert output_fmt in ('NCHW', 'NLC'), 'Output shape for MViT-V2 must be NCHW or NLC.'
|
||||
reshape = output_fmt == 'NCHW'
|
||||
intermediates = []
|
||||
num_stages = len(self.stages) # block list is two-tiered, first tier == stage
|
||||
if n is None:
|
||||
n = num_stages
|
||||
take_indices, max_index = feature_take_indices(n, num_stages)
|
||||
|
||||
# FIXME slice block/pos_block if < max
|
||||
|
||||
# forward pass
|
||||
x, feat_size = self.patch_embed(x)
|
||||
B = x.shape[0]
|
||||
if self.cls_token is not None:
|
||||
cls_tokens = self.cls_token.expand(B, -1, -1)
|
||||
x = torch.cat((cls_tokens, x), dim=1)
|
||||
if self.pos_embed is not None:
|
||||
x = x + self.pos_embed
|
||||
for i, stage in enumerate(self.stages):
|
||||
x, feat_size = stage(x, feat_size)
|
||||
if i in take_indices:
|
||||
if norm and i == (len(self.stages) - 1):
|
||||
x_inter = self.norm(x) # applying final norm last intermediate
|
||||
else:
|
||||
x_inter = x
|
||||
if reshape:
|
||||
x_inter = x_inter.reshape(B, feat_size[0], feat_size[1], -1).permute(0, 3, 1, 2)
|
||||
intermediates.append(x_inter)
|
||||
|
||||
if features_only:
|
||||
return intermediates
|
||||
|
||||
x = self.norm(x)
|
||||
|
||||
return x, intermediates
|
||||
|
||||
def forward_features(self, x):
|
||||
x, feat_size = self.patch_embed(x)
|
||||
B, N, C = x.shape
|
||||
@ -862,6 +912,18 @@ class MultiScaleVit(nn.Module):
|
||||
|
||||
def checkpoint_filter_fn(state_dict, model):
|
||||
if 'stages.0.blocks.0.norm1.weight' in state_dict:
|
||||
# native checkpoint, look for rel_pos interpolations
|
||||
for k in state_dict.keys():
|
||||
if 'rel_pos' in k:
|
||||
rel_pos = state_dict[k]
|
||||
dest_rel_pos_shape = model.state_dict()[k].shape
|
||||
if rel_pos.shape[0] != dest_rel_pos_shape[0]:
|
||||
rel_pos_resized = torch.nn.functional.interpolate(
|
||||
rel_pos.reshape(1, rel_pos.shape[0], -1).permute(0, 2, 1),
|
||||
size=dest_rel_pos_shape[0],
|
||||
mode="linear",
|
||||
)
|
||||
state_dict[k] = rel_pos_resized.reshape(-1, dest_rel_pos_shape[0]).permute(1, 0)
|
||||
return state_dict
|
||||
|
||||
import re
|
||||
@ -892,16 +954,6 @@ def checkpoint_filter_fn(state_dict, model):
|
||||
k = k.replace('head.projection', 'head.fc')
|
||||
out_dict[k] = v
|
||||
|
||||
# for k, v in state_dict.items():
|
||||
# if model.pos_embed is not None and k == 'pos_embed' and v.shape[1] != model.pos_embed.shape[1]:
|
||||
# # To resize pos embedding when using model at different size from pretrained weights
|
||||
# v = resize_pos_embed(
|
||||
# v,
|
||||
# model.pos_embed,
|
||||
# 0 if getattr(model, 'no_embed_class') else getattr(model, 'num_prefix_tokens', 1),
|
||||
# model.patch_embed.grid_size
|
||||
# )
|
||||
|
||||
return out_dict
|
||||
|
||||
|
||||
@ -948,16 +1000,14 @@ model_cfgs = dict(
|
||||
|
||||
|
||||
def _create_mvitv2(variant, cfg_variant=None, pretrained=False, **kwargs):
|
||||
if kwargs.get('features_only', None):
|
||||
raise RuntimeError('features_only not implemented for Multiscale Vision Transformer models.')
|
||||
|
||||
out_indices = kwargs.pop('out_indices', 4)
|
||||
return build_model_with_cfg(
|
||||
MultiScaleVit,
|
||||
variant,
|
||||
pretrained,
|
||||
model_cfg=model_cfgs[variant] if not cfg_variant else model_cfgs[cfg_variant],
|
||||
pretrained_filter_fn=checkpoint_filter_fn,
|
||||
feature_cfg=dict(flatten_sequential=True),
|
||||
feature_cfg=dict(out_indices=out_indices, feature_cls='getter'),
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
|
@ -13,7 +13,7 @@ Code/weights from https://github.com/Meituan-AutoML/Twins, original copyright/li
|
||||
# --------------------------------------------------------
|
||||
import math
|
||||
from functools import partial
|
||||
from typing import Tuple
|
||||
from typing import List, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
@ -22,6 +22,7 @@ import torch.nn.functional as F
|
||||
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
|
||||
from timm.layers import Mlp, DropPath, to_2tuple, trunc_normal_, use_fused_attn
|
||||
from ._builder import build_model_with_cfg
|
||||
from ._features import feature_take_indices
|
||||
from ._features_fx import register_notrace_module
|
||||
from ._registry import register_model, generate_default_cfgs
|
||||
from .vision_transformer import Attention
|
||||
@ -324,6 +325,7 @@ class Twins(nn.Module):
|
||||
patch_size = 2
|
||||
|
||||
self.blocks = nn.ModuleList()
|
||||
self.feature_info = []
|
||||
dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))] # stochastic depth decay rule
|
||||
cur = 0
|
||||
for k in range(len(depths)):
|
||||
@ -339,6 +341,7 @@ class Twins(nn.Module):
|
||||
ws=1 if wss is None or i % 2 == 1 else wss[k]) for i in range(depths[k])],
|
||||
)
|
||||
self.blocks.append(_block)
|
||||
self.feature_info += [dict(module=f'block.{k}', num_chs=embed_dims[k], reduction=2**(2+k))]
|
||||
cur += depths[k]
|
||||
|
||||
self.pos_block = nn.ModuleList([PosConv(embed_dim, embed_dim) for embed_dim in embed_dims])
|
||||
@ -401,6 +404,53 @@ class Twins(nn.Module):
|
||||
if m.bias is not None:
|
||||
m.bias.data.zero_()
|
||||
|
||||
def forward_intermediates(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
n: Union[int, List[int], Tuple[int]] = None,
|
||||
norm: bool = False,
|
||||
stop_early: bool = True,
|
||||
output_fmt: str = 'NCHW',
|
||||
features_only: bool = False,
|
||||
) -> Union[List[torch.Tensor], Tuple[torch.Tensor, List[torch.Tensor]]]:
|
||||
assert output_fmt == 'NCHW', 'Output shape for Twins must be NCHW.'
|
||||
intermediates = []
|
||||
num_stages = len(self.blocks) # block list is two-tiered, first tier == stage
|
||||
if n is None:
|
||||
n = num_stages
|
||||
take_indices, max_index = feature_take_indices(n, num_stages)
|
||||
|
||||
# FIXME slice block/pos_block if < max
|
||||
|
||||
# forward pass
|
||||
B, _, height, width = x.shape
|
||||
for i, (embed, drop, blocks, pos_blk) in enumerate(zip(
|
||||
self.patch_embeds, self.pos_drops, self.blocks, self.pos_block)
|
||||
):
|
||||
x, size = embed(x)
|
||||
x = drop(x)
|
||||
for j, blk in enumerate(blocks):
|
||||
x = blk(x, size)
|
||||
if j == 0:
|
||||
x = pos_blk(x, size) # PEG here
|
||||
|
||||
if i < len(self.depths) - 1:
|
||||
x = x.reshape(B, *size, -1).permute(0, 3, 1, 2).contiguous()
|
||||
if i in take_indices:
|
||||
intermediates.append(x)
|
||||
else:
|
||||
if i in take_indices:
|
||||
# only last feature can be normed
|
||||
x_feat = self.norm(x) if norm else x
|
||||
intermediates.append(x_feat.reshape(B, *size, -1).permute(0, 3, 1, 2).contiguous())
|
||||
|
||||
if features_only:
|
||||
return intermediates
|
||||
|
||||
x = self.norm(x)
|
||||
|
||||
return x, intermediates
|
||||
|
||||
def forward_features(self, x):
|
||||
B = x.shape[0]
|
||||
for i, (embed, drop, blocks, pos_blk) in enumerate(
|
||||
@ -429,10 +479,12 @@ class Twins(nn.Module):
|
||||
|
||||
|
||||
def _create_twins(variant, pretrained=False, **kwargs):
|
||||
if kwargs.get('features_only', None):
|
||||
raise RuntimeError('features_only not implemented for Vision Transformer models.')
|
||||
|
||||
model = build_model_with_cfg(Twins, variant, pretrained, **kwargs)
|
||||
out_indices = kwargs.pop('out_indices', 4)
|
||||
model = build_model_with_cfg(
|
||||
Twins, variant, pretrained,
|
||||
feature_cfg=dict(out_indices=out_indices, feature_cls='getter'),
|
||||
**kwargs,
|
||||
)
|
||||
return model
|
||||
|
||||
|
||||
|
@ -45,6 +45,7 @@ from timm.layers import PatchEmbed, Mlp, DropPath, AttentionPoolLatent, RmsNorm,
|
||||
trunc_normal_, lecun_normal_, resample_patch_embed, resample_abs_pos_embed, use_fused_attn, \
|
||||
get_act_layer, get_norm_layer, LayerType
|
||||
from ._builder import build_model_with_cfg
|
||||
from ._features import feature_take_indices
|
||||
from ._manipulate import named_apply, checkpoint_seq, adapt_input_conv
|
||||
from ._registry import generate_default_cfgs, register_model, register_model_deprecations
|
||||
|
||||
@ -473,7 +474,6 @@ class VisionTransformer(nn.Module):
|
||||
self.no_embed_class = no_embed_class # don't embed prefix positions (includes reg)
|
||||
self.dynamic_img_size = dynamic_img_size
|
||||
self.grad_checkpointing = False
|
||||
self.feature_info = []
|
||||
|
||||
embed_args = {}
|
||||
if dynamic_img_size:
|
||||
@ -631,58 +631,111 @@ class VisionTransformer(nn.Module):
|
||||
|
||||
return self.pos_drop(x)
|
||||
|
||||
def _intermediate_layers(
|
||||
def forward_intermediates(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
n: Union[int, Sequence] = 1,
|
||||
) -> List[torch.Tensor]:
|
||||
outputs, num_blocks = [], len(self.blocks)
|
||||
take_indices = set(range(num_blocks - n, num_blocks) if isinstance(n, int) else n)
|
||||
last_index_to_take = max(take_indices)
|
||||
n: Optional[Union[int, List[int], Tuple[int]]] = None,
|
||||
return_prefix_tokens: bool = False,
|
||||
norm: bool = False,
|
||||
stop_early: bool = True,
|
||||
output_fmt: str = 'NCHW',
|
||||
features_only: bool = False,
|
||||
) -> Union[List[torch.Tensor], Tuple[torch.Tensor, List[torch.Tensor]]]:
|
||||
""" Forward features that returns intermediates.
|
||||
|
||||
Args:
|
||||
x: Input image tensor
|
||||
n: Take last n blocks if n is an int, if in is a sequence, select by matching indices
|
||||
return_prefix_tokens: Return both prefix and spatial intermediate tokens
|
||||
norm: Apply norm layer to all intermediates
|
||||
stop_early: Stop iterating over blocks when last desired intermediate hit
|
||||
output_fmt: Shape of intermediate feature outputs
|
||||
features_only: Only return intermediate features
|
||||
Returns:
|
||||
|
||||
"""
|
||||
assert output_fmt in ('NCHW', 'NLC'), 'Output format for ViT features must be one of NCHW or NLC.'
|
||||
reshape = output_fmt == 'NCHW'
|
||||
intermediates = []
|
||||
num_blocks = len(self.blocks)
|
||||
if n is None:
|
||||
n = num_blocks
|
||||
take_indices, max_index = feature_take_indices(n, num_blocks)
|
||||
|
||||
# forward pass
|
||||
B, _, height, width = x.shape
|
||||
x = self.patch_embed(x)
|
||||
x = self._pos_embed(x)
|
||||
x = self.patch_drop(x)
|
||||
x = self.norm_pre(x)
|
||||
for i, blk in enumerate(self.blocks[: last_index_to_take + 1]):
|
||||
if torch.jit.is_scripting() or not stop_early: # can't slice blocks in torchscript
|
||||
blocks = self.blocks
|
||||
else:
|
||||
blocks = self.blocks[:max_index + 1]
|
||||
for i, blk in enumerate(blocks):
|
||||
x = blk(x)
|
||||
if i in take_indices:
|
||||
outputs.append(x)
|
||||
# normalize intermediates with final norm layer if enabled
|
||||
intermediates.append(self.norm(x) if norm else x)
|
||||
|
||||
return outputs
|
||||
# process intermediates
|
||||
if self.num_prefix_tokens:
|
||||
# split prefix (e.g. class, distill) and spatial feature tokens
|
||||
prefix_tokens = [y[:, 0:self.num_prefix_tokens] for y in intermediates]
|
||||
intermediates = [y[:, self.num_prefix_tokens:] for y in intermediates]
|
||||
if reshape:
|
||||
# reshape == True => BCHW output format
|
||||
patch_size = self.patch_embed.patch_size
|
||||
H = int(math.ceil(height / patch_size[0]))
|
||||
W = int(math.ceil(width / patch_size[1]))
|
||||
intermediates = [y.reshape(B, H, W, -1).permute(0, 3, 1, 2).contiguous() for y in intermediates]
|
||||
if not torch.jit.is_scripting() and return_prefix_tokens:
|
||||
# return_prefix not support in torchscript due to poor type handling
|
||||
intermediates = list(zip(intermediates, prefix_tokens))
|
||||
|
||||
if features_only:
|
||||
return intermediates
|
||||
|
||||
x = self.norm(x)
|
||||
|
||||
return x, intermediates
|
||||
|
||||
def prune_intermediate_layers(
|
||||
self,
|
||||
n: Union[int, List[int], Tuple[int]] = 1,
|
||||
prune_norm: bool = False,
|
||||
prune_head: bool = True,
|
||||
):
|
||||
""" Prune layers not required for specified intermediates.
|
||||
"""
|
||||
take_indices, max_index = feature_take_indices(n, len(self.blocks))
|
||||
self.blocks = self.blocks[:max_index + 1] # truncate blocks
|
||||
if prune_norm:
|
||||
self.norm = nn.Identity()
|
||||
if prune_head:
|
||||
if self.attn_pool is not None:
|
||||
self.attn_pool = None
|
||||
self.fc_norm = nn.Identity()
|
||||
self.head = nn.Identity()
|
||||
|
||||
def get_intermediate_layers(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
n: Union[int, Sequence] = 1,
|
||||
n: Union[int, List[int], Tuple[int]] = 1,
|
||||
reshape: bool = False,
|
||||
return_prefix_tokens: bool = False,
|
||||
norm: bool = False,
|
||||
) -> Tuple[Union[torch.Tensor, Tuple[torch.Tensor]]]:
|
||||
""" Intermediate layer accessor (NOTE: This is a WIP experiment).
|
||||
Inspired by DINO / DINOv2 interface
|
||||
) -> List[torch.Tensor]:
|
||||
""" Intermediate layer accessor inspired by DINO / DINOv2 interface.
|
||||
NOTE: This API is for backwards compat, favour using forward_intermediates() directly.
|
||||
"""
|
||||
# take last n blocks if n is an int, if in is a sequence, select by matching indices
|
||||
outputs = self._intermediate_layers(x, n)
|
||||
if norm:
|
||||
outputs = [self.norm(out) for out in outputs]
|
||||
prefix_tokens = [out[:, 0:self.num_prefix_tokens] for out in outputs]
|
||||
outputs = [out[:, self.num_prefix_tokens:] for out in outputs]
|
||||
|
||||
if reshape:
|
||||
patch_size = self.patch_embed.patch_size
|
||||
batch, _, height, width = x.size()
|
||||
outputs = [
|
||||
out.reshape(batch, int(math.ceil(height / patch_size[0])), int(math.ceil(width / patch_size[1])), -1)
|
||||
.permute(0, 3, 1, 2)
|
||||
.contiguous()
|
||||
for out in outputs
|
||||
]
|
||||
|
||||
if return_prefix_tokens:
|
||||
return tuple(zip(outputs, prefix_tokens))
|
||||
return tuple(outputs)
|
||||
return self.forward_intermediates(
|
||||
x, n,
|
||||
return_prefix_tokens=return_prefix_tokens,
|
||||
norm=norm,
|
||||
output_fmt='NCHW' if reshape else 'NLC',
|
||||
features_only=True,
|
||||
)
|
||||
|
||||
def forward_features(self, x: torch.Tensor) -> torch.Tensor:
|
||||
x = self.patch_embed(x)
|
||||
@ -2485,7 +2538,7 @@ def vit_huge_patch14_xp_224(pretrained: bool = False, **kwargs) -> VisionTransfo
|
||||
def vit_small_patch14_dinov2(pretrained: bool = False, **kwargs) -> VisionTransformer:
|
||||
""" ViT-S/14 for DINOv2
|
||||
"""
|
||||
model_args = dict(patch_size=14, embed_dim=384, depth=12, num_heads=6, init_values=1e-5, img_size=518)
|
||||
model_args = dict(patch_size=14, embed_dim=384, depth=12, num_heads=6, init_values=1e-5)
|
||||
model = _create_vision_transformer(
|
||||
'vit_small_patch14_dinov2', pretrained=pretrained, **dict(model_args, **kwargs))
|
||||
return model
|
||||
@ -2495,7 +2548,7 @@ def vit_small_patch14_dinov2(pretrained: bool = False, **kwargs) -> VisionTransf
|
||||
def vit_base_patch14_dinov2(pretrained: bool = False, **kwargs) -> VisionTransformer:
|
||||
""" ViT-B/14 for DINOv2
|
||||
"""
|
||||
model_args = dict(patch_size=14, embed_dim=768, depth=12, num_heads=12, init_values=1e-5, img_size=518)
|
||||
model_args = dict(patch_size=14, embed_dim=768, depth=12, num_heads=12, init_values=1e-5)
|
||||
model = _create_vision_transformer(
|
||||
'vit_base_patch14_dinov2', pretrained=pretrained, **dict(model_args, **kwargs))
|
||||
return model
|
||||
@ -2505,7 +2558,7 @@ def vit_base_patch14_dinov2(pretrained: bool = False, **kwargs) -> VisionTransfo
|
||||
def vit_large_patch14_dinov2(pretrained: bool = False, **kwargs) -> VisionTransformer:
|
||||
""" ViT-L/14 for DINOv2
|
||||
"""
|
||||
model_args = dict(patch_size=14, embed_dim=1024, depth=24, num_heads=16, init_values=1e-5, img_size=518)
|
||||
model_args = dict(patch_size=14, embed_dim=1024, depth=24, num_heads=16, init_values=1e-5)
|
||||
model = _create_vision_transformer(
|
||||
'vit_large_patch14_dinov2', pretrained=pretrained, **dict(model_args, **kwargs))
|
||||
return model
|
||||
@ -2521,7 +2574,7 @@ def vit_giant_patch14_dinov2(pretrained: bool = False, **kwargs) -> VisionTransf
|
||||
# With SwiGLUPacked, we need to set hidden_features = 2 * 4096 = 8192
|
||||
model_args = dict(
|
||||
patch_size=14, embed_dim=1536, depth=40, num_heads=24, init_values=1e-5,
|
||||
mlp_ratio=2.66667 * 2, mlp_layer=SwiGLUPacked, img_size=518, act_layer=nn.SiLU
|
||||
mlp_ratio=2.66667 * 2, mlp_layer=SwiGLUPacked, act_layer=nn.SiLU
|
||||
)
|
||||
model = _create_vision_transformer(
|
||||
'vit_giant_patch14_dinov2', pretrained=pretrained, **dict(model_args, **kwargs))
|
||||
|
@ -11,21 +11,22 @@ A PyTorch implement of Vision Transformers as described in:
|
||||
"""
|
||||
import logging
|
||||
from functools import partial
|
||||
from typing import Callable, Optional, Tuple
|
||||
from typing import Callable, List, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
import torch.utils.checkpoint
|
||||
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD, IMAGENET_INCEPTION_MEAN, IMAGENET_INCEPTION_STD
|
||||
from timm.layers import PatchEmbed, Mlp, DropPath, PatchDropout, LayerNorm2d, ClassifierHead, NormMlpClassifierHead, \
|
||||
Format, resample_abs_pos_embed_nhwc, RotaryEmbeddingCat, apply_rot_embed_cat, to_2tuple, use_fused_attn
|
||||
from torch.jit import Final
|
||||
|
||||
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD, IMAGENET_INCEPTION_MEAN, IMAGENET_INCEPTION_STD
|
||||
from timm.layers import PatchEmbed, Mlp, DropPath, PatchDropout, LayerNorm2d, ClassifierHead, NormMlpClassifierHead,\
|
||||
Format, resample_abs_pos_embed_nhwc, RotaryEmbeddingCat, apply_rot_embed_cat, to_2tuple, use_fused_attn
|
||||
from ._builder import build_model_with_cfg
|
||||
from ._features import feature_take_indices
|
||||
from ._features_fx import register_notrace_function
|
||||
from ._manipulate import checkpoint_seq
|
||||
from ._registry import generate_default_cfgs, register_model
|
||||
from ._features_fx import register_notrace_function
|
||||
|
||||
# model_registry will add each entrypoint fn to this
|
||||
__all__ = ['VisionTransformerSAM']
|
||||
@ -343,8 +344,7 @@ class VisionTransformerSAM(nn.Module):
|
||||
attn_drop_rate: float = 0.,
|
||||
drop_path_rate: float = 0.,
|
||||
weight_init: str = '',
|
||||
embed_layer: Callable = partial(
|
||||
PatchEmbed, output_fmt=Format.NHWC, strict_img_size=False),
|
||||
embed_layer: Callable = partial(PatchEmbed, output_fmt=Format.NHWC, strict_img_size=False),
|
||||
norm_layer: Optional[Callable] = nn.LayerNorm,
|
||||
act_layer: Optional[Callable] = nn.GELU,
|
||||
block_fn: Callable = Block,
|
||||
@ -469,6 +469,8 @@ class VisionTransformerSAM(nn.Module):
|
||||
rope=self.rope_window if i not in global_attn_indexes else self.rope_global,
|
||||
)
|
||||
for i in range(depth)])
|
||||
self.feature_info = [
|
||||
dict(module=f'blocks.{i}', num_chs=embed_dim, reduction=patch_size) for i in range(depth)]
|
||||
|
||||
if neck_chans:
|
||||
self.neck = nn.Sequential(
|
||||
@ -536,6 +538,52 @@ class VisionTransformerSAM(nn.Module):
|
||||
def reset_classifier(self, num_classes=0, global_pool=None):
|
||||
self.head.reset(num_classes, global_pool)
|
||||
|
||||
def forward_intermediates(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
n: Union[int, List[int], Tuple[int]] = None,
|
||||
norm: bool = False,
|
||||
stop_early: bool = True,
|
||||
output_fmt: str = 'NCHW',
|
||||
features_only: bool = False,
|
||||
) -> Union[List[torch.Tensor], Tuple[torch.Tensor, List[torch.Tensor]]]:
|
||||
assert output_fmt == 'NCHW', 'Output shape for ViT-SAM must be NCHW.'
|
||||
intermediates = []
|
||||
num_blocks = len(self.blocks)
|
||||
if n is None:
|
||||
n = num_blocks
|
||||
take_indices, max_index = feature_take_indices(n, num_blocks)
|
||||
|
||||
# forward pass, collect intermediates
|
||||
x = self.patch_embed(x)
|
||||
if self.pos_embed is not None:
|
||||
# dynamically resize abs pos embedding if needed
|
||||
x = x + resample_abs_pos_embed_nhwc(self.pos_embed, x.shape[1:3])
|
||||
x = self.pos_drop(x)
|
||||
x = self.patch_drop(x)
|
||||
x = self.norm_pre(x)
|
||||
if torch.jit.is_scripting() or not stop_early: # can't slice blocks in torchscript
|
||||
blocks = self.blocks
|
||||
else:
|
||||
blocks = self.blocks[:max_index + 1]
|
||||
for i, blk in enumerate(blocks):
|
||||
x = blk(x)
|
||||
if i in take_indices:
|
||||
# make output BCHW
|
||||
if norm:
|
||||
# norm is intertwined with neck convs so apply both, changes the dim
|
||||
# FIXME only apply to final? Need experiments
|
||||
intermediates.append(self.neck(x.permute(0, 3, 1, 2)))
|
||||
else:
|
||||
intermediates.append(x.permute(0, 3, 1, 2))
|
||||
|
||||
if features_only:
|
||||
return intermediates
|
||||
|
||||
x = self.neck(x.permute(0, 3, 1, 2))
|
||||
|
||||
return x, intermediates
|
||||
|
||||
def forward_features(self, x):
|
||||
x = self.patch_embed(x)
|
||||
if self.pos_embed is not None:
|
||||
@ -618,15 +666,13 @@ default_cfgs = generate_default_cfgs({
|
||||
|
||||
|
||||
def _create_vision_transformer(variant, pretrained=False, **kwargs):
|
||||
if kwargs.get('features_only', None):
|
||||
raise RuntimeError(
|
||||
'features_only not implemented for Vision Transformer models.')
|
||||
|
||||
out_indices = kwargs.pop('out_indices', 3)
|
||||
return build_model_with_cfg(
|
||||
VisionTransformerSAM,
|
||||
variant,
|
||||
pretrained,
|
||||
pretrained_filter_fn=checkpoint_filter_fn,
|
||||
feature_cfg=dict(out_indices=out_indices, feature_cls='getter'),
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user