mirror of
https://github.com/huggingface/pytorch-image-models.git
synced 2025-06-03 15:01:08 +08:00
ResNet / Res2Net additions:
* ResNet torchscript compat * output_stride arg supported to limit network stride via dilations (support for dilation added to Res2Net) * allow activation layer to be changed via act_layer arg
This commit is contained in:
parent
f96b3e5e92
commit
53001dd292
@ -54,9 +54,8 @@ class Bottle2neck(nn.Module):
|
||||
|
||||
def __init__(self, inplanes, planes, stride=1, downsample=None,
|
||||
cardinality=1, base_width=26, scale=4, use_se=False,
|
||||
norm_layer=None, dilation=1, previous_dilation=1, **_):
|
||||
act_layer=nn.ReLU, norm_layer=None, dilation=1, previous_dilation=1, **_):
|
||||
super(Bottle2neck, self).__init__()
|
||||
assert dilation == 1 and previous_dilation == 1 # FIXME support dilation
|
||||
self.scale = scale
|
||||
self.is_first = stride > 1 or downsample is not None
|
||||
self.num_scales = max(1, scale - 1)
|
||||
@ -71,18 +70,20 @@ class Bottle2neck(nn.Module):
|
||||
bns = []
|
||||
for i in range(self.num_scales):
|
||||
convs.append(nn.Conv2d(
|
||||
width, width, kernel_size=3, stride=stride, padding=1, groups=cardinality, bias=False))
|
||||
width, width, kernel_size=3, stride=stride, padding=dilation,
|
||||
dilation=dilation, groups=cardinality, bias=False))
|
||||
bns.append(norm_layer(width))
|
||||
self.convs = nn.ModuleList(convs)
|
||||
self.bns = nn.ModuleList(bns)
|
||||
if self.is_first:
|
||||
# FIXME this should probably have count_include_pad=False, but hurts original weights
|
||||
self.pool = nn.AvgPool2d(kernel_size=3, stride=stride, padding=1)
|
||||
|
||||
self.conv3 = nn.Conv2d(width * scale, outplanes, kernel_size=1, bias=False)
|
||||
self.bn3 = norm_layer(outplanes)
|
||||
self.se = SEModule(outplanes, planes // 4) if use_se else None
|
||||
|
||||
self.relu = nn.ReLU(inplace=True)
|
||||
self.relu = act_layer(inplace=True)
|
||||
self.downsample = downsample
|
||||
|
||||
def forward(self, x):
|
||||
|
@ -125,11 +125,12 @@ class SEModule(nn.Module):
|
||||
|
||||
|
||||
class BasicBlock(nn.Module):
|
||||
__constants__ = ['se', 'downsample'] # for pre 1.4 torchscript compat
|
||||
expansion = 1
|
||||
|
||||
def __init__(self, inplanes, planes, stride=1, downsample=None,
|
||||
cardinality=1, base_width=64, use_se=False,
|
||||
reduce_first=1, dilation=1, previous_dilation=1, norm_layer=nn.BatchNorm2d):
|
||||
reduce_first=1, dilation=1, previous_dilation=1, act_layer=nn.ReLU, norm_layer=nn.BatchNorm2d):
|
||||
super(BasicBlock, self).__init__()
|
||||
|
||||
assert cardinality == 1, 'BasicBlock only supports cardinality of 1'
|
||||
@ -141,12 +142,13 @@ class BasicBlock(nn.Module):
|
||||
inplanes, first_planes, kernel_size=3, stride=stride, padding=dilation,
|
||||
dilation=dilation, bias=False)
|
||||
self.bn1 = norm_layer(first_planes)
|
||||
self.relu = nn.ReLU(inplace=True)
|
||||
self.act1 = act_layer(inplace=True)
|
||||
self.conv2 = nn.Conv2d(
|
||||
first_planes, outplanes, kernel_size=3, padding=previous_dilation,
|
||||
dilation=previous_dilation, bias=False)
|
||||
self.bn2 = norm_layer(outplanes)
|
||||
self.se = SEModule(outplanes, planes // 4) if use_se else None
|
||||
self.act2 = act_layer(inplace=True)
|
||||
self.downsample = downsample
|
||||
self.stride = stride
|
||||
self.dilation = dilation
|
||||
@ -156,7 +158,7 @@ class BasicBlock(nn.Module):
|
||||
|
||||
out = self.conv1(x)
|
||||
out = self.bn1(out)
|
||||
out = self.relu(out)
|
||||
out = self.act1(out)
|
||||
out = self.conv2(out)
|
||||
out = self.bn2(out)
|
||||
|
||||
@ -167,17 +169,18 @@ class BasicBlock(nn.Module):
|
||||
residual = self.downsample(x)
|
||||
|
||||
out += residual
|
||||
out = self.relu(out)
|
||||
out = self.act2(out)
|
||||
|
||||
return out
|
||||
|
||||
|
||||
class Bottleneck(nn.Module):
|
||||
__constants__ = ['se', 'downsample'] # for pre 1.4 torchscript compat
|
||||
expansion = 4
|
||||
|
||||
def __init__(self, inplanes, planes, stride=1, downsample=None,
|
||||
cardinality=1, base_width=64, use_se=False,
|
||||
reduce_first=1, dilation=1, previous_dilation=1, norm_layer=nn.BatchNorm2d):
|
||||
reduce_first=1, dilation=1, previous_dilation=1, act_layer=nn.ReLU, norm_layer=nn.BatchNorm2d):
|
||||
super(Bottleneck, self).__init__()
|
||||
|
||||
width = int(math.floor(planes * (base_width / 64)) * cardinality)
|
||||
@ -186,14 +189,16 @@ class Bottleneck(nn.Module):
|
||||
|
||||
self.conv1 = nn.Conv2d(inplanes, first_planes, kernel_size=1, bias=False)
|
||||
self.bn1 = norm_layer(first_planes)
|
||||
self.act1 = act_layer(inplace=True)
|
||||
self.conv2 = nn.Conv2d(
|
||||
first_planes, width, kernel_size=3, stride=stride,
|
||||
padding=dilation, dilation=dilation, groups=cardinality, bias=False)
|
||||
self.bn2 = norm_layer(width)
|
||||
self.act2 = act_layer(inplace=True)
|
||||
self.conv3 = nn.Conv2d(width, outplanes, kernel_size=1, bias=False)
|
||||
self.bn3 = norm_layer(outplanes)
|
||||
self.se = SEModule(outplanes, planes // 4) if use_se else None
|
||||
self.relu = nn.ReLU(inplace=True)
|
||||
self.act3 = act_layer(inplace=True)
|
||||
self.downsample = downsample
|
||||
self.stride = stride
|
||||
self.dilation = dilation
|
||||
@ -203,11 +208,11 @@ class Bottleneck(nn.Module):
|
||||
|
||||
out = self.conv1(x)
|
||||
out = self.bn1(out)
|
||||
out = self.relu(out)
|
||||
out = self.act1(out)
|
||||
|
||||
out = self.conv2(out)
|
||||
out = self.bn2(out)
|
||||
out = self.relu(out)
|
||||
out = self.act2(out)
|
||||
|
||||
out = self.conv3(out)
|
||||
out = self.bn3(out)
|
||||
@ -219,7 +224,7 @@ class Bottleneck(nn.Module):
|
||||
residual = self.downsample(x)
|
||||
|
||||
out += residual
|
||||
out = self.relu(out)
|
||||
out = self.act3(out)
|
||||
|
||||
return out
|
||||
|
||||
@ -284,9 +289,10 @@ class ResNet(nn.Module):
|
||||
Kernel size of residual block downsampling path, 1x1 for most archs, 3x3 for senets
|
||||
avg_down : bool, default False
|
||||
Whether to use average pooling for projection skip connection between stages/downsample.
|
||||
dilated : bool, default False
|
||||
Applying dilation strategy to pretrained ResNet yielding a stride-8 model,
|
||||
typically used in Semantic Segmentation.
|
||||
output_stride : int, default 32
|
||||
Set the output stride of the network, 32, 16, or 8. Typically used in segmentation.
|
||||
act_layer : class, activation layer
|
||||
norm_layer : class, normalization layer
|
||||
drop_rate : float, default 0.
|
||||
Dropout probability before classifier, for training
|
||||
global_pool : str, default 'avg'
|
||||
@ -294,8 +300,8 @@ class ResNet(nn.Module):
|
||||
"""
|
||||
def __init__(self, block, layers, num_classes=1000, in_chans=3, use_se=False,
|
||||
cardinality=1, base_width=64, stem_width=64, stem_type='',
|
||||
block_reduce_first=1, down_kernel_size=1, avg_down=False, dilated=False,
|
||||
norm_layer=nn.BatchNorm2d, drop_rate=0.0, global_pool='avg',
|
||||
block_reduce_first=1, down_kernel_size=1, avg_down=False, output_stride=32,
|
||||
act_layer=nn.ReLU, norm_layer=nn.BatchNorm2d, drop_rate=0.0, global_pool='avg',
|
||||
zero_init_last_bn=True, block_args=None):
|
||||
block_args = block_args or dict()
|
||||
self.num_classes = num_classes
|
||||
@ -305,9 +311,9 @@ class ResNet(nn.Module):
|
||||
self.base_width = base_width
|
||||
self.drop_rate = drop_rate
|
||||
self.expansion = block.expansion
|
||||
self.dilated = dilated
|
||||
super(ResNet, self).__init__()
|
||||
|
||||
# Stem
|
||||
if deep_stem:
|
||||
stem_chs_1 = stem_chs_2 = stem_width
|
||||
if 'tiered' in stem_type:
|
||||
@ -316,25 +322,37 @@ class ResNet(nn.Module):
|
||||
self.conv1 = nn.Sequential(*[
|
||||
nn.Conv2d(in_chans, stem_chs_1, 3, stride=2, padding=1, bias=False),
|
||||
norm_layer(stem_chs_1),
|
||||
nn.ReLU(inplace=True),
|
||||
act_layer(inplace=True),
|
||||
nn.Conv2d(stem_chs_1, stem_chs_2, 3, stride=1, padding=1, bias=False),
|
||||
norm_layer(stem_chs_2),
|
||||
nn.ReLU(inplace=True),
|
||||
act_layer(inplace=True),
|
||||
nn.Conv2d(stem_chs_2, self.inplanes, 3, stride=1, padding=1, bias=False)])
|
||||
else:
|
||||
self.conv1 = nn.Conv2d(in_chans, self.inplanes, kernel_size=7, stride=2, padding=3, bias=False)
|
||||
self.bn1 = norm_layer(self.inplanes)
|
||||
self.relu = nn.ReLU(inplace=True)
|
||||
self.act1 = act_layer(inplace=True)
|
||||
self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
|
||||
stride_3_4 = 1 if self.dilated else 2
|
||||
dilation_3 = 2 if self.dilated else 1
|
||||
dilation_4 = 4 if self.dilated else 1
|
||||
largs = dict(use_se=use_se, reduce_first=block_reduce_first, norm_layer=norm_layer,
|
||||
avg_down=avg_down, down_kernel_size=down_kernel_size, **block_args)
|
||||
self.layer1 = self._make_layer(block, 64, layers[0], stride=1, **largs)
|
||||
self.layer2 = self._make_layer(block, 128, layers[1], stride=2, **largs)
|
||||
self.layer3 = self._make_layer(block, 256, layers[2], stride=stride_3_4, dilation=dilation_3, **largs)
|
||||
self.layer4 = self._make_layer(block, 512, layers[3], stride=stride_3_4, dilation=dilation_4, **largs)
|
||||
|
||||
# Feature Blocks
|
||||
channels, strides, dilations = [64, 128, 256, 512], [1, 2, 2, 2], [1] * 4
|
||||
if output_stride == 16:
|
||||
strides[3] = 1
|
||||
dilations[3] = 2
|
||||
elif output_stride == 8:
|
||||
strides[2:4] = [1, 1]
|
||||
dilations[2:4] = [2, 4]
|
||||
else:
|
||||
assert output_stride == 32
|
||||
llargs = list(zip(channels, layers, strides, dilations))
|
||||
lkwargs = dict(
|
||||
use_se=use_se, reduce_first=block_reduce_first, act_layer=act_layer, norm_layer=norm_layer,
|
||||
avg_down=avg_down, down_kernel_size=down_kernel_size, **block_args)
|
||||
self.layer1 = self._make_layer(block, *llargs[0], **lkwargs)
|
||||
self.layer2 = self._make_layer(block, *llargs[1], **lkwargs)
|
||||
self.layer3 = self._make_layer(block, *llargs[2], **lkwargs)
|
||||
self.layer4 = self._make_layer(block, *llargs[3], **lkwargs)
|
||||
|
||||
# Head (Pooling and Classifier)
|
||||
self.global_pool = SelectAdaptivePool2d(pool_type=global_pool)
|
||||
self.num_features = 512 * block.expansion
|
||||
self.fc = nn.Linear(self.num_features * self.global_pool.feat_mult(), num_classes)
|
||||
@ -352,7 +370,8 @@ class ResNet(nn.Module):
|
||||
nn.init.constant_(m.bias, 0.)
|
||||
|
||||
def _make_layer(self, block, planes, blocks, stride=1, dilation=1, reduce_first=1,
|
||||
use_se=False, avg_down=False, down_kernel_size=1, norm_layer=nn.BatchNorm2d, **kwargs):
|
||||
use_se=False, avg_down=False, down_kernel_size=1, **kwargs):
|
||||
norm_layer = kwargs.get('norm_layer')
|
||||
downsample = None
|
||||
down_kernel_size = 1 if stride == 1 and dilation == 1 else down_kernel_size
|
||||
if stride != 1 or self.inplanes != planes * block.expansion:
|
||||
@ -370,15 +389,15 @@ class ResNet(nn.Module):
|
||||
downsample = nn.Sequential(*downsample_layers)
|
||||
|
||||
first_dilation = 1 if dilation in (1, 2) else 2
|
||||
bargs = dict(
|
||||
bkwargs = dict(
|
||||
cardinality=self.cardinality, base_width=self.base_width, reduce_first=reduce_first,
|
||||
use_se=use_se, norm_layer=norm_layer, **kwargs)
|
||||
use_se=use_se, **kwargs)
|
||||
layers = [block(
|
||||
self.inplanes, planes, stride, downsample, dilation=first_dilation, previous_dilation=dilation, **bargs)]
|
||||
self.inplanes, planes, stride, downsample, dilation=first_dilation, previous_dilation=dilation, **bkwargs)]
|
||||
self.inplanes = planes * block.expansion
|
||||
for i in range(1, blocks):
|
||||
layers.append(block(
|
||||
self.inplanes, planes, dilation=dilation, previous_dilation=dilation, **bargs))
|
||||
self.inplanes, planes, dilation=dilation, previous_dilation=dilation, **bkwargs))
|
||||
|
||||
return nn.Sequential(*layers)
|
||||
|
||||
@ -394,7 +413,7 @@ class ResNet(nn.Module):
|
||||
def forward_features(self, x):
|
||||
x = self.conv1(x)
|
||||
x = self.bn1(x)
|
||||
x = self.relu(x)
|
||||
x = self.act1(x)
|
||||
x = self.maxpool(x)
|
||||
|
||||
x = self.layer1(x)
|
||||
|
Loading…
x
Reference in New Issue
Block a user