mirror of
https://github.com/huggingface/pytorch-image-models.git
synced 2025-06-03 15:01:08 +08:00
Add non-local and BAT attention. Merge attn and self-attn factories into one. Add attention references to README. Add mlp 'mode' to ECA.
This commit is contained in:
parent
17dc47c8e6
commit
307a935b79
16
README.md
16
README.md
@ -295,10 +295,24 @@ Several (less common) features that I often utilize in my projects are included.
|
||||
* SplitBachNorm - allows splitting batch norm layers between clean and augmented (auxiliary batch norm) data
|
||||
* DropPath aka "Stochastic Depth" (https://arxiv.org/abs/1603.09382)
|
||||
* DropBlock (https://arxiv.org/abs/1810.12890)
|
||||
* Efficient Channel Attention - ECA (https://arxiv.org/abs/1910.03151)
|
||||
* Blur Pooling (https://arxiv.org/abs/1904.11486)
|
||||
* Space-to-Depth by [mrT23](https://github.com/mrT23/TResNet/blob/master/src/models/tresnet/layers/space_to_depth.py) (https://arxiv.org/abs/1801.04590) -- original paper?
|
||||
* Adaptive Gradient Clipping (https://arxiv.org/abs/2102.06171, https://github.com/deepmind/deepmind-research/tree/master/nfnets)
|
||||
* An extensive selection of channel and/or spatial attention modules:
|
||||
* Bottleneck Transformer - https://arxiv.org/abs/2101.11605
|
||||
* CBAM - https://arxiv.org/abs/1807.06521
|
||||
* Effective Squeeze-Excitation (ESE) - https://arxiv.org/abs/1911.06667
|
||||
* Efficient Channel Attention (ECA) - https://arxiv.org/abs/1910.03151
|
||||
* Gather-Excite (GE) - https://arxiv.org/abs/1810.12348
|
||||
* Global Context (GC) - https://arxiv.org/abs/1904.11492
|
||||
* Halo - https://arxiv.org/abs/2103.12731
|
||||
* Involution - https://arxiv.org/abs/2103.06255
|
||||
* Lambda Layer - https://arxiv.org/abs/2102.08602
|
||||
* Non-Local (NL) - https://arxiv.org/abs/1711.07971
|
||||
* Squeeze-and-Excitation (SE) - https://arxiv.org/abs/1709.01507
|
||||
* Selective Kernel (SK) - (https://arxiv.org/abs/1903.06586
|
||||
* Split (SPLAT) - https://arxiv.org/abs/2004.08955
|
||||
* Shifted Window (SWIN) - https://arxiv.org/abs/2103.14030
|
||||
|
||||
## Results
|
||||
|
||||
|
@ -35,7 +35,7 @@ import torch.nn as nn
|
||||
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
|
||||
from .helpers import build_model_with_cfg
|
||||
from .layers import ClassifierHead, ConvBnAct, BatchNormAct2d, DropPath, AvgPool2dSame, \
|
||||
create_conv2d, get_act_layer, convert_norm_act, get_attn, get_self_attn, make_divisible, to_2tuple
|
||||
create_conv2d, get_act_layer, convert_norm_act, get_attn, make_divisible, to_2tuple
|
||||
from .registry import register_model
|
||||
|
||||
__all__ = ['ByobNet', 'ByoModelCfg', 'ByoBlockCfg', 'create_byob_stem', 'create_block']
|
||||
@ -935,7 +935,7 @@ def update_block_kwargs(block_kwargs: Dict[str, Any], block_cfg: ByoBlockCfg, mo
|
||||
else:
|
||||
self_attn_kwargs = override_kwargs(block_cfg.self_attn_kwargs, model_cfg.self_attn_kwargs)
|
||||
self_attn_layer = block_cfg.self_attn_layer or model_cfg.self_attn_layer
|
||||
self_attn_layer = partial(get_self_attn(self_attn_layer), *self_attn_kwargs) \
|
||||
self_attn_layer = partial(get_attn(self_attn_layer), *self_attn_kwargs) \
|
||||
if self_attn_layer is not None else None
|
||||
layer_fns = replace(layer_fns, self_attn=self_attn_layer)
|
||||
|
||||
@ -1010,7 +1010,7 @@ def get_layer_fns(cfg: ByoModelCfg):
|
||||
norm_act = convert_norm_act(norm_layer=cfg.norm_layer, act_layer=act)
|
||||
conv_norm_act = partial(ConvBnAct, norm_layer=cfg.norm_layer, act_layer=act)
|
||||
attn = partial(get_attn(cfg.attn_layer), **cfg.attn_kwargs) if cfg.attn_layer else None
|
||||
self_attn = partial(get_self_attn(cfg.self_attn_layer), **cfg.self_attn_kwargs) if cfg.self_attn_layer else None
|
||||
self_attn = partial(get_attn(cfg.self_attn_layer), **cfg.self_attn_kwargs) if cfg.self_attn_layer else None
|
||||
layer_fn = LayerFn(conv_norm_act=conv_norm_act, norm_act=norm_act, act=act, attn=attn, self_attn=self_attn)
|
||||
return layer_fn
|
||||
|
||||
|
@ -1234,7 +1234,8 @@ def eca_efficientnet_b0(pretrained=False, **kwargs):
|
||||
""" EfficientNet-B0 w/ ECA attn """
|
||||
# NOTE experimental config
|
||||
model = _gen_efficientnet(
|
||||
'eca_efficientnet_b0', se_layer='eca', channel_multiplier=1.0, depth_multiplier=1.0, pretrained=pretrained, **kwargs)
|
||||
'eca_efficientnet_b0', se_layer='ecam', channel_multiplier=1.0, depth_multiplier=1.0,
|
||||
pretrained=pretrained, **kwargs)
|
||||
return model
|
||||
|
||||
|
||||
@ -1243,7 +1244,8 @@ def gc_efficientnet_b0(pretrained=False, **kwargs):
|
||||
""" EfficientNet-B0 w/ GlobalContext """
|
||||
# NOTE experminetal config
|
||||
model = _gen_efficientnet(
|
||||
'gc_efficientnet_b0', se_layer='gc', channel_multiplier=1.0, depth_multiplier=1.0, pretrained=pretrained, **kwargs)
|
||||
'gc_efficientnet_b0', se_layer='gc', channel_multiplier=1.0, depth_multiplier=1.0,
|
||||
pretrained=pretrained, **kwargs)
|
||||
return model
|
||||
|
||||
|
||||
|
@ -12,7 +12,6 @@ from .create_act import create_act_layer, get_act_layer, get_act_fn
|
||||
from .create_attn import get_attn, create_attn
|
||||
from .create_conv2d import create_conv2d
|
||||
from .create_norm_act import get_norm_act_layer, create_norm_act, convert_norm_act
|
||||
from .create_self_attn import get_self_attn, create_self_attn
|
||||
from .drop import DropBlock2d, DropPath, drop_block_2d, drop_path
|
||||
from .eca import EcaModule, CecaModule, EfficientChannelAttn, CircularEfficientChannelAttn
|
||||
from .evo_norm import EvoNormBatch2d, EvoNormSample2d
|
||||
@ -24,16 +23,17 @@ from .involution import Involution
|
||||
from .linear import Linear
|
||||
from .mixed_conv2d import MixedConv2d
|
||||
from .mlp import Mlp, GluMlp, GatedMlp
|
||||
from .non_local_attn import NonLocalAttn, BatNonLocalAttn
|
||||
from .norm import GroupNorm, LayerNorm2d
|
||||
from .norm_act import BatchNormAct2d, GroupNormAct
|
||||
from .padding import get_padding, get_same_padding, pad_same
|
||||
from .patch_embed import PatchEmbed
|
||||
from .pool2d_same import AvgPool2dSame, create_pool2d
|
||||
from .squeeze_excite import SEModule, SqueezeExcite, EffectiveSEModule, EffectiveSqueezeExcite
|
||||
from .selective_kernel import SelectiveKernelConv
|
||||
from .selective_kernel import SelectiveKernel
|
||||
from .separable_conv import SeparableConv2d, SeparableConvBnAct
|
||||
from .space_to_depth import SpaceToDepthModule
|
||||
from .split_attn import SplitAttnConv2d
|
||||
from .split_attn import SplitAttn
|
||||
from .split_batchnorm import SplitBatchNorm2d, convert_splitbn_model
|
||||
from .std_conv import StdConv2d, StdConv2dSame, ScaledStdConv2d, ScaledStdConv2dSame
|
||||
from .test_time_pool import TestTimePoolHead, apply_test_time_pool
|
||||
|
@ -1,14 +1,23 @@
|
||||
""" Select AttentionFactory Method
|
||||
""" Attention Factory
|
||||
|
||||
Hacked together by / Copyright 2020 Ross Wightman
|
||||
Hacked together by / Copyright 2021 Ross Wightman
|
||||
"""
|
||||
import torch
|
||||
from functools import partial
|
||||
|
||||
from .bottleneck_attn import BottleneckAttn
|
||||
from .cbam import CbamModule, LightCbamModule
|
||||
from .eca import EcaModule, CecaModule
|
||||
from .gather_excite import GatherExcite
|
||||
from .global_context import GlobalContext
|
||||
from .halo_attn import HaloAttn
|
||||
from .involution import Involution
|
||||
from .lambda_layer import LambdaLayer
|
||||
from .non_local_attn import NonLocalAttn, BatNonLocalAttn
|
||||
from .selective_kernel import SelectiveKernel
|
||||
from .split_attn import SplitAttn
|
||||
from .squeeze_excite import SEModule, EffectiveSEModule
|
||||
from .swin_attn import WindowAttention
|
||||
|
||||
|
||||
def get_attn(attn_type):
|
||||
@ -18,12 +27,16 @@ def get_attn(attn_type):
|
||||
if attn_type is not None:
|
||||
if isinstance(attn_type, str):
|
||||
attn_type = attn_type.lower()
|
||||
# Lightweight attention modules (channel and/or coarse spatial).
|
||||
# Typically added to existing network architecture blocks in addition to existing convolutions.
|
||||
if attn_type == 'se':
|
||||
module_cls = SEModule
|
||||
elif attn_type == 'ese':
|
||||
module_cls = EffectiveSEModule
|
||||
elif attn_type == 'eca':
|
||||
module_cls = EcaModule
|
||||
elif attn_type == 'ecam':
|
||||
module_cls = partial(EcaModule, use_mlp=True)
|
||||
elif attn_type == 'ceca':
|
||||
module_cls = CecaModule
|
||||
elif attn_type == 'ge':
|
||||
@ -34,6 +47,34 @@ def get_attn(attn_type):
|
||||
module_cls = CbamModule
|
||||
elif attn_type == 'lcbam':
|
||||
module_cls = LightCbamModule
|
||||
|
||||
# Attention / attention-like modules w/ significant params
|
||||
# Typically replace some of the existing workhorse convs in a network architecture.
|
||||
# All of these accept a stride argument and can spatially downsample the input.
|
||||
elif attn_type == 'sk':
|
||||
module_cls = SelectiveKernel
|
||||
elif attn_type == 'splat':
|
||||
module_cls = SplitAttn
|
||||
|
||||
# Self-attention / attention-like modules w/ significant compute and/or params
|
||||
# Typically replace some of the existing workhorse convs in a network architecture.
|
||||
# All of these accept a stride argument and can spatially downsample the input.
|
||||
elif attn_type == 'lambda':
|
||||
return LambdaLayer
|
||||
elif attn_type == 'bottleneck':
|
||||
return BottleneckAttn
|
||||
elif attn_type == 'halo':
|
||||
return HaloAttn
|
||||
elif attn_type == 'swin':
|
||||
return WindowAttention
|
||||
elif attn_type == 'involution':
|
||||
return Involution
|
||||
elif attn_type == 'nl':
|
||||
module_cls = NonLocalAttn
|
||||
elif attn_type == 'bat':
|
||||
module_cls = BatNonLocalAttn
|
||||
|
||||
# Woops!
|
||||
else:
|
||||
assert False, "Invalid attn module (%s)" % attn_type
|
||||
elif isinstance(attn_type, bool):
|
||||
|
@ -1,25 +0,0 @@
|
||||
from .bottleneck_attn import BottleneckAttn
|
||||
from .halo_attn import HaloAttn
|
||||
from .involution import Involution
|
||||
from .lambda_layer import LambdaLayer
|
||||
from .swin_attn import WindowAttention
|
||||
|
||||
|
||||
def get_self_attn(attn_type):
|
||||
if attn_type == 'bottleneck':
|
||||
return BottleneckAttn
|
||||
elif attn_type == 'halo':
|
||||
return HaloAttn
|
||||
elif attn_type == 'lambda':
|
||||
return LambdaLayer
|
||||
elif attn_type == 'swin':
|
||||
return WindowAttention
|
||||
elif attn_type == 'involution':
|
||||
return Involution
|
||||
else:
|
||||
assert False, f"Unknown attn type ({attn_type})"
|
||||
|
||||
|
||||
def create_self_attn(attn_type, dim, stride=1, **kwargs):
|
||||
attn_fn = get_self_attn(attn_type)
|
||||
return attn_fn(dim, stride=stride, **kwargs)
|
@ -39,6 +39,7 @@ import torch.nn.functional as F
|
||||
|
||||
|
||||
from .create_act import create_act_layer
|
||||
from .helpers import make_divisible
|
||||
|
||||
|
||||
class EcaModule(nn.Module):
|
||||
@ -56,21 +57,36 @@ class EcaModule(nn.Module):
|
||||
act_layer: optional non-linearity after conv, enables conv bias, this is an experiment
|
||||
gate_layer: gating non-linearity to use
|
||||
"""
|
||||
def __init__(self, channels=None, kernel_size=3, gamma=2, beta=1, act_layer=None, gate_layer='sigmoid'):
|
||||
def __init__(
|
||||
self, channels=None, kernel_size=3, gamma=2, beta=1, act_layer=None, gate_layer='sigmoid',
|
||||
rd_ratio=1/8, rd_channels=None, rd_divisor=8, use_mlp=False):
|
||||
super(EcaModule, self).__init__()
|
||||
if channels is not None:
|
||||
t = int(abs(math.log(channels, 2) + beta) / gamma)
|
||||
kernel_size = max(t if t % 2 else t + 1, 3)
|
||||
assert kernel_size % 2 == 1
|
||||
has_act = act_layer is not None
|
||||
self.conv = nn.Conv1d(1, 1, kernel_size=kernel_size, padding=(kernel_size - 1) // 2, bias=has_act)
|
||||
self.act = create_act_layer(act_layer) if has_act else nn.Identity()
|
||||
padding = (kernel_size - 1) // 2
|
||||
if use_mlp:
|
||||
# NOTE 'mlp' mode is a timm experiment, not in paper
|
||||
assert channels is not None
|
||||
if rd_channels is None:
|
||||
rd_channels = make_divisible(channels * rd_ratio, divisor=rd_divisor)
|
||||
act_layer = act_layer or nn.ReLU
|
||||
self.conv = nn.Conv1d(1, rd_channels, kernel_size=1, padding=0, bias=True)
|
||||
self.act = create_act_layer(act_layer)
|
||||
self.conv2 = nn.Conv1d(rd_channels, 1, kernel_size=kernel_size, padding=padding, bias=True)
|
||||
else:
|
||||
self.conv = nn.Conv1d(1, 1, kernel_size=kernel_size, padding=padding, bias=False)
|
||||
self.act = None
|
||||
self.conv2 = None
|
||||
self.gate = create_act_layer(gate_layer)
|
||||
|
||||
def forward(self, x):
|
||||
y = x.mean((2, 3)).view(x.shape[0], 1, -1) # view for 1d conv
|
||||
y = self.conv(y)
|
||||
y = self.act(y) # NOTE: usually a no-op, added for experimentation
|
||||
if self.conv2 is not None:
|
||||
y = self.act(y)
|
||||
y = self.conv2(y)
|
||||
y = self.gate(y).view(x.shape[0], -1, 1, 1)
|
||||
return x * y.expand_as(x)
|
||||
|
||||
@ -115,7 +131,6 @@ class CecaModule(nn.Module):
|
||||
# implement manual circular padding
|
||||
self.padding = (kernel_size - 1) // 2
|
||||
self.conv = nn.Conv1d(1, 1, kernel_size=kernel_size, padding=0, bias=has_act)
|
||||
self.act = create_act_layer(act_layer) if has_act else nn.Identity()
|
||||
self.gate = create_act_layer(gate_layer)
|
||||
|
||||
def forward(self, x):
|
||||
@ -123,7 +138,6 @@ class CecaModule(nn.Module):
|
||||
# Manually implement circular padding, F.pad does not seemed to be bugged
|
||||
y = F.pad(y, (self.padding, self.padding), mode='circular')
|
||||
y = self.conv(y)
|
||||
y = self.act(y) # NOTE: usually a no-op, added for experimentation
|
||||
y = self.gate(y).view(x.shape[0], -1, 1, 1)
|
||||
return x * y.expand_as(x)
|
||||
|
||||
|
145
timm/models/layers/non_local_attn.py
Normal file
145
timm/models/layers/non_local_attn.py
Normal file
@ -0,0 +1,145 @@
|
||||
""" Bilinear-Attention-Transform and Non-Local Attention
|
||||
|
||||
Paper: `Non-Local Neural Networks With Grouped Bilinear Attentional Transforms`
|
||||
- https://openaccess.thecvf.com/content_CVPR_2020/html/Chi_Non-Local_Neural_Networks_With_Grouped_Bilinear_Attentional_Transforms_CVPR_2020_paper.html
|
||||
Adapted from original code: https://github.com/BA-Transform/BAT-Image-Classification
|
||||
"""
|
||||
import torch
|
||||
from torch import nn
|
||||
from torch.nn import functional as F
|
||||
|
||||
from .conv_bn_act import ConvBnAct
|
||||
from .helpers import make_divisible
|
||||
|
||||
|
||||
class NonLocalAttn(nn.Module):
|
||||
"""Spatial NL block for image classification.
|
||||
|
||||
This was adapted from https://github.com/BA-Transform/BAT-Image-Classification
|
||||
Their NonLocal impl inspired by https://github.com/facebookresearch/video-nonlocal-net.
|
||||
"""
|
||||
|
||||
def __init__(self, in_channels, use_scale=True, rd_ratio=1/8, rd_channels=None, rd_divisor=8, **kwargs):
|
||||
super(NonLocalAttn, self).__init__()
|
||||
if rd_channels is None:
|
||||
rd_channels = make_divisible(in_channels * rd_ratio, divisor=rd_divisor)
|
||||
self.scale = in_channels ** -0.5 if use_scale else 1.0
|
||||
self.t = nn.Conv2d(in_channels, rd_channels, kernel_size=1, stride=1, bias=True)
|
||||
self.p = nn.Conv2d(in_channels, rd_channels, kernel_size=1, stride=1, bias=True)
|
||||
self.g = nn.Conv2d(in_channels, rd_channels, kernel_size=1, stride=1, bias=True)
|
||||
self.z = nn.Conv2d(rd_channels, in_channels, kernel_size=1, stride=1, bias=True)
|
||||
self.norm = nn.BatchNorm2d(in_channels)
|
||||
self.reset_parameters()
|
||||
|
||||
def forward(self, x):
|
||||
shortcut = x
|
||||
|
||||
t = self.t(x)
|
||||
p = self.p(x)
|
||||
g = self.g(x)
|
||||
|
||||
B, C, H, W = t.size()
|
||||
t = t.view(B, C, -1).permute(0, 2, 1)
|
||||
p = p.view(B, C, -1)
|
||||
g = g.view(B, C, -1).permute(0, 2, 1)
|
||||
|
||||
att = torch.bmm(t, p) * self.scale
|
||||
att = F.softmax(att, dim=2)
|
||||
x = torch.bmm(att, g)
|
||||
|
||||
x = x.permute(0, 2, 1).reshape(B, C, H, W)
|
||||
x = self.z(x)
|
||||
x = self.norm(x) + shortcut
|
||||
|
||||
return x
|
||||
|
||||
def reset_parameters(self):
|
||||
for name, m in self.named_modules():
|
||||
if isinstance(m, nn.Conv2d):
|
||||
nn.init.kaiming_normal_(
|
||||
m.weight, mode='fan_out', nonlinearity='relu')
|
||||
if len(list(m.parameters())) > 1:
|
||||
nn.init.constant_(m.bias, 0.0)
|
||||
elif isinstance(m, nn.BatchNorm2d):
|
||||
nn.init.constant_(m.weight, 0)
|
||||
nn.init.constant_(m.bias, 0)
|
||||
elif isinstance(m, nn.GroupNorm):
|
||||
nn.init.constant_(m.weight, 0)
|
||||
nn.init.constant_(m.bias, 0)
|
||||
|
||||
|
||||
class BilinearAttnTransform(nn.Module):
|
||||
|
||||
def __init__(self, in_channels, block_size, groups, act_layer=nn.ReLU, norm_layer=nn.BatchNorm2d):
|
||||
super(BilinearAttnTransform, self).__init__()
|
||||
|
||||
self.conv1 = ConvBnAct(in_channels, groups, 1, act_layer=act_layer, norm_layer=norm_layer)
|
||||
self.conv_p = nn.Conv2d(groups, block_size * block_size * groups, kernel_size=(block_size, 1))
|
||||
self.conv_q = nn.Conv2d(groups, block_size * block_size * groups, kernel_size=(1, block_size))
|
||||
self.conv2 = ConvBnAct(in_channels, in_channels, 1, act_layer=act_layer, norm_layer=norm_layer)
|
||||
self.block_size = block_size
|
||||
self.groups = groups
|
||||
self.in_channels = in_channels
|
||||
|
||||
def resize_mat(self, x, t):
|
||||
B, C, block_size, block_size1 = x.shape
|
||||
assert block_size == block_size1
|
||||
if t <= 1:
|
||||
return x
|
||||
x = x.view(B * C, -1, 1, 1)
|
||||
x = x * torch.eye(t, t, dtype=x.dtype, device=x.device)
|
||||
x = x.view(B * C, block_size, block_size, t, t)
|
||||
x = torch.cat(torch.split(x, 1, dim=1), dim=3)
|
||||
x = torch.cat(torch.split(x, 1, dim=2), dim=4)
|
||||
x = x.view(B, C, block_size * t, block_size * t)
|
||||
return x
|
||||
|
||||
def forward(self, x):
|
||||
assert x.shape[-1] % self.block_size == 0 and x.shape[-2] % self.block_size == 0
|
||||
B, C, H, W = x.shape
|
||||
out = self.conv1(x)
|
||||
rp = F.adaptive_max_pool2d(out, (self.block_size, 1))
|
||||
cp = F.adaptive_max_pool2d(out, (1, self.block_size))
|
||||
p = self.conv_p(rp).view(B, self.groups, self.block_size, self.block_size)
|
||||
q = self.conv_q(cp).view(B, self.groups, self.block_size, self.block_size)
|
||||
p = F.sigmoid(p)
|
||||
q = F.sigmoid(q)
|
||||
p = p / p.sum(dim=3, keepdim=True)
|
||||
q = q / q.sum(dim=2, keepdim=True)
|
||||
p = p.view(B, self.groups, 1, self.block_size, self.block_size).expand(x.size(
|
||||
0), self.groups, C // self.groups, self.block_size, self.block_size).contiguous()
|
||||
p = p.view(B, C, self.block_size, self.block_size)
|
||||
q = q.view(B, self.groups, 1, self.block_size, self.block_size).expand(x.size(
|
||||
0), self.groups, C // self.groups, self.block_size, self.block_size).contiguous()
|
||||
q = q.view(B, C, self.block_size, self.block_size)
|
||||
p = self.resize_mat(p, H // self.block_size)
|
||||
q = self.resize_mat(q, W // self.block_size)
|
||||
y = p.matmul(x)
|
||||
y = y.matmul(q)
|
||||
|
||||
y = self.conv2(y)
|
||||
return y
|
||||
|
||||
|
||||
class BatNonLocalAttn(nn.Module):
|
||||
""" BAT
|
||||
Adapted from: https://github.com/BA-Transform/BAT-Image-Classification
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self, in_channels, block_size=7, groups=2, rd_ratio=0.25, rd_channels=None, rd_divisor=8,
|
||||
drop_rate=0.2, act_layer=nn.ReLU, norm_layer=nn.BatchNorm2d, **_):
|
||||
super().__init__()
|
||||
if rd_channels is None:
|
||||
rd_channels = make_divisible(in_channels * rd_ratio, divisor=rd_divisor)
|
||||
self.conv1 = ConvBnAct(in_channels, rd_channels, 1, act_layer=act_layer, norm_layer=norm_layer)
|
||||
self.ba = BilinearAttnTransform(rd_channels, block_size, groups, act_layer=act_layer, norm_layer=norm_layer)
|
||||
self.conv2 = ConvBnAct(rd_channels, in_channels, 1, act_layer=act_layer, norm_layer=norm_layer)
|
||||
self.dropout = nn.Dropout2d(p=drop_rate)
|
||||
|
||||
def forward(self, x):
|
||||
xl = self.conv1(x)
|
||||
y = self.ba(xl)
|
||||
y = self.conv2(y)
|
||||
y = self.dropout(y)
|
||||
return y + x
|
@ -8,6 +8,7 @@ import torch
|
||||
from torch import nn as nn
|
||||
|
||||
from .conv_bn_act import ConvBnAct
|
||||
from .helpers import make_divisible
|
||||
|
||||
|
||||
def _kernel_valid(k):
|
||||
@ -45,10 +46,10 @@ class SelectiveKernelAttn(nn.Module):
|
||||
return x
|
||||
|
||||
|
||||
class SelectiveKernelConv(nn.Module):
|
||||
class SelectiveKernel(nn.Module):
|
||||
|
||||
def __init__(self, in_channels, out_channels, kernel_size=None, stride=1, dilation=1, groups=1,
|
||||
attn_reduction=16, min_attn_channels=32, keep_3x3=True, split_input=False,
|
||||
def __init__(self, in_channels, out_channels=None, kernel_size=None, stride=1, dilation=1, groups=1,
|
||||
rd_ratio=1./16, rd_channels=None, min_rd_channels=16, rd_divisor=8, keep_3x3=True, split_input=True,
|
||||
drop_block=None, act_layer=nn.ReLU, norm_layer=nn.BatchNorm2d, aa_layer=None):
|
||||
""" Selective Kernel Convolution Module
|
||||
|
||||
@ -66,8 +67,8 @@ class SelectiveKernelConv(nn.Module):
|
||||
stride (int): stride for convolutions
|
||||
dilation (int): dilation for module as a whole, impacts dilation of each branch
|
||||
groups (int): number of groups for each branch
|
||||
attn_reduction (int, float): reduction factor for attention features
|
||||
min_attn_channels (int): minimum attention feature channels
|
||||
rd_ratio (int, float): reduction factor for attention features
|
||||
min_rd_channels (int): minimum attention feature channels
|
||||
keep_3x3 (bool): keep all branch convolution kernels as 3x3, changing larger kernels for dilations
|
||||
split_input (bool): split input channels evenly across each convolution branch, keeps param count lower,
|
||||
can be viewed as grouping by path, output expands to module out_channels count
|
||||
@ -75,7 +76,8 @@ class SelectiveKernelConv(nn.Module):
|
||||
act_layer (nn.Module): activation layer to use
|
||||
norm_layer (nn.Module): batchnorm/norm layer to use
|
||||
"""
|
||||
super(SelectiveKernelConv, self).__init__()
|
||||
super(SelectiveKernel, self).__init__()
|
||||
out_channels = out_channels or in_channels
|
||||
kernel_size = kernel_size or [3, 5] # default to one 3x3 and one 5x5 branch. 5x5 -> 3x3 + dilation
|
||||
_kernel_valid(kernel_size)
|
||||
if not isinstance(kernel_size, list):
|
||||
@ -101,7 +103,8 @@ class SelectiveKernelConv(nn.Module):
|
||||
ConvBnAct(in_channels, out_channels, kernel_size=k, dilation=d, **conv_kwargs)
|
||||
for k, d in zip(kernel_size, dilation)])
|
||||
|
||||
attn_channels = max(int(out_channels / attn_reduction), min_attn_channels)
|
||||
attn_channels = rd_channels or make_divisible(
|
||||
out_channels * rd_ratio, min_value=min_rd_channels, divisor=rd_divisor)
|
||||
self.attn = SelectiveKernelAttn(out_channels, self.num_paths, attn_channels)
|
||||
self.drop_block = drop_block
|
||||
|
||||
|
@ -10,6 +10,8 @@ import torch
|
||||
import torch.nn.functional as F
|
||||
from torch import nn
|
||||
|
||||
from .helpers import make_divisible
|
||||
|
||||
|
||||
class RadixSoftmax(nn.Module):
|
||||
def __init__(self, radix, cardinality):
|
||||
@ -28,41 +30,37 @@ class RadixSoftmax(nn.Module):
|
||||
return x
|
||||
|
||||
|
||||
class SplitAttnConv2d(nn.Module):
|
||||
"""Split-Attention Conv2d
|
||||
class SplitAttn(nn.Module):
|
||||
"""Split-Attention (aka Splat)
|
||||
"""
|
||||
def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0,
|
||||
dilation=1, groups=1, bias=False, radix=2, reduction_factor=4,
|
||||
def __init__(self, in_channels, out_channels=None, kernel_size=3, stride=1, padding=None,
|
||||
dilation=1, groups=1, bias=False, radix=2, rd_ratio=0.25, rd_channels=None, rd_divisor=8,
|
||||
act_layer=nn.ReLU, norm_layer=None, drop_block=None, **kwargs):
|
||||
super(SplitAttnConv2d, self).__init__()
|
||||
super(SplitAttn, self).__init__()
|
||||
out_channels = out_channels or in_channels
|
||||
self.radix = radix
|
||||
self.drop_block = drop_block
|
||||
mid_chs = out_channels * radix
|
||||
attn_chs = max(in_channels * radix // reduction_factor, 32)
|
||||
if rd_channels is None:
|
||||
attn_chs = make_divisible(in_channels * radix * rd_ratio, min_value=32, divisor=rd_divisor)
|
||||
else:
|
||||
attn_chs = rd_channels * radix
|
||||
|
||||
padding = kernel_size // 2 if padding is None else padding
|
||||
self.conv = nn.Conv2d(
|
||||
in_channels, mid_chs, kernel_size, stride, padding, dilation,
|
||||
groups=groups * radix, bias=bias, **kwargs)
|
||||
self.bn0 = norm_layer(mid_chs) if norm_layer is not None else None
|
||||
self.bn0 = norm_layer(mid_chs) if norm_layer else nn.Identity()
|
||||
self.act0 = act_layer(inplace=True)
|
||||
self.fc1 = nn.Conv2d(out_channels, attn_chs, 1, groups=groups)
|
||||
self.bn1 = norm_layer(attn_chs) if norm_layer is not None else None
|
||||
self.bn1 = norm_layer(attn_chs) if norm_layer else nn.Identity()
|
||||
self.act1 = act_layer(inplace=True)
|
||||
self.fc2 = nn.Conv2d(attn_chs, mid_chs, 1, groups=groups)
|
||||
self.rsoftmax = RadixSoftmax(radix, groups)
|
||||
|
||||
@property
|
||||
def in_channels(self):
|
||||
return self.conv.in_channels
|
||||
|
||||
@property
|
||||
def out_channels(self):
|
||||
return self.fc1.out_channels
|
||||
|
||||
def forward(self, x):
|
||||
x = self.conv(x)
|
||||
if self.bn0 is not None:
|
||||
x = self.bn0(x)
|
||||
x = self.bn0(x)
|
||||
if self.drop_block is not None:
|
||||
x = self.drop_block(x)
|
||||
x = self.act0(x)
|
||||
@ -73,10 +71,9 @@ class SplitAttnConv2d(nn.Module):
|
||||
x_gap = x.sum(dim=1)
|
||||
else:
|
||||
x_gap = x
|
||||
x_gap = F.adaptive_avg_pool2d(x_gap, 1)
|
||||
x_gap = x_gap.mean((2, 3), keepdim=True)
|
||||
x_gap = self.fc1(x_gap)
|
||||
if self.bn1 is not None:
|
||||
x_gap = self.bn1(x_gap)
|
||||
x_gap = self.bn1(x_gap)
|
||||
x_gap = self.act1(x_gap)
|
||||
x_attn = self.fc2(x_gap)
|
||||
|
||||
|
@ -56,7 +56,7 @@ class EffectiveSEModule(nn.Module):
|
||||
""" 'Effective Squeeze-Excitation
|
||||
From `CenterMask : Real-Time Anchor-Free Instance Segmentation` - https://arxiv.org/abs/1911.06667
|
||||
"""
|
||||
def __init__(self, channels, add_maxpool=False, gate_layer='hard_sigmoid'):
|
||||
def __init__(self, channels, add_maxpool=False, gate_layer='hard_sigmoid', **_):
|
||||
super(EffectiveSEModule, self).__init__()
|
||||
self.add_maxpool = add_maxpool
|
||||
self.fc = nn.Conv2d(channels, channels, kernel_size=1, padding=0)
|
||||
|
@ -11,7 +11,7 @@ from torch import nn
|
||||
|
||||
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
|
||||
from .helpers import build_model_with_cfg
|
||||
from .layers import SplitAttnConv2d
|
||||
from .layers import SplitAttn
|
||||
from .registry import register_model
|
||||
from .resnet import ResNet
|
||||
|
||||
@ -83,11 +83,11 @@ class ResNestBottleneck(nn.Module):
|
||||
self.avd_first = nn.AvgPool2d(3, avd_stride, padding=1) if avd_stride > 0 and avd_first else None
|
||||
|
||||
if self.radix >= 1:
|
||||
self.conv2 = SplitAttnConv2d(
|
||||
self.conv2 = SplitAttn(
|
||||
group_width, group_width, kernel_size=3, stride=stride, padding=first_dilation,
|
||||
dilation=first_dilation, groups=cardinality, radix=radix, norm_layer=norm_layer, drop_block=drop_block)
|
||||
self.bn2 = None # FIXME revisit, here to satisfy current torchscript fussyness
|
||||
self.act2 = None
|
||||
self.bn2 = nn.Identity()
|
||||
self.act2 = nn.Identity()
|
||||
else:
|
||||
self.conv2 = nn.Conv2d(
|
||||
group_width, group_width, kernel_size=3, stride=stride, padding=first_dilation,
|
||||
@ -117,11 +117,10 @@ class ResNestBottleneck(nn.Module):
|
||||
out = self.avd_first(out)
|
||||
|
||||
out = self.conv2(out)
|
||||
if self.bn2 is not None:
|
||||
out = self.bn2(out)
|
||||
if self.drop_block is not None:
|
||||
out = self.drop_block(out)
|
||||
out = self.act2(out)
|
||||
out = self.bn2(out)
|
||||
if self.drop_block is not None:
|
||||
out = self.drop_block(out)
|
||||
out = self.act2(out)
|
||||
|
||||
if self.avd_last is not None:
|
||||
out = self.avd_last(out)
|
||||
|
@ -14,7 +14,7 @@ from torch import nn as nn
|
||||
|
||||
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
|
||||
from .helpers import build_model_with_cfg
|
||||
from .layers import SelectiveKernelConv, ConvBnAct, create_attn
|
||||
from .layers import SelectiveKernel, ConvBnAct, create_attn
|
||||
from .registry import register_model
|
||||
from .resnet import ResNet
|
||||
|
||||
@ -59,7 +59,7 @@ class SelectiveKernelBasic(nn.Module):
|
||||
outplanes = planes * self.expansion
|
||||
first_dilation = first_dilation or dilation
|
||||
|
||||
self.conv1 = SelectiveKernelConv(
|
||||
self.conv1 = SelectiveKernel(
|
||||
inplanes, first_planes, stride=stride, dilation=first_dilation, **conv_kwargs, **sk_kwargs)
|
||||
conv_kwargs['act_layer'] = None
|
||||
self.conv2 = ConvBnAct(
|
||||
@ -107,7 +107,7 @@ class SelectiveKernelBottleneck(nn.Module):
|
||||
first_dilation = first_dilation or dilation
|
||||
|
||||
self.conv1 = ConvBnAct(inplanes, first_planes, kernel_size=1, **conv_kwargs)
|
||||
self.conv2 = SelectiveKernelConv(
|
||||
self.conv2 = SelectiveKernel(
|
||||
first_planes, width, stride=stride, dilation=first_dilation, groups=cardinality,
|
||||
**conv_kwargs, **sk_kwargs)
|
||||
conv_kwargs['act_layer'] = None
|
||||
@ -153,10 +153,7 @@ def skresnet18(pretrained=False, **kwargs):
|
||||
Different from configs in Select Kernel paper or "Compounding the Performance Improvements..." this
|
||||
variation splits the input channels to the selective convolutions to keep param count down.
|
||||
"""
|
||||
sk_kwargs = dict(
|
||||
min_attn_channels=16,
|
||||
attn_reduction=8,
|
||||
split_input=True)
|
||||
sk_kwargs = dict(min_rd_channels=16, rd_ratio=1/8, split_input=True)
|
||||
model_args = dict(
|
||||
block=SelectiveKernelBasic, layers=[2, 2, 2, 2], block_args=dict(sk_kwargs=sk_kwargs),
|
||||
zero_init_last_bn=False, **kwargs)
|
||||
@ -170,10 +167,7 @@ def skresnet34(pretrained=False, **kwargs):
|
||||
Different from configs in Select Kernel paper or "Compounding the Performance Improvements..." this
|
||||
variation splits the input channels to the selective convolutions to keep param count down.
|
||||
"""
|
||||
sk_kwargs = dict(
|
||||
min_attn_channels=16,
|
||||
attn_reduction=8,
|
||||
split_input=True)
|
||||
sk_kwargs = dict(min_rd_channels=16, rd_ratio=1/8, split_input=True)
|
||||
model_args = dict(
|
||||
block=SelectiveKernelBasic, layers=[3, 4, 6, 3], block_args=dict(sk_kwargs=sk_kwargs),
|
||||
zero_init_last_bn=False, **kwargs)
|
||||
|
Loading…
x
Reference in New Issue
Block a user