Merge pull request #2168 from huggingface/more_vit_better_getter_redux
A few more features_intermediate() models, AttentionExtract helper, related minor cleanup.pull/2167/merge
commit
b2c10fec05
|
@ -51,7 +51,8 @@ if hasattr(torch._C, '_jit_set_profiling_executor'):
|
||||||
FEAT_INTER_FILTERS = [
|
FEAT_INTER_FILTERS = [
|
||||||
'vision_transformer', 'vision_transformer_sam', 'vision_transformer_hybrid', 'vision_transformer_relpos',
|
'vision_transformer', 'vision_transformer_sam', 'vision_transformer_hybrid', 'vision_transformer_relpos',
|
||||||
'beit', 'mvitv2', 'eva', 'cait', 'xcit', 'volo', 'twins', 'deit', 'swin_transformer', 'swin_transformer_v2',
|
'beit', 'mvitv2', 'eva', 'cait', 'xcit', 'volo', 'twins', 'deit', 'swin_transformer', 'swin_transformer_v2',
|
||||||
'swin_transformer_v2_cr', 'maxxvit', 'efficientnet', 'mobilenetv3', 'levit', 'efficientformer', 'resnet'
|
'swin_transformer_v2_cr', 'maxxvit', 'efficientnet', 'mobilenetv3', 'levit', 'efficientformer', 'resnet',
|
||||||
|
'regnet', 'byobnet', 'byoanet', 'mlp_mixer'
|
||||||
]
|
]
|
||||||
|
|
||||||
# transformer / hybrid models don't support full set of spatial / feature APIs and/or have spatial output.
|
# transformer / hybrid models don't support full set of spatial / feature APIs and/or have spatial output.
|
||||||
|
|
|
@ -47,7 +47,7 @@ from .pos_embed_sincos import pixel_freq_bands, freq_bands, build_sincos2d_pos_e
|
||||||
from .squeeze_excite import SEModule, SqueezeExcite, EffectiveSEModule, EffectiveSqueezeExcite
|
from .squeeze_excite import SEModule, SqueezeExcite, EffectiveSEModule, EffectiveSqueezeExcite
|
||||||
from .selective_kernel import SelectiveKernel
|
from .selective_kernel import SelectiveKernel
|
||||||
from .separable_conv import SeparableConv2d, SeparableConvNormAct
|
from .separable_conv import SeparableConv2d, SeparableConvNormAct
|
||||||
from .space_to_depth import SpaceToDepthModule, SpaceToDepth, DepthToSpace
|
from .space_to_depth import SpaceToDepth, DepthToSpace
|
||||||
from .split_attn import SplitAttn
|
from .split_attn import SplitAttn
|
||||||
from .split_batchnorm import SplitBatchNorm2d, convert_splitbn_model
|
from .split_batchnorm import SplitBatchNorm2d, convert_splitbn_model
|
||||||
from .std_conv import StdConv2d, StdConv2dSame, ScaledStdConv2d, ScaledStdConv2dSame
|
from .std_conv import StdConv2d, StdConv2dSame, ScaledStdConv2d, ScaledStdConv2dSame
|
||||||
|
|
|
@ -1,90 +0,0 @@
|
||||||
""" Activations
|
|
||||||
|
|
||||||
A collection of jit-scripted activations fn and modules with a common interface so that they can
|
|
||||||
easily be swapped. All have an `inplace` arg even if not used.
|
|
||||||
|
|
||||||
All jit scripted activations are lacking in-place variations on purpose, scripted kernel fusion does not
|
|
||||||
currently work across in-place op boundaries, thus performance is equal to or less than the non-scripted
|
|
||||||
versions if they contain in-place ops.
|
|
||||||
|
|
||||||
Hacked together by / Copyright 2020 Ross Wightman
|
|
||||||
"""
|
|
||||||
|
|
||||||
import torch
|
|
||||||
from torch import nn as nn
|
|
||||||
from torch.nn import functional as F
|
|
||||||
|
|
||||||
|
|
||||||
@torch.jit.script
|
|
||||||
def swish_jit(x, inplace: bool = False):
|
|
||||||
"""Swish - Described in: https://arxiv.org/abs/1710.05941
|
|
||||||
"""
|
|
||||||
return x.mul(x.sigmoid())
|
|
||||||
|
|
||||||
|
|
||||||
@torch.jit.script
|
|
||||||
def mish_jit(x, _inplace: bool = False):
|
|
||||||
"""Mish: A Self Regularized Non-Monotonic Neural Activation Function - https://arxiv.org/abs/1908.08681
|
|
||||||
"""
|
|
||||||
return x.mul(F.softplus(x).tanh())
|
|
||||||
|
|
||||||
|
|
||||||
class SwishJit(nn.Module):
|
|
||||||
def __init__(self, inplace: bool = False):
|
|
||||||
super(SwishJit, self).__init__()
|
|
||||||
|
|
||||||
def forward(self, x):
|
|
||||||
return swish_jit(x)
|
|
||||||
|
|
||||||
|
|
||||||
class MishJit(nn.Module):
|
|
||||||
def __init__(self, inplace: bool = False):
|
|
||||||
super(MishJit, self).__init__()
|
|
||||||
|
|
||||||
def forward(self, x):
|
|
||||||
return mish_jit(x)
|
|
||||||
|
|
||||||
|
|
||||||
@torch.jit.script
|
|
||||||
def hard_sigmoid_jit(x, inplace: bool = False):
|
|
||||||
# return F.relu6(x + 3.) / 6.
|
|
||||||
return (x + 3).clamp(min=0, max=6).div(6.) # clamp seems ever so slightly faster?
|
|
||||||
|
|
||||||
|
|
||||||
class HardSigmoidJit(nn.Module):
|
|
||||||
def __init__(self, inplace: bool = False):
|
|
||||||
super(HardSigmoidJit, self).__init__()
|
|
||||||
|
|
||||||
def forward(self, x):
|
|
||||||
return hard_sigmoid_jit(x)
|
|
||||||
|
|
||||||
|
|
||||||
@torch.jit.script
|
|
||||||
def hard_swish_jit(x, inplace: bool = False):
|
|
||||||
# return x * (F.relu6(x + 3.) / 6)
|
|
||||||
return x * (x + 3).clamp(min=0, max=6).div(6.) # clamp seems ever so slightly faster?
|
|
||||||
|
|
||||||
|
|
||||||
class HardSwishJit(nn.Module):
|
|
||||||
def __init__(self, inplace: bool = False):
|
|
||||||
super(HardSwishJit, self).__init__()
|
|
||||||
|
|
||||||
def forward(self, x):
|
|
||||||
return hard_swish_jit(x)
|
|
||||||
|
|
||||||
|
|
||||||
@torch.jit.script
|
|
||||||
def hard_mish_jit(x, inplace: bool = False):
|
|
||||||
""" Hard Mish
|
|
||||||
Experimental, based on notes by Mish author Diganta Misra at
|
|
||||||
https://github.com/digantamisra98/H-Mish/blob/0da20d4bc58e696b6803f2523c58d3c8a82782d0/README.md
|
|
||||||
"""
|
|
||||||
return 0.5 * x * (x + 2).clamp(min=0, max=2)
|
|
||||||
|
|
||||||
|
|
||||||
class HardMishJit(nn.Module):
|
|
||||||
def __init__(self, inplace: bool = False):
|
|
||||||
super(HardMishJit, self).__init__()
|
|
||||||
|
|
||||||
def forward(self, x):
|
|
||||||
return hard_mish_jit(x)
|
|
|
@ -3,8 +3,8 @@
|
||||||
A collection of activations fn and modules with a common interface so that they can
|
A collection of activations fn and modules with a common interface so that they can
|
||||||
easily be swapped. All have an `inplace` arg even if not used.
|
easily be swapped. All have an `inplace` arg even if not used.
|
||||||
|
|
||||||
These activations are not compatible with jit scripting or ONNX export of the model, please use either
|
These activations are not compatible with jit scripting or ONNX export of the model, please use
|
||||||
the JIT or basic versions of the activations.
|
basic versions of the activations.
|
||||||
|
|
||||||
Hacked together by / Copyright 2020 Ross Wightman
|
Hacked together by / Copyright 2020 Ross Wightman
|
||||||
"""
|
"""
|
||||||
|
@ -14,19 +14,17 @@ from torch import nn as nn
|
||||||
from torch.nn import functional as F
|
from torch.nn import functional as F
|
||||||
|
|
||||||
|
|
||||||
@torch.jit.script
|
def swish_fwd(x):
|
||||||
def swish_jit_fwd(x):
|
|
||||||
return x.mul(torch.sigmoid(x))
|
return x.mul(torch.sigmoid(x))
|
||||||
|
|
||||||
|
|
||||||
@torch.jit.script
|
def swish_bwd(x, grad_output):
|
||||||
def swish_jit_bwd(x, grad_output):
|
|
||||||
x_sigmoid = torch.sigmoid(x)
|
x_sigmoid = torch.sigmoid(x)
|
||||||
return grad_output * (x_sigmoid * (1 + x * (1 - x_sigmoid)))
|
return grad_output * (x_sigmoid * (1 + x * (1 - x_sigmoid)))
|
||||||
|
|
||||||
|
|
||||||
class SwishJitAutoFn(torch.autograd.Function):
|
class SwishAutoFn(torch.autograd.Function):
|
||||||
""" torch.jit.script optimised Swish w/ memory-efficient checkpoint
|
""" optimised Swish w/ memory-efficient checkpoint
|
||||||
Inspired by conversation btw Jeremy Howard & Adam Pazske
|
Inspired by conversation btw Jeremy Howard & Adam Pazske
|
||||||
https://twitter.com/jeremyphoward/status/1188251041835315200
|
https://twitter.com/jeremyphoward/status/1188251041835315200
|
||||||
"""
|
"""
|
||||||
|
@ -37,16 +35,16 @@ class SwishJitAutoFn(torch.autograd.Function):
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def forward(ctx, x):
|
def forward(ctx, x):
|
||||||
ctx.save_for_backward(x)
|
ctx.save_for_backward(x)
|
||||||
return swish_jit_fwd(x)
|
return swish_fwd(x)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def backward(ctx, grad_output):
|
def backward(ctx, grad_output):
|
||||||
x = ctx.saved_tensors[0]
|
x = ctx.saved_tensors[0]
|
||||||
return swish_jit_bwd(x, grad_output)
|
return swish_bwd(x, grad_output)
|
||||||
|
|
||||||
|
|
||||||
def swish_me(x, inplace=False):
|
def swish_me(x, inplace=False):
|
||||||
return SwishJitAutoFn.apply(x)
|
return SwishAutoFn.apply(x)
|
||||||
|
|
||||||
|
|
||||||
class SwishMe(nn.Module):
|
class SwishMe(nn.Module):
|
||||||
|
@ -54,38 +52,36 @@ class SwishMe(nn.Module):
|
||||||
super(SwishMe, self).__init__()
|
super(SwishMe, self).__init__()
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
return SwishJitAutoFn.apply(x)
|
return SwishAutoFn.apply(x)
|
||||||
|
|
||||||
|
|
||||||
@torch.jit.script
|
def mish_fwd(x):
|
||||||
def mish_jit_fwd(x):
|
|
||||||
return x.mul(torch.tanh(F.softplus(x)))
|
return x.mul(torch.tanh(F.softplus(x)))
|
||||||
|
|
||||||
|
|
||||||
@torch.jit.script
|
def mish_bwd(x, grad_output):
|
||||||
def mish_jit_bwd(x, grad_output):
|
|
||||||
x_sigmoid = torch.sigmoid(x)
|
x_sigmoid = torch.sigmoid(x)
|
||||||
x_tanh_sp = F.softplus(x).tanh()
|
x_tanh_sp = F.softplus(x).tanh()
|
||||||
return grad_output.mul(x_tanh_sp + x * x_sigmoid * (1 - x_tanh_sp * x_tanh_sp))
|
return grad_output.mul(x_tanh_sp + x * x_sigmoid * (1 - x_tanh_sp * x_tanh_sp))
|
||||||
|
|
||||||
|
|
||||||
class MishJitAutoFn(torch.autograd.Function):
|
class MishAutoFn(torch.autograd.Function):
|
||||||
""" Mish: A Self Regularized Non-Monotonic Neural Activation Function - https://arxiv.org/abs/1908.08681
|
""" Mish: A Self Regularized Non-Monotonic Neural Activation Function - https://arxiv.org/abs/1908.08681
|
||||||
A memory efficient, jit scripted variant of Mish
|
A memory efficient variant of Mish
|
||||||
"""
|
"""
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def forward(ctx, x):
|
def forward(ctx, x):
|
||||||
ctx.save_for_backward(x)
|
ctx.save_for_backward(x)
|
||||||
return mish_jit_fwd(x)
|
return mish_fwd(x)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def backward(ctx, grad_output):
|
def backward(ctx, grad_output):
|
||||||
x = ctx.saved_tensors[0]
|
x = ctx.saved_tensors[0]
|
||||||
return mish_jit_bwd(x, grad_output)
|
return mish_bwd(x, grad_output)
|
||||||
|
|
||||||
|
|
||||||
def mish_me(x, inplace=False):
|
def mish_me(x, inplace=False):
|
||||||
return MishJitAutoFn.apply(x)
|
return MishAutoFn.apply(x)
|
||||||
|
|
||||||
|
|
||||||
class MishMe(nn.Module):
|
class MishMe(nn.Module):
|
||||||
|
@ -93,34 +89,32 @@ class MishMe(nn.Module):
|
||||||
super(MishMe, self).__init__()
|
super(MishMe, self).__init__()
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
return MishJitAutoFn.apply(x)
|
return MishAutoFn.apply(x)
|
||||||
|
|
||||||
|
|
||||||
@torch.jit.script
|
def hard_sigmoid_fwd(x, inplace: bool = False):
|
||||||
def hard_sigmoid_jit_fwd(x, inplace: bool = False):
|
|
||||||
return (x + 3).clamp(min=0, max=6).div(6.)
|
return (x + 3).clamp(min=0, max=6).div(6.)
|
||||||
|
|
||||||
|
|
||||||
@torch.jit.script
|
def hard_sigmoid_bwd(x, grad_output):
|
||||||
def hard_sigmoid_jit_bwd(x, grad_output):
|
|
||||||
m = torch.ones_like(x) * ((x >= -3.) & (x <= 3.)) / 6.
|
m = torch.ones_like(x) * ((x >= -3.) & (x <= 3.)) / 6.
|
||||||
return grad_output * m
|
return grad_output * m
|
||||||
|
|
||||||
|
|
||||||
class HardSigmoidJitAutoFn(torch.autograd.Function):
|
class HardSigmoidAutoFn(torch.autograd.Function):
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def forward(ctx, x):
|
def forward(ctx, x):
|
||||||
ctx.save_for_backward(x)
|
ctx.save_for_backward(x)
|
||||||
return hard_sigmoid_jit_fwd(x)
|
return hard_sigmoid_fwd(x)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def backward(ctx, grad_output):
|
def backward(ctx, grad_output):
|
||||||
x = ctx.saved_tensors[0]
|
x = ctx.saved_tensors[0]
|
||||||
return hard_sigmoid_jit_bwd(x, grad_output)
|
return hard_sigmoid_bwd(x, grad_output)
|
||||||
|
|
||||||
|
|
||||||
def hard_sigmoid_me(x, inplace: bool = False):
|
def hard_sigmoid_me(x, inplace: bool = False):
|
||||||
return HardSigmoidJitAutoFn.apply(x)
|
return HardSigmoidAutoFn.apply(x)
|
||||||
|
|
||||||
|
|
||||||
class HardSigmoidMe(nn.Module):
|
class HardSigmoidMe(nn.Module):
|
||||||
|
@ -128,32 +122,30 @@ class HardSigmoidMe(nn.Module):
|
||||||
super(HardSigmoidMe, self).__init__()
|
super(HardSigmoidMe, self).__init__()
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
return HardSigmoidJitAutoFn.apply(x)
|
return HardSigmoidAutoFn.apply(x)
|
||||||
|
|
||||||
|
|
||||||
@torch.jit.script
|
def hard_swish_fwd(x):
|
||||||
def hard_swish_jit_fwd(x):
|
|
||||||
return x * (x + 3).clamp(min=0, max=6).div(6.)
|
return x * (x + 3).clamp(min=0, max=6).div(6.)
|
||||||
|
|
||||||
|
|
||||||
@torch.jit.script
|
def hard_swish_bwd(x, grad_output):
|
||||||
def hard_swish_jit_bwd(x, grad_output):
|
|
||||||
m = torch.ones_like(x) * (x >= 3.)
|
m = torch.ones_like(x) * (x >= 3.)
|
||||||
m = torch.where((x >= -3.) & (x <= 3.), x / 3. + .5, m)
|
m = torch.where((x >= -3.) & (x <= 3.), x / 3. + .5, m)
|
||||||
return grad_output * m
|
return grad_output * m
|
||||||
|
|
||||||
|
|
||||||
class HardSwishJitAutoFn(torch.autograd.Function):
|
class HardSwishAutoFn(torch.autograd.Function):
|
||||||
"""A memory efficient, jit-scripted HardSwish activation"""
|
"""A memory efficient HardSwish activation"""
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def forward(ctx, x):
|
def forward(ctx, x):
|
||||||
ctx.save_for_backward(x)
|
ctx.save_for_backward(x)
|
||||||
return hard_swish_jit_fwd(x)
|
return hard_swish_fwd(x)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def backward(ctx, grad_output):
|
def backward(ctx, grad_output):
|
||||||
x = ctx.saved_tensors[0]
|
x = ctx.saved_tensors[0]
|
||||||
return hard_swish_jit_bwd(x, grad_output)
|
return hard_swish_bwd(x, grad_output)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def symbolic(g, self):
|
def symbolic(g, self):
|
||||||
|
@ -164,7 +156,7 @@ class HardSwishJitAutoFn(torch.autograd.Function):
|
||||||
|
|
||||||
|
|
||||||
def hard_swish_me(x, inplace=False):
|
def hard_swish_me(x, inplace=False):
|
||||||
return HardSwishJitAutoFn.apply(x)
|
return HardSwishAutoFn.apply(x)
|
||||||
|
|
||||||
|
|
||||||
class HardSwishMe(nn.Module):
|
class HardSwishMe(nn.Module):
|
||||||
|
@ -172,39 +164,37 @@ class HardSwishMe(nn.Module):
|
||||||
super(HardSwishMe, self).__init__()
|
super(HardSwishMe, self).__init__()
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
return HardSwishJitAutoFn.apply(x)
|
return HardSwishAutoFn.apply(x)
|
||||||
|
|
||||||
|
|
||||||
@torch.jit.script
|
def hard_mish_fwd(x):
|
||||||
def hard_mish_jit_fwd(x):
|
|
||||||
return 0.5 * x * (x + 2).clamp(min=0, max=2)
|
return 0.5 * x * (x + 2).clamp(min=0, max=2)
|
||||||
|
|
||||||
|
|
||||||
@torch.jit.script
|
def hard_mish_bwd(x, grad_output):
|
||||||
def hard_mish_jit_bwd(x, grad_output):
|
|
||||||
m = torch.ones_like(x) * (x >= -2.)
|
m = torch.ones_like(x) * (x >= -2.)
|
||||||
m = torch.where((x >= -2.) & (x <= 0.), x + 1., m)
|
m = torch.where((x >= -2.) & (x <= 0.), x + 1., m)
|
||||||
return grad_output * m
|
return grad_output * m
|
||||||
|
|
||||||
|
|
||||||
class HardMishJitAutoFn(torch.autograd.Function):
|
class HardMishAutoFn(torch.autograd.Function):
|
||||||
""" A memory efficient, jit scripted variant of Hard Mish
|
""" A memory efficient variant of Hard Mish
|
||||||
Experimental, based on notes by Mish author Diganta Misra at
|
Experimental, based on notes by Mish author Diganta Misra at
|
||||||
https://github.com/digantamisra98/H-Mish/blob/0da20d4bc58e696b6803f2523c58d3c8a82782d0/README.md
|
https://github.com/digantamisra98/H-Mish/blob/0da20d4bc58e696b6803f2523c58d3c8a82782d0/README.md
|
||||||
"""
|
"""
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def forward(ctx, x):
|
def forward(ctx, x):
|
||||||
ctx.save_for_backward(x)
|
ctx.save_for_backward(x)
|
||||||
return hard_mish_jit_fwd(x)
|
return hard_mish_fwd(x)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def backward(ctx, grad_output):
|
def backward(ctx, grad_output):
|
||||||
x = ctx.saved_tensors[0]
|
x = ctx.saved_tensors[0]
|
||||||
return hard_mish_jit_bwd(x, grad_output)
|
return hard_mish_bwd(x, grad_output)
|
||||||
|
|
||||||
|
|
||||||
def hard_mish_me(x, inplace: bool = False):
|
def hard_mish_me(x, inplace: bool = False):
|
||||||
return HardMishJitAutoFn.apply(x)
|
return HardMishAutoFn.apply(x)
|
||||||
|
|
||||||
|
|
||||||
class HardMishMe(nn.Module):
|
class HardMishMe(nn.Module):
|
||||||
|
@ -212,7 +202,7 @@ class HardMishMe(nn.Module):
|
||||||
super(HardMishMe, self).__init__()
|
super(HardMishMe, self).__init__()
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
return HardMishJitAutoFn.apply(x)
|
return HardMishAutoFn.apply(x)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -4,9 +4,8 @@ Hacked together by / Copyright 2020 Ross Wightman
|
||||||
from typing import Union, Callable, Type
|
from typing import Union, Callable, Type
|
||||||
|
|
||||||
from .activations import *
|
from .activations import *
|
||||||
from .activations_jit import *
|
|
||||||
from .activations_me import *
|
from .activations_me import *
|
||||||
from .config import is_exportable, is_scriptable, is_no_jit
|
from .config import is_exportable, is_scriptable
|
||||||
|
|
||||||
# PyTorch has an optimized, native 'silu' (aka 'swish') operator as of PyTorch 1.7.
|
# PyTorch has an optimized, native 'silu' (aka 'swish') operator as of PyTorch 1.7.
|
||||||
# Also hardsigmoid, hardswish, and soon mish. This code will use native version if present.
|
# Also hardsigmoid, hardswish, and soon mish. This code will use native version if present.
|
||||||
|
@ -37,15 +36,6 @@ _ACT_FN_DEFAULT = dict(
|
||||||
hard_mish=hard_mish,
|
hard_mish=hard_mish,
|
||||||
)
|
)
|
||||||
|
|
||||||
_ACT_FN_JIT = dict(
|
|
||||||
silu=F.silu if _has_silu else swish_jit,
|
|
||||||
swish=F.silu if _has_silu else swish_jit,
|
|
||||||
mish=F.mish if _has_mish else mish_jit,
|
|
||||||
hard_sigmoid=F.hardsigmoid if _has_hardsigmoid else hard_sigmoid_jit,
|
|
||||||
hard_swish=F.hardswish if _has_hardswish else hard_swish_jit,
|
|
||||||
hard_mish=hard_mish_jit,
|
|
||||||
)
|
|
||||||
|
|
||||||
_ACT_FN_ME = dict(
|
_ACT_FN_ME = dict(
|
||||||
silu=F.silu if _has_silu else swish_me,
|
silu=F.silu if _has_silu else swish_me,
|
||||||
swish=F.silu if _has_silu else swish_me,
|
swish=F.silu if _has_silu else swish_me,
|
||||||
|
@ -55,7 +45,7 @@ _ACT_FN_ME = dict(
|
||||||
hard_mish=hard_mish_me,
|
hard_mish=hard_mish_me,
|
||||||
)
|
)
|
||||||
|
|
||||||
_ACT_FNS = (_ACT_FN_ME, _ACT_FN_JIT, _ACT_FN_DEFAULT)
|
_ACT_FNS = (_ACT_FN_ME, _ACT_FN_DEFAULT)
|
||||||
for a in _ACT_FNS:
|
for a in _ACT_FNS:
|
||||||
a.setdefault('hardsigmoid', a.get('hard_sigmoid'))
|
a.setdefault('hardsigmoid', a.get('hard_sigmoid'))
|
||||||
a.setdefault('hardswish', a.get('hard_swish'))
|
a.setdefault('hardswish', a.get('hard_swish'))
|
||||||
|
@ -83,15 +73,6 @@ _ACT_LAYER_DEFAULT = dict(
|
||||||
identity=nn.Identity,
|
identity=nn.Identity,
|
||||||
)
|
)
|
||||||
|
|
||||||
_ACT_LAYER_JIT = dict(
|
|
||||||
silu=nn.SiLU if _has_silu else SwishJit,
|
|
||||||
swish=nn.SiLU if _has_silu else SwishJit,
|
|
||||||
mish=nn.Mish if _has_mish else MishJit,
|
|
||||||
hard_sigmoid=nn.Hardsigmoid if _has_hardsigmoid else HardSigmoidJit,
|
|
||||||
hard_swish=nn.Hardswish if _has_hardswish else HardSwishJit,
|
|
||||||
hard_mish=HardMishJit,
|
|
||||||
)
|
|
||||||
|
|
||||||
_ACT_LAYER_ME = dict(
|
_ACT_LAYER_ME = dict(
|
||||||
silu=nn.SiLU if _has_silu else SwishMe,
|
silu=nn.SiLU if _has_silu else SwishMe,
|
||||||
swish=nn.SiLU if _has_silu else SwishMe,
|
swish=nn.SiLU if _has_silu else SwishMe,
|
||||||
|
@ -101,7 +82,7 @@ _ACT_LAYER_ME = dict(
|
||||||
hard_mish=HardMishMe,
|
hard_mish=HardMishMe,
|
||||||
)
|
)
|
||||||
|
|
||||||
_ACT_LAYERS = (_ACT_LAYER_ME, _ACT_LAYER_JIT, _ACT_LAYER_DEFAULT)
|
_ACT_LAYERS = (_ACT_LAYER_ME, _ACT_LAYER_DEFAULT)
|
||||||
for a in _ACT_LAYERS:
|
for a in _ACT_LAYERS:
|
||||||
a.setdefault('hardsigmoid', a.get('hard_sigmoid'))
|
a.setdefault('hardsigmoid', a.get('hard_sigmoid'))
|
||||||
a.setdefault('hardswish', a.get('hard_swish'))
|
a.setdefault('hardswish', a.get('hard_swish'))
|
||||||
|
@ -116,14 +97,11 @@ def get_act_fn(name: Union[Callable, str] = 'relu'):
|
||||||
return None
|
return None
|
||||||
if isinstance(name, Callable):
|
if isinstance(name, Callable):
|
||||||
return name
|
return name
|
||||||
if not (is_no_jit() or is_exportable() or is_scriptable()):
|
if not (is_exportable() or is_scriptable()):
|
||||||
# If not exporting or scripting the model, first look for a memory-efficient version with
|
# If not exporting or scripting the model, first look for a memory-efficient version with
|
||||||
# custom autograd, then fallback
|
# custom autograd, then fallback
|
||||||
if name in _ACT_FN_ME:
|
if name in _ACT_FN_ME:
|
||||||
return _ACT_FN_ME[name]
|
return _ACT_FN_ME[name]
|
||||||
if not (is_no_jit() or is_exportable()):
|
|
||||||
if name in _ACT_FN_JIT:
|
|
||||||
return _ACT_FN_JIT[name]
|
|
||||||
return _ACT_FN_DEFAULT[name]
|
return _ACT_FN_DEFAULT[name]
|
||||||
|
|
||||||
|
|
||||||
|
@ -139,12 +117,9 @@ def get_act_layer(name: Union[Type[nn.Module], str] = 'relu'):
|
||||||
return name
|
return name
|
||||||
if not name:
|
if not name:
|
||||||
return None
|
return None
|
||||||
if not (is_no_jit() or is_exportable() or is_scriptable()):
|
if not (is_exportable() or is_scriptable()):
|
||||||
if name in _ACT_LAYER_ME:
|
if name in _ACT_LAYER_ME:
|
||||||
return _ACT_LAYER_ME[name]
|
return _ACT_LAYER_ME[name]
|
||||||
if not (is_no_jit() or is_exportable()):
|
|
||||||
if name in _ACT_LAYER_JIT:
|
|
||||||
return _ACT_LAYER_JIT[name]
|
|
||||||
return _ACT_LAYER_DEFAULT[name]
|
return _ACT_LAYER_DEFAULT[name]
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -73,7 +73,6 @@ class TransformerDecoderLayerOptimal(nn.Module):
|
||||||
return tgt
|
return tgt
|
||||||
|
|
||||||
|
|
||||||
# @torch.jit.script
|
|
||||||
# class ExtrapClasses(object):
|
# class ExtrapClasses(object):
|
||||||
# def __init__(self, num_queries: int, group_size: int):
|
# def __init__(self, num_queries: int, group_size: int):
|
||||||
# self.num_queries = num_queries
|
# self.num_queries = num_queries
|
||||||
|
@ -88,24 +87,13 @@ class TransformerDecoderLayerOptimal(nn.Module):
|
||||||
# out = out.view((h.shape[0], self.group_size * self.num_queries))
|
# out = out.view((h.shape[0], self.group_size * self.num_queries))
|
||||||
# return out
|
# return out
|
||||||
|
|
||||||
@torch.jit.script
|
|
||||||
class GroupFC(object):
|
|
||||||
def __init__(self, embed_len_decoder: int):
|
|
||||||
self.embed_len_decoder = embed_len_decoder
|
|
||||||
|
|
||||||
def __call__(self, h: torch.Tensor, duplicate_pooling: torch.Tensor, out_extrap: torch.Tensor):
|
|
||||||
for i in range(self.embed_len_decoder):
|
|
||||||
h_i = h[:, i, :]
|
|
||||||
w_i = duplicate_pooling[i, :, :]
|
|
||||||
out_extrap[:, i, :] = torch.matmul(h_i, w_i)
|
|
||||||
|
|
||||||
|
|
||||||
class MLDecoder(nn.Module):
|
class MLDecoder(nn.Module):
|
||||||
def __init__(self, num_classes, num_of_groups=-1, decoder_embedding=768, initial_num_features=2048):
|
def __init__(self, num_classes, num_of_groups=-1, decoder_embedding=768, initial_num_features=2048):
|
||||||
super(MLDecoder, self).__init__()
|
super(MLDecoder, self).__init__()
|
||||||
embed_len_decoder = 100 if num_of_groups < 0 else num_of_groups
|
embed_len_decoder = 100 if num_of_groups < 0 else num_of_groups
|
||||||
if embed_len_decoder > num_classes:
|
if embed_len_decoder > num_classes:
|
||||||
embed_len_decoder = num_classes
|
embed_len_decoder = num_classes
|
||||||
|
self.embed_len_decoder = embed_len_decoder
|
||||||
|
|
||||||
# switching to 768 initial embeddings
|
# switching to 768 initial embeddings
|
||||||
decoder_embedding = 768 if decoder_embedding < 0 else decoder_embedding
|
decoder_embedding = 768 if decoder_embedding < 0 else decoder_embedding
|
||||||
|
@ -131,7 +119,6 @@ class MLDecoder(nn.Module):
|
||||||
self.duplicate_pooling_bias = torch.nn.Parameter(torch.Tensor(num_classes))
|
self.duplicate_pooling_bias = torch.nn.Parameter(torch.Tensor(num_classes))
|
||||||
torch.nn.init.xavier_normal_(self.duplicate_pooling)
|
torch.nn.init.xavier_normal_(self.duplicate_pooling)
|
||||||
torch.nn.init.constant_(self.duplicate_pooling_bias, 0)
|
torch.nn.init.constant_(self.duplicate_pooling_bias, 0)
|
||||||
self.group_fc = GroupFC(embed_len_decoder)
|
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
if len(x.shape) == 4: # [bs,2048, 7,7]
|
if len(x.shape) == 4: # [bs,2048, 7,7]
|
||||||
|
@ -149,7 +136,10 @@ class MLDecoder(nn.Module):
|
||||||
h = h.transpose(0, 1)
|
h = h.transpose(0, 1)
|
||||||
|
|
||||||
out_extrap = torch.zeros(h.shape[0], h.shape[1], self.duplicate_factor, device=h.device, dtype=h.dtype)
|
out_extrap = torch.zeros(h.shape[0], h.shape[1], self.duplicate_factor, device=h.device, dtype=h.dtype)
|
||||||
self.group_fc(h, self.duplicate_pooling, out_extrap)
|
for i in range(self.embed_len_decoder): # group FC
|
||||||
|
h_i = h[:, i, :]
|
||||||
|
w_i = self.duplicate_pooling[i, :, :]
|
||||||
|
out_extrap[:, i, :] = torch.matmul(h_i, w_i)
|
||||||
h_out = out_extrap.flatten(1)[:, :self.num_classes]
|
h_out = out_extrap.flatten(1)[:, :self.num_classes]
|
||||||
h_out += self.duplicate_pooling_bias
|
h_out += self.duplicate_pooling_bias
|
||||||
logits = h_out
|
logits = h_out
|
||||||
|
|
|
@ -82,7 +82,6 @@ def _is_contiguous(tensor: torch.Tensor) -> bool:
|
||||||
return tensor.is_contiguous(memory_format=torch.contiguous_format)
|
return tensor.is_contiguous(memory_format=torch.contiguous_format)
|
||||||
|
|
||||||
|
|
||||||
@torch.jit.script
|
|
||||||
def _layer_norm_cf(x: torch.Tensor, weight: torch.Tensor, bias: torch.Tensor, eps: float):
|
def _layer_norm_cf(x: torch.Tensor, weight: torch.Tensor, bias: torch.Tensor, eps: float):
|
||||||
s, u = torch.var_mean(x, dim=1, unbiased=False, keepdim=True)
|
s, u = torch.var_mean(x, dim=1, unbiased=False, keepdim=True)
|
||||||
x = (x - u) * torch.rsqrt(s + eps)
|
x = (x - u) * torch.rsqrt(s + eps)
|
||||||
|
|
|
@ -18,29 +18,6 @@ class SpaceToDepth(nn.Module):
|
||||||
return x
|
return x
|
||||||
|
|
||||||
|
|
||||||
@torch.jit.script
|
|
||||||
class SpaceToDepthJit:
|
|
||||||
def __call__(self, x: torch.Tensor):
|
|
||||||
# assuming hard-coded that block_size==4 for acceleration
|
|
||||||
N, C, H, W = x.size()
|
|
||||||
x = x.view(N, C, H // 4, 4, W // 4, 4) # (N, C, H//bs, bs, W//bs, bs)
|
|
||||||
x = x.permute(0, 3, 5, 1, 2, 4).contiguous() # (N, bs, bs, C, H//bs, W//bs)
|
|
||||||
x = x.view(N, C * 16, H // 4, W // 4) # (N, C*bs^2, H//bs, W//bs)
|
|
||||||
return x
|
|
||||||
|
|
||||||
|
|
||||||
class SpaceToDepthModule(nn.Module):
|
|
||||||
def __init__(self, no_jit=False):
|
|
||||||
super().__init__()
|
|
||||||
if not no_jit:
|
|
||||||
self.op = SpaceToDepthJit()
|
|
||||||
else:
|
|
||||||
self.op = SpaceToDepth()
|
|
||||||
|
|
||||||
def forward(self, x):
|
|
||||||
return self.op(x)
|
|
||||||
|
|
||||||
|
|
||||||
class DepthToSpace(nn.Module):
|
class DepthToSpace(nn.Module):
|
||||||
|
|
||||||
def __init__(self, block_size):
|
def __init__(self, block_size):
|
||||||
|
|
|
@ -80,7 +80,7 @@ from ._builder import build_model_with_cfg, load_pretrained, load_custom_pretrai
|
||||||
set_pretrained_download_progress, set_pretrained_check_hash
|
set_pretrained_download_progress, set_pretrained_check_hash
|
||||||
from ._factory import create_model, parse_model_name, safe_model_name
|
from ._factory import create_model, parse_model_name, safe_model_name
|
||||||
from ._features import FeatureInfo, FeatureHooks, FeatureHookNet, FeatureListNet, FeatureDictNet
|
from ._features import FeatureInfo, FeatureHooks, FeatureHookNet, FeatureListNet, FeatureDictNet
|
||||||
from ._features_fx import FeatureGraphNet, GraphExtractNet, create_feature_extractor, \
|
from ._features_fx import FeatureGraphNet, GraphExtractNet, create_feature_extractor, get_graph_node_names, \
|
||||||
register_notrace_module, is_notrace_module, get_notrace_modules, \
|
register_notrace_module, is_notrace_module, get_notrace_modules, \
|
||||||
register_notrace_function, is_notrace_function, get_notrace_functions
|
register_notrace_function, is_notrace_function, get_notrace_functions
|
||||||
from ._helpers import clean_state_dict, load_state_dict, load_checkpoint, remap_state_dict, resume_checkpoint
|
from ._helpers import clean_state_dict, load_state_dict, load_checkpoint, remap_state_dict, resume_checkpoint
|
||||||
|
|
|
@ -158,7 +158,7 @@ class FeatureHooks:
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
hooks: Sequence[str],
|
hooks: Sequence[Union[str, Dict]],
|
||||||
named_modules: dict,
|
named_modules: dict,
|
||||||
out_map: Sequence[Union[int, str]] = None,
|
out_map: Sequence[Union[int, str]] = None,
|
||||||
default_hook_type: str = 'forward',
|
default_hook_type: str = 'forward',
|
||||||
|
@ -168,11 +168,13 @@ class FeatureHooks:
|
||||||
self._handles = []
|
self._handles = []
|
||||||
modules = {k: v for k, v in named_modules}
|
modules = {k: v for k, v in named_modules}
|
||||||
for i, h in enumerate(hooks):
|
for i, h in enumerate(hooks):
|
||||||
hook_name = h['module']
|
hook_name = h if isinstance(h, str) else h['module']
|
||||||
m = modules[hook_name]
|
m = modules[hook_name]
|
||||||
hook_id = out_map[i] if out_map else hook_name
|
hook_id = out_map[i] if out_map else hook_name
|
||||||
hook_fn = partial(self._collect_output_hook, hook_id)
|
hook_fn = partial(self._collect_output_hook, hook_id)
|
||||||
hook_type = h.get('hook_type', default_hook_type)
|
hook_type = default_hook_type
|
||||||
|
if isinstance(h, dict):
|
||||||
|
hook_type = h.get('hook_type', default_hook_type)
|
||||||
if hook_type == 'forward_pre':
|
if hook_type == 'forward_pre':
|
||||||
handle = m.register_forward_pre_hook(hook_fn)
|
handle = m.register_forward_pre_hook(hook_fn)
|
||||||
elif hook_type == 'forward':
|
elif hook_type == 'forward':
|
||||||
|
|
|
@ -9,7 +9,9 @@ from torch import nn
|
||||||
from ._features import _get_feature_info, _get_return_layers
|
from ._features import _get_feature_info, _get_return_layers
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
# NOTE we wrap torchvision fns to use timm leaf / no trace definitions
|
||||||
from torchvision.models.feature_extraction import create_feature_extractor as _create_feature_extractor
|
from torchvision.models.feature_extraction import create_feature_extractor as _create_feature_extractor
|
||||||
|
from torchvision.models.feature_extraction import get_graph_node_names as _get_graph_node_names
|
||||||
has_fx_feature_extraction = True
|
has_fx_feature_extraction = True
|
||||||
except ImportError:
|
except ImportError:
|
||||||
has_fx_feature_extraction = False
|
has_fx_feature_extraction = False
|
||||||
|
@ -30,7 +32,7 @@ from timm.layers.norm_act import (
|
||||||
|
|
||||||
__all__ = ['register_notrace_module', 'is_notrace_module', 'get_notrace_modules',
|
__all__ = ['register_notrace_module', 'is_notrace_module', 'get_notrace_modules',
|
||||||
'register_notrace_function', 'is_notrace_function', 'get_notrace_functions',
|
'register_notrace_function', 'is_notrace_function', 'get_notrace_functions',
|
||||||
'create_feature_extractor', 'FeatureGraphNet', 'GraphExtractNet']
|
'create_feature_extractor', 'get_graph_node_names', 'FeatureGraphNet', 'GraphExtractNet']
|
||||||
|
|
||||||
|
|
||||||
# NOTE: By default, any modules from timm.models.layers that we want to treat as leaf modules go here
|
# NOTE: By default, any modules from timm.models.layers that we want to treat as leaf modules go here
|
||||||
|
@ -92,6 +94,13 @@ def get_notrace_functions():
|
||||||
return list(_autowrap_functions)
|
return list(_autowrap_functions)
|
||||||
|
|
||||||
|
|
||||||
|
def get_graph_node_names(model: nn.Module) -> Tuple[List[str], List[str]]:
|
||||||
|
return _get_graph_node_names(
|
||||||
|
model,
|
||||||
|
tracer_kwargs={'leaf_modules': list(_leaf_modules), 'autowrap_functions': list(_autowrap_functions)}
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def create_feature_extractor(model: nn.Module, return_nodes: Union[Dict[str, str], List[str]]):
|
def create_feature_extractor(model: nn.Module, return_nodes: Union[Dict[str, str], List[str]]):
|
||||||
assert has_fx_feature_extraction, 'Please update to PyTorch 1.10+, torchvision 0.11+ for FX feature extraction'
|
assert has_fx_feature_extraction, 'Please update to PyTorch 1.10+, torchvision 0.11+ for FX feature extraction'
|
||||||
return _create_feature_extractor(
|
return _create_feature_extractor(
|
||||||
|
|
|
@ -40,6 +40,7 @@ from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
|
||||||
from timm.layers import ClassifierHead, ConvNormAct, BatchNormAct2d, DropPath, AvgPool2dSame, \
|
from timm.layers import ClassifierHead, ConvNormAct, BatchNormAct2d, DropPath, AvgPool2dSame, \
|
||||||
create_conv2d, get_act_layer, get_norm_act_layer, get_attn, make_divisible, to_2tuple, EvoNorm2dS0a
|
create_conv2d, get_act_layer, get_norm_act_layer, get_attn, make_divisible, to_2tuple, EvoNorm2dS0a
|
||||||
from ._builder import build_model_with_cfg
|
from ._builder import build_model_with_cfg
|
||||||
|
from ._features import feature_take_indices
|
||||||
from ._manipulate import named_apply, checkpoint_seq
|
from ._manipulate import named_apply, checkpoint_seq
|
||||||
from ._registry import generate_default_cfgs, register_model
|
from ._registry import generate_default_cfgs, register_model
|
||||||
|
|
||||||
|
@ -948,25 +949,37 @@ class Stem(nn.Sequential):
|
||||||
stem_norm_acts = [False] * (num_rep - num_act) + [True] * num_act
|
stem_norm_acts = [False] * (num_rep - num_act) + [True] * num_act
|
||||||
prev_chs = in_chs
|
prev_chs = in_chs
|
||||||
curr_stride = 1
|
curr_stride = 1
|
||||||
|
last_feat_idx = -1
|
||||||
for i, (ch, s, na) in enumerate(zip(stem_chs, stem_strides, stem_norm_acts)):
|
for i, (ch, s, na) in enumerate(zip(stem_chs, stem_strides, stem_norm_acts)):
|
||||||
layer_fn = layers.conv_norm_act if na else create_conv2d
|
layer_fn = layers.conv_norm_act if na else create_conv2d
|
||||||
conv_name = f'conv{i + 1}'
|
conv_name = f'conv{i + 1}'
|
||||||
if i > 0 and s > 1:
|
if i > 0 and s > 1:
|
||||||
self.feature_info.append(dict(num_chs=prev_chs, reduction=curr_stride, module=prev_feat))
|
last_feat_idx = i - 1
|
||||||
|
self.feature_info.append(dict(num_chs=prev_chs, reduction=curr_stride, module=prev_feat, stage=0))
|
||||||
self.add_module(conv_name, layer_fn(prev_chs, ch, kernel_size=kernel_size, stride=s))
|
self.add_module(conv_name, layer_fn(prev_chs, ch, kernel_size=kernel_size, stride=s))
|
||||||
prev_chs = ch
|
prev_chs = ch
|
||||||
curr_stride *= s
|
curr_stride *= s
|
||||||
prev_feat = conv_name
|
prev_feat = conv_name
|
||||||
|
|
||||||
if pool and 'max' in pool.lower():
|
if pool and 'max' in pool.lower():
|
||||||
self.feature_info.append(dict(num_chs=prev_chs, reduction=curr_stride, module=prev_feat))
|
last_feat_idx = i
|
||||||
|
self.feature_info.append(dict(num_chs=prev_chs, reduction=curr_stride, module=prev_feat, stage=0))
|
||||||
self.add_module('pool', nn.MaxPool2d(3, 2, 1))
|
self.add_module('pool', nn.MaxPool2d(3, 2, 1))
|
||||||
curr_stride *= 2
|
curr_stride *= 2
|
||||||
prev_feat = 'pool'
|
prev_feat = 'pool'
|
||||||
|
|
||||||
self.feature_info.append(dict(num_chs=prev_chs, reduction=curr_stride, module=prev_feat))
|
self.last_feat_idx = last_feat_idx if last_feat_idx >= 0 else None
|
||||||
|
self.feature_info.append(dict(num_chs=prev_chs, reduction=curr_stride, module=prev_feat, stage=0))
|
||||||
assert curr_stride == stride
|
assert curr_stride == stride
|
||||||
|
|
||||||
|
def forward_intermediates(self, x) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
|
||||||
|
intermediate: Optional[torch.Tensor] = None
|
||||||
|
for i, m in enumerate(self):
|
||||||
|
x = m(x)
|
||||||
|
if self.last_feat_idx is not None and i == self.last_feat_idx:
|
||||||
|
intermediate = x
|
||||||
|
return x, intermediate
|
||||||
|
|
||||||
|
|
||||||
def create_byob_stem(
|
def create_byob_stem(
|
||||||
in_chs: int,
|
in_chs: int,
|
||||||
|
@ -1008,7 +1021,7 @@ def create_byob_stem(
|
||||||
if isinstance(stem, Stem):
|
if isinstance(stem, Stem):
|
||||||
feature_info = [dict(f, module='.'.join([feat_prefix, f['module']])) for f in stem.feature_info]
|
feature_info = [dict(f, module='.'.join([feat_prefix, f['module']])) for f in stem.feature_info]
|
||||||
else:
|
else:
|
||||||
feature_info = [dict(num_chs=out_chs, reduction=2, module=feat_prefix)]
|
feature_info = [dict(num_chs=out_chs, reduction=2, module=feat_prefix, stage=0)]
|
||||||
return stem, feature_info
|
return stem, feature_info
|
||||||
|
|
||||||
|
|
||||||
|
@ -1122,7 +1135,7 @@ def create_byob_stages(
|
||||||
feat_size = reduce_feat_size(feat_size, stride)
|
feat_size = reduce_feat_size(feat_size, stride)
|
||||||
|
|
||||||
stages += [nn.Sequential(*blocks)]
|
stages += [nn.Sequential(*blocks)]
|
||||||
prev_feat = dict(num_chs=prev_chs, reduction=net_stride, module=f'stages.{stage_idx}')
|
prev_feat = dict(num_chs=prev_chs, reduction=net_stride, module=f'stages.{stage_idx}', stage=stage_idx + 1)
|
||||||
|
|
||||||
feature_info.append(prev_feat)
|
feature_info.append(prev_feat)
|
||||||
return nn.Sequential(*stages), feature_info
|
return nn.Sequential(*stages), feature_info
|
||||||
|
@ -1198,6 +1211,7 @@ class ByobNet(nn.Module):
|
||||||
feat_size=feat_size,
|
feat_size=feat_size,
|
||||||
)
|
)
|
||||||
self.feature_info.extend(stage_feat[:-1])
|
self.feature_info.extend(stage_feat[:-1])
|
||||||
|
reduction = stage_feat[-1]['reduction']
|
||||||
|
|
||||||
prev_chs = stage_feat[-1]['num_chs']
|
prev_chs = stage_feat[-1]['num_chs']
|
||||||
if cfg.num_features:
|
if cfg.num_features:
|
||||||
|
@ -1207,7 +1221,8 @@ class ByobNet(nn.Module):
|
||||||
self.num_features = prev_chs
|
self.num_features = prev_chs
|
||||||
self.final_conv = nn.Identity()
|
self.final_conv = nn.Identity()
|
||||||
self.feature_info += [
|
self.feature_info += [
|
||||||
dict(num_chs=self.num_features, reduction=stage_feat[-1]['reduction'], module='final_conv')]
|
dict(num_chs=self.num_features, reduction=reduction, module='final_conv', stage=len(self.stages))]
|
||||||
|
self.stage_ends = [f['stage'] for f in self.feature_info]
|
||||||
|
|
||||||
self.head = ClassifierHead(
|
self.head = ClassifierHead(
|
||||||
self.num_features,
|
self.num_features,
|
||||||
|
@ -1241,6 +1256,83 @@ class ByobNet(nn.Module):
|
||||||
def reset_classifier(self, num_classes, global_pool='avg'):
|
def reset_classifier(self, num_classes, global_pool='avg'):
|
||||||
self.head.reset(num_classes, global_pool)
|
self.head.reset(num_classes, global_pool)
|
||||||
|
|
||||||
|
def forward_intermediates(
|
||||||
|
self,
|
||||||
|
x: torch.Tensor,
|
||||||
|
indices: Optional[Union[int, List[int], Tuple[int]]] = None,
|
||||||
|
norm: bool = False,
|
||||||
|
stop_early: bool = False,
|
||||||
|
output_fmt: str = 'NCHW',
|
||||||
|
intermediates_only: bool = False,
|
||||||
|
exclude_final_conv: bool = False,
|
||||||
|
) -> Union[List[torch.Tensor], Tuple[torch.Tensor, List[torch.Tensor]]]:
|
||||||
|
""" Forward features that returns intermediates.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
x: Input image tensor
|
||||||
|
indices: Take last n blocks if int, all if None, select matching indices if sequence
|
||||||
|
norm: Apply norm layer to compatible intermediates
|
||||||
|
stop_early: Stop iterating over blocks when last desired intermediate hit
|
||||||
|
output_fmt: Shape of intermediate feature outputs
|
||||||
|
intermediates_only: Only return intermediate features
|
||||||
|
exclude_final_conv: Exclude final_conv from last intermediate
|
||||||
|
Returns:
|
||||||
|
|
||||||
|
"""
|
||||||
|
assert output_fmt in ('NCHW',), 'Output shape must be NCHW.'
|
||||||
|
intermediates = []
|
||||||
|
take_indices, max_index = feature_take_indices(len(self.stage_ends), indices)
|
||||||
|
take_indices = [self.stage_ends[i] for i in take_indices]
|
||||||
|
max_index = self.stage_ends[max_index]
|
||||||
|
# forward pass
|
||||||
|
feat_idx = 0 # stem is index 0
|
||||||
|
if hasattr(self.stem, 'forward_intermediates'):
|
||||||
|
# returns last intermediate features in stem (before final stride in stride > 2 stems)
|
||||||
|
x, x_inter = self.stem.forward_intermediates(x)
|
||||||
|
else:
|
||||||
|
x, x_inter = self.stem(x), None
|
||||||
|
if feat_idx in take_indices:
|
||||||
|
intermediates.append(x if x_inter is None else x_inter)
|
||||||
|
last_idx = self.stage_ends[-1]
|
||||||
|
if torch.jit.is_scripting() or not stop_early: # can't slice blocks in torchscript
|
||||||
|
stages = self.stages
|
||||||
|
else:
|
||||||
|
stages = self.stages[:max_index]
|
||||||
|
for stage in stages:
|
||||||
|
feat_idx += 1
|
||||||
|
x = stage(x)
|
||||||
|
if not exclude_final_conv and feat_idx == last_idx:
|
||||||
|
# default feature_info for this model uses final_conv as the last feature output (if present)
|
||||||
|
x = self.final_conv(x)
|
||||||
|
if feat_idx in take_indices:
|
||||||
|
intermediates.append(x)
|
||||||
|
|
||||||
|
if intermediates_only:
|
||||||
|
return intermediates
|
||||||
|
|
||||||
|
if exclude_final_conv and feat_idx == last_idx:
|
||||||
|
x = self.final_conv(x)
|
||||||
|
|
||||||
|
return x, intermediates
|
||||||
|
|
||||||
|
def prune_intermediate_layers(
|
||||||
|
self,
|
||||||
|
indices: Union[int, List[int], Tuple[int]] = 1,
|
||||||
|
prune_norm: bool = False,
|
||||||
|
prune_head: bool = True,
|
||||||
|
):
|
||||||
|
""" Prune layers not required for specified intermediates.
|
||||||
|
"""
|
||||||
|
take_indices, max_index = feature_take_indices(len(self.stage_ends), indices)
|
||||||
|
max_index = self.stage_ends[max_index]
|
||||||
|
self.stages = self.stages[:max_index] # truncate blocks w/ stem as idx 0
|
||||||
|
if max_index < self.stage_ends[-1]:
|
||||||
|
self.final_conv = nn.Identity()
|
||||||
|
if prune_head:
|
||||||
|
self.reset_classifier(0, '')
|
||||||
|
return take_indices
|
||||||
|
|
||||||
|
|
||||||
def forward_features(self, x):
|
def forward_features(self, x):
|
||||||
x = self.stem(x)
|
x = self.stem(x)
|
||||||
if self.grad_checkpointing and not torch.jit.is_scripting():
|
if self.grad_checkpointing and not torch.jit.is_scripting():
|
||||||
|
|
|
@ -411,7 +411,7 @@ class ConvNeXt(nn.Module):
|
||||||
def forward_intermediates(
|
def forward_intermediates(
|
||||||
self,
|
self,
|
||||||
x: torch.Tensor,
|
x: torch.Tensor,
|
||||||
indices: Union[int, List[int], Tuple[int]] = None,
|
indices: Optional[Union[int, List[int], Tuple[int]]] = None,
|
||||||
norm: bool = False,
|
norm: bool = False,
|
||||||
stop_early: bool = False,
|
stop_early: bool = False,
|
||||||
output_fmt: str = 'NCHW',
|
output_fmt: str = 'NCHW',
|
||||||
|
|
|
@ -126,9 +126,9 @@ class ChannelAttention(nn.Module):
|
||||||
q, k, v = qkv.unbind(0)
|
q, k, v = qkv.unbind(0)
|
||||||
|
|
||||||
k = k * self.scale
|
k = k * self.scale
|
||||||
attention = k.transpose(-1, -2) @ v
|
attn = k.transpose(-1, -2) @ v
|
||||||
attention = attention.softmax(dim=-1)
|
attn = attn.softmax(dim=-1)
|
||||||
x = (attention @ q.transpose(-1, -2)).transpose(-1, -2)
|
x = (attn @ q.transpose(-1, -2)).transpose(-1, -2)
|
||||||
x = x.transpose(1, 2).reshape(B, N, C)
|
x = x.transpose(1, 2).reshape(B, N, C)
|
||||||
x = self.proj(x)
|
x = self.proj(x)
|
||||||
return x
|
return x
|
||||||
|
|
|
@ -12,7 +12,7 @@ Based on Apache 2.0 licensed code at https://github.com/snap-research/EfficientF
|
||||||
|
|
||||||
Modifications and timm support by / Copyright 2022, Ross Wightman
|
Modifications and timm support by / Copyright 2022, Ross Wightman
|
||||||
"""
|
"""
|
||||||
from typing import Dict, List, Tuple, Union
|
from typing import Dict, List, Optional, Tuple, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
|
@ -463,7 +463,7 @@ class EfficientFormer(nn.Module):
|
||||||
def forward_intermediates(
|
def forward_intermediates(
|
||||||
self,
|
self,
|
||||||
x: torch.Tensor,
|
x: torch.Tensor,
|
||||||
indices: Union[int, List[int], Tuple[int]] = None,
|
indices: Optional[Union[int, List[int], Tuple[int]]] = None,
|
||||||
norm: bool = False,
|
norm: bool = False,
|
||||||
stop_early: bool = False,
|
stop_early: bool = False,
|
||||||
output_fmt: str = 'NCHW',
|
output_fmt: str = 'NCHW',
|
||||||
|
|
|
@ -162,7 +162,7 @@ class EfficientNet(nn.Module):
|
||||||
def forward_intermediates(
|
def forward_intermediates(
|
||||||
self,
|
self,
|
||||||
x: torch.Tensor,
|
x: torch.Tensor,
|
||||||
indices: Union[int, List[int], Tuple[int]] = None,
|
indices: Optional[Union[int, List[int], Tuple[int]]] = None,
|
||||||
norm: bool = False,
|
norm: bool = False,
|
||||||
stop_early: bool = False,
|
stop_early: bool = False,
|
||||||
output_fmt: str = 'NCHW',
|
output_fmt: str = 'NCHW',
|
||||||
|
@ -183,8 +183,6 @@ class EfficientNet(nn.Module):
|
||||||
|
|
||||||
"""
|
"""
|
||||||
assert output_fmt in ('NCHW',), 'Output shape must be NCHW.'
|
assert output_fmt in ('NCHW',), 'Output shape must be NCHW.'
|
||||||
if stop_early:
|
|
||||||
assert intermediates_only, 'Must use intermediates_only for early stopping.'
|
|
||||||
intermediates = []
|
intermediates = []
|
||||||
if extra_blocks:
|
if extra_blocks:
|
||||||
take_indices, max_index = feature_take_indices(len(self.blocks) + 1, indices)
|
take_indices, max_index = feature_take_indices(len(self.blocks) + 1, indices)
|
||||||
|
@ -212,8 +210,9 @@ class EfficientNet(nn.Module):
|
||||||
if intermediates_only:
|
if intermediates_only:
|
||||||
return intermediates
|
return intermediates
|
||||||
|
|
||||||
x = self.conv_head(x)
|
if feat_idx == self.stage_ends[-1]:
|
||||||
x = self.bn2(x)
|
x = self.conv_head(x)
|
||||||
|
x = self.bn2(x)
|
||||||
|
|
||||||
return x, intermediates
|
return x, intermediates
|
||||||
|
|
||||||
|
|
|
@ -717,12 +717,6 @@ def checkpoint_filter_fn(
|
||||||
# fixed embedding no need to load buffer from checkpoint
|
# fixed embedding no need to load buffer from checkpoint
|
||||||
continue
|
continue
|
||||||
|
|
||||||
# FIXME here while import new weights, to remove
|
|
||||||
# if k == 'cls_token':
|
|
||||||
# print('DEBUG: cls token -> reg')
|
|
||||||
# k = 'reg_token'
|
|
||||||
# #v = v + state_dict['pos_embed'][0, :]
|
|
||||||
|
|
||||||
if 'patch_embed.proj.weight' in k:
|
if 'patch_embed.proj.weight' in k:
|
||||||
_, _, H, W = model.patch_embed.proj.weight.shape
|
_, _, H, W = model.patch_embed.proj.weight.shape
|
||||||
if v.shape[-1] != W or v.shape[-2] != H:
|
if v.shape[-1] != W or v.shape[-2] != H:
|
||||||
|
@ -951,26 +945,22 @@ default_cfgs = generate_default_cfgs({
|
||||||
num_classes=0,
|
num_classes=0,
|
||||||
),
|
),
|
||||||
|
|
||||||
'vit_medium_patch16_rope_reg1_gap_256.in1k': _cfg(
|
'vit_medium_patch16_rope_reg1_gap_256.sbb_in1k': _cfg(
|
||||||
#hf_hub_id='timm/',
|
hf_hub_id='timm/',
|
||||||
#file='vit_medium_gap1_rope-in1k-20230920-5.pth',
|
|
||||||
input_size=(3, 256, 256), crop_pct=0.95,
|
input_size=(3, 256, 256), crop_pct=0.95,
|
||||||
mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)
|
mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)
|
||||||
),
|
),
|
||||||
'vit_mediumd_patch16_rope_reg1_gap_256.in1k': _cfg(
|
'vit_mediumd_patch16_rope_reg1_gap_256.sbb_in1k': _cfg(
|
||||||
#hf_hub_id='timm/',
|
hf_hub_id='timm/',
|
||||||
#file='vit_mediumd_gap1_rope-in1k-20230926-5.pth',
|
|
||||||
input_size=(3, 256, 256), crop_pct=0.95,
|
input_size=(3, 256, 256), crop_pct=0.95,
|
||||||
mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)
|
mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)
|
||||||
),
|
),
|
||||||
'vit_betwixt_patch16_rope_reg4_gap_256.in1k': _cfg(
|
'vit_betwixt_patch16_rope_reg4_gap_256.sbb_in1k': _cfg(
|
||||||
#hf_hub_id='timm/',
|
hf_hub_id='timm/',
|
||||||
#file='vit_betwixt_gap4_rope-in1k-20231005-5.pth',
|
|
||||||
input_size=(3, 256, 256), crop_pct=0.95,
|
input_size=(3, 256, 256), crop_pct=0.95,
|
||||||
),
|
),
|
||||||
'vit_base_patch16_rope_reg1_gap_256.in1k': _cfg(
|
'vit_base_patch16_rope_reg1_gap_256.sbb_in1k': _cfg(
|
||||||
#hf_hub_id='timm/',
|
hf_hub_id='timm/',
|
||||||
#file='vit_base_gap1_rope-in1k-20230930-5.pth',
|
|
||||||
input_size=(3, 256, 256), crop_pct=0.95,
|
input_size=(3, 256, 256), crop_pct=0.95,
|
||||||
mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)
|
mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)
|
||||||
),
|
),
|
||||||
|
|
|
@ -37,7 +37,6 @@ from timm.layers.pool2d_same import AvgPool2dSame, create_pool2d
|
||||||
from timm.layers.squeeze_excite import SEModule, SqueezeExcite, EffectiveSEModule, EffectiveSqueezeExcite
|
from timm.layers.squeeze_excite import SEModule, SqueezeExcite, EffectiveSEModule, EffectiveSqueezeExcite
|
||||||
from timm.layers.selective_kernel import SelectiveKernel
|
from timm.layers.selective_kernel import SelectiveKernel
|
||||||
from timm.layers.separable_conv import SeparableConv2d, SeparableConvNormAct
|
from timm.layers.separable_conv import SeparableConv2d, SeparableConvNormAct
|
||||||
from timm.layers.space_to_depth import SpaceToDepthModule
|
|
||||||
from timm.layers.split_attn import SplitAttn
|
from timm.layers.split_attn import SplitAttn
|
||||||
from timm.layers.split_batchnorm import SplitBatchNorm2d, convert_splitbn_model
|
from timm.layers.split_batchnorm import SplitBatchNorm2d, convert_splitbn_model
|
||||||
from timm.layers.std_conv import StdConv2d, StdConv2dSame, ScaledStdConv2d, ScaledStdConv2dSame
|
from timm.layers.std_conv import StdConv2d, StdConv2dSame, ScaledStdConv2d, ScaledStdConv2dSame
|
||||||
|
|
|
@ -25,7 +25,7 @@ Modifications and additions for timm hacked together by / Copyright 2021, Ross W
|
||||||
# Copyright 2020 Ross Wightman, Apache-2.0 License
|
# Copyright 2020 Ross Wightman, Apache-2.0 License
|
||||||
from collections import OrderedDict
|
from collections import OrderedDict
|
||||||
from functools import partial
|
from functools import partial
|
||||||
from typing import Dict, List, Tuple, Union
|
from typing import Dict, List, Optional, Tuple, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
|
@ -638,7 +638,7 @@ class Levit(nn.Module):
|
||||||
def forward_intermediates(
|
def forward_intermediates(
|
||||||
self,
|
self,
|
||||||
x: torch.Tensor,
|
x: torch.Tensor,
|
||||||
indices: Union[int, List[int], Tuple[int]] = None,
|
indices: Optional[Union[int, List[int], Tuple[int]]] = None,
|
||||||
norm: bool = False,
|
norm: bool = False,
|
||||||
stop_early: bool = False,
|
stop_early: bool = False,
|
||||||
output_fmt: str = 'NCHW',
|
output_fmt: str = 'NCHW',
|
||||||
|
|
|
@ -1255,7 +1255,7 @@ class MaxxVit(nn.Module):
|
||||||
def forward_intermediates(
|
def forward_intermediates(
|
||||||
self,
|
self,
|
||||||
x: torch.Tensor,
|
x: torch.Tensor,
|
||||||
indices: Union[int, List[int], Tuple[int]] = None,
|
indices: Optional[Union[int, List[int], Tuple[int]]] = None,
|
||||||
norm: bool = False,
|
norm: bool = False,
|
||||||
stop_early: bool = False,
|
stop_early: bool = False,
|
||||||
output_fmt: str = 'NCHW',
|
output_fmt: str = 'NCHW',
|
||||||
|
|
|
@ -40,6 +40,7 @@ Hacked together by / Copyright 2021 Ross Wightman
|
||||||
"""
|
"""
|
||||||
import math
|
import math
|
||||||
from functools import partial
|
from functools import partial
|
||||||
|
from typing import List, Optional, Union, Tuple
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
|
@ -47,6 +48,7 @@ import torch.nn as nn
|
||||||
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
|
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
|
||||||
from timm.layers import PatchEmbed, Mlp, GluMlp, GatedMlp, DropPath, lecun_normal_, to_2tuple
|
from timm.layers import PatchEmbed, Mlp, GluMlp, GatedMlp, DropPath, lecun_normal_, to_2tuple
|
||||||
from ._builder import build_model_with_cfg
|
from ._builder import build_model_with_cfg
|
||||||
|
from ._features import feature_take_indices
|
||||||
from ._manipulate import named_apply, checkpoint_seq
|
from ._manipulate import named_apply, checkpoint_seq
|
||||||
from ._registry import generate_default_cfgs, register_model, register_model_deprecations
|
from ._registry import generate_default_cfgs, register_model, register_model_deprecations
|
||||||
|
|
||||||
|
@ -211,6 +213,7 @@ class MlpMixer(nn.Module):
|
||||||
embed_dim=embed_dim,
|
embed_dim=embed_dim,
|
||||||
norm_layer=norm_layer if stem_norm else None,
|
norm_layer=norm_layer if stem_norm else None,
|
||||||
)
|
)
|
||||||
|
reduction = self.stem.feat_ratio() if hasattr(self.stem, 'feat_ratio') else patch_size
|
||||||
# FIXME drop_path (stochastic depth scaling rule or all the same?)
|
# FIXME drop_path (stochastic depth scaling rule or all the same?)
|
||||||
self.blocks = nn.Sequential(*[
|
self.blocks = nn.Sequential(*[
|
||||||
block_layer(
|
block_layer(
|
||||||
|
@ -224,6 +227,8 @@ class MlpMixer(nn.Module):
|
||||||
drop_path=drop_path_rate,
|
drop_path=drop_path_rate,
|
||||||
)
|
)
|
||||||
for _ in range(num_blocks)])
|
for _ in range(num_blocks)])
|
||||||
|
self.feature_info = [
|
||||||
|
dict(module=f'blocks.{i}', num_chs=embed_dim, reduction=reduction) for i in range(num_blocks)]
|
||||||
self.norm = norm_layer(embed_dim)
|
self.norm = norm_layer(embed_dim)
|
||||||
self.head_drop = nn.Dropout(drop_rate)
|
self.head_drop = nn.Dropout(drop_rate)
|
||||||
self.head = nn.Linear(embed_dim, self.num_classes) if num_classes > 0 else nn.Identity()
|
self.head = nn.Linear(embed_dim, self.num_classes) if num_classes > 0 else nn.Identity()
|
||||||
|
@ -257,6 +262,76 @@ class MlpMixer(nn.Module):
|
||||||
self.global_pool = global_pool
|
self.global_pool = global_pool
|
||||||
self.head = nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity()
|
self.head = nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity()
|
||||||
|
|
||||||
|
def forward_intermediates(
|
||||||
|
self,
|
||||||
|
x: torch.Tensor,
|
||||||
|
indices: Optional[Union[int, List[int], Tuple[int]]] = None,
|
||||||
|
norm: bool = False,
|
||||||
|
stop_early: bool = False,
|
||||||
|
output_fmt: str = 'NCHW',
|
||||||
|
intermediates_only: bool = False,
|
||||||
|
) -> Union[List[torch.Tensor], Tuple[torch.Tensor, List[torch.Tensor]]]:
|
||||||
|
""" Forward features that returns intermediates.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
x: Input image tensor
|
||||||
|
indices: Take last n blocks if int, all if None, select matching indices if sequence
|
||||||
|
return_prefix_tokens: Return both prefix and spatial intermediate tokens
|
||||||
|
norm: Apply norm layer to all intermediates
|
||||||
|
stop_early: Stop iterating over blocks when last desired intermediate hit
|
||||||
|
output_fmt: Shape of intermediate feature outputs
|
||||||
|
intermediates_only: Only return intermediate features
|
||||||
|
Returns:
|
||||||
|
|
||||||
|
"""
|
||||||
|
assert output_fmt in ('NCHW', 'NLC'), 'Output format must be one of NCHW or NLC.'
|
||||||
|
reshape = output_fmt == 'NCHW'
|
||||||
|
intermediates = []
|
||||||
|
take_indices, max_index = feature_take_indices(len(self.blocks), indices)
|
||||||
|
|
||||||
|
# forward pass
|
||||||
|
B, _, height, width = x.shape
|
||||||
|
x = self.stem(x)
|
||||||
|
|
||||||
|
if torch.jit.is_scripting() or not stop_early: # can't slice blocks in torchscript
|
||||||
|
blocks = self.blocks
|
||||||
|
else:
|
||||||
|
blocks = self.blocks[:max_index + 1]
|
||||||
|
for i, blk in enumerate(blocks):
|
||||||
|
x = blk(x)
|
||||||
|
if i in take_indices:
|
||||||
|
# normalize intermediates with final norm layer if enabled
|
||||||
|
intermediates.append(self.norm(x) if norm else x)
|
||||||
|
|
||||||
|
# process intermediates
|
||||||
|
if reshape:
|
||||||
|
# reshape to BCHW output format
|
||||||
|
H, W = self.stem.dynamic_feat_size((height, width))
|
||||||
|
intermediates = [y.reshape(B, H, W, -1).permute(0, 3, 1, 2).contiguous() for y in intermediates]
|
||||||
|
|
||||||
|
if intermediates_only:
|
||||||
|
return intermediates
|
||||||
|
|
||||||
|
x = self.norm(x)
|
||||||
|
|
||||||
|
return x, intermediates
|
||||||
|
|
||||||
|
def prune_intermediate_layers(
|
||||||
|
self,
|
||||||
|
indices: Union[int, List[int], Tuple[int]] = 1,
|
||||||
|
prune_norm: bool = False,
|
||||||
|
prune_head: bool = True,
|
||||||
|
):
|
||||||
|
""" Prune layers not required for specified intermediates.
|
||||||
|
"""
|
||||||
|
take_indices, max_index = feature_take_indices(len(self.blocks), indices)
|
||||||
|
self.blocks = self.blocks[:max_index + 1] # truncate blocks
|
||||||
|
if prune_norm:
|
||||||
|
self.norm = nn.Identity()
|
||||||
|
if prune_head:
|
||||||
|
self.reset_classifier(0, '')
|
||||||
|
return take_indices
|
||||||
|
|
||||||
def forward_features(self, x):
|
def forward_features(self, x):
|
||||||
x = self.stem(x)
|
x = self.stem(x)
|
||||||
if self.grad_checkpointing and not torch.jit.is_scripting():
|
if self.grad_checkpointing and not torch.jit.is_scripting():
|
||||||
|
@ -330,14 +405,13 @@ def checkpoint_filter_fn(state_dict, model):
|
||||||
|
|
||||||
|
|
||||||
def _create_mixer(variant, pretrained=False, **kwargs):
|
def _create_mixer(variant, pretrained=False, **kwargs):
|
||||||
if kwargs.get('features_only', None):
|
out_indices = kwargs.pop('out_indices', 3)
|
||||||
raise RuntimeError('features_only not implemented for MLP-Mixer models.')
|
|
||||||
|
|
||||||
model = build_model_with_cfg(
|
model = build_model_with_cfg(
|
||||||
MlpMixer,
|
MlpMixer,
|
||||||
variant,
|
variant,
|
||||||
pretrained,
|
pretrained,
|
||||||
pretrained_filter_fn=checkpoint_filter_fn,
|
pretrained_filter_fn=checkpoint_filter_fn,
|
||||||
|
feature_cfg=dict(out_indices=out_indices, feature_cls='getter'),
|
||||||
**kwargs,
|
**kwargs,
|
||||||
)
|
)
|
||||||
return model
|
return model
|
||||||
|
|
|
@ -154,7 +154,7 @@ class MobileNetV3(nn.Module):
|
||||||
def forward_intermediates(
|
def forward_intermediates(
|
||||||
self,
|
self,
|
||||||
x: torch.Tensor,
|
x: torch.Tensor,
|
||||||
indices: Union[int, List[int], Tuple[int]] = None,
|
indices: Optional[Union[int, List[int], Tuple[int]]] = None,
|
||||||
norm: bool = False,
|
norm: bool = False,
|
||||||
stop_early: bool = False,
|
stop_early: bool = False,
|
||||||
output_fmt: str = 'NCHW',
|
output_fmt: str = 'NCHW',
|
||||||
|
|
|
@ -837,7 +837,7 @@ class MultiScaleVit(nn.Module):
|
||||||
def forward_intermediates(
|
def forward_intermediates(
|
||||||
self,
|
self,
|
||||||
x: torch.Tensor,
|
x: torch.Tensor,
|
||||||
indices: Union[int, List[int], Tuple[int]] = None,
|
indices: Optional[Union[int, List[int], Tuple[int]]] = None,
|
||||||
norm: bool = False,
|
norm: bool = False,
|
||||||
stop_early: bool = False,
|
stop_early: bool = False,
|
||||||
output_fmt: str = 'NCHW',
|
output_fmt: str = 'NCHW',
|
||||||
|
|
|
@ -26,7 +26,7 @@ Hacked together by / Copyright 2020 Ross Wightman
|
||||||
import math
|
import math
|
||||||
from dataclasses import dataclass, replace
|
from dataclasses import dataclass, replace
|
||||||
from functools import partial
|
from functools import partial
|
||||||
from typing import Optional, Union, Callable
|
from typing import Callable, List, Optional, Union, Tuple
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
|
@ -36,6 +36,7 @@ from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
|
||||||
from timm.layers import ClassifierHead, AvgPool2dSame, ConvNormAct, SEModule, DropPath, GroupNormAct
|
from timm.layers import ClassifierHead, AvgPool2dSame, ConvNormAct, SEModule, DropPath, GroupNormAct
|
||||||
from timm.layers import get_act_layer, get_norm_act_layer, create_conv2d, make_divisible
|
from timm.layers import get_act_layer, get_norm_act_layer, create_conv2d, make_divisible
|
||||||
from ._builder import build_model_with_cfg
|
from ._builder import build_model_with_cfg
|
||||||
|
from ._features import feature_take_indices
|
||||||
from ._manipulate import checkpoint_seq, named_apply
|
from ._manipulate import checkpoint_seq, named_apply
|
||||||
from ._registry import generate_default_cfgs, register_model, register_model_deprecations
|
from ._registry import generate_default_cfgs, register_model, register_model_deprecations
|
||||||
|
|
||||||
|
@ -515,6 +516,73 @@ class RegNet(nn.Module):
|
||||||
def reset_classifier(self, num_classes, global_pool='avg'):
|
def reset_classifier(self, num_classes, global_pool='avg'):
|
||||||
self.head.reset(num_classes, pool_type=global_pool)
|
self.head.reset(num_classes, pool_type=global_pool)
|
||||||
|
|
||||||
|
def forward_intermediates(
|
||||||
|
self,
|
||||||
|
x: torch.Tensor,
|
||||||
|
indices: Optional[Union[int, List[int], Tuple[int]]] = None,
|
||||||
|
norm: bool = False,
|
||||||
|
stop_early: bool = False,
|
||||||
|
output_fmt: str = 'NCHW',
|
||||||
|
intermediates_only: bool = False,
|
||||||
|
) -> Union[List[torch.Tensor], Tuple[torch.Tensor, List[torch.Tensor]]]:
|
||||||
|
""" Forward features that returns intermediates.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
x: Input image tensor
|
||||||
|
indices: Take last n blocks if int, all if None, select matching indices if sequence
|
||||||
|
norm: Apply norm layer to compatible intermediates
|
||||||
|
stop_early: Stop iterating over blocks when last desired intermediate hit
|
||||||
|
output_fmt: Shape of intermediate feature outputs
|
||||||
|
intermediates_only: Only return intermediate features
|
||||||
|
Returns:
|
||||||
|
|
||||||
|
"""
|
||||||
|
assert output_fmt in ('NCHW',), 'Output shape must be NCHW.'
|
||||||
|
intermediates = []
|
||||||
|
take_indices, max_index = feature_take_indices(5, indices)
|
||||||
|
|
||||||
|
# forward pass
|
||||||
|
feat_idx = 0
|
||||||
|
x = self.stem(x)
|
||||||
|
if feat_idx in take_indices:
|
||||||
|
intermediates.append(x)
|
||||||
|
|
||||||
|
layer_names = ('s1', 's2', 's3', 's4')
|
||||||
|
if stop_early:
|
||||||
|
layer_names = layer_names[:max_index]
|
||||||
|
for n in layer_names:
|
||||||
|
feat_idx += 1
|
||||||
|
x = getattr(self, n)(x) # won't work with torchscript, but keeps code reasonable, FML
|
||||||
|
if feat_idx in take_indices:
|
||||||
|
intermediates.append(x)
|
||||||
|
|
||||||
|
if intermediates_only:
|
||||||
|
return intermediates
|
||||||
|
|
||||||
|
if feat_idx == 4:
|
||||||
|
x = self.final_conv(x)
|
||||||
|
|
||||||
|
return x, intermediates
|
||||||
|
|
||||||
|
def prune_intermediate_layers(
|
||||||
|
self,
|
||||||
|
indices: Union[int, List[int], Tuple[int]] = 1,
|
||||||
|
prune_norm: bool = False,
|
||||||
|
prune_head: bool = True,
|
||||||
|
):
|
||||||
|
""" Prune layers not required for specified intermediates.
|
||||||
|
"""
|
||||||
|
take_indices, max_index = feature_take_indices(5, indices)
|
||||||
|
layer_names = ('s1', 's2', 's3', 's4')
|
||||||
|
layer_names = layer_names[max_index:]
|
||||||
|
for n in layer_names:
|
||||||
|
setattr(self, n, nn.Identity())
|
||||||
|
if max_index < 4:
|
||||||
|
self.final_conv = nn.Identity()
|
||||||
|
if prune_head:
|
||||||
|
self.reset_classifier(0, '')
|
||||||
|
return take_indices
|
||||||
|
|
||||||
def forward_features(self, x):
|
def forward_features(self, x):
|
||||||
x = self.stem(x)
|
x = self.stem(x)
|
||||||
x = self.s1(x)
|
x = self.s1(x)
|
||||||
|
|
|
@ -557,7 +557,7 @@ class ResNet(nn.Module):
|
||||||
def forward_intermediates(
|
def forward_intermediates(
|
||||||
self,
|
self,
|
||||||
x: torch.Tensor,
|
x: torch.Tensor,
|
||||||
indices: Union[int, List[int], Tuple[int]] = None,
|
indices: Optional[Union[int, List[int], Tuple[int]]] = None,
|
||||||
norm: bool = False,
|
norm: bool = False,
|
||||||
stop_early: bool = False,
|
stop_early: bool = False,
|
||||||
output_fmt: str = 'NCHW',
|
output_fmt: str = 'NCHW',
|
||||||
|
@ -576,8 +576,6 @@ class ResNet(nn.Module):
|
||||||
|
|
||||||
"""
|
"""
|
||||||
assert output_fmt in ('NCHW',), 'Output shape must be NCHW.'
|
assert output_fmt in ('NCHW',), 'Output shape must be NCHW.'
|
||||||
if stop_early:
|
|
||||||
assert intermediates_only, 'Must use intermediates_only for early stopping.'
|
|
||||||
intermediates = []
|
intermediates = []
|
||||||
take_indices, max_index = feature_take_indices(5, indices)
|
take_indices, max_index = feature_take_indices(5, indices)
|
||||||
|
|
||||||
|
|
|
@ -611,7 +611,7 @@ class SwinTransformer(nn.Module):
|
||||||
def forward_intermediates(
|
def forward_intermediates(
|
||||||
self,
|
self,
|
||||||
x: torch.Tensor,
|
x: torch.Tensor,
|
||||||
indices: Union[int, List[int], Tuple[int]] = None,
|
indices: Optional[Union[int, List[int], Tuple[int]]] = None,
|
||||||
norm: bool = False,
|
norm: bool = False,
|
||||||
stop_early: bool = False,
|
stop_early: bool = False,
|
||||||
output_fmt: str = 'NCHW',
|
output_fmt: str = 'NCHW',
|
||||||
|
|
|
@ -612,7 +612,7 @@ class SwinTransformerV2(nn.Module):
|
||||||
def forward_intermediates(
|
def forward_intermediates(
|
||||||
self,
|
self,
|
||||||
x: torch.Tensor,
|
x: torch.Tensor,
|
||||||
indices: Union[int, List[int], Tuple[int]] = None,
|
indices: Optional[Union[int, List[int], Tuple[int]]] = None,
|
||||||
norm: bool = False,
|
norm: bool = False,
|
||||||
stop_early: bool = False,
|
stop_early: bool = False,
|
||||||
output_fmt: str = 'NCHW',
|
output_fmt: str = 'NCHW',
|
||||||
|
|
|
@ -722,7 +722,7 @@ class SwinTransformerV2Cr(nn.Module):
|
||||||
def forward_intermediates(
|
def forward_intermediates(
|
||||||
self,
|
self,
|
||||||
x: torch.Tensor,
|
x: torch.Tensor,
|
||||||
indices: Union[int, List[int], Tuple[int]] = None,
|
indices: Optional[Union[int, List[int], Tuple[int]]] = None,
|
||||||
norm: bool = False,
|
norm: bool = False,
|
||||||
stop_early: bool = False,
|
stop_early: bool = False,
|
||||||
output_fmt: str = 'NCHW',
|
output_fmt: str = 'NCHW',
|
||||||
|
|
|
@ -407,7 +407,7 @@ class Twins(nn.Module):
|
||||||
def forward_intermediates(
|
def forward_intermediates(
|
||||||
self,
|
self,
|
||||||
x: torch.Tensor,
|
x: torch.Tensor,
|
||||||
indices: Union[int, List[int], Tuple[int]] = None,
|
indices: Optional[Union[int, List[int], Tuple[int]]] = None,
|
||||||
norm: bool = False,
|
norm: bool = False,
|
||||||
stop_early: bool = False,
|
stop_early: bool = False,
|
||||||
output_fmt: str = 'NCHW',
|
output_fmt: str = 'NCHW',
|
||||||
|
|
|
@ -489,7 +489,7 @@ class VisionTransformer(nn.Module):
|
||||||
**embed_args,
|
**embed_args,
|
||||||
)
|
)
|
||||||
num_patches = self.patch_embed.num_patches
|
num_patches = self.patch_embed.num_patches
|
||||||
r = self.patch_embed.feat_ratio() if hasattr(self.patch_embed, 'feat_ratio') else patch_size
|
reduction = self.patch_embed.feat_ratio() if hasattr(self.patch_embed, 'feat_ratio') else patch_size
|
||||||
|
|
||||||
self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) if class_token else None
|
self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) if class_token else None
|
||||||
self.reg_token = nn.Parameter(torch.zeros(1, reg_tokens, embed_dim)) if reg_tokens else None
|
self.reg_token = nn.Parameter(torch.zeros(1, reg_tokens, embed_dim)) if reg_tokens else None
|
||||||
|
@ -523,7 +523,7 @@ class VisionTransformer(nn.Module):
|
||||||
)
|
)
|
||||||
for i in range(depth)])
|
for i in range(depth)])
|
||||||
self.feature_info = [
|
self.feature_info = [
|
||||||
dict(module=f'blocks.{i}', num_chs=embed_dim, reduction=r) for i in range(depth)]
|
dict(module=f'blocks.{i}', num_chs=embed_dim, reduction=reduction) for i in range(depth)]
|
||||||
self.norm = norm_layer(embed_dim) if not use_fc_norm else nn.Identity()
|
self.norm = norm_layer(embed_dim) if not use_fc_norm else nn.Identity()
|
||||||
|
|
||||||
# Classifier Head
|
# Classifier Head
|
||||||
|
@ -1790,23 +1790,39 @@ default_cfgs = {
|
||||||
license='mit',
|
license='mit',
|
||||||
mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, num_classes=512),
|
mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, num_classes=512),
|
||||||
|
|
||||||
'vit_wee_patch16_reg1_gap_256': _cfg(
|
'vit_wee_patch16_reg1_gap_256.sbb_in1k': _cfg(
|
||||||
#file='',
|
|
||||||
input_size=(3, 256, 256), crop_pct=0.95),
|
input_size=(3, 256, 256), crop_pct=0.95),
|
||||||
'vit_little_patch16_reg4_gap_256': _cfg(
|
'vit_pwee_patch16_reg1_gap_256.sbb_in1k': _cfg(
|
||||||
#file='',
|
hf_hub_id='timm/',
|
||||||
input_size=(3, 256, 256), crop_pct=0.95),
|
input_size=(3, 256, 256), crop_pct=0.95),
|
||||||
'vit_medium_patch16_reg1_gap_256': _cfg(
|
'vit_little_patch16_reg4_gap_256.sbb_in1k': _cfg(
|
||||||
#file='vit_medium_gap1-in1k-20231118-8.pth',
|
hf_hub_id='timm/',
|
||||||
input_size=(3, 256, 256), crop_pct=0.95),
|
input_size=(3, 256, 256), crop_pct=0.95),
|
||||||
'vit_medium_patch16_reg4_gap_256': _cfg(
|
'vit_medium_patch16_reg1_gap_256.sbb_in1k': _cfg(
|
||||||
#file='vit_medium_gap4-in1k-20231115-8.pth',
|
hf_hub_id='timm/',
|
||||||
input_size=(3, 256, 256), crop_pct=0.95),
|
input_size=(3, 256, 256), crop_pct=0.95),
|
||||||
'vit_betwixt_patch16_reg1_gap_256': _cfg(
|
'vit_medium_patch16_reg4_gap_256.sbb_in1k': _cfg(
|
||||||
#file='vit_betwixt_gap1-in1k-20231121-8.pth',
|
hf_hub_id='timm/',
|
||||||
input_size=(3, 256, 256), crop_pct=0.95),
|
input_size=(3, 256, 256), crop_pct=0.95),
|
||||||
'vit_betwixt_patch16_reg4_gap_256': _cfg(
|
'vit_mediumd_patch16_reg4_gap_256.sbb_in12k_ft_in1k': _cfg(
|
||||||
#file='vit_betwixt_gap4-in1k-20231106-8.pth',
|
hf_hub_id='timm/',
|
||||||
|
input_size=(3, 256, 256), crop_pct=0.95),
|
||||||
|
'vit_mediumd_patch16_reg4_gap_256.sbb_in12k': _cfg(
|
||||||
|
hf_hub_id='timm/',
|
||||||
|
num_classes=11821,
|
||||||
|
input_size=(3, 256, 256), crop_pct=0.95),
|
||||||
|
'vit_betwixt_patch16_reg1_gap_256.sbb_in1k': _cfg(
|
||||||
|
hf_hub_id='timm/',
|
||||||
|
input_size=(3, 256, 256), crop_pct=0.95),
|
||||||
|
'vit_betwixt_patch16_reg4_gap_256.sbb_in12k_ft_in1k': _cfg(
|
||||||
|
hf_hub_id='timm/',
|
||||||
|
input_size=(3, 256, 256), crop_pct=0.95),
|
||||||
|
'vit_betwixt_patch16_reg4_gap_256.sbb_in1k': _cfg(
|
||||||
|
hf_hub_id='timm/',
|
||||||
|
input_size=(3, 256, 256), crop_pct=0.95),
|
||||||
|
'vit_betwixt_patch16_reg4_gap_256.sbb_in12k': _cfg(
|
||||||
|
hf_hub_id='timm/',
|
||||||
|
num_classes=11821,
|
||||||
input_size=(3, 256, 256), crop_pct=0.95),
|
input_size=(3, 256, 256), crop_pct=0.95),
|
||||||
'vit_base_patch16_reg4_gap_256': _cfg(
|
'vit_base_patch16_reg4_gap_256': _cfg(
|
||||||
input_size=(3, 256, 256)),
|
input_size=(3, 256, 256)),
|
||||||
|
@ -2755,10 +2771,21 @@ def vit_so400m_patch14_siglip_384(pretrained: bool = False, **kwargs) -> VisionT
|
||||||
def vit_wee_patch16_reg1_gap_256(pretrained: bool = False, **kwargs) -> VisionTransformer:
|
def vit_wee_patch16_reg1_gap_256(pretrained: bool = False, **kwargs) -> VisionTransformer:
|
||||||
model_args = dict(
|
model_args = dict(
|
||||||
patch_size=16, embed_dim=256, depth=14, num_heads=4, init_values=1e-5, mlp_ratio=5,
|
patch_size=16, embed_dim=256, depth=14, num_heads=4, init_values=1e-5, mlp_ratio=5,
|
||||||
|
class_token=False, no_embed_class=True, reg_tokens=1, global_pool='avg',
|
||||||
|
)
|
||||||
|
model = _create_vision_transformer(
|
||||||
|
'vit_wee_patch16_reg1_gap_256', pretrained=pretrained, **dict(model_args, **kwargs))
|
||||||
|
return model
|
||||||
|
|
||||||
|
|
||||||
|
@register_model
|
||||||
|
def vit_pwee_patch16_reg1_gap_256(pretrained: bool = False, **kwargs) -> VisionTransformer:
|
||||||
|
model_args = dict(
|
||||||
|
patch_size=16, embed_dim=256, depth=16, num_heads=4, init_values=1e-5, mlp_ratio=5,
|
||||||
class_token=False, no_embed_class=True, reg_tokens=1, global_pool='avg', block_fn=ParallelScalingBlock,
|
class_token=False, no_embed_class=True, reg_tokens=1, global_pool='avg', block_fn=ParallelScalingBlock,
|
||||||
)
|
)
|
||||||
model = _create_vision_transformer(
|
model = _create_vision_transformer(
|
||||||
'vit_medium_patch16_reg1_gap_256', pretrained=pretrained, **dict(model_args, **kwargs))
|
'vit_pwee_patch16_reg1_gap_256', pretrained=pretrained, **dict(model_args, **kwargs))
|
||||||
return model
|
return model
|
||||||
|
|
||||||
|
|
||||||
|
@ -2769,7 +2796,7 @@ def vit_little_patch16_reg4_gap_256(pretrained: bool = False, **kwargs) -> Visio
|
||||||
class_token=False, no_embed_class=True, reg_tokens=4, global_pool='avg',
|
class_token=False, no_embed_class=True, reg_tokens=4, global_pool='avg',
|
||||||
)
|
)
|
||||||
model = _create_vision_transformer(
|
model = _create_vision_transformer(
|
||||||
'vit_medium_patch16_reg1_gap_256', pretrained=pretrained, **dict(model_args, **kwargs))
|
'vit_little_patch16_reg4_gap_256', pretrained=pretrained, **dict(model_args, **kwargs))
|
||||||
return model
|
return model
|
||||||
|
|
||||||
|
|
||||||
|
@ -2795,6 +2822,17 @@ def vit_medium_patch16_reg4_gap_256(pretrained: bool = False, **kwargs) -> Visio
|
||||||
return model
|
return model
|
||||||
|
|
||||||
|
|
||||||
|
@register_model
|
||||||
|
def vit_mediumd_patch16_reg4_gap_256(pretrained: bool = False, **kwargs) -> VisionTransformer:
|
||||||
|
model_args = dict(
|
||||||
|
patch_size=16, embed_dim=512, depth=20, num_heads=8, init_values=1e-5,
|
||||||
|
class_token=False, no_embed_class=True, reg_tokens=4, global_pool='avg',
|
||||||
|
)
|
||||||
|
model = _create_vision_transformer(
|
||||||
|
'vit_mediumd_patch16_reg4_gap_256', pretrained=pretrained, **dict(model_args, **kwargs))
|
||||||
|
return model
|
||||||
|
|
||||||
|
|
||||||
@register_model
|
@register_model
|
||||||
def vit_betwixt_patch16_reg1_gap_256(pretrained: bool = False, **kwargs) -> VisionTransformer:
|
def vit_betwixt_patch16_reg1_gap_256(pretrained: bool = False, **kwargs) -> VisionTransformer:
|
||||||
model_args = dict(
|
model_args = dict(
|
||||||
|
|
|
@ -543,7 +543,7 @@ class VisionTransformerSAM(nn.Module):
|
||||||
def forward_intermediates(
|
def forward_intermediates(
|
||||||
self,
|
self,
|
||||||
x: torch.Tensor,
|
x: torch.Tensor,
|
||||||
indices: Union[int, List[int], Tuple[int]] = None,
|
indices: Optional[Union[int, List[int], Tuple[int]]] = None,
|
||||||
norm: bool = False,
|
norm: bool = False,
|
||||||
stop_early: bool = False,
|
stop_early: bool = False,
|
||||||
output_fmt: str = 'NCHW',
|
output_fmt: str = 'NCHW',
|
||||||
|
@ -598,7 +598,7 @@ class VisionTransformerSAM(nn.Module):
|
||||||
|
|
||||||
def prune_intermediate_layers(
|
def prune_intermediate_layers(
|
||||||
self,
|
self,
|
||||||
indices: Union[int, List[int], Tuple[int]] = None,
|
indices: Optional[Union[int, List[int], Tuple[int]]] = None,
|
||||||
prune_norm: bool = False,
|
prune_norm: bool = False,
|
||||||
prune_head: bool = True,
|
prune_head: bool = True,
|
||||||
):
|
):
|
||||||
|
|
|
@ -1,4 +1,5 @@
|
||||||
from .agc import adaptive_clip_grad
|
from .agc import adaptive_clip_grad
|
||||||
|
from .attention_extract import AttentionExtract
|
||||||
from .checkpoint_saver import CheckpointSaver
|
from .checkpoint_saver import CheckpointSaver
|
||||||
from .clip_grad import dispatch_clip_grad
|
from .clip_grad import dispatch_clip_grad
|
||||||
from .cuda import ApexScaler, NativeScaler
|
from .cuda import ApexScaler, NativeScaler
|
||||||
|
|
|
@ -0,0 +1,79 @@
|
||||||
|
import fnmatch
|
||||||
|
from collections import OrderedDict
|
||||||
|
from typing import Union, Optional, List
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
|
|
||||||
|
class AttentionExtract(torch.nn.Module):
|
||||||
|
# defaults should cover a significant number of timm models with attention maps.
|
||||||
|
default_node_names = ['*attn.softmax']
|
||||||
|
default_module_names = ['*attn_drop']
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
model: Union[torch.nn.Module],
|
||||||
|
names: Optional[List[str]] = None,
|
||||||
|
mode: str = 'eval',
|
||||||
|
method: str = 'fx',
|
||||||
|
hook_type: str = 'forward',
|
||||||
|
):
|
||||||
|
""" Extract attention maps (or other activations) from a model by name.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model: Instantiated model to extract from.
|
||||||
|
names: List of concrete or wildcard names to extract. Names are nodes for fx and modules for hooks.
|
||||||
|
mode: 'train' or 'eval' model mode.
|
||||||
|
method: 'fx' or 'hook' extraction method.
|
||||||
|
hook_type: 'forward' or 'forward_pre' hooks used.
|
||||||
|
"""
|
||||||
|
super().__init__()
|
||||||
|
assert mode in ('train', 'eval')
|
||||||
|
if mode == 'train':
|
||||||
|
model = model.train()
|
||||||
|
else:
|
||||||
|
model = model.eval()
|
||||||
|
|
||||||
|
assert method in ('fx', 'hook')
|
||||||
|
if method == 'fx':
|
||||||
|
# names are activation node names
|
||||||
|
from timm.models._features_fx import get_graph_node_names, GraphExtractNet
|
||||||
|
|
||||||
|
node_names = get_graph_node_names(model)[0 if mode == 'train' else 1]
|
||||||
|
matched = []
|
||||||
|
names = names or self.default_node_names
|
||||||
|
for n in names:
|
||||||
|
matched.extend(fnmatch.filter(node_names, n))
|
||||||
|
if not matched:
|
||||||
|
raise RuntimeError(f'No node names found matching {names}.')
|
||||||
|
|
||||||
|
self.model = GraphExtractNet(model, matched)
|
||||||
|
self.hooks = None
|
||||||
|
else:
|
||||||
|
# names are module names
|
||||||
|
assert hook_type in ('forward', 'forward_pre')
|
||||||
|
from timm.models._features import FeatureHooks
|
||||||
|
|
||||||
|
module_names = [n for n, m in model.named_modules()]
|
||||||
|
matched = []
|
||||||
|
names = names or self.default_module_names
|
||||||
|
for n in names:
|
||||||
|
matched.extend(fnmatch.filter(module_names, n))
|
||||||
|
if not matched:
|
||||||
|
raise RuntimeError(f'No module names found matching {names}.')
|
||||||
|
|
||||||
|
self.model = model
|
||||||
|
self.hooks = FeatureHooks(matched, model.named_modules(), default_hook_type=hook_type)
|
||||||
|
|
||||||
|
self.names = matched
|
||||||
|
self.mode = mode
|
||||||
|
self.method = method
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
if self.hooks is not None:
|
||||||
|
self.model(x)
|
||||||
|
output = self.hooks.get_output(device=x.device)
|
||||||
|
else:
|
||||||
|
output = self.model(x)
|
||||||
|
output = OrderedDict(zip(self.names, output))
|
||||||
|
return output
|
|
@ -1 +1 @@
|
||||||
__version__ = '1.0.0.dev0'
|
__version__ = '1.0.1.dev0'
|
||||||
|
|
Loading…
Reference in New Issue