mirror of
https://github.com/huggingface/pytorch-image-models.git
synced 2025-06-03 15:01:08 +08:00
Remove min channels for SelectiveKernel, divisor should cover cases well enough.
This commit is contained in:
parent
a27f4aec4a
commit
bda8ab015a
@ -49,7 +49,7 @@ class SelectiveKernelAttn(nn.Module):
|
||||
class SelectiveKernel(nn.Module):
|
||||
|
||||
def __init__(self, in_channels, out_channels=None, kernel_size=None, stride=1, dilation=1, groups=1,
|
||||
rd_ratio=1./16, rd_channels=None, min_rd_channels=32, rd_divisor=8, keep_3x3=True, split_input=True,
|
||||
rd_ratio=1./16, rd_channels=None, rd_divisor=8, keep_3x3=True, split_input=True,
|
||||
drop_block=None, act_layer=nn.ReLU, norm_layer=nn.BatchNorm2d, aa_layer=None):
|
||||
""" Selective Kernel Convolution Module
|
||||
|
||||
@ -68,7 +68,6 @@ class SelectiveKernel(nn.Module):
|
||||
dilation (int): dilation for module as a whole, impacts dilation of each branch
|
||||
groups (int): number of groups for each branch
|
||||
rd_ratio (int, float): reduction factor for attention features
|
||||
min_rd_channels (int): minimum attention feature channels
|
||||
keep_3x3 (bool): keep all branch convolution kernels as 3x3, changing larger kernels for dilations
|
||||
split_input (bool): split input channels evenly across each convolution branch, keeps param count lower,
|
||||
can be viewed as grouping by path, output expands to module out_channels count
|
||||
@ -103,8 +102,7 @@ class SelectiveKernel(nn.Module):
|
||||
ConvBnAct(in_channels, out_channels, kernel_size=k, dilation=d, **conv_kwargs)
|
||||
for k, d in zip(kernel_size, dilation)])
|
||||
|
||||
attn_channels = rd_channels or make_divisible(
|
||||
out_channels * rd_ratio, min_value=min_rd_channels, divisor=rd_divisor)
|
||||
attn_channels = rd_channels or make_divisible(out_channels * rd_ratio, divisor=rd_divisor)
|
||||
self.attn = SelectiveKernelAttn(out_channels, self.num_paths, attn_channels)
|
||||
self.drop_block = drop_block
|
||||
|
||||
|
@ -153,7 +153,7 @@ def skresnet18(pretrained=False, **kwargs):
|
||||
Different from configs in Select Kernel paper or "Compounding the Performance Improvements..." this
|
||||
variation splits the input channels to the selective convolutions to keep param count down.
|
||||
"""
|
||||
sk_kwargs = dict(min_rd_channels=16, rd_ratio=1/8, split_input=True)
|
||||
sk_kwargs = dict(rd_ratio=1 / 8, rd_divisor=16, split_input=True)
|
||||
model_args = dict(
|
||||
block=SelectiveKernelBasic, layers=[2, 2, 2, 2], block_args=dict(sk_kwargs=sk_kwargs),
|
||||
zero_init_last_bn=False, **kwargs)
|
||||
@ -167,7 +167,7 @@ def skresnet34(pretrained=False, **kwargs):
|
||||
Different from configs in Select Kernel paper or "Compounding the Performance Improvements..." this
|
||||
variation splits the input channels to the selective convolutions to keep param count down.
|
||||
"""
|
||||
sk_kwargs = dict(min_rd_channels=16, rd_ratio=1/8, split_input=True)
|
||||
sk_kwargs = dict(rd_ratio=1 / 8, rd_divisor=16, split_input=True)
|
||||
model_args = dict(
|
||||
block=SelectiveKernelBasic, layers=[3, 4, 6, 3], block_args=dict(sk_kwargs=sk_kwargs),
|
||||
zero_init_last_bn=False, **kwargs)
|
||||
@ -207,7 +207,7 @@ def skresnext50_32x4d(pretrained=False, **kwargs):
|
||||
"""Constructs a Select Kernel ResNeXt50-32x4d model. This should be equivalent to
|
||||
the SKNet-50 model in the Select Kernel Paper
|
||||
"""
|
||||
sk_kwargs = dict(min_rd_channels=32, rd_ratio=1/16, split_input=False)
|
||||
sk_kwargs = dict(rd_ratio=1/16, rd_divisor=32, split_input=False)
|
||||
model_args = dict(
|
||||
block=SelectiveKernelBottleneck, layers=[3, 4, 6, 3], cardinality=32, base_width=4,
|
||||
block_args=dict(sk_kwargs=sk_kwargs), zero_init_last_bn=False, **kwargs)
|
||||
|
Loading…
x
Reference in New Issue
Block a user