Merge pull request #2168 from huggingface/more_vit_better_getter_redux

A few more features_intermediate() models, AttentionExtract helper, related minor cleanup.
pull/2167/merge
Ross Wightman 2024-05-11 08:13:50 -07:00 committed by GitHub
commit b2c10fec05
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
34 changed files with 479 additions and 288 deletions

View File

@ -51,7 +51,8 @@ if hasattr(torch._C, '_jit_set_profiling_executor'):
FEAT_INTER_FILTERS = [
'vision_transformer', 'vision_transformer_sam', 'vision_transformer_hybrid', 'vision_transformer_relpos',
'beit', 'mvitv2', 'eva', 'cait', 'xcit', 'volo', 'twins', 'deit', 'swin_transformer', 'swin_transformer_v2',
'swin_transformer_v2_cr', 'maxxvit', 'efficientnet', 'mobilenetv3', 'levit', 'efficientformer', 'resnet'
'swin_transformer_v2_cr', 'maxxvit', 'efficientnet', 'mobilenetv3', 'levit', 'efficientformer', 'resnet',
'regnet', 'byobnet', 'byoanet', 'mlp_mixer'
]
# transformer / hybrid models don't support full set of spatial / feature APIs and/or have spatial output.

View File

@ -47,7 +47,7 @@ from .pos_embed_sincos import pixel_freq_bands, freq_bands, build_sincos2d_pos_e
from .squeeze_excite import SEModule, SqueezeExcite, EffectiveSEModule, EffectiveSqueezeExcite
from .selective_kernel import SelectiveKernel
from .separable_conv import SeparableConv2d, SeparableConvNormAct
from .space_to_depth import SpaceToDepthModule, SpaceToDepth, DepthToSpace
from .space_to_depth import SpaceToDepth, DepthToSpace
from .split_attn import SplitAttn
from .split_batchnorm import SplitBatchNorm2d, convert_splitbn_model
from .std_conv import StdConv2d, StdConv2dSame, ScaledStdConv2d, ScaledStdConv2dSame

View File

@ -1,90 +0,0 @@
""" Activations
A collection of jit-scripted activations fn and modules with a common interface so that they can
easily be swapped. All have an `inplace` arg even if not used.
All jit scripted activations are lacking in-place variations on purpose, scripted kernel fusion does not
currently work across in-place op boundaries, thus performance is equal to or less than the non-scripted
versions if they contain in-place ops.
Hacked together by / Copyright 2020 Ross Wightman
"""
import torch
from torch import nn as nn
from torch.nn import functional as F
@torch.jit.script
def swish_jit(x, inplace: bool = False):
"""Swish - Described in: https://arxiv.org/abs/1710.05941
"""
return x.mul(x.sigmoid())
@torch.jit.script
def mish_jit(x, _inplace: bool = False):
"""Mish: A Self Regularized Non-Monotonic Neural Activation Function - https://arxiv.org/abs/1908.08681
"""
return x.mul(F.softplus(x).tanh())
class SwishJit(nn.Module):
def __init__(self, inplace: bool = False):
super(SwishJit, self).__init__()
def forward(self, x):
return swish_jit(x)
class MishJit(nn.Module):
def __init__(self, inplace: bool = False):
super(MishJit, self).__init__()
def forward(self, x):
return mish_jit(x)
@torch.jit.script
def hard_sigmoid_jit(x, inplace: bool = False):
# return F.relu6(x + 3.) / 6.
return (x + 3).clamp(min=0, max=6).div(6.) # clamp seems ever so slightly faster?
class HardSigmoidJit(nn.Module):
def __init__(self, inplace: bool = False):
super(HardSigmoidJit, self).__init__()
def forward(self, x):
return hard_sigmoid_jit(x)
@torch.jit.script
def hard_swish_jit(x, inplace: bool = False):
# return x * (F.relu6(x + 3.) / 6)
return x * (x + 3).clamp(min=0, max=6).div(6.) # clamp seems ever so slightly faster?
class HardSwishJit(nn.Module):
def __init__(self, inplace: bool = False):
super(HardSwishJit, self).__init__()
def forward(self, x):
return hard_swish_jit(x)
@torch.jit.script
def hard_mish_jit(x, inplace: bool = False):
""" Hard Mish
Experimental, based on notes by Mish author Diganta Misra at
https://github.com/digantamisra98/H-Mish/blob/0da20d4bc58e696b6803f2523c58d3c8a82782d0/README.md
"""
return 0.5 * x * (x + 2).clamp(min=0, max=2)
class HardMishJit(nn.Module):
def __init__(self, inplace: bool = False):
super(HardMishJit, self).__init__()
def forward(self, x):
return hard_mish_jit(x)

View File

@ -3,8 +3,8 @@
A collection of activations fn and modules with a common interface so that they can
easily be swapped. All have an `inplace` arg even if not used.
These activations are not compatible with jit scripting or ONNX export of the model, please use either
the JIT or basic versions of the activations.
These activations are not compatible with jit scripting or ONNX export of the model, please use
basic versions of the activations.
Hacked together by / Copyright 2020 Ross Wightman
"""
@ -14,19 +14,17 @@ from torch import nn as nn
from torch.nn import functional as F
@torch.jit.script
def swish_jit_fwd(x):
def swish_fwd(x):
return x.mul(torch.sigmoid(x))
@torch.jit.script
def swish_jit_bwd(x, grad_output):
def swish_bwd(x, grad_output):
x_sigmoid = torch.sigmoid(x)
return grad_output * (x_sigmoid * (1 + x * (1 - x_sigmoid)))
class SwishJitAutoFn(torch.autograd.Function):
""" torch.jit.script optimised Swish w/ memory-efficient checkpoint
class SwishAutoFn(torch.autograd.Function):
""" optimised Swish w/ memory-efficient checkpoint
Inspired by conversation btw Jeremy Howard & Adam Pazske
https://twitter.com/jeremyphoward/status/1188251041835315200
"""
@ -37,16 +35,16 @@ class SwishJitAutoFn(torch.autograd.Function):
@staticmethod
def forward(ctx, x):
ctx.save_for_backward(x)
return swish_jit_fwd(x)
return swish_fwd(x)
@staticmethod
def backward(ctx, grad_output):
x = ctx.saved_tensors[0]
return swish_jit_bwd(x, grad_output)
return swish_bwd(x, grad_output)
def swish_me(x, inplace=False):
return SwishJitAutoFn.apply(x)
return SwishAutoFn.apply(x)
class SwishMe(nn.Module):
@ -54,38 +52,36 @@ class SwishMe(nn.Module):
super(SwishMe, self).__init__()
def forward(self, x):
return SwishJitAutoFn.apply(x)
return SwishAutoFn.apply(x)
@torch.jit.script
def mish_jit_fwd(x):
def mish_fwd(x):
return x.mul(torch.tanh(F.softplus(x)))
@torch.jit.script
def mish_jit_bwd(x, grad_output):
def mish_bwd(x, grad_output):
x_sigmoid = torch.sigmoid(x)
x_tanh_sp = F.softplus(x).tanh()
return grad_output.mul(x_tanh_sp + x * x_sigmoid * (1 - x_tanh_sp * x_tanh_sp))
class MishJitAutoFn(torch.autograd.Function):
class MishAutoFn(torch.autograd.Function):
""" Mish: A Self Regularized Non-Monotonic Neural Activation Function - https://arxiv.org/abs/1908.08681
A memory efficient, jit scripted variant of Mish
A memory efficient variant of Mish
"""
@staticmethod
def forward(ctx, x):
ctx.save_for_backward(x)
return mish_jit_fwd(x)
return mish_fwd(x)
@staticmethod
def backward(ctx, grad_output):
x = ctx.saved_tensors[0]
return mish_jit_bwd(x, grad_output)
return mish_bwd(x, grad_output)
def mish_me(x, inplace=False):
return MishJitAutoFn.apply(x)
return MishAutoFn.apply(x)
class MishMe(nn.Module):
@ -93,34 +89,32 @@ class MishMe(nn.Module):
super(MishMe, self).__init__()
def forward(self, x):
return MishJitAutoFn.apply(x)
return MishAutoFn.apply(x)
@torch.jit.script
def hard_sigmoid_jit_fwd(x, inplace: bool = False):
def hard_sigmoid_fwd(x, inplace: bool = False):
return (x + 3).clamp(min=0, max=6).div(6.)
@torch.jit.script
def hard_sigmoid_jit_bwd(x, grad_output):
def hard_sigmoid_bwd(x, grad_output):
m = torch.ones_like(x) * ((x >= -3.) & (x <= 3.)) / 6.
return grad_output * m
class HardSigmoidJitAutoFn(torch.autograd.Function):
class HardSigmoidAutoFn(torch.autograd.Function):
@staticmethod
def forward(ctx, x):
ctx.save_for_backward(x)
return hard_sigmoid_jit_fwd(x)
return hard_sigmoid_fwd(x)
@staticmethod
def backward(ctx, grad_output):
x = ctx.saved_tensors[0]
return hard_sigmoid_jit_bwd(x, grad_output)
return hard_sigmoid_bwd(x, grad_output)
def hard_sigmoid_me(x, inplace: bool = False):
return HardSigmoidJitAutoFn.apply(x)
return HardSigmoidAutoFn.apply(x)
class HardSigmoidMe(nn.Module):
@ -128,32 +122,30 @@ class HardSigmoidMe(nn.Module):
super(HardSigmoidMe, self).__init__()
def forward(self, x):
return HardSigmoidJitAutoFn.apply(x)
return HardSigmoidAutoFn.apply(x)
@torch.jit.script
def hard_swish_jit_fwd(x):
def hard_swish_fwd(x):
return x * (x + 3).clamp(min=0, max=6).div(6.)
@torch.jit.script
def hard_swish_jit_bwd(x, grad_output):
def hard_swish_bwd(x, grad_output):
m = torch.ones_like(x) * (x >= 3.)
m = torch.where((x >= -3.) & (x <= 3.), x / 3. + .5, m)
return grad_output * m
class HardSwishJitAutoFn(torch.autograd.Function):
"""A memory efficient, jit-scripted HardSwish activation"""
class HardSwishAutoFn(torch.autograd.Function):
"""A memory efficient HardSwish activation"""
@staticmethod
def forward(ctx, x):
ctx.save_for_backward(x)
return hard_swish_jit_fwd(x)
return hard_swish_fwd(x)
@staticmethod
def backward(ctx, grad_output):
x = ctx.saved_tensors[0]
return hard_swish_jit_bwd(x, grad_output)
return hard_swish_bwd(x, grad_output)
@staticmethod
def symbolic(g, self):
@ -164,7 +156,7 @@ class HardSwishJitAutoFn(torch.autograd.Function):
def hard_swish_me(x, inplace=False):
return HardSwishJitAutoFn.apply(x)
return HardSwishAutoFn.apply(x)
class HardSwishMe(nn.Module):
@ -172,39 +164,37 @@ class HardSwishMe(nn.Module):
super(HardSwishMe, self).__init__()
def forward(self, x):
return HardSwishJitAutoFn.apply(x)
return HardSwishAutoFn.apply(x)
@torch.jit.script
def hard_mish_jit_fwd(x):
def hard_mish_fwd(x):
return 0.5 * x * (x + 2).clamp(min=0, max=2)
@torch.jit.script
def hard_mish_jit_bwd(x, grad_output):
def hard_mish_bwd(x, grad_output):
m = torch.ones_like(x) * (x >= -2.)
m = torch.where((x >= -2.) & (x <= 0.), x + 1., m)
return grad_output * m
class HardMishJitAutoFn(torch.autograd.Function):
""" A memory efficient, jit scripted variant of Hard Mish
class HardMishAutoFn(torch.autograd.Function):
""" A memory efficient variant of Hard Mish
Experimental, based on notes by Mish author Diganta Misra at
https://github.com/digantamisra98/H-Mish/blob/0da20d4bc58e696b6803f2523c58d3c8a82782d0/README.md
"""
@staticmethod
def forward(ctx, x):
ctx.save_for_backward(x)
return hard_mish_jit_fwd(x)
return hard_mish_fwd(x)
@staticmethod
def backward(ctx, grad_output):
x = ctx.saved_tensors[0]
return hard_mish_jit_bwd(x, grad_output)
return hard_mish_bwd(x, grad_output)
def hard_mish_me(x, inplace: bool = False):
return HardMishJitAutoFn.apply(x)
return HardMishAutoFn.apply(x)
class HardMishMe(nn.Module):
@ -212,7 +202,7 @@ class HardMishMe(nn.Module):
super(HardMishMe, self).__init__()
def forward(self, x):
return HardMishJitAutoFn.apply(x)
return HardMishAutoFn.apply(x)

View File

@ -4,9 +4,8 @@ Hacked together by / Copyright 2020 Ross Wightman
from typing import Union, Callable, Type
from .activations import *
from .activations_jit import *
from .activations_me import *
from .config import is_exportable, is_scriptable, is_no_jit
from .config import is_exportable, is_scriptable
# PyTorch has an optimized, native 'silu' (aka 'swish') operator as of PyTorch 1.7.
# Also hardsigmoid, hardswish, and soon mish. This code will use native version if present.
@ -37,15 +36,6 @@ _ACT_FN_DEFAULT = dict(
hard_mish=hard_mish,
)
_ACT_FN_JIT = dict(
silu=F.silu if _has_silu else swish_jit,
swish=F.silu if _has_silu else swish_jit,
mish=F.mish if _has_mish else mish_jit,
hard_sigmoid=F.hardsigmoid if _has_hardsigmoid else hard_sigmoid_jit,
hard_swish=F.hardswish if _has_hardswish else hard_swish_jit,
hard_mish=hard_mish_jit,
)
_ACT_FN_ME = dict(
silu=F.silu if _has_silu else swish_me,
swish=F.silu if _has_silu else swish_me,
@ -55,7 +45,7 @@ _ACT_FN_ME = dict(
hard_mish=hard_mish_me,
)
_ACT_FNS = (_ACT_FN_ME, _ACT_FN_JIT, _ACT_FN_DEFAULT)
_ACT_FNS = (_ACT_FN_ME, _ACT_FN_DEFAULT)
for a in _ACT_FNS:
a.setdefault('hardsigmoid', a.get('hard_sigmoid'))
a.setdefault('hardswish', a.get('hard_swish'))
@ -83,15 +73,6 @@ _ACT_LAYER_DEFAULT = dict(
identity=nn.Identity,
)
_ACT_LAYER_JIT = dict(
silu=nn.SiLU if _has_silu else SwishJit,
swish=nn.SiLU if _has_silu else SwishJit,
mish=nn.Mish if _has_mish else MishJit,
hard_sigmoid=nn.Hardsigmoid if _has_hardsigmoid else HardSigmoidJit,
hard_swish=nn.Hardswish if _has_hardswish else HardSwishJit,
hard_mish=HardMishJit,
)
_ACT_LAYER_ME = dict(
silu=nn.SiLU if _has_silu else SwishMe,
swish=nn.SiLU if _has_silu else SwishMe,
@ -101,7 +82,7 @@ _ACT_LAYER_ME = dict(
hard_mish=HardMishMe,
)
_ACT_LAYERS = (_ACT_LAYER_ME, _ACT_LAYER_JIT, _ACT_LAYER_DEFAULT)
_ACT_LAYERS = (_ACT_LAYER_ME, _ACT_LAYER_DEFAULT)
for a in _ACT_LAYERS:
a.setdefault('hardsigmoid', a.get('hard_sigmoid'))
a.setdefault('hardswish', a.get('hard_swish'))
@ -116,14 +97,11 @@ def get_act_fn(name: Union[Callable, str] = 'relu'):
return None
if isinstance(name, Callable):
return name
if not (is_no_jit() or is_exportable() or is_scriptable()):
if not (is_exportable() or is_scriptable()):
# If not exporting or scripting the model, first look for a memory-efficient version with
# custom autograd, then fallback
if name in _ACT_FN_ME:
return _ACT_FN_ME[name]
if not (is_no_jit() or is_exportable()):
if name in _ACT_FN_JIT:
return _ACT_FN_JIT[name]
return _ACT_FN_DEFAULT[name]
@ -139,12 +117,9 @@ def get_act_layer(name: Union[Type[nn.Module], str] = 'relu'):
return name
if not name:
return None
if not (is_no_jit() or is_exportable() or is_scriptable()):
if not (is_exportable() or is_scriptable()):
if name in _ACT_LAYER_ME:
return _ACT_LAYER_ME[name]
if not (is_no_jit() or is_exportable()):
if name in _ACT_LAYER_JIT:
return _ACT_LAYER_JIT[name]
return _ACT_LAYER_DEFAULT[name]

View File

@ -73,7 +73,6 @@ class TransformerDecoderLayerOptimal(nn.Module):
return tgt
# @torch.jit.script
# class ExtrapClasses(object):
# def __init__(self, num_queries: int, group_size: int):
# self.num_queries = num_queries
@ -88,24 +87,13 @@ class TransformerDecoderLayerOptimal(nn.Module):
# out = out.view((h.shape[0], self.group_size * self.num_queries))
# return out
@torch.jit.script
class GroupFC(object):
def __init__(self, embed_len_decoder: int):
self.embed_len_decoder = embed_len_decoder
def __call__(self, h: torch.Tensor, duplicate_pooling: torch.Tensor, out_extrap: torch.Tensor):
for i in range(self.embed_len_decoder):
h_i = h[:, i, :]
w_i = duplicate_pooling[i, :, :]
out_extrap[:, i, :] = torch.matmul(h_i, w_i)
class MLDecoder(nn.Module):
def __init__(self, num_classes, num_of_groups=-1, decoder_embedding=768, initial_num_features=2048):
super(MLDecoder, self).__init__()
embed_len_decoder = 100 if num_of_groups < 0 else num_of_groups
if embed_len_decoder > num_classes:
embed_len_decoder = num_classes
self.embed_len_decoder = embed_len_decoder
# switching to 768 initial embeddings
decoder_embedding = 768 if decoder_embedding < 0 else decoder_embedding
@ -131,7 +119,6 @@ class MLDecoder(nn.Module):
self.duplicate_pooling_bias = torch.nn.Parameter(torch.Tensor(num_classes))
torch.nn.init.xavier_normal_(self.duplicate_pooling)
torch.nn.init.constant_(self.duplicate_pooling_bias, 0)
self.group_fc = GroupFC(embed_len_decoder)
def forward(self, x):
if len(x.shape) == 4: # [bs,2048, 7,7]
@ -149,7 +136,10 @@ class MLDecoder(nn.Module):
h = h.transpose(0, 1)
out_extrap = torch.zeros(h.shape[0], h.shape[1], self.duplicate_factor, device=h.device, dtype=h.dtype)
self.group_fc(h, self.duplicate_pooling, out_extrap)
for i in range(self.embed_len_decoder): # group FC
h_i = h[:, i, :]
w_i = self.duplicate_pooling[i, :, :]
out_extrap[:, i, :] = torch.matmul(h_i, w_i)
h_out = out_extrap.flatten(1)[:, :self.num_classes]
h_out += self.duplicate_pooling_bias
logits = h_out

View File

@ -82,7 +82,6 @@ def _is_contiguous(tensor: torch.Tensor) -> bool:
return tensor.is_contiguous(memory_format=torch.contiguous_format)
@torch.jit.script
def _layer_norm_cf(x: torch.Tensor, weight: torch.Tensor, bias: torch.Tensor, eps: float):
s, u = torch.var_mean(x, dim=1, unbiased=False, keepdim=True)
x = (x - u) * torch.rsqrt(s + eps)

View File

@ -18,29 +18,6 @@ class SpaceToDepth(nn.Module):
return x
@torch.jit.script
class SpaceToDepthJit:
def __call__(self, x: torch.Tensor):
# assuming hard-coded that block_size==4 for acceleration
N, C, H, W = x.size()
x = x.view(N, C, H // 4, 4, W // 4, 4) # (N, C, H//bs, bs, W//bs, bs)
x = x.permute(0, 3, 5, 1, 2, 4).contiguous() # (N, bs, bs, C, H//bs, W//bs)
x = x.view(N, C * 16, H // 4, W // 4) # (N, C*bs^2, H//bs, W//bs)
return x
class SpaceToDepthModule(nn.Module):
def __init__(self, no_jit=False):
super().__init__()
if not no_jit:
self.op = SpaceToDepthJit()
else:
self.op = SpaceToDepth()
def forward(self, x):
return self.op(x)
class DepthToSpace(nn.Module):
def __init__(self, block_size):

View File

@ -80,7 +80,7 @@ from ._builder import build_model_with_cfg, load_pretrained, load_custom_pretrai
set_pretrained_download_progress, set_pretrained_check_hash
from ._factory import create_model, parse_model_name, safe_model_name
from ._features import FeatureInfo, FeatureHooks, FeatureHookNet, FeatureListNet, FeatureDictNet
from ._features_fx import FeatureGraphNet, GraphExtractNet, create_feature_extractor, \
from ._features_fx import FeatureGraphNet, GraphExtractNet, create_feature_extractor, get_graph_node_names, \
register_notrace_module, is_notrace_module, get_notrace_modules, \
register_notrace_function, is_notrace_function, get_notrace_functions
from ._helpers import clean_state_dict, load_state_dict, load_checkpoint, remap_state_dict, resume_checkpoint

View File

@ -158,7 +158,7 @@ class FeatureHooks:
def __init__(
self,
hooks: Sequence[str],
hooks: Sequence[Union[str, Dict]],
named_modules: dict,
out_map: Sequence[Union[int, str]] = None,
default_hook_type: str = 'forward',
@ -168,11 +168,13 @@ class FeatureHooks:
self._handles = []
modules = {k: v for k, v in named_modules}
for i, h in enumerate(hooks):
hook_name = h['module']
hook_name = h if isinstance(h, str) else h['module']
m = modules[hook_name]
hook_id = out_map[i] if out_map else hook_name
hook_fn = partial(self._collect_output_hook, hook_id)
hook_type = h.get('hook_type', default_hook_type)
hook_type = default_hook_type
if isinstance(h, dict):
hook_type = h.get('hook_type', default_hook_type)
if hook_type == 'forward_pre':
handle = m.register_forward_pre_hook(hook_fn)
elif hook_type == 'forward':

View File

@ -9,7 +9,9 @@ from torch import nn
from ._features import _get_feature_info, _get_return_layers
try:
# NOTE we wrap torchvision fns to use timm leaf / no trace definitions
from torchvision.models.feature_extraction import create_feature_extractor as _create_feature_extractor
from torchvision.models.feature_extraction import get_graph_node_names as _get_graph_node_names
has_fx_feature_extraction = True
except ImportError:
has_fx_feature_extraction = False
@ -30,7 +32,7 @@ from timm.layers.norm_act import (
__all__ = ['register_notrace_module', 'is_notrace_module', 'get_notrace_modules',
'register_notrace_function', 'is_notrace_function', 'get_notrace_functions',
'create_feature_extractor', 'FeatureGraphNet', 'GraphExtractNet']
'create_feature_extractor', 'get_graph_node_names', 'FeatureGraphNet', 'GraphExtractNet']
# NOTE: By default, any modules from timm.models.layers that we want to treat as leaf modules go here
@ -92,6 +94,13 @@ def get_notrace_functions():
return list(_autowrap_functions)
def get_graph_node_names(model: nn.Module) -> Tuple[List[str], List[str]]:
return _get_graph_node_names(
model,
tracer_kwargs={'leaf_modules': list(_leaf_modules), 'autowrap_functions': list(_autowrap_functions)}
)
def create_feature_extractor(model: nn.Module, return_nodes: Union[Dict[str, str], List[str]]):
assert has_fx_feature_extraction, 'Please update to PyTorch 1.10+, torchvision 0.11+ for FX feature extraction'
return _create_feature_extractor(

View File

@ -40,6 +40,7 @@ from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
from timm.layers import ClassifierHead, ConvNormAct, BatchNormAct2d, DropPath, AvgPool2dSame, \
create_conv2d, get_act_layer, get_norm_act_layer, get_attn, make_divisible, to_2tuple, EvoNorm2dS0a
from ._builder import build_model_with_cfg
from ._features import feature_take_indices
from ._manipulate import named_apply, checkpoint_seq
from ._registry import generate_default_cfgs, register_model
@ -948,25 +949,37 @@ class Stem(nn.Sequential):
stem_norm_acts = [False] * (num_rep - num_act) + [True] * num_act
prev_chs = in_chs
curr_stride = 1
last_feat_idx = -1
for i, (ch, s, na) in enumerate(zip(stem_chs, stem_strides, stem_norm_acts)):
layer_fn = layers.conv_norm_act if na else create_conv2d
conv_name = f'conv{i + 1}'
if i > 0 and s > 1:
self.feature_info.append(dict(num_chs=prev_chs, reduction=curr_stride, module=prev_feat))
last_feat_idx = i - 1
self.feature_info.append(dict(num_chs=prev_chs, reduction=curr_stride, module=prev_feat, stage=0))
self.add_module(conv_name, layer_fn(prev_chs, ch, kernel_size=kernel_size, stride=s))
prev_chs = ch
curr_stride *= s
prev_feat = conv_name
if pool and 'max' in pool.lower():
self.feature_info.append(dict(num_chs=prev_chs, reduction=curr_stride, module=prev_feat))
last_feat_idx = i
self.feature_info.append(dict(num_chs=prev_chs, reduction=curr_stride, module=prev_feat, stage=0))
self.add_module('pool', nn.MaxPool2d(3, 2, 1))
curr_stride *= 2
prev_feat = 'pool'
self.feature_info.append(dict(num_chs=prev_chs, reduction=curr_stride, module=prev_feat))
self.last_feat_idx = last_feat_idx if last_feat_idx >= 0 else None
self.feature_info.append(dict(num_chs=prev_chs, reduction=curr_stride, module=prev_feat, stage=0))
assert curr_stride == stride
def forward_intermediates(self, x) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
intermediate: Optional[torch.Tensor] = None
for i, m in enumerate(self):
x = m(x)
if self.last_feat_idx is not None and i == self.last_feat_idx:
intermediate = x
return x, intermediate
def create_byob_stem(
in_chs: int,
@ -1008,7 +1021,7 @@ def create_byob_stem(
if isinstance(stem, Stem):
feature_info = [dict(f, module='.'.join([feat_prefix, f['module']])) for f in stem.feature_info]
else:
feature_info = [dict(num_chs=out_chs, reduction=2, module=feat_prefix)]
feature_info = [dict(num_chs=out_chs, reduction=2, module=feat_prefix, stage=0)]
return stem, feature_info
@ -1122,7 +1135,7 @@ def create_byob_stages(
feat_size = reduce_feat_size(feat_size, stride)
stages += [nn.Sequential(*blocks)]
prev_feat = dict(num_chs=prev_chs, reduction=net_stride, module=f'stages.{stage_idx}')
prev_feat = dict(num_chs=prev_chs, reduction=net_stride, module=f'stages.{stage_idx}', stage=stage_idx + 1)
feature_info.append(prev_feat)
return nn.Sequential(*stages), feature_info
@ -1198,6 +1211,7 @@ class ByobNet(nn.Module):
feat_size=feat_size,
)
self.feature_info.extend(stage_feat[:-1])
reduction = stage_feat[-1]['reduction']
prev_chs = stage_feat[-1]['num_chs']
if cfg.num_features:
@ -1207,7 +1221,8 @@ class ByobNet(nn.Module):
self.num_features = prev_chs
self.final_conv = nn.Identity()
self.feature_info += [
dict(num_chs=self.num_features, reduction=stage_feat[-1]['reduction'], module='final_conv')]
dict(num_chs=self.num_features, reduction=reduction, module='final_conv', stage=len(self.stages))]
self.stage_ends = [f['stage'] for f in self.feature_info]
self.head = ClassifierHead(
self.num_features,
@ -1241,6 +1256,83 @@ class ByobNet(nn.Module):
def reset_classifier(self, num_classes, global_pool='avg'):
self.head.reset(num_classes, global_pool)
def forward_intermediates(
self,
x: torch.Tensor,
indices: Optional[Union[int, List[int], Tuple[int]]] = None,
norm: bool = False,
stop_early: bool = False,
output_fmt: str = 'NCHW',
intermediates_only: bool = False,
exclude_final_conv: bool = False,
) -> Union[List[torch.Tensor], Tuple[torch.Tensor, List[torch.Tensor]]]:
""" Forward features that returns intermediates.
Args:
x: Input image tensor
indices: Take last n blocks if int, all if None, select matching indices if sequence
norm: Apply norm layer to compatible intermediates
stop_early: Stop iterating over blocks when last desired intermediate hit
output_fmt: Shape of intermediate feature outputs
intermediates_only: Only return intermediate features
exclude_final_conv: Exclude final_conv from last intermediate
Returns:
"""
assert output_fmt in ('NCHW',), 'Output shape must be NCHW.'
intermediates = []
take_indices, max_index = feature_take_indices(len(self.stage_ends), indices)
take_indices = [self.stage_ends[i] for i in take_indices]
max_index = self.stage_ends[max_index]
# forward pass
feat_idx = 0 # stem is index 0
if hasattr(self.stem, 'forward_intermediates'):
# returns last intermediate features in stem (before final stride in stride > 2 stems)
x, x_inter = self.stem.forward_intermediates(x)
else:
x, x_inter = self.stem(x), None
if feat_idx in take_indices:
intermediates.append(x if x_inter is None else x_inter)
last_idx = self.stage_ends[-1]
if torch.jit.is_scripting() or not stop_early: # can't slice blocks in torchscript
stages = self.stages
else:
stages = self.stages[:max_index]
for stage in stages:
feat_idx += 1
x = stage(x)
if not exclude_final_conv and feat_idx == last_idx:
# default feature_info for this model uses final_conv as the last feature output (if present)
x = self.final_conv(x)
if feat_idx in take_indices:
intermediates.append(x)
if intermediates_only:
return intermediates
if exclude_final_conv and feat_idx == last_idx:
x = self.final_conv(x)
return x, intermediates
def prune_intermediate_layers(
self,
indices: Union[int, List[int], Tuple[int]] = 1,
prune_norm: bool = False,
prune_head: bool = True,
):
""" Prune layers not required for specified intermediates.
"""
take_indices, max_index = feature_take_indices(len(self.stage_ends), indices)
max_index = self.stage_ends[max_index]
self.stages = self.stages[:max_index] # truncate blocks w/ stem as idx 0
if max_index < self.stage_ends[-1]:
self.final_conv = nn.Identity()
if prune_head:
self.reset_classifier(0, '')
return take_indices
def forward_features(self, x):
x = self.stem(x)
if self.grad_checkpointing and not torch.jit.is_scripting():

View File

@ -411,7 +411,7 @@ class ConvNeXt(nn.Module):
def forward_intermediates(
self,
x: torch.Tensor,
indices: Union[int, List[int], Tuple[int]] = None,
indices: Optional[Union[int, List[int], Tuple[int]]] = None,
norm: bool = False,
stop_early: bool = False,
output_fmt: str = 'NCHW',

View File

@ -126,9 +126,9 @@ class ChannelAttention(nn.Module):
q, k, v = qkv.unbind(0)
k = k * self.scale
attention = k.transpose(-1, -2) @ v
attention = attention.softmax(dim=-1)
x = (attention @ q.transpose(-1, -2)).transpose(-1, -2)
attn = k.transpose(-1, -2) @ v
attn = attn.softmax(dim=-1)
x = (attn @ q.transpose(-1, -2)).transpose(-1, -2)
x = x.transpose(1, 2).reshape(B, N, C)
x = self.proj(x)
return x

View File

@ -12,7 +12,7 @@ Based on Apache 2.0 licensed code at https://github.com/snap-research/EfficientF
Modifications and timm support by / Copyright 2022, Ross Wightman
"""
from typing import Dict, List, Tuple, Union
from typing import Dict, List, Optional, Tuple, Union
import torch
import torch.nn as nn
@ -463,7 +463,7 @@ class EfficientFormer(nn.Module):
def forward_intermediates(
self,
x: torch.Tensor,
indices: Union[int, List[int], Tuple[int]] = None,
indices: Optional[Union[int, List[int], Tuple[int]]] = None,
norm: bool = False,
stop_early: bool = False,
output_fmt: str = 'NCHW',

View File

@ -162,7 +162,7 @@ class EfficientNet(nn.Module):
def forward_intermediates(
self,
x: torch.Tensor,
indices: Union[int, List[int], Tuple[int]] = None,
indices: Optional[Union[int, List[int], Tuple[int]]] = None,
norm: bool = False,
stop_early: bool = False,
output_fmt: str = 'NCHW',
@ -183,8 +183,6 @@ class EfficientNet(nn.Module):
"""
assert output_fmt in ('NCHW',), 'Output shape must be NCHW.'
if stop_early:
assert intermediates_only, 'Must use intermediates_only for early stopping.'
intermediates = []
if extra_blocks:
take_indices, max_index = feature_take_indices(len(self.blocks) + 1, indices)
@ -212,8 +210,9 @@ class EfficientNet(nn.Module):
if intermediates_only:
return intermediates
x = self.conv_head(x)
x = self.bn2(x)
if feat_idx == self.stage_ends[-1]:
x = self.conv_head(x)
x = self.bn2(x)
return x, intermediates

View File

@ -717,12 +717,6 @@ def checkpoint_filter_fn(
# fixed embedding no need to load buffer from checkpoint
continue
# FIXME here while import new weights, to remove
# if k == 'cls_token':
# print('DEBUG: cls token -> reg')
# k = 'reg_token'
# #v = v + state_dict['pos_embed'][0, :]
if 'patch_embed.proj.weight' in k:
_, _, H, W = model.patch_embed.proj.weight.shape
if v.shape[-1] != W or v.shape[-2] != H:
@ -951,26 +945,22 @@ default_cfgs = generate_default_cfgs({
num_classes=0,
),
'vit_medium_patch16_rope_reg1_gap_256.in1k': _cfg(
#hf_hub_id='timm/',
#file='vit_medium_gap1_rope-in1k-20230920-5.pth',
'vit_medium_patch16_rope_reg1_gap_256.sbb_in1k': _cfg(
hf_hub_id='timm/',
input_size=(3, 256, 256), crop_pct=0.95,
mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)
),
'vit_mediumd_patch16_rope_reg1_gap_256.in1k': _cfg(
#hf_hub_id='timm/',
#file='vit_mediumd_gap1_rope-in1k-20230926-5.pth',
'vit_mediumd_patch16_rope_reg1_gap_256.sbb_in1k': _cfg(
hf_hub_id='timm/',
input_size=(3, 256, 256), crop_pct=0.95,
mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)
),
'vit_betwixt_patch16_rope_reg4_gap_256.in1k': _cfg(
#hf_hub_id='timm/',
#file='vit_betwixt_gap4_rope-in1k-20231005-5.pth',
'vit_betwixt_patch16_rope_reg4_gap_256.sbb_in1k': _cfg(
hf_hub_id='timm/',
input_size=(3, 256, 256), crop_pct=0.95,
),
'vit_base_patch16_rope_reg1_gap_256.in1k': _cfg(
#hf_hub_id='timm/',
#file='vit_base_gap1_rope-in1k-20230930-5.pth',
'vit_base_patch16_rope_reg1_gap_256.sbb_in1k': _cfg(
hf_hub_id='timm/',
input_size=(3, 256, 256), crop_pct=0.95,
mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)
),

View File

@ -37,7 +37,6 @@ from timm.layers.pool2d_same import AvgPool2dSame, create_pool2d
from timm.layers.squeeze_excite import SEModule, SqueezeExcite, EffectiveSEModule, EffectiveSqueezeExcite
from timm.layers.selective_kernel import SelectiveKernel
from timm.layers.separable_conv import SeparableConv2d, SeparableConvNormAct
from timm.layers.space_to_depth import SpaceToDepthModule
from timm.layers.split_attn import SplitAttn
from timm.layers.split_batchnorm import SplitBatchNorm2d, convert_splitbn_model
from timm.layers.std_conv import StdConv2d, StdConv2dSame, ScaledStdConv2d, ScaledStdConv2dSame

View File

@ -25,7 +25,7 @@ Modifications and additions for timm hacked together by / Copyright 2021, Ross W
# Copyright 2020 Ross Wightman, Apache-2.0 License
from collections import OrderedDict
from functools import partial
from typing import Dict, List, Tuple, Union
from typing import Dict, List, Optional, Tuple, Union
import torch
import torch.nn as nn
@ -638,7 +638,7 @@ class Levit(nn.Module):
def forward_intermediates(
self,
x: torch.Tensor,
indices: Union[int, List[int], Tuple[int]] = None,
indices: Optional[Union[int, List[int], Tuple[int]]] = None,
norm: bool = False,
stop_early: bool = False,
output_fmt: str = 'NCHW',

View File

@ -1255,7 +1255,7 @@ class MaxxVit(nn.Module):
def forward_intermediates(
self,
x: torch.Tensor,
indices: Union[int, List[int], Tuple[int]] = None,
indices: Optional[Union[int, List[int], Tuple[int]]] = None,
norm: bool = False,
stop_early: bool = False,
output_fmt: str = 'NCHW',

View File

@ -40,6 +40,7 @@ Hacked together by / Copyright 2021 Ross Wightman
"""
import math
from functools import partial
from typing import List, Optional, Union, Tuple
import torch
import torch.nn as nn
@ -47,6 +48,7 @@ import torch.nn as nn
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
from timm.layers import PatchEmbed, Mlp, GluMlp, GatedMlp, DropPath, lecun_normal_, to_2tuple
from ._builder import build_model_with_cfg
from ._features import feature_take_indices
from ._manipulate import named_apply, checkpoint_seq
from ._registry import generate_default_cfgs, register_model, register_model_deprecations
@ -211,6 +213,7 @@ class MlpMixer(nn.Module):
embed_dim=embed_dim,
norm_layer=norm_layer if stem_norm else None,
)
reduction = self.stem.feat_ratio() if hasattr(self.stem, 'feat_ratio') else patch_size
# FIXME drop_path (stochastic depth scaling rule or all the same?)
self.blocks = nn.Sequential(*[
block_layer(
@ -224,6 +227,8 @@ class MlpMixer(nn.Module):
drop_path=drop_path_rate,
)
for _ in range(num_blocks)])
self.feature_info = [
dict(module=f'blocks.{i}', num_chs=embed_dim, reduction=reduction) for i in range(num_blocks)]
self.norm = norm_layer(embed_dim)
self.head_drop = nn.Dropout(drop_rate)
self.head = nn.Linear(embed_dim, self.num_classes) if num_classes > 0 else nn.Identity()
@ -257,6 +262,76 @@ class MlpMixer(nn.Module):
self.global_pool = global_pool
self.head = nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity()
def forward_intermediates(
self,
x: torch.Tensor,
indices: Optional[Union[int, List[int], Tuple[int]]] = None,
norm: bool = False,
stop_early: bool = False,
output_fmt: str = 'NCHW',
intermediates_only: bool = False,
) -> Union[List[torch.Tensor], Tuple[torch.Tensor, List[torch.Tensor]]]:
""" Forward features that returns intermediates.
Args:
x: Input image tensor
indices: Take last n blocks if int, all if None, select matching indices if sequence
return_prefix_tokens: Return both prefix and spatial intermediate tokens
norm: Apply norm layer to all intermediates
stop_early: Stop iterating over blocks when last desired intermediate hit
output_fmt: Shape of intermediate feature outputs
intermediates_only: Only return intermediate features
Returns:
"""
assert output_fmt in ('NCHW', 'NLC'), 'Output format must be one of NCHW or NLC.'
reshape = output_fmt == 'NCHW'
intermediates = []
take_indices, max_index = feature_take_indices(len(self.blocks), indices)
# forward pass
B, _, height, width = x.shape
x = self.stem(x)
if torch.jit.is_scripting() or not stop_early: # can't slice blocks in torchscript
blocks = self.blocks
else:
blocks = self.blocks[:max_index + 1]
for i, blk in enumerate(blocks):
x = blk(x)
if i in take_indices:
# normalize intermediates with final norm layer if enabled
intermediates.append(self.norm(x) if norm else x)
# process intermediates
if reshape:
# reshape to BCHW output format
H, W = self.stem.dynamic_feat_size((height, width))
intermediates = [y.reshape(B, H, W, -1).permute(0, 3, 1, 2).contiguous() for y in intermediates]
if intermediates_only:
return intermediates
x = self.norm(x)
return x, intermediates
def prune_intermediate_layers(
self,
indices: Union[int, List[int], Tuple[int]] = 1,
prune_norm: bool = False,
prune_head: bool = True,
):
""" Prune layers not required for specified intermediates.
"""
take_indices, max_index = feature_take_indices(len(self.blocks), indices)
self.blocks = self.blocks[:max_index + 1] # truncate blocks
if prune_norm:
self.norm = nn.Identity()
if prune_head:
self.reset_classifier(0, '')
return take_indices
def forward_features(self, x):
x = self.stem(x)
if self.grad_checkpointing and not torch.jit.is_scripting():
@ -330,14 +405,13 @@ def checkpoint_filter_fn(state_dict, model):
def _create_mixer(variant, pretrained=False, **kwargs):
if kwargs.get('features_only', None):
raise RuntimeError('features_only not implemented for MLP-Mixer models.')
out_indices = kwargs.pop('out_indices', 3)
model = build_model_with_cfg(
MlpMixer,
variant,
pretrained,
pretrained_filter_fn=checkpoint_filter_fn,
feature_cfg=dict(out_indices=out_indices, feature_cls='getter'),
**kwargs,
)
return model

View File

@ -154,7 +154,7 @@ class MobileNetV3(nn.Module):
def forward_intermediates(
self,
x: torch.Tensor,
indices: Union[int, List[int], Tuple[int]] = None,
indices: Optional[Union[int, List[int], Tuple[int]]] = None,
norm: bool = False,
stop_early: bool = False,
output_fmt: str = 'NCHW',

View File

@ -837,7 +837,7 @@ class MultiScaleVit(nn.Module):
def forward_intermediates(
self,
x: torch.Tensor,
indices: Union[int, List[int], Tuple[int]] = None,
indices: Optional[Union[int, List[int], Tuple[int]]] = None,
norm: bool = False,
stop_early: bool = False,
output_fmt: str = 'NCHW',

View File

@ -26,7 +26,7 @@ Hacked together by / Copyright 2020 Ross Wightman
import math
from dataclasses import dataclass, replace
from functools import partial
from typing import Optional, Union, Callable
from typing import Callable, List, Optional, Union, Tuple
import numpy as np
import torch
@ -36,6 +36,7 @@ from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
from timm.layers import ClassifierHead, AvgPool2dSame, ConvNormAct, SEModule, DropPath, GroupNormAct
from timm.layers import get_act_layer, get_norm_act_layer, create_conv2d, make_divisible
from ._builder import build_model_with_cfg
from ._features import feature_take_indices
from ._manipulate import checkpoint_seq, named_apply
from ._registry import generate_default_cfgs, register_model, register_model_deprecations
@ -515,6 +516,73 @@ class RegNet(nn.Module):
def reset_classifier(self, num_classes, global_pool='avg'):
self.head.reset(num_classes, pool_type=global_pool)
def forward_intermediates(
self,
x: torch.Tensor,
indices: Optional[Union[int, List[int], Tuple[int]]] = None,
norm: bool = False,
stop_early: bool = False,
output_fmt: str = 'NCHW',
intermediates_only: bool = False,
) -> Union[List[torch.Tensor], Tuple[torch.Tensor, List[torch.Tensor]]]:
""" Forward features that returns intermediates.
Args:
x: Input image tensor
indices: Take last n blocks if int, all if None, select matching indices if sequence
norm: Apply norm layer to compatible intermediates
stop_early: Stop iterating over blocks when last desired intermediate hit
output_fmt: Shape of intermediate feature outputs
intermediates_only: Only return intermediate features
Returns:
"""
assert output_fmt in ('NCHW',), 'Output shape must be NCHW.'
intermediates = []
take_indices, max_index = feature_take_indices(5, indices)
# forward pass
feat_idx = 0
x = self.stem(x)
if feat_idx in take_indices:
intermediates.append(x)
layer_names = ('s1', 's2', 's3', 's4')
if stop_early:
layer_names = layer_names[:max_index]
for n in layer_names:
feat_idx += 1
x = getattr(self, n)(x) # won't work with torchscript, but keeps code reasonable, FML
if feat_idx in take_indices:
intermediates.append(x)
if intermediates_only:
return intermediates
if feat_idx == 4:
x = self.final_conv(x)
return x, intermediates
def prune_intermediate_layers(
self,
indices: Union[int, List[int], Tuple[int]] = 1,
prune_norm: bool = False,
prune_head: bool = True,
):
""" Prune layers not required for specified intermediates.
"""
take_indices, max_index = feature_take_indices(5, indices)
layer_names = ('s1', 's2', 's3', 's4')
layer_names = layer_names[max_index:]
for n in layer_names:
setattr(self, n, nn.Identity())
if max_index < 4:
self.final_conv = nn.Identity()
if prune_head:
self.reset_classifier(0, '')
return take_indices
def forward_features(self, x):
x = self.stem(x)
x = self.s1(x)

View File

@ -557,7 +557,7 @@ class ResNet(nn.Module):
def forward_intermediates(
self,
x: torch.Tensor,
indices: Union[int, List[int], Tuple[int]] = None,
indices: Optional[Union[int, List[int], Tuple[int]]] = None,
norm: bool = False,
stop_early: bool = False,
output_fmt: str = 'NCHW',
@ -576,8 +576,6 @@ class ResNet(nn.Module):
"""
assert output_fmt in ('NCHW',), 'Output shape must be NCHW.'
if stop_early:
assert intermediates_only, 'Must use intermediates_only for early stopping.'
intermediates = []
take_indices, max_index = feature_take_indices(5, indices)

View File

@ -611,7 +611,7 @@ class SwinTransformer(nn.Module):
def forward_intermediates(
self,
x: torch.Tensor,
indices: Union[int, List[int], Tuple[int]] = None,
indices: Optional[Union[int, List[int], Tuple[int]]] = None,
norm: bool = False,
stop_early: bool = False,
output_fmt: str = 'NCHW',

View File

@ -612,7 +612,7 @@ class SwinTransformerV2(nn.Module):
def forward_intermediates(
self,
x: torch.Tensor,
indices: Union[int, List[int], Tuple[int]] = None,
indices: Optional[Union[int, List[int], Tuple[int]]] = None,
norm: bool = False,
stop_early: bool = False,
output_fmt: str = 'NCHW',

View File

@ -722,7 +722,7 @@ class SwinTransformerV2Cr(nn.Module):
def forward_intermediates(
self,
x: torch.Tensor,
indices: Union[int, List[int], Tuple[int]] = None,
indices: Optional[Union[int, List[int], Tuple[int]]] = None,
norm: bool = False,
stop_early: bool = False,
output_fmt: str = 'NCHW',

View File

@ -407,7 +407,7 @@ class Twins(nn.Module):
def forward_intermediates(
self,
x: torch.Tensor,
indices: Union[int, List[int], Tuple[int]] = None,
indices: Optional[Union[int, List[int], Tuple[int]]] = None,
norm: bool = False,
stop_early: bool = False,
output_fmt: str = 'NCHW',

View File

@ -489,7 +489,7 @@ class VisionTransformer(nn.Module):
**embed_args,
)
num_patches = self.patch_embed.num_patches
r = self.patch_embed.feat_ratio() if hasattr(self.patch_embed, 'feat_ratio') else patch_size
reduction = self.patch_embed.feat_ratio() if hasattr(self.patch_embed, 'feat_ratio') else patch_size
self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) if class_token else None
self.reg_token = nn.Parameter(torch.zeros(1, reg_tokens, embed_dim)) if reg_tokens else None
@ -523,7 +523,7 @@ class VisionTransformer(nn.Module):
)
for i in range(depth)])
self.feature_info = [
dict(module=f'blocks.{i}', num_chs=embed_dim, reduction=r) for i in range(depth)]
dict(module=f'blocks.{i}', num_chs=embed_dim, reduction=reduction) for i in range(depth)]
self.norm = norm_layer(embed_dim) if not use_fc_norm else nn.Identity()
# Classifier Head
@ -1790,23 +1790,39 @@ default_cfgs = {
license='mit',
mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, num_classes=512),
'vit_wee_patch16_reg1_gap_256': _cfg(
#file='',
'vit_wee_patch16_reg1_gap_256.sbb_in1k': _cfg(
input_size=(3, 256, 256), crop_pct=0.95),
'vit_little_patch16_reg4_gap_256': _cfg(
#file='',
'vit_pwee_patch16_reg1_gap_256.sbb_in1k': _cfg(
hf_hub_id='timm/',
input_size=(3, 256, 256), crop_pct=0.95),
'vit_medium_patch16_reg1_gap_256': _cfg(
#file='vit_medium_gap1-in1k-20231118-8.pth',
'vit_little_patch16_reg4_gap_256.sbb_in1k': _cfg(
hf_hub_id='timm/',
input_size=(3, 256, 256), crop_pct=0.95),
'vit_medium_patch16_reg4_gap_256': _cfg(
#file='vit_medium_gap4-in1k-20231115-8.pth',
'vit_medium_patch16_reg1_gap_256.sbb_in1k': _cfg(
hf_hub_id='timm/',
input_size=(3, 256, 256), crop_pct=0.95),
'vit_betwixt_patch16_reg1_gap_256': _cfg(
#file='vit_betwixt_gap1-in1k-20231121-8.pth',
'vit_medium_patch16_reg4_gap_256.sbb_in1k': _cfg(
hf_hub_id='timm/',
input_size=(3, 256, 256), crop_pct=0.95),
'vit_betwixt_patch16_reg4_gap_256': _cfg(
#file='vit_betwixt_gap4-in1k-20231106-8.pth',
'vit_mediumd_patch16_reg4_gap_256.sbb_in12k_ft_in1k': _cfg(
hf_hub_id='timm/',
input_size=(3, 256, 256), crop_pct=0.95),
'vit_mediumd_patch16_reg4_gap_256.sbb_in12k': _cfg(
hf_hub_id='timm/',
num_classes=11821,
input_size=(3, 256, 256), crop_pct=0.95),
'vit_betwixt_patch16_reg1_gap_256.sbb_in1k': _cfg(
hf_hub_id='timm/',
input_size=(3, 256, 256), crop_pct=0.95),
'vit_betwixt_patch16_reg4_gap_256.sbb_in12k_ft_in1k': _cfg(
hf_hub_id='timm/',
input_size=(3, 256, 256), crop_pct=0.95),
'vit_betwixt_patch16_reg4_gap_256.sbb_in1k': _cfg(
hf_hub_id='timm/',
input_size=(3, 256, 256), crop_pct=0.95),
'vit_betwixt_patch16_reg4_gap_256.sbb_in12k': _cfg(
hf_hub_id='timm/',
num_classes=11821,
input_size=(3, 256, 256), crop_pct=0.95),
'vit_base_patch16_reg4_gap_256': _cfg(
input_size=(3, 256, 256)),
@ -2755,10 +2771,21 @@ def vit_so400m_patch14_siglip_384(pretrained: bool = False, **kwargs) -> VisionT
def vit_wee_patch16_reg1_gap_256(pretrained: bool = False, **kwargs) -> VisionTransformer:
model_args = dict(
patch_size=16, embed_dim=256, depth=14, num_heads=4, init_values=1e-5, mlp_ratio=5,
class_token=False, no_embed_class=True, reg_tokens=1, global_pool='avg',
)
model = _create_vision_transformer(
'vit_wee_patch16_reg1_gap_256', pretrained=pretrained, **dict(model_args, **kwargs))
return model
@register_model
def vit_pwee_patch16_reg1_gap_256(pretrained: bool = False, **kwargs) -> VisionTransformer:
model_args = dict(
patch_size=16, embed_dim=256, depth=16, num_heads=4, init_values=1e-5, mlp_ratio=5,
class_token=False, no_embed_class=True, reg_tokens=1, global_pool='avg', block_fn=ParallelScalingBlock,
)
model = _create_vision_transformer(
'vit_medium_patch16_reg1_gap_256', pretrained=pretrained, **dict(model_args, **kwargs))
'vit_pwee_patch16_reg1_gap_256', pretrained=pretrained, **dict(model_args, **kwargs))
return model
@ -2769,7 +2796,7 @@ def vit_little_patch16_reg4_gap_256(pretrained: bool = False, **kwargs) -> Visio
class_token=False, no_embed_class=True, reg_tokens=4, global_pool='avg',
)
model = _create_vision_transformer(
'vit_medium_patch16_reg1_gap_256', pretrained=pretrained, **dict(model_args, **kwargs))
'vit_little_patch16_reg4_gap_256', pretrained=pretrained, **dict(model_args, **kwargs))
return model
@ -2795,6 +2822,17 @@ def vit_medium_patch16_reg4_gap_256(pretrained: bool = False, **kwargs) -> Visio
return model
@register_model
def vit_mediumd_patch16_reg4_gap_256(pretrained: bool = False, **kwargs) -> VisionTransformer:
model_args = dict(
patch_size=16, embed_dim=512, depth=20, num_heads=8, init_values=1e-5,
class_token=False, no_embed_class=True, reg_tokens=4, global_pool='avg',
)
model = _create_vision_transformer(
'vit_mediumd_patch16_reg4_gap_256', pretrained=pretrained, **dict(model_args, **kwargs))
return model
@register_model
def vit_betwixt_patch16_reg1_gap_256(pretrained: bool = False, **kwargs) -> VisionTransformer:
model_args = dict(

View File

@ -543,7 +543,7 @@ class VisionTransformerSAM(nn.Module):
def forward_intermediates(
self,
x: torch.Tensor,
indices: Union[int, List[int], Tuple[int]] = None,
indices: Optional[Union[int, List[int], Tuple[int]]] = None,
norm: bool = False,
stop_early: bool = False,
output_fmt: str = 'NCHW',
@ -598,7 +598,7 @@ class VisionTransformerSAM(nn.Module):
def prune_intermediate_layers(
self,
indices: Union[int, List[int], Tuple[int]] = None,
indices: Optional[Union[int, List[int], Tuple[int]]] = None,
prune_norm: bool = False,
prune_head: bool = True,
):

View File

@ -1,4 +1,5 @@
from .agc import adaptive_clip_grad
from .attention_extract import AttentionExtract
from .checkpoint_saver import CheckpointSaver
from .clip_grad import dispatch_clip_grad
from .cuda import ApexScaler, NativeScaler

View File

@ -0,0 +1,79 @@
import fnmatch
from collections import OrderedDict
from typing import Union, Optional, List
import torch
class AttentionExtract(torch.nn.Module):
# defaults should cover a significant number of timm models with attention maps.
default_node_names = ['*attn.softmax']
default_module_names = ['*attn_drop']
def __init__(
self,
model: Union[torch.nn.Module],
names: Optional[List[str]] = None,
mode: str = 'eval',
method: str = 'fx',
hook_type: str = 'forward',
):
""" Extract attention maps (or other activations) from a model by name.
Args:
model: Instantiated model to extract from.
names: List of concrete or wildcard names to extract. Names are nodes for fx and modules for hooks.
mode: 'train' or 'eval' model mode.
method: 'fx' or 'hook' extraction method.
hook_type: 'forward' or 'forward_pre' hooks used.
"""
super().__init__()
assert mode in ('train', 'eval')
if mode == 'train':
model = model.train()
else:
model = model.eval()
assert method in ('fx', 'hook')
if method == 'fx':
# names are activation node names
from timm.models._features_fx import get_graph_node_names, GraphExtractNet
node_names = get_graph_node_names(model)[0 if mode == 'train' else 1]
matched = []
names = names or self.default_node_names
for n in names:
matched.extend(fnmatch.filter(node_names, n))
if not matched:
raise RuntimeError(f'No node names found matching {names}.')
self.model = GraphExtractNet(model, matched)
self.hooks = None
else:
# names are module names
assert hook_type in ('forward', 'forward_pre')
from timm.models._features import FeatureHooks
module_names = [n for n, m in model.named_modules()]
matched = []
names = names or self.default_module_names
for n in names:
matched.extend(fnmatch.filter(module_names, n))
if not matched:
raise RuntimeError(f'No module names found matching {names}.')
self.model = model
self.hooks = FeatureHooks(matched, model.named_modules(), default_hook_type=hook_type)
self.names = matched
self.mode = mode
self.method = method
def forward(self, x):
if self.hooks is not None:
self.model(x)
output = self.hooks.get_output(device=x.device)
else:
output = self.model(x)
output = OrderedDict(zip(self.names, output))
return output

View File

@ -1 +1 @@
__version__ = '1.0.0.dev0'
__version__ = '1.0.1.dev0'