diff --git a/timm/models/resnetv2.py b/timm/models/resnetv2.py index 8aa2facf..e857e7a9 100644 --- a/timm/models/resnetv2.py +++ b/timm/models/resnetv2.py @@ -46,6 +46,69 @@ from ._registry import generate_default_cfgs, register_model, register_model_dep __all__ = ['ResNetV2'] # model_registry will add each entrypoint fn to this +class PreActBasic(nn.Module): + """ Pre-activation basic block (not in typical 'v2' implementations) + """ + + def __init__( + self, + in_chs, + out_chs=None, + bottle_ratio=1.0, + stride=1, + dilation=1, + first_dilation=None, + groups=1, + act_layer=None, + conv_layer=None, + norm_layer=None, + proj_layer=None, + drop_path_rate=0., + ): + super().__init__() + first_dilation = first_dilation or dilation + conv_layer = conv_layer or StdConv2d + norm_layer = norm_layer or partial(GroupNormAct, num_groups=32) + out_chs = out_chs or in_chs + mid_chs = make_divisible(out_chs * bottle_ratio) + + if proj_layer is not None and (stride != 1 or first_dilation != dilation or in_chs != out_chs): + self.downsample = proj_layer( + in_chs, + out_chs, + stride=stride, + dilation=dilation, + first_dilation=first_dilation, + preact=True, + conv_layer=conv_layer, + norm_layer=norm_layer, + ) + else: + self.downsample = None + + self.norm1 = norm_layer(in_chs) + self.conv1 = conv_layer(in_chs, mid_chs, 3, stride=stride, dilation=first_dilation, groups=groups) + self.norm2 = norm_layer(mid_chs) + self.conv2 = conv_layer(mid_chs, out_chs, 3, dilation=dilation, groups=groups) + self.drop_path = DropPath(drop_path_rate) if drop_path_rate > 0 else nn.Identity() + + def zero_init_last(self): + nn.init.zeros_(self.conv3.weight) + + def forward(self, x): + x_preact = self.norm1(x) + + # shortcut branch + shortcut = x + if self.downsample is not None: + shortcut = self.downsample(x_preact) + + # residual branch + x = self.conv1(x_preact) + x = self.conv2(self.norm2(x)) + x = self.drop_path(x) + return x + shortcut + class PreActBottleneck(nn.Module): """Pre-activation (v2) bottleneck block. @@ -80,8 +143,15 @@ class PreActBottleneck(nn.Module): if proj_layer is not None: self.downsample = proj_layer( - in_chs, out_chs, stride=stride, dilation=dilation, first_dilation=first_dilation, preact=True, - conv_layer=conv_layer, norm_layer=norm_layer) + in_chs, + out_chs, + stride=stride, + dilation=dilation, + first_dilation=first_dilation, + preact=True, + conv_layer=conv_layer, + norm_layer=norm_layer, + ) else: self.downsample = None @@ -140,8 +210,14 @@ class Bottleneck(nn.Module): if proj_layer is not None: self.downsample = proj_layer( - in_chs, out_chs, stride=stride, dilation=dilation, preact=False, - conv_layer=conv_layer, norm_layer=norm_layer) + in_chs, + out_chs, + stride=stride, + dilation=dilation, + preact=False, + conv_layer=conv_layer, + norm_layer=norm_layer, + ) else: self.downsample = None @@ -339,6 +415,8 @@ class ResNetV2(nn.Module): stem_type='', avg_down=False, preact=True, + basic=False, + bottle_ratio=0.25, act_layer=nn.ReLU, norm_layer=partial(GroupNormAct, num_groups=32), conv_layer=StdConv2d, @@ -390,7 +468,11 @@ class ResNetV2(nn.Module): curr_stride = 4 dilation = 1 block_dprs = [x.tolist() for x in torch.linspace(0, drop_path_rate, sum(layers)).split(layers)] - block_fn = PreActBottleneck if preact else Bottleneck + if preact: + block_fn = PreActBasic if basic else PreActBottleneck + else: + assert not basic + block_fn = Bottleneck self.stages = nn.Sequential() for stage_idx, (d, c, bdpr) in enumerate(zip(layers, channels, block_dprs)): out_chs = make_divisible(c * wf) @@ -404,6 +486,7 @@ class ResNetV2(nn.Module): stride=stride, dilation=dilation, depth=d, + bottle_ratio=bottle_ratio, avg_down=avg_down, act_layer=act_layer, conv_layer=conv_layer, @@ -613,6 +696,14 @@ default_cfgs = generate_default_cfgs({ hf_hub_id='timm/', num_classes=21843, custom_load=True), + 'resnetv2_18.untrained': _cfg( + interpolation='bicubic', crop_pct=0.95), + 'resnetv2_18d.untrained': _cfg( + interpolation='bicubic', crop_pct=0.95, first_conv='stem.conv1'), + 'resnetv2_34.untrained': _cfg( + interpolation='bicubic', crop_pct=0.95), + 'resnetv2_34d.untrained': _cfg( + interpolation='bicubic', crop_pct=0.95, first_conv='stem.conv1'), 'resnetv2_50.a1h_in1k': _cfg( hf_hub_id='timm/', interpolation='bicubic', crop_pct=0.95, test_input_size=(3, 288, 288), test_crop_pct=1.0), @@ -679,6 +770,42 @@ def resnetv2_152x4_bit(pretrained=False, **kwargs) -> ResNetV2: 'resnetv2_152x4_bit', pretrained=pretrained, layers=[3, 8, 36, 3], width_factor=4, **kwargs) +@register_model +def resnetv2_18(pretrained=False, **kwargs) -> ResNetV2: + model_args = dict( + layers=[2, 2, 2, 2], channels=(64, 128, 256, 512), basic=True, bottle_ratio=1.0, + conv_layer=create_conv2d, norm_layer=BatchNormAct2d + ) + return _create_resnetv2('resnetv2_18', pretrained=pretrained, **dict(model_args, **kwargs)) + + +@register_model +def resnetv2_18d(pretrained=False, **kwargs) -> ResNetV2: + model_args = dict( + layers=[2, 2, 2, 2], channels=(64, 128, 256, 512), basic=True, bottle_ratio=1.0, + conv_layer=create_conv2d, norm_layer=BatchNormAct2d, stem_type='deep', avg_down=True + ) + return _create_resnetv2('resnetv2_18d', pretrained=pretrained, **dict(model_args, **kwargs)) + + +@register_model +def resnetv2_34(pretrained=False, **kwargs) -> ResNetV2: + model_args = dict( + layers=(3, 4, 6, 3), channels=(64, 128, 256, 512), basic=True, bottle_ratio=1.0, + conv_layer=create_conv2d, norm_layer=BatchNormAct2d + ) + return _create_resnetv2('resnetv2_34', pretrained=pretrained, **dict(model_args, **kwargs)) + + +@register_model +def resnetv2_34d(pretrained=False, **kwargs) -> ResNetV2: + model_args = dict( + layers=(3, 4, 6, 3), channels=(64, 128, 256, 512), basic=True, bottle_ratio=1.0, + conv_layer=create_conv2d, norm_layer=BatchNormAct2d, stem_type='deep', avg_down=True + ) + return _create_resnetv2('resnetv2_34d', pretrained=pretrained, **dict(model_args, **kwargs)) + + @register_model def resnetv2_50(pretrained=False, **kwargs) -> ResNetV2: model_args = dict(layers=[3, 4, 6, 3], conv_layer=create_conv2d, norm_layer=BatchNormAct2d)