mirror of
https://github.com/huggingface/pytorch-image-models.git
synced 2025-06-03 15:01:08 +08:00
Add initial RedNet model / Involution layer impl for testing
This commit is contained in:
parent
715519a5ef
commit
165fb354b2
@ -58,6 +58,9 @@ default_cfgs = {
|
||||
|
||||
'swinnet26t_256': _cfg(url='', fixed_input_size=True, input_size=(3, 256, 256), pool_size=(8, 8)),
|
||||
'swinnet50ts_256': _cfg(url='', fixed_input_size=True, input_size=(3, 256, 256), pool_size=(8, 8)),
|
||||
|
||||
'rednet26t': _cfg(url='', fixed_input_size=False, input_size=(3, 256, 256), pool_size=(8, 8)),
|
||||
'rednet50ts': _cfg(url='', fixed_input_size=False, input_size=(3, 256, 256), pool_size=(8, 8)),
|
||||
}
|
||||
|
||||
|
||||
@ -245,6 +248,38 @@ model_cfgs = dict(
|
||||
self_attn_fixed_size=True,
|
||||
self_attn_kwargs=dict(win_size=8)
|
||||
),
|
||||
|
||||
rednet26t=ByoaCfg(
|
||||
blocks=(
|
||||
ByoaBlocksCfg(type='self_attn', d=2, c=256, s=1, gs=0, br=0.25),
|
||||
ByoaBlocksCfg(type='self_attn', d=2, c=512, s=2, gs=0, br=0.25),
|
||||
ByoaBlocksCfg(type='self_attn', d=2, c=1024, s=2, gs=0, br=0.25),
|
||||
ByoaBlocksCfg(type='self_attn', d=2, c=2048, s=2, gs=0, br=0.25),
|
||||
),
|
||||
stem_chs=64,
|
||||
stem_type='tiered', # FIXME RedNet uses involution in middle of stem
|
||||
stem_pool='maxpool',
|
||||
num_features=0,
|
||||
self_attn_layer='involution',
|
||||
self_attn_fixed_size=False,
|
||||
self_attn_kwargs=dict()
|
||||
),
|
||||
rednet50ts=ByoaCfg(
|
||||
blocks=(
|
||||
ByoaBlocksCfg(type='self_attn', d=3, c=256, s=1, gs=0, br=0.25),
|
||||
ByoaBlocksCfg(type='self_attn', d=4, c=512, s=2, gs=0, br=0.25),
|
||||
ByoaBlocksCfg(type='self_attn', d=2, c=1024, s=2, gs=0, br=0.25),
|
||||
ByoaBlocksCfg(type='self_attn', d=3, c=2048, s=2, gs=0, br=0.25),
|
||||
),
|
||||
stem_chs=64,
|
||||
stem_type='tiered',
|
||||
stem_pool='maxpool',
|
||||
num_features=0,
|
||||
act_layer='silu',
|
||||
self_attn_layer='involution',
|
||||
self_attn_fixed_size=False,
|
||||
self_attn_kwargs=dict()
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
@ -477,3 +512,17 @@ def swinnet50ts_256(pretrained=False, **kwargs):
|
||||
"""
|
||||
kwargs.setdefault('img_size', 256)
|
||||
return _create_byoanet('swinnet50ts_256', 'swinnet50ts', pretrained=pretrained, **kwargs)
|
||||
|
||||
|
||||
@register_model
|
||||
def rednet26t(pretrained=False, **kwargs):
|
||||
"""
|
||||
"""
|
||||
return _create_byoanet('rednet26t', pretrained=pretrained, **kwargs)
|
||||
|
||||
|
||||
@register_model
|
||||
def rednet50ts(pretrained=False, **kwargs):
|
||||
"""
|
||||
"""
|
||||
return _create_byoanet('rednet50ts', pretrained=pretrained, **kwargs)
|
||||
|
@ -18,6 +18,7 @@ from .eca import EcaModule, CecaModule
|
||||
from .evo_norm import EvoNormBatch2d, EvoNormSample2d
|
||||
from .helpers import to_ntuple, to_2tuple, to_3tuple, to_4tuple, make_divisible
|
||||
from .inplace_abn import InplaceAbn
|
||||
from .involution import Involution
|
||||
from .linear import Linear
|
||||
from .mixed_conv2d import MixedConv2d
|
||||
from .mlp import Mlp, GluMlp
|
||||
|
@ -1,5 +1,6 @@
|
||||
from .bottleneck_attn import BottleneckAttn
|
||||
from .halo_attn import HaloAttn
|
||||
from .involution import Involution
|
||||
from .lambda_layer import LambdaLayer
|
||||
from .swin_attn import WindowAttention
|
||||
|
||||
@ -13,6 +14,8 @@ def get_self_attn(attn_type):
|
||||
return LambdaLayer
|
||||
elif attn_type == 'swin':
|
||||
return WindowAttention
|
||||
elif attn_type == 'involution':
|
||||
return Involution
|
||||
else:
|
||||
assert False, f"Unknown attn type ({attn_type})"
|
||||
|
||||
|
50
timm/models/layers/involution.py
Normal file
50
timm/models/layers/involution.py
Normal file
@ -0,0 +1,50 @@
|
||||
""" PyTorch Involution Layer
|
||||
|
||||
Official impl: https://github.com/d-li14/involution/blob/main/cls/mmcls/models/utils/involution_naive.py
|
||||
Paper: `Involution: Inverting the Inherence of Convolution for Visual Recognition` - https://arxiv.org/abs/2103.06255
|
||||
"""
|
||||
import torch.nn as nn
|
||||
from .conv_bn_act import ConvBnAct
|
||||
from .create_conv2d import create_conv2d
|
||||
|
||||
|
||||
class Involution(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
channels,
|
||||
kernel_size=3,
|
||||
stride=1,
|
||||
group_size=16,
|
||||
reduction_ratio=4,
|
||||
norm_layer=nn.BatchNorm2d,
|
||||
act_layer=nn.ReLU,
|
||||
):
|
||||
super(Involution, self).__init__()
|
||||
self.kernel_size = kernel_size
|
||||
self.stride = stride
|
||||
self.channels = channels
|
||||
self.group_size = group_size
|
||||
self.groups = self.channels // self.group_size
|
||||
self.conv1 = ConvBnAct(
|
||||
in_channels=channels,
|
||||
out_channels=channels // reduction_ratio,
|
||||
kernel_size=1,
|
||||
norm_layer=norm_layer,
|
||||
act_layer=act_layer)
|
||||
self.conv2 = self.conv = create_conv2d(
|
||||
in_channels=channels // reduction_ratio,
|
||||
out_channels=kernel_size**2 * self.groups,
|
||||
kernel_size=1,
|
||||
stride=1)
|
||||
self.avgpool = nn.AvgPool2d(stride, stride) if stride == 2 else nn.Identity()
|
||||
self.unfold = nn.Unfold(kernel_size, 1, (kernel_size-1)//2, stride)
|
||||
|
||||
def forward(self, x):
|
||||
weight = self.conv2(self.conv1(self.avgpool(x)))
|
||||
B, C, H, W = weight.shape
|
||||
KK = int(self.kernel_size ** 2)
|
||||
weight = weight.view(B, self.groups, KK, H, W).unsqueeze(2)
|
||||
out = self.unfold(x).view(B, self.groups, self.group_size, KK, H, W)
|
||||
out = (weight * out).sum(dim=3).view(B, self.channels, H, W)
|
||||
return out
|
Loading…
x
Reference in New Issue
Block a user