mirror of
https://github.com/huggingface/pytorch-image-models.git
synced 2025-06-03 15:01:08 +08:00
A SelectiveKernelBasicBlock for more experiments
This commit is contained in:
parent
ad087b4b17
commit
a93bae6dc5
@ -265,7 +265,7 @@ class SelectiveKernelAttn(nn.Module):
|
||||
class SelectiveKernelConv(nn.Module):
|
||||
|
||||
def __init__(self, in_channels, out_channels, kernel_size=[3, 5], attn_reduction=16,
|
||||
min_attn_feat=32, stride=1, dilation=1, groups=1, keep_3x3=True, use_attn=True,
|
||||
min_attn_feat=16, stride=1, dilation=1, groups=1, keep_3x3=True, use_attn=True,
|
||||
act_layer=nn.ReLU, norm_layer=nn.BatchNorm2d):
|
||||
super(SelectiveKernelConv, self).__init__()
|
||||
if not isinstance(kernel_size, list):
|
||||
@ -316,6 +316,53 @@ class SelectiveKernelConv(nn.Module):
|
||||
return x
|
||||
|
||||
|
||||
class SelectiveKernelBasicBlock(nn.Module):
|
||||
expansion = 1
|
||||
|
||||
def __init__(self, inplanes, planes, stride=1, downsample=None,
|
||||
cardinality=1, base_width=64, use_se=False,
|
||||
reduce_first=1, dilation=1, previous_dilation=1, act_layer=nn.ReLU, norm_layer=nn.BatchNorm2d):
|
||||
super(SelectiveKernelBasicBlock, self).__init__()
|
||||
|
||||
assert cardinality == 1, 'BasicBlock only supports cardinality of 1'
|
||||
assert base_width == 64, 'BasicBlock doest not support changing base width'
|
||||
first_planes = planes // reduce_first
|
||||
outplanes = planes * self.expansion
|
||||
|
||||
self.conv1 = nn.Conv2d(
|
||||
inplanes, first_planes, kernel_size=3, stride=stride, padding=dilation,
|
||||
dilation=dilation, bias=False)
|
||||
self.bn1 = norm_layer(first_planes)
|
||||
self.act1 = act_layer(inplace=True)
|
||||
self.conv2 = SelectiveKernelConv(first_planes, outplanes, dilation=previous_dilation)
|
||||
self.bn2 = norm_layer(outplanes)
|
||||
self.se = SEModule(outplanes, planes // 4) if use_se else None
|
||||
self.act2 = act_layer(inplace=True)
|
||||
self.downsample = downsample
|
||||
self.stride = stride
|
||||
self.dilation = dilation
|
||||
|
||||
def forward(self, x):
|
||||
residual = x
|
||||
|
||||
out = self.conv1(x)
|
||||
out = self.bn1(out)
|
||||
out = self.act1(out)
|
||||
out = self.conv2(out)
|
||||
out = self.bn2(out)
|
||||
|
||||
if self.se is not None:
|
||||
out = self.se(out)
|
||||
|
||||
if self.downsample is not None:
|
||||
residual = self.downsample(x)
|
||||
|
||||
out += residual
|
||||
out = self.act2(out)
|
||||
|
||||
return out
|
||||
|
||||
|
||||
class SelectiveKernelBottleneck(nn.Module):
|
||||
expansion = 4
|
||||
|
||||
@ -581,6 +628,18 @@ def resnet18(pretrained=False, num_classes=1000, in_chans=3, **kwargs):
|
||||
return model
|
||||
|
||||
|
||||
@register_model
|
||||
def skresnet18(pretrained=False, num_classes=1000, in_chans=3, **kwargs):
|
||||
"""Constructs a ResNet-18 model.
|
||||
"""
|
||||
default_cfg = default_cfgs['resnet18']
|
||||
model = ResNet(SelectiveKernelBasicBlock, [2, 2, 2, 2], num_classes=num_classes, in_chans=in_chans, **kwargs)
|
||||
model.default_cfg = default_cfg
|
||||
if pretrained:
|
||||
load_pretrained(model, default_cfg, num_classes, in_chans)
|
||||
return model
|
||||
|
||||
|
||||
@register_model
|
||||
def resnet34(pretrained=False, num_classes=1000, in_chans=3, **kwargs):
|
||||
"""Constructs a ResNet-34 model.
|
||||
|
Loading…
x
Reference in New Issue
Block a user