mirror of
https://github.com/huggingface/pytorch-image-models.git
synced 2025-06-03 15:01:08 +08:00
SelectKernel split_input works best when input channels split like grouped conv, but output is full width. Disable zero_init for SK nets, seems a bad combo.
This commit is contained in:
parent
7d07ebb660
commit
13e8da2b46
@ -311,15 +311,13 @@ class SelectiveKernelConv(nn.Module):
|
||||
kernel_size = [3] * len(kernel_size)
|
||||
else:
|
||||
dilation = [dilation] * len(kernel_size)
|
||||
num_paths = len(kernel_size)
|
||||
self.num_paths = num_paths
|
||||
self.split_input = split_input
|
||||
self.num_paths = len(kernel_size)
|
||||
self.in_channels = in_channels
|
||||
self.out_channels = out_channels
|
||||
if split_input:
|
||||
assert in_channels % num_paths == 0 and out_channels % num_paths == 0
|
||||
in_channels = in_channels // num_paths
|
||||
out_channels = out_channels // num_paths
|
||||
self.split_input = split_input
|
||||
if self.split_input:
|
||||
assert in_channels % self.num_paths == 0
|
||||
in_channels = in_channels // self.num_paths
|
||||
groups = min(out_channels, groups)
|
||||
|
||||
conv_kwargs = dict(
|
||||
@ -329,7 +327,7 @@ class SelectiveKernelConv(nn.Module):
|
||||
for k, d in zip(kernel_size, dilation)])
|
||||
|
||||
attn_channels = max(int(out_channels / attn_reduction), min_attn_channels)
|
||||
self.attn = SelectiveKernelAttn(out_channels, num_paths, attn_channels)
|
||||
self.attn = SelectiveKernelAttn(out_channels, self.num_paths, attn_channels)
|
||||
self.drop_block = drop_block
|
||||
|
||||
def forward(self, x):
|
||||
@ -338,16 +336,10 @@ class SelectiveKernelConv(nn.Module):
|
||||
x_paths = [op(x_split[i]) for i, op in enumerate(self.paths)]
|
||||
else:
|
||||
x_paths = [op(x) for op in self.paths]
|
||||
|
||||
x = torch.stack(x_paths, dim=1)
|
||||
x_attn = self.attn(x)
|
||||
x = x * x_attn
|
||||
|
||||
if self.split_input:
|
||||
B, N, C, H, W = x.shape
|
||||
x = x.reshape(B, N * C, H, W)
|
||||
else:
|
||||
x = torch.sum(x, dim=1)
|
||||
x = torch.sum(x, dim=1)
|
||||
return x
|
||||
|
||||
|
||||
|
@ -158,11 +158,12 @@ def sksresnet18(pretrained=False, num_classes=1000, in_chans=3, **kwargs):
|
||||
default_cfg = default_cfgs['skresnet18']
|
||||
sk_kwargs = dict(
|
||||
min_attn_channels=16,
|
||||
attn_reduction=8,
|
||||
split_input=True
|
||||
)
|
||||
model = ResNet(
|
||||
SelectiveKernelBasic, [2, 2, 2, 2], num_classes=num_classes, in_chans=in_chans,
|
||||
block_args=dict(sk_kwargs=sk_kwargs), **kwargs)
|
||||
block_args=dict(sk_kwargs=sk_kwargs), zero_init_last_bn=False, **kwargs)
|
||||
model.default_cfg = default_cfg
|
||||
if pretrained:
|
||||
load_pretrained(model, default_cfg, num_classes, in_chans)
|
||||
@ -179,7 +180,7 @@ def skresnet26d(pretrained=False, num_classes=1000, in_chans=3, **kwargs):
|
||||
)
|
||||
model = ResNet(
|
||||
SelectiveKernelBottleneck, [2, 2, 2, 2], stem_width=32, stem_type='deep', avg_down=True,
|
||||
num_classes=num_classes, in_chans=in_chans, block_args=dict(sk_kwargs=sk_kwargs),
|
||||
num_classes=num_classes, in_chans=in_chans, block_args=dict(sk_kwargs=sk_kwargs), zero_init_last_bn=False
|
||||
**kwargs)
|
||||
model.default_cfg = default_cfg
|
||||
if pretrained:
|
||||
@ -199,7 +200,7 @@ def skresnet50(pretrained=False, num_classes=1000, in_chans=3, **kwargs):
|
||||
default_cfg = default_cfgs['skresnet50']
|
||||
model = ResNet(
|
||||
SelectiveKernelBottleneck, [3, 4, 6, 3], num_classes=num_classes, in_chans=in_chans,
|
||||
block_args=dict(sk_kwargs=sk_kwargs), **kwargs)
|
||||
block_args=dict(sk_kwargs=sk_kwargs), zero_init_last_bn=False, **kwargs)
|
||||
model.default_cfg = default_cfg
|
||||
if pretrained:
|
||||
load_pretrained(model, default_cfg, num_classes, in_chans)
|
||||
@ -218,7 +219,8 @@ def skresnet50d(pretrained=False, num_classes=1000, in_chans=3, **kwargs):
|
||||
default_cfg = default_cfgs['skresnet50d']
|
||||
model = ResNet(
|
||||
SelectiveKernelBottleneck, [3, 4, 6, 3], stem_width=32, stem_type='deep', avg_down=True,
|
||||
num_classes=num_classes, in_chans=in_chans, block_args=dict(sk_kwargs=sk_kwargs), **kwargs)
|
||||
num_classes=num_classes, in_chans=in_chans, block_args=dict(sk_kwargs=sk_kwargs),
|
||||
zero_init_last_bn=False, **kwargs)
|
||||
model.default_cfg = default_cfg
|
||||
if pretrained:
|
||||
load_pretrained(model, default_cfg, num_classes, in_chans)
|
||||
@ -233,7 +235,7 @@ def skresnext50_32x4d(pretrained=False, num_classes=1000, in_chans=3, **kwargs):
|
||||
default_cfg = default_cfgs['skresnext50_32x4d']
|
||||
model = ResNet(
|
||||
SelectiveKernelBottleneck, [3, 4, 6, 3], cardinality=32, base_width=4,
|
||||
num_classes=num_classes, in_chans=in_chans, **kwargs)
|
||||
num_classes=num_classes, in_chans=in_chans, zero_init_last_bn=False, **kwargs)
|
||||
model.default_cfg = default_cfg
|
||||
if pretrained:
|
||||
load_pretrained(model, default_cfg, num_classes, in_chans)
|
||||
|
Loading…
x
Reference in New Issue
Block a user