Remove JIT activations, take jit out of ME activations. Remove other instances of torch.jit.script. Breaks torch.compile and is much less performant. Remove SpaceToDepthModule

This commit is contained in:
Ross Wightman 2024-05-06 16:32:49 -07:00
parent 07535f408a
commit 2bfa5e5d74
8 changed files with 52 additions and 212 deletions

View File

@ -47,7 +47,7 @@ from .pos_embed_sincos import pixel_freq_bands, freq_bands, build_sincos2d_pos_e
from .squeeze_excite import SEModule, SqueezeExcite, EffectiveSEModule, EffectiveSqueezeExcite from .squeeze_excite import SEModule, SqueezeExcite, EffectiveSEModule, EffectiveSqueezeExcite
from .selective_kernel import SelectiveKernel from .selective_kernel import SelectiveKernel
from .separable_conv import SeparableConv2d, SeparableConvNormAct from .separable_conv import SeparableConv2d, SeparableConvNormAct
from .space_to_depth import SpaceToDepthModule, SpaceToDepth, DepthToSpace from .space_to_depth import SpaceToDepth, DepthToSpace
from .split_attn import SplitAttn from .split_attn import SplitAttn
from .split_batchnorm import SplitBatchNorm2d, convert_splitbn_model from .split_batchnorm import SplitBatchNorm2d, convert_splitbn_model
from .std_conv import StdConv2d, StdConv2dSame, ScaledStdConv2d, ScaledStdConv2dSame from .std_conv import StdConv2d, StdConv2dSame, ScaledStdConv2d, ScaledStdConv2dSame

View File

@ -1,90 +0,0 @@
""" Activations
A collection of jit-scripted activations fn and modules with a common interface so that they can
easily be swapped. All have an `inplace` arg even if not used.
All jit scripted activations are lacking in-place variations on purpose, scripted kernel fusion does not
currently work across in-place op boundaries, thus performance is equal to or less than the non-scripted
versions if they contain in-place ops.
Hacked together by / Copyright 2020 Ross Wightman
"""
import torch
from torch import nn as nn
from torch.nn import functional as F
@torch.jit.script
def swish_jit(x, inplace: bool = False):
"""Swish - Described in: https://arxiv.org/abs/1710.05941
"""
return x.mul(x.sigmoid())
@torch.jit.script
def mish_jit(x, _inplace: bool = False):
"""Mish: A Self Regularized Non-Monotonic Neural Activation Function - https://arxiv.org/abs/1908.08681
"""
return x.mul(F.softplus(x).tanh())
class SwishJit(nn.Module):
def __init__(self, inplace: bool = False):
super(SwishJit, self).__init__()
def forward(self, x):
return swish_jit(x)
class MishJit(nn.Module):
def __init__(self, inplace: bool = False):
super(MishJit, self).__init__()
def forward(self, x):
return mish_jit(x)
@torch.jit.script
def hard_sigmoid_jit(x, inplace: bool = False):
# return F.relu6(x + 3.) / 6.
return (x + 3).clamp(min=0, max=6).div(6.) # clamp seems ever so slightly faster?
class HardSigmoidJit(nn.Module):
def __init__(self, inplace: bool = False):
super(HardSigmoidJit, self).__init__()
def forward(self, x):
return hard_sigmoid_jit(x)
@torch.jit.script
def hard_swish_jit(x, inplace: bool = False):
# return x * (F.relu6(x + 3.) / 6)
return x * (x + 3).clamp(min=0, max=6).div(6.) # clamp seems ever so slightly faster?
class HardSwishJit(nn.Module):
def __init__(self, inplace: bool = False):
super(HardSwishJit, self).__init__()
def forward(self, x):
return hard_swish_jit(x)
@torch.jit.script
def hard_mish_jit(x, inplace: bool = False):
""" Hard Mish
Experimental, based on notes by Mish author Diganta Misra at
https://github.com/digantamisra98/H-Mish/blob/0da20d4bc58e696b6803f2523c58d3c8a82782d0/README.md
"""
return 0.5 * x * (x + 2).clamp(min=0, max=2)
class HardMishJit(nn.Module):
def __init__(self, inplace: bool = False):
super(HardMishJit, self).__init__()
def forward(self, x):
return hard_mish_jit(x)

View File

@ -3,8 +3,8 @@
A collection of activations fn and modules with a common interface so that they can A collection of activations fn and modules with a common interface so that they can
easily be swapped. All have an `inplace` arg even if not used. easily be swapped. All have an `inplace` arg even if not used.
These activations are not compatible with jit scripting or ONNX export of the model, please use either These activations are not compatible with jit scripting or ONNX export of the model, please use
the JIT or basic versions of the activations. basic versions of the activations.
Hacked together by / Copyright 2020 Ross Wightman Hacked together by / Copyright 2020 Ross Wightman
""" """
@ -14,19 +14,17 @@ from torch import nn as nn
from torch.nn import functional as F from torch.nn import functional as F
@torch.jit.script def swish_fwd(x):
def swish_jit_fwd(x):
return x.mul(torch.sigmoid(x)) return x.mul(torch.sigmoid(x))
@torch.jit.script def swish_bwd(x, grad_output):
def swish_jit_bwd(x, grad_output):
x_sigmoid = torch.sigmoid(x) x_sigmoid = torch.sigmoid(x)
return grad_output * (x_sigmoid * (1 + x * (1 - x_sigmoid))) return grad_output * (x_sigmoid * (1 + x * (1 - x_sigmoid)))
class SwishJitAutoFn(torch.autograd.Function): class SwishAutoFn(torch.autograd.Function):
""" torch.jit.script optimised Swish w/ memory-efficient checkpoint """ optimised Swish w/ memory-efficient checkpoint
Inspired by conversation btw Jeremy Howard & Adam Pazske Inspired by conversation btw Jeremy Howard & Adam Pazske
https://twitter.com/jeremyphoward/status/1188251041835315200 https://twitter.com/jeremyphoward/status/1188251041835315200
""" """
@ -37,16 +35,16 @@ class SwishJitAutoFn(torch.autograd.Function):
@staticmethod @staticmethod
def forward(ctx, x): def forward(ctx, x):
ctx.save_for_backward(x) ctx.save_for_backward(x)
return swish_jit_fwd(x) return swish_fwd(x)
@staticmethod @staticmethod
def backward(ctx, grad_output): def backward(ctx, grad_output):
x = ctx.saved_tensors[0] x = ctx.saved_tensors[0]
return swish_jit_bwd(x, grad_output) return swish_bwd(x, grad_output)
def swish_me(x, inplace=False): def swish_me(x, inplace=False):
return SwishJitAutoFn.apply(x) return SwishAutoFn.apply(x)
class SwishMe(nn.Module): class SwishMe(nn.Module):
@ -54,38 +52,36 @@ class SwishMe(nn.Module):
super(SwishMe, self).__init__() super(SwishMe, self).__init__()
def forward(self, x): def forward(self, x):
return SwishJitAutoFn.apply(x) return SwishAutoFn.apply(x)
@torch.jit.script def mish_fwd(x):
def mish_jit_fwd(x):
return x.mul(torch.tanh(F.softplus(x))) return x.mul(torch.tanh(F.softplus(x)))
@torch.jit.script def mish_bwd(x, grad_output):
def mish_jit_bwd(x, grad_output):
x_sigmoid = torch.sigmoid(x) x_sigmoid = torch.sigmoid(x)
x_tanh_sp = F.softplus(x).tanh() x_tanh_sp = F.softplus(x).tanh()
return grad_output.mul(x_tanh_sp + x * x_sigmoid * (1 - x_tanh_sp * x_tanh_sp)) return grad_output.mul(x_tanh_sp + x * x_sigmoid * (1 - x_tanh_sp * x_tanh_sp))
class MishJitAutoFn(torch.autograd.Function): class MishAutoFn(torch.autograd.Function):
""" Mish: A Self Regularized Non-Monotonic Neural Activation Function - https://arxiv.org/abs/1908.08681 """ Mish: A Self Regularized Non-Monotonic Neural Activation Function - https://arxiv.org/abs/1908.08681
A memory efficient, jit scripted variant of Mish A memory efficient variant of Mish
""" """
@staticmethod @staticmethod
def forward(ctx, x): def forward(ctx, x):
ctx.save_for_backward(x) ctx.save_for_backward(x)
return mish_jit_fwd(x) return mish_fwd(x)
@staticmethod @staticmethod
def backward(ctx, grad_output): def backward(ctx, grad_output):
x = ctx.saved_tensors[0] x = ctx.saved_tensors[0]
return mish_jit_bwd(x, grad_output) return mish_bwd(x, grad_output)
def mish_me(x, inplace=False): def mish_me(x, inplace=False):
return MishJitAutoFn.apply(x) return MishAutoFn.apply(x)
class MishMe(nn.Module): class MishMe(nn.Module):
@ -93,34 +89,32 @@ class MishMe(nn.Module):
super(MishMe, self).__init__() super(MishMe, self).__init__()
def forward(self, x): def forward(self, x):
return MishJitAutoFn.apply(x) return MishAutoFn.apply(x)
@torch.jit.script def hard_sigmoid_fwd(x, inplace: bool = False):
def hard_sigmoid_jit_fwd(x, inplace: bool = False):
return (x + 3).clamp(min=0, max=6).div(6.) return (x + 3).clamp(min=0, max=6).div(6.)
@torch.jit.script def hard_sigmoid_bwd(x, grad_output):
def hard_sigmoid_jit_bwd(x, grad_output):
m = torch.ones_like(x) * ((x >= -3.) & (x <= 3.)) / 6. m = torch.ones_like(x) * ((x >= -3.) & (x <= 3.)) / 6.
return grad_output * m return grad_output * m
class HardSigmoidJitAutoFn(torch.autograd.Function): class HardSigmoidAutoFn(torch.autograd.Function):
@staticmethod @staticmethod
def forward(ctx, x): def forward(ctx, x):
ctx.save_for_backward(x) ctx.save_for_backward(x)
return hard_sigmoid_jit_fwd(x) return hard_sigmoid_fwd(x)
@staticmethod @staticmethod
def backward(ctx, grad_output): def backward(ctx, grad_output):
x = ctx.saved_tensors[0] x = ctx.saved_tensors[0]
return hard_sigmoid_jit_bwd(x, grad_output) return hard_sigmoid_bwd(x, grad_output)
def hard_sigmoid_me(x, inplace: bool = False): def hard_sigmoid_me(x, inplace: bool = False):
return HardSigmoidJitAutoFn.apply(x) return HardSigmoidAutoFn.apply(x)
class HardSigmoidMe(nn.Module): class HardSigmoidMe(nn.Module):
@ -128,32 +122,30 @@ class HardSigmoidMe(nn.Module):
super(HardSigmoidMe, self).__init__() super(HardSigmoidMe, self).__init__()
def forward(self, x): def forward(self, x):
return HardSigmoidJitAutoFn.apply(x) return HardSigmoidAutoFn.apply(x)
@torch.jit.script def hard_swish_fwd(x):
def hard_swish_jit_fwd(x):
return x * (x + 3).clamp(min=0, max=6).div(6.) return x * (x + 3).clamp(min=0, max=6).div(6.)
@torch.jit.script def hard_swish_bwd(x, grad_output):
def hard_swish_jit_bwd(x, grad_output):
m = torch.ones_like(x) * (x >= 3.) m = torch.ones_like(x) * (x >= 3.)
m = torch.where((x >= -3.) & (x <= 3.), x / 3. + .5, m) m = torch.where((x >= -3.) & (x <= 3.), x / 3. + .5, m)
return grad_output * m return grad_output * m
class HardSwishJitAutoFn(torch.autograd.Function): class HardSwishAutoFn(torch.autograd.Function):
"""A memory efficient, jit-scripted HardSwish activation""" """A memory efficient HardSwish activation"""
@staticmethod @staticmethod
def forward(ctx, x): def forward(ctx, x):
ctx.save_for_backward(x) ctx.save_for_backward(x)
return hard_swish_jit_fwd(x) return hard_swish_fwd(x)
@staticmethod @staticmethod
def backward(ctx, grad_output): def backward(ctx, grad_output):
x = ctx.saved_tensors[0] x = ctx.saved_tensors[0]
return hard_swish_jit_bwd(x, grad_output) return hard_swish_bwd(x, grad_output)
@staticmethod @staticmethod
def symbolic(g, self): def symbolic(g, self):
@ -164,7 +156,7 @@ class HardSwishJitAutoFn(torch.autograd.Function):
def hard_swish_me(x, inplace=False): def hard_swish_me(x, inplace=False):
return HardSwishJitAutoFn.apply(x) return HardSwishAutoFn.apply(x)
class HardSwishMe(nn.Module): class HardSwishMe(nn.Module):
@ -172,39 +164,37 @@ class HardSwishMe(nn.Module):
super(HardSwishMe, self).__init__() super(HardSwishMe, self).__init__()
def forward(self, x): def forward(self, x):
return HardSwishJitAutoFn.apply(x) return HardSwishAutoFn.apply(x)
@torch.jit.script def hard_mish_fwd(x):
def hard_mish_jit_fwd(x):
return 0.5 * x * (x + 2).clamp(min=0, max=2) return 0.5 * x * (x + 2).clamp(min=0, max=2)
@torch.jit.script def hard_mish_bwd(x, grad_output):
def hard_mish_jit_bwd(x, grad_output):
m = torch.ones_like(x) * (x >= -2.) m = torch.ones_like(x) * (x >= -2.)
m = torch.where((x >= -2.) & (x <= 0.), x + 1., m) m = torch.where((x >= -2.) & (x <= 0.), x + 1., m)
return grad_output * m return grad_output * m
class HardMishJitAutoFn(torch.autograd.Function): class HardMishAutoFn(torch.autograd.Function):
""" A memory efficient, jit scripted variant of Hard Mish """ A memory efficient variant of Hard Mish
Experimental, based on notes by Mish author Diganta Misra at Experimental, based on notes by Mish author Diganta Misra at
https://github.com/digantamisra98/H-Mish/blob/0da20d4bc58e696b6803f2523c58d3c8a82782d0/README.md https://github.com/digantamisra98/H-Mish/blob/0da20d4bc58e696b6803f2523c58d3c8a82782d0/README.md
""" """
@staticmethod @staticmethod
def forward(ctx, x): def forward(ctx, x):
ctx.save_for_backward(x) ctx.save_for_backward(x)
return hard_mish_jit_fwd(x) return hard_mish_fwd(x)
@staticmethod @staticmethod
def backward(ctx, grad_output): def backward(ctx, grad_output):
x = ctx.saved_tensors[0] x = ctx.saved_tensors[0]
return hard_mish_jit_bwd(x, grad_output) return hard_mish_bwd(x, grad_output)
def hard_mish_me(x, inplace: bool = False): def hard_mish_me(x, inplace: bool = False):
return HardMishJitAutoFn.apply(x) return HardMishAutoFn.apply(x)
class HardMishMe(nn.Module): class HardMishMe(nn.Module):
@ -212,7 +202,7 @@ class HardMishMe(nn.Module):
super(HardMishMe, self).__init__() super(HardMishMe, self).__init__()
def forward(self, x): def forward(self, x):
return HardMishJitAutoFn.apply(x) return HardMishAutoFn.apply(x)

View File

@ -4,9 +4,8 @@ Hacked together by / Copyright 2020 Ross Wightman
from typing import Union, Callable, Type from typing import Union, Callable, Type
from .activations import * from .activations import *
from .activations_jit import *
from .activations_me import * from .activations_me import *
from .config import is_exportable, is_scriptable, is_no_jit from .config import is_exportable, is_scriptable
# PyTorch has an optimized, native 'silu' (aka 'swish') operator as of PyTorch 1.7. # PyTorch has an optimized, native 'silu' (aka 'swish') operator as of PyTorch 1.7.
# Also hardsigmoid, hardswish, and soon mish. This code will use native version if present. # Also hardsigmoid, hardswish, and soon mish. This code will use native version if present.
@ -37,15 +36,6 @@ _ACT_FN_DEFAULT = dict(
hard_mish=hard_mish, hard_mish=hard_mish,
) )
_ACT_FN_JIT = dict(
silu=F.silu if _has_silu else swish_jit,
swish=F.silu if _has_silu else swish_jit,
mish=F.mish if _has_mish else mish_jit,
hard_sigmoid=F.hardsigmoid if _has_hardsigmoid else hard_sigmoid_jit,
hard_swish=F.hardswish if _has_hardswish else hard_swish_jit,
hard_mish=hard_mish_jit,
)
_ACT_FN_ME = dict( _ACT_FN_ME = dict(
silu=F.silu if _has_silu else swish_me, silu=F.silu if _has_silu else swish_me,
swish=F.silu if _has_silu else swish_me, swish=F.silu if _has_silu else swish_me,
@ -55,7 +45,7 @@ _ACT_FN_ME = dict(
hard_mish=hard_mish_me, hard_mish=hard_mish_me,
) )
_ACT_FNS = (_ACT_FN_ME, _ACT_FN_JIT, _ACT_FN_DEFAULT) _ACT_FNS = (_ACT_FN_ME, _ACT_FN_DEFAULT)
for a in _ACT_FNS: for a in _ACT_FNS:
a.setdefault('hardsigmoid', a.get('hard_sigmoid')) a.setdefault('hardsigmoid', a.get('hard_sigmoid'))
a.setdefault('hardswish', a.get('hard_swish')) a.setdefault('hardswish', a.get('hard_swish'))
@ -83,15 +73,6 @@ _ACT_LAYER_DEFAULT = dict(
identity=nn.Identity, identity=nn.Identity,
) )
_ACT_LAYER_JIT = dict(
silu=nn.SiLU if _has_silu else SwishJit,
swish=nn.SiLU if _has_silu else SwishJit,
mish=nn.Mish if _has_mish else MishJit,
hard_sigmoid=nn.Hardsigmoid if _has_hardsigmoid else HardSigmoidJit,
hard_swish=nn.Hardswish if _has_hardswish else HardSwishJit,
hard_mish=HardMishJit,
)
_ACT_LAYER_ME = dict( _ACT_LAYER_ME = dict(
silu=nn.SiLU if _has_silu else SwishMe, silu=nn.SiLU if _has_silu else SwishMe,
swish=nn.SiLU if _has_silu else SwishMe, swish=nn.SiLU if _has_silu else SwishMe,
@ -101,7 +82,7 @@ _ACT_LAYER_ME = dict(
hard_mish=HardMishMe, hard_mish=HardMishMe,
) )
_ACT_LAYERS = (_ACT_LAYER_ME, _ACT_LAYER_JIT, _ACT_LAYER_DEFAULT) _ACT_LAYERS = (_ACT_LAYER_ME, _ACT_LAYER_DEFAULT)
for a in _ACT_LAYERS: for a in _ACT_LAYERS:
a.setdefault('hardsigmoid', a.get('hard_sigmoid')) a.setdefault('hardsigmoid', a.get('hard_sigmoid'))
a.setdefault('hardswish', a.get('hard_swish')) a.setdefault('hardswish', a.get('hard_swish'))
@ -116,14 +97,11 @@ def get_act_fn(name: Union[Callable, str] = 'relu'):
return None return None
if isinstance(name, Callable): if isinstance(name, Callable):
return name return name
if not (is_no_jit() or is_exportable() or is_scriptable()): if not (is_exportable() or is_scriptable()):
# If not exporting or scripting the model, first look for a memory-efficient version with # If not exporting or scripting the model, first look for a memory-efficient version with
# custom autograd, then fallback # custom autograd, then fallback
if name in _ACT_FN_ME: if name in _ACT_FN_ME:
return _ACT_FN_ME[name] return _ACT_FN_ME[name]
if not (is_no_jit() or is_exportable()):
if name in _ACT_FN_JIT:
return _ACT_FN_JIT[name]
return _ACT_FN_DEFAULT[name] return _ACT_FN_DEFAULT[name]
@ -139,12 +117,9 @@ def get_act_layer(name: Union[Type[nn.Module], str] = 'relu'):
return name return name
if not name: if not name:
return None return None
if not (is_no_jit() or is_exportable() or is_scriptable()): if not (is_exportable() or is_scriptable()):
if name in _ACT_LAYER_ME: if name in _ACT_LAYER_ME:
return _ACT_LAYER_ME[name] return _ACT_LAYER_ME[name]
if not (is_no_jit() or is_exportable()):
if name in _ACT_LAYER_JIT:
return _ACT_LAYER_JIT[name]
return _ACT_LAYER_DEFAULT[name] return _ACT_LAYER_DEFAULT[name]

View File

@ -73,7 +73,6 @@ class TransformerDecoderLayerOptimal(nn.Module):
return tgt return tgt
# @torch.jit.script
# class ExtrapClasses(object): # class ExtrapClasses(object):
# def __init__(self, num_queries: int, group_size: int): # def __init__(self, num_queries: int, group_size: int):
# self.num_queries = num_queries # self.num_queries = num_queries
@ -88,24 +87,13 @@ class TransformerDecoderLayerOptimal(nn.Module):
# out = out.view((h.shape[0], self.group_size * self.num_queries)) # out = out.view((h.shape[0], self.group_size * self.num_queries))
# return out # return out
@torch.jit.script
class GroupFC(object):
def __init__(self, embed_len_decoder: int):
self.embed_len_decoder = embed_len_decoder
def __call__(self, h: torch.Tensor, duplicate_pooling: torch.Tensor, out_extrap: torch.Tensor):
for i in range(self.embed_len_decoder):
h_i = h[:, i, :]
w_i = duplicate_pooling[i, :, :]
out_extrap[:, i, :] = torch.matmul(h_i, w_i)
class MLDecoder(nn.Module): class MLDecoder(nn.Module):
def __init__(self, num_classes, num_of_groups=-1, decoder_embedding=768, initial_num_features=2048): def __init__(self, num_classes, num_of_groups=-1, decoder_embedding=768, initial_num_features=2048):
super(MLDecoder, self).__init__() super(MLDecoder, self).__init__()
embed_len_decoder = 100 if num_of_groups < 0 else num_of_groups embed_len_decoder = 100 if num_of_groups < 0 else num_of_groups
if embed_len_decoder > num_classes: if embed_len_decoder > num_classes:
embed_len_decoder = num_classes embed_len_decoder = num_classes
self.embed_len_decoder = embed_len_decoder
# switching to 768 initial embeddings # switching to 768 initial embeddings
decoder_embedding = 768 if decoder_embedding < 0 else decoder_embedding decoder_embedding = 768 if decoder_embedding < 0 else decoder_embedding
@ -131,7 +119,6 @@ class MLDecoder(nn.Module):
self.duplicate_pooling_bias = torch.nn.Parameter(torch.Tensor(num_classes)) self.duplicate_pooling_bias = torch.nn.Parameter(torch.Tensor(num_classes))
torch.nn.init.xavier_normal_(self.duplicate_pooling) torch.nn.init.xavier_normal_(self.duplicate_pooling)
torch.nn.init.constant_(self.duplicate_pooling_bias, 0) torch.nn.init.constant_(self.duplicate_pooling_bias, 0)
self.group_fc = GroupFC(embed_len_decoder)
def forward(self, x): def forward(self, x):
if len(x.shape) == 4: # [bs,2048, 7,7] if len(x.shape) == 4: # [bs,2048, 7,7]
@ -149,7 +136,10 @@ class MLDecoder(nn.Module):
h = h.transpose(0, 1) h = h.transpose(0, 1)
out_extrap = torch.zeros(h.shape[0], h.shape[1], self.duplicate_factor, device=h.device, dtype=h.dtype) out_extrap = torch.zeros(h.shape[0], h.shape[1], self.duplicate_factor, device=h.device, dtype=h.dtype)
self.group_fc(h, self.duplicate_pooling, out_extrap) for i in range(self.embed_len_decoder): # group FC
h_i = h[:, i, :]
w_i = self.duplicate_pooling[i, :, :]
out_extrap[:, i, :] = torch.matmul(h_i, w_i)
h_out = out_extrap.flatten(1)[:, :self.num_classes] h_out = out_extrap.flatten(1)[:, :self.num_classes]
h_out += self.duplicate_pooling_bias h_out += self.duplicate_pooling_bias
logits = h_out logits = h_out

View File

@ -82,7 +82,6 @@ def _is_contiguous(tensor: torch.Tensor) -> bool:
return tensor.is_contiguous(memory_format=torch.contiguous_format) return tensor.is_contiguous(memory_format=torch.contiguous_format)
@torch.jit.script
def _layer_norm_cf(x: torch.Tensor, weight: torch.Tensor, bias: torch.Tensor, eps: float): def _layer_norm_cf(x: torch.Tensor, weight: torch.Tensor, bias: torch.Tensor, eps: float):
s, u = torch.var_mean(x, dim=1, unbiased=False, keepdim=True) s, u = torch.var_mean(x, dim=1, unbiased=False, keepdim=True)
x = (x - u) * torch.rsqrt(s + eps) x = (x - u) * torch.rsqrt(s + eps)

View File

@ -18,29 +18,6 @@ class SpaceToDepth(nn.Module):
return x return x
@torch.jit.script
class SpaceToDepthJit:
def __call__(self, x: torch.Tensor):
# assuming hard-coded that block_size==4 for acceleration
N, C, H, W = x.size()
x = x.view(N, C, H // 4, 4, W // 4, 4) # (N, C, H//bs, bs, W//bs, bs)
x = x.permute(0, 3, 5, 1, 2, 4).contiguous() # (N, bs, bs, C, H//bs, W//bs)
x = x.view(N, C * 16, H // 4, W // 4) # (N, C*bs^2, H//bs, W//bs)
return x
class SpaceToDepthModule(nn.Module):
def __init__(self, no_jit=False):
super().__init__()
if not no_jit:
self.op = SpaceToDepthJit()
else:
self.op = SpaceToDepth()
def forward(self, x):
return self.op(x)
class DepthToSpace(nn.Module): class DepthToSpace(nn.Module):
def __init__(self, block_size): def __init__(self, block_size):

View File

@ -37,7 +37,6 @@ from timm.layers.pool2d_same import AvgPool2dSame, create_pool2d
from timm.layers.squeeze_excite import SEModule, SqueezeExcite, EffectiveSEModule, EffectiveSqueezeExcite from timm.layers.squeeze_excite import SEModule, SqueezeExcite, EffectiveSEModule, EffectiveSqueezeExcite
from timm.layers.selective_kernel import SelectiveKernel from timm.layers.selective_kernel import SelectiveKernel
from timm.layers.separable_conv import SeparableConv2d, SeparableConvNormAct from timm.layers.separable_conv import SeparableConv2d, SeparableConvNormAct
from timm.layers.space_to_depth import SpaceToDepthModule
from timm.layers.split_attn import SplitAttn from timm.layers.split_attn import SplitAttn
from timm.layers.split_batchnorm import SplitBatchNorm2d, convert_splitbn_model from timm.layers.split_batchnorm import SplitBatchNorm2d, convert_splitbn_model
from timm.layers.std_conv import StdConv2d, StdConv2dSame, ScaledStdConv2d, ScaledStdConv2dSame from timm.layers.std_conv import StdConv2d, StdConv2dSame, ScaledStdConv2d, ScaledStdConv2dSame