mirror of
https://github.com/huggingface/pytorch-image-models.git
synced 2025-06-03 15:01:08 +08:00
Significant ResNet refactor:
* stage creation + make_layer moved to separate fn with more sensible dilation/output_stride calc * drop path rate decay easy to impl with refactored block creation loops * fix dilation + blur pool combo
This commit is contained in:
parent
a66df5fb91
commit
f122f0274b
@ -156,7 +156,7 @@ def drop_path(x, drop_prob: float = 0., training: bool = False):
|
||||
return output
|
||||
|
||||
|
||||
class DropPath(nn.ModuleDict):
|
||||
class DropPath(nn.Module):
|
||||
"""Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
|
||||
"""
|
||||
def __init__(self, drop_prob=None):
|
||||
|
@ -205,14 +205,14 @@ class BasicBlock(nn.Module):
|
||||
first_planes = planes // reduce_first
|
||||
outplanes = planes * self.expansion
|
||||
first_dilation = first_dilation or dilation
|
||||
use_aa = aa_layer is not None
|
||||
use_aa = aa_layer is not None and (stride == 2 or first_dilation != dilation)
|
||||
|
||||
self.conv1 = nn.Conv2d(
|
||||
inplanes, first_planes, kernel_size=3, stride=1 if use_aa else stride, padding=first_dilation,
|
||||
dilation=first_dilation, bias=False)
|
||||
self.bn1 = norm_layer(first_planes)
|
||||
self.act1 = act_layer(inplace=True)
|
||||
self.aa = aa_layer(channels=first_planes) if stride == 2 and use_aa else None
|
||||
self.aa = aa_layer(channels=first_planes, stride=stride) if use_aa else None
|
||||
|
||||
self.conv2 = nn.Conv2d(
|
||||
first_planes, outplanes, kernel_size=3, padding=dilation, dilation=dilation, bias=False)
|
||||
@ -272,7 +272,7 @@ class Bottleneck(nn.Module):
|
||||
first_planes = width // reduce_first
|
||||
outplanes = planes * self.expansion
|
||||
first_dilation = first_dilation or dilation
|
||||
use_aa = aa_layer is not None
|
||||
use_aa = aa_layer is not None and (stride == 2 or first_dilation != dilation)
|
||||
|
||||
self.conv1 = nn.Conv2d(inplanes, first_planes, kernel_size=1, bias=False)
|
||||
self.bn1 = norm_layer(first_planes)
|
||||
@ -283,7 +283,7 @@ class Bottleneck(nn.Module):
|
||||
padding=first_dilation, dilation=first_dilation, groups=cardinality, bias=False)
|
||||
self.bn2 = norm_layer(width)
|
||||
self.act2 = act_layer(inplace=True)
|
||||
self.aa = aa_layer(channels=width) if stride == 2 and use_aa else None
|
||||
self.aa = aa_layer(channels=width, stride=stride) if use_aa else None
|
||||
|
||||
self.conv3 = nn.Conv2d(width, outplanes, kernel_size=1, bias=False)
|
||||
self.bn3 = norm_layer(outplanes)
|
||||
@ -336,14 +336,6 @@ class Bottleneck(nn.Module):
|
||||
return x
|
||||
|
||||
|
||||
def setup_drop_block(drop_block_rate=0.):
|
||||
return [
|
||||
None,
|
||||
None,
|
||||
DropBlock2d(drop_block_rate, 5, 0.25) if drop_block_rate else None,
|
||||
DropBlock2d(drop_block_rate, 3, 1.00) if drop_block_rate else None]
|
||||
|
||||
|
||||
def downsample_conv(
|
||||
in_channels, out_channels, kernel_size, stride=1, dilation=1, first_dilation=None, norm_layer=None):
|
||||
norm_layer = norm_layer or nn.BatchNorm2d
|
||||
@ -375,6 +367,57 @@ def downsample_avg(
|
||||
])
|
||||
|
||||
|
||||
def drop_blocks(drop_block_rate=0.):
|
||||
return [
|
||||
None, None,
|
||||
DropBlock2d(drop_block_rate, 5, 0.25) if drop_block_rate else None,
|
||||
DropBlock2d(drop_block_rate, 3, 1.00) if drop_block_rate else None]
|
||||
|
||||
|
||||
def make_blocks(
|
||||
block_fn, channels, block_repeats, inplanes, reduce_first=1, output_stride=32,
|
||||
down_kernel_size=1, avg_down=False, drop_block_rate=0., drop_path_rate=0., **kwargs):
|
||||
stages = []
|
||||
feature_info = []
|
||||
net_num_blocks = sum(block_repeats)
|
||||
net_block_idx = 0
|
||||
net_stride = 4
|
||||
dilation = prev_dilation = 1
|
||||
for stage_idx, (planes, num_blocks, db) in enumerate(zip(channels, block_repeats, drop_blocks(drop_block_rate))):
|
||||
stage_name = f'layer{stage_idx + 1}' # never liked this name, but weight compat requires it
|
||||
stride = 1 if stage_idx == 0 else 2
|
||||
if net_stride >= output_stride:
|
||||
dilation *= stride
|
||||
stride = 1
|
||||
else:
|
||||
net_stride *= stride
|
||||
|
||||
downsample = None
|
||||
if stride != 1 or inplanes != planes * block_fn.expansion:
|
||||
down_kwargs = dict(
|
||||
in_channels=inplanes, out_channels=planes * block_fn.expansion, kernel_size=down_kernel_size,
|
||||
stride=stride, dilation=dilation, first_dilation=prev_dilation, norm_layer=kwargs.get('norm_layer'))
|
||||
downsample = downsample_avg(**down_kwargs) if avg_down else downsample_conv(**down_kwargs)
|
||||
|
||||
block_kwargs = dict(reduce_first=reduce_first, dilation=dilation, drop_block=db, **kwargs)
|
||||
blocks = []
|
||||
for block_idx in range(num_blocks):
|
||||
downsample = downsample if block_idx == 0 else None
|
||||
stride = stride if block_idx == 0 else 1
|
||||
block_dpr = drop_path_rate * net_block_idx / (net_num_blocks - 1) # stochastic depth linear decay rule
|
||||
blocks.append(block_fn(
|
||||
inplanes, planes, stride, downsample, first_dilation=prev_dilation,
|
||||
drop_path=DropPath(block_dpr) if block_dpr > 0. else None, **block_kwargs))
|
||||
prev_dilation = dilation
|
||||
inplanes = planes * block_fn.expansion
|
||||
net_block_idx += 1
|
||||
|
||||
stages.append((stage_name, nn.Sequential(*blocks)))
|
||||
feature_info.append(dict(num_chs=inplanes, reduction=net_stride, module=stage_name))
|
||||
|
||||
return stages, feature_info
|
||||
|
||||
|
||||
class ResNet(nn.Module):
|
||||
"""ResNet / ResNeXt / SE-ResNeXt / SE-Net
|
||||
|
||||
@ -448,21 +491,18 @@ class ResNet(nn.Module):
|
||||
|
||||
def __init__(self, block, layers, num_classes=1000, in_chans=3,
|
||||
cardinality=1, base_width=64, stem_width=64, stem_type='',
|
||||
block_reduce_first=1, down_kernel_size=1, avg_down=False, output_stride=32,
|
||||
output_stride=32, block_reduce_first=1, down_kernel_size=1, avg_down=False,
|
||||
act_layer=nn.ReLU, norm_layer=nn.BatchNorm2d, aa_layer=None, drop_rate=0.0, drop_path_rate=0.,
|
||||
drop_block_rate=0., global_pool='avg', zero_init_last_bn=True, block_args=None):
|
||||
block_args = block_args or dict()
|
||||
assert output_stride in (8, 16, 32)
|
||||
self.num_classes = num_classes
|
||||
deep_stem = 'deep' in stem_type
|
||||
self.inplanes = stem_width * 2 if deep_stem else 64
|
||||
self.cardinality = cardinality
|
||||
self.base_width = base_width
|
||||
self.drop_rate = drop_rate
|
||||
self.expansion = block.expansion
|
||||
super(ResNet, self).__init__()
|
||||
|
||||
# Stem
|
||||
deep_stem = 'deep' in stem_type
|
||||
inplanes = stem_width * 2 if deep_stem else 64
|
||||
if deep_stem:
|
||||
stem_chs_1 = stem_chs_2 = stem_width
|
||||
if 'tiered' in stem_type:
|
||||
@ -475,43 +515,31 @@ class ResNet(nn.Module):
|
||||
nn.Conv2d(stem_chs_1, stem_chs_2, 3, stride=1, padding=1, bias=False),
|
||||
norm_layer(stem_chs_2),
|
||||
act_layer(inplace=True),
|
||||
nn.Conv2d(stem_chs_2, self.inplanes, 3, stride=1, padding=1, bias=False)])
|
||||
nn.Conv2d(stem_chs_2, inplanes, 3, stride=1, padding=1, bias=False)])
|
||||
else:
|
||||
self.conv1 = nn.Conv2d(in_chans, self.inplanes, kernel_size=7, stride=2, padding=3, bias=False)
|
||||
self.bn1 = norm_layer(self.inplanes)
|
||||
self.conv1 = nn.Conv2d(in_chans, inplanes, kernel_size=7, stride=2, padding=3, bias=False)
|
||||
self.bn1 = norm_layer(inplanes)
|
||||
self.act1 = act_layer(inplace=True)
|
||||
self.feature_info = [dict(num_chs=self.inplanes, reduction=2, module='act1')]
|
||||
self.feature_info = [dict(num_chs=inplanes, reduction=2, module='act1')]
|
||||
|
||||
# Stem Pooling
|
||||
if aa_layer is not None:
|
||||
self.maxpool = nn.Sequential(*[
|
||||
nn.MaxPool2d(kernel_size=3, stride=1, padding=1),
|
||||
aa_layer(channels=self.inplanes, stride=2)
|
||||
])
|
||||
aa_layer(channels=inplanes, stride=2)])
|
||||
else:
|
||||
self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
|
||||
|
||||
# Feature Blocks
|
||||
channels = [64, 128, 256, 512]
|
||||
dp = DropPath(drop_path_rate) if drop_path_rate else None
|
||||
db = setup_drop_block(drop_block_rate)
|
||||
layer_kwargs = dict(
|
||||
reduce_first=block_reduce_first, act_layer=act_layer, norm_layer=norm_layer, aa_layer=aa_layer,
|
||||
avg_down=avg_down, down_kernel_size=down_kernel_size, drop_path=dp, **block_args)
|
||||
total_stride = 4
|
||||
dilation = 1
|
||||
for i in range(4):
|
||||
layer_name = f'layer{i + 1}'
|
||||
stride = 2 if i > 0 else 1
|
||||
if total_stride >= output_stride:
|
||||
dilation *= stride
|
||||
stride = 1
|
||||
else:
|
||||
total_stride *= stride
|
||||
self.add_module(layer_name, self._make_layer(
|
||||
block, channels[i], layers[i], stride, dilation, drop_block=db[i], **layer_kwargs))
|
||||
self.feature_info.append(dict(
|
||||
num_chs=self.inplanes, reduction=total_stride, module=layer_name))
|
||||
stage_modules, stage_feature_info = make_blocks(
|
||||
block, channels, layers, inplanes, cardinality=cardinality, base_width=base_width,
|
||||
output_stride=output_stride, reduce_first=block_reduce_first, avg_down=avg_down,
|
||||
down_kernel_size=down_kernel_size, act_layer=act_layer, norm_layer=norm_layer, aa_layer=aa_layer,
|
||||
drop_block_rate=drop_block_rate, drop_path_rate=drop_path_rate, **block_args)
|
||||
for stage in stage_modules:
|
||||
self.add_module(*stage) # layer1, layer2, etc
|
||||
self.feature_info.extend(stage_feature_info)
|
||||
|
||||
# Head (Pooling and Classifier)
|
||||
self.global_pool = SelectAdaptivePool2d(pool_type=global_pool)
|
||||
@ -529,25 +557,6 @@ class ResNet(nn.Module):
|
||||
if hasattr(m, 'zero_init_last_bn'):
|
||||
m.zero_init_last_bn()
|
||||
|
||||
def _make_layer(self, block, planes, blocks, stride=1, dilation=1, reduce_first=1,
|
||||
avg_down=False, down_kernel_size=1, **kwargs):
|
||||
downsample = None
|
||||
first_dilation = 1 if dilation in (1, 2) else 2
|
||||
if stride != 1 or self.inplanes != planes * block.expansion:
|
||||
downsample_args = dict(
|
||||
in_channels=self.inplanes, out_channels=planes * block.expansion, kernel_size=down_kernel_size,
|
||||
stride=stride, dilation=dilation, first_dilation=first_dilation, norm_layer=kwargs.get('norm_layer'))
|
||||
downsample = downsample_avg(**downsample_args) if avg_down else downsample_conv(**downsample_args)
|
||||
|
||||
block_kwargs = dict(
|
||||
cardinality=self.cardinality, base_width=self.base_width, reduce_first=reduce_first,
|
||||
dilation=dilation, **kwargs)
|
||||
layers = [block(self.inplanes, planes, stride, downsample, first_dilation=first_dilation, **block_kwargs)]
|
||||
self.inplanes = planes * block.expansion
|
||||
layers += [block(self.inplanes, planes, **block_kwargs) for _ in range(1, blocks)]
|
||||
|
||||
return nn.Sequential(*layers)
|
||||
|
||||
def get_classifier(self):
|
||||
return self.fc
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user