Merge pull request #2136 from huggingface/vit_features_only

Exploring vit features_only via new forward_intermediates() API, inspired by #2131
pull/2162/head
Ross Wightman 2024-04-11 08:38:20 -07:00 committed by GitHub
commit d6b95520f1
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
23 changed files with 864 additions and 127 deletions

View File

@ -47,11 +47,16 @@ if hasattr(torch._C, '_jit_set_profiling_executor'):
torch._C._jit_set_profiling_executor(True)
torch._C._jit_set_profiling_mode(False)
# models with forward_intermediates() and support for FeatureGetterNet features_only wrapper
FEAT_INTER_FILTERS = [
'vit_*', 'twins_*', 'deit*', 'beit*', 'mvitv2*', 'eva*', 'samvit_*', 'flexivit*'
]
# transformer models don't support many of the spatial / feature based model functionalities
NON_STD_FILTERS = [
'vit_*', 'tnt_*', 'pit_*', 'coat_*', 'cait_*', '*mixer_*', 'gmlp_*', 'resmlp_*', 'twins_*',
'convit_*', 'levit*', 'visformer*', 'deit*', 'jx_nest_*', 'nest_*', 'xcit_*', 'crossvit_*', 'beit*',
'poolformer_*', 'volo_*', 'sequencer2d_*', 'pvt_v2*', 'mvitv2*', 'gcvit*', 'efficientformer*',
'convit_*', 'levit*', 'visformer*', 'deit*', 'xcit_*', 'crossvit_*', 'beit*',
'poolformer_*', 'volo_*', 'sequencer2d_*', 'mvitv2*', 'gcvit*', 'efficientformer*',
'eva_*', 'flexivit*', 'eva02*', 'samvit_*', 'efficientvit_m*', 'tiny_vit_*'
]
NUM_NON_STD = len(NON_STD_FILTERS)
@ -351,7 +356,7 @@ if 'GITHUB_ACTIONS' in os.environ: # and 'Linux' in platform.system():
@pytest.mark.features
@pytest.mark.timeout(120)
@pytest.mark.parametrize('model_name', list_models(exclude_filters=EXCLUDE_FILTERS + EXCLUDE_FEAT_FILTERS, include_tags=True))
@pytest.mark.parametrize('model_name', list_models(exclude_filters=EXCLUDE_FILTERS + EXCLUDE_FEAT_FILTERS))
@pytest.mark.parametrize('batch_size', [1])
def test_model_forward_features(model_name, batch_size):
"""Run a single forward pass with each model in feature extraction mode"""
@ -359,7 +364,7 @@ def test_model_forward_features(model_name, batch_size):
model.eval()
expected_channels = model.feature_info.channels()
expected_reduction = model.feature_info.reduction()
assert len(expected_channels) >= 4 # all models here should have at least 4 feature levels by default, some 5 or 6
assert len(expected_channels) >= 3 # all models here should have at least 3 default feat levels
input_size = _get_input_size(model=model, target=TARGET_FFEAT_SIZE)
if max(input_size) > MAX_FFEAT_SIZE:
@ -380,6 +385,72 @@ def test_model_forward_features(model_name, batch_size):
assert not torch.isnan(o).any()
@pytest.mark.features
@pytest.mark.timeout(120)
@pytest.mark.parametrize('model_name', list_models(FEAT_INTER_FILTERS, exclude_filters=EXCLUDE_FILTERS))
@pytest.mark.parametrize('batch_size', [1])
def test_model_forward_intermediates_features(model_name, batch_size):
"""Run a single forward pass with each model in feature extraction mode"""
model = create_model(model_name, pretrained=False, features_only=True)
model.eval()
print(model.feature_info.out_indices)
expected_channels = model.feature_info.channels()
expected_reduction = model.feature_info.reduction()
input_size = _get_input_size(model=model, target=TARGET_FFEAT_SIZE)
if max(input_size) > MAX_FFEAT_SIZE:
pytest.skip("Fixed input size model > limit.")
output_fmt = getattr(model, 'output_fmt', 'NCHW')
feat_axis = get_channel_dim(output_fmt)
spatial_axis = get_spatial_dim(output_fmt)
import math
outputs = model(torch.randn((batch_size, *input_size)))
assert len(expected_channels) == len(outputs)
spatial_size = input_size[-2:]
for e, r, o in zip(expected_channels, expected_reduction, outputs):
print(o.shape)
assert e == o.shape[feat_axis]
assert o.shape[spatial_axis[0]] <= math.ceil(spatial_size[0] / r) + 1
assert o.shape[spatial_axis[1]] <= math.ceil(spatial_size[1] / r) + 1
assert o.shape[0] == batch_size
assert not torch.isnan(o).any()
@pytest.mark.features
@pytest.mark.timeout(120)
@pytest.mark.parametrize('model_name', list_models(FEAT_INTER_FILTERS, exclude_filters=EXCLUDE_FILTERS))
@pytest.mark.parametrize('batch_size', [1])
def test_model_forward_intermediates(model_name, batch_size):
"""Run a single forward pass with each model in feature extraction mode"""
model = create_model(model_name, pretrained=False)
model.eval()
feature_info = timm.models.FeatureInfo(model.feature_info, len(model.feature_info))
expected_channels = feature_info.channels()
expected_reduction = feature_info.reduction()
assert len(expected_channels) >= 4 # all models here should have at least 4 feature levels by default, some 5 or 6
input_size = _get_input_size(model=model, target=TARGET_FFEAT_SIZE)
if max(input_size) > MAX_FFEAT_SIZE:
pytest.skip("Fixed input size model > limit.")
output_fmt = getattr(model, 'output_fmt', 'NCHW')
feat_axis = get_channel_dim(output_fmt)
spatial_axis = get_spatial_dim(output_fmt)
import math
output, intermediates = model.forward_intermediates(
torch.randn((batch_size, *input_size)),
)
assert len(expected_channels) == len(intermediates)
spatial_size = input_size[-2:]
for e, r, o in zip(expected_channels, expected_reduction, intermediates):
assert e == o.shape[feat_axis]
assert o.shape[spatial_axis[0]] <= math.ceil(spatial_size[0] / r) + 1
assert o.shape[spatial_axis[1]] <= math.ceil(spatial_size[1] / r) + 1
assert o.shape[0] == batch_size
assert not torch.isnan(o).any()
def _create_fx_model(model, train=False):
# This block of code does a bit of juggling to handle any case where there are multiple outputs in train mode
# So we trace once and look at the graph, and get the indices of the nodes that lead into the original fx output

View File

@ -9,6 +9,7 @@ Based on code in:
Hacked together by / Copyright 2020 Ross Wightman
"""
import logging
import math
from typing import Callable, List, Optional, Tuple, Union
import torch
@ -65,6 +66,21 @@ class PatchEmbed(nn.Module):
self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size, bias=bias)
self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity()
def feat_ratio(self, as_scalar=True) -> Union[Tuple[int, int], int]:
if as_scalar:
return max(self.patch_size)
else:
return self.patch_size
def dynamic_feat_size(self, img_size: Tuple[int, int]) -> Tuple[int, int]:
""" Get grid (feature) size for given image size taking account of dynamic padding.
NOTE: must be torchscript compatible so using fixed tuple indexing
"""
if self.dynamic_img_pad:
return math.ceil(img_size[0] / self.patch_size[0]), math.ceil(img_size[1] / self.patch_size[1])
else:
return img_size[0] // self.patch_size[0], img_size[1] // self.patch_size[1]
def forward(self, x):
B, C, H, W = x.shape
if self.img_size is not None:
@ -127,13 +143,13 @@ class PatchEmbedWithSize(PatchEmbed):
_assert(W % self.patch_size[1] == 0, f"Input image width ({W}) must be divisible by patch size ({self.patch_size[1]}).")
x = self.proj(x)
grid_size = x.shape[-2:]
feat_size = x.shape[-2:]
if self.flatten:
x = x.flatten(2).transpose(1, 2) # NCHW -> NLC
elif self.output_fmt != Format.NCHW:
x = nchw_to(x, self.output_fmt)
x = self.norm(x)
return x, grid_size
return x, feat_size
def resample_patch_embed(

View File

@ -7,7 +7,7 @@ from typing import Optional, Dict, Callable, Any, Tuple
from torch import nn as nn
from torch.hub import load_state_dict_from_url
from timm.models._features import FeatureListNet, FeatureHookNet
from timm.models._features import FeatureListNet, FeatureDictNet, FeatureHookNet, FeatureGetterNet
from timm.models._features_fx import FeatureGraphNet
from timm.models._helpers import load_state_dict
from timm.models._hub import has_hf_hub, download_cached_file, check_cached_file, load_state_dict_from_hf
@ -428,8 +428,12 @@ def build_model_with_cfg(
feature_cls = feature_cls.lower()
if 'hook' in feature_cls:
feature_cls = FeatureHookNet
elif feature_cls == 'dict':
feature_cls = FeatureDictNet
elif feature_cls == 'fx':
feature_cls = FeatureGraphNet
elif feature_cls == 'getter':
feature_cls = FeatureGetterNet
else:
assert False, f'Unknown feature class {feature_cls}'
model = feature_cls(model, **feature_cfg)

View File

@ -11,7 +11,7 @@ Hacked together by / Copyright 2020 Ross Wightman
from collections import OrderedDict, defaultdict
from copy import deepcopy
from functools import partial
from typing import Dict, List, Sequence, Tuple, Union
from typing import Dict, List, Optional, Sequence, Set, Tuple, Union
import torch
import torch.nn as nn
@ -20,12 +20,70 @@ from torch.utils.checkpoint import checkpoint
from timm.layers import Format
__all__ = ['FeatureInfo', 'FeatureHooks', 'FeatureDictNet', 'FeatureListNet', 'FeatureHookNet']
__all__ = [
'FeatureInfo', 'FeatureHooks', 'FeatureDictNet', 'FeatureListNet', 'FeatureHookNet', 'FeatureGetterNet',
'feature_take_indices'
]
def _take_indices(
num_blocks: int,
n: Optional[Union[int, List[int], Tuple[int]]],
) -> Tuple[Set[int], int]:
if isinstance(n, int):
assert n >= 0
take_indices = {x for x in range(num_blocks - n, num_blocks)}
else:
take_indices = {num_blocks + idx if idx < 0 else idx for idx in n}
return take_indices, max(take_indices)
def _take_indices_jit(
num_blocks: int,
n: Union[int, List[int], Tuple[int]],
) -> Tuple[List[int], int]:
if isinstance(n, int):
assert n >= 0
take_indices = [num_blocks - n + i for i in range(n)]
elif isinstance(n, tuple):
# splitting this up is silly, but needed for torchscript type resolution of n
take_indices = [num_blocks + idx if idx < 0 else idx for idx in n]
else:
take_indices = [num_blocks + idx if idx < 0 else idx for idx in n]
return take_indices, max(take_indices)
def feature_take_indices(
num_blocks: int,
indices: Optional[Union[int, List[int], Tuple[int]]] = None,
) -> Tuple[List[int], int]:
if indices is None:
indices = num_blocks # all blocks if None
if torch.jit.is_scripting():
return _take_indices_jit(num_blocks, indices)
else:
# NOTE non-jit returns Set[int] instead of List[int] but torchscript can't handle that anno
return _take_indices(num_blocks, indices)
def _out_indices_as_tuple(x: Union[int, Tuple[int, ...]]) -> Tuple[int, ...]:
if isinstance(x, int):
# if indices is an int, take last N features
return tuple(range(-x, 0))
return tuple(x)
OutIndicesT = Union[int, Tuple[int, ...]]
class FeatureInfo:
def __init__(self, feature_info: List[Dict], out_indices: Tuple[int]):
def __init__(
self,
feature_info: List[Dict],
out_indices: OutIndicesT,
):
out_indices = _out_indices_as_tuple(out_indices)
prev_reduction = 1
for i, fi in enumerate(feature_info):
# sanity check the mandatory fields, there may be additional fields depending on the model
@ -37,14 +95,15 @@ class FeatureInfo:
self.out_indices = out_indices
self.info = feature_info
def from_other(self, out_indices: Tuple[int]):
def from_other(self, out_indices: OutIndicesT):
out_indices = _out_indices_as_tuple(out_indices)
return FeatureInfo(deepcopy(self.info), out_indices)
def get(self, key, idx=None):
def get(self, key: str, idx: Optional[Union[int, List[int]]] = None):
""" Get value by key at specified index (indices)
if idx == None, returns value for key at each output index
if idx is an integer, return value for that feature module index (ignoring output indices)
if idx is a list/tupple, return value for each module index (ignoring output indices)
if idx is a list/tuple, return value for each module index (ignoring output indices)
"""
if idx is None:
return [self.info[i][key] for i in self.out_indices]
@ -53,7 +112,7 @@ class FeatureInfo:
else:
return self.info[idx][key]
def get_dicts(self, keys=None, idx=None):
def get_dicts(self, keys: Optional[List[str]] = None, idx: Optional[Union[int, List[int]]] = None):
""" return info dicts for specified keys (or all if None) at specified indices (or out_indices if None)
"""
if idx is None:
@ -66,17 +125,17 @@ class FeatureInfo:
else:
return self.info[idx] if keys is None else {k: self.info[idx][k] for k in keys}
def channels(self, idx=None):
def channels(self, idx: Optional[Union[int, List[int]]] = None):
""" feature channels accessor
"""
return self.get('num_chs', idx)
def reduction(self, idx=None):
def reduction(self, idx: Optional[Union[int, List[int]]] = None):
""" feature reduction (output stride) accessor
"""
return self.get('reduction', idx)
def module_name(self, idx=None):
def module_name(self, idx: Optional[Union[int, List[int]]] = None):
""" feature module name accessor
"""
return self.get('module', idx)
@ -146,7 +205,7 @@ def _module_list(module, flatten_sequential=False):
return ml
def _get_feature_info(net, out_indices):
def _get_feature_info(net, out_indices: OutIndicesT):
feature_info = getattr(net, 'feature_info')
if isinstance(feature_info, FeatureInfo):
return feature_info.from_other(out_indices)
@ -182,7 +241,7 @@ class FeatureDictNet(nn.ModuleDict):
def __init__(
self,
model: nn.Module,
out_indices: Tuple[int, ...] = (0, 1, 2, 3, 4),
out_indices: OutIndicesT = (0, 1, 2, 3, 4),
out_map: Sequence[Union[int, str]] = None,
output_fmt: str = 'NCHW',
feature_concat: bool = False,
@ -257,7 +316,7 @@ class FeatureListNet(FeatureDictNet):
def __init__(
self,
model: nn.Module,
out_indices: Tuple[int, ...] = (0, 1, 2, 3, 4),
out_indices: OutIndicesT = (0, 1, 2, 3, 4),
output_fmt: str = 'NCHW',
feature_concat: bool = False,
flatten_sequential: bool = False,
@ -298,8 +357,8 @@ class FeatureHookNet(nn.ModuleDict):
def __init__(
self,
model: nn.Module,
out_indices: Tuple[int, ...] = (0, 1, 2, 3, 4),
out_map: Sequence[Union[int, str]] = None,
out_indices: OutIndicesT = (0, 1, 2, 3, 4),
out_map: Optional[Sequence[Union[int, str]]] = None,
return_dict: bool = False,
output_fmt: str = 'NCHW',
no_rewrite: bool = False,
@ -366,3 +425,55 @@ class FeatureHookNet(nn.ModuleDict):
x = module(x)
out = self.hooks.get_output(x.device)
return out if self.return_dict else list(out.values())
class FeatureGetterNet(nn.ModuleDict):
""" FeatureGetterNet
Wrap models with a feature getter method, like 'get_intermediate_layers'
"""
def __init__(
self,
model: nn.Module,
out_indices: OutIndicesT = 4,
out_map: Optional[Sequence[Union[int, str]]] = None,
return_dict: bool = False,
output_fmt: str = 'NCHW',
norm: bool = False,
prune: bool = True,
):
"""
Args:
model: Model to wrap.
out_indices: Indices of features to extract.
out_map: Remap feature names for dict output (WIP, not supported).
return_dict: Return features as dictionary instead of list (WIP, not supported).
norm: Apply final model norm to all output features (if possible).
"""
super().__init__()
if prune and hasattr(model, 'prune_intermediate_layers'):
# replace out_indices after they've been normalized, -ve indices will be invalid after prune
out_indices = model.prune_intermediate_layers(
out_indices,
prune_norm=not norm,
)
out_indices = list(out_indices)
self.feature_info = _get_feature_info(model, out_indices)
self.model = model
self.out_indices = out_indices
self.out_map = out_map
self.return_dict = return_dict
self.output_fmt = output_fmt
self.norm = norm
def forward(self, x):
features = self.model.forward_intermediates(
x,
indices=self.out_indices,
norm=self.norm,
output_fmt=self.output_fmt,
intermediates_only=True,
)
return features

View File

@ -1,7 +1,7 @@
""" PyTorch FX Based Feature Extraction Helpers
Using https://pytorch.org/vision/stable/feature_extraction.html
"""
from typing import Callable, List, Dict, Union, Type
from typing import Callable, Dict, List, Optional, Union, Tuple, Type
import torch
from torch import nn
@ -103,7 +103,12 @@ def create_feature_extractor(model: nn.Module, return_nodes: Union[Dict[str, str
class FeatureGraphNet(nn.Module):
""" A FX Graph based feature extractor that works with the model feature_info metadata
"""
def __init__(self, model, out_indices, out_map=None):
def __init__(
self,
model: nn.Module,
out_indices: Tuple[int, ...],
out_map: Optional[Dict] = None,
):
super().__init__()
assert has_fx_feature_extraction, 'Please update to PyTorch 1.10+, torchvision 0.11+ for FX feature extraction'
self.feature_info = _get_feature_info(model, out_indices)
@ -129,7 +134,12 @@ class GraphExtractNet(nn.Module):
return_nodes: node names to return features from (dict or list)
squeeze_out: if only one output, and output in list format, flatten to single tensor
"""
def __init__(self, model, return_nodes: Union[Dict[str, str], List[str]], squeeze_out: bool = True):
def __init__(
self,
model: nn.Module,
return_nodes: Union[Dict[str, str], List[str]],
squeeze_out: bool = True,
):
super().__init__()
self.squeeze_out = squeeze_out
self.graph_module = create_feature_extractor(model, return_nodes)

View File

@ -39,7 +39,7 @@ Modifications by / Copyright 2021 Ross Wightman, original copyrights below
# --------------------------------------------------------'
import math
from typing import Callable, Optional, Tuple, Union
from typing import Callable, List, Optional, Tuple, Union
import torch
import torch.nn as nn
@ -52,8 +52,8 @@ from timm.layers import resample_patch_embed, resample_abs_pos_embed, resize_rel
from ._builder import build_model_with_cfg
from ._features import feature_take_indices
from ._registry import generate_default_cfgs, register_model
from .vision_transformer import checkpoint_filter_fn
__all__ = ['Beit']
@ -302,6 +302,7 @@ class Beit(nn.Module):
embed_dim=embed_dim,
)
num_patches = self.patch_embed.num_patches
r = self.patch_embed.feat_ratio() if hasattr(self.patch_embed, 'feat_ratio') else patch_size
self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
# self.mask_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
@ -333,6 +334,8 @@ class Beit(nn.Module):
window_size=self.patch_embed.grid_size if use_rel_pos_bias else None,
)
for i in range(depth)])
self.feature_info = [
dict(module=f'blocks.{i}', num_chs=embed_dim, reduction=r) for i in range(depth)]
use_fc_norm = self.global_pool == 'avg'
self.norm = nn.Identity() if use_fc_norm else norm_layer(embed_dim)
@ -398,6 +401,89 @@ class Beit(nn.Module):
self.global_pool = global_pool
self.head = nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity()
def forward_intermediates(
self,
x: torch.Tensor,
indices: Optional[Union[int, List[int], Tuple[int]]] = None,
return_prefix_tokens: bool = False,
norm: bool = False,
stop_early: bool = True,
output_fmt: str = 'NCHW',
intermediates_only: bool = False,
) -> Union[List[torch.Tensor], Tuple[torch.Tensor, List[torch.Tensor]]]:
""" Forward features that returns intermediates.
Args:
x: Input image tensor
indices: Take last n blocks if an int, if is a sequence, select by matching indices
return_prefix_tokens: Return both prefix and spatial intermediate tokens
norm: Apply norm layer to all intermediates
stop_early: Stop iterating over blocks when last desired intermediate hit
output_fmt: Shape of intermediate feature outputs
intermediates_only: Only return intermediate features
Returns:
"""
assert output_fmt in ('NCHW', 'NLC'), 'Output format for ViT features must be one of NCHW or NLC.'
reshape = output_fmt == 'NCHW'
intermediates = []
take_indices, max_index = feature_take_indices(len(self.blocks), indices)
# forward pass
B, _, height, width = x.shape
x = self.patch_embed(x)
x = torch.cat((self.cls_token.expand(x.shape[0], -1, -1), x), dim=1)
if self.pos_embed is not None:
x = x + self.pos_embed
x = self.pos_drop(x)
rel_pos_bias = self.rel_pos_bias() if self.rel_pos_bias is not None else None
if torch.jit.is_scripting() or not stop_early: # can't slice blocks in torchscript
blocks = self.blocks
else:
blocks = self.blocks[:max_index + 1]
for i, blk in enumerate(blocks):
x = blk(x, shared_rel_pos_bias=rel_pos_bias)
if i in take_indices:
# normalize intermediates with final norm layer if enabled
intermediates.append(self.norm(x) if norm else x)
# process intermediates
if self.num_prefix_tokens:
# split prefix (e.g. class, distill) and spatial feature tokens
prefix_tokens = [y[:, 0:self.num_prefix_tokens] for y in intermediates]
intermediates = [y[:, self.num_prefix_tokens:] for y in intermediates]
if reshape:
# reshape to BCHW output format
H, W = self.patch_embed.dynamic_feat_size((height, width))
intermediates = [y.reshape(B, H, W, -1).permute(0, 3, 1, 2).contiguous() for y in intermediates]
if not torch.jit.is_scripting() and return_prefix_tokens:
# return_prefix not support in torchscript due to poor type handling
intermediates = list(zip(intermediates, prefix_tokens))
if intermediates_only:
return intermediates
x = self.norm(x)
return x, intermediates
def prune_intermediate_layers(
self,
n: Union[int, List[int], Tuple[int]] = 1,
prune_norm: bool = False,
prune_head: bool = True,
):
""" Prune layers not required for specified intermediates.
"""
take_indices, max_index = feature_take_indices(len(self.blocks), n)
self.blocks = self.blocks[:max_index + 1] # truncate blocks
if prune_norm:
self.norm = nn.Identity()
if prune_head:
self.fc_norm = nn.Identity()
self.head = nn.Identity()
return take_indices
def forward_features(self, x):
x = self.patch_embed(x)
x = torch.cat((self.cls_token.expand(x.shape[0], -1, -1), x), dim=1)
@ -547,14 +633,13 @@ def _beit_checkpoint_filter_fn(state_dict, model, interpolation='bicubic', antia
def _create_beit(variant, pretrained=False, **kwargs):
if kwargs.get('features_only', None):
raise RuntimeError('features_only not implemented for BEiT models.')
out_indices = kwargs.pop('out_indices', 3)
model = build_model_with_cfg(
Beit, variant, pretrained,
# FIXME an updated filter fn needed to interpolate rel pos emb if fine tuning to diff model sizes
pretrained_filter_fn=_beit_checkpoint_filter_fn,
**kwargs)
feature_cfg=dict(out_indices=out_indices, feature_cls='getter'),
**kwargs,
)
return model

View File

@ -1251,7 +1251,7 @@ class ByobNet(nn.Module):
return x
def forward_head(self, x, pre_logits: bool = False):
return self.head(x, pre_logits=pre_logits)
return self.head(x, pre_logits=pre_logits) if pre_logits else self.head(x)
def forward(self, x):
x = self.forward_features(x)

View File

@ -710,7 +710,7 @@ class CspNet(nn.Module):
return x
def forward_head(self, x, pre_logits: bool = False):
return self.head(x, pre_logits=pre_logits)
return self.head(x, pre_logits=pre_logits) if pre_logits else self.head(x)
def forward(self, x):
x = self.forward_features(x)

View File

@ -119,14 +119,14 @@ class VisionTransformerDistilled(VisionTransformer):
def _create_deit(variant, pretrained=False, distilled=False, **kwargs):
if kwargs.get('features_only', None):
raise RuntimeError('features_only not implemented for Vision Transformer models.')
out_indices = kwargs.pop('out_indices', 3)
model_cls = VisionTransformerDistilled if distilled else VisionTransformer
model = build_model_with_cfg(
model_cls,
variant,
pretrained,
pretrained_filter_fn=partial(checkpoint_filter_fn, adapt_layer_scale=True),
feature_cfg=dict(out_indices=out_indices, feature_cls='getter'),
**kwargs,
)
return model

View File

@ -24,9 +24,8 @@ Modifications by / Copyright 2023 Ross Wightman, original copyrights below
"""
# EVA models Copyright (c) 2022 BAAI-Vision
# EVA02 models Copyright (c) 2023 BAAI-Vision
import math
from typing import Callable, Optional, Tuple, Union
from typing import Callable, List, Optional, Tuple, Union
import torch
import torch.nn as nn
@ -39,6 +38,7 @@ from timm.layers import PatchEmbed, Mlp, GluMlp, SwiGLU, LayerNorm, DropPath, Pa
to_2tuple, use_fused_attn
from ._builder import build_model_with_cfg
from ._features import feature_take_indices
from ._registry import generate_default_cfgs, register_model
__all__ = ['Eva']
@ -424,6 +424,7 @@ class Eva(nn.Module):
**embed_args,
)
num_patches = self.patch_embed.num_patches
r = self.patch_embed.feat_ratio() if hasattr(self.patch_embed, 'feat_ratio') else patch_size
self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) if class_token else None
@ -469,6 +470,8 @@ class Eva(nn.Module):
init_values=init_values,
)
for i in range(depth)])
self.feature_info = [
dict(module=f'blocks.{i}', num_chs=embed_dim, reduction=r) for i in range(depth)]
use_fc_norm = self.global_pool == 'avg'
self.norm = nn.Identity() if use_fc_norm else norm_layer(embed_dim)
@ -559,6 +562,81 @@ class Eva(nn.Module):
rot_pos_embed = apply_keep_indices_nlc(x, rot_pos_embed, keep_indices)
return x, rot_pos_embed
def forward_intermediates(
self,
x: torch.Tensor,
indices: Optional[Union[int, List[int], Tuple[int]]] = None,
return_prefix_tokens: bool = False,
norm: bool = False,
stop_early: bool = True,
output_fmt: str = 'NCHW',
intermediates_only: bool = False,
) -> Union[List[torch.Tensor], Tuple[torch.Tensor, List[torch.Tensor]]]:
""" Forward features that returns intermediates.
Args:
x: Input image tensor
indices: Take last n blocks if an int, if is a sequence, select by matching indices
return_prefix_tokens: Return both prefix and spatial intermediate tokens
norm: Apply norm layer to all intermediates
stop_early: Stop iterating over blocks when last desired intermediate hit
output_fmt: Shape of intermediate feature outputs
intermediates_only: Only return intermediate features
"""
assert output_fmt in ('NCHW', 'NLC'), 'Output format for EVA-ViT features must be one of NCHW or NLC.'
reshape = output_fmt == 'NCHW'
intermediates = []
take_indices, max_index = feature_take_indices(len(self.blocks), indices)
# forward pass
B, _, height, width = x.shape
x = self.patch_embed(x)
x, rot_pos_embed = self._pos_embed(x)
if torch.jit.is_scripting() or not stop_early: # can't slice blocks in torchscript
blocks = self.blocks
else:
blocks = self.blocks[:max_index + 1]
for i, blk in enumerate(blocks):
x = blk(x, rope=rot_pos_embed)
if i in take_indices:
intermediates.append(self.norm(x) if norm else x)
# process intermediates
if self.num_prefix_tokens:
# split prefix (e.g. class, distill) and spatial feature tokens
prefix_tokens = [y[:, 0:self.num_prefix_tokens] for y in intermediates]
intermediates = [y[:, self.num_prefix_tokens:] for y in intermediates]
if reshape:
# reshape to BCHW output format
H, W = self.patch_embed.dynamic_feat_size((height, width))
intermediates = [y.reshape(B, H, W, -1).permute(0, 3, 1, 2).contiguous() for y in intermediates]
if not torch.jit.is_scripting() and return_prefix_tokens:
# return_prefix not support in torchscript due to poor type handling
intermediates = list(zip(intermediates, prefix_tokens))
if intermediates_only:
return intermediates
x = self.norm(x)
return x, intermediates
def prune_intermediate_layers(
self,
n: Union[int, List[int], Tuple[int]] = 1,
prune_norm: bool = False,
prune_head: bool = True,
):
""" Prune layers not required for specified intermediates.
"""
take_indices, max_index = feature_take_indices(len(self.blocks), n)
self.blocks = self.blocks[:max_index + 1] # truncate blocks
if prune_norm:
self.norm = nn.Identity()
if prune_head:
self.fc_norm = nn.Identity()
self.head = nn.Identity()
return take_indices
def forward_features(self, x):
x = self.patch_embed(x)
x, rot_pos_embed = self._pos_embed(x)
@ -663,13 +741,13 @@ def checkpoint_filter_fn(
def _create_eva(variant, pretrained=False, **kwargs):
if kwargs.get('features_only', None):
raise RuntimeError('features_only not implemented for Eva models.')
out_indices = kwargs.pop('out_indices', 3)
model = build_model_with_cfg(
Eva, variant, pretrained,
pretrained_filter_fn=checkpoint_filter_fn,
**kwargs)
feature_cfg=dict(out_indices=out_indices, feature_cls='getter'),
**kwargs,
)
return model

View File

@ -464,7 +464,7 @@ class FocalNet(nn.Module):
return x
def forward_head(self, x, pre_logits: bool = False):
return self.head(x, pre_logits=pre_logits)
return self.head(x, pre_logits=pre_logits) if pre_logits else self.head(x)
def forward(self, x):
x = self.forward_features(x)

View File

@ -501,7 +501,7 @@ class GlobalContextVit(nn.Module):
return x
def forward_head(self, x, pre_logits: bool = False):
return self.head(x, pre_logits=pre_logits)
return self.head(x, pre_logits=pre_logits) if pre_logits else self.head(x)
def forward(self, x: torch.Tensor) -> torch.Tensor:
x = self.forward_features(x)

View File

@ -1258,7 +1258,7 @@ class MaxxVit(nn.Module):
return x
def forward_head(self, x, pre_logits: bool = False):
return self.head(x, pre_logits=pre_logits)
return self.head(x, pre_logits=pre_logits) if pre_logits else self.head(x)
def forward(self, x):
x = self.forward_features(x)

View File

@ -26,6 +26,7 @@ from torch import nn
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
from timm.layers import Mlp, DropPath, trunc_normal_tf_, get_norm_layer, to_2tuple
from ._builder import build_model_with_cfg
from ._features import feature_take_indices
from ._features_fx import register_notrace_function
from ._registry import register_model, register_model_deprecations, generate_default_cfgs
@ -747,8 +748,10 @@ class MultiScaleVit(nn.Module):
num_stages = len(cfg.embed_dim)
feat_size = patch_dims
curr_stride = max(cfg.patch_stride)
dpr = [x.tolist() for x in torch.linspace(0, drop_path_rate, sum(cfg.depths)).split(cfg.depths)]
self.stages = nn.ModuleList()
self.feature_info = []
for i in range(num_stages):
if cfg.expand_attn:
dim_out = cfg.embed_dim[i]
@ -775,6 +778,8 @@ class MultiScaleVit(nn.Module):
norm_layer=norm_layer,
drop_path=dpr[i],
)
curr_stride *= max(cfg.stride_q[i])
self.feature_info += [dict(module=f'block.{i}', num_chs=dim_out, reduction=curr_stride)]
embed_dim = dim_out
feat_size = stage.feat_size
self.stages.append(stage)
@ -829,6 +834,63 @@ class MultiScaleVit(nn.Module):
('fc', nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity())
]))
def forward_intermediates(
self,
x: torch.Tensor,
indices: Union[int, List[int], Tuple[int]] = None,
norm: bool = False,
stop_early: bool = True,
output_fmt: str = 'NCHW',
intermediates_only: bool = False,
) -> Union[List[torch.Tensor], Tuple[torch.Tensor, List[torch.Tensor]]]:
""" Forward features that returns intermediates.
Args:
x: Input image tensor
indices: Take last n blocks if int, all if None, select matching indices if sequence
norm: Apply norm layer to all intermediates
stop_early: Stop iterating over blocks when last desired intermediate hit
output_fmt: Shape of intermediate feature outputs
intermediates_only: Only return intermediate features
Returns:
"""
assert output_fmt in ('NCHW', 'NLC'), 'Output shape for MViT-V2 must be NCHW or NLC.'
reshape = output_fmt == 'NCHW'
intermediates = []
take_indices, max_index = feature_take_indices(len(self.stages), indices)
# FIXME slice block/pos_block if < max
# forward pass
x, feat_size = self.patch_embed(x)
B = x.shape[0]
if self.cls_token is not None:
cls_tokens = self.cls_token.expand(B, -1, -1)
x = torch.cat((cls_tokens, x), dim=1)
if self.pos_embed is not None:
x = x + self.pos_embed
for i, stage in enumerate(self.stages):
x, feat_size = stage(x, feat_size)
if i in take_indices:
if norm and i == (len(self.stages) - 1):
x_inter = self.norm(x) # applying final norm last intermediate
else:
x_inter = x
if reshape:
if self.cls_token is not None:
# possible to allow return of class tokens, TBD
x_inter = x_inter[:, 1:]
x_inter = x_inter.reshape(B, feat_size[0], feat_size[1], -1).permute(0, 3, 1, 2)
intermediates.append(x_inter)
if intermediates_only:
return intermediates
x = self.norm(x)
return x, intermediates
def forward_features(self, x):
x, feat_size = self.patch_embed(x)
B, N, C = x.shape
@ -862,6 +924,18 @@ class MultiScaleVit(nn.Module):
def checkpoint_filter_fn(state_dict, model):
if 'stages.0.blocks.0.norm1.weight' in state_dict:
# native checkpoint, look for rel_pos interpolations
for k in state_dict.keys():
if 'rel_pos' in k:
rel_pos = state_dict[k]
dest_rel_pos_shape = model.state_dict()[k].shape
if rel_pos.shape[0] != dest_rel_pos_shape[0]:
rel_pos_resized = torch.nn.functional.interpolate(
rel_pos.reshape(1, rel_pos.shape[0], -1).permute(0, 2, 1),
size=dest_rel_pos_shape[0],
mode="linear",
)
state_dict[k] = rel_pos_resized.reshape(-1, dest_rel_pos_shape[0]).permute(1, 0)
return state_dict
import re
@ -892,16 +966,6 @@ def checkpoint_filter_fn(state_dict, model):
k = k.replace('head.projection', 'head.fc')
out_dict[k] = v
# for k, v in state_dict.items():
# if model.pos_embed is not None and k == 'pos_embed' and v.shape[1] != model.pos_embed.shape[1]:
# # To resize pos embedding when using model at different size from pretrained weights
# v = resize_pos_embed(
# v,
# model.pos_embed,
# 0 if getattr(model, 'no_embed_class') else getattr(model, 'num_prefix_tokens', 1),
# model.patch_embed.grid_size
# )
return out_dict
@ -948,16 +1012,14 @@ model_cfgs = dict(
def _create_mvitv2(variant, cfg_variant=None, pretrained=False, **kwargs):
if kwargs.get('features_only', None):
raise RuntimeError('features_only not implemented for Multiscale Vision Transformer models.')
out_indices = kwargs.pop('out_indices', 4)
return build_model_with_cfg(
MultiScaleVit,
variant,
pretrained,
model_cfg=model_cfgs[variant] if not cfg_variant else model_cfgs[cfg_variant],
pretrained_filter_fn=checkpoint_filter_fn,
feature_cfg=dict(flatten_sequential=True),
feature_cfg=dict(out_indices=out_indices, feature_cls='getter'),
**kwargs,
)

View File

@ -525,7 +525,7 @@ class RegNet(nn.Module):
return x
def forward_head(self, x, pre_logits: bool = False):
return self.head(x, pre_logits=pre_logits)
return self.head(x, pre_logits=pre_logits) if pre_logits else self.head(x)
def forward(self, x):
x = self.forward_features(x)

View File

@ -469,7 +469,7 @@ class ResNetV2(nn.Module):
return x
def forward_head(self, x, pre_logits: bool = False):
return self.head(x, pre_logits=pre_logits)
return self.head(x, pre_logits=pre_logits) if pre_logits else self.head(x)
def forward(self, x):
x = self.forward_features(x)

View File

@ -13,7 +13,7 @@ Code/weights from https://github.com/Meituan-AutoML/Twins, original copyright/li
# --------------------------------------------------------
import math
from functools import partial
from typing import Tuple
from typing import List, Optional, Tuple, Union
import torch
import torch.nn as nn
@ -22,6 +22,7 @@ import torch.nn.functional as F
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
from timm.layers import Mlp, DropPath, to_2tuple, trunc_normal_, use_fused_attn
from ._builder import build_model_with_cfg
from ._features import feature_take_indices
from ._features_fx import register_notrace_module
from ._registry import register_model, generate_default_cfgs
from .vision_transformer import Attention
@ -324,6 +325,7 @@ class Twins(nn.Module):
patch_size = 2
self.blocks = nn.ModuleList()
self.feature_info = []
dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))] # stochastic depth decay rule
cur = 0
for k in range(len(depths)):
@ -339,6 +341,7 @@ class Twins(nn.Module):
ws=1 if wss is None or i % 2 == 1 else wss[k]) for i in range(depths[k])],
)
self.blocks.append(_block)
self.feature_info += [dict(module=f'block.{k}', num_chs=embed_dims[k], reduction=2**(2+k))]
cur += depths[k]
self.pos_block = nn.ModuleList([PosConv(embed_dim, embed_dim) for embed_dim in embed_dims])
@ -401,6 +404,61 @@ class Twins(nn.Module):
if m.bias is not None:
m.bias.data.zero_()
def forward_intermediates(
self,
x: torch.Tensor,
indices: Union[int, List[int], Tuple[int]] = None,
norm: bool = False,
stop_early: bool = True,
output_fmt: str = 'NCHW',
intermediates_only: bool = False,
) -> Union[List[torch.Tensor], Tuple[torch.Tensor, List[torch.Tensor]]]:
""" Forward features that returns intermediates.
Args:
x: Input image tensor
indices: Take last n blocks if int, all if None, select matching indices if sequence
norm: Apply norm layer to all intermediates
stop_early: Stop iterating over blocks when last desired intermediate hit
output_fmt: Shape of intermediate feature outputs
intermediates_only: Only return intermediate features
Returns:
"""
assert output_fmt == 'NCHW', 'Output shape for Twins must be NCHW.'
intermediates = []
take_indices, max_index = feature_take_indices(len(self.blocks), indices)
# FIXME slice block/pos_block if < max
# forward pass
B, _, height, width = x.shape
for i, (embed, drop, blocks, pos_blk) in enumerate(zip(
self.patch_embeds, self.pos_drops, self.blocks, self.pos_block)
):
x, size = embed(x)
x = drop(x)
for j, blk in enumerate(blocks):
x = blk(x, size)
if j == 0:
x = pos_blk(x, size) # PEG here
if i < len(self.depths) - 1:
x = x.reshape(B, *size, -1).permute(0, 3, 1, 2).contiguous()
if i in take_indices:
intermediates.append(x)
else:
if i in take_indices:
# only last feature can be normed
x_feat = self.norm(x) if norm else x
intermediates.append(x_feat.reshape(B, *size, -1).permute(0, 3, 1, 2).contiguous())
if intermediates_only:
return intermediates
x = self.norm(x)
return x, intermediates
def forward_features(self, x):
B = x.shape[0]
for i, (embed, drop, blocks, pos_blk) in enumerate(
@ -429,10 +487,12 @@ class Twins(nn.Module):
def _create_twins(variant, pretrained=False, **kwargs):
if kwargs.get('features_only', None):
raise RuntimeError('features_only not implemented for Vision Transformer models.')
model = build_model_with_cfg(Twins, variant, pretrained, **kwargs)
out_indices = kwargs.pop('out_indices', 4)
model = build_model_with_cfg(
Twins, variant, pretrained,
feature_cfg=dict(out_indices=out_indices, feature_cls='getter'),
**kwargs,
)
return model

View File

@ -27,7 +27,7 @@ import logging
import math
from collections import OrderedDict
from functools import partial
from typing import Any, Callable, Dict, Optional, Sequence, Set, Tuple, Type, Union, List
from typing import Any, Callable, Dict, Optional, Set, Tuple, Type, Union, List
try:
from typing import Literal
except ImportError:
@ -45,6 +45,7 @@ from timm.layers import PatchEmbed, Mlp, DropPath, AttentionPoolLatent, RmsNorm,
trunc_normal_, lecun_normal_, resample_patch_embed, resample_abs_pos_embed, use_fused_attn, \
get_act_layer, get_norm_layer, LayerType
from ._builder import build_model_with_cfg
from ._features import feature_take_indices
from ._manipulate import named_apply, checkpoint_seq, adapt_input_conv
from ._registry import generate_default_cfgs, register_model, register_model_deprecations
@ -488,6 +489,7 @@ class VisionTransformer(nn.Module):
**embed_args,
)
num_patches = self.patch_embed.num_patches
r = self.patch_embed.feat_ratio() if hasattr(self.patch_embed, 'feat_ratio') else patch_size
self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) if class_token else None
self.reg_token = nn.Parameter(torch.zeros(1, reg_tokens, embed_dim)) if reg_tokens else None
@ -520,6 +522,8 @@ class VisionTransformer(nn.Module):
mlp_layer=mlp_layer,
)
for i in range(depth)])
self.feature_info = [
dict(module=f'blocks.{i}', num_chs=embed_dim, reduction=r) for i in range(depth)]
self.norm = norm_layer(embed_dim) if not use_fc_norm else nn.Identity()
# Classifier Head
@ -628,58 +632,107 @@ class VisionTransformer(nn.Module):
return self.pos_drop(x)
def _intermediate_layers(
def forward_intermediates(
self,
x: torch.Tensor,
n: Union[int, Sequence] = 1,
) -> List[torch.Tensor]:
outputs, num_blocks = [], len(self.blocks)
take_indices = set(range(num_blocks - n, num_blocks) if isinstance(n, int) else n)
last_index_to_take = max(take_indices)
indices: Optional[Union[int, List[int], Tuple[int]]] = None,
return_prefix_tokens: bool = False,
norm: bool = False,
stop_early: bool = True,
output_fmt: str = 'NCHW',
intermediates_only: bool = False,
) -> Union[List[torch.Tensor], Tuple[torch.Tensor, List[torch.Tensor]]]:
""" Forward features that returns intermediates.
Args:
x: Input image tensor
indices: Take last n blocks if int, all if None, select matching indices if sequence
return_prefix_tokens: Return both prefix and spatial intermediate tokens
norm: Apply norm layer to all intermediates
stop_early: Stop iterating over blocks when last desired intermediate hit
output_fmt: Shape of intermediate feature outputs
intermediates_only: Only return intermediate features
Returns:
"""
assert output_fmt in ('NCHW', 'NLC'), 'Output format for ViT features must be one of NCHW or NLC.'
reshape = output_fmt == 'NCHW'
intermediates = []
take_indices, max_index = feature_take_indices(len(self.blocks), indices)
# forward pass
B, _, height, width = x.shape
x = self.patch_embed(x)
x = self._pos_embed(x)
x = self.patch_drop(x)
x = self.norm_pre(x)
for i, blk in enumerate(self.blocks[: last_index_to_take + 1]):
if torch.jit.is_scripting() or not stop_early: # can't slice blocks in torchscript
blocks = self.blocks
else:
blocks = self.blocks[:max_index + 1]
for i, blk in enumerate(blocks):
x = blk(x)
if i in take_indices:
outputs.append(x)
# normalize intermediates with final norm layer if enabled
intermediates.append(self.norm(x) if norm else x)
return outputs
# process intermediates
if self.num_prefix_tokens:
# split prefix (e.g. class, distill) and spatial feature tokens
prefix_tokens = [y[:, 0:self.num_prefix_tokens] for y in intermediates]
intermediates = [y[:, self.num_prefix_tokens:] for y in intermediates]
if reshape:
# reshape to BCHW output format
H, W = self.patch_embed.dynamic_feat_size((height, width))
intermediates = [y.reshape(B, H, W, -1).permute(0, 3, 1, 2).contiguous() for y in intermediates]
if not torch.jit.is_scripting() and return_prefix_tokens:
# return_prefix not support in torchscript due to poor type handling
intermediates = list(zip(intermediates, prefix_tokens))
if intermediates_only:
return intermediates
x = self.norm(x)
return x, intermediates
def prune_intermediate_layers(
self,
n: Union[int, List[int], Tuple[int]] = 1,
prune_norm: bool = False,
prune_head: bool = True,
):
""" Prune layers not required for specified intermediates.
"""
take_indices, max_index = feature_take_indices(len(self.blocks), n)
self.blocks = self.blocks[:max_index + 1] # truncate blocks
if prune_norm:
self.norm = nn.Identity()
if prune_head:
if self.attn_pool is not None:
self.attn_pool = None
self.fc_norm = nn.Identity()
self.head = nn.Identity()
return take_indices
def get_intermediate_layers(
self,
x: torch.Tensor,
n: Union[int, Sequence] = 1,
n: Union[int, List[int], Tuple[int]] = 1,
reshape: bool = False,
return_prefix_tokens: bool = False,
norm: bool = False,
) -> Tuple[Union[torch.Tensor, Tuple[torch.Tensor]]]:
""" Intermediate layer accessor (NOTE: This is a WIP experiment).
Inspired by DINO / DINOv2 interface
) -> List[torch.Tensor]:
""" Intermediate layer accessor inspired by DINO / DINOv2 interface.
NOTE: This API is for backwards compat, favour using forward_intermediates() directly.
"""
# take last n blocks if n is an int, if in is a sequence, select by matching indices
outputs = self._intermediate_layers(x, n)
if norm:
outputs = [self.norm(out) for out in outputs]
prefix_tokens = [out[:, 0:self.num_prefix_tokens] for out in outputs]
outputs = [out[:, self.num_prefix_tokens:] for out in outputs]
if reshape:
patch_size = self.patch_embed.patch_size
batch, _, height, width = x.size()
outputs = [
out.reshape(batch, int(math.ceil(height / patch_size[0])), int(math.ceil(width / patch_size[1])), -1)
.permute(0, 3, 1, 2)
.contiguous()
for out in outputs
]
if return_prefix_tokens:
return tuple(zip(outputs, prefix_tokens))
return tuple(outputs)
return self.forward_intermediates(
x, n,
return_prefix_tokens=return_prefix_tokens,
norm=norm,
output_fmt='NCHW' if reshape else 'NLC',
intermediates_only=True,
)
def forward_features(self, x: torch.Tensor) -> torch.Tensor:
x = self.patch_embed(x)
@ -1770,9 +1823,7 @@ default_cfgs = generate_default_cfgs(default_cfgs)
def _create_vision_transformer(variant: str, pretrained: bool = False, **kwargs) -> VisionTransformer:
if kwargs.get('features_only', None):
raise RuntimeError('features_only not implemented for Vision Transformer models.')
out_indices = kwargs.pop('out_indices', 3)
if 'flexi' in variant:
# FIXME Google FlexiViT pretrained models have a strong preference for bilinear patch / embed
# interpolation, other pretrained models resize better w/ anti-aliased bicubic interpolation.
@ -1791,6 +1842,7 @@ def _create_vision_transformer(variant: str, pretrained: bool = False, **kwargs)
pretrained,
pretrained_filter_fn=_filter_fn,
pretrained_strict=strict,
feature_cfg=dict(out_indices=out_indices, feature_cls='getter'),
**kwargs,
)
@ -2483,7 +2535,7 @@ def vit_huge_patch14_xp_224(pretrained: bool = False, **kwargs) -> VisionTransfo
def vit_small_patch14_dinov2(pretrained: bool = False, **kwargs) -> VisionTransformer:
""" ViT-S/14 for DINOv2
"""
model_args = dict(patch_size=14, embed_dim=384, depth=12, num_heads=6, init_values=1e-5, img_size=518)
model_args = dict(patch_size=14, embed_dim=384, depth=12, num_heads=6, init_values=1e-5)
model = _create_vision_transformer(
'vit_small_patch14_dinov2', pretrained=pretrained, **dict(model_args, **kwargs))
return model
@ -2493,7 +2545,7 @@ def vit_small_patch14_dinov2(pretrained: bool = False, **kwargs) -> VisionTransf
def vit_base_patch14_dinov2(pretrained: bool = False, **kwargs) -> VisionTransformer:
""" ViT-B/14 for DINOv2
"""
model_args = dict(patch_size=14, embed_dim=768, depth=12, num_heads=12, init_values=1e-5, img_size=518)
model_args = dict(patch_size=14, embed_dim=768, depth=12, num_heads=12, init_values=1e-5)
model = _create_vision_transformer(
'vit_base_patch14_dinov2', pretrained=pretrained, **dict(model_args, **kwargs))
return model
@ -2503,7 +2555,7 @@ def vit_base_patch14_dinov2(pretrained: bool = False, **kwargs) -> VisionTransfo
def vit_large_patch14_dinov2(pretrained: bool = False, **kwargs) -> VisionTransformer:
""" ViT-L/14 for DINOv2
"""
model_args = dict(patch_size=14, embed_dim=1024, depth=24, num_heads=16, init_values=1e-5, img_size=518)
model_args = dict(patch_size=14, embed_dim=1024, depth=24, num_heads=16, init_values=1e-5)
model = _create_vision_transformer(
'vit_large_patch14_dinov2', pretrained=pretrained, **dict(model_args, **kwargs))
return model
@ -2519,7 +2571,7 @@ def vit_giant_patch14_dinov2(pretrained: bool = False, **kwargs) -> VisionTransf
# With SwiGLUPacked, we need to set hidden_features = 2 * 4096 = 8192
model_args = dict(
patch_size=14, embed_dim=1536, depth=40, num_heads=24, init_values=1e-5,
mlp_ratio=2.66667 * 2, mlp_layer=SwiGLUPacked, img_size=518, act_layer=nn.SiLU
mlp_ratio=2.66667 * 2, mlp_layer=SwiGLUPacked, act_layer=nn.SiLU
)
model = _create_vision_transformer(
'vit_giant_patch14_dinov2', pretrained=pretrained, **dict(model_args, **kwargs))

View File

@ -13,8 +13,9 @@ They were moved here to keep file sizes sane.
Hacked together by / Copyright 2020, Ross Wightman
"""
import math
from functools import partial
from typing import List, Optional, Tuple
from typing import List, Optional, Tuple, Union
import torch
import torch.nn as nn
@ -41,6 +42,7 @@ class HybridEmbed(nn.Module):
img_size=224,
patch_size=1,
feature_size=None,
feature_ratio=None,
in_chans=3,
embed_dim=768,
bias=True,
@ -68,15 +70,20 @@ class HybridEmbed(nn.Module):
feature_size = o.shape[-2:]
feature_dim = o.shape[1]
backbone.train(training)
feature_ratio = tuple([s // f for s, f in zip(img_size, feature_size)])
else:
feature_size = to_2tuple(feature_size)
feature_ratio = to_2tuple(feature_ratio or 16)
if hasattr(self.backbone, 'feature_info'):
feature_dim = self.backbone.feature_info.channels()[-1]
else:
feature_dim = self.backbone.num_features
if not dynamic_img_pad:
assert feature_size[0] % patch_size[0] == 0 and feature_size[1] % patch_size[1] == 0
self.grid_size = (feature_size[0] // patch_size[0], feature_size[1] // patch_size[1])
self.feature_size = feature_size
self.feature_ratio = feature_ratio
self.grid_size = tuple([f // p for f, p in zip(self.feature_size, self.patch_size)])
self.num_patches = self.grid_size[0] * self.grid_size[1]
if output_fmt is not None:
self.flatten = False
@ -90,6 +97,25 @@ class HybridEmbed(nn.Module):
self.proj = nn.Conv2d(feature_dim, embed_dim, kernel_size=patch_size, stride=patch_size, bias=bias)
def feat_ratio(self, as_scalar=True) -> Union[Tuple[int, int], int]:
total_reduction = (
self.feature_ratio[0] * self.patch_size[0],
self.feature_ratio[1] * self.patch_size[1]
)
if as_scalar:
return max(total_reduction)
else:
return total_reduction
def dynamic_feat_size(self, img_size: Tuple[int, int]) -> Tuple[int, int]:
""" Get feature grid size taking account dynamic padding and backbone network feat reduction
"""
feat_size = (img_size[0] // self.feature_ratio[0], img_size[1] // self.feature_ratio[1])
if self.dynamic_img_pad:
return math.ceil(feat_size[0] / self.patch_size[0]), math.ceil(feat_size[1] / self.patch_size[1])
else:
return feat_size[0] // self.patch_size[0], feat_size[1] // self.patch_size[1]
def forward(self, x):
x = self.backbone(x)
if isinstance(x, (list, tuple)):

View File

@ -7,7 +7,7 @@ Hacked together by / Copyright 2022, Ross Wightman
import logging
import math
from functools import partial
from typing import Optional, Tuple, Type, Union
from typing import List, Optional, Tuple, Type, Union
try:
from typing import Literal
@ -22,6 +22,7 @@ from torch.utils.checkpoint import checkpoint
from timm.data import IMAGENET_INCEPTION_MEAN, IMAGENET_INCEPTION_STD
from timm.layers import PatchEmbed, Mlp, DropPath, RelPosMlp, RelPosBias, use_fused_attn, LayerType
from ._builder import build_model_with_cfg
from ._features import feature_take_indices
from ._manipulate import named_apply
from ._registry import generate_default_cfgs, register_model
from .vision_transformer import get_init_weights_vit
@ -297,6 +298,7 @@ class VisionTransformerRelPos(nn.Module):
embed_dim=embed_dim,
)
feat_size = self.patch_embed.grid_size
r = self.patch_embed.feat_ratio() if hasattr(self.patch_embed, 'feat_ratio') else patch_size
rel_pos_args = dict(window_size=feat_size, prefix_tokens=self.num_prefix_tokens)
if rel_pos_type.startswith('mlp'):
@ -332,6 +334,8 @@ class VisionTransformerRelPos(nn.Module):
act_layer=act_layer,
)
for i in range(depth)])
self.feature_info = [
dict(module=f'blocks.{i}', num_chs=embed_dim, reduction=r) for i in range(depth)]
self.norm = norm_layer(embed_dim) if not fc_norm else nn.Identity()
# Classifier Head
@ -384,6 +388,88 @@ class VisionTransformerRelPos(nn.Module):
self.global_pool = global_pool
self.head = nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity()
def forward_intermediates(
self,
x: torch.Tensor,
indices: Optional[Union[int, List[int], Tuple[int]]] = None,
return_prefix_tokens: bool = False,
norm: bool = False,
stop_early: bool = True,
output_fmt: str = 'NCHW',
intermediates_only: bool = False,
) -> Union[List[torch.Tensor], Tuple[torch.Tensor, List[torch.Tensor]]]:
""" Forward features that returns intermediates.
Args:
x: Input image tensor
indices: Take last n blocks if int, all if None, select matching indices if sequence
return_prefix_tokens: Return both prefix and spatial intermediate tokens
norm: Apply norm layer to all intermediates
stop_early: Stop iterating over blocks when last desired intermediate hit
output_fmt: Shape of intermediate feature outputs
intermediates_only: Only return intermediate features
Returns:
"""
assert output_fmt in ('NCHW', 'NLC'), 'Output format for ViT features must be one of NCHW or NLC.'
reshape = output_fmt == 'NCHW'
intermediates = []
take_indices, max_index = feature_take_indices(len(self.blocks), indices)
# forward pass
B, _, height, width = x.shape
x = self.patch_embed(x)
if self.cls_token is not None:
x = torch.cat((self.cls_token.expand(x.shape[0], -1, -1), x), dim=1)
shared_rel_pos = self.shared_rel_pos.get_bias() if self.shared_rel_pos is not None else None
if torch.jit.is_scripting() or not stop_early: # can't slice blocks in torchscript
blocks = self.blocks
else:
blocks = self.blocks[:max_index + 1]
for i, blk in enumerate(blocks):
x = blk(x, shared_rel_pos=shared_rel_pos)
if i in take_indices:
# normalize intermediates with final norm layer if enabled
intermediates.append(self.norm(x) if norm else x)
# process intermediates
if self.num_prefix_tokens:
# split prefix (e.g. class, distill) and spatial feature tokens
prefix_tokens = [y[:, 0:self.num_prefix_tokens] for y in intermediates]
intermediates = [y[:, self.num_prefix_tokens:] for y in intermediates]
if reshape:
# reshape to BCHW output format
H, W = self.patch_embed.dynamic_feat_size((height, width))
intermediates = [y.reshape(B, H, W, -1).permute(0, 3, 1, 2).contiguous() for y in intermediates]
if not torch.jit.is_scripting() and return_prefix_tokens:
# return_prefix not support in torchscript due to poor type handling
intermediates = list(zip(intermediates, prefix_tokens))
if intermediates_only:
return intermediates
x = self.norm(x)
return x, intermediates
def prune_intermediate_layers(
self,
n: Union[int, List[int], Tuple[int]] = 1,
prune_norm: bool = False,
prune_head: bool = True,
):
""" Prune layers not required for specified intermediates.
"""
take_indices, max_index = feature_take_indices(len(self.blocks), n)
self.blocks = self.blocks[:max_index + 1] # truncate blocks
if prune_norm:
self.norm = nn.Identity()
if prune_head:
self.fc_norm = nn.Identity()
self.head = nn.Identity()
return take_indices
def forward_features(self, x):
x = self.patch_embed(x)
if self.cls_token is not None:
@ -412,10 +498,12 @@ class VisionTransformerRelPos(nn.Module):
def _create_vision_transformer_relpos(variant, pretrained=False, **kwargs):
if kwargs.get('features_only', None):
raise RuntimeError('features_only not implemented for Vision Transformer models.')
model = build_model_with_cfg(VisionTransformerRelPos, variant, pretrained, **kwargs)
out_indices = kwargs.pop('out_indices', 3)
model = build_model_with_cfg(
VisionTransformerRelPos, variant, pretrained,
feature_cfg=dict(out_indices=out_indices, feature_cls='getter'),
**kwargs,
)
return model

View File

@ -11,21 +11,22 @@ A PyTorch implement of Vision Transformers as described in:
"""
import logging
from functools import partial
from typing import Callable, Optional, Tuple
from typing import Callable, List, Optional, Tuple, Union
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.utils.checkpoint
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD, IMAGENET_INCEPTION_MEAN, IMAGENET_INCEPTION_STD
from timm.layers import PatchEmbed, Mlp, DropPath, PatchDropout, LayerNorm2d, ClassifierHead, NormMlpClassifierHead, \
Format, resample_abs_pos_embed_nhwc, RotaryEmbeddingCat, apply_rot_embed_cat, to_2tuple, use_fused_attn
from torch.jit import Final
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD, IMAGENET_INCEPTION_MEAN, IMAGENET_INCEPTION_STD
from timm.layers import PatchEmbed, Mlp, DropPath, PatchDropout, LayerNorm2d, ClassifierHead, NormMlpClassifierHead,\
Format, resample_abs_pos_embed_nhwc, RotaryEmbeddingCat, apply_rot_embed_cat, to_2tuple, use_fused_attn
from ._builder import build_model_with_cfg
from ._features import feature_take_indices
from ._features_fx import register_notrace_function
from ._manipulate import checkpoint_seq
from ._registry import generate_default_cfgs, register_model
from ._features_fx import register_notrace_function
# model_registry will add each entrypoint fn to this
__all__ = ['VisionTransformerSAM']
@ -343,8 +344,7 @@ class VisionTransformerSAM(nn.Module):
attn_drop_rate: float = 0.,
drop_path_rate: float = 0.,
weight_init: str = '',
embed_layer: Callable = partial(
PatchEmbed, output_fmt=Format.NHWC, strict_img_size=False),
embed_layer: Callable = partial(PatchEmbed, output_fmt=Format.NHWC, strict_img_size=False),
norm_layer: Optional[Callable] = nn.LayerNorm,
act_layer: Optional[Callable] = nn.GELU,
block_fn: Callable = Block,
@ -408,6 +408,8 @@ class VisionTransformerSAM(nn.Module):
bias=not pre_norm, # disable bias if pre-norm is used
)
grid_size = self.patch_embed.grid_size
r = self.patch_embed.feat_ratio() if hasattr(self.patch_embed, 'feat_ratio') else patch_size
if use_abs_pos:
# Initialize absolute positional embedding with pretrain image size.
self.pos_embed = nn.Parameter(torch.zeros(1, grid_size[0], grid_size[1], embed_dim))
@ -469,6 +471,8 @@ class VisionTransformerSAM(nn.Module):
rope=self.rope_window if i not in global_attn_indexes else self.rope_global,
)
for i in range(depth)])
self.feature_info = [
dict(module=f'blocks.{i}', num_chs=embed_dim, reduction=r) for i in range(depth)]
if neck_chans:
self.neck = nn.Sequential(
@ -536,6 +540,78 @@ class VisionTransformerSAM(nn.Module):
def reset_classifier(self, num_classes=0, global_pool=None):
self.head.reset(num_classes, global_pool)
def forward_intermediates(
self,
x: torch.Tensor,
indices: Union[int, List[int], Tuple[int]] = None,
norm: bool = False,
stop_early: bool = True,
output_fmt: str = 'NCHW',
intermediates_only: bool = False,
) -> Union[List[torch.Tensor], Tuple[torch.Tensor, List[torch.Tensor]]]:
""" Forward features that returns intermediates.
Args:
x: Input image tensor
indices: Take last n blocks if int, all if None, select matching indices if sequence
norm: Apply norm layer to all intermediates
stop_early: Stop iterating over blocks when last desired intermediate hit
output_fmt: Shape of intermediate feature outputs
intermediates_only: Only return intermediate features
Returns:
"""
assert output_fmt == 'NCHW', 'Output shape for ViT-SAM must be NCHW.'
intermediates = []
take_indices, max_index = feature_take_indices(len(self.blocks), indices)
# forward pass, collect intermediates
x = self.patch_embed(x)
if self.pos_embed is not None:
# dynamically resize abs pos embedding if needed
x = x + resample_abs_pos_embed_nhwc(self.pos_embed, x.shape[1:3])
x = self.pos_drop(x)
x = self.patch_drop(x)
x = self.norm_pre(x)
if torch.jit.is_scripting() or not stop_early: # can't slice blocks in torchscript
blocks = self.blocks
else:
blocks = self.blocks[:max_index + 1]
for i, blk in enumerate(blocks):
x = blk(x)
if i in take_indices:
# make output BCHW
if norm:
# norm is intertwined with neck convs so apply both, changes the dim
# FIXME only apply to final? Need experiments
intermediates.append(self.neck(x.permute(0, 3, 1, 2)))
else:
intermediates.append(x.permute(0, 3, 1, 2))
if intermediates_only:
return intermediates
x = self.neck(x.permute(0, 3, 1, 2))
return x, intermediates
def prune_intermediate_layers(
self,
n: Union[int, List[int], Tuple[int]] = None,
prune_norm: bool = False,
prune_head: bool = True,
):
""" Prune layers not required for specified intermediates.
"""
take_indices, max_index = feature_take_indices(len(self.blocks), n)
self.blocks = self.blocks[:max_index + 1] # truncate blocks
if prune_norm:
# neck is being treated as equivalent to final norm here
self.neck = nn.Identity()
if prune_head:
self.head = nn.Identity()
return take_indices
def forward_features(self, x):
x = self.patch_embed(x)
if self.pos_embed is not None:
@ -618,15 +694,13 @@ default_cfgs = generate_default_cfgs({
def _create_vision_transformer(variant, pretrained=False, **kwargs):
if kwargs.get('features_only', None):
raise RuntimeError(
'features_only not implemented for Vision Transformer models.')
out_indices = kwargs.pop('out_indices', 3)
return build_model_with_cfg(
VisionTransformerSAM,
variant,
pretrained,
pretrained_filter_fn=checkpoint_filter_fn,
feature_cfg=dict(out_indices=out_indices, feature_cls='getter'),
**kwargs,
)

View File

@ -259,7 +259,7 @@ class VovNet(nn.Module):
return self.stages(x)
def forward_head(self, x, pre_logits: bool = False):
return self.head(x, pre_logits=pre_logits)
return self.head(x, pre_logits=pre_logits) if pre_logits else self.head(x)
def forward(self, x):
x = self.forward_features(x)

View File

@ -286,7 +286,7 @@ class XceptionAligned(nn.Module):
return x
def forward_head(self, x, pre_logits: bool = False):
return self.head(x, pre_logits=pre_logits)
return self.head(x, pre_logits=pre_logits) if pre_logits else self.head(x)
def forward(self, x):
x = self.forward_features(x)