Merge pull request #2298 from huggingface/preact_resnet18

Add resnet18/18d pre-act model configs for potential training.
This commit is contained in:
Ross Wightman 2024-10-14 19:39:04 -07:00 committed by GitHub
commit c3052fa19e
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -46,6 +46,69 @@ from ._registry import generate_default_cfgs, register_model, register_model_dep
__all__ = ['ResNetV2'] # model_registry will add each entrypoint fn to this
class PreActBasic(nn.Module):
""" Pre-activation basic block (not in typical 'v2' implementations)
"""
def __init__(
self,
in_chs,
out_chs=None,
bottle_ratio=1.0,
stride=1,
dilation=1,
first_dilation=None,
groups=1,
act_layer=None,
conv_layer=None,
norm_layer=None,
proj_layer=None,
drop_path_rate=0.,
):
super().__init__()
first_dilation = first_dilation or dilation
conv_layer = conv_layer or StdConv2d
norm_layer = norm_layer or partial(GroupNormAct, num_groups=32)
out_chs = out_chs or in_chs
mid_chs = make_divisible(out_chs * bottle_ratio)
if proj_layer is not None and (stride != 1 or first_dilation != dilation or in_chs != out_chs):
self.downsample = proj_layer(
in_chs,
out_chs,
stride=stride,
dilation=dilation,
first_dilation=first_dilation,
preact=True,
conv_layer=conv_layer,
norm_layer=norm_layer,
)
else:
self.downsample = None
self.norm1 = norm_layer(in_chs)
self.conv1 = conv_layer(in_chs, mid_chs, 3, stride=stride, dilation=first_dilation, groups=groups)
self.norm2 = norm_layer(mid_chs)
self.conv2 = conv_layer(mid_chs, out_chs, 3, dilation=dilation, groups=groups)
self.drop_path = DropPath(drop_path_rate) if drop_path_rate > 0 else nn.Identity()
def zero_init_last(self):
nn.init.zeros_(self.conv3.weight)
def forward(self, x):
x_preact = self.norm1(x)
# shortcut branch
shortcut = x
if self.downsample is not None:
shortcut = self.downsample(x_preact)
# residual branch
x = self.conv1(x_preact)
x = self.conv2(self.norm2(x))
x = self.drop_path(x)
return x + shortcut
class PreActBottleneck(nn.Module):
"""Pre-activation (v2) bottleneck block.
@ -80,8 +143,15 @@ class PreActBottleneck(nn.Module):
if proj_layer is not None:
self.downsample = proj_layer(
in_chs, out_chs, stride=stride, dilation=dilation, first_dilation=first_dilation, preact=True,
conv_layer=conv_layer, norm_layer=norm_layer)
in_chs,
out_chs,
stride=stride,
dilation=dilation,
first_dilation=first_dilation,
preact=True,
conv_layer=conv_layer,
norm_layer=norm_layer,
)
else:
self.downsample = None
@ -140,8 +210,14 @@ class Bottleneck(nn.Module):
if proj_layer is not None:
self.downsample = proj_layer(
in_chs, out_chs, stride=stride, dilation=dilation, preact=False,
conv_layer=conv_layer, norm_layer=norm_layer)
in_chs,
out_chs,
stride=stride,
dilation=dilation,
preact=False,
conv_layer=conv_layer,
norm_layer=norm_layer,
)
else:
self.downsample = None
@ -339,6 +415,8 @@ class ResNetV2(nn.Module):
stem_type='',
avg_down=False,
preact=True,
basic=False,
bottle_ratio=0.25,
act_layer=nn.ReLU,
norm_layer=partial(GroupNormAct, num_groups=32),
conv_layer=StdConv2d,
@ -390,7 +468,11 @@ class ResNetV2(nn.Module):
curr_stride = 4
dilation = 1
block_dprs = [x.tolist() for x in torch.linspace(0, drop_path_rate, sum(layers)).split(layers)]
block_fn = PreActBottleneck if preact else Bottleneck
if preact:
block_fn = PreActBasic if basic else PreActBottleneck
else:
assert not basic
block_fn = Bottleneck
self.stages = nn.Sequential()
for stage_idx, (d, c, bdpr) in enumerate(zip(layers, channels, block_dprs)):
out_chs = make_divisible(c * wf)
@ -404,6 +486,7 @@ class ResNetV2(nn.Module):
stride=stride,
dilation=dilation,
depth=d,
bottle_ratio=bottle_ratio,
avg_down=avg_down,
act_layer=act_layer,
conv_layer=conv_layer,
@ -613,6 +696,14 @@ default_cfgs = generate_default_cfgs({
hf_hub_id='timm/',
num_classes=21843, custom_load=True),
'resnetv2_18.untrained': _cfg(
interpolation='bicubic', crop_pct=0.95),
'resnetv2_18d.untrained': _cfg(
interpolation='bicubic', crop_pct=0.95, first_conv='stem.conv1'),
'resnetv2_34.untrained': _cfg(
interpolation='bicubic', crop_pct=0.95),
'resnetv2_34d.untrained': _cfg(
interpolation='bicubic', crop_pct=0.95, first_conv='stem.conv1'),
'resnetv2_50.a1h_in1k': _cfg(
hf_hub_id='timm/',
interpolation='bicubic', crop_pct=0.95, test_input_size=(3, 288, 288), test_crop_pct=1.0),
@ -679,6 +770,42 @@ def resnetv2_152x4_bit(pretrained=False, **kwargs) -> ResNetV2:
'resnetv2_152x4_bit', pretrained=pretrained, layers=[3, 8, 36, 3], width_factor=4, **kwargs)
@register_model
def resnetv2_18(pretrained=False, **kwargs) -> ResNetV2:
model_args = dict(
layers=[2, 2, 2, 2], channels=(64, 128, 256, 512), basic=True, bottle_ratio=1.0,
conv_layer=create_conv2d, norm_layer=BatchNormAct2d
)
return _create_resnetv2('resnetv2_18', pretrained=pretrained, **dict(model_args, **kwargs))
@register_model
def resnetv2_18d(pretrained=False, **kwargs) -> ResNetV2:
model_args = dict(
layers=[2, 2, 2, 2], channels=(64, 128, 256, 512), basic=True, bottle_ratio=1.0,
conv_layer=create_conv2d, norm_layer=BatchNormAct2d, stem_type='deep', avg_down=True
)
return _create_resnetv2('resnetv2_18d', pretrained=pretrained, **dict(model_args, **kwargs))
@register_model
def resnetv2_34(pretrained=False, **kwargs) -> ResNetV2:
model_args = dict(
layers=(3, 4, 6, 3), channels=(64, 128, 256, 512), basic=True, bottle_ratio=1.0,
conv_layer=create_conv2d, norm_layer=BatchNormAct2d
)
return _create_resnetv2('resnetv2_34', pretrained=pretrained, **dict(model_args, **kwargs))
@register_model
def resnetv2_34d(pretrained=False, **kwargs) -> ResNetV2:
model_args = dict(
layers=(3, 4, 6, 3), channels=(64, 128, 256, 512), basic=True, bottle_ratio=1.0,
conv_layer=create_conv2d, norm_layer=BatchNormAct2d, stem_type='deep', avg_down=True
)
return _create_resnetv2('resnetv2_34d', pretrained=pretrained, **dict(model_args, **kwargs))
@register_model
def resnetv2_50(pretrained=False, **kwargs) -> ResNetV2:
model_args = dict(layers=[3, 4, 6, 3], conv_layer=create_conv2d, norm_layer=BatchNormAct2d)