mirror of
https://github.com/huggingface/pytorch-image-models.git
synced 2025-06-03 15:01:08 +08:00
SelectKernel split_input works best when input channels split like grouped conv, but output is full width. Disable zero_init for SK nets, seems a bad combo.
This commit is contained in:
parent
7d07ebb660
commit
13e8da2b46
@ -311,15 +311,13 @@ class SelectiveKernelConv(nn.Module):
|
|||||||
kernel_size = [3] * len(kernel_size)
|
kernel_size = [3] * len(kernel_size)
|
||||||
else:
|
else:
|
||||||
dilation = [dilation] * len(kernel_size)
|
dilation = [dilation] * len(kernel_size)
|
||||||
num_paths = len(kernel_size)
|
self.num_paths = len(kernel_size)
|
||||||
self.num_paths = num_paths
|
|
||||||
self.split_input = split_input
|
|
||||||
self.in_channels = in_channels
|
self.in_channels = in_channels
|
||||||
self.out_channels = out_channels
|
self.out_channels = out_channels
|
||||||
if split_input:
|
self.split_input = split_input
|
||||||
assert in_channels % num_paths == 0 and out_channels % num_paths == 0
|
if self.split_input:
|
||||||
in_channels = in_channels // num_paths
|
assert in_channels % self.num_paths == 0
|
||||||
out_channels = out_channels // num_paths
|
in_channels = in_channels // self.num_paths
|
||||||
groups = min(out_channels, groups)
|
groups = min(out_channels, groups)
|
||||||
|
|
||||||
conv_kwargs = dict(
|
conv_kwargs = dict(
|
||||||
@ -329,7 +327,7 @@ class SelectiveKernelConv(nn.Module):
|
|||||||
for k, d in zip(kernel_size, dilation)])
|
for k, d in zip(kernel_size, dilation)])
|
||||||
|
|
||||||
attn_channels = max(int(out_channels / attn_reduction), min_attn_channels)
|
attn_channels = max(int(out_channels / attn_reduction), min_attn_channels)
|
||||||
self.attn = SelectiveKernelAttn(out_channels, num_paths, attn_channels)
|
self.attn = SelectiveKernelAttn(out_channels, self.num_paths, attn_channels)
|
||||||
self.drop_block = drop_block
|
self.drop_block = drop_block
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
@ -338,15 +336,9 @@ class SelectiveKernelConv(nn.Module):
|
|||||||
x_paths = [op(x_split[i]) for i, op in enumerate(self.paths)]
|
x_paths = [op(x_split[i]) for i, op in enumerate(self.paths)]
|
||||||
else:
|
else:
|
||||||
x_paths = [op(x) for op in self.paths]
|
x_paths = [op(x) for op in self.paths]
|
||||||
|
|
||||||
x = torch.stack(x_paths, dim=1)
|
x = torch.stack(x_paths, dim=1)
|
||||||
x_attn = self.attn(x)
|
x_attn = self.attn(x)
|
||||||
x = x * x_attn
|
x = x * x_attn
|
||||||
|
|
||||||
if self.split_input:
|
|
||||||
B, N, C, H, W = x.shape
|
|
||||||
x = x.reshape(B, N * C, H, W)
|
|
||||||
else:
|
|
||||||
x = torch.sum(x, dim=1)
|
x = torch.sum(x, dim=1)
|
||||||
return x
|
return x
|
||||||
|
|
||||||
|
@ -158,11 +158,12 @@ def sksresnet18(pretrained=False, num_classes=1000, in_chans=3, **kwargs):
|
|||||||
default_cfg = default_cfgs['skresnet18']
|
default_cfg = default_cfgs['skresnet18']
|
||||||
sk_kwargs = dict(
|
sk_kwargs = dict(
|
||||||
min_attn_channels=16,
|
min_attn_channels=16,
|
||||||
|
attn_reduction=8,
|
||||||
split_input=True
|
split_input=True
|
||||||
)
|
)
|
||||||
model = ResNet(
|
model = ResNet(
|
||||||
SelectiveKernelBasic, [2, 2, 2, 2], num_classes=num_classes, in_chans=in_chans,
|
SelectiveKernelBasic, [2, 2, 2, 2], num_classes=num_classes, in_chans=in_chans,
|
||||||
block_args=dict(sk_kwargs=sk_kwargs), **kwargs)
|
block_args=dict(sk_kwargs=sk_kwargs), zero_init_last_bn=False, **kwargs)
|
||||||
model.default_cfg = default_cfg
|
model.default_cfg = default_cfg
|
||||||
if pretrained:
|
if pretrained:
|
||||||
load_pretrained(model, default_cfg, num_classes, in_chans)
|
load_pretrained(model, default_cfg, num_classes, in_chans)
|
||||||
@ -179,7 +180,7 @@ def skresnet26d(pretrained=False, num_classes=1000, in_chans=3, **kwargs):
|
|||||||
)
|
)
|
||||||
model = ResNet(
|
model = ResNet(
|
||||||
SelectiveKernelBottleneck, [2, 2, 2, 2], stem_width=32, stem_type='deep', avg_down=True,
|
SelectiveKernelBottleneck, [2, 2, 2, 2], stem_width=32, stem_type='deep', avg_down=True,
|
||||||
num_classes=num_classes, in_chans=in_chans, block_args=dict(sk_kwargs=sk_kwargs),
|
num_classes=num_classes, in_chans=in_chans, block_args=dict(sk_kwargs=sk_kwargs), zero_init_last_bn=False
|
||||||
**kwargs)
|
**kwargs)
|
||||||
model.default_cfg = default_cfg
|
model.default_cfg = default_cfg
|
||||||
if pretrained:
|
if pretrained:
|
||||||
@ -199,7 +200,7 @@ def skresnet50(pretrained=False, num_classes=1000, in_chans=3, **kwargs):
|
|||||||
default_cfg = default_cfgs['skresnet50']
|
default_cfg = default_cfgs['skresnet50']
|
||||||
model = ResNet(
|
model = ResNet(
|
||||||
SelectiveKernelBottleneck, [3, 4, 6, 3], num_classes=num_classes, in_chans=in_chans,
|
SelectiveKernelBottleneck, [3, 4, 6, 3], num_classes=num_classes, in_chans=in_chans,
|
||||||
block_args=dict(sk_kwargs=sk_kwargs), **kwargs)
|
block_args=dict(sk_kwargs=sk_kwargs), zero_init_last_bn=False, **kwargs)
|
||||||
model.default_cfg = default_cfg
|
model.default_cfg = default_cfg
|
||||||
if pretrained:
|
if pretrained:
|
||||||
load_pretrained(model, default_cfg, num_classes, in_chans)
|
load_pretrained(model, default_cfg, num_classes, in_chans)
|
||||||
@ -218,7 +219,8 @@ def skresnet50d(pretrained=False, num_classes=1000, in_chans=3, **kwargs):
|
|||||||
default_cfg = default_cfgs['skresnet50d']
|
default_cfg = default_cfgs['skresnet50d']
|
||||||
model = ResNet(
|
model = ResNet(
|
||||||
SelectiveKernelBottleneck, [3, 4, 6, 3], stem_width=32, stem_type='deep', avg_down=True,
|
SelectiveKernelBottleneck, [3, 4, 6, 3], stem_width=32, stem_type='deep', avg_down=True,
|
||||||
num_classes=num_classes, in_chans=in_chans, block_args=dict(sk_kwargs=sk_kwargs), **kwargs)
|
num_classes=num_classes, in_chans=in_chans, block_args=dict(sk_kwargs=sk_kwargs),
|
||||||
|
zero_init_last_bn=False, **kwargs)
|
||||||
model.default_cfg = default_cfg
|
model.default_cfg = default_cfg
|
||||||
if pretrained:
|
if pretrained:
|
||||||
load_pretrained(model, default_cfg, num_classes, in_chans)
|
load_pretrained(model, default_cfg, num_classes, in_chans)
|
||||||
@ -233,7 +235,7 @@ def skresnext50_32x4d(pretrained=False, num_classes=1000, in_chans=3, **kwargs):
|
|||||||
default_cfg = default_cfgs['skresnext50_32x4d']
|
default_cfg = default_cfgs['skresnext50_32x4d']
|
||||||
model = ResNet(
|
model = ResNet(
|
||||||
SelectiveKernelBottleneck, [3, 4, 6, 3], cardinality=32, base_width=4,
|
SelectiveKernelBottleneck, [3, 4, 6, 3], cardinality=32, base_width=4,
|
||||||
num_classes=num_classes, in_chans=in_chans, **kwargs)
|
num_classes=num_classes, in_chans=in_chans, zero_init_last_bn=False, **kwargs)
|
||||||
model.default_cfg = default_cfg
|
model.default_cfg = default_cfg
|
||||||
if pretrained:
|
if pretrained:
|
||||||
load_pretrained(model, default_cfg, num_classes, in_chans)
|
load_pretrained(model, default_cfg, num_classes, in_chans)
|
||||||
|
Loading…
x
Reference in New Issue
Block a user