mirror of
https://github.com/huggingface/pytorch-image-models.git
synced 2025-06-03 15:01:08 +08:00
Remove JIT activations, take jit out of ME activations. Remove other instances of torch.jit.script. Breaks torch.compile and is much less performant. Remove SpaceToDepthModule
This commit is contained in:
parent
07535f408a
commit
2bfa5e5d74
@ -47,7 +47,7 @@ from .pos_embed_sincos import pixel_freq_bands, freq_bands, build_sincos2d_pos_e
|
|||||||
from .squeeze_excite import SEModule, SqueezeExcite, EffectiveSEModule, EffectiveSqueezeExcite
|
from .squeeze_excite import SEModule, SqueezeExcite, EffectiveSEModule, EffectiveSqueezeExcite
|
||||||
from .selective_kernel import SelectiveKernel
|
from .selective_kernel import SelectiveKernel
|
||||||
from .separable_conv import SeparableConv2d, SeparableConvNormAct
|
from .separable_conv import SeparableConv2d, SeparableConvNormAct
|
||||||
from .space_to_depth import SpaceToDepthModule, SpaceToDepth, DepthToSpace
|
from .space_to_depth import SpaceToDepth, DepthToSpace
|
||||||
from .split_attn import SplitAttn
|
from .split_attn import SplitAttn
|
||||||
from .split_batchnorm import SplitBatchNorm2d, convert_splitbn_model
|
from .split_batchnorm import SplitBatchNorm2d, convert_splitbn_model
|
||||||
from .std_conv import StdConv2d, StdConv2dSame, ScaledStdConv2d, ScaledStdConv2dSame
|
from .std_conv import StdConv2d, StdConv2dSame, ScaledStdConv2d, ScaledStdConv2dSame
|
||||||
|
@ -1,90 +0,0 @@
|
|||||||
""" Activations
|
|
||||||
|
|
||||||
A collection of jit-scripted activations fn and modules with a common interface so that they can
|
|
||||||
easily be swapped. All have an `inplace` arg even if not used.
|
|
||||||
|
|
||||||
All jit scripted activations are lacking in-place variations on purpose, scripted kernel fusion does not
|
|
||||||
currently work across in-place op boundaries, thus performance is equal to or less than the non-scripted
|
|
||||||
versions if they contain in-place ops.
|
|
||||||
|
|
||||||
Hacked together by / Copyright 2020 Ross Wightman
|
|
||||||
"""
|
|
||||||
|
|
||||||
import torch
|
|
||||||
from torch import nn as nn
|
|
||||||
from torch.nn import functional as F
|
|
||||||
|
|
||||||
|
|
||||||
@torch.jit.script
|
|
||||||
def swish_jit(x, inplace: bool = False):
|
|
||||||
"""Swish - Described in: https://arxiv.org/abs/1710.05941
|
|
||||||
"""
|
|
||||||
return x.mul(x.sigmoid())
|
|
||||||
|
|
||||||
|
|
||||||
@torch.jit.script
|
|
||||||
def mish_jit(x, _inplace: bool = False):
|
|
||||||
"""Mish: A Self Regularized Non-Monotonic Neural Activation Function - https://arxiv.org/abs/1908.08681
|
|
||||||
"""
|
|
||||||
return x.mul(F.softplus(x).tanh())
|
|
||||||
|
|
||||||
|
|
||||||
class SwishJit(nn.Module):
|
|
||||||
def __init__(self, inplace: bool = False):
|
|
||||||
super(SwishJit, self).__init__()
|
|
||||||
|
|
||||||
def forward(self, x):
|
|
||||||
return swish_jit(x)
|
|
||||||
|
|
||||||
|
|
||||||
class MishJit(nn.Module):
|
|
||||||
def __init__(self, inplace: bool = False):
|
|
||||||
super(MishJit, self).__init__()
|
|
||||||
|
|
||||||
def forward(self, x):
|
|
||||||
return mish_jit(x)
|
|
||||||
|
|
||||||
|
|
||||||
@torch.jit.script
|
|
||||||
def hard_sigmoid_jit(x, inplace: bool = False):
|
|
||||||
# return F.relu6(x + 3.) / 6.
|
|
||||||
return (x + 3).clamp(min=0, max=6).div(6.) # clamp seems ever so slightly faster?
|
|
||||||
|
|
||||||
|
|
||||||
class HardSigmoidJit(nn.Module):
|
|
||||||
def __init__(self, inplace: bool = False):
|
|
||||||
super(HardSigmoidJit, self).__init__()
|
|
||||||
|
|
||||||
def forward(self, x):
|
|
||||||
return hard_sigmoid_jit(x)
|
|
||||||
|
|
||||||
|
|
||||||
@torch.jit.script
|
|
||||||
def hard_swish_jit(x, inplace: bool = False):
|
|
||||||
# return x * (F.relu6(x + 3.) / 6)
|
|
||||||
return x * (x + 3).clamp(min=0, max=6).div(6.) # clamp seems ever so slightly faster?
|
|
||||||
|
|
||||||
|
|
||||||
class HardSwishJit(nn.Module):
|
|
||||||
def __init__(self, inplace: bool = False):
|
|
||||||
super(HardSwishJit, self).__init__()
|
|
||||||
|
|
||||||
def forward(self, x):
|
|
||||||
return hard_swish_jit(x)
|
|
||||||
|
|
||||||
|
|
||||||
@torch.jit.script
|
|
||||||
def hard_mish_jit(x, inplace: bool = False):
|
|
||||||
""" Hard Mish
|
|
||||||
Experimental, based on notes by Mish author Diganta Misra at
|
|
||||||
https://github.com/digantamisra98/H-Mish/blob/0da20d4bc58e696b6803f2523c58d3c8a82782d0/README.md
|
|
||||||
"""
|
|
||||||
return 0.5 * x * (x + 2).clamp(min=0, max=2)
|
|
||||||
|
|
||||||
|
|
||||||
class HardMishJit(nn.Module):
|
|
||||||
def __init__(self, inplace: bool = False):
|
|
||||||
super(HardMishJit, self).__init__()
|
|
||||||
|
|
||||||
def forward(self, x):
|
|
||||||
return hard_mish_jit(x)
|
|
@ -3,8 +3,8 @@
|
|||||||
A collection of activations fn and modules with a common interface so that they can
|
A collection of activations fn and modules with a common interface so that they can
|
||||||
easily be swapped. All have an `inplace` arg even if not used.
|
easily be swapped. All have an `inplace` arg even if not used.
|
||||||
|
|
||||||
These activations are not compatible with jit scripting or ONNX export of the model, please use either
|
These activations are not compatible with jit scripting or ONNX export of the model, please use
|
||||||
the JIT or basic versions of the activations.
|
basic versions of the activations.
|
||||||
|
|
||||||
Hacked together by / Copyright 2020 Ross Wightman
|
Hacked together by / Copyright 2020 Ross Wightman
|
||||||
"""
|
"""
|
||||||
@ -14,19 +14,17 @@ from torch import nn as nn
|
|||||||
from torch.nn import functional as F
|
from torch.nn import functional as F
|
||||||
|
|
||||||
|
|
||||||
@torch.jit.script
|
def swish_fwd(x):
|
||||||
def swish_jit_fwd(x):
|
|
||||||
return x.mul(torch.sigmoid(x))
|
return x.mul(torch.sigmoid(x))
|
||||||
|
|
||||||
|
|
||||||
@torch.jit.script
|
def swish_bwd(x, grad_output):
|
||||||
def swish_jit_bwd(x, grad_output):
|
|
||||||
x_sigmoid = torch.sigmoid(x)
|
x_sigmoid = torch.sigmoid(x)
|
||||||
return grad_output * (x_sigmoid * (1 + x * (1 - x_sigmoid)))
|
return grad_output * (x_sigmoid * (1 + x * (1 - x_sigmoid)))
|
||||||
|
|
||||||
|
|
||||||
class SwishJitAutoFn(torch.autograd.Function):
|
class SwishAutoFn(torch.autograd.Function):
|
||||||
""" torch.jit.script optimised Swish w/ memory-efficient checkpoint
|
""" optimised Swish w/ memory-efficient checkpoint
|
||||||
Inspired by conversation btw Jeremy Howard & Adam Pazske
|
Inspired by conversation btw Jeremy Howard & Adam Pazske
|
||||||
https://twitter.com/jeremyphoward/status/1188251041835315200
|
https://twitter.com/jeremyphoward/status/1188251041835315200
|
||||||
"""
|
"""
|
||||||
@ -37,16 +35,16 @@ class SwishJitAutoFn(torch.autograd.Function):
|
|||||||
@staticmethod
|
@staticmethod
|
||||||
def forward(ctx, x):
|
def forward(ctx, x):
|
||||||
ctx.save_for_backward(x)
|
ctx.save_for_backward(x)
|
||||||
return swish_jit_fwd(x)
|
return swish_fwd(x)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def backward(ctx, grad_output):
|
def backward(ctx, grad_output):
|
||||||
x = ctx.saved_tensors[0]
|
x = ctx.saved_tensors[0]
|
||||||
return swish_jit_bwd(x, grad_output)
|
return swish_bwd(x, grad_output)
|
||||||
|
|
||||||
|
|
||||||
def swish_me(x, inplace=False):
|
def swish_me(x, inplace=False):
|
||||||
return SwishJitAutoFn.apply(x)
|
return SwishAutoFn.apply(x)
|
||||||
|
|
||||||
|
|
||||||
class SwishMe(nn.Module):
|
class SwishMe(nn.Module):
|
||||||
@ -54,38 +52,36 @@ class SwishMe(nn.Module):
|
|||||||
super(SwishMe, self).__init__()
|
super(SwishMe, self).__init__()
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
return SwishJitAutoFn.apply(x)
|
return SwishAutoFn.apply(x)
|
||||||
|
|
||||||
|
|
||||||
@torch.jit.script
|
def mish_fwd(x):
|
||||||
def mish_jit_fwd(x):
|
|
||||||
return x.mul(torch.tanh(F.softplus(x)))
|
return x.mul(torch.tanh(F.softplus(x)))
|
||||||
|
|
||||||
|
|
||||||
@torch.jit.script
|
def mish_bwd(x, grad_output):
|
||||||
def mish_jit_bwd(x, grad_output):
|
|
||||||
x_sigmoid = torch.sigmoid(x)
|
x_sigmoid = torch.sigmoid(x)
|
||||||
x_tanh_sp = F.softplus(x).tanh()
|
x_tanh_sp = F.softplus(x).tanh()
|
||||||
return grad_output.mul(x_tanh_sp + x * x_sigmoid * (1 - x_tanh_sp * x_tanh_sp))
|
return grad_output.mul(x_tanh_sp + x * x_sigmoid * (1 - x_tanh_sp * x_tanh_sp))
|
||||||
|
|
||||||
|
|
||||||
class MishJitAutoFn(torch.autograd.Function):
|
class MishAutoFn(torch.autograd.Function):
|
||||||
""" Mish: A Self Regularized Non-Monotonic Neural Activation Function - https://arxiv.org/abs/1908.08681
|
""" Mish: A Self Regularized Non-Monotonic Neural Activation Function - https://arxiv.org/abs/1908.08681
|
||||||
A memory efficient, jit scripted variant of Mish
|
A memory efficient variant of Mish
|
||||||
"""
|
"""
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def forward(ctx, x):
|
def forward(ctx, x):
|
||||||
ctx.save_for_backward(x)
|
ctx.save_for_backward(x)
|
||||||
return mish_jit_fwd(x)
|
return mish_fwd(x)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def backward(ctx, grad_output):
|
def backward(ctx, grad_output):
|
||||||
x = ctx.saved_tensors[0]
|
x = ctx.saved_tensors[0]
|
||||||
return mish_jit_bwd(x, grad_output)
|
return mish_bwd(x, grad_output)
|
||||||
|
|
||||||
|
|
||||||
def mish_me(x, inplace=False):
|
def mish_me(x, inplace=False):
|
||||||
return MishJitAutoFn.apply(x)
|
return MishAutoFn.apply(x)
|
||||||
|
|
||||||
|
|
||||||
class MishMe(nn.Module):
|
class MishMe(nn.Module):
|
||||||
@ -93,34 +89,32 @@ class MishMe(nn.Module):
|
|||||||
super(MishMe, self).__init__()
|
super(MishMe, self).__init__()
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
return MishJitAutoFn.apply(x)
|
return MishAutoFn.apply(x)
|
||||||
|
|
||||||
|
|
||||||
@torch.jit.script
|
def hard_sigmoid_fwd(x, inplace: bool = False):
|
||||||
def hard_sigmoid_jit_fwd(x, inplace: bool = False):
|
|
||||||
return (x + 3).clamp(min=0, max=6).div(6.)
|
return (x + 3).clamp(min=0, max=6).div(6.)
|
||||||
|
|
||||||
|
|
||||||
@torch.jit.script
|
def hard_sigmoid_bwd(x, grad_output):
|
||||||
def hard_sigmoid_jit_bwd(x, grad_output):
|
|
||||||
m = torch.ones_like(x) * ((x >= -3.) & (x <= 3.)) / 6.
|
m = torch.ones_like(x) * ((x >= -3.) & (x <= 3.)) / 6.
|
||||||
return grad_output * m
|
return grad_output * m
|
||||||
|
|
||||||
|
|
||||||
class HardSigmoidJitAutoFn(torch.autograd.Function):
|
class HardSigmoidAutoFn(torch.autograd.Function):
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def forward(ctx, x):
|
def forward(ctx, x):
|
||||||
ctx.save_for_backward(x)
|
ctx.save_for_backward(x)
|
||||||
return hard_sigmoid_jit_fwd(x)
|
return hard_sigmoid_fwd(x)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def backward(ctx, grad_output):
|
def backward(ctx, grad_output):
|
||||||
x = ctx.saved_tensors[0]
|
x = ctx.saved_tensors[0]
|
||||||
return hard_sigmoid_jit_bwd(x, grad_output)
|
return hard_sigmoid_bwd(x, grad_output)
|
||||||
|
|
||||||
|
|
||||||
def hard_sigmoid_me(x, inplace: bool = False):
|
def hard_sigmoid_me(x, inplace: bool = False):
|
||||||
return HardSigmoidJitAutoFn.apply(x)
|
return HardSigmoidAutoFn.apply(x)
|
||||||
|
|
||||||
|
|
||||||
class HardSigmoidMe(nn.Module):
|
class HardSigmoidMe(nn.Module):
|
||||||
@ -128,32 +122,30 @@ class HardSigmoidMe(nn.Module):
|
|||||||
super(HardSigmoidMe, self).__init__()
|
super(HardSigmoidMe, self).__init__()
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
return HardSigmoidJitAutoFn.apply(x)
|
return HardSigmoidAutoFn.apply(x)
|
||||||
|
|
||||||
|
|
||||||
@torch.jit.script
|
def hard_swish_fwd(x):
|
||||||
def hard_swish_jit_fwd(x):
|
|
||||||
return x * (x + 3).clamp(min=0, max=6).div(6.)
|
return x * (x + 3).clamp(min=0, max=6).div(6.)
|
||||||
|
|
||||||
|
|
||||||
@torch.jit.script
|
def hard_swish_bwd(x, grad_output):
|
||||||
def hard_swish_jit_bwd(x, grad_output):
|
|
||||||
m = torch.ones_like(x) * (x >= 3.)
|
m = torch.ones_like(x) * (x >= 3.)
|
||||||
m = torch.where((x >= -3.) & (x <= 3.), x / 3. + .5, m)
|
m = torch.where((x >= -3.) & (x <= 3.), x / 3. + .5, m)
|
||||||
return grad_output * m
|
return grad_output * m
|
||||||
|
|
||||||
|
|
||||||
class HardSwishJitAutoFn(torch.autograd.Function):
|
class HardSwishAutoFn(torch.autograd.Function):
|
||||||
"""A memory efficient, jit-scripted HardSwish activation"""
|
"""A memory efficient HardSwish activation"""
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def forward(ctx, x):
|
def forward(ctx, x):
|
||||||
ctx.save_for_backward(x)
|
ctx.save_for_backward(x)
|
||||||
return hard_swish_jit_fwd(x)
|
return hard_swish_fwd(x)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def backward(ctx, grad_output):
|
def backward(ctx, grad_output):
|
||||||
x = ctx.saved_tensors[0]
|
x = ctx.saved_tensors[0]
|
||||||
return hard_swish_jit_bwd(x, grad_output)
|
return hard_swish_bwd(x, grad_output)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def symbolic(g, self):
|
def symbolic(g, self):
|
||||||
@ -164,7 +156,7 @@ class HardSwishJitAutoFn(torch.autograd.Function):
|
|||||||
|
|
||||||
|
|
||||||
def hard_swish_me(x, inplace=False):
|
def hard_swish_me(x, inplace=False):
|
||||||
return HardSwishJitAutoFn.apply(x)
|
return HardSwishAutoFn.apply(x)
|
||||||
|
|
||||||
|
|
||||||
class HardSwishMe(nn.Module):
|
class HardSwishMe(nn.Module):
|
||||||
@ -172,39 +164,37 @@ class HardSwishMe(nn.Module):
|
|||||||
super(HardSwishMe, self).__init__()
|
super(HardSwishMe, self).__init__()
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
return HardSwishJitAutoFn.apply(x)
|
return HardSwishAutoFn.apply(x)
|
||||||
|
|
||||||
|
|
||||||
@torch.jit.script
|
def hard_mish_fwd(x):
|
||||||
def hard_mish_jit_fwd(x):
|
|
||||||
return 0.5 * x * (x + 2).clamp(min=0, max=2)
|
return 0.5 * x * (x + 2).clamp(min=0, max=2)
|
||||||
|
|
||||||
|
|
||||||
@torch.jit.script
|
def hard_mish_bwd(x, grad_output):
|
||||||
def hard_mish_jit_bwd(x, grad_output):
|
|
||||||
m = torch.ones_like(x) * (x >= -2.)
|
m = torch.ones_like(x) * (x >= -2.)
|
||||||
m = torch.where((x >= -2.) & (x <= 0.), x + 1., m)
|
m = torch.where((x >= -2.) & (x <= 0.), x + 1., m)
|
||||||
return grad_output * m
|
return grad_output * m
|
||||||
|
|
||||||
|
|
||||||
class HardMishJitAutoFn(torch.autograd.Function):
|
class HardMishAutoFn(torch.autograd.Function):
|
||||||
""" A memory efficient, jit scripted variant of Hard Mish
|
""" A memory efficient variant of Hard Mish
|
||||||
Experimental, based on notes by Mish author Diganta Misra at
|
Experimental, based on notes by Mish author Diganta Misra at
|
||||||
https://github.com/digantamisra98/H-Mish/blob/0da20d4bc58e696b6803f2523c58d3c8a82782d0/README.md
|
https://github.com/digantamisra98/H-Mish/blob/0da20d4bc58e696b6803f2523c58d3c8a82782d0/README.md
|
||||||
"""
|
"""
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def forward(ctx, x):
|
def forward(ctx, x):
|
||||||
ctx.save_for_backward(x)
|
ctx.save_for_backward(x)
|
||||||
return hard_mish_jit_fwd(x)
|
return hard_mish_fwd(x)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def backward(ctx, grad_output):
|
def backward(ctx, grad_output):
|
||||||
x = ctx.saved_tensors[0]
|
x = ctx.saved_tensors[0]
|
||||||
return hard_mish_jit_bwd(x, grad_output)
|
return hard_mish_bwd(x, grad_output)
|
||||||
|
|
||||||
|
|
||||||
def hard_mish_me(x, inplace: bool = False):
|
def hard_mish_me(x, inplace: bool = False):
|
||||||
return HardMishJitAutoFn.apply(x)
|
return HardMishAutoFn.apply(x)
|
||||||
|
|
||||||
|
|
||||||
class HardMishMe(nn.Module):
|
class HardMishMe(nn.Module):
|
||||||
@ -212,7 +202,7 @@ class HardMishMe(nn.Module):
|
|||||||
super(HardMishMe, self).__init__()
|
super(HardMishMe, self).__init__()
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
return HardMishJitAutoFn.apply(x)
|
return HardMishAutoFn.apply(x)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
@ -4,9 +4,8 @@ Hacked together by / Copyright 2020 Ross Wightman
|
|||||||
from typing import Union, Callable, Type
|
from typing import Union, Callable, Type
|
||||||
|
|
||||||
from .activations import *
|
from .activations import *
|
||||||
from .activations_jit import *
|
|
||||||
from .activations_me import *
|
from .activations_me import *
|
||||||
from .config import is_exportable, is_scriptable, is_no_jit
|
from .config import is_exportable, is_scriptable
|
||||||
|
|
||||||
# PyTorch has an optimized, native 'silu' (aka 'swish') operator as of PyTorch 1.7.
|
# PyTorch has an optimized, native 'silu' (aka 'swish') operator as of PyTorch 1.7.
|
||||||
# Also hardsigmoid, hardswish, and soon mish. This code will use native version if present.
|
# Also hardsigmoid, hardswish, and soon mish. This code will use native version if present.
|
||||||
@ -37,15 +36,6 @@ _ACT_FN_DEFAULT = dict(
|
|||||||
hard_mish=hard_mish,
|
hard_mish=hard_mish,
|
||||||
)
|
)
|
||||||
|
|
||||||
_ACT_FN_JIT = dict(
|
|
||||||
silu=F.silu if _has_silu else swish_jit,
|
|
||||||
swish=F.silu if _has_silu else swish_jit,
|
|
||||||
mish=F.mish if _has_mish else mish_jit,
|
|
||||||
hard_sigmoid=F.hardsigmoid if _has_hardsigmoid else hard_sigmoid_jit,
|
|
||||||
hard_swish=F.hardswish if _has_hardswish else hard_swish_jit,
|
|
||||||
hard_mish=hard_mish_jit,
|
|
||||||
)
|
|
||||||
|
|
||||||
_ACT_FN_ME = dict(
|
_ACT_FN_ME = dict(
|
||||||
silu=F.silu if _has_silu else swish_me,
|
silu=F.silu if _has_silu else swish_me,
|
||||||
swish=F.silu if _has_silu else swish_me,
|
swish=F.silu if _has_silu else swish_me,
|
||||||
@ -55,7 +45,7 @@ _ACT_FN_ME = dict(
|
|||||||
hard_mish=hard_mish_me,
|
hard_mish=hard_mish_me,
|
||||||
)
|
)
|
||||||
|
|
||||||
_ACT_FNS = (_ACT_FN_ME, _ACT_FN_JIT, _ACT_FN_DEFAULT)
|
_ACT_FNS = (_ACT_FN_ME, _ACT_FN_DEFAULT)
|
||||||
for a in _ACT_FNS:
|
for a in _ACT_FNS:
|
||||||
a.setdefault('hardsigmoid', a.get('hard_sigmoid'))
|
a.setdefault('hardsigmoid', a.get('hard_sigmoid'))
|
||||||
a.setdefault('hardswish', a.get('hard_swish'))
|
a.setdefault('hardswish', a.get('hard_swish'))
|
||||||
@ -83,15 +73,6 @@ _ACT_LAYER_DEFAULT = dict(
|
|||||||
identity=nn.Identity,
|
identity=nn.Identity,
|
||||||
)
|
)
|
||||||
|
|
||||||
_ACT_LAYER_JIT = dict(
|
|
||||||
silu=nn.SiLU if _has_silu else SwishJit,
|
|
||||||
swish=nn.SiLU if _has_silu else SwishJit,
|
|
||||||
mish=nn.Mish if _has_mish else MishJit,
|
|
||||||
hard_sigmoid=nn.Hardsigmoid if _has_hardsigmoid else HardSigmoidJit,
|
|
||||||
hard_swish=nn.Hardswish if _has_hardswish else HardSwishJit,
|
|
||||||
hard_mish=HardMishJit,
|
|
||||||
)
|
|
||||||
|
|
||||||
_ACT_LAYER_ME = dict(
|
_ACT_LAYER_ME = dict(
|
||||||
silu=nn.SiLU if _has_silu else SwishMe,
|
silu=nn.SiLU if _has_silu else SwishMe,
|
||||||
swish=nn.SiLU if _has_silu else SwishMe,
|
swish=nn.SiLU if _has_silu else SwishMe,
|
||||||
@ -101,7 +82,7 @@ _ACT_LAYER_ME = dict(
|
|||||||
hard_mish=HardMishMe,
|
hard_mish=HardMishMe,
|
||||||
)
|
)
|
||||||
|
|
||||||
_ACT_LAYERS = (_ACT_LAYER_ME, _ACT_LAYER_JIT, _ACT_LAYER_DEFAULT)
|
_ACT_LAYERS = (_ACT_LAYER_ME, _ACT_LAYER_DEFAULT)
|
||||||
for a in _ACT_LAYERS:
|
for a in _ACT_LAYERS:
|
||||||
a.setdefault('hardsigmoid', a.get('hard_sigmoid'))
|
a.setdefault('hardsigmoid', a.get('hard_sigmoid'))
|
||||||
a.setdefault('hardswish', a.get('hard_swish'))
|
a.setdefault('hardswish', a.get('hard_swish'))
|
||||||
@ -116,14 +97,11 @@ def get_act_fn(name: Union[Callable, str] = 'relu'):
|
|||||||
return None
|
return None
|
||||||
if isinstance(name, Callable):
|
if isinstance(name, Callable):
|
||||||
return name
|
return name
|
||||||
if not (is_no_jit() or is_exportable() or is_scriptable()):
|
if not (is_exportable() or is_scriptable()):
|
||||||
# If not exporting or scripting the model, first look for a memory-efficient version with
|
# If not exporting or scripting the model, first look for a memory-efficient version with
|
||||||
# custom autograd, then fallback
|
# custom autograd, then fallback
|
||||||
if name in _ACT_FN_ME:
|
if name in _ACT_FN_ME:
|
||||||
return _ACT_FN_ME[name]
|
return _ACT_FN_ME[name]
|
||||||
if not (is_no_jit() or is_exportable()):
|
|
||||||
if name in _ACT_FN_JIT:
|
|
||||||
return _ACT_FN_JIT[name]
|
|
||||||
return _ACT_FN_DEFAULT[name]
|
return _ACT_FN_DEFAULT[name]
|
||||||
|
|
||||||
|
|
||||||
@ -139,12 +117,9 @@ def get_act_layer(name: Union[Type[nn.Module], str] = 'relu'):
|
|||||||
return name
|
return name
|
||||||
if not name:
|
if not name:
|
||||||
return None
|
return None
|
||||||
if not (is_no_jit() or is_exportable() or is_scriptable()):
|
if not (is_exportable() or is_scriptable()):
|
||||||
if name in _ACT_LAYER_ME:
|
if name in _ACT_LAYER_ME:
|
||||||
return _ACT_LAYER_ME[name]
|
return _ACT_LAYER_ME[name]
|
||||||
if not (is_no_jit() or is_exportable()):
|
|
||||||
if name in _ACT_LAYER_JIT:
|
|
||||||
return _ACT_LAYER_JIT[name]
|
|
||||||
return _ACT_LAYER_DEFAULT[name]
|
return _ACT_LAYER_DEFAULT[name]
|
||||||
|
|
||||||
|
|
||||||
|
@ -73,7 +73,6 @@ class TransformerDecoderLayerOptimal(nn.Module):
|
|||||||
return tgt
|
return tgt
|
||||||
|
|
||||||
|
|
||||||
# @torch.jit.script
|
|
||||||
# class ExtrapClasses(object):
|
# class ExtrapClasses(object):
|
||||||
# def __init__(self, num_queries: int, group_size: int):
|
# def __init__(self, num_queries: int, group_size: int):
|
||||||
# self.num_queries = num_queries
|
# self.num_queries = num_queries
|
||||||
@ -88,24 +87,13 @@ class TransformerDecoderLayerOptimal(nn.Module):
|
|||||||
# out = out.view((h.shape[0], self.group_size * self.num_queries))
|
# out = out.view((h.shape[0], self.group_size * self.num_queries))
|
||||||
# return out
|
# return out
|
||||||
|
|
||||||
@torch.jit.script
|
|
||||||
class GroupFC(object):
|
|
||||||
def __init__(self, embed_len_decoder: int):
|
|
||||||
self.embed_len_decoder = embed_len_decoder
|
|
||||||
|
|
||||||
def __call__(self, h: torch.Tensor, duplicate_pooling: torch.Tensor, out_extrap: torch.Tensor):
|
|
||||||
for i in range(self.embed_len_decoder):
|
|
||||||
h_i = h[:, i, :]
|
|
||||||
w_i = duplicate_pooling[i, :, :]
|
|
||||||
out_extrap[:, i, :] = torch.matmul(h_i, w_i)
|
|
||||||
|
|
||||||
|
|
||||||
class MLDecoder(nn.Module):
|
class MLDecoder(nn.Module):
|
||||||
def __init__(self, num_classes, num_of_groups=-1, decoder_embedding=768, initial_num_features=2048):
|
def __init__(self, num_classes, num_of_groups=-1, decoder_embedding=768, initial_num_features=2048):
|
||||||
super(MLDecoder, self).__init__()
|
super(MLDecoder, self).__init__()
|
||||||
embed_len_decoder = 100 if num_of_groups < 0 else num_of_groups
|
embed_len_decoder = 100 if num_of_groups < 0 else num_of_groups
|
||||||
if embed_len_decoder > num_classes:
|
if embed_len_decoder > num_classes:
|
||||||
embed_len_decoder = num_classes
|
embed_len_decoder = num_classes
|
||||||
|
self.embed_len_decoder = embed_len_decoder
|
||||||
|
|
||||||
# switching to 768 initial embeddings
|
# switching to 768 initial embeddings
|
||||||
decoder_embedding = 768 if decoder_embedding < 0 else decoder_embedding
|
decoder_embedding = 768 if decoder_embedding < 0 else decoder_embedding
|
||||||
@ -131,7 +119,6 @@ class MLDecoder(nn.Module):
|
|||||||
self.duplicate_pooling_bias = torch.nn.Parameter(torch.Tensor(num_classes))
|
self.duplicate_pooling_bias = torch.nn.Parameter(torch.Tensor(num_classes))
|
||||||
torch.nn.init.xavier_normal_(self.duplicate_pooling)
|
torch.nn.init.xavier_normal_(self.duplicate_pooling)
|
||||||
torch.nn.init.constant_(self.duplicate_pooling_bias, 0)
|
torch.nn.init.constant_(self.duplicate_pooling_bias, 0)
|
||||||
self.group_fc = GroupFC(embed_len_decoder)
|
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
if len(x.shape) == 4: # [bs,2048, 7,7]
|
if len(x.shape) == 4: # [bs,2048, 7,7]
|
||||||
@ -149,7 +136,10 @@ class MLDecoder(nn.Module):
|
|||||||
h = h.transpose(0, 1)
|
h = h.transpose(0, 1)
|
||||||
|
|
||||||
out_extrap = torch.zeros(h.shape[0], h.shape[1], self.duplicate_factor, device=h.device, dtype=h.dtype)
|
out_extrap = torch.zeros(h.shape[0], h.shape[1], self.duplicate_factor, device=h.device, dtype=h.dtype)
|
||||||
self.group_fc(h, self.duplicate_pooling, out_extrap)
|
for i in range(self.embed_len_decoder): # group FC
|
||||||
|
h_i = h[:, i, :]
|
||||||
|
w_i = self.duplicate_pooling[i, :, :]
|
||||||
|
out_extrap[:, i, :] = torch.matmul(h_i, w_i)
|
||||||
h_out = out_extrap.flatten(1)[:, :self.num_classes]
|
h_out = out_extrap.flatten(1)[:, :self.num_classes]
|
||||||
h_out += self.duplicate_pooling_bias
|
h_out += self.duplicate_pooling_bias
|
||||||
logits = h_out
|
logits = h_out
|
||||||
|
@ -82,7 +82,6 @@ def _is_contiguous(tensor: torch.Tensor) -> bool:
|
|||||||
return tensor.is_contiguous(memory_format=torch.contiguous_format)
|
return tensor.is_contiguous(memory_format=torch.contiguous_format)
|
||||||
|
|
||||||
|
|
||||||
@torch.jit.script
|
|
||||||
def _layer_norm_cf(x: torch.Tensor, weight: torch.Tensor, bias: torch.Tensor, eps: float):
|
def _layer_norm_cf(x: torch.Tensor, weight: torch.Tensor, bias: torch.Tensor, eps: float):
|
||||||
s, u = torch.var_mean(x, dim=1, unbiased=False, keepdim=True)
|
s, u = torch.var_mean(x, dim=1, unbiased=False, keepdim=True)
|
||||||
x = (x - u) * torch.rsqrt(s + eps)
|
x = (x - u) * torch.rsqrt(s + eps)
|
||||||
|
@ -18,29 +18,6 @@ class SpaceToDepth(nn.Module):
|
|||||||
return x
|
return x
|
||||||
|
|
||||||
|
|
||||||
@torch.jit.script
|
|
||||||
class SpaceToDepthJit:
|
|
||||||
def __call__(self, x: torch.Tensor):
|
|
||||||
# assuming hard-coded that block_size==4 for acceleration
|
|
||||||
N, C, H, W = x.size()
|
|
||||||
x = x.view(N, C, H // 4, 4, W // 4, 4) # (N, C, H//bs, bs, W//bs, bs)
|
|
||||||
x = x.permute(0, 3, 5, 1, 2, 4).contiguous() # (N, bs, bs, C, H//bs, W//bs)
|
|
||||||
x = x.view(N, C * 16, H // 4, W // 4) # (N, C*bs^2, H//bs, W//bs)
|
|
||||||
return x
|
|
||||||
|
|
||||||
|
|
||||||
class SpaceToDepthModule(nn.Module):
|
|
||||||
def __init__(self, no_jit=False):
|
|
||||||
super().__init__()
|
|
||||||
if not no_jit:
|
|
||||||
self.op = SpaceToDepthJit()
|
|
||||||
else:
|
|
||||||
self.op = SpaceToDepth()
|
|
||||||
|
|
||||||
def forward(self, x):
|
|
||||||
return self.op(x)
|
|
||||||
|
|
||||||
|
|
||||||
class DepthToSpace(nn.Module):
|
class DepthToSpace(nn.Module):
|
||||||
|
|
||||||
def __init__(self, block_size):
|
def __init__(self, block_size):
|
||||||
|
@ -37,7 +37,6 @@ from timm.layers.pool2d_same import AvgPool2dSame, create_pool2d
|
|||||||
from timm.layers.squeeze_excite import SEModule, SqueezeExcite, EffectiveSEModule, EffectiveSqueezeExcite
|
from timm.layers.squeeze_excite import SEModule, SqueezeExcite, EffectiveSEModule, EffectiveSqueezeExcite
|
||||||
from timm.layers.selective_kernel import SelectiveKernel
|
from timm.layers.selective_kernel import SelectiveKernel
|
||||||
from timm.layers.separable_conv import SeparableConv2d, SeparableConvNormAct
|
from timm.layers.separable_conv import SeparableConv2d, SeparableConvNormAct
|
||||||
from timm.layers.space_to_depth import SpaceToDepthModule
|
|
||||||
from timm.layers.split_attn import SplitAttn
|
from timm.layers.split_attn import SplitAttn
|
||||||
from timm.layers.split_batchnorm import SplitBatchNorm2d, convert_splitbn_model
|
from timm.layers.split_batchnorm import SplitBatchNorm2d, convert_splitbn_model
|
||||||
from timm.layers.std_conv import StdConv2d, StdConv2dSame, ScaledStdConv2d, ScaledStdConv2dSame
|
from timm.layers.std_conv import StdConv2d, StdConv2dSame, ScaledStdConv2d, ScaledStdConv2dSame
|
||||||
|
Loading…
x
Reference in New Issue
Block a user