mirror of
https://github.com/huggingface/pytorch-image-models.git
synced 2025-06-03 15:01:08 +08:00
Merge pull request #2298 from huggingface/preact_resnet18
Add resnet18/18d pre-act model configs for potential training.
This commit is contained in:
commit
c3052fa19e
@ -46,6 +46,69 @@ from ._registry import generate_default_cfgs, register_model, register_model_dep
|
||||
__all__ = ['ResNetV2'] # model_registry will add each entrypoint fn to this
|
||||
|
||||
|
||||
class PreActBasic(nn.Module):
|
||||
""" Pre-activation basic block (not in typical 'v2' implementations)
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
in_chs,
|
||||
out_chs=None,
|
||||
bottle_ratio=1.0,
|
||||
stride=1,
|
||||
dilation=1,
|
||||
first_dilation=None,
|
||||
groups=1,
|
||||
act_layer=None,
|
||||
conv_layer=None,
|
||||
norm_layer=None,
|
||||
proj_layer=None,
|
||||
drop_path_rate=0.,
|
||||
):
|
||||
super().__init__()
|
||||
first_dilation = first_dilation or dilation
|
||||
conv_layer = conv_layer or StdConv2d
|
||||
norm_layer = norm_layer or partial(GroupNormAct, num_groups=32)
|
||||
out_chs = out_chs or in_chs
|
||||
mid_chs = make_divisible(out_chs * bottle_ratio)
|
||||
|
||||
if proj_layer is not None and (stride != 1 or first_dilation != dilation or in_chs != out_chs):
|
||||
self.downsample = proj_layer(
|
||||
in_chs,
|
||||
out_chs,
|
||||
stride=stride,
|
||||
dilation=dilation,
|
||||
first_dilation=first_dilation,
|
||||
preact=True,
|
||||
conv_layer=conv_layer,
|
||||
norm_layer=norm_layer,
|
||||
)
|
||||
else:
|
||||
self.downsample = None
|
||||
|
||||
self.norm1 = norm_layer(in_chs)
|
||||
self.conv1 = conv_layer(in_chs, mid_chs, 3, stride=stride, dilation=first_dilation, groups=groups)
|
||||
self.norm2 = norm_layer(mid_chs)
|
||||
self.conv2 = conv_layer(mid_chs, out_chs, 3, dilation=dilation, groups=groups)
|
||||
self.drop_path = DropPath(drop_path_rate) if drop_path_rate > 0 else nn.Identity()
|
||||
|
||||
def zero_init_last(self):
|
||||
nn.init.zeros_(self.conv3.weight)
|
||||
|
||||
def forward(self, x):
|
||||
x_preact = self.norm1(x)
|
||||
|
||||
# shortcut branch
|
||||
shortcut = x
|
||||
if self.downsample is not None:
|
||||
shortcut = self.downsample(x_preact)
|
||||
|
||||
# residual branch
|
||||
x = self.conv1(x_preact)
|
||||
x = self.conv2(self.norm2(x))
|
||||
x = self.drop_path(x)
|
||||
return x + shortcut
|
||||
|
||||
|
||||
class PreActBottleneck(nn.Module):
|
||||
"""Pre-activation (v2) bottleneck block.
|
||||
@ -80,8 +143,15 @@ class PreActBottleneck(nn.Module):
|
||||
|
||||
if proj_layer is not None:
|
||||
self.downsample = proj_layer(
|
||||
in_chs, out_chs, stride=stride, dilation=dilation, first_dilation=first_dilation, preact=True,
|
||||
conv_layer=conv_layer, norm_layer=norm_layer)
|
||||
in_chs,
|
||||
out_chs,
|
||||
stride=stride,
|
||||
dilation=dilation,
|
||||
first_dilation=first_dilation,
|
||||
preact=True,
|
||||
conv_layer=conv_layer,
|
||||
norm_layer=norm_layer,
|
||||
)
|
||||
else:
|
||||
self.downsample = None
|
||||
|
||||
@ -140,8 +210,14 @@ class Bottleneck(nn.Module):
|
||||
|
||||
if proj_layer is not None:
|
||||
self.downsample = proj_layer(
|
||||
in_chs, out_chs, stride=stride, dilation=dilation, preact=False,
|
||||
conv_layer=conv_layer, norm_layer=norm_layer)
|
||||
in_chs,
|
||||
out_chs,
|
||||
stride=stride,
|
||||
dilation=dilation,
|
||||
preact=False,
|
||||
conv_layer=conv_layer,
|
||||
norm_layer=norm_layer,
|
||||
)
|
||||
else:
|
||||
self.downsample = None
|
||||
|
||||
@ -339,6 +415,8 @@ class ResNetV2(nn.Module):
|
||||
stem_type='',
|
||||
avg_down=False,
|
||||
preact=True,
|
||||
basic=False,
|
||||
bottle_ratio=0.25,
|
||||
act_layer=nn.ReLU,
|
||||
norm_layer=partial(GroupNormAct, num_groups=32),
|
||||
conv_layer=StdConv2d,
|
||||
@ -390,7 +468,11 @@ class ResNetV2(nn.Module):
|
||||
curr_stride = 4
|
||||
dilation = 1
|
||||
block_dprs = [x.tolist() for x in torch.linspace(0, drop_path_rate, sum(layers)).split(layers)]
|
||||
block_fn = PreActBottleneck if preact else Bottleneck
|
||||
if preact:
|
||||
block_fn = PreActBasic if basic else PreActBottleneck
|
||||
else:
|
||||
assert not basic
|
||||
block_fn = Bottleneck
|
||||
self.stages = nn.Sequential()
|
||||
for stage_idx, (d, c, bdpr) in enumerate(zip(layers, channels, block_dprs)):
|
||||
out_chs = make_divisible(c * wf)
|
||||
@ -404,6 +486,7 @@ class ResNetV2(nn.Module):
|
||||
stride=stride,
|
||||
dilation=dilation,
|
||||
depth=d,
|
||||
bottle_ratio=bottle_ratio,
|
||||
avg_down=avg_down,
|
||||
act_layer=act_layer,
|
||||
conv_layer=conv_layer,
|
||||
@ -613,6 +696,14 @@ default_cfgs = generate_default_cfgs({
|
||||
hf_hub_id='timm/',
|
||||
num_classes=21843, custom_load=True),
|
||||
|
||||
'resnetv2_18.untrained': _cfg(
|
||||
interpolation='bicubic', crop_pct=0.95),
|
||||
'resnetv2_18d.untrained': _cfg(
|
||||
interpolation='bicubic', crop_pct=0.95, first_conv='stem.conv1'),
|
||||
'resnetv2_34.untrained': _cfg(
|
||||
interpolation='bicubic', crop_pct=0.95),
|
||||
'resnetv2_34d.untrained': _cfg(
|
||||
interpolation='bicubic', crop_pct=0.95, first_conv='stem.conv1'),
|
||||
'resnetv2_50.a1h_in1k': _cfg(
|
||||
hf_hub_id='timm/',
|
||||
interpolation='bicubic', crop_pct=0.95, test_input_size=(3, 288, 288), test_crop_pct=1.0),
|
||||
@ -679,6 +770,42 @@ def resnetv2_152x4_bit(pretrained=False, **kwargs) -> ResNetV2:
|
||||
'resnetv2_152x4_bit', pretrained=pretrained, layers=[3, 8, 36, 3], width_factor=4, **kwargs)
|
||||
|
||||
|
||||
@register_model
|
||||
def resnetv2_18(pretrained=False, **kwargs) -> ResNetV2:
|
||||
model_args = dict(
|
||||
layers=[2, 2, 2, 2], channels=(64, 128, 256, 512), basic=True, bottle_ratio=1.0,
|
||||
conv_layer=create_conv2d, norm_layer=BatchNormAct2d
|
||||
)
|
||||
return _create_resnetv2('resnetv2_18', pretrained=pretrained, **dict(model_args, **kwargs))
|
||||
|
||||
|
||||
@register_model
|
||||
def resnetv2_18d(pretrained=False, **kwargs) -> ResNetV2:
|
||||
model_args = dict(
|
||||
layers=[2, 2, 2, 2], channels=(64, 128, 256, 512), basic=True, bottle_ratio=1.0,
|
||||
conv_layer=create_conv2d, norm_layer=BatchNormAct2d, stem_type='deep', avg_down=True
|
||||
)
|
||||
return _create_resnetv2('resnetv2_18d', pretrained=pretrained, **dict(model_args, **kwargs))
|
||||
|
||||
|
||||
@register_model
|
||||
def resnetv2_34(pretrained=False, **kwargs) -> ResNetV2:
|
||||
model_args = dict(
|
||||
layers=(3, 4, 6, 3), channels=(64, 128, 256, 512), basic=True, bottle_ratio=1.0,
|
||||
conv_layer=create_conv2d, norm_layer=BatchNormAct2d
|
||||
)
|
||||
return _create_resnetv2('resnetv2_34', pretrained=pretrained, **dict(model_args, **kwargs))
|
||||
|
||||
|
||||
@register_model
|
||||
def resnetv2_34d(pretrained=False, **kwargs) -> ResNetV2:
|
||||
model_args = dict(
|
||||
layers=(3, 4, 6, 3), channels=(64, 128, 256, 512), basic=True, bottle_ratio=1.0,
|
||||
conv_layer=create_conv2d, norm_layer=BatchNormAct2d, stem_type='deep', avg_down=True
|
||||
)
|
||||
return _create_resnetv2('resnetv2_34d', pretrained=pretrained, **dict(model_args, **kwargs))
|
||||
|
||||
|
||||
@register_model
|
||||
def resnetv2_50(pretrained=False, **kwargs) -> ResNetV2:
|
||||
model_args = dict(layers=[3, 4, 6, 3], conv_layer=create_conv2d, norm_layer=BatchNormAct2d)
|
||||
|
Loading…
x
Reference in New Issue
Block a user