mirror of
https://github.com/huggingface/pytorch-image-models.git
synced 2025-06-03 15:01:08 +08:00
More forward_intermediates() & features_only work
* forward_intermediates() added to beit, deit, eva, mvitv2, twins, vit, vit_sam * add features_only to forward intermediates to allow just intermediate features * fix #2060 * fix #1374 * fix #657
This commit is contained in:
parent
5fdc0b4e93
commit
679daef76a
@ -11,7 +11,7 @@ Hacked together by / Copyright 2020 Ross Wightman
|
|||||||
from collections import OrderedDict, defaultdict
|
from collections import OrderedDict, defaultdict
|
||||||
from copy import deepcopy
|
from copy import deepcopy
|
||||||
from functools import partial
|
from functools import partial
|
||||||
from typing import Dict, List, Optional, Sequence, Tuple, Union
|
from typing import Dict, List, Optional, Sequence, Set, Tuple, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
@ -20,7 +20,39 @@ from torch.utils.checkpoint import checkpoint
|
|||||||
from timm.layers import Format
|
from timm.layers import Format
|
||||||
|
|
||||||
|
|
||||||
__all__ = ['FeatureInfo', 'FeatureHooks', 'FeatureDictNet', 'FeatureListNet', 'FeatureHookNet']
|
__all__ = [
|
||||||
|
'FeatureInfo', 'FeatureHooks', 'FeatureDictNet', 'FeatureListNet', 'FeatureHookNet', 'FeatureGetterNet',
|
||||||
|
'feature_take_indices'
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
def _take_indices(n: Union[int, List[int], Tuple[int]], num_blocks: int) -> Tuple[Set[int], int]:
|
||||||
|
if isinstance(n, int):
|
||||||
|
assert n >= 0
|
||||||
|
take_indices = {x for x in range(num_blocks - n, num_blocks)}
|
||||||
|
else:
|
||||||
|
take_indices = {num_blocks + idx if idx < 0 else idx for idx in n}
|
||||||
|
return take_indices, max(take_indices)
|
||||||
|
|
||||||
|
|
||||||
|
def _take_indices_jit(n: Union[int, List[int], Tuple[int]], num_blocks: int) -> Tuple[List[int], int]:
|
||||||
|
if isinstance(n, int):
|
||||||
|
assert n >= 0
|
||||||
|
take_indices = [num_blocks - n + i for i in range(n)]
|
||||||
|
elif isinstance(n, tuple):
|
||||||
|
# splitting this up is silly, but needed for torchscript type resolution of n
|
||||||
|
take_indices = [num_blocks + idx if idx < 0 else idx for idx in n]
|
||||||
|
else:
|
||||||
|
take_indices = [num_blocks + idx if idx < 0 else idx for idx in n]
|
||||||
|
return take_indices, max(take_indices)
|
||||||
|
|
||||||
|
|
||||||
|
def feature_take_indices(n: Union[int, List[int], Tuple[int]], num_blocks: int) -> Tuple[List[int], int]:
|
||||||
|
if torch.jit.is_scripting():
|
||||||
|
return _take_indices_jit(n, num_blocks)
|
||||||
|
else:
|
||||||
|
# NOTE non-jit returns Set[int] instead of List[int] but torchscript can't handle that anno
|
||||||
|
return _take_indices(n, num_blocks)
|
||||||
|
|
||||||
|
|
||||||
def _out_indices_as_tuple(x: Union[int, Tuple[int, ...]]) -> Tuple[int, ...]:
|
def _out_indices_as_tuple(x: Union[int, Tuple[int, ...]]) -> Tuple[int, ...]:
|
||||||
@ -397,29 +429,38 @@ class FeatureGetterNet(nn.ModuleDict):
|
|||||||
out_map: Optional[Sequence[Union[int, str]]] = None,
|
out_map: Optional[Sequence[Union[int, str]]] = None,
|
||||||
return_dict: bool = False,
|
return_dict: bool = False,
|
||||||
output_fmt: str = 'NCHW',
|
output_fmt: str = 'NCHW',
|
||||||
|
norm: bool = False,
|
||||||
|
prune: bool = True,
|
||||||
):
|
):
|
||||||
|
"""
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model: Model to wrap.
|
||||||
|
out_indices: Indices of features to extract.
|
||||||
|
out_map: Remap feature names for dict output (WIP, not supported).
|
||||||
|
return_dict: Return features as dictionary instead of list (WIP, not supported).
|
||||||
|
norm: Apply final model norm to all output features (if possible).
|
||||||
|
"""
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.model = model
|
if prune and hasattr(model, 'prune_intermediate_layers'):
|
||||||
|
model.prune_intermediate_layers(
|
||||||
|
out_indices,
|
||||||
|
prune_norm=not norm,
|
||||||
|
)
|
||||||
self.feature_info = _get_feature_info(model, out_indices)
|
self.feature_info = _get_feature_info(model, out_indices)
|
||||||
|
self.model = model
|
||||||
self.out_indices = out_indices
|
self.out_indices = out_indices
|
||||||
self.out_map = out_map
|
self.out_map = out_map
|
||||||
self.return_dict = return_dict
|
self.return_dict = return_dict
|
||||||
self.output_fmt = output_fmt
|
self.output_fmt = output_fmt
|
||||||
|
self.norm = norm
|
||||||
|
|
||||||
def forward(self, *args, **kwargs):
|
def forward(self, x):
|
||||||
"""
|
features = self.model.forward_intermediates(
|
||||||
def get_intermediate_layers(
|
x,
|
||||||
self,
|
|
||||||
x: torch.Tensor,
|
|
||||||
n: Union[int, Sequence] = 1,
|
|
||||||
reshape: bool = False,
|
|
||||||
return_prefix_tokens: bool = False,
|
|
||||||
norm: bool = False,
|
|
||||||
"""
|
|
||||||
out = self.model.get_intermediate_layers(
|
|
||||||
*args,
|
|
||||||
n=self.out_indices,
|
n=self.out_indices,
|
||||||
reshape=True,
|
norm=self.norm,
|
||||||
**kwargs,
|
output_fmt=self.output_fmt,
|
||||||
|
features_only=True,
|
||||||
)
|
)
|
||||||
return out
|
return features
|
||||||
|
@ -39,7 +39,7 @@ Modifications by / Copyright 2021 Ross Wightman, original copyrights below
|
|||||||
# --------------------------------------------------------'
|
# --------------------------------------------------------'
|
||||||
|
|
||||||
import math
|
import math
|
||||||
from typing import Callable, Optional, Tuple, Union
|
from typing import Callable, List, Optional, Tuple, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
@ -52,8 +52,8 @@ from timm.layers import resample_patch_embed, resample_abs_pos_embed, resize_rel
|
|||||||
|
|
||||||
|
|
||||||
from ._builder import build_model_with_cfg
|
from ._builder import build_model_with_cfg
|
||||||
|
from ._features import feature_take_indices
|
||||||
from ._registry import generate_default_cfgs, register_model
|
from ._registry import generate_default_cfgs, register_model
|
||||||
from .vision_transformer import checkpoint_filter_fn
|
|
||||||
|
|
||||||
__all__ = ['Beit']
|
__all__ = ['Beit']
|
||||||
|
|
||||||
@ -333,6 +333,8 @@ class Beit(nn.Module):
|
|||||||
window_size=self.patch_embed.grid_size if use_rel_pos_bias else None,
|
window_size=self.patch_embed.grid_size if use_rel_pos_bias else None,
|
||||||
)
|
)
|
||||||
for i in range(depth)])
|
for i in range(depth)])
|
||||||
|
self.feature_info = [
|
||||||
|
dict(module=f'blocks.{i}', num_chs=embed_dim, reduction=patch_size) for i in range(depth)]
|
||||||
|
|
||||||
use_fc_norm = self.global_pool == 'avg'
|
use_fc_norm = self.global_pool == 'avg'
|
||||||
self.norm = nn.Identity() if use_fc_norm else norm_layer(embed_dim)
|
self.norm = nn.Identity() if use_fc_norm else norm_layer(embed_dim)
|
||||||
@ -398,6 +400,93 @@ class Beit(nn.Module):
|
|||||||
self.global_pool = global_pool
|
self.global_pool = global_pool
|
||||||
self.head = nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity()
|
self.head = nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity()
|
||||||
|
|
||||||
|
def forward_intermediates(
|
||||||
|
self,
|
||||||
|
x: torch.Tensor,
|
||||||
|
n: Optional[Union[int, List[int], Tuple[int]]] = None,
|
||||||
|
return_prefix_tokens: bool = False,
|
||||||
|
norm: bool = False,
|
||||||
|
stop_early: bool = True,
|
||||||
|
output_fmt: str = 'NCHW',
|
||||||
|
features_only: bool = False,
|
||||||
|
) -> Union[List[torch.Tensor], Tuple[torch.Tensor, List[torch.Tensor]]]:
|
||||||
|
""" Forward features that returns intermediates.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
x: Input image tensor
|
||||||
|
n: Take last n blocks if n is an int, if in is a sequence, select by matching indices
|
||||||
|
return_prefix_tokens: Return both prefix and spatial intermediate tokens
|
||||||
|
norm: Apply norm layer to all intermediates
|
||||||
|
stop_early: Stop iterating over blocks when last desired intermediate hit
|
||||||
|
output_fmt: Shape of intermediate feature outputs
|
||||||
|
features_only: Only return intermediate features
|
||||||
|
Returns:
|
||||||
|
|
||||||
|
"""
|
||||||
|
assert output_fmt in ('NCHW', 'NLC'), 'Output format for ViT features must be one of NCHW or NLC.'
|
||||||
|
reshape = output_fmt == 'NCHW'
|
||||||
|
intermediates = []
|
||||||
|
num_blocks = len(self.blocks)
|
||||||
|
if n is None:
|
||||||
|
n = num_blocks
|
||||||
|
take_indices, max_index = feature_take_indices(n, num_blocks)
|
||||||
|
|
||||||
|
# forward pass
|
||||||
|
B, _, height, width = x.shape
|
||||||
|
x = self.patch_embed(x)
|
||||||
|
x = torch.cat((self.cls_token.expand(x.shape[0], -1, -1), x), dim=1)
|
||||||
|
if self.pos_embed is not None:
|
||||||
|
x = x + self.pos_embed
|
||||||
|
x = self.pos_drop(x)
|
||||||
|
rel_pos_bias = self.rel_pos_bias() if self.rel_pos_bias is not None else None
|
||||||
|
if torch.jit.is_scripting() or not stop_early: # can't slice blocks in torchscript
|
||||||
|
blocks = self.blocks
|
||||||
|
else:
|
||||||
|
blocks = self.blocks[:max_index + 1]
|
||||||
|
for i, blk in enumerate(blocks):
|
||||||
|
x = blk(x, shared_rel_pos_bias=rel_pos_bias)
|
||||||
|
if i in take_indices:
|
||||||
|
# normalize intermediates with final norm layer if enabled
|
||||||
|
intermediates.append(self.norm(x) if norm else x)
|
||||||
|
|
||||||
|
# process intermediates
|
||||||
|
if self.num_prefix_tokens:
|
||||||
|
# split prefix (e.g. class, distill) and spatial feature tokens
|
||||||
|
prefix_tokens = [y[:, 0:self.num_prefix_tokens] for y in intermediates]
|
||||||
|
intermediates = [y[:, self.num_prefix_tokens:] for y in intermediates]
|
||||||
|
if reshape:
|
||||||
|
# reshape == True => BCHW output format
|
||||||
|
patch_size = self.patch_embed.patch_size
|
||||||
|
H = int(math.ceil(height / patch_size[0]))
|
||||||
|
W = int(math.ceil(width / patch_size[1]))
|
||||||
|
intermediates = [y.reshape(B, H, W, -1).permute(0, 3, 1, 2).contiguous() for y in intermediates]
|
||||||
|
if not torch.jit.is_scripting() and return_prefix_tokens:
|
||||||
|
# return_prefix not support in torchscript due to poor type handling
|
||||||
|
intermediates = list(zip(intermediates, prefix_tokens))
|
||||||
|
|
||||||
|
if features_only:
|
||||||
|
return intermediates
|
||||||
|
|
||||||
|
x = self.norm(x)
|
||||||
|
|
||||||
|
return x, intermediates
|
||||||
|
|
||||||
|
def prune_intermediate_layers(
|
||||||
|
self,
|
||||||
|
n: Union[int, List[int], Tuple[int]] = 1,
|
||||||
|
prune_norm: bool = False,
|
||||||
|
prune_head: bool = True,
|
||||||
|
):
|
||||||
|
""" Prune layers not required for specified intermediates.
|
||||||
|
"""
|
||||||
|
take_indices, max_index = feature_take_indices(n, len(self.blocks))
|
||||||
|
self.blocks = self.blocks[:max_index + 1] # truncate blocks
|
||||||
|
if prune_norm:
|
||||||
|
self.norm = nn.Identity()
|
||||||
|
if prune_head:
|
||||||
|
self.fc_norm = nn.Identity()
|
||||||
|
self.head = nn.Identity()
|
||||||
|
|
||||||
def forward_features(self, x):
|
def forward_features(self, x):
|
||||||
x = self.patch_embed(x)
|
x = self.patch_embed(x)
|
||||||
x = torch.cat((self.cls_token.expand(x.shape[0], -1, -1), x), dim=1)
|
x = torch.cat((self.cls_token.expand(x.shape[0], -1, -1), x), dim=1)
|
||||||
@ -547,14 +636,13 @@ def _beit_checkpoint_filter_fn(state_dict, model, interpolation='bicubic', antia
|
|||||||
|
|
||||||
|
|
||||||
def _create_beit(variant, pretrained=False, **kwargs):
|
def _create_beit(variant, pretrained=False, **kwargs):
|
||||||
if kwargs.get('features_only', None):
|
out_indices = kwargs.pop('out_indices', 3)
|
||||||
raise RuntimeError('features_only not implemented for BEiT models.')
|
|
||||||
|
|
||||||
model = build_model_with_cfg(
|
model = build_model_with_cfg(
|
||||||
Beit, variant, pretrained,
|
Beit, variant, pretrained,
|
||||||
# FIXME an updated filter fn needed to interpolate rel pos emb if fine tuning to diff model sizes
|
|
||||||
pretrained_filter_fn=_beit_checkpoint_filter_fn,
|
pretrained_filter_fn=_beit_checkpoint_filter_fn,
|
||||||
**kwargs)
|
feature_cfg=dict(out_indices=out_indices, feature_cls='getter'),
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
return model
|
return model
|
||||||
|
|
||||||
|
|
||||||
|
@ -119,14 +119,14 @@ class VisionTransformerDistilled(VisionTransformer):
|
|||||||
|
|
||||||
|
|
||||||
def _create_deit(variant, pretrained=False, distilled=False, **kwargs):
|
def _create_deit(variant, pretrained=False, distilled=False, **kwargs):
|
||||||
if kwargs.get('features_only', None):
|
out_indices = kwargs.pop('out_indices', 3)
|
||||||
raise RuntimeError('features_only not implemented for Vision Transformer models.')
|
|
||||||
model_cls = VisionTransformerDistilled if distilled else VisionTransformer
|
model_cls = VisionTransformerDistilled if distilled else VisionTransformer
|
||||||
model = build_model_with_cfg(
|
model = build_model_with_cfg(
|
||||||
model_cls,
|
model_cls,
|
||||||
variant,
|
variant,
|
||||||
pretrained,
|
pretrained,
|
||||||
pretrained_filter_fn=partial(checkpoint_filter_fn, adapt_layer_scale=True),
|
pretrained_filter_fn=partial(checkpoint_filter_fn, adapt_layer_scale=True),
|
||||||
|
feature_cfg=dict(out_indices=out_indices, feature_cls='getter'),
|
||||||
**kwargs,
|
**kwargs,
|
||||||
)
|
)
|
||||||
return model
|
return model
|
||||||
|
@ -24,9 +24,8 @@ Modifications by / Copyright 2023 Ross Wightman, original copyrights below
|
|||||||
"""
|
"""
|
||||||
# EVA models Copyright (c) 2022 BAAI-Vision
|
# EVA models Copyright (c) 2022 BAAI-Vision
|
||||||
# EVA02 models Copyright (c) 2023 BAAI-Vision
|
# EVA02 models Copyright (c) 2023 BAAI-Vision
|
||||||
|
|
||||||
import math
|
import math
|
||||||
from typing import Callable, Optional, Tuple, Union
|
from typing import Callable, List, Optional, Tuple, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
@ -39,6 +38,7 @@ from timm.layers import PatchEmbed, Mlp, GluMlp, SwiGLU, LayerNorm, DropPath, Pa
|
|||||||
to_2tuple, use_fused_attn
|
to_2tuple, use_fused_attn
|
||||||
|
|
||||||
from ._builder import build_model_with_cfg
|
from ._builder import build_model_with_cfg
|
||||||
|
from ._features import feature_take_indices
|
||||||
from ._registry import generate_default_cfgs, register_model
|
from ._registry import generate_default_cfgs, register_model
|
||||||
|
|
||||||
__all__ = ['Eva']
|
__all__ = ['Eva']
|
||||||
@ -469,6 +469,8 @@ class Eva(nn.Module):
|
|||||||
init_values=init_values,
|
init_values=init_values,
|
||||||
)
|
)
|
||||||
for i in range(depth)])
|
for i in range(depth)])
|
||||||
|
self.feature_info = [
|
||||||
|
dict(module=f'blocks.{i}', num_chs=embed_dim, reduction=patch_size) for i in range(depth)]
|
||||||
|
|
||||||
use_fc_norm = self.global_pool == 'avg'
|
use_fc_norm = self.global_pool == 'avg'
|
||||||
self.norm = nn.Identity() if use_fc_norm else norm_layer(embed_dim)
|
self.norm = nn.Identity() if use_fc_norm else norm_layer(embed_dim)
|
||||||
@ -559,6 +561,85 @@ class Eva(nn.Module):
|
|||||||
rot_pos_embed = apply_keep_indices_nlc(x, rot_pos_embed, keep_indices)
|
rot_pos_embed = apply_keep_indices_nlc(x, rot_pos_embed, keep_indices)
|
||||||
return x, rot_pos_embed
|
return x, rot_pos_embed
|
||||||
|
|
||||||
|
def forward_intermediates(
|
||||||
|
self,
|
||||||
|
x: torch.Tensor,
|
||||||
|
n: Optional[Union[int, List[int], Tuple[int]]] = None,
|
||||||
|
return_prefix_tokens: bool = False,
|
||||||
|
norm: bool = False,
|
||||||
|
stop_early: bool = True,
|
||||||
|
output_fmt: str = 'NCHW',
|
||||||
|
features_only: bool = False,
|
||||||
|
) -> Union[List[torch.Tensor], Tuple[torch.Tensor, List[torch.Tensor]]]:
|
||||||
|
""" Forward features that returns intermediates.
|
||||||
|
Args:
|
||||||
|
x: Input image tensor
|
||||||
|
n: Take last n blocks if n is an int, if in is a sequence, select by matching indices
|
||||||
|
return_prefix_tokens: Return both prefix and spatial intermediate tokens
|
||||||
|
norm: Apply norm layer to all intermediates
|
||||||
|
stop_early: Stop iterating over blocks when last desired intermediate hit
|
||||||
|
output_fmt: Shape of intermediate feature outputs
|
||||||
|
features_only: Only return intermediate features
|
||||||
|
"""
|
||||||
|
assert output_fmt in ('NCHW', 'NLC'), 'Output format for EVA-ViT features must be one of NCHW or NLC.'
|
||||||
|
reshape = output_fmt == 'NCHW'
|
||||||
|
intermediates = []
|
||||||
|
num_blocks = len(self.blocks)
|
||||||
|
if n is None:
|
||||||
|
n = num_blocks
|
||||||
|
take_indices, max_index = feature_take_indices(n, num_blocks)
|
||||||
|
|
||||||
|
# forward pass
|
||||||
|
B, _, height, width = x.shape
|
||||||
|
x = self.patch_embed(x)
|
||||||
|
x, rot_pos_embed = self._pos_embed(x)
|
||||||
|
if torch.jit.is_scripting() or not stop_early: # can't slice blocks in torchscript
|
||||||
|
blocks = self.blocks
|
||||||
|
else:
|
||||||
|
blocks = self.blocks[:max_index + 1]
|
||||||
|
for i, blk in enumerate(blocks):
|
||||||
|
x = blk(x, rope=rot_pos_embed)
|
||||||
|
if i in take_indices:
|
||||||
|
intermediates.append(self.norm(x) if norm else x)
|
||||||
|
|
||||||
|
# process intermediates
|
||||||
|
if self.num_prefix_tokens:
|
||||||
|
# split prefix (e.g. class, distill) and spatial feature tokens
|
||||||
|
prefix_tokens = [y[:, 0:self.num_prefix_tokens] for y in intermediates]
|
||||||
|
intermediates = [y[:, self.num_prefix_tokens:] for y in intermediates]
|
||||||
|
if reshape:
|
||||||
|
# reshape == True => BCHW output format
|
||||||
|
patch_size = self.patch_embed.patch_size
|
||||||
|
H = int(math.ceil(height / patch_size[0]))
|
||||||
|
W = int(math.ceil(width / patch_size[1]))
|
||||||
|
intermediates = [y.reshape(B, H, W, -1).permute(0, 3, 1, 2).contiguous() for y in intermediates]
|
||||||
|
if not torch.jit.is_scripting() and return_prefix_tokens:
|
||||||
|
# return_prefix not support in torchscript due to poor type handling
|
||||||
|
intermediates = list(zip(intermediates, prefix_tokens))
|
||||||
|
|
||||||
|
if features_only:
|
||||||
|
return intermediates
|
||||||
|
|
||||||
|
x = self.norm(x)
|
||||||
|
|
||||||
|
return x, intermediates
|
||||||
|
|
||||||
|
def prune_intermediate_layers(
|
||||||
|
self,
|
||||||
|
n: Union[int, List[int], Tuple[int]] = 1,
|
||||||
|
prune_norm: bool = False,
|
||||||
|
prune_head: bool = True,
|
||||||
|
):
|
||||||
|
""" Prune layers not required for specified intermediates.
|
||||||
|
"""
|
||||||
|
take_indices, max_index = feature_take_indices(n, len(self.blocks))
|
||||||
|
self.blocks = self.blocks[:max_index + 1] # truncate blocks
|
||||||
|
if prune_norm:
|
||||||
|
self.norm = nn.Identity()
|
||||||
|
if prune_head:
|
||||||
|
self.fc_norm = nn.Identity()
|
||||||
|
self.head = nn.Identity()
|
||||||
|
|
||||||
def forward_features(self, x):
|
def forward_features(self, x):
|
||||||
x = self.patch_embed(x)
|
x = self.patch_embed(x)
|
||||||
x, rot_pos_embed = self._pos_embed(x)
|
x, rot_pos_embed = self._pos_embed(x)
|
||||||
@ -663,13 +744,13 @@ def checkpoint_filter_fn(
|
|||||||
|
|
||||||
|
|
||||||
def _create_eva(variant, pretrained=False, **kwargs):
|
def _create_eva(variant, pretrained=False, **kwargs):
|
||||||
if kwargs.get('features_only', None):
|
out_indices = kwargs.pop('out_indices', 3)
|
||||||
raise RuntimeError('features_only not implemented for Eva models.')
|
|
||||||
|
|
||||||
model = build_model_with_cfg(
|
model = build_model_with_cfg(
|
||||||
Eva, variant, pretrained,
|
Eva, variant, pretrained,
|
||||||
pretrained_filter_fn=checkpoint_filter_fn,
|
pretrained_filter_fn=checkpoint_filter_fn,
|
||||||
**kwargs)
|
feature_cfg=dict(out_indices=out_indices, feature_cls='getter'),
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
return model
|
return model
|
||||||
|
|
||||||
|
|
||||||
|
@ -26,6 +26,7 @@ from torch import nn
|
|||||||
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
|
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
|
||||||
from timm.layers import Mlp, DropPath, trunc_normal_tf_, get_norm_layer, to_2tuple
|
from timm.layers import Mlp, DropPath, trunc_normal_tf_, get_norm_layer, to_2tuple
|
||||||
from ._builder import build_model_with_cfg
|
from ._builder import build_model_with_cfg
|
||||||
|
from ._features import feature_take_indices
|
||||||
from ._features_fx import register_notrace_function
|
from ._features_fx import register_notrace_function
|
||||||
from ._registry import register_model, register_model_deprecations, generate_default_cfgs
|
from ._registry import register_model, register_model_deprecations, generate_default_cfgs
|
||||||
|
|
||||||
@ -747,8 +748,10 @@ class MultiScaleVit(nn.Module):
|
|||||||
|
|
||||||
num_stages = len(cfg.embed_dim)
|
num_stages = len(cfg.embed_dim)
|
||||||
feat_size = patch_dims
|
feat_size = patch_dims
|
||||||
|
curr_stride = max(cfg.patch_stride)
|
||||||
dpr = [x.tolist() for x in torch.linspace(0, drop_path_rate, sum(cfg.depths)).split(cfg.depths)]
|
dpr = [x.tolist() for x in torch.linspace(0, drop_path_rate, sum(cfg.depths)).split(cfg.depths)]
|
||||||
self.stages = nn.ModuleList()
|
self.stages = nn.ModuleList()
|
||||||
|
self.feature_info = []
|
||||||
for i in range(num_stages):
|
for i in range(num_stages):
|
||||||
if cfg.expand_attn:
|
if cfg.expand_attn:
|
||||||
dim_out = cfg.embed_dim[i]
|
dim_out = cfg.embed_dim[i]
|
||||||
@ -775,6 +778,8 @@ class MultiScaleVit(nn.Module):
|
|||||||
norm_layer=norm_layer,
|
norm_layer=norm_layer,
|
||||||
drop_path=dpr[i],
|
drop_path=dpr[i],
|
||||||
)
|
)
|
||||||
|
curr_stride *= max(cfg.stride_q[i])
|
||||||
|
self.feature_info += [dict(module=f'block.{i}', num_chs=dim_out, reduction=curr_stride)]
|
||||||
embed_dim = dim_out
|
embed_dim = dim_out
|
||||||
feat_size = stage.feat_size
|
feat_size = stage.feat_size
|
||||||
self.stages.append(stage)
|
self.stages.append(stage)
|
||||||
@ -829,6 +834,51 @@ class MultiScaleVit(nn.Module):
|
|||||||
('fc', nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity())
|
('fc', nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity())
|
||||||
]))
|
]))
|
||||||
|
|
||||||
|
def forward_intermediates(
|
||||||
|
self,
|
||||||
|
x: torch.Tensor,
|
||||||
|
n: Union[int, List[int], Tuple[int]] = None,
|
||||||
|
norm: bool = False,
|
||||||
|
stop_early: bool = True,
|
||||||
|
output_fmt: str = 'NCHW',
|
||||||
|
features_only: bool = False,
|
||||||
|
) -> Union[List[torch.Tensor], Tuple[torch.Tensor, List[torch.Tensor]]]:
|
||||||
|
assert output_fmt in ('NCHW', 'NLC'), 'Output shape for MViT-V2 must be NCHW or NLC.'
|
||||||
|
reshape = output_fmt == 'NCHW'
|
||||||
|
intermediates = []
|
||||||
|
num_stages = len(self.stages) # block list is two-tiered, first tier == stage
|
||||||
|
if n is None:
|
||||||
|
n = num_stages
|
||||||
|
take_indices, max_index = feature_take_indices(n, num_stages)
|
||||||
|
|
||||||
|
# FIXME slice block/pos_block if < max
|
||||||
|
|
||||||
|
# forward pass
|
||||||
|
x, feat_size = self.patch_embed(x)
|
||||||
|
B = x.shape[0]
|
||||||
|
if self.cls_token is not None:
|
||||||
|
cls_tokens = self.cls_token.expand(B, -1, -1)
|
||||||
|
x = torch.cat((cls_tokens, x), dim=1)
|
||||||
|
if self.pos_embed is not None:
|
||||||
|
x = x + self.pos_embed
|
||||||
|
for i, stage in enumerate(self.stages):
|
||||||
|
x, feat_size = stage(x, feat_size)
|
||||||
|
if i in take_indices:
|
||||||
|
if norm and i == (len(self.stages) - 1):
|
||||||
|
x_inter = self.norm(x) # applying final norm last intermediate
|
||||||
|
else:
|
||||||
|
x_inter = x
|
||||||
|
if reshape:
|
||||||
|
x_inter = x_inter.reshape(B, feat_size[0], feat_size[1], -1).permute(0, 3, 1, 2)
|
||||||
|
intermediates.append(x_inter)
|
||||||
|
|
||||||
|
if features_only:
|
||||||
|
return intermediates
|
||||||
|
|
||||||
|
x = self.norm(x)
|
||||||
|
|
||||||
|
return x, intermediates
|
||||||
|
|
||||||
def forward_features(self, x):
|
def forward_features(self, x):
|
||||||
x, feat_size = self.patch_embed(x)
|
x, feat_size = self.patch_embed(x)
|
||||||
B, N, C = x.shape
|
B, N, C = x.shape
|
||||||
@ -862,6 +912,18 @@ class MultiScaleVit(nn.Module):
|
|||||||
|
|
||||||
def checkpoint_filter_fn(state_dict, model):
|
def checkpoint_filter_fn(state_dict, model):
|
||||||
if 'stages.0.blocks.0.norm1.weight' in state_dict:
|
if 'stages.0.blocks.0.norm1.weight' in state_dict:
|
||||||
|
# native checkpoint, look for rel_pos interpolations
|
||||||
|
for k in state_dict.keys():
|
||||||
|
if 'rel_pos' in k:
|
||||||
|
rel_pos = state_dict[k]
|
||||||
|
dest_rel_pos_shape = model.state_dict()[k].shape
|
||||||
|
if rel_pos.shape[0] != dest_rel_pos_shape[0]:
|
||||||
|
rel_pos_resized = torch.nn.functional.interpolate(
|
||||||
|
rel_pos.reshape(1, rel_pos.shape[0], -1).permute(0, 2, 1),
|
||||||
|
size=dest_rel_pos_shape[0],
|
||||||
|
mode="linear",
|
||||||
|
)
|
||||||
|
state_dict[k] = rel_pos_resized.reshape(-1, dest_rel_pos_shape[0]).permute(1, 0)
|
||||||
return state_dict
|
return state_dict
|
||||||
|
|
||||||
import re
|
import re
|
||||||
@ -892,16 +954,6 @@ def checkpoint_filter_fn(state_dict, model):
|
|||||||
k = k.replace('head.projection', 'head.fc')
|
k = k.replace('head.projection', 'head.fc')
|
||||||
out_dict[k] = v
|
out_dict[k] = v
|
||||||
|
|
||||||
# for k, v in state_dict.items():
|
|
||||||
# if model.pos_embed is not None and k == 'pos_embed' and v.shape[1] != model.pos_embed.shape[1]:
|
|
||||||
# # To resize pos embedding when using model at different size from pretrained weights
|
|
||||||
# v = resize_pos_embed(
|
|
||||||
# v,
|
|
||||||
# model.pos_embed,
|
|
||||||
# 0 if getattr(model, 'no_embed_class') else getattr(model, 'num_prefix_tokens', 1),
|
|
||||||
# model.patch_embed.grid_size
|
|
||||||
# )
|
|
||||||
|
|
||||||
return out_dict
|
return out_dict
|
||||||
|
|
||||||
|
|
||||||
@ -948,16 +1000,14 @@ model_cfgs = dict(
|
|||||||
|
|
||||||
|
|
||||||
def _create_mvitv2(variant, cfg_variant=None, pretrained=False, **kwargs):
|
def _create_mvitv2(variant, cfg_variant=None, pretrained=False, **kwargs):
|
||||||
if kwargs.get('features_only', None):
|
out_indices = kwargs.pop('out_indices', 4)
|
||||||
raise RuntimeError('features_only not implemented for Multiscale Vision Transformer models.')
|
|
||||||
|
|
||||||
return build_model_with_cfg(
|
return build_model_with_cfg(
|
||||||
MultiScaleVit,
|
MultiScaleVit,
|
||||||
variant,
|
variant,
|
||||||
pretrained,
|
pretrained,
|
||||||
model_cfg=model_cfgs[variant] if not cfg_variant else model_cfgs[cfg_variant],
|
model_cfg=model_cfgs[variant] if not cfg_variant else model_cfgs[cfg_variant],
|
||||||
pretrained_filter_fn=checkpoint_filter_fn,
|
pretrained_filter_fn=checkpoint_filter_fn,
|
||||||
feature_cfg=dict(flatten_sequential=True),
|
feature_cfg=dict(out_indices=out_indices, feature_cls='getter'),
|
||||||
**kwargs,
|
**kwargs,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -13,7 +13,7 @@ Code/weights from https://github.com/Meituan-AutoML/Twins, original copyright/li
|
|||||||
# --------------------------------------------------------
|
# --------------------------------------------------------
|
||||||
import math
|
import math
|
||||||
from functools import partial
|
from functools import partial
|
||||||
from typing import Tuple
|
from typing import List, Optional, Tuple, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
@ -22,6 +22,7 @@ import torch.nn.functional as F
|
|||||||
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
|
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
|
||||||
from timm.layers import Mlp, DropPath, to_2tuple, trunc_normal_, use_fused_attn
|
from timm.layers import Mlp, DropPath, to_2tuple, trunc_normal_, use_fused_attn
|
||||||
from ._builder import build_model_with_cfg
|
from ._builder import build_model_with_cfg
|
||||||
|
from ._features import feature_take_indices
|
||||||
from ._features_fx import register_notrace_module
|
from ._features_fx import register_notrace_module
|
||||||
from ._registry import register_model, generate_default_cfgs
|
from ._registry import register_model, generate_default_cfgs
|
||||||
from .vision_transformer import Attention
|
from .vision_transformer import Attention
|
||||||
@ -324,6 +325,7 @@ class Twins(nn.Module):
|
|||||||
patch_size = 2
|
patch_size = 2
|
||||||
|
|
||||||
self.blocks = nn.ModuleList()
|
self.blocks = nn.ModuleList()
|
||||||
|
self.feature_info = []
|
||||||
dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))] # stochastic depth decay rule
|
dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))] # stochastic depth decay rule
|
||||||
cur = 0
|
cur = 0
|
||||||
for k in range(len(depths)):
|
for k in range(len(depths)):
|
||||||
@ -339,6 +341,7 @@ class Twins(nn.Module):
|
|||||||
ws=1 if wss is None or i % 2 == 1 else wss[k]) for i in range(depths[k])],
|
ws=1 if wss is None or i % 2 == 1 else wss[k]) for i in range(depths[k])],
|
||||||
)
|
)
|
||||||
self.blocks.append(_block)
|
self.blocks.append(_block)
|
||||||
|
self.feature_info += [dict(module=f'block.{k}', num_chs=embed_dims[k], reduction=2**(2+k))]
|
||||||
cur += depths[k]
|
cur += depths[k]
|
||||||
|
|
||||||
self.pos_block = nn.ModuleList([PosConv(embed_dim, embed_dim) for embed_dim in embed_dims])
|
self.pos_block = nn.ModuleList([PosConv(embed_dim, embed_dim) for embed_dim in embed_dims])
|
||||||
@ -401,6 +404,53 @@ class Twins(nn.Module):
|
|||||||
if m.bias is not None:
|
if m.bias is not None:
|
||||||
m.bias.data.zero_()
|
m.bias.data.zero_()
|
||||||
|
|
||||||
|
def forward_intermediates(
|
||||||
|
self,
|
||||||
|
x: torch.Tensor,
|
||||||
|
n: Union[int, List[int], Tuple[int]] = None,
|
||||||
|
norm: bool = False,
|
||||||
|
stop_early: bool = True,
|
||||||
|
output_fmt: str = 'NCHW',
|
||||||
|
features_only: bool = False,
|
||||||
|
) -> Union[List[torch.Tensor], Tuple[torch.Tensor, List[torch.Tensor]]]:
|
||||||
|
assert output_fmt == 'NCHW', 'Output shape for Twins must be NCHW.'
|
||||||
|
intermediates = []
|
||||||
|
num_stages = len(self.blocks) # block list is two-tiered, first tier == stage
|
||||||
|
if n is None:
|
||||||
|
n = num_stages
|
||||||
|
take_indices, max_index = feature_take_indices(n, num_stages)
|
||||||
|
|
||||||
|
# FIXME slice block/pos_block if < max
|
||||||
|
|
||||||
|
# forward pass
|
||||||
|
B, _, height, width = x.shape
|
||||||
|
for i, (embed, drop, blocks, pos_blk) in enumerate(zip(
|
||||||
|
self.patch_embeds, self.pos_drops, self.blocks, self.pos_block)
|
||||||
|
):
|
||||||
|
x, size = embed(x)
|
||||||
|
x = drop(x)
|
||||||
|
for j, blk in enumerate(blocks):
|
||||||
|
x = blk(x, size)
|
||||||
|
if j == 0:
|
||||||
|
x = pos_blk(x, size) # PEG here
|
||||||
|
|
||||||
|
if i < len(self.depths) - 1:
|
||||||
|
x = x.reshape(B, *size, -1).permute(0, 3, 1, 2).contiguous()
|
||||||
|
if i in take_indices:
|
||||||
|
intermediates.append(x)
|
||||||
|
else:
|
||||||
|
if i in take_indices:
|
||||||
|
# only last feature can be normed
|
||||||
|
x_feat = self.norm(x) if norm else x
|
||||||
|
intermediates.append(x_feat.reshape(B, *size, -1).permute(0, 3, 1, 2).contiguous())
|
||||||
|
|
||||||
|
if features_only:
|
||||||
|
return intermediates
|
||||||
|
|
||||||
|
x = self.norm(x)
|
||||||
|
|
||||||
|
return x, intermediates
|
||||||
|
|
||||||
def forward_features(self, x):
|
def forward_features(self, x):
|
||||||
B = x.shape[0]
|
B = x.shape[0]
|
||||||
for i, (embed, drop, blocks, pos_blk) in enumerate(
|
for i, (embed, drop, blocks, pos_blk) in enumerate(
|
||||||
@ -429,10 +479,12 @@ class Twins(nn.Module):
|
|||||||
|
|
||||||
|
|
||||||
def _create_twins(variant, pretrained=False, **kwargs):
|
def _create_twins(variant, pretrained=False, **kwargs):
|
||||||
if kwargs.get('features_only', None):
|
out_indices = kwargs.pop('out_indices', 4)
|
||||||
raise RuntimeError('features_only not implemented for Vision Transformer models.')
|
model = build_model_with_cfg(
|
||||||
|
Twins, variant, pretrained,
|
||||||
model = build_model_with_cfg(Twins, variant, pretrained, **kwargs)
|
feature_cfg=dict(out_indices=out_indices, feature_cls='getter'),
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
return model
|
return model
|
||||||
|
|
||||||
|
|
||||||
|
@ -45,6 +45,7 @@ from timm.layers import PatchEmbed, Mlp, DropPath, AttentionPoolLatent, RmsNorm,
|
|||||||
trunc_normal_, lecun_normal_, resample_patch_embed, resample_abs_pos_embed, use_fused_attn, \
|
trunc_normal_, lecun_normal_, resample_patch_embed, resample_abs_pos_embed, use_fused_attn, \
|
||||||
get_act_layer, get_norm_layer, LayerType
|
get_act_layer, get_norm_layer, LayerType
|
||||||
from ._builder import build_model_with_cfg
|
from ._builder import build_model_with_cfg
|
||||||
|
from ._features import feature_take_indices
|
||||||
from ._manipulate import named_apply, checkpoint_seq, adapt_input_conv
|
from ._manipulate import named_apply, checkpoint_seq, adapt_input_conv
|
||||||
from ._registry import generate_default_cfgs, register_model, register_model_deprecations
|
from ._registry import generate_default_cfgs, register_model, register_model_deprecations
|
||||||
|
|
||||||
@ -473,7 +474,6 @@ class VisionTransformer(nn.Module):
|
|||||||
self.no_embed_class = no_embed_class # don't embed prefix positions (includes reg)
|
self.no_embed_class = no_embed_class # don't embed prefix positions (includes reg)
|
||||||
self.dynamic_img_size = dynamic_img_size
|
self.dynamic_img_size = dynamic_img_size
|
||||||
self.grad_checkpointing = False
|
self.grad_checkpointing = False
|
||||||
self.feature_info = []
|
|
||||||
|
|
||||||
embed_args = {}
|
embed_args = {}
|
||||||
if dynamic_img_size:
|
if dynamic_img_size:
|
||||||
@ -631,58 +631,111 @@ class VisionTransformer(nn.Module):
|
|||||||
|
|
||||||
return self.pos_drop(x)
|
return self.pos_drop(x)
|
||||||
|
|
||||||
def _intermediate_layers(
|
def forward_intermediates(
|
||||||
self,
|
self,
|
||||||
x: torch.Tensor,
|
x: torch.Tensor,
|
||||||
n: Union[int, Sequence] = 1,
|
n: Optional[Union[int, List[int], Tuple[int]]] = None,
|
||||||
) -> List[torch.Tensor]:
|
return_prefix_tokens: bool = False,
|
||||||
outputs, num_blocks = [], len(self.blocks)
|
norm: bool = False,
|
||||||
take_indices = set(range(num_blocks - n, num_blocks) if isinstance(n, int) else n)
|
stop_early: bool = True,
|
||||||
last_index_to_take = max(take_indices)
|
output_fmt: str = 'NCHW',
|
||||||
|
features_only: bool = False,
|
||||||
|
) -> Union[List[torch.Tensor], Tuple[torch.Tensor, List[torch.Tensor]]]:
|
||||||
|
""" Forward features that returns intermediates.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
x: Input image tensor
|
||||||
|
n: Take last n blocks if n is an int, if in is a sequence, select by matching indices
|
||||||
|
return_prefix_tokens: Return both prefix and spatial intermediate tokens
|
||||||
|
norm: Apply norm layer to all intermediates
|
||||||
|
stop_early: Stop iterating over blocks when last desired intermediate hit
|
||||||
|
output_fmt: Shape of intermediate feature outputs
|
||||||
|
features_only: Only return intermediate features
|
||||||
|
Returns:
|
||||||
|
|
||||||
|
"""
|
||||||
|
assert output_fmt in ('NCHW', 'NLC'), 'Output format for ViT features must be one of NCHW or NLC.'
|
||||||
|
reshape = output_fmt == 'NCHW'
|
||||||
|
intermediates = []
|
||||||
|
num_blocks = len(self.blocks)
|
||||||
|
if n is None:
|
||||||
|
n = num_blocks
|
||||||
|
take_indices, max_index = feature_take_indices(n, num_blocks)
|
||||||
|
|
||||||
# forward pass
|
# forward pass
|
||||||
|
B, _, height, width = x.shape
|
||||||
x = self.patch_embed(x)
|
x = self.patch_embed(x)
|
||||||
x = self._pos_embed(x)
|
x = self._pos_embed(x)
|
||||||
x = self.patch_drop(x)
|
x = self.patch_drop(x)
|
||||||
x = self.norm_pre(x)
|
x = self.norm_pre(x)
|
||||||
for i, blk in enumerate(self.blocks[: last_index_to_take + 1]):
|
if torch.jit.is_scripting() or not stop_early: # can't slice blocks in torchscript
|
||||||
|
blocks = self.blocks
|
||||||
|
else:
|
||||||
|
blocks = self.blocks[:max_index + 1]
|
||||||
|
for i, blk in enumerate(blocks):
|
||||||
x = blk(x)
|
x = blk(x)
|
||||||
if i in take_indices:
|
if i in take_indices:
|
||||||
outputs.append(x)
|
# normalize intermediates with final norm layer if enabled
|
||||||
|
intermediates.append(self.norm(x) if norm else x)
|
||||||
|
|
||||||
return outputs
|
# process intermediates
|
||||||
|
if self.num_prefix_tokens:
|
||||||
|
# split prefix (e.g. class, distill) and spatial feature tokens
|
||||||
|
prefix_tokens = [y[:, 0:self.num_prefix_tokens] for y in intermediates]
|
||||||
|
intermediates = [y[:, self.num_prefix_tokens:] for y in intermediates]
|
||||||
|
if reshape:
|
||||||
|
# reshape == True => BCHW output format
|
||||||
|
patch_size = self.patch_embed.patch_size
|
||||||
|
H = int(math.ceil(height / patch_size[0]))
|
||||||
|
W = int(math.ceil(width / patch_size[1]))
|
||||||
|
intermediates = [y.reshape(B, H, W, -1).permute(0, 3, 1, 2).contiguous() for y in intermediates]
|
||||||
|
if not torch.jit.is_scripting() and return_prefix_tokens:
|
||||||
|
# return_prefix not support in torchscript due to poor type handling
|
||||||
|
intermediates = list(zip(intermediates, prefix_tokens))
|
||||||
|
|
||||||
|
if features_only:
|
||||||
|
return intermediates
|
||||||
|
|
||||||
|
x = self.norm(x)
|
||||||
|
|
||||||
|
return x, intermediates
|
||||||
|
|
||||||
|
def prune_intermediate_layers(
|
||||||
|
self,
|
||||||
|
n: Union[int, List[int], Tuple[int]] = 1,
|
||||||
|
prune_norm: bool = False,
|
||||||
|
prune_head: bool = True,
|
||||||
|
):
|
||||||
|
""" Prune layers not required for specified intermediates.
|
||||||
|
"""
|
||||||
|
take_indices, max_index = feature_take_indices(n, len(self.blocks))
|
||||||
|
self.blocks = self.blocks[:max_index + 1] # truncate blocks
|
||||||
|
if prune_norm:
|
||||||
|
self.norm = nn.Identity()
|
||||||
|
if prune_head:
|
||||||
|
if self.attn_pool is not None:
|
||||||
|
self.attn_pool = None
|
||||||
|
self.fc_norm = nn.Identity()
|
||||||
|
self.head = nn.Identity()
|
||||||
|
|
||||||
def get_intermediate_layers(
|
def get_intermediate_layers(
|
||||||
self,
|
self,
|
||||||
x: torch.Tensor,
|
x: torch.Tensor,
|
||||||
n: Union[int, Sequence] = 1,
|
n: Union[int, List[int], Tuple[int]] = 1,
|
||||||
reshape: bool = False,
|
reshape: bool = False,
|
||||||
return_prefix_tokens: bool = False,
|
return_prefix_tokens: bool = False,
|
||||||
norm: bool = False,
|
norm: bool = False,
|
||||||
) -> Tuple[Union[torch.Tensor, Tuple[torch.Tensor]]]:
|
) -> List[torch.Tensor]:
|
||||||
""" Intermediate layer accessor (NOTE: This is a WIP experiment).
|
""" Intermediate layer accessor inspired by DINO / DINOv2 interface.
|
||||||
Inspired by DINO / DINOv2 interface
|
NOTE: This API is for backwards compat, favour using forward_intermediates() directly.
|
||||||
"""
|
"""
|
||||||
# take last n blocks if n is an int, if in is a sequence, select by matching indices
|
return self.forward_intermediates(
|
||||||
outputs = self._intermediate_layers(x, n)
|
x, n,
|
||||||
if norm:
|
return_prefix_tokens=return_prefix_tokens,
|
||||||
outputs = [self.norm(out) for out in outputs]
|
norm=norm,
|
||||||
prefix_tokens = [out[:, 0:self.num_prefix_tokens] for out in outputs]
|
output_fmt='NCHW' if reshape else 'NLC',
|
||||||
outputs = [out[:, self.num_prefix_tokens:] for out in outputs]
|
features_only=True,
|
||||||
|
)
|
||||||
if reshape:
|
|
||||||
patch_size = self.patch_embed.patch_size
|
|
||||||
batch, _, height, width = x.size()
|
|
||||||
outputs = [
|
|
||||||
out.reshape(batch, int(math.ceil(height / patch_size[0])), int(math.ceil(width / patch_size[1])), -1)
|
|
||||||
.permute(0, 3, 1, 2)
|
|
||||||
.contiguous()
|
|
||||||
for out in outputs
|
|
||||||
]
|
|
||||||
|
|
||||||
if return_prefix_tokens:
|
|
||||||
return tuple(zip(outputs, prefix_tokens))
|
|
||||||
return tuple(outputs)
|
|
||||||
|
|
||||||
def forward_features(self, x: torch.Tensor) -> torch.Tensor:
|
def forward_features(self, x: torch.Tensor) -> torch.Tensor:
|
||||||
x = self.patch_embed(x)
|
x = self.patch_embed(x)
|
||||||
@ -2485,7 +2538,7 @@ def vit_huge_patch14_xp_224(pretrained: bool = False, **kwargs) -> VisionTransfo
|
|||||||
def vit_small_patch14_dinov2(pretrained: bool = False, **kwargs) -> VisionTransformer:
|
def vit_small_patch14_dinov2(pretrained: bool = False, **kwargs) -> VisionTransformer:
|
||||||
""" ViT-S/14 for DINOv2
|
""" ViT-S/14 for DINOv2
|
||||||
"""
|
"""
|
||||||
model_args = dict(patch_size=14, embed_dim=384, depth=12, num_heads=6, init_values=1e-5, img_size=518)
|
model_args = dict(patch_size=14, embed_dim=384, depth=12, num_heads=6, init_values=1e-5)
|
||||||
model = _create_vision_transformer(
|
model = _create_vision_transformer(
|
||||||
'vit_small_patch14_dinov2', pretrained=pretrained, **dict(model_args, **kwargs))
|
'vit_small_patch14_dinov2', pretrained=pretrained, **dict(model_args, **kwargs))
|
||||||
return model
|
return model
|
||||||
@ -2495,7 +2548,7 @@ def vit_small_patch14_dinov2(pretrained: bool = False, **kwargs) -> VisionTransf
|
|||||||
def vit_base_patch14_dinov2(pretrained: bool = False, **kwargs) -> VisionTransformer:
|
def vit_base_patch14_dinov2(pretrained: bool = False, **kwargs) -> VisionTransformer:
|
||||||
""" ViT-B/14 for DINOv2
|
""" ViT-B/14 for DINOv2
|
||||||
"""
|
"""
|
||||||
model_args = dict(patch_size=14, embed_dim=768, depth=12, num_heads=12, init_values=1e-5, img_size=518)
|
model_args = dict(patch_size=14, embed_dim=768, depth=12, num_heads=12, init_values=1e-5)
|
||||||
model = _create_vision_transformer(
|
model = _create_vision_transformer(
|
||||||
'vit_base_patch14_dinov2', pretrained=pretrained, **dict(model_args, **kwargs))
|
'vit_base_patch14_dinov2', pretrained=pretrained, **dict(model_args, **kwargs))
|
||||||
return model
|
return model
|
||||||
@ -2505,7 +2558,7 @@ def vit_base_patch14_dinov2(pretrained: bool = False, **kwargs) -> VisionTransfo
|
|||||||
def vit_large_patch14_dinov2(pretrained: bool = False, **kwargs) -> VisionTransformer:
|
def vit_large_patch14_dinov2(pretrained: bool = False, **kwargs) -> VisionTransformer:
|
||||||
""" ViT-L/14 for DINOv2
|
""" ViT-L/14 for DINOv2
|
||||||
"""
|
"""
|
||||||
model_args = dict(patch_size=14, embed_dim=1024, depth=24, num_heads=16, init_values=1e-5, img_size=518)
|
model_args = dict(patch_size=14, embed_dim=1024, depth=24, num_heads=16, init_values=1e-5)
|
||||||
model = _create_vision_transformer(
|
model = _create_vision_transformer(
|
||||||
'vit_large_patch14_dinov2', pretrained=pretrained, **dict(model_args, **kwargs))
|
'vit_large_patch14_dinov2', pretrained=pretrained, **dict(model_args, **kwargs))
|
||||||
return model
|
return model
|
||||||
@ -2521,7 +2574,7 @@ def vit_giant_patch14_dinov2(pretrained: bool = False, **kwargs) -> VisionTransf
|
|||||||
# With SwiGLUPacked, we need to set hidden_features = 2 * 4096 = 8192
|
# With SwiGLUPacked, we need to set hidden_features = 2 * 4096 = 8192
|
||||||
model_args = dict(
|
model_args = dict(
|
||||||
patch_size=14, embed_dim=1536, depth=40, num_heads=24, init_values=1e-5,
|
patch_size=14, embed_dim=1536, depth=40, num_heads=24, init_values=1e-5,
|
||||||
mlp_ratio=2.66667 * 2, mlp_layer=SwiGLUPacked, img_size=518, act_layer=nn.SiLU
|
mlp_ratio=2.66667 * 2, mlp_layer=SwiGLUPacked, act_layer=nn.SiLU
|
||||||
)
|
)
|
||||||
model = _create_vision_transformer(
|
model = _create_vision_transformer(
|
||||||
'vit_giant_patch14_dinov2', pretrained=pretrained, **dict(model_args, **kwargs))
|
'vit_giant_patch14_dinov2', pretrained=pretrained, **dict(model_args, **kwargs))
|
||||||
|
@ -11,21 +11,22 @@ A PyTorch implement of Vision Transformers as described in:
|
|||||||
"""
|
"""
|
||||||
import logging
|
import logging
|
||||||
from functools import partial
|
from functools import partial
|
||||||
from typing import Callable, Optional, Tuple
|
from typing import Callable, List, Optional, Tuple, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
import torch.utils.checkpoint
|
import torch.utils.checkpoint
|
||||||
|
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD, IMAGENET_INCEPTION_MEAN, IMAGENET_INCEPTION_STD
|
||||||
|
from timm.layers import PatchEmbed, Mlp, DropPath, PatchDropout, LayerNorm2d, ClassifierHead, NormMlpClassifierHead, \
|
||||||
|
Format, resample_abs_pos_embed_nhwc, RotaryEmbeddingCat, apply_rot_embed_cat, to_2tuple, use_fused_attn
|
||||||
from torch.jit import Final
|
from torch.jit import Final
|
||||||
|
|
||||||
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD, IMAGENET_INCEPTION_MEAN, IMAGENET_INCEPTION_STD
|
|
||||||
from timm.layers import PatchEmbed, Mlp, DropPath, PatchDropout, LayerNorm2d, ClassifierHead, NormMlpClassifierHead,\
|
|
||||||
Format, resample_abs_pos_embed_nhwc, RotaryEmbeddingCat, apply_rot_embed_cat, to_2tuple, use_fused_attn
|
|
||||||
from ._builder import build_model_with_cfg
|
from ._builder import build_model_with_cfg
|
||||||
|
from ._features import feature_take_indices
|
||||||
|
from ._features_fx import register_notrace_function
|
||||||
from ._manipulate import checkpoint_seq
|
from ._manipulate import checkpoint_seq
|
||||||
from ._registry import generate_default_cfgs, register_model
|
from ._registry import generate_default_cfgs, register_model
|
||||||
from ._features_fx import register_notrace_function
|
|
||||||
|
|
||||||
# model_registry will add each entrypoint fn to this
|
# model_registry will add each entrypoint fn to this
|
||||||
__all__ = ['VisionTransformerSAM']
|
__all__ = ['VisionTransformerSAM']
|
||||||
@ -343,8 +344,7 @@ class VisionTransformerSAM(nn.Module):
|
|||||||
attn_drop_rate: float = 0.,
|
attn_drop_rate: float = 0.,
|
||||||
drop_path_rate: float = 0.,
|
drop_path_rate: float = 0.,
|
||||||
weight_init: str = '',
|
weight_init: str = '',
|
||||||
embed_layer: Callable = partial(
|
embed_layer: Callable = partial(PatchEmbed, output_fmt=Format.NHWC, strict_img_size=False),
|
||||||
PatchEmbed, output_fmt=Format.NHWC, strict_img_size=False),
|
|
||||||
norm_layer: Optional[Callable] = nn.LayerNorm,
|
norm_layer: Optional[Callable] = nn.LayerNorm,
|
||||||
act_layer: Optional[Callable] = nn.GELU,
|
act_layer: Optional[Callable] = nn.GELU,
|
||||||
block_fn: Callable = Block,
|
block_fn: Callable = Block,
|
||||||
@ -469,6 +469,8 @@ class VisionTransformerSAM(nn.Module):
|
|||||||
rope=self.rope_window if i not in global_attn_indexes else self.rope_global,
|
rope=self.rope_window if i not in global_attn_indexes else self.rope_global,
|
||||||
)
|
)
|
||||||
for i in range(depth)])
|
for i in range(depth)])
|
||||||
|
self.feature_info = [
|
||||||
|
dict(module=f'blocks.{i}', num_chs=embed_dim, reduction=patch_size) for i in range(depth)]
|
||||||
|
|
||||||
if neck_chans:
|
if neck_chans:
|
||||||
self.neck = nn.Sequential(
|
self.neck = nn.Sequential(
|
||||||
@ -536,6 +538,52 @@ class VisionTransformerSAM(nn.Module):
|
|||||||
def reset_classifier(self, num_classes=0, global_pool=None):
|
def reset_classifier(self, num_classes=0, global_pool=None):
|
||||||
self.head.reset(num_classes, global_pool)
|
self.head.reset(num_classes, global_pool)
|
||||||
|
|
||||||
|
def forward_intermediates(
|
||||||
|
self,
|
||||||
|
x: torch.Tensor,
|
||||||
|
n: Union[int, List[int], Tuple[int]] = None,
|
||||||
|
norm: bool = False,
|
||||||
|
stop_early: bool = True,
|
||||||
|
output_fmt: str = 'NCHW',
|
||||||
|
features_only: bool = False,
|
||||||
|
) -> Union[List[torch.Tensor], Tuple[torch.Tensor, List[torch.Tensor]]]:
|
||||||
|
assert output_fmt == 'NCHW', 'Output shape for ViT-SAM must be NCHW.'
|
||||||
|
intermediates = []
|
||||||
|
num_blocks = len(self.blocks)
|
||||||
|
if n is None:
|
||||||
|
n = num_blocks
|
||||||
|
take_indices, max_index = feature_take_indices(n, num_blocks)
|
||||||
|
|
||||||
|
# forward pass, collect intermediates
|
||||||
|
x = self.patch_embed(x)
|
||||||
|
if self.pos_embed is not None:
|
||||||
|
# dynamically resize abs pos embedding if needed
|
||||||
|
x = x + resample_abs_pos_embed_nhwc(self.pos_embed, x.shape[1:3])
|
||||||
|
x = self.pos_drop(x)
|
||||||
|
x = self.patch_drop(x)
|
||||||
|
x = self.norm_pre(x)
|
||||||
|
if torch.jit.is_scripting() or not stop_early: # can't slice blocks in torchscript
|
||||||
|
blocks = self.blocks
|
||||||
|
else:
|
||||||
|
blocks = self.blocks[:max_index + 1]
|
||||||
|
for i, blk in enumerate(blocks):
|
||||||
|
x = blk(x)
|
||||||
|
if i in take_indices:
|
||||||
|
# make output BCHW
|
||||||
|
if norm:
|
||||||
|
# norm is intertwined with neck convs so apply both, changes the dim
|
||||||
|
# FIXME only apply to final? Need experiments
|
||||||
|
intermediates.append(self.neck(x.permute(0, 3, 1, 2)))
|
||||||
|
else:
|
||||||
|
intermediates.append(x.permute(0, 3, 1, 2))
|
||||||
|
|
||||||
|
if features_only:
|
||||||
|
return intermediates
|
||||||
|
|
||||||
|
x = self.neck(x.permute(0, 3, 1, 2))
|
||||||
|
|
||||||
|
return x, intermediates
|
||||||
|
|
||||||
def forward_features(self, x):
|
def forward_features(self, x):
|
||||||
x = self.patch_embed(x)
|
x = self.patch_embed(x)
|
||||||
if self.pos_embed is not None:
|
if self.pos_embed is not None:
|
||||||
@ -618,15 +666,13 @@ default_cfgs = generate_default_cfgs({
|
|||||||
|
|
||||||
|
|
||||||
def _create_vision_transformer(variant, pretrained=False, **kwargs):
|
def _create_vision_transformer(variant, pretrained=False, **kwargs):
|
||||||
if kwargs.get('features_only', None):
|
out_indices = kwargs.pop('out_indices', 3)
|
||||||
raise RuntimeError(
|
|
||||||
'features_only not implemented for Vision Transformer models.')
|
|
||||||
|
|
||||||
return build_model_with_cfg(
|
return build_model_with_cfg(
|
||||||
VisionTransformerSAM,
|
VisionTransformerSAM,
|
||||||
variant,
|
variant,
|
||||||
pretrained,
|
pretrained,
|
||||||
pretrained_filter_fn=checkpoint_filter_fn,
|
pretrained_filter_fn=checkpoint_filter_fn,
|
||||||
|
feature_cfg=dict(out_indices=out_indices, feature_cls='getter'),
|
||||||
**kwargs,
|
**kwargs,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user