Fix BatchNorm for ResNetV2 non GN models, add more ResNetV2 model defs for future experimentation, fix zero_init of last residual for pre-act.
parent
02aaa785b9
commit
e8045e712f
|
@ -38,7 +38,8 @@ from functools import partial
|
|||
from timm.data import IMAGENET_INCEPTION_MEAN, IMAGENET_INCEPTION_STD
|
||||
from .helpers import build_model_with_cfg, named_apply, adapt_input_conv
|
||||
from .registry import register_model
|
||||
from .layers import GroupNormAct, ClassifierHead, DropPath, AvgPool2dSame, create_pool2d, StdConv2d, create_conv2d
|
||||
from .layers import GroupNormAct, BatchNormAct2d, EvoNormBatch2d, EvoNormSample2d,\
|
||||
ClassifierHead, DropPath, AvgPool2dSame, create_pool2d, StdConv2d, create_conv2d
|
||||
|
||||
|
||||
def _cfg(url='', **kwargs):
|
||||
|
@ -107,6 +108,16 @@ default_cfgs = {
|
|||
interpolation='bicubic'),
|
||||
'resnetv2_50d': _cfg(
|
||||
interpolation='bicubic', first_conv='stem.conv1'),
|
||||
'resnetv2_50t': _cfg(
|
||||
interpolation='bicubic', first_conv='stem.conv1'),
|
||||
'resnetv2_101': _cfg(
|
||||
interpolation='bicubic'),
|
||||
'resnetv2_101d': _cfg(
|
||||
interpolation='bicubic', first_conv='stem.conv1'),
|
||||
'resnetv2_152': _cfg(
|
||||
interpolation='bicubic'),
|
||||
'resnetv2_152d': _cfg(
|
||||
interpolation='bicubic', first_conv='stem.conv1'),
|
||||
}
|
||||
|
||||
|
||||
|
@ -152,8 +163,8 @@ class PreActBottleneck(nn.Module):
|
|||
self.conv3 = conv_layer(mid_chs, out_chs, 1)
|
||||
self.drop_path = DropPath(drop_path_rate) if drop_path_rate > 0 else nn.Identity()
|
||||
|
||||
def zero_init_last_bn(self):
|
||||
nn.init.zeros_(self.norm3.weight)
|
||||
def zero_init_last(self):
|
||||
nn.init.zeros_(self.conv3.weight)
|
||||
|
||||
def forward(self, x):
|
||||
x_preact = self.norm1(x)
|
||||
|
@ -201,7 +212,7 @@ class Bottleneck(nn.Module):
|
|||
self.drop_path = DropPath(drop_path_rate) if drop_path_rate > 0 else nn.Identity()
|
||||
self.act3 = act_layer(inplace=True)
|
||||
|
||||
def zero_init_last_bn(self):
|
||||
def zero_init_last(self):
|
||||
nn.init.zeros_(self.norm3.weight)
|
||||
|
||||
def forward(self, x):
|
||||
|
@ -284,17 +295,20 @@ def create_resnetv2_stem(
|
|||
in_chs, out_chs=64, stem_type='', preact=True,
|
||||
conv_layer=StdConv2d, norm_layer=partial(GroupNormAct, num_groups=32)):
|
||||
stem = OrderedDict()
|
||||
assert stem_type in ('', 'fixed', 'same', 'deep', 'deep_fixed', 'deep_same')
|
||||
assert stem_type in ('', 'fixed', 'same', 'deep', 'deep_fixed', 'deep_same', 'tiered')
|
||||
|
||||
# NOTE conv padding mode can be changed by overriding the conv_layer def
|
||||
if 'deep' in stem_type:
|
||||
if any([s in stem_type for s in ('deep', 'tiered')]):
|
||||
# A 3 deep 3x3 conv stack as in ResNet V1D models
|
||||
mid_chs = out_chs // 2
|
||||
stem['conv1'] = conv_layer(in_chs, mid_chs, kernel_size=3, stride=2)
|
||||
stem['norm1'] = norm_layer(mid_chs)
|
||||
stem['conv2'] = conv_layer(mid_chs, mid_chs, kernel_size=3, stride=1)
|
||||
stem['norm2'] = norm_layer(mid_chs)
|
||||
stem['conv3'] = conv_layer(mid_chs, out_chs, kernel_size=3, stride=1)
|
||||
if 'tiered' in stem_type:
|
||||
stem_chs = (3 * out_chs // 8, out_chs // 2) # 'T' resnets in resnet.py
|
||||
else:
|
||||
stem_chs = (out_chs // 2, out_chs // 2) # 'D' ResNets
|
||||
stem['conv1'] = conv_layer(in_chs, stem_chs[0], kernel_size=3, stride=2)
|
||||
stem['norm1'] = norm_layer(stem_chs[0])
|
||||
stem['conv2'] = conv_layer(stem_chs[0], stem_chs[1], kernel_size=3, stride=1)
|
||||
stem['norm2'] = norm_layer(stem_chs[1])
|
||||
stem['conv3'] = conv_layer(stem_chs[1], out_chs, kernel_size=3, stride=1)
|
||||
if not preact:
|
||||
stem['norm3'] = norm_layer(out_chs)
|
||||
else:
|
||||
|
@ -326,7 +340,7 @@ class ResNetV2(nn.Module):
|
|||
num_classes=1000, in_chans=3, global_pool='avg', output_stride=32,
|
||||
width_factor=1, stem_chs=64, stem_type='', avg_down=False, preact=True,
|
||||
act_layer=nn.ReLU, conv_layer=StdConv2d, norm_layer=partial(GroupNormAct, num_groups=32),
|
||||
drop_rate=0., drop_path_rate=0., zero_init_last_bn=True):
|
||||
drop_rate=0., drop_path_rate=0., zero_init_last=True):
|
||||
super().__init__()
|
||||
self.num_classes = num_classes
|
||||
self.drop_rate = drop_rate
|
||||
|
@ -364,10 +378,10 @@ class ResNetV2(nn.Module):
|
|||
self.head = ClassifierHead(
|
||||
self.num_features, num_classes, pool_type=global_pool, drop_rate=self.drop_rate, use_conv=True)
|
||||
|
||||
self.init_weights(zero_init_last_bn=zero_init_last_bn)
|
||||
self.init_weights(zero_init_last=zero_init_last)
|
||||
|
||||
def init_weights(self, zero_init_last_bn=True):
|
||||
named_apply(partial(_init_weights, zero_init_last_bn=zero_init_last_bn), self)
|
||||
def init_weights(self, zero_init_last=True):
|
||||
named_apply(partial(_init_weights, zero_init_last=zero_init_last), self)
|
||||
|
||||
@torch.jit.ignore()
|
||||
def load_pretrained(self, checkpoint_path, prefix='resnet/'):
|
||||
|
@ -393,7 +407,7 @@ class ResNetV2(nn.Module):
|
|||
return x
|
||||
|
||||
|
||||
def _init_weights(module: nn.Module, name: str = '', zero_init_last_bn=True):
|
||||
def _init_weights(module: nn.Module, name: str = '', zero_init_last=True):
|
||||
if isinstance(module, nn.Linear) or ('head.fc' in name and isinstance(module, nn.Conv2d)):
|
||||
nn.init.normal_(module.weight, mean=0.0, std=0.01)
|
||||
nn.init.zeros_(module.bias)
|
||||
|
@ -404,8 +418,8 @@ def _init_weights(module: nn.Module, name: str = '', zero_init_last_bn=True):
|
|||
elif isinstance(module, (nn.BatchNorm2d, nn.LayerNorm, nn.GroupNorm)):
|
||||
nn.init.ones_(module.weight)
|
||||
nn.init.zeros_(module.bias)
|
||||
elif zero_init_last_bn and hasattr(module, 'zero_init_last_bn'):
|
||||
module.zero_init_last_bn()
|
||||
elif zero_init_last and hasattr(module, 'zero_init_last'):
|
||||
module.zero_init_last()
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
|
@ -570,12 +584,68 @@ def resnetv2_152x2_bit_teacher_384(pretrained=False, **kwargs):
|
|||
def resnetv2_50(pretrained=False, **kwargs):
|
||||
return _create_resnetv2(
|
||||
'resnetv2_50', pretrained=pretrained,
|
||||
layers=[3, 4, 6, 3], conv_layer=create_conv2d, norm_layer=nn.BatchNorm2d, **kwargs)
|
||||
layers=[3, 4, 6, 3], conv_layer=create_conv2d, norm_layer=BatchNormAct2d, **kwargs)
|
||||
|
||||
|
||||
@register_model
|
||||
def resnetv2_50d(pretrained=False, **kwargs):
|
||||
return _create_resnetv2(
|
||||
'resnetv2_50d', pretrained=pretrained,
|
||||
layers=[3, 4, 6, 3], conv_layer=create_conv2d, norm_layer=nn.BatchNorm2d,
|
||||
layers=[3, 4, 6, 3], conv_layer=create_conv2d, norm_layer=BatchNormAct2d,
|
||||
stem_type='deep', avg_down=True, **kwargs)
|
||||
|
||||
|
||||
@register_model
|
||||
def resnetv2_50t(pretrained=False, **kwargs):
|
||||
return _create_resnetv2(
|
||||
'resnetv2_50t', pretrained=pretrained,
|
||||
layers=[3, 4, 6, 3], conv_layer=create_conv2d, norm_layer=BatchNormAct2d,
|
||||
stem_type='tiered', avg_down=True, **kwargs)
|
||||
|
||||
|
||||
@register_model
|
||||
def resnetv2_101(pretrained=False, **kwargs):
|
||||
return _create_resnetv2(
|
||||
'resnetv2_101', pretrained=pretrained,
|
||||
layers=[3, 4, 23, 3], conv_layer=create_conv2d, norm_layer=BatchNormAct2d, **kwargs)
|
||||
|
||||
|
||||
@register_model
|
||||
def resnetv2_101d(pretrained=False, **kwargs):
|
||||
return _create_resnetv2(
|
||||
'resnetv2_101d', pretrained=pretrained,
|
||||
layers=[3, 4, 23, 3], conv_layer=create_conv2d, norm_layer=BatchNormAct2d,
|
||||
stem_type='deep', avg_down=True, **kwargs)
|
||||
|
||||
|
||||
@register_model
|
||||
def resnetv2_152(pretrained=False, **kwargs):
|
||||
return _create_resnetv2(
|
||||
'resnetv2_152', pretrained=pretrained,
|
||||
layers=[3, 8, 36, 3], conv_layer=create_conv2d, norm_layer=BatchNormAct2d, **kwargs)
|
||||
|
||||
|
||||
@register_model
|
||||
def resnetv2_152d(pretrained=False, **kwargs):
|
||||
return _create_resnetv2(
|
||||
'resnetv2_152d', pretrained=pretrained,
|
||||
layers=[3, 8, 36, 3], conv_layer=create_conv2d, norm_layer=BatchNormAct2d,
|
||||
stem_type='deep', avg_down=True, **kwargs)
|
||||
|
||||
|
||||
# @register_model
|
||||
# def resnetv2_50ebd(pretrained=False, **kwargs):
|
||||
# # FIXME for testing w/ TPU + PyTorch XLA
|
||||
# return _create_resnetv2(
|
||||
# 'resnetv2_50d', pretrained=pretrained,
|
||||
# layers=[3, 4, 6, 3], conv_layer=create_conv2d, norm_layer=EvoNormBatch2d,
|
||||
# stem_type='deep', avg_down=True, **kwargs)
|
||||
#
|
||||
#
|
||||
# @register_model
|
||||
# def resnetv2_50esd(pretrained=False, **kwargs):
|
||||
# # FIXME for testing w/ TPU + PyTorch XLA
|
||||
# return _create_resnetv2(
|
||||
# 'resnetv2_50d', pretrained=pretrained,
|
||||
# layers=[3, 4, 6, 3], conv_layer=create_conv2d, norm_layer=EvoNormSample2d,
|
||||
# stem_type='deep', avg_down=True, **kwargs)
|
||||
|
|
Loading…
Reference in New Issue