mirror of
https://github.com/huggingface/pytorch-image-models.git
synced 2025-06-03 15:01:08 +08:00
Remove JIT activations, take jit out of ME activations. Remove other instances of torch.jit.script. Breaks torch.compile and is much less performant. Remove SpaceToDepthModule
This commit is contained in:
parent
07535f408a
commit
2bfa5e5d74
@ -47,7 +47,7 @@ from .pos_embed_sincos import pixel_freq_bands, freq_bands, build_sincos2d_pos_e
|
||||
from .squeeze_excite import SEModule, SqueezeExcite, EffectiveSEModule, EffectiveSqueezeExcite
|
||||
from .selective_kernel import SelectiveKernel
|
||||
from .separable_conv import SeparableConv2d, SeparableConvNormAct
|
||||
from .space_to_depth import SpaceToDepthModule, SpaceToDepth, DepthToSpace
|
||||
from .space_to_depth import SpaceToDepth, DepthToSpace
|
||||
from .split_attn import SplitAttn
|
||||
from .split_batchnorm import SplitBatchNorm2d, convert_splitbn_model
|
||||
from .std_conv import StdConv2d, StdConv2dSame, ScaledStdConv2d, ScaledStdConv2dSame
|
||||
|
@ -1,90 +0,0 @@
|
||||
""" Activations
|
||||
|
||||
A collection of jit-scripted activations fn and modules with a common interface so that they can
|
||||
easily be swapped. All have an `inplace` arg even if not used.
|
||||
|
||||
All jit scripted activations are lacking in-place variations on purpose, scripted kernel fusion does not
|
||||
currently work across in-place op boundaries, thus performance is equal to or less than the non-scripted
|
||||
versions if they contain in-place ops.
|
||||
|
||||
Hacked together by / Copyright 2020 Ross Wightman
|
||||
"""
|
||||
|
||||
import torch
|
||||
from torch import nn as nn
|
||||
from torch.nn import functional as F
|
||||
|
||||
|
||||
@torch.jit.script
|
||||
def swish_jit(x, inplace: bool = False):
|
||||
"""Swish - Described in: https://arxiv.org/abs/1710.05941
|
||||
"""
|
||||
return x.mul(x.sigmoid())
|
||||
|
||||
|
||||
@torch.jit.script
|
||||
def mish_jit(x, _inplace: bool = False):
|
||||
"""Mish: A Self Regularized Non-Monotonic Neural Activation Function - https://arxiv.org/abs/1908.08681
|
||||
"""
|
||||
return x.mul(F.softplus(x).tanh())
|
||||
|
||||
|
||||
class SwishJit(nn.Module):
|
||||
def __init__(self, inplace: bool = False):
|
||||
super(SwishJit, self).__init__()
|
||||
|
||||
def forward(self, x):
|
||||
return swish_jit(x)
|
||||
|
||||
|
||||
class MishJit(nn.Module):
|
||||
def __init__(self, inplace: bool = False):
|
||||
super(MishJit, self).__init__()
|
||||
|
||||
def forward(self, x):
|
||||
return mish_jit(x)
|
||||
|
||||
|
||||
@torch.jit.script
|
||||
def hard_sigmoid_jit(x, inplace: bool = False):
|
||||
# return F.relu6(x + 3.) / 6.
|
||||
return (x + 3).clamp(min=0, max=6).div(6.) # clamp seems ever so slightly faster?
|
||||
|
||||
|
||||
class HardSigmoidJit(nn.Module):
|
||||
def __init__(self, inplace: bool = False):
|
||||
super(HardSigmoidJit, self).__init__()
|
||||
|
||||
def forward(self, x):
|
||||
return hard_sigmoid_jit(x)
|
||||
|
||||
|
||||
@torch.jit.script
|
||||
def hard_swish_jit(x, inplace: bool = False):
|
||||
# return x * (F.relu6(x + 3.) / 6)
|
||||
return x * (x + 3).clamp(min=0, max=6).div(6.) # clamp seems ever so slightly faster?
|
||||
|
||||
|
||||
class HardSwishJit(nn.Module):
|
||||
def __init__(self, inplace: bool = False):
|
||||
super(HardSwishJit, self).__init__()
|
||||
|
||||
def forward(self, x):
|
||||
return hard_swish_jit(x)
|
||||
|
||||
|
||||
@torch.jit.script
|
||||
def hard_mish_jit(x, inplace: bool = False):
|
||||
""" Hard Mish
|
||||
Experimental, based on notes by Mish author Diganta Misra at
|
||||
https://github.com/digantamisra98/H-Mish/blob/0da20d4bc58e696b6803f2523c58d3c8a82782d0/README.md
|
||||
"""
|
||||
return 0.5 * x * (x + 2).clamp(min=0, max=2)
|
||||
|
||||
|
||||
class HardMishJit(nn.Module):
|
||||
def __init__(self, inplace: bool = False):
|
||||
super(HardMishJit, self).__init__()
|
||||
|
||||
def forward(self, x):
|
||||
return hard_mish_jit(x)
|
@ -3,8 +3,8 @@
|
||||
A collection of activations fn and modules with a common interface so that they can
|
||||
easily be swapped. All have an `inplace` arg even if not used.
|
||||
|
||||
These activations are not compatible with jit scripting or ONNX export of the model, please use either
|
||||
the JIT or basic versions of the activations.
|
||||
These activations are not compatible with jit scripting or ONNX export of the model, please use
|
||||
basic versions of the activations.
|
||||
|
||||
Hacked together by / Copyright 2020 Ross Wightman
|
||||
"""
|
||||
@ -14,19 +14,17 @@ from torch import nn as nn
|
||||
from torch.nn import functional as F
|
||||
|
||||
|
||||
@torch.jit.script
|
||||
def swish_jit_fwd(x):
|
||||
def swish_fwd(x):
|
||||
return x.mul(torch.sigmoid(x))
|
||||
|
||||
|
||||
@torch.jit.script
|
||||
def swish_jit_bwd(x, grad_output):
|
||||
def swish_bwd(x, grad_output):
|
||||
x_sigmoid = torch.sigmoid(x)
|
||||
return grad_output * (x_sigmoid * (1 + x * (1 - x_sigmoid)))
|
||||
|
||||
|
||||
class SwishJitAutoFn(torch.autograd.Function):
|
||||
""" torch.jit.script optimised Swish w/ memory-efficient checkpoint
|
||||
class SwishAutoFn(torch.autograd.Function):
|
||||
""" optimised Swish w/ memory-efficient checkpoint
|
||||
Inspired by conversation btw Jeremy Howard & Adam Pazske
|
||||
https://twitter.com/jeremyphoward/status/1188251041835315200
|
||||
"""
|
||||
@ -37,16 +35,16 @@ class SwishJitAutoFn(torch.autograd.Function):
|
||||
@staticmethod
|
||||
def forward(ctx, x):
|
||||
ctx.save_for_backward(x)
|
||||
return swish_jit_fwd(x)
|
||||
return swish_fwd(x)
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, grad_output):
|
||||
x = ctx.saved_tensors[0]
|
||||
return swish_jit_bwd(x, grad_output)
|
||||
return swish_bwd(x, grad_output)
|
||||
|
||||
|
||||
def swish_me(x, inplace=False):
|
||||
return SwishJitAutoFn.apply(x)
|
||||
return SwishAutoFn.apply(x)
|
||||
|
||||
|
||||
class SwishMe(nn.Module):
|
||||
@ -54,38 +52,36 @@ class SwishMe(nn.Module):
|
||||
super(SwishMe, self).__init__()
|
||||
|
||||
def forward(self, x):
|
||||
return SwishJitAutoFn.apply(x)
|
||||
return SwishAutoFn.apply(x)
|
||||
|
||||
|
||||
@torch.jit.script
|
||||
def mish_jit_fwd(x):
|
||||
def mish_fwd(x):
|
||||
return x.mul(torch.tanh(F.softplus(x)))
|
||||
|
||||
|
||||
@torch.jit.script
|
||||
def mish_jit_bwd(x, grad_output):
|
||||
def mish_bwd(x, grad_output):
|
||||
x_sigmoid = torch.sigmoid(x)
|
||||
x_tanh_sp = F.softplus(x).tanh()
|
||||
return grad_output.mul(x_tanh_sp + x * x_sigmoid * (1 - x_tanh_sp * x_tanh_sp))
|
||||
|
||||
|
||||
class MishJitAutoFn(torch.autograd.Function):
|
||||
class MishAutoFn(torch.autograd.Function):
|
||||
""" Mish: A Self Regularized Non-Monotonic Neural Activation Function - https://arxiv.org/abs/1908.08681
|
||||
A memory efficient, jit scripted variant of Mish
|
||||
A memory efficient variant of Mish
|
||||
"""
|
||||
@staticmethod
|
||||
def forward(ctx, x):
|
||||
ctx.save_for_backward(x)
|
||||
return mish_jit_fwd(x)
|
||||
return mish_fwd(x)
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, grad_output):
|
||||
x = ctx.saved_tensors[0]
|
||||
return mish_jit_bwd(x, grad_output)
|
||||
return mish_bwd(x, grad_output)
|
||||
|
||||
|
||||
def mish_me(x, inplace=False):
|
||||
return MishJitAutoFn.apply(x)
|
||||
return MishAutoFn.apply(x)
|
||||
|
||||
|
||||
class MishMe(nn.Module):
|
||||
@ -93,34 +89,32 @@ class MishMe(nn.Module):
|
||||
super(MishMe, self).__init__()
|
||||
|
||||
def forward(self, x):
|
||||
return MishJitAutoFn.apply(x)
|
||||
return MishAutoFn.apply(x)
|
||||
|
||||
|
||||
@torch.jit.script
|
||||
def hard_sigmoid_jit_fwd(x, inplace: bool = False):
|
||||
def hard_sigmoid_fwd(x, inplace: bool = False):
|
||||
return (x + 3).clamp(min=0, max=6).div(6.)
|
||||
|
||||
|
||||
@torch.jit.script
|
||||
def hard_sigmoid_jit_bwd(x, grad_output):
|
||||
def hard_sigmoid_bwd(x, grad_output):
|
||||
m = torch.ones_like(x) * ((x >= -3.) & (x <= 3.)) / 6.
|
||||
return grad_output * m
|
||||
|
||||
|
||||
class HardSigmoidJitAutoFn(torch.autograd.Function):
|
||||
class HardSigmoidAutoFn(torch.autograd.Function):
|
||||
@staticmethod
|
||||
def forward(ctx, x):
|
||||
ctx.save_for_backward(x)
|
||||
return hard_sigmoid_jit_fwd(x)
|
||||
return hard_sigmoid_fwd(x)
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, grad_output):
|
||||
x = ctx.saved_tensors[0]
|
||||
return hard_sigmoid_jit_bwd(x, grad_output)
|
||||
return hard_sigmoid_bwd(x, grad_output)
|
||||
|
||||
|
||||
def hard_sigmoid_me(x, inplace: bool = False):
|
||||
return HardSigmoidJitAutoFn.apply(x)
|
||||
return HardSigmoidAutoFn.apply(x)
|
||||
|
||||
|
||||
class HardSigmoidMe(nn.Module):
|
||||
@ -128,32 +122,30 @@ class HardSigmoidMe(nn.Module):
|
||||
super(HardSigmoidMe, self).__init__()
|
||||
|
||||
def forward(self, x):
|
||||
return HardSigmoidJitAutoFn.apply(x)
|
||||
return HardSigmoidAutoFn.apply(x)
|
||||
|
||||
|
||||
@torch.jit.script
|
||||
def hard_swish_jit_fwd(x):
|
||||
def hard_swish_fwd(x):
|
||||
return x * (x + 3).clamp(min=0, max=6).div(6.)
|
||||
|
||||
|
||||
@torch.jit.script
|
||||
def hard_swish_jit_bwd(x, grad_output):
|
||||
def hard_swish_bwd(x, grad_output):
|
||||
m = torch.ones_like(x) * (x >= 3.)
|
||||
m = torch.where((x >= -3.) & (x <= 3.), x / 3. + .5, m)
|
||||
return grad_output * m
|
||||
|
||||
|
||||
class HardSwishJitAutoFn(torch.autograd.Function):
|
||||
"""A memory efficient, jit-scripted HardSwish activation"""
|
||||
class HardSwishAutoFn(torch.autograd.Function):
|
||||
"""A memory efficient HardSwish activation"""
|
||||
@staticmethod
|
||||
def forward(ctx, x):
|
||||
ctx.save_for_backward(x)
|
||||
return hard_swish_jit_fwd(x)
|
||||
return hard_swish_fwd(x)
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, grad_output):
|
||||
x = ctx.saved_tensors[0]
|
||||
return hard_swish_jit_bwd(x, grad_output)
|
||||
return hard_swish_bwd(x, grad_output)
|
||||
|
||||
@staticmethod
|
||||
def symbolic(g, self):
|
||||
@ -164,7 +156,7 @@ class HardSwishJitAutoFn(torch.autograd.Function):
|
||||
|
||||
|
||||
def hard_swish_me(x, inplace=False):
|
||||
return HardSwishJitAutoFn.apply(x)
|
||||
return HardSwishAutoFn.apply(x)
|
||||
|
||||
|
||||
class HardSwishMe(nn.Module):
|
||||
@ -172,39 +164,37 @@ class HardSwishMe(nn.Module):
|
||||
super(HardSwishMe, self).__init__()
|
||||
|
||||
def forward(self, x):
|
||||
return HardSwishJitAutoFn.apply(x)
|
||||
return HardSwishAutoFn.apply(x)
|
||||
|
||||
|
||||
@torch.jit.script
|
||||
def hard_mish_jit_fwd(x):
|
||||
def hard_mish_fwd(x):
|
||||
return 0.5 * x * (x + 2).clamp(min=0, max=2)
|
||||
|
||||
|
||||
@torch.jit.script
|
||||
def hard_mish_jit_bwd(x, grad_output):
|
||||
def hard_mish_bwd(x, grad_output):
|
||||
m = torch.ones_like(x) * (x >= -2.)
|
||||
m = torch.where((x >= -2.) & (x <= 0.), x + 1., m)
|
||||
return grad_output * m
|
||||
|
||||
|
||||
class HardMishJitAutoFn(torch.autograd.Function):
|
||||
""" A memory efficient, jit scripted variant of Hard Mish
|
||||
class HardMishAutoFn(torch.autograd.Function):
|
||||
""" A memory efficient variant of Hard Mish
|
||||
Experimental, based on notes by Mish author Diganta Misra at
|
||||
https://github.com/digantamisra98/H-Mish/blob/0da20d4bc58e696b6803f2523c58d3c8a82782d0/README.md
|
||||
"""
|
||||
@staticmethod
|
||||
def forward(ctx, x):
|
||||
ctx.save_for_backward(x)
|
||||
return hard_mish_jit_fwd(x)
|
||||
return hard_mish_fwd(x)
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, grad_output):
|
||||
x = ctx.saved_tensors[0]
|
||||
return hard_mish_jit_bwd(x, grad_output)
|
||||
return hard_mish_bwd(x, grad_output)
|
||||
|
||||
|
||||
def hard_mish_me(x, inplace: bool = False):
|
||||
return HardMishJitAutoFn.apply(x)
|
||||
return HardMishAutoFn.apply(x)
|
||||
|
||||
|
||||
class HardMishMe(nn.Module):
|
||||
@ -212,7 +202,7 @@ class HardMishMe(nn.Module):
|
||||
super(HardMishMe, self).__init__()
|
||||
|
||||
def forward(self, x):
|
||||
return HardMishJitAutoFn.apply(x)
|
||||
return HardMishAutoFn.apply(x)
|
||||
|
||||
|
||||
|
||||
|
@ -4,9 +4,8 @@ Hacked together by / Copyright 2020 Ross Wightman
|
||||
from typing import Union, Callable, Type
|
||||
|
||||
from .activations import *
|
||||
from .activations_jit import *
|
||||
from .activations_me import *
|
||||
from .config import is_exportable, is_scriptable, is_no_jit
|
||||
from .config import is_exportable, is_scriptable
|
||||
|
||||
# PyTorch has an optimized, native 'silu' (aka 'swish') operator as of PyTorch 1.7.
|
||||
# Also hardsigmoid, hardswish, and soon mish. This code will use native version if present.
|
||||
@ -37,15 +36,6 @@ _ACT_FN_DEFAULT = dict(
|
||||
hard_mish=hard_mish,
|
||||
)
|
||||
|
||||
_ACT_FN_JIT = dict(
|
||||
silu=F.silu if _has_silu else swish_jit,
|
||||
swish=F.silu if _has_silu else swish_jit,
|
||||
mish=F.mish if _has_mish else mish_jit,
|
||||
hard_sigmoid=F.hardsigmoid if _has_hardsigmoid else hard_sigmoid_jit,
|
||||
hard_swish=F.hardswish if _has_hardswish else hard_swish_jit,
|
||||
hard_mish=hard_mish_jit,
|
||||
)
|
||||
|
||||
_ACT_FN_ME = dict(
|
||||
silu=F.silu if _has_silu else swish_me,
|
||||
swish=F.silu if _has_silu else swish_me,
|
||||
@ -55,7 +45,7 @@ _ACT_FN_ME = dict(
|
||||
hard_mish=hard_mish_me,
|
||||
)
|
||||
|
||||
_ACT_FNS = (_ACT_FN_ME, _ACT_FN_JIT, _ACT_FN_DEFAULT)
|
||||
_ACT_FNS = (_ACT_FN_ME, _ACT_FN_DEFAULT)
|
||||
for a in _ACT_FNS:
|
||||
a.setdefault('hardsigmoid', a.get('hard_sigmoid'))
|
||||
a.setdefault('hardswish', a.get('hard_swish'))
|
||||
@ -83,15 +73,6 @@ _ACT_LAYER_DEFAULT = dict(
|
||||
identity=nn.Identity,
|
||||
)
|
||||
|
||||
_ACT_LAYER_JIT = dict(
|
||||
silu=nn.SiLU if _has_silu else SwishJit,
|
||||
swish=nn.SiLU if _has_silu else SwishJit,
|
||||
mish=nn.Mish if _has_mish else MishJit,
|
||||
hard_sigmoid=nn.Hardsigmoid if _has_hardsigmoid else HardSigmoidJit,
|
||||
hard_swish=nn.Hardswish if _has_hardswish else HardSwishJit,
|
||||
hard_mish=HardMishJit,
|
||||
)
|
||||
|
||||
_ACT_LAYER_ME = dict(
|
||||
silu=nn.SiLU if _has_silu else SwishMe,
|
||||
swish=nn.SiLU if _has_silu else SwishMe,
|
||||
@ -101,7 +82,7 @@ _ACT_LAYER_ME = dict(
|
||||
hard_mish=HardMishMe,
|
||||
)
|
||||
|
||||
_ACT_LAYERS = (_ACT_LAYER_ME, _ACT_LAYER_JIT, _ACT_LAYER_DEFAULT)
|
||||
_ACT_LAYERS = (_ACT_LAYER_ME, _ACT_LAYER_DEFAULT)
|
||||
for a in _ACT_LAYERS:
|
||||
a.setdefault('hardsigmoid', a.get('hard_sigmoid'))
|
||||
a.setdefault('hardswish', a.get('hard_swish'))
|
||||
@ -116,14 +97,11 @@ def get_act_fn(name: Union[Callable, str] = 'relu'):
|
||||
return None
|
||||
if isinstance(name, Callable):
|
||||
return name
|
||||
if not (is_no_jit() or is_exportable() or is_scriptable()):
|
||||
if not (is_exportable() or is_scriptable()):
|
||||
# If not exporting or scripting the model, first look for a memory-efficient version with
|
||||
# custom autograd, then fallback
|
||||
if name in _ACT_FN_ME:
|
||||
return _ACT_FN_ME[name]
|
||||
if not (is_no_jit() or is_exportable()):
|
||||
if name in _ACT_FN_JIT:
|
||||
return _ACT_FN_JIT[name]
|
||||
return _ACT_FN_DEFAULT[name]
|
||||
|
||||
|
||||
@ -139,12 +117,9 @@ def get_act_layer(name: Union[Type[nn.Module], str] = 'relu'):
|
||||
return name
|
||||
if not name:
|
||||
return None
|
||||
if not (is_no_jit() or is_exportable() or is_scriptable()):
|
||||
if not (is_exportable() or is_scriptable()):
|
||||
if name in _ACT_LAYER_ME:
|
||||
return _ACT_LAYER_ME[name]
|
||||
if not (is_no_jit() or is_exportable()):
|
||||
if name in _ACT_LAYER_JIT:
|
||||
return _ACT_LAYER_JIT[name]
|
||||
return _ACT_LAYER_DEFAULT[name]
|
||||
|
||||
|
||||
|
@ -73,7 +73,6 @@ class TransformerDecoderLayerOptimal(nn.Module):
|
||||
return tgt
|
||||
|
||||
|
||||
# @torch.jit.script
|
||||
# class ExtrapClasses(object):
|
||||
# def __init__(self, num_queries: int, group_size: int):
|
||||
# self.num_queries = num_queries
|
||||
@ -88,24 +87,13 @@ class TransformerDecoderLayerOptimal(nn.Module):
|
||||
# out = out.view((h.shape[0], self.group_size * self.num_queries))
|
||||
# return out
|
||||
|
||||
@torch.jit.script
|
||||
class GroupFC(object):
|
||||
def __init__(self, embed_len_decoder: int):
|
||||
self.embed_len_decoder = embed_len_decoder
|
||||
|
||||
def __call__(self, h: torch.Tensor, duplicate_pooling: torch.Tensor, out_extrap: torch.Tensor):
|
||||
for i in range(self.embed_len_decoder):
|
||||
h_i = h[:, i, :]
|
||||
w_i = duplicate_pooling[i, :, :]
|
||||
out_extrap[:, i, :] = torch.matmul(h_i, w_i)
|
||||
|
||||
|
||||
class MLDecoder(nn.Module):
|
||||
def __init__(self, num_classes, num_of_groups=-1, decoder_embedding=768, initial_num_features=2048):
|
||||
super(MLDecoder, self).__init__()
|
||||
embed_len_decoder = 100 if num_of_groups < 0 else num_of_groups
|
||||
if embed_len_decoder > num_classes:
|
||||
embed_len_decoder = num_classes
|
||||
self.embed_len_decoder = embed_len_decoder
|
||||
|
||||
# switching to 768 initial embeddings
|
||||
decoder_embedding = 768 if decoder_embedding < 0 else decoder_embedding
|
||||
@ -131,7 +119,6 @@ class MLDecoder(nn.Module):
|
||||
self.duplicate_pooling_bias = torch.nn.Parameter(torch.Tensor(num_classes))
|
||||
torch.nn.init.xavier_normal_(self.duplicate_pooling)
|
||||
torch.nn.init.constant_(self.duplicate_pooling_bias, 0)
|
||||
self.group_fc = GroupFC(embed_len_decoder)
|
||||
|
||||
def forward(self, x):
|
||||
if len(x.shape) == 4: # [bs,2048, 7,7]
|
||||
@ -149,7 +136,10 @@ class MLDecoder(nn.Module):
|
||||
h = h.transpose(0, 1)
|
||||
|
||||
out_extrap = torch.zeros(h.shape[0], h.shape[1], self.duplicate_factor, device=h.device, dtype=h.dtype)
|
||||
self.group_fc(h, self.duplicate_pooling, out_extrap)
|
||||
for i in range(self.embed_len_decoder): # group FC
|
||||
h_i = h[:, i, :]
|
||||
w_i = self.duplicate_pooling[i, :, :]
|
||||
out_extrap[:, i, :] = torch.matmul(h_i, w_i)
|
||||
h_out = out_extrap.flatten(1)[:, :self.num_classes]
|
||||
h_out += self.duplicate_pooling_bias
|
||||
logits = h_out
|
||||
|
@ -82,7 +82,6 @@ def _is_contiguous(tensor: torch.Tensor) -> bool:
|
||||
return tensor.is_contiguous(memory_format=torch.contiguous_format)
|
||||
|
||||
|
||||
@torch.jit.script
|
||||
def _layer_norm_cf(x: torch.Tensor, weight: torch.Tensor, bias: torch.Tensor, eps: float):
|
||||
s, u = torch.var_mean(x, dim=1, unbiased=False, keepdim=True)
|
||||
x = (x - u) * torch.rsqrt(s + eps)
|
||||
|
@ -18,29 +18,6 @@ class SpaceToDepth(nn.Module):
|
||||
return x
|
||||
|
||||
|
||||
@torch.jit.script
|
||||
class SpaceToDepthJit:
|
||||
def __call__(self, x: torch.Tensor):
|
||||
# assuming hard-coded that block_size==4 for acceleration
|
||||
N, C, H, W = x.size()
|
||||
x = x.view(N, C, H // 4, 4, W // 4, 4) # (N, C, H//bs, bs, W//bs, bs)
|
||||
x = x.permute(0, 3, 5, 1, 2, 4).contiguous() # (N, bs, bs, C, H//bs, W//bs)
|
||||
x = x.view(N, C * 16, H // 4, W // 4) # (N, C*bs^2, H//bs, W//bs)
|
||||
return x
|
||||
|
||||
|
||||
class SpaceToDepthModule(nn.Module):
|
||||
def __init__(self, no_jit=False):
|
||||
super().__init__()
|
||||
if not no_jit:
|
||||
self.op = SpaceToDepthJit()
|
||||
else:
|
||||
self.op = SpaceToDepth()
|
||||
|
||||
def forward(self, x):
|
||||
return self.op(x)
|
||||
|
||||
|
||||
class DepthToSpace(nn.Module):
|
||||
|
||||
def __init__(self, block_size):
|
||||
|
@ -37,7 +37,6 @@ from timm.layers.pool2d_same import AvgPool2dSame, create_pool2d
|
||||
from timm.layers.squeeze_excite import SEModule, SqueezeExcite, EffectiveSEModule, EffectiveSqueezeExcite
|
||||
from timm.layers.selective_kernel import SelectiveKernel
|
||||
from timm.layers.separable_conv import SeparableConv2d, SeparableConvNormAct
|
||||
from timm.layers.space_to_depth import SpaceToDepthModule
|
||||
from timm.layers.split_attn import SplitAttn
|
||||
from timm.layers.split_batchnorm import SplitBatchNorm2d, convert_splitbn_model
|
||||
from timm.layers.std_conv import StdConv2d, StdConv2dSame, ScaledStdConv2d, ScaledStdConv2dSame
|
||||
|
Loading…
x
Reference in New Issue
Block a user