diff --git a/timm/layers/__init__.py b/timm/layers/__init__.py index f4cc8c07..de077797 100644 --- a/timm/layers/__init__.py +++ b/timm/layers/__init__.py @@ -47,7 +47,7 @@ from .pos_embed_sincos import pixel_freq_bands, freq_bands, build_sincos2d_pos_e from .squeeze_excite import SEModule, SqueezeExcite, EffectiveSEModule, EffectiveSqueezeExcite from .selective_kernel import SelectiveKernel from .separable_conv import SeparableConv2d, SeparableConvNormAct -from .space_to_depth import SpaceToDepthModule, SpaceToDepth, DepthToSpace +from .space_to_depth import SpaceToDepth, DepthToSpace from .split_attn import SplitAttn from .split_batchnorm import SplitBatchNorm2d, convert_splitbn_model from .std_conv import StdConv2d, StdConv2dSame, ScaledStdConv2d, ScaledStdConv2dSame diff --git a/timm/layers/activations_jit.py b/timm/layers/activations_jit.py deleted file mode 100644 index b4a51653..00000000 --- a/timm/layers/activations_jit.py +++ /dev/null @@ -1,90 +0,0 @@ -""" Activations - -A collection of jit-scripted activations fn and modules with a common interface so that they can -easily be swapped. All have an `inplace` arg even if not used. - -All jit scripted activations are lacking in-place variations on purpose, scripted kernel fusion does not -currently work across in-place op boundaries, thus performance is equal to or less than the non-scripted -versions if they contain in-place ops. - -Hacked together by / Copyright 2020 Ross Wightman -""" - -import torch -from torch import nn as nn -from torch.nn import functional as F - - -@torch.jit.script -def swish_jit(x, inplace: bool = False): - """Swish - Described in: https://arxiv.org/abs/1710.05941 - """ - return x.mul(x.sigmoid()) - - -@torch.jit.script -def mish_jit(x, _inplace: bool = False): - """Mish: A Self Regularized Non-Monotonic Neural Activation Function - https://arxiv.org/abs/1908.08681 - """ - return x.mul(F.softplus(x).tanh()) - - -class SwishJit(nn.Module): - def __init__(self, inplace: bool = False): - super(SwishJit, self).__init__() - - def forward(self, x): - return swish_jit(x) - - -class MishJit(nn.Module): - def __init__(self, inplace: bool = False): - super(MishJit, self).__init__() - - def forward(self, x): - return mish_jit(x) - - -@torch.jit.script -def hard_sigmoid_jit(x, inplace: bool = False): - # return F.relu6(x + 3.) / 6. - return (x + 3).clamp(min=0, max=6).div(6.) # clamp seems ever so slightly faster? - - -class HardSigmoidJit(nn.Module): - def __init__(self, inplace: bool = False): - super(HardSigmoidJit, self).__init__() - - def forward(self, x): - return hard_sigmoid_jit(x) - - -@torch.jit.script -def hard_swish_jit(x, inplace: bool = False): - # return x * (F.relu6(x + 3.) / 6) - return x * (x + 3).clamp(min=0, max=6).div(6.) # clamp seems ever so slightly faster? - - -class HardSwishJit(nn.Module): - def __init__(self, inplace: bool = False): - super(HardSwishJit, self).__init__() - - def forward(self, x): - return hard_swish_jit(x) - - -@torch.jit.script -def hard_mish_jit(x, inplace: bool = False): - """ Hard Mish - Experimental, based on notes by Mish author Diganta Misra at - https://github.com/digantamisra98/H-Mish/blob/0da20d4bc58e696b6803f2523c58d3c8a82782d0/README.md - """ - return 0.5 * x * (x + 2).clamp(min=0, max=2) - - -class HardMishJit(nn.Module): - def __init__(self, inplace: bool = False): - super(HardMishJit, self).__init__() - - def forward(self, x): - return hard_mish_jit(x) diff --git a/timm/layers/activations_me.py b/timm/layers/activations_me.py index 9a12bb7e..b0ddd5cb 100644 --- a/timm/layers/activations_me.py +++ b/timm/layers/activations_me.py @@ -3,8 +3,8 @@ A collection of activations fn and modules with a common interface so that they can easily be swapped. All have an `inplace` arg even if not used. -These activations are not compatible with jit scripting or ONNX export of the model, please use either -the JIT or basic versions of the activations. +These activations are not compatible with jit scripting or ONNX export of the model, please use +basic versions of the activations. Hacked together by / Copyright 2020 Ross Wightman """ @@ -14,19 +14,17 @@ from torch import nn as nn from torch.nn import functional as F -@torch.jit.script -def swish_jit_fwd(x): +def swish_fwd(x): return x.mul(torch.sigmoid(x)) -@torch.jit.script -def swish_jit_bwd(x, grad_output): +def swish_bwd(x, grad_output): x_sigmoid = torch.sigmoid(x) return grad_output * (x_sigmoid * (1 + x * (1 - x_sigmoid))) -class SwishJitAutoFn(torch.autograd.Function): - """ torch.jit.script optimised Swish w/ memory-efficient checkpoint +class SwishAutoFn(torch.autograd.Function): + """ optimised Swish w/ memory-efficient checkpoint Inspired by conversation btw Jeremy Howard & Adam Pazske https://twitter.com/jeremyphoward/status/1188251041835315200 """ @@ -37,16 +35,16 @@ class SwishJitAutoFn(torch.autograd.Function): @staticmethod def forward(ctx, x): ctx.save_for_backward(x) - return swish_jit_fwd(x) + return swish_fwd(x) @staticmethod def backward(ctx, grad_output): x = ctx.saved_tensors[0] - return swish_jit_bwd(x, grad_output) + return swish_bwd(x, grad_output) def swish_me(x, inplace=False): - return SwishJitAutoFn.apply(x) + return SwishAutoFn.apply(x) class SwishMe(nn.Module): @@ -54,38 +52,36 @@ class SwishMe(nn.Module): super(SwishMe, self).__init__() def forward(self, x): - return SwishJitAutoFn.apply(x) + return SwishAutoFn.apply(x) -@torch.jit.script -def mish_jit_fwd(x): +def mish_fwd(x): return x.mul(torch.tanh(F.softplus(x))) -@torch.jit.script -def mish_jit_bwd(x, grad_output): +def mish_bwd(x, grad_output): x_sigmoid = torch.sigmoid(x) x_tanh_sp = F.softplus(x).tanh() return grad_output.mul(x_tanh_sp + x * x_sigmoid * (1 - x_tanh_sp * x_tanh_sp)) -class MishJitAutoFn(torch.autograd.Function): +class MishAutoFn(torch.autograd.Function): """ Mish: A Self Regularized Non-Monotonic Neural Activation Function - https://arxiv.org/abs/1908.08681 - A memory efficient, jit scripted variant of Mish + A memory efficient variant of Mish """ @staticmethod def forward(ctx, x): ctx.save_for_backward(x) - return mish_jit_fwd(x) + return mish_fwd(x) @staticmethod def backward(ctx, grad_output): x = ctx.saved_tensors[0] - return mish_jit_bwd(x, grad_output) + return mish_bwd(x, grad_output) def mish_me(x, inplace=False): - return MishJitAutoFn.apply(x) + return MishAutoFn.apply(x) class MishMe(nn.Module): @@ -93,34 +89,32 @@ class MishMe(nn.Module): super(MishMe, self).__init__() def forward(self, x): - return MishJitAutoFn.apply(x) + return MishAutoFn.apply(x) -@torch.jit.script -def hard_sigmoid_jit_fwd(x, inplace: bool = False): +def hard_sigmoid_fwd(x, inplace: bool = False): return (x + 3).clamp(min=0, max=6).div(6.) -@torch.jit.script -def hard_sigmoid_jit_bwd(x, grad_output): +def hard_sigmoid_bwd(x, grad_output): m = torch.ones_like(x) * ((x >= -3.) & (x <= 3.)) / 6. return grad_output * m -class HardSigmoidJitAutoFn(torch.autograd.Function): +class HardSigmoidAutoFn(torch.autograd.Function): @staticmethod def forward(ctx, x): ctx.save_for_backward(x) - return hard_sigmoid_jit_fwd(x) + return hard_sigmoid_fwd(x) @staticmethod def backward(ctx, grad_output): x = ctx.saved_tensors[0] - return hard_sigmoid_jit_bwd(x, grad_output) + return hard_sigmoid_bwd(x, grad_output) def hard_sigmoid_me(x, inplace: bool = False): - return HardSigmoidJitAutoFn.apply(x) + return HardSigmoidAutoFn.apply(x) class HardSigmoidMe(nn.Module): @@ -128,32 +122,30 @@ class HardSigmoidMe(nn.Module): super(HardSigmoidMe, self).__init__() def forward(self, x): - return HardSigmoidJitAutoFn.apply(x) + return HardSigmoidAutoFn.apply(x) -@torch.jit.script -def hard_swish_jit_fwd(x): +def hard_swish_fwd(x): return x * (x + 3).clamp(min=0, max=6).div(6.) -@torch.jit.script -def hard_swish_jit_bwd(x, grad_output): +def hard_swish_bwd(x, grad_output): m = torch.ones_like(x) * (x >= 3.) m = torch.where((x >= -3.) & (x <= 3.), x / 3. + .5, m) return grad_output * m -class HardSwishJitAutoFn(torch.autograd.Function): - """A memory efficient, jit-scripted HardSwish activation""" +class HardSwishAutoFn(torch.autograd.Function): + """A memory efficient HardSwish activation""" @staticmethod def forward(ctx, x): ctx.save_for_backward(x) - return hard_swish_jit_fwd(x) + return hard_swish_fwd(x) @staticmethod def backward(ctx, grad_output): x = ctx.saved_tensors[0] - return hard_swish_jit_bwd(x, grad_output) + return hard_swish_bwd(x, grad_output) @staticmethod def symbolic(g, self): @@ -164,7 +156,7 @@ class HardSwishJitAutoFn(torch.autograd.Function): def hard_swish_me(x, inplace=False): - return HardSwishJitAutoFn.apply(x) + return HardSwishAutoFn.apply(x) class HardSwishMe(nn.Module): @@ -172,39 +164,37 @@ class HardSwishMe(nn.Module): super(HardSwishMe, self).__init__() def forward(self, x): - return HardSwishJitAutoFn.apply(x) + return HardSwishAutoFn.apply(x) -@torch.jit.script -def hard_mish_jit_fwd(x): +def hard_mish_fwd(x): return 0.5 * x * (x + 2).clamp(min=0, max=2) -@torch.jit.script -def hard_mish_jit_bwd(x, grad_output): +def hard_mish_bwd(x, grad_output): m = torch.ones_like(x) * (x >= -2.) m = torch.where((x >= -2.) & (x <= 0.), x + 1., m) return grad_output * m -class HardMishJitAutoFn(torch.autograd.Function): - """ A memory efficient, jit scripted variant of Hard Mish +class HardMishAutoFn(torch.autograd.Function): + """ A memory efficient variant of Hard Mish Experimental, based on notes by Mish author Diganta Misra at https://github.com/digantamisra98/H-Mish/blob/0da20d4bc58e696b6803f2523c58d3c8a82782d0/README.md """ @staticmethod def forward(ctx, x): ctx.save_for_backward(x) - return hard_mish_jit_fwd(x) + return hard_mish_fwd(x) @staticmethod def backward(ctx, grad_output): x = ctx.saved_tensors[0] - return hard_mish_jit_bwd(x, grad_output) + return hard_mish_bwd(x, grad_output) def hard_mish_me(x, inplace: bool = False): - return HardMishJitAutoFn.apply(x) + return HardMishAutoFn.apply(x) class HardMishMe(nn.Module): @@ -212,7 +202,7 @@ class HardMishMe(nn.Module): super(HardMishMe, self).__init__() def forward(self, x): - return HardMishJitAutoFn.apply(x) + return HardMishAutoFn.apply(x) diff --git a/timm/layers/create_act.py b/timm/layers/create_act.py index 93bcbf0e..6bbbc14b 100644 --- a/timm/layers/create_act.py +++ b/timm/layers/create_act.py @@ -4,9 +4,8 @@ Hacked together by / Copyright 2020 Ross Wightman from typing import Union, Callable, Type from .activations import * -from .activations_jit import * from .activations_me import * -from .config import is_exportable, is_scriptable, is_no_jit +from .config import is_exportable, is_scriptable # PyTorch has an optimized, native 'silu' (aka 'swish') operator as of PyTorch 1.7. # Also hardsigmoid, hardswish, and soon mish. This code will use native version if present. @@ -37,15 +36,6 @@ _ACT_FN_DEFAULT = dict( hard_mish=hard_mish, ) -_ACT_FN_JIT = dict( - silu=F.silu if _has_silu else swish_jit, - swish=F.silu if _has_silu else swish_jit, - mish=F.mish if _has_mish else mish_jit, - hard_sigmoid=F.hardsigmoid if _has_hardsigmoid else hard_sigmoid_jit, - hard_swish=F.hardswish if _has_hardswish else hard_swish_jit, - hard_mish=hard_mish_jit, -) - _ACT_FN_ME = dict( silu=F.silu if _has_silu else swish_me, swish=F.silu if _has_silu else swish_me, @@ -55,7 +45,7 @@ _ACT_FN_ME = dict( hard_mish=hard_mish_me, ) -_ACT_FNS = (_ACT_FN_ME, _ACT_FN_JIT, _ACT_FN_DEFAULT) +_ACT_FNS = (_ACT_FN_ME, _ACT_FN_DEFAULT) for a in _ACT_FNS: a.setdefault('hardsigmoid', a.get('hard_sigmoid')) a.setdefault('hardswish', a.get('hard_swish')) @@ -83,15 +73,6 @@ _ACT_LAYER_DEFAULT = dict( identity=nn.Identity, ) -_ACT_LAYER_JIT = dict( - silu=nn.SiLU if _has_silu else SwishJit, - swish=nn.SiLU if _has_silu else SwishJit, - mish=nn.Mish if _has_mish else MishJit, - hard_sigmoid=nn.Hardsigmoid if _has_hardsigmoid else HardSigmoidJit, - hard_swish=nn.Hardswish if _has_hardswish else HardSwishJit, - hard_mish=HardMishJit, -) - _ACT_LAYER_ME = dict( silu=nn.SiLU if _has_silu else SwishMe, swish=nn.SiLU if _has_silu else SwishMe, @@ -101,7 +82,7 @@ _ACT_LAYER_ME = dict( hard_mish=HardMishMe, ) -_ACT_LAYERS = (_ACT_LAYER_ME, _ACT_LAYER_JIT, _ACT_LAYER_DEFAULT) +_ACT_LAYERS = (_ACT_LAYER_ME, _ACT_LAYER_DEFAULT) for a in _ACT_LAYERS: a.setdefault('hardsigmoid', a.get('hard_sigmoid')) a.setdefault('hardswish', a.get('hard_swish')) @@ -116,14 +97,11 @@ def get_act_fn(name: Union[Callable, str] = 'relu'): return None if isinstance(name, Callable): return name - if not (is_no_jit() or is_exportable() or is_scriptable()): + if not (is_exportable() or is_scriptable()): # If not exporting or scripting the model, first look for a memory-efficient version with # custom autograd, then fallback if name in _ACT_FN_ME: return _ACT_FN_ME[name] - if not (is_no_jit() or is_exportable()): - if name in _ACT_FN_JIT: - return _ACT_FN_JIT[name] return _ACT_FN_DEFAULT[name] @@ -139,12 +117,9 @@ def get_act_layer(name: Union[Type[nn.Module], str] = 'relu'): return name if not name: return None - if not (is_no_jit() or is_exportable() or is_scriptable()): + if not (is_exportable() or is_scriptable()): if name in _ACT_LAYER_ME: return _ACT_LAYER_ME[name] - if not (is_no_jit() or is_exportable()): - if name in _ACT_LAYER_JIT: - return _ACT_LAYER_JIT[name] return _ACT_LAYER_DEFAULT[name] diff --git a/timm/layers/ml_decoder.py b/timm/layers/ml_decoder.py index 3f828c6d..cd7d5062 100644 --- a/timm/layers/ml_decoder.py +++ b/timm/layers/ml_decoder.py @@ -73,7 +73,6 @@ class TransformerDecoderLayerOptimal(nn.Module): return tgt -# @torch.jit.script # class ExtrapClasses(object): # def __init__(self, num_queries: int, group_size: int): # self.num_queries = num_queries @@ -88,24 +87,13 @@ class TransformerDecoderLayerOptimal(nn.Module): # out = out.view((h.shape[0], self.group_size * self.num_queries)) # return out -@torch.jit.script -class GroupFC(object): - def __init__(self, embed_len_decoder: int): - self.embed_len_decoder = embed_len_decoder - - def __call__(self, h: torch.Tensor, duplicate_pooling: torch.Tensor, out_extrap: torch.Tensor): - for i in range(self.embed_len_decoder): - h_i = h[:, i, :] - w_i = duplicate_pooling[i, :, :] - out_extrap[:, i, :] = torch.matmul(h_i, w_i) - - class MLDecoder(nn.Module): def __init__(self, num_classes, num_of_groups=-1, decoder_embedding=768, initial_num_features=2048): super(MLDecoder, self).__init__() embed_len_decoder = 100 if num_of_groups < 0 else num_of_groups if embed_len_decoder > num_classes: embed_len_decoder = num_classes + self.embed_len_decoder = embed_len_decoder # switching to 768 initial embeddings decoder_embedding = 768 if decoder_embedding < 0 else decoder_embedding @@ -131,7 +119,6 @@ class MLDecoder(nn.Module): self.duplicate_pooling_bias = torch.nn.Parameter(torch.Tensor(num_classes)) torch.nn.init.xavier_normal_(self.duplicate_pooling) torch.nn.init.constant_(self.duplicate_pooling_bias, 0) - self.group_fc = GroupFC(embed_len_decoder) def forward(self, x): if len(x.shape) == 4: # [bs,2048, 7,7] @@ -149,7 +136,10 @@ class MLDecoder(nn.Module): h = h.transpose(0, 1) out_extrap = torch.zeros(h.shape[0], h.shape[1], self.duplicate_factor, device=h.device, dtype=h.dtype) - self.group_fc(h, self.duplicate_pooling, out_extrap) + for i in range(self.embed_len_decoder): # group FC + h_i = h[:, i, :] + w_i = self.duplicate_pooling[i, :, :] + out_extrap[:, i, :] = torch.matmul(h_i, w_i) h_out = out_extrap.flatten(1)[:, :self.num_classes] h_out += self.duplicate_pooling_bias logits = h_out diff --git a/timm/layers/norm.py b/timm/layers/norm.py index 504060c7..4b81dcef 100644 --- a/timm/layers/norm.py +++ b/timm/layers/norm.py @@ -82,7 +82,6 @@ def _is_contiguous(tensor: torch.Tensor) -> bool: return tensor.is_contiguous(memory_format=torch.contiguous_format) -@torch.jit.script def _layer_norm_cf(x: torch.Tensor, weight: torch.Tensor, bias: torch.Tensor, eps: float): s, u = torch.var_mean(x, dim=1, unbiased=False, keepdim=True) x = (x - u) * torch.rsqrt(s + eps) diff --git a/timm/layers/space_to_depth.py b/timm/layers/space_to_depth.py index 5867456c..45268154 100644 --- a/timm/layers/space_to_depth.py +++ b/timm/layers/space_to_depth.py @@ -18,29 +18,6 @@ class SpaceToDepth(nn.Module): return x -@torch.jit.script -class SpaceToDepthJit: - def __call__(self, x: torch.Tensor): - # assuming hard-coded that block_size==4 for acceleration - N, C, H, W = x.size() - x = x.view(N, C, H // 4, 4, W // 4, 4) # (N, C, H//bs, bs, W//bs, bs) - x = x.permute(0, 3, 5, 1, 2, 4).contiguous() # (N, bs, bs, C, H//bs, W//bs) - x = x.view(N, C * 16, H // 4, W // 4) # (N, C*bs^2, H//bs, W//bs) - return x - - -class SpaceToDepthModule(nn.Module): - def __init__(self, no_jit=False): - super().__init__() - if not no_jit: - self.op = SpaceToDepthJit() - else: - self.op = SpaceToDepth() - - def forward(self, x): - return self.op(x) - - class DepthToSpace(nn.Module): def __init__(self, block_size): diff --git a/timm/models/layers/__init__.py b/timm/models/layers/__init__.py index 50324597..705ebd25 100644 --- a/timm/models/layers/__init__.py +++ b/timm/models/layers/__init__.py @@ -37,7 +37,6 @@ from timm.layers.pool2d_same import AvgPool2dSame, create_pool2d from timm.layers.squeeze_excite import SEModule, SqueezeExcite, EffectiveSEModule, EffectiveSqueezeExcite from timm.layers.selective_kernel import SelectiveKernel from timm.layers.separable_conv import SeparableConv2d, SeparableConvNormAct -from timm.layers.space_to_depth import SpaceToDepthModule from timm.layers.split_attn import SplitAttn from timm.layers.split_batchnorm import SplitBatchNorm2d, convert_splitbn_model from timm.layers.std_conv import StdConv2d, StdConv2dSame, ScaledStdConv2d, ScaledStdConv2dSame