Merge pull request #2136 from huggingface/vit_features_only
Exploring vit features_only via new forward_intermediates() API, inspired by #2131pull/2162/head
commit
d6b95520f1
|
@ -47,11 +47,16 @@ if hasattr(torch._C, '_jit_set_profiling_executor'):
|
|||
torch._C._jit_set_profiling_executor(True)
|
||||
torch._C._jit_set_profiling_mode(False)
|
||||
|
||||
# models with forward_intermediates() and support for FeatureGetterNet features_only wrapper
|
||||
FEAT_INTER_FILTERS = [
|
||||
'vit_*', 'twins_*', 'deit*', 'beit*', 'mvitv2*', 'eva*', 'samvit_*', 'flexivit*'
|
||||
]
|
||||
|
||||
# transformer models don't support many of the spatial / feature based model functionalities
|
||||
NON_STD_FILTERS = [
|
||||
'vit_*', 'tnt_*', 'pit_*', 'coat_*', 'cait_*', '*mixer_*', 'gmlp_*', 'resmlp_*', 'twins_*',
|
||||
'convit_*', 'levit*', 'visformer*', 'deit*', 'jx_nest_*', 'nest_*', 'xcit_*', 'crossvit_*', 'beit*',
|
||||
'poolformer_*', 'volo_*', 'sequencer2d_*', 'pvt_v2*', 'mvitv2*', 'gcvit*', 'efficientformer*',
|
||||
'convit_*', 'levit*', 'visformer*', 'deit*', 'xcit_*', 'crossvit_*', 'beit*',
|
||||
'poolformer_*', 'volo_*', 'sequencer2d_*', 'mvitv2*', 'gcvit*', 'efficientformer*',
|
||||
'eva_*', 'flexivit*', 'eva02*', 'samvit_*', 'efficientvit_m*', 'tiny_vit_*'
|
||||
]
|
||||
NUM_NON_STD = len(NON_STD_FILTERS)
|
||||
|
@ -351,7 +356,7 @@ if 'GITHUB_ACTIONS' in os.environ: # and 'Linux' in platform.system():
|
|||
|
||||
@pytest.mark.features
|
||||
@pytest.mark.timeout(120)
|
||||
@pytest.mark.parametrize('model_name', list_models(exclude_filters=EXCLUDE_FILTERS + EXCLUDE_FEAT_FILTERS, include_tags=True))
|
||||
@pytest.mark.parametrize('model_name', list_models(exclude_filters=EXCLUDE_FILTERS + EXCLUDE_FEAT_FILTERS))
|
||||
@pytest.mark.parametrize('batch_size', [1])
|
||||
def test_model_forward_features(model_name, batch_size):
|
||||
"""Run a single forward pass with each model in feature extraction mode"""
|
||||
|
@ -359,7 +364,7 @@ def test_model_forward_features(model_name, batch_size):
|
|||
model.eval()
|
||||
expected_channels = model.feature_info.channels()
|
||||
expected_reduction = model.feature_info.reduction()
|
||||
assert len(expected_channels) >= 4 # all models here should have at least 4 feature levels by default, some 5 or 6
|
||||
assert len(expected_channels) >= 3 # all models here should have at least 3 default feat levels
|
||||
|
||||
input_size = _get_input_size(model=model, target=TARGET_FFEAT_SIZE)
|
||||
if max(input_size) > MAX_FFEAT_SIZE:
|
||||
|
@ -380,6 +385,72 @@ def test_model_forward_features(model_name, batch_size):
|
|||
assert not torch.isnan(o).any()
|
||||
|
||||
|
||||
@pytest.mark.features
|
||||
@pytest.mark.timeout(120)
|
||||
@pytest.mark.parametrize('model_name', list_models(FEAT_INTER_FILTERS, exclude_filters=EXCLUDE_FILTERS))
|
||||
@pytest.mark.parametrize('batch_size', [1])
|
||||
def test_model_forward_intermediates_features(model_name, batch_size):
|
||||
"""Run a single forward pass with each model in feature extraction mode"""
|
||||
model = create_model(model_name, pretrained=False, features_only=True)
|
||||
model.eval()
|
||||
print(model.feature_info.out_indices)
|
||||
expected_channels = model.feature_info.channels()
|
||||
expected_reduction = model.feature_info.reduction()
|
||||
|
||||
input_size = _get_input_size(model=model, target=TARGET_FFEAT_SIZE)
|
||||
if max(input_size) > MAX_FFEAT_SIZE:
|
||||
pytest.skip("Fixed input size model > limit.")
|
||||
output_fmt = getattr(model, 'output_fmt', 'NCHW')
|
||||
feat_axis = get_channel_dim(output_fmt)
|
||||
spatial_axis = get_spatial_dim(output_fmt)
|
||||
import math
|
||||
|
||||
outputs = model(torch.randn((batch_size, *input_size)))
|
||||
assert len(expected_channels) == len(outputs)
|
||||
spatial_size = input_size[-2:]
|
||||
for e, r, o in zip(expected_channels, expected_reduction, outputs):
|
||||
print(o.shape)
|
||||
assert e == o.shape[feat_axis]
|
||||
assert o.shape[spatial_axis[0]] <= math.ceil(spatial_size[0] / r) + 1
|
||||
assert o.shape[spatial_axis[1]] <= math.ceil(spatial_size[1] / r) + 1
|
||||
assert o.shape[0] == batch_size
|
||||
assert not torch.isnan(o).any()
|
||||
|
||||
|
||||
@pytest.mark.features
|
||||
@pytest.mark.timeout(120)
|
||||
@pytest.mark.parametrize('model_name', list_models(FEAT_INTER_FILTERS, exclude_filters=EXCLUDE_FILTERS))
|
||||
@pytest.mark.parametrize('batch_size', [1])
|
||||
def test_model_forward_intermediates(model_name, batch_size):
|
||||
"""Run a single forward pass with each model in feature extraction mode"""
|
||||
model = create_model(model_name, pretrained=False)
|
||||
model.eval()
|
||||
feature_info = timm.models.FeatureInfo(model.feature_info, len(model.feature_info))
|
||||
expected_channels = feature_info.channels()
|
||||
expected_reduction = feature_info.reduction()
|
||||
assert len(expected_channels) >= 4 # all models here should have at least 4 feature levels by default, some 5 or 6
|
||||
|
||||
input_size = _get_input_size(model=model, target=TARGET_FFEAT_SIZE)
|
||||
if max(input_size) > MAX_FFEAT_SIZE:
|
||||
pytest.skip("Fixed input size model > limit.")
|
||||
output_fmt = getattr(model, 'output_fmt', 'NCHW')
|
||||
feat_axis = get_channel_dim(output_fmt)
|
||||
spatial_axis = get_spatial_dim(output_fmt)
|
||||
import math
|
||||
|
||||
output, intermediates = model.forward_intermediates(
|
||||
torch.randn((batch_size, *input_size)),
|
||||
)
|
||||
assert len(expected_channels) == len(intermediates)
|
||||
spatial_size = input_size[-2:]
|
||||
for e, r, o in zip(expected_channels, expected_reduction, intermediates):
|
||||
assert e == o.shape[feat_axis]
|
||||
assert o.shape[spatial_axis[0]] <= math.ceil(spatial_size[0] / r) + 1
|
||||
assert o.shape[spatial_axis[1]] <= math.ceil(spatial_size[1] / r) + 1
|
||||
assert o.shape[0] == batch_size
|
||||
assert not torch.isnan(o).any()
|
||||
|
||||
|
||||
def _create_fx_model(model, train=False):
|
||||
# This block of code does a bit of juggling to handle any case where there are multiple outputs in train mode
|
||||
# So we trace once and look at the graph, and get the indices of the nodes that lead into the original fx output
|
||||
|
|
|
@ -9,6 +9,7 @@ Based on code in:
|
|||
Hacked together by / Copyright 2020 Ross Wightman
|
||||
"""
|
||||
import logging
|
||||
import math
|
||||
from typing import Callable, List, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
|
@ -65,6 +66,21 @@ class PatchEmbed(nn.Module):
|
|||
self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size, bias=bias)
|
||||
self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity()
|
||||
|
||||
def feat_ratio(self, as_scalar=True) -> Union[Tuple[int, int], int]:
|
||||
if as_scalar:
|
||||
return max(self.patch_size)
|
||||
else:
|
||||
return self.patch_size
|
||||
|
||||
def dynamic_feat_size(self, img_size: Tuple[int, int]) -> Tuple[int, int]:
|
||||
""" Get grid (feature) size for given image size taking account of dynamic padding.
|
||||
NOTE: must be torchscript compatible so using fixed tuple indexing
|
||||
"""
|
||||
if self.dynamic_img_pad:
|
||||
return math.ceil(img_size[0] / self.patch_size[0]), math.ceil(img_size[1] / self.patch_size[1])
|
||||
else:
|
||||
return img_size[0] // self.patch_size[0], img_size[1] // self.patch_size[1]
|
||||
|
||||
def forward(self, x):
|
||||
B, C, H, W = x.shape
|
||||
if self.img_size is not None:
|
||||
|
@ -127,13 +143,13 @@ class PatchEmbedWithSize(PatchEmbed):
|
|||
_assert(W % self.patch_size[1] == 0, f"Input image width ({W}) must be divisible by patch size ({self.patch_size[1]}).")
|
||||
|
||||
x = self.proj(x)
|
||||
grid_size = x.shape[-2:]
|
||||
feat_size = x.shape[-2:]
|
||||
if self.flatten:
|
||||
x = x.flatten(2).transpose(1, 2) # NCHW -> NLC
|
||||
elif self.output_fmt != Format.NCHW:
|
||||
x = nchw_to(x, self.output_fmt)
|
||||
x = self.norm(x)
|
||||
return x, grid_size
|
||||
return x, feat_size
|
||||
|
||||
|
||||
def resample_patch_embed(
|
||||
|
|
|
@ -7,7 +7,7 @@ from typing import Optional, Dict, Callable, Any, Tuple
|
|||
from torch import nn as nn
|
||||
from torch.hub import load_state_dict_from_url
|
||||
|
||||
from timm.models._features import FeatureListNet, FeatureHookNet
|
||||
from timm.models._features import FeatureListNet, FeatureDictNet, FeatureHookNet, FeatureGetterNet
|
||||
from timm.models._features_fx import FeatureGraphNet
|
||||
from timm.models._helpers import load_state_dict
|
||||
from timm.models._hub import has_hf_hub, download_cached_file, check_cached_file, load_state_dict_from_hf
|
||||
|
@ -428,8 +428,12 @@ def build_model_with_cfg(
|
|||
feature_cls = feature_cls.lower()
|
||||
if 'hook' in feature_cls:
|
||||
feature_cls = FeatureHookNet
|
||||
elif feature_cls == 'dict':
|
||||
feature_cls = FeatureDictNet
|
||||
elif feature_cls == 'fx':
|
||||
feature_cls = FeatureGraphNet
|
||||
elif feature_cls == 'getter':
|
||||
feature_cls = FeatureGetterNet
|
||||
else:
|
||||
assert False, f'Unknown feature class {feature_cls}'
|
||||
model = feature_cls(model, **feature_cfg)
|
||||
|
|
|
@ -11,7 +11,7 @@ Hacked together by / Copyright 2020 Ross Wightman
|
|||
from collections import OrderedDict, defaultdict
|
||||
from copy import deepcopy
|
||||
from functools import partial
|
||||
from typing import Dict, List, Sequence, Tuple, Union
|
||||
from typing import Dict, List, Optional, Sequence, Set, Tuple, Union
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
@ -20,12 +20,70 @@ from torch.utils.checkpoint import checkpoint
|
|||
from timm.layers import Format
|
||||
|
||||
|
||||
__all__ = ['FeatureInfo', 'FeatureHooks', 'FeatureDictNet', 'FeatureListNet', 'FeatureHookNet']
|
||||
__all__ = [
|
||||
'FeatureInfo', 'FeatureHooks', 'FeatureDictNet', 'FeatureListNet', 'FeatureHookNet', 'FeatureGetterNet',
|
||||
'feature_take_indices'
|
||||
]
|
||||
|
||||
|
||||
def _take_indices(
|
||||
num_blocks: int,
|
||||
n: Optional[Union[int, List[int], Tuple[int]]],
|
||||
) -> Tuple[Set[int], int]:
|
||||
if isinstance(n, int):
|
||||
assert n >= 0
|
||||
take_indices = {x for x in range(num_blocks - n, num_blocks)}
|
||||
else:
|
||||
take_indices = {num_blocks + idx if idx < 0 else idx for idx in n}
|
||||
return take_indices, max(take_indices)
|
||||
|
||||
|
||||
def _take_indices_jit(
|
||||
num_blocks: int,
|
||||
n: Union[int, List[int], Tuple[int]],
|
||||
) -> Tuple[List[int], int]:
|
||||
if isinstance(n, int):
|
||||
assert n >= 0
|
||||
take_indices = [num_blocks - n + i for i in range(n)]
|
||||
elif isinstance(n, tuple):
|
||||
# splitting this up is silly, but needed for torchscript type resolution of n
|
||||
take_indices = [num_blocks + idx if idx < 0 else idx for idx in n]
|
||||
else:
|
||||
take_indices = [num_blocks + idx if idx < 0 else idx for idx in n]
|
||||
return take_indices, max(take_indices)
|
||||
|
||||
|
||||
def feature_take_indices(
|
||||
num_blocks: int,
|
||||
indices: Optional[Union[int, List[int], Tuple[int]]] = None,
|
||||
) -> Tuple[List[int], int]:
|
||||
if indices is None:
|
||||
indices = num_blocks # all blocks if None
|
||||
if torch.jit.is_scripting():
|
||||
return _take_indices_jit(num_blocks, indices)
|
||||
else:
|
||||
# NOTE non-jit returns Set[int] instead of List[int] but torchscript can't handle that anno
|
||||
return _take_indices(num_blocks, indices)
|
||||
|
||||
|
||||
def _out_indices_as_tuple(x: Union[int, Tuple[int, ...]]) -> Tuple[int, ...]:
|
||||
if isinstance(x, int):
|
||||
# if indices is an int, take last N features
|
||||
return tuple(range(-x, 0))
|
||||
return tuple(x)
|
||||
|
||||
|
||||
OutIndicesT = Union[int, Tuple[int, ...]]
|
||||
|
||||
|
||||
class FeatureInfo:
|
||||
|
||||
def __init__(self, feature_info: List[Dict], out_indices: Tuple[int]):
|
||||
def __init__(
|
||||
self,
|
||||
feature_info: List[Dict],
|
||||
out_indices: OutIndicesT,
|
||||
):
|
||||
out_indices = _out_indices_as_tuple(out_indices)
|
||||
prev_reduction = 1
|
||||
for i, fi in enumerate(feature_info):
|
||||
# sanity check the mandatory fields, there may be additional fields depending on the model
|
||||
|
@ -37,14 +95,15 @@ class FeatureInfo:
|
|||
self.out_indices = out_indices
|
||||
self.info = feature_info
|
||||
|
||||
def from_other(self, out_indices: Tuple[int]):
|
||||
def from_other(self, out_indices: OutIndicesT):
|
||||
out_indices = _out_indices_as_tuple(out_indices)
|
||||
return FeatureInfo(deepcopy(self.info), out_indices)
|
||||
|
||||
def get(self, key, idx=None):
|
||||
def get(self, key: str, idx: Optional[Union[int, List[int]]] = None):
|
||||
""" Get value by key at specified index (indices)
|
||||
if idx == None, returns value for key at each output index
|
||||
if idx is an integer, return value for that feature module index (ignoring output indices)
|
||||
if idx is a list/tupple, return value for each module index (ignoring output indices)
|
||||
if idx is a list/tuple, return value for each module index (ignoring output indices)
|
||||
"""
|
||||
if idx is None:
|
||||
return [self.info[i][key] for i in self.out_indices]
|
||||
|
@ -53,7 +112,7 @@ class FeatureInfo:
|
|||
else:
|
||||
return self.info[idx][key]
|
||||
|
||||
def get_dicts(self, keys=None, idx=None):
|
||||
def get_dicts(self, keys: Optional[List[str]] = None, idx: Optional[Union[int, List[int]]] = None):
|
||||
""" return info dicts for specified keys (or all if None) at specified indices (or out_indices if None)
|
||||
"""
|
||||
if idx is None:
|
||||
|
@ -66,17 +125,17 @@ class FeatureInfo:
|
|||
else:
|
||||
return self.info[idx] if keys is None else {k: self.info[idx][k] for k in keys}
|
||||
|
||||
def channels(self, idx=None):
|
||||
def channels(self, idx: Optional[Union[int, List[int]]] = None):
|
||||
""" feature channels accessor
|
||||
"""
|
||||
return self.get('num_chs', idx)
|
||||
|
||||
def reduction(self, idx=None):
|
||||
def reduction(self, idx: Optional[Union[int, List[int]]] = None):
|
||||
""" feature reduction (output stride) accessor
|
||||
"""
|
||||
return self.get('reduction', idx)
|
||||
|
||||
def module_name(self, idx=None):
|
||||
def module_name(self, idx: Optional[Union[int, List[int]]] = None):
|
||||
""" feature module name accessor
|
||||
"""
|
||||
return self.get('module', idx)
|
||||
|
@ -146,7 +205,7 @@ def _module_list(module, flatten_sequential=False):
|
|||
return ml
|
||||
|
||||
|
||||
def _get_feature_info(net, out_indices):
|
||||
def _get_feature_info(net, out_indices: OutIndicesT):
|
||||
feature_info = getattr(net, 'feature_info')
|
||||
if isinstance(feature_info, FeatureInfo):
|
||||
return feature_info.from_other(out_indices)
|
||||
|
@ -182,7 +241,7 @@ class FeatureDictNet(nn.ModuleDict):
|
|||
def __init__(
|
||||
self,
|
||||
model: nn.Module,
|
||||
out_indices: Tuple[int, ...] = (0, 1, 2, 3, 4),
|
||||
out_indices: OutIndicesT = (0, 1, 2, 3, 4),
|
||||
out_map: Sequence[Union[int, str]] = None,
|
||||
output_fmt: str = 'NCHW',
|
||||
feature_concat: bool = False,
|
||||
|
@ -257,7 +316,7 @@ class FeatureListNet(FeatureDictNet):
|
|||
def __init__(
|
||||
self,
|
||||
model: nn.Module,
|
||||
out_indices: Tuple[int, ...] = (0, 1, 2, 3, 4),
|
||||
out_indices: OutIndicesT = (0, 1, 2, 3, 4),
|
||||
output_fmt: str = 'NCHW',
|
||||
feature_concat: bool = False,
|
||||
flatten_sequential: bool = False,
|
||||
|
@ -298,8 +357,8 @@ class FeatureHookNet(nn.ModuleDict):
|
|||
def __init__(
|
||||
self,
|
||||
model: nn.Module,
|
||||
out_indices: Tuple[int, ...] = (0, 1, 2, 3, 4),
|
||||
out_map: Sequence[Union[int, str]] = None,
|
||||
out_indices: OutIndicesT = (0, 1, 2, 3, 4),
|
||||
out_map: Optional[Sequence[Union[int, str]]] = None,
|
||||
return_dict: bool = False,
|
||||
output_fmt: str = 'NCHW',
|
||||
no_rewrite: bool = False,
|
||||
|
@ -366,3 +425,55 @@ class FeatureHookNet(nn.ModuleDict):
|
|||
x = module(x)
|
||||
out = self.hooks.get_output(x.device)
|
||||
return out if self.return_dict else list(out.values())
|
||||
|
||||
|
||||
class FeatureGetterNet(nn.ModuleDict):
|
||||
""" FeatureGetterNet
|
||||
|
||||
Wrap models with a feature getter method, like 'get_intermediate_layers'
|
||||
|
||||
"""
|
||||
def __init__(
|
||||
self,
|
||||
model: nn.Module,
|
||||
out_indices: OutIndicesT = 4,
|
||||
out_map: Optional[Sequence[Union[int, str]]] = None,
|
||||
return_dict: bool = False,
|
||||
output_fmt: str = 'NCHW',
|
||||
norm: bool = False,
|
||||
prune: bool = True,
|
||||
):
|
||||
"""
|
||||
|
||||
Args:
|
||||
model: Model to wrap.
|
||||
out_indices: Indices of features to extract.
|
||||
out_map: Remap feature names for dict output (WIP, not supported).
|
||||
return_dict: Return features as dictionary instead of list (WIP, not supported).
|
||||
norm: Apply final model norm to all output features (if possible).
|
||||
"""
|
||||
super().__init__()
|
||||
if prune and hasattr(model, 'prune_intermediate_layers'):
|
||||
# replace out_indices after they've been normalized, -ve indices will be invalid after prune
|
||||
out_indices = model.prune_intermediate_layers(
|
||||
out_indices,
|
||||
prune_norm=not norm,
|
||||
)
|
||||
out_indices = list(out_indices)
|
||||
self.feature_info = _get_feature_info(model, out_indices)
|
||||
self.model = model
|
||||
self.out_indices = out_indices
|
||||
self.out_map = out_map
|
||||
self.return_dict = return_dict
|
||||
self.output_fmt = output_fmt
|
||||
self.norm = norm
|
||||
|
||||
def forward(self, x):
|
||||
features = self.model.forward_intermediates(
|
||||
x,
|
||||
indices=self.out_indices,
|
||||
norm=self.norm,
|
||||
output_fmt=self.output_fmt,
|
||||
intermediates_only=True,
|
||||
)
|
||||
return features
|
||||
|
|
|
@ -1,7 +1,7 @@
|
|||
""" PyTorch FX Based Feature Extraction Helpers
|
||||
Using https://pytorch.org/vision/stable/feature_extraction.html
|
||||
"""
|
||||
from typing import Callable, List, Dict, Union, Type
|
||||
from typing import Callable, Dict, List, Optional, Union, Tuple, Type
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
|
@ -103,7 +103,12 @@ def create_feature_extractor(model: nn.Module, return_nodes: Union[Dict[str, str
|
|||
class FeatureGraphNet(nn.Module):
|
||||
""" A FX Graph based feature extractor that works with the model feature_info metadata
|
||||
"""
|
||||
def __init__(self, model, out_indices, out_map=None):
|
||||
def __init__(
|
||||
self,
|
||||
model: nn.Module,
|
||||
out_indices: Tuple[int, ...],
|
||||
out_map: Optional[Dict] = None,
|
||||
):
|
||||
super().__init__()
|
||||
assert has_fx_feature_extraction, 'Please update to PyTorch 1.10+, torchvision 0.11+ for FX feature extraction'
|
||||
self.feature_info = _get_feature_info(model, out_indices)
|
||||
|
@ -129,7 +134,12 @@ class GraphExtractNet(nn.Module):
|
|||
return_nodes: node names to return features from (dict or list)
|
||||
squeeze_out: if only one output, and output in list format, flatten to single tensor
|
||||
"""
|
||||
def __init__(self, model, return_nodes: Union[Dict[str, str], List[str]], squeeze_out: bool = True):
|
||||
def __init__(
|
||||
self,
|
||||
model: nn.Module,
|
||||
return_nodes: Union[Dict[str, str], List[str]],
|
||||
squeeze_out: bool = True,
|
||||
):
|
||||
super().__init__()
|
||||
self.squeeze_out = squeeze_out
|
||||
self.graph_module = create_feature_extractor(model, return_nodes)
|
||||
|
|
|
@ -39,7 +39,7 @@ Modifications by / Copyright 2021 Ross Wightman, original copyrights below
|
|||
# --------------------------------------------------------'
|
||||
|
||||
import math
|
||||
from typing import Callable, Optional, Tuple, Union
|
||||
from typing import Callable, List, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
@ -52,8 +52,8 @@ from timm.layers import resample_patch_embed, resample_abs_pos_embed, resize_rel
|
|||
|
||||
|
||||
from ._builder import build_model_with_cfg
|
||||
from ._features import feature_take_indices
|
||||
from ._registry import generate_default_cfgs, register_model
|
||||
from .vision_transformer import checkpoint_filter_fn
|
||||
|
||||
__all__ = ['Beit']
|
||||
|
||||
|
@ -302,6 +302,7 @@ class Beit(nn.Module):
|
|||
embed_dim=embed_dim,
|
||||
)
|
||||
num_patches = self.patch_embed.num_patches
|
||||
r = self.patch_embed.feat_ratio() if hasattr(self.patch_embed, 'feat_ratio') else patch_size
|
||||
|
||||
self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
|
||||
# self.mask_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
|
||||
|
@ -333,6 +334,8 @@ class Beit(nn.Module):
|
|||
window_size=self.patch_embed.grid_size if use_rel_pos_bias else None,
|
||||
)
|
||||
for i in range(depth)])
|
||||
self.feature_info = [
|
||||
dict(module=f'blocks.{i}', num_chs=embed_dim, reduction=r) for i in range(depth)]
|
||||
|
||||
use_fc_norm = self.global_pool == 'avg'
|
||||
self.norm = nn.Identity() if use_fc_norm else norm_layer(embed_dim)
|
||||
|
@ -398,6 +401,89 @@ class Beit(nn.Module):
|
|||
self.global_pool = global_pool
|
||||
self.head = nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity()
|
||||
|
||||
def forward_intermediates(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
indices: Optional[Union[int, List[int], Tuple[int]]] = None,
|
||||
return_prefix_tokens: bool = False,
|
||||
norm: bool = False,
|
||||
stop_early: bool = True,
|
||||
output_fmt: str = 'NCHW',
|
||||
intermediates_only: bool = False,
|
||||
) -> Union[List[torch.Tensor], Tuple[torch.Tensor, List[torch.Tensor]]]:
|
||||
""" Forward features that returns intermediates.
|
||||
|
||||
Args:
|
||||
x: Input image tensor
|
||||
indices: Take last n blocks if an int, if is a sequence, select by matching indices
|
||||
return_prefix_tokens: Return both prefix and spatial intermediate tokens
|
||||
norm: Apply norm layer to all intermediates
|
||||
stop_early: Stop iterating over blocks when last desired intermediate hit
|
||||
output_fmt: Shape of intermediate feature outputs
|
||||
intermediates_only: Only return intermediate features
|
||||
Returns:
|
||||
|
||||
"""
|
||||
assert output_fmt in ('NCHW', 'NLC'), 'Output format for ViT features must be one of NCHW or NLC.'
|
||||
reshape = output_fmt == 'NCHW'
|
||||
intermediates = []
|
||||
take_indices, max_index = feature_take_indices(len(self.blocks), indices)
|
||||
|
||||
# forward pass
|
||||
B, _, height, width = x.shape
|
||||
x = self.patch_embed(x)
|
||||
x = torch.cat((self.cls_token.expand(x.shape[0], -1, -1), x), dim=1)
|
||||
if self.pos_embed is not None:
|
||||
x = x + self.pos_embed
|
||||
x = self.pos_drop(x)
|
||||
rel_pos_bias = self.rel_pos_bias() if self.rel_pos_bias is not None else None
|
||||
if torch.jit.is_scripting() or not stop_early: # can't slice blocks in torchscript
|
||||
blocks = self.blocks
|
||||
else:
|
||||
blocks = self.blocks[:max_index + 1]
|
||||
for i, blk in enumerate(blocks):
|
||||
x = blk(x, shared_rel_pos_bias=rel_pos_bias)
|
||||
if i in take_indices:
|
||||
# normalize intermediates with final norm layer if enabled
|
||||
intermediates.append(self.norm(x) if norm else x)
|
||||
|
||||
# process intermediates
|
||||
if self.num_prefix_tokens:
|
||||
# split prefix (e.g. class, distill) and spatial feature tokens
|
||||
prefix_tokens = [y[:, 0:self.num_prefix_tokens] for y in intermediates]
|
||||
intermediates = [y[:, self.num_prefix_tokens:] for y in intermediates]
|
||||
if reshape:
|
||||
# reshape to BCHW output format
|
||||
H, W = self.patch_embed.dynamic_feat_size((height, width))
|
||||
intermediates = [y.reshape(B, H, W, -1).permute(0, 3, 1, 2).contiguous() for y in intermediates]
|
||||
if not torch.jit.is_scripting() and return_prefix_tokens:
|
||||
# return_prefix not support in torchscript due to poor type handling
|
||||
intermediates = list(zip(intermediates, prefix_tokens))
|
||||
|
||||
if intermediates_only:
|
||||
return intermediates
|
||||
|
||||
x = self.norm(x)
|
||||
|
||||
return x, intermediates
|
||||
|
||||
def prune_intermediate_layers(
|
||||
self,
|
||||
n: Union[int, List[int], Tuple[int]] = 1,
|
||||
prune_norm: bool = False,
|
||||
prune_head: bool = True,
|
||||
):
|
||||
""" Prune layers not required for specified intermediates.
|
||||
"""
|
||||
take_indices, max_index = feature_take_indices(len(self.blocks), n)
|
||||
self.blocks = self.blocks[:max_index + 1] # truncate blocks
|
||||
if prune_norm:
|
||||
self.norm = nn.Identity()
|
||||
if prune_head:
|
||||
self.fc_norm = nn.Identity()
|
||||
self.head = nn.Identity()
|
||||
return take_indices
|
||||
|
||||
def forward_features(self, x):
|
||||
x = self.patch_embed(x)
|
||||
x = torch.cat((self.cls_token.expand(x.shape[0], -1, -1), x), dim=1)
|
||||
|
@ -547,14 +633,13 @@ def _beit_checkpoint_filter_fn(state_dict, model, interpolation='bicubic', antia
|
|||
|
||||
|
||||
def _create_beit(variant, pretrained=False, **kwargs):
|
||||
if kwargs.get('features_only', None):
|
||||
raise RuntimeError('features_only not implemented for BEiT models.')
|
||||
|
||||
out_indices = kwargs.pop('out_indices', 3)
|
||||
model = build_model_with_cfg(
|
||||
Beit, variant, pretrained,
|
||||
# FIXME an updated filter fn needed to interpolate rel pos emb if fine tuning to diff model sizes
|
||||
pretrained_filter_fn=_beit_checkpoint_filter_fn,
|
||||
**kwargs)
|
||||
feature_cfg=dict(out_indices=out_indices, feature_cls='getter'),
|
||||
**kwargs,
|
||||
)
|
||||
return model
|
||||
|
||||
|
||||
|
|
|
@ -1251,7 +1251,7 @@ class ByobNet(nn.Module):
|
|||
return x
|
||||
|
||||
def forward_head(self, x, pre_logits: bool = False):
|
||||
return self.head(x, pre_logits=pre_logits)
|
||||
return self.head(x, pre_logits=pre_logits) if pre_logits else self.head(x)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.forward_features(x)
|
||||
|
|
|
@ -710,7 +710,7 @@ class CspNet(nn.Module):
|
|||
return x
|
||||
|
||||
def forward_head(self, x, pre_logits: bool = False):
|
||||
return self.head(x, pre_logits=pre_logits)
|
||||
return self.head(x, pre_logits=pre_logits) if pre_logits else self.head(x)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.forward_features(x)
|
||||
|
|
|
@ -119,14 +119,14 @@ class VisionTransformerDistilled(VisionTransformer):
|
|||
|
||||
|
||||
def _create_deit(variant, pretrained=False, distilled=False, **kwargs):
|
||||
if kwargs.get('features_only', None):
|
||||
raise RuntimeError('features_only not implemented for Vision Transformer models.')
|
||||
out_indices = kwargs.pop('out_indices', 3)
|
||||
model_cls = VisionTransformerDistilled if distilled else VisionTransformer
|
||||
model = build_model_with_cfg(
|
||||
model_cls,
|
||||
variant,
|
||||
pretrained,
|
||||
pretrained_filter_fn=partial(checkpoint_filter_fn, adapt_layer_scale=True),
|
||||
feature_cfg=dict(out_indices=out_indices, feature_cls='getter'),
|
||||
**kwargs,
|
||||
)
|
||||
return model
|
||||
|
|
|
@ -24,9 +24,8 @@ Modifications by / Copyright 2023 Ross Wightman, original copyrights below
|
|||
"""
|
||||
# EVA models Copyright (c) 2022 BAAI-Vision
|
||||
# EVA02 models Copyright (c) 2023 BAAI-Vision
|
||||
|
||||
import math
|
||||
from typing import Callable, Optional, Tuple, Union
|
||||
from typing import Callable, List, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
@ -39,6 +38,7 @@ from timm.layers import PatchEmbed, Mlp, GluMlp, SwiGLU, LayerNorm, DropPath, Pa
|
|||
to_2tuple, use_fused_attn
|
||||
|
||||
from ._builder import build_model_with_cfg
|
||||
from ._features import feature_take_indices
|
||||
from ._registry import generate_default_cfgs, register_model
|
||||
|
||||
__all__ = ['Eva']
|
||||
|
@ -424,6 +424,7 @@ class Eva(nn.Module):
|
|||
**embed_args,
|
||||
)
|
||||
num_patches = self.patch_embed.num_patches
|
||||
r = self.patch_embed.feat_ratio() if hasattr(self.patch_embed, 'feat_ratio') else patch_size
|
||||
|
||||
self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) if class_token else None
|
||||
|
||||
|
@ -469,6 +470,8 @@ class Eva(nn.Module):
|
|||
init_values=init_values,
|
||||
)
|
||||
for i in range(depth)])
|
||||
self.feature_info = [
|
||||
dict(module=f'blocks.{i}', num_chs=embed_dim, reduction=r) for i in range(depth)]
|
||||
|
||||
use_fc_norm = self.global_pool == 'avg'
|
||||
self.norm = nn.Identity() if use_fc_norm else norm_layer(embed_dim)
|
||||
|
@ -559,6 +562,81 @@ class Eva(nn.Module):
|
|||
rot_pos_embed = apply_keep_indices_nlc(x, rot_pos_embed, keep_indices)
|
||||
return x, rot_pos_embed
|
||||
|
||||
def forward_intermediates(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
indices: Optional[Union[int, List[int], Tuple[int]]] = None,
|
||||
return_prefix_tokens: bool = False,
|
||||
norm: bool = False,
|
||||
stop_early: bool = True,
|
||||
output_fmt: str = 'NCHW',
|
||||
intermediates_only: bool = False,
|
||||
) -> Union[List[torch.Tensor], Tuple[torch.Tensor, List[torch.Tensor]]]:
|
||||
""" Forward features that returns intermediates.
|
||||
Args:
|
||||
x: Input image tensor
|
||||
indices: Take last n blocks if an int, if is a sequence, select by matching indices
|
||||
return_prefix_tokens: Return both prefix and spatial intermediate tokens
|
||||
norm: Apply norm layer to all intermediates
|
||||
stop_early: Stop iterating over blocks when last desired intermediate hit
|
||||
output_fmt: Shape of intermediate feature outputs
|
||||
intermediates_only: Only return intermediate features
|
||||
"""
|
||||
assert output_fmt in ('NCHW', 'NLC'), 'Output format for EVA-ViT features must be one of NCHW or NLC.'
|
||||
reshape = output_fmt == 'NCHW'
|
||||
intermediates = []
|
||||
take_indices, max_index = feature_take_indices(len(self.blocks), indices)
|
||||
|
||||
# forward pass
|
||||
B, _, height, width = x.shape
|
||||
x = self.patch_embed(x)
|
||||
x, rot_pos_embed = self._pos_embed(x)
|
||||
if torch.jit.is_scripting() or not stop_early: # can't slice blocks in torchscript
|
||||
blocks = self.blocks
|
||||
else:
|
||||
blocks = self.blocks[:max_index + 1]
|
||||
for i, blk in enumerate(blocks):
|
||||
x = blk(x, rope=rot_pos_embed)
|
||||
if i in take_indices:
|
||||
intermediates.append(self.norm(x) if norm else x)
|
||||
|
||||
# process intermediates
|
||||
if self.num_prefix_tokens:
|
||||
# split prefix (e.g. class, distill) and spatial feature tokens
|
||||
prefix_tokens = [y[:, 0:self.num_prefix_tokens] for y in intermediates]
|
||||
intermediates = [y[:, self.num_prefix_tokens:] for y in intermediates]
|
||||
if reshape:
|
||||
# reshape to BCHW output format
|
||||
H, W = self.patch_embed.dynamic_feat_size((height, width))
|
||||
intermediates = [y.reshape(B, H, W, -1).permute(0, 3, 1, 2).contiguous() for y in intermediates]
|
||||
if not torch.jit.is_scripting() and return_prefix_tokens:
|
||||
# return_prefix not support in torchscript due to poor type handling
|
||||
intermediates = list(zip(intermediates, prefix_tokens))
|
||||
|
||||
if intermediates_only:
|
||||
return intermediates
|
||||
|
||||
x = self.norm(x)
|
||||
|
||||
return x, intermediates
|
||||
|
||||
def prune_intermediate_layers(
|
||||
self,
|
||||
n: Union[int, List[int], Tuple[int]] = 1,
|
||||
prune_norm: bool = False,
|
||||
prune_head: bool = True,
|
||||
):
|
||||
""" Prune layers not required for specified intermediates.
|
||||
"""
|
||||
take_indices, max_index = feature_take_indices(len(self.blocks), n)
|
||||
self.blocks = self.blocks[:max_index + 1] # truncate blocks
|
||||
if prune_norm:
|
||||
self.norm = nn.Identity()
|
||||
if prune_head:
|
||||
self.fc_norm = nn.Identity()
|
||||
self.head = nn.Identity()
|
||||
return take_indices
|
||||
|
||||
def forward_features(self, x):
|
||||
x = self.patch_embed(x)
|
||||
x, rot_pos_embed = self._pos_embed(x)
|
||||
|
@ -663,13 +741,13 @@ def checkpoint_filter_fn(
|
|||
|
||||
|
||||
def _create_eva(variant, pretrained=False, **kwargs):
|
||||
if kwargs.get('features_only', None):
|
||||
raise RuntimeError('features_only not implemented for Eva models.')
|
||||
|
||||
out_indices = kwargs.pop('out_indices', 3)
|
||||
model = build_model_with_cfg(
|
||||
Eva, variant, pretrained,
|
||||
pretrained_filter_fn=checkpoint_filter_fn,
|
||||
**kwargs)
|
||||
feature_cfg=dict(out_indices=out_indices, feature_cls='getter'),
|
||||
**kwargs,
|
||||
)
|
||||
return model
|
||||
|
||||
|
||||
|
|
|
@ -464,7 +464,7 @@ class FocalNet(nn.Module):
|
|||
return x
|
||||
|
||||
def forward_head(self, x, pre_logits: bool = False):
|
||||
return self.head(x, pre_logits=pre_logits)
|
||||
return self.head(x, pre_logits=pre_logits) if pre_logits else self.head(x)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.forward_features(x)
|
||||
|
|
|
@ -501,7 +501,7 @@ class GlobalContextVit(nn.Module):
|
|||
return x
|
||||
|
||||
def forward_head(self, x, pre_logits: bool = False):
|
||||
return self.head(x, pre_logits=pre_logits)
|
||||
return self.head(x, pre_logits=pre_logits) if pre_logits else self.head(x)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
x = self.forward_features(x)
|
||||
|
|
|
@ -1258,7 +1258,7 @@ class MaxxVit(nn.Module):
|
|||
return x
|
||||
|
||||
def forward_head(self, x, pre_logits: bool = False):
|
||||
return self.head(x, pre_logits=pre_logits)
|
||||
return self.head(x, pre_logits=pre_logits) if pre_logits else self.head(x)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.forward_features(x)
|
||||
|
|
|
@ -26,6 +26,7 @@ from torch import nn
|
|||
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
|
||||
from timm.layers import Mlp, DropPath, trunc_normal_tf_, get_norm_layer, to_2tuple
|
||||
from ._builder import build_model_with_cfg
|
||||
from ._features import feature_take_indices
|
||||
from ._features_fx import register_notrace_function
|
||||
from ._registry import register_model, register_model_deprecations, generate_default_cfgs
|
||||
|
||||
|
@ -747,8 +748,10 @@ class MultiScaleVit(nn.Module):
|
|||
|
||||
num_stages = len(cfg.embed_dim)
|
||||
feat_size = patch_dims
|
||||
curr_stride = max(cfg.patch_stride)
|
||||
dpr = [x.tolist() for x in torch.linspace(0, drop_path_rate, sum(cfg.depths)).split(cfg.depths)]
|
||||
self.stages = nn.ModuleList()
|
||||
self.feature_info = []
|
||||
for i in range(num_stages):
|
||||
if cfg.expand_attn:
|
||||
dim_out = cfg.embed_dim[i]
|
||||
|
@ -775,6 +778,8 @@ class MultiScaleVit(nn.Module):
|
|||
norm_layer=norm_layer,
|
||||
drop_path=dpr[i],
|
||||
)
|
||||
curr_stride *= max(cfg.stride_q[i])
|
||||
self.feature_info += [dict(module=f'block.{i}', num_chs=dim_out, reduction=curr_stride)]
|
||||
embed_dim = dim_out
|
||||
feat_size = stage.feat_size
|
||||
self.stages.append(stage)
|
||||
|
@ -829,6 +834,63 @@ class MultiScaleVit(nn.Module):
|
|||
('fc', nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity())
|
||||
]))
|
||||
|
||||
def forward_intermediates(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
indices: Union[int, List[int], Tuple[int]] = None,
|
||||
norm: bool = False,
|
||||
stop_early: bool = True,
|
||||
output_fmt: str = 'NCHW',
|
||||
intermediates_only: bool = False,
|
||||
) -> Union[List[torch.Tensor], Tuple[torch.Tensor, List[torch.Tensor]]]:
|
||||
""" Forward features that returns intermediates.
|
||||
|
||||
Args:
|
||||
x: Input image tensor
|
||||
indices: Take last n blocks if int, all if None, select matching indices if sequence
|
||||
norm: Apply norm layer to all intermediates
|
||||
stop_early: Stop iterating over blocks when last desired intermediate hit
|
||||
output_fmt: Shape of intermediate feature outputs
|
||||
intermediates_only: Only return intermediate features
|
||||
Returns:
|
||||
|
||||
"""
|
||||
assert output_fmt in ('NCHW', 'NLC'), 'Output shape for MViT-V2 must be NCHW or NLC.'
|
||||
reshape = output_fmt == 'NCHW'
|
||||
intermediates = []
|
||||
take_indices, max_index = feature_take_indices(len(self.stages), indices)
|
||||
|
||||
# FIXME slice block/pos_block if < max
|
||||
|
||||
# forward pass
|
||||
x, feat_size = self.patch_embed(x)
|
||||
B = x.shape[0]
|
||||
if self.cls_token is not None:
|
||||
cls_tokens = self.cls_token.expand(B, -1, -1)
|
||||
x = torch.cat((cls_tokens, x), dim=1)
|
||||
if self.pos_embed is not None:
|
||||
x = x + self.pos_embed
|
||||
for i, stage in enumerate(self.stages):
|
||||
x, feat_size = stage(x, feat_size)
|
||||
if i in take_indices:
|
||||
if norm and i == (len(self.stages) - 1):
|
||||
x_inter = self.norm(x) # applying final norm last intermediate
|
||||
else:
|
||||
x_inter = x
|
||||
if reshape:
|
||||
if self.cls_token is not None:
|
||||
# possible to allow return of class tokens, TBD
|
||||
x_inter = x_inter[:, 1:]
|
||||
x_inter = x_inter.reshape(B, feat_size[0], feat_size[1], -1).permute(0, 3, 1, 2)
|
||||
intermediates.append(x_inter)
|
||||
|
||||
if intermediates_only:
|
||||
return intermediates
|
||||
|
||||
x = self.norm(x)
|
||||
|
||||
return x, intermediates
|
||||
|
||||
def forward_features(self, x):
|
||||
x, feat_size = self.patch_embed(x)
|
||||
B, N, C = x.shape
|
||||
|
@ -862,6 +924,18 @@ class MultiScaleVit(nn.Module):
|
|||
|
||||
def checkpoint_filter_fn(state_dict, model):
|
||||
if 'stages.0.blocks.0.norm1.weight' in state_dict:
|
||||
# native checkpoint, look for rel_pos interpolations
|
||||
for k in state_dict.keys():
|
||||
if 'rel_pos' in k:
|
||||
rel_pos = state_dict[k]
|
||||
dest_rel_pos_shape = model.state_dict()[k].shape
|
||||
if rel_pos.shape[0] != dest_rel_pos_shape[0]:
|
||||
rel_pos_resized = torch.nn.functional.interpolate(
|
||||
rel_pos.reshape(1, rel_pos.shape[0], -1).permute(0, 2, 1),
|
||||
size=dest_rel_pos_shape[0],
|
||||
mode="linear",
|
||||
)
|
||||
state_dict[k] = rel_pos_resized.reshape(-1, dest_rel_pos_shape[0]).permute(1, 0)
|
||||
return state_dict
|
||||
|
||||
import re
|
||||
|
@ -892,16 +966,6 @@ def checkpoint_filter_fn(state_dict, model):
|
|||
k = k.replace('head.projection', 'head.fc')
|
||||
out_dict[k] = v
|
||||
|
||||
# for k, v in state_dict.items():
|
||||
# if model.pos_embed is not None and k == 'pos_embed' and v.shape[1] != model.pos_embed.shape[1]:
|
||||
# # To resize pos embedding when using model at different size from pretrained weights
|
||||
# v = resize_pos_embed(
|
||||
# v,
|
||||
# model.pos_embed,
|
||||
# 0 if getattr(model, 'no_embed_class') else getattr(model, 'num_prefix_tokens', 1),
|
||||
# model.patch_embed.grid_size
|
||||
# )
|
||||
|
||||
return out_dict
|
||||
|
||||
|
||||
|
@ -948,16 +1012,14 @@ model_cfgs = dict(
|
|||
|
||||
|
||||
def _create_mvitv2(variant, cfg_variant=None, pretrained=False, **kwargs):
|
||||
if kwargs.get('features_only', None):
|
||||
raise RuntimeError('features_only not implemented for Multiscale Vision Transformer models.')
|
||||
|
||||
out_indices = kwargs.pop('out_indices', 4)
|
||||
return build_model_with_cfg(
|
||||
MultiScaleVit,
|
||||
variant,
|
||||
pretrained,
|
||||
model_cfg=model_cfgs[variant] if not cfg_variant else model_cfgs[cfg_variant],
|
||||
pretrained_filter_fn=checkpoint_filter_fn,
|
||||
feature_cfg=dict(flatten_sequential=True),
|
||||
feature_cfg=dict(out_indices=out_indices, feature_cls='getter'),
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
|
|
|
@ -525,7 +525,7 @@ class RegNet(nn.Module):
|
|||
return x
|
||||
|
||||
def forward_head(self, x, pre_logits: bool = False):
|
||||
return self.head(x, pre_logits=pre_logits)
|
||||
return self.head(x, pre_logits=pre_logits) if pre_logits else self.head(x)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.forward_features(x)
|
||||
|
|
|
@ -469,7 +469,7 @@ class ResNetV2(nn.Module):
|
|||
return x
|
||||
|
||||
def forward_head(self, x, pre_logits: bool = False):
|
||||
return self.head(x, pre_logits=pre_logits)
|
||||
return self.head(x, pre_logits=pre_logits) if pre_logits else self.head(x)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.forward_features(x)
|
||||
|
|
|
@ -13,7 +13,7 @@ Code/weights from https://github.com/Meituan-AutoML/Twins, original copyright/li
|
|||
# --------------------------------------------------------
|
||||
import math
|
||||
from functools import partial
|
||||
from typing import Tuple
|
||||
from typing import List, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
@ -22,6 +22,7 @@ import torch.nn.functional as F
|
|||
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
|
||||
from timm.layers import Mlp, DropPath, to_2tuple, trunc_normal_, use_fused_attn
|
||||
from ._builder import build_model_with_cfg
|
||||
from ._features import feature_take_indices
|
||||
from ._features_fx import register_notrace_module
|
||||
from ._registry import register_model, generate_default_cfgs
|
||||
from .vision_transformer import Attention
|
||||
|
@ -324,6 +325,7 @@ class Twins(nn.Module):
|
|||
patch_size = 2
|
||||
|
||||
self.blocks = nn.ModuleList()
|
||||
self.feature_info = []
|
||||
dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))] # stochastic depth decay rule
|
||||
cur = 0
|
||||
for k in range(len(depths)):
|
||||
|
@ -339,6 +341,7 @@ class Twins(nn.Module):
|
|||
ws=1 if wss is None or i % 2 == 1 else wss[k]) for i in range(depths[k])],
|
||||
)
|
||||
self.blocks.append(_block)
|
||||
self.feature_info += [dict(module=f'block.{k}', num_chs=embed_dims[k], reduction=2**(2+k))]
|
||||
cur += depths[k]
|
||||
|
||||
self.pos_block = nn.ModuleList([PosConv(embed_dim, embed_dim) for embed_dim in embed_dims])
|
||||
|
@ -401,6 +404,61 @@ class Twins(nn.Module):
|
|||
if m.bias is not None:
|
||||
m.bias.data.zero_()
|
||||
|
||||
def forward_intermediates(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
indices: Union[int, List[int], Tuple[int]] = None,
|
||||
norm: bool = False,
|
||||
stop_early: bool = True,
|
||||
output_fmt: str = 'NCHW',
|
||||
intermediates_only: bool = False,
|
||||
) -> Union[List[torch.Tensor], Tuple[torch.Tensor, List[torch.Tensor]]]:
|
||||
""" Forward features that returns intermediates.
|
||||
Args:
|
||||
x: Input image tensor
|
||||
indices: Take last n blocks if int, all if None, select matching indices if sequence
|
||||
norm: Apply norm layer to all intermediates
|
||||
stop_early: Stop iterating over blocks when last desired intermediate hit
|
||||
output_fmt: Shape of intermediate feature outputs
|
||||
intermediates_only: Only return intermediate features
|
||||
Returns:
|
||||
|
||||
"""
|
||||
assert output_fmt == 'NCHW', 'Output shape for Twins must be NCHW.'
|
||||
intermediates = []
|
||||
take_indices, max_index = feature_take_indices(len(self.blocks), indices)
|
||||
|
||||
# FIXME slice block/pos_block if < max
|
||||
|
||||
# forward pass
|
||||
B, _, height, width = x.shape
|
||||
for i, (embed, drop, blocks, pos_blk) in enumerate(zip(
|
||||
self.patch_embeds, self.pos_drops, self.blocks, self.pos_block)
|
||||
):
|
||||
x, size = embed(x)
|
||||
x = drop(x)
|
||||
for j, blk in enumerate(blocks):
|
||||
x = blk(x, size)
|
||||
if j == 0:
|
||||
x = pos_blk(x, size) # PEG here
|
||||
|
||||
if i < len(self.depths) - 1:
|
||||
x = x.reshape(B, *size, -1).permute(0, 3, 1, 2).contiguous()
|
||||
if i in take_indices:
|
||||
intermediates.append(x)
|
||||
else:
|
||||
if i in take_indices:
|
||||
# only last feature can be normed
|
||||
x_feat = self.norm(x) if norm else x
|
||||
intermediates.append(x_feat.reshape(B, *size, -1).permute(0, 3, 1, 2).contiguous())
|
||||
|
||||
if intermediates_only:
|
||||
return intermediates
|
||||
|
||||
x = self.norm(x)
|
||||
|
||||
return x, intermediates
|
||||
|
||||
def forward_features(self, x):
|
||||
B = x.shape[0]
|
||||
for i, (embed, drop, blocks, pos_blk) in enumerate(
|
||||
|
@ -429,10 +487,12 @@ class Twins(nn.Module):
|
|||
|
||||
|
||||
def _create_twins(variant, pretrained=False, **kwargs):
|
||||
if kwargs.get('features_only', None):
|
||||
raise RuntimeError('features_only not implemented for Vision Transformer models.')
|
||||
|
||||
model = build_model_with_cfg(Twins, variant, pretrained, **kwargs)
|
||||
out_indices = kwargs.pop('out_indices', 4)
|
||||
model = build_model_with_cfg(
|
||||
Twins, variant, pretrained,
|
||||
feature_cfg=dict(out_indices=out_indices, feature_cls='getter'),
|
||||
**kwargs,
|
||||
)
|
||||
return model
|
||||
|
||||
|
||||
|
|
|
@ -27,7 +27,7 @@ import logging
|
|||
import math
|
||||
from collections import OrderedDict
|
||||
from functools import partial
|
||||
from typing import Any, Callable, Dict, Optional, Sequence, Set, Tuple, Type, Union, List
|
||||
from typing import Any, Callable, Dict, Optional, Set, Tuple, Type, Union, List
|
||||
try:
|
||||
from typing import Literal
|
||||
except ImportError:
|
||||
|
@ -45,6 +45,7 @@ from timm.layers import PatchEmbed, Mlp, DropPath, AttentionPoolLatent, RmsNorm,
|
|||
trunc_normal_, lecun_normal_, resample_patch_embed, resample_abs_pos_embed, use_fused_attn, \
|
||||
get_act_layer, get_norm_layer, LayerType
|
||||
from ._builder import build_model_with_cfg
|
||||
from ._features import feature_take_indices
|
||||
from ._manipulate import named_apply, checkpoint_seq, adapt_input_conv
|
||||
from ._registry import generate_default_cfgs, register_model, register_model_deprecations
|
||||
|
||||
|
@ -488,6 +489,7 @@ class VisionTransformer(nn.Module):
|
|||
**embed_args,
|
||||
)
|
||||
num_patches = self.patch_embed.num_patches
|
||||
r = self.patch_embed.feat_ratio() if hasattr(self.patch_embed, 'feat_ratio') else patch_size
|
||||
|
||||
self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) if class_token else None
|
||||
self.reg_token = nn.Parameter(torch.zeros(1, reg_tokens, embed_dim)) if reg_tokens else None
|
||||
|
@ -520,6 +522,8 @@ class VisionTransformer(nn.Module):
|
|||
mlp_layer=mlp_layer,
|
||||
)
|
||||
for i in range(depth)])
|
||||
self.feature_info = [
|
||||
dict(module=f'blocks.{i}', num_chs=embed_dim, reduction=r) for i in range(depth)]
|
||||
self.norm = norm_layer(embed_dim) if not use_fc_norm else nn.Identity()
|
||||
|
||||
# Classifier Head
|
||||
|
@ -628,58 +632,107 @@ class VisionTransformer(nn.Module):
|
|||
|
||||
return self.pos_drop(x)
|
||||
|
||||
def _intermediate_layers(
|
||||
def forward_intermediates(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
n: Union[int, Sequence] = 1,
|
||||
) -> List[torch.Tensor]:
|
||||
outputs, num_blocks = [], len(self.blocks)
|
||||
take_indices = set(range(num_blocks - n, num_blocks) if isinstance(n, int) else n)
|
||||
last_index_to_take = max(take_indices)
|
||||
indices: Optional[Union[int, List[int], Tuple[int]]] = None,
|
||||
return_prefix_tokens: bool = False,
|
||||
norm: bool = False,
|
||||
stop_early: bool = True,
|
||||
output_fmt: str = 'NCHW',
|
||||
intermediates_only: bool = False,
|
||||
) -> Union[List[torch.Tensor], Tuple[torch.Tensor, List[torch.Tensor]]]:
|
||||
""" Forward features that returns intermediates.
|
||||
|
||||
Args:
|
||||
x: Input image tensor
|
||||
indices: Take last n blocks if int, all if None, select matching indices if sequence
|
||||
return_prefix_tokens: Return both prefix and spatial intermediate tokens
|
||||
norm: Apply norm layer to all intermediates
|
||||
stop_early: Stop iterating over blocks when last desired intermediate hit
|
||||
output_fmt: Shape of intermediate feature outputs
|
||||
intermediates_only: Only return intermediate features
|
||||
Returns:
|
||||
|
||||
"""
|
||||
assert output_fmt in ('NCHW', 'NLC'), 'Output format for ViT features must be one of NCHW or NLC.'
|
||||
reshape = output_fmt == 'NCHW'
|
||||
intermediates = []
|
||||
take_indices, max_index = feature_take_indices(len(self.blocks), indices)
|
||||
|
||||
# forward pass
|
||||
B, _, height, width = x.shape
|
||||
x = self.patch_embed(x)
|
||||
x = self._pos_embed(x)
|
||||
x = self.patch_drop(x)
|
||||
x = self.norm_pre(x)
|
||||
for i, blk in enumerate(self.blocks[: last_index_to_take + 1]):
|
||||
if torch.jit.is_scripting() or not stop_early: # can't slice blocks in torchscript
|
||||
blocks = self.blocks
|
||||
else:
|
||||
blocks = self.blocks[:max_index + 1]
|
||||
for i, blk in enumerate(blocks):
|
||||
x = blk(x)
|
||||
if i in take_indices:
|
||||
outputs.append(x)
|
||||
# normalize intermediates with final norm layer if enabled
|
||||
intermediates.append(self.norm(x) if norm else x)
|
||||
|
||||
return outputs
|
||||
# process intermediates
|
||||
if self.num_prefix_tokens:
|
||||
# split prefix (e.g. class, distill) and spatial feature tokens
|
||||
prefix_tokens = [y[:, 0:self.num_prefix_tokens] for y in intermediates]
|
||||
intermediates = [y[:, self.num_prefix_tokens:] for y in intermediates]
|
||||
if reshape:
|
||||
# reshape to BCHW output format
|
||||
H, W = self.patch_embed.dynamic_feat_size((height, width))
|
||||
intermediates = [y.reshape(B, H, W, -1).permute(0, 3, 1, 2).contiguous() for y in intermediates]
|
||||
if not torch.jit.is_scripting() and return_prefix_tokens:
|
||||
# return_prefix not support in torchscript due to poor type handling
|
||||
intermediates = list(zip(intermediates, prefix_tokens))
|
||||
|
||||
if intermediates_only:
|
||||
return intermediates
|
||||
|
||||
x = self.norm(x)
|
||||
|
||||
return x, intermediates
|
||||
|
||||
def prune_intermediate_layers(
|
||||
self,
|
||||
n: Union[int, List[int], Tuple[int]] = 1,
|
||||
prune_norm: bool = False,
|
||||
prune_head: bool = True,
|
||||
):
|
||||
""" Prune layers not required for specified intermediates.
|
||||
"""
|
||||
take_indices, max_index = feature_take_indices(len(self.blocks), n)
|
||||
self.blocks = self.blocks[:max_index + 1] # truncate blocks
|
||||
if prune_norm:
|
||||
self.norm = nn.Identity()
|
||||
if prune_head:
|
||||
if self.attn_pool is not None:
|
||||
self.attn_pool = None
|
||||
self.fc_norm = nn.Identity()
|
||||
self.head = nn.Identity()
|
||||
return take_indices
|
||||
|
||||
def get_intermediate_layers(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
n: Union[int, Sequence] = 1,
|
||||
n: Union[int, List[int], Tuple[int]] = 1,
|
||||
reshape: bool = False,
|
||||
return_prefix_tokens: bool = False,
|
||||
norm: bool = False,
|
||||
) -> Tuple[Union[torch.Tensor, Tuple[torch.Tensor]]]:
|
||||
""" Intermediate layer accessor (NOTE: This is a WIP experiment).
|
||||
Inspired by DINO / DINOv2 interface
|
||||
) -> List[torch.Tensor]:
|
||||
""" Intermediate layer accessor inspired by DINO / DINOv2 interface.
|
||||
NOTE: This API is for backwards compat, favour using forward_intermediates() directly.
|
||||
"""
|
||||
# take last n blocks if n is an int, if in is a sequence, select by matching indices
|
||||
outputs = self._intermediate_layers(x, n)
|
||||
if norm:
|
||||
outputs = [self.norm(out) for out in outputs]
|
||||
prefix_tokens = [out[:, 0:self.num_prefix_tokens] for out in outputs]
|
||||
outputs = [out[:, self.num_prefix_tokens:] for out in outputs]
|
||||
|
||||
if reshape:
|
||||
patch_size = self.patch_embed.patch_size
|
||||
batch, _, height, width = x.size()
|
||||
outputs = [
|
||||
out.reshape(batch, int(math.ceil(height / patch_size[0])), int(math.ceil(width / patch_size[1])), -1)
|
||||
.permute(0, 3, 1, 2)
|
||||
.contiguous()
|
||||
for out in outputs
|
||||
]
|
||||
|
||||
if return_prefix_tokens:
|
||||
return tuple(zip(outputs, prefix_tokens))
|
||||
return tuple(outputs)
|
||||
return self.forward_intermediates(
|
||||
x, n,
|
||||
return_prefix_tokens=return_prefix_tokens,
|
||||
norm=norm,
|
||||
output_fmt='NCHW' if reshape else 'NLC',
|
||||
intermediates_only=True,
|
||||
)
|
||||
|
||||
def forward_features(self, x: torch.Tensor) -> torch.Tensor:
|
||||
x = self.patch_embed(x)
|
||||
|
@ -1770,9 +1823,7 @@ default_cfgs = generate_default_cfgs(default_cfgs)
|
|||
|
||||
|
||||
def _create_vision_transformer(variant: str, pretrained: bool = False, **kwargs) -> VisionTransformer:
|
||||
if kwargs.get('features_only', None):
|
||||
raise RuntimeError('features_only not implemented for Vision Transformer models.')
|
||||
|
||||
out_indices = kwargs.pop('out_indices', 3)
|
||||
if 'flexi' in variant:
|
||||
# FIXME Google FlexiViT pretrained models have a strong preference for bilinear patch / embed
|
||||
# interpolation, other pretrained models resize better w/ anti-aliased bicubic interpolation.
|
||||
|
@ -1791,6 +1842,7 @@ def _create_vision_transformer(variant: str, pretrained: bool = False, **kwargs)
|
|||
pretrained,
|
||||
pretrained_filter_fn=_filter_fn,
|
||||
pretrained_strict=strict,
|
||||
feature_cfg=dict(out_indices=out_indices, feature_cls='getter'),
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
|
@ -2483,7 +2535,7 @@ def vit_huge_patch14_xp_224(pretrained: bool = False, **kwargs) -> VisionTransfo
|
|||
def vit_small_patch14_dinov2(pretrained: bool = False, **kwargs) -> VisionTransformer:
|
||||
""" ViT-S/14 for DINOv2
|
||||
"""
|
||||
model_args = dict(patch_size=14, embed_dim=384, depth=12, num_heads=6, init_values=1e-5, img_size=518)
|
||||
model_args = dict(patch_size=14, embed_dim=384, depth=12, num_heads=6, init_values=1e-5)
|
||||
model = _create_vision_transformer(
|
||||
'vit_small_patch14_dinov2', pretrained=pretrained, **dict(model_args, **kwargs))
|
||||
return model
|
||||
|
@ -2493,7 +2545,7 @@ def vit_small_patch14_dinov2(pretrained: bool = False, **kwargs) -> VisionTransf
|
|||
def vit_base_patch14_dinov2(pretrained: bool = False, **kwargs) -> VisionTransformer:
|
||||
""" ViT-B/14 for DINOv2
|
||||
"""
|
||||
model_args = dict(patch_size=14, embed_dim=768, depth=12, num_heads=12, init_values=1e-5, img_size=518)
|
||||
model_args = dict(patch_size=14, embed_dim=768, depth=12, num_heads=12, init_values=1e-5)
|
||||
model = _create_vision_transformer(
|
||||
'vit_base_patch14_dinov2', pretrained=pretrained, **dict(model_args, **kwargs))
|
||||
return model
|
||||
|
@ -2503,7 +2555,7 @@ def vit_base_patch14_dinov2(pretrained: bool = False, **kwargs) -> VisionTransfo
|
|||
def vit_large_patch14_dinov2(pretrained: bool = False, **kwargs) -> VisionTransformer:
|
||||
""" ViT-L/14 for DINOv2
|
||||
"""
|
||||
model_args = dict(patch_size=14, embed_dim=1024, depth=24, num_heads=16, init_values=1e-5, img_size=518)
|
||||
model_args = dict(patch_size=14, embed_dim=1024, depth=24, num_heads=16, init_values=1e-5)
|
||||
model = _create_vision_transformer(
|
||||
'vit_large_patch14_dinov2', pretrained=pretrained, **dict(model_args, **kwargs))
|
||||
return model
|
||||
|
@ -2519,7 +2571,7 @@ def vit_giant_patch14_dinov2(pretrained: bool = False, **kwargs) -> VisionTransf
|
|||
# With SwiGLUPacked, we need to set hidden_features = 2 * 4096 = 8192
|
||||
model_args = dict(
|
||||
patch_size=14, embed_dim=1536, depth=40, num_heads=24, init_values=1e-5,
|
||||
mlp_ratio=2.66667 * 2, mlp_layer=SwiGLUPacked, img_size=518, act_layer=nn.SiLU
|
||||
mlp_ratio=2.66667 * 2, mlp_layer=SwiGLUPacked, act_layer=nn.SiLU
|
||||
)
|
||||
model = _create_vision_transformer(
|
||||
'vit_giant_patch14_dinov2', pretrained=pretrained, **dict(model_args, **kwargs))
|
||||
|
|
|
@ -13,8 +13,9 @@ They were moved here to keep file sizes sane.
|
|||
|
||||
Hacked together by / Copyright 2020, Ross Wightman
|
||||
"""
|
||||
import math
|
||||
from functools import partial
|
||||
from typing import List, Optional, Tuple
|
||||
from typing import List, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
@ -41,6 +42,7 @@ class HybridEmbed(nn.Module):
|
|||
img_size=224,
|
||||
patch_size=1,
|
||||
feature_size=None,
|
||||
feature_ratio=None,
|
||||
in_chans=3,
|
||||
embed_dim=768,
|
||||
bias=True,
|
||||
|
@ -68,15 +70,20 @@ class HybridEmbed(nn.Module):
|
|||
feature_size = o.shape[-2:]
|
||||
feature_dim = o.shape[1]
|
||||
backbone.train(training)
|
||||
feature_ratio = tuple([s // f for s, f in zip(img_size, feature_size)])
|
||||
else:
|
||||
|
||||
feature_size = to_2tuple(feature_size)
|
||||
feature_ratio = to_2tuple(feature_ratio or 16)
|
||||
if hasattr(self.backbone, 'feature_info'):
|
||||
feature_dim = self.backbone.feature_info.channels()[-1]
|
||||
else:
|
||||
feature_dim = self.backbone.num_features
|
||||
if not dynamic_img_pad:
|
||||
assert feature_size[0] % patch_size[0] == 0 and feature_size[1] % patch_size[1] == 0
|
||||
self.grid_size = (feature_size[0] // patch_size[0], feature_size[1] // patch_size[1])
|
||||
self.feature_size = feature_size
|
||||
self.feature_ratio = feature_ratio
|
||||
self.grid_size = tuple([f // p for f, p in zip(self.feature_size, self.patch_size)])
|
||||
self.num_patches = self.grid_size[0] * self.grid_size[1]
|
||||
if output_fmt is not None:
|
||||
self.flatten = False
|
||||
|
@ -90,6 +97,25 @@ class HybridEmbed(nn.Module):
|
|||
|
||||
self.proj = nn.Conv2d(feature_dim, embed_dim, kernel_size=patch_size, stride=patch_size, bias=bias)
|
||||
|
||||
def feat_ratio(self, as_scalar=True) -> Union[Tuple[int, int], int]:
|
||||
total_reduction = (
|
||||
self.feature_ratio[0] * self.patch_size[0],
|
||||
self.feature_ratio[1] * self.patch_size[1]
|
||||
)
|
||||
if as_scalar:
|
||||
return max(total_reduction)
|
||||
else:
|
||||
return total_reduction
|
||||
|
||||
def dynamic_feat_size(self, img_size: Tuple[int, int]) -> Tuple[int, int]:
|
||||
""" Get feature grid size taking account dynamic padding and backbone network feat reduction
|
||||
"""
|
||||
feat_size = (img_size[0] // self.feature_ratio[0], img_size[1] // self.feature_ratio[1])
|
||||
if self.dynamic_img_pad:
|
||||
return math.ceil(feat_size[0] / self.patch_size[0]), math.ceil(feat_size[1] / self.patch_size[1])
|
||||
else:
|
||||
return feat_size[0] // self.patch_size[0], feat_size[1] // self.patch_size[1]
|
||||
|
||||
def forward(self, x):
|
||||
x = self.backbone(x)
|
||||
if isinstance(x, (list, tuple)):
|
||||
|
|
|
@ -7,7 +7,7 @@ Hacked together by / Copyright 2022, Ross Wightman
|
|||
import logging
|
||||
import math
|
||||
from functools import partial
|
||||
from typing import Optional, Tuple, Type, Union
|
||||
from typing import List, Optional, Tuple, Type, Union
|
||||
|
||||
try:
|
||||
from typing import Literal
|
||||
|
@ -22,6 +22,7 @@ from torch.utils.checkpoint import checkpoint
|
|||
from timm.data import IMAGENET_INCEPTION_MEAN, IMAGENET_INCEPTION_STD
|
||||
from timm.layers import PatchEmbed, Mlp, DropPath, RelPosMlp, RelPosBias, use_fused_attn, LayerType
|
||||
from ._builder import build_model_with_cfg
|
||||
from ._features import feature_take_indices
|
||||
from ._manipulate import named_apply
|
||||
from ._registry import generate_default_cfgs, register_model
|
||||
from .vision_transformer import get_init_weights_vit
|
||||
|
@ -297,6 +298,7 @@ class VisionTransformerRelPos(nn.Module):
|
|||
embed_dim=embed_dim,
|
||||
)
|
||||
feat_size = self.patch_embed.grid_size
|
||||
r = self.patch_embed.feat_ratio() if hasattr(self.patch_embed, 'feat_ratio') else patch_size
|
||||
|
||||
rel_pos_args = dict(window_size=feat_size, prefix_tokens=self.num_prefix_tokens)
|
||||
if rel_pos_type.startswith('mlp'):
|
||||
|
@ -332,6 +334,8 @@ class VisionTransformerRelPos(nn.Module):
|
|||
act_layer=act_layer,
|
||||
)
|
||||
for i in range(depth)])
|
||||
self.feature_info = [
|
||||
dict(module=f'blocks.{i}', num_chs=embed_dim, reduction=r) for i in range(depth)]
|
||||
self.norm = norm_layer(embed_dim) if not fc_norm else nn.Identity()
|
||||
|
||||
# Classifier Head
|
||||
|
@ -384,6 +388,88 @@ class VisionTransformerRelPos(nn.Module):
|
|||
self.global_pool = global_pool
|
||||
self.head = nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity()
|
||||
|
||||
def forward_intermediates(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
indices: Optional[Union[int, List[int], Tuple[int]]] = None,
|
||||
return_prefix_tokens: bool = False,
|
||||
norm: bool = False,
|
||||
stop_early: bool = True,
|
||||
output_fmt: str = 'NCHW',
|
||||
intermediates_only: bool = False,
|
||||
) -> Union[List[torch.Tensor], Tuple[torch.Tensor, List[torch.Tensor]]]:
|
||||
""" Forward features that returns intermediates.
|
||||
|
||||
Args:
|
||||
x: Input image tensor
|
||||
indices: Take last n blocks if int, all if None, select matching indices if sequence
|
||||
return_prefix_tokens: Return both prefix and spatial intermediate tokens
|
||||
norm: Apply norm layer to all intermediates
|
||||
stop_early: Stop iterating over blocks when last desired intermediate hit
|
||||
output_fmt: Shape of intermediate feature outputs
|
||||
intermediates_only: Only return intermediate features
|
||||
Returns:
|
||||
|
||||
"""
|
||||
assert output_fmt in ('NCHW', 'NLC'), 'Output format for ViT features must be one of NCHW or NLC.'
|
||||
reshape = output_fmt == 'NCHW'
|
||||
intermediates = []
|
||||
take_indices, max_index = feature_take_indices(len(self.blocks), indices)
|
||||
|
||||
# forward pass
|
||||
B, _, height, width = x.shape
|
||||
x = self.patch_embed(x)
|
||||
if self.cls_token is not None:
|
||||
x = torch.cat((self.cls_token.expand(x.shape[0], -1, -1), x), dim=1)
|
||||
|
||||
shared_rel_pos = self.shared_rel_pos.get_bias() if self.shared_rel_pos is not None else None
|
||||
if torch.jit.is_scripting() or not stop_early: # can't slice blocks in torchscript
|
||||
blocks = self.blocks
|
||||
else:
|
||||
blocks = self.blocks[:max_index + 1]
|
||||
for i, blk in enumerate(blocks):
|
||||
x = blk(x, shared_rel_pos=shared_rel_pos)
|
||||
if i in take_indices:
|
||||
# normalize intermediates with final norm layer if enabled
|
||||
intermediates.append(self.norm(x) if norm else x)
|
||||
|
||||
# process intermediates
|
||||
if self.num_prefix_tokens:
|
||||
# split prefix (e.g. class, distill) and spatial feature tokens
|
||||
prefix_tokens = [y[:, 0:self.num_prefix_tokens] for y in intermediates]
|
||||
intermediates = [y[:, self.num_prefix_tokens:] for y in intermediates]
|
||||
if reshape:
|
||||
# reshape to BCHW output format
|
||||
H, W = self.patch_embed.dynamic_feat_size((height, width))
|
||||
intermediates = [y.reshape(B, H, W, -1).permute(0, 3, 1, 2).contiguous() for y in intermediates]
|
||||
if not torch.jit.is_scripting() and return_prefix_tokens:
|
||||
# return_prefix not support in torchscript due to poor type handling
|
||||
intermediates = list(zip(intermediates, prefix_tokens))
|
||||
|
||||
if intermediates_only:
|
||||
return intermediates
|
||||
|
||||
x = self.norm(x)
|
||||
|
||||
return x, intermediates
|
||||
|
||||
def prune_intermediate_layers(
|
||||
self,
|
||||
n: Union[int, List[int], Tuple[int]] = 1,
|
||||
prune_norm: bool = False,
|
||||
prune_head: bool = True,
|
||||
):
|
||||
""" Prune layers not required for specified intermediates.
|
||||
"""
|
||||
take_indices, max_index = feature_take_indices(len(self.blocks), n)
|
||||
self.blocks = self.blocks[:max_index + 1] # truncate blocks
|
||||
if prune_norm:
|
||||
self.norm = nn.Identity()
|
||||
if prune_head:
|
||||
self.fc_norm = nn.Identity()
|
||||
self.head = nn.Identity()
|
||||
return take_indices
|
||||
|
||||
def forward_features(self, x):
|
||||
x = self.patch_embed(x)
|
||||
if self.cls_token is not None:
|
||||
|
@ -412,10 +498,12 @@ class VisionTransformerRelPos(nn.Module):
|
|||
|
||||
|
||||
def _create_vision_transformer_relpos(variant, pretrained=False, **kwargs):
|
||||
if kwargs.get('features_only', None):
|
||||
raise RuntimeError('features_only not implemented for Vision Transformer models.')
|
||||
|
||||
model = build_model_with_cfg(VisionTransformerRelPos, variant, pretrained, **kwargs)
|
||||
out_indices = kwargs.pop('out_indices', 3)
|
||||
model = build_model_with_cfg(
|
||||
VisionTransformerRelPos, variant, pretrained,
|
||||
feature_cfg=dict(out_indices=out_indices, feature_cls='getter'),
|
||||
**kwargs,
|
||||
)
|
||||
return model
|
||||
|
||||
|
||||
|
|
|
@ -11,21 +11,22 @@ A PyTorch implement of Vision Transformers as described in:
|
|||
"""
|
||||
import logging
|
||||
from functools import partial
|
||||
from typing import Callable, Optional, Tuple
|
||||
from typing import Callable, List, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
import torch.utils.checkpoint
|
||||
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD, IMAGENET_INCEPTION_MEAN, IMAGENET_INCEPTION_STD
|
||||
from timm.layers import PatchEmbed, Mlp, DropPath, PatchDropout, LayerNorm2d, ClassifierHead, NormMlpClassifierHead, \
|
||||
Format, resample_abs_pos_embed_nhwc, RotaryEmbeddingCat, apply_rot_embed_cat, to_2tuple, use_fused_attn
|
||||
from torch.jit import Final
|
||||
|
||||
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD, IMAGENET_INCEPTION_MEAN, IMAGENET_INCEPTION_STD
|
||||
from timm.layers import PatchEmbed, Mlp, DropPath, PatchDropout, LayerNorm2d, ClassifierHead, NormMlpClassifierHead,\
|
||||
Format, resample_abs_pos_embed_nhwc, RotaryEmbeddingCat, apply_rot_embed_cat, to_2tuple, use_fused_attn
|
||||
from ._builder import build_model_with_cfg
|
||||
from ._features import feature_take_indices
|
||||
from ._features_fx import register_notrace_function
|
||||
from ._manipulate import checkpoint_seq
|
||||
from ._registry import generate_default_cfgs, register_model
|
||||
from ._features_fx import register_notrace_function
|
||||
|
||||
# model_registry will add each entrypoint fn to this
|
||||
__all__ = ['VisionTransformerSAM']
|
||||
|
@ -343,8 +344,7 @@ class VisionTransformerSAM(nn.Module):
|
|||
attn_drop_rate: float = 0.,
|
||||
drop_path_rate: float = 0.,
|
||||
weight_init: str = '',
|
||||
embed_layer: Callable = partial(
|
||||
PatchEmbed, output_fmt=Format.NHWC, strict_img_size=False),
|
||||
embed_layer: Callable = partial(PatchEmbed, output_fmt=Format.NHWC, strict_img_size=False),
|
||||
norm_layer: Optional[Callable] = nn.LayerNorm,
|
||||
act_layer: Optional[Callable] = nn.GELU,
|
||||
block_fn: Callable = Block,
|
||||
|
@ -408,6 +408,8 @@ class VisionTransformerSAM(nn.Module):
|
|||
bias=not pre_norm, # disable bias if pre-norm is used
|
||||
)
|
||||
grid_size = self.patch_embed.grid_size
|
||||
r = self.patch_embed.feat_ratio() if hasattr(self.patch_embed, 'feat_ratio') else patch_size
|
||||
|
||||
if use_abs_pos:
|
||||
# Initialize absolute positional embedding with pretrain image size.
|
||||
self.pos_embed = nn.Parameter(torch.zeros(1, grid_size[0], grid_size[1], embed_dim))
|
||||
|
@ -469,6 +471,8 @@ class VisionTransformerSAM(nn.Module):
|
|||
rope=self.rope_window if i not in global_attn_indexes else self.rope_global,
|
||||
)
|
||||
for i in range(depth)])
|
||||
self.feature_info = [
|
||||
dict(module=f'blocks.{i}', num_chs=embed_dim, reduction=r) for i in range(depth)]
|
||||
|
||||
if neck_chans:
|
||||
self.neck = nn.Sequential(
|
||||
|
@ -536,6 +540,78 @@ class VisionTransformerSAM(nn.Module):
|
|||
def reset_classifier(self, num_classes=0, global_pool=None):
|
||||
self.head.reset(num_classes, global_pool)
|
||||
|
||||
def forward_intermediates(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
indices: Union[int, List[int], Tuple[int]] = None,
|
||||
norm: bool = False,
|
||||
stop_early: bool = True,
|
||||
output_fmt: str = 'NCHW',
|
||||
intermediates_only: bool = False,
|
||||
) -> Union[List[torch.Tensor], Tuple[torch.Tensor, List[torch.Tensor]]]:
|
||||
""" Forward features that returns intermediates.
|
||||
|
||||
Args:
|
||||
x: Input image tensor
|
||||
indices: Take last n blocks if int, all if None, select matching indices if sequence
|
||||
norm: Apply norm layer to all intermediates
|
||||
stop_early: Stop iterating over blocks when last desired intermediate hit
|
||||
output_fmt: Shape of intermediate feature outputs
|
||||
intermediates_only: Only return intermediate features
|
||||
Returns:
|
||||
|
||||
"""
|
||||
assert output_fmt == 'NCHW', 'Output shape for ViT-SAM must be NCHW.'
|
||||
intermediates = []
|
||||
take_indices, max_index = feature_take_indices(len(self.blocks), indices)
|
||||
|
||||
# forward pass, collect intermediates
|
||||
x = self.patch_embed(x)
|
||||
if self.pos_embed is not None:
|
||||
# dynamically resize abs pos embedding if needed
|
||||
x = x + resample_abs_pos_embed_nhwc(self.pos_embed, x.shape[1:3])
|
||||
x = self.pos_drop(x)
|
||||
x = self.patch_drop(x)
|
||||
x = self.norm_pre(x)
|
||||
if torch.jit.is_scripting() or not stop_early: # can't slice blocks in torchscript
|
||||
blocks = self.blocks
|
||||
else:
|
||||
blocks = self.blocks[:max_index + 1]
|
||||
for i, blk in enumerate(blocks):
|
||||
x = blk(x)
|
||||
if i in take_indices:
|
||||
# make output BCHW
|
||||
if norm:
|
||||
# norm is intertwined with neck convs so apply both, changes the dim
|
||||
# FIXME only apply to final? Need experiments
|
||||
intermediates.append(self.neck(x.permute(0, 3, 1, 2)))
|
||||
else:
|
||||
intermediates.append(x.permute(0, 3, 1, 2))
|
||||
|
||||
if intermediates_only:
|
||||
return intermediates
|
||||
|
||||
x = self.neck(x.permute(0, 3, 1, 2))
|
||||
|
||||
return x, intermediates
|
||||
|
||||
def prune_intermediate_layers(
|
||||
self,
|
||||
n: Union[int, List[int], Tuple[int]] = None,
|
||||
prune_norm: bool = False,
|
||||
prune_head: bool = True,
|
||||
):
|
||||
""" Prune layers not required for specified intermediates.
|
||||
"""
|
||||
take_indices, max_index = feature_take_indices(len(self.blocks), n)
|
||||
self.blocks = self.blocks[:max_index + 1] # truncate blocks
|
||||
if prune_norm:
|
||||
# neck is being treated as equivalent to final norm here
|
||||
self.neck = nn.Identity()
|
||||
if prune_head:
|
||||
self.head = nn.Identity()
|
||||
return take_indices
|
||||
|
||||
def forward_features(self, x):
|
||||
x = self.patch_embed(x)
|
||||
if self.pos_embed is not None:
|
||||
|
@ -618,15 +694,13 @@ default_cfgs = generate_default_cfgs({
|
|||
|
||||
|
||||
def _create_vision_transformer(variant, pretrained=False, **kwargs):
|
||||
if kwargs.get('features_only', None):
|
||||
raise RuntimeError(
|
||||
'features_only not implemented for Vision Transformer models.')
|
||||
|
||||
out_indices = kwargs.pop('out_indices', 3)
|
||||
return build_model_with_cfg(
|
||||
VisionTransformerSAM,
|
||||
variant,
|
||||
pretrained,
|
||||
pretrained_filter_fn=checkpoint_filter_fn,
|
||||
feature_cfg=dict(out_indices=out_indices, feature_cls='getter'),
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
|
|
|
@ -259,7 +259,7 @@ class VovNet(nn.Module):
|
|||
return self.stages(x)
|
||||
|
||||
def forward_head(self, x, pre_logits: bool = False):
|
||||
return self.head(x, pre_logits=pre_logits)
|
||||
return self.head(x, pre_logits=pre_logits) if pre_logits else self.head(x)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.forward_features(x)
|
||||
|
|
|
@ -286,7 +286,7 @@ class XceptionAligned(nn.Module):
|
|||
return x
|
||||
|
||||
def forward_head(self, x, pre_logits: bool = False):
|
||||
return self.head(x, pre_logits=pre_logits)
|
||||
return self.head(x, pre_logits=pre_logits) if pre_logits else self.head(x)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.forward_features(x)
|
||||
|
|
Loading…
Reference in New Issue