Update resnetv2.py for multi-weight and HF hub weights

pull/1736/head
Ross Wightman 2023-03-22 15:38:04 -07:00
parent b3e816d6d7
commit da6bdd4560
1 changed files with 144 additions and 199 deletions

View File

@ -36,114 +36,15 @@ import torch
import torch.nn as nn
from timm.data import IMAGENET_INCEPTION_MEAN, IMAGENET_INCEPTION_STD
from timm.layers import GroupNormAct, BatchNormAct2d, EvoNorm2dB0, EvoNorm2dS0, FilterResponseNormTlu2d, \
ClassifierHead, DropPath, AvgPool2dSame, create_pool2d, StdConv2d, create_conv2d, get_act_layer, get_norm_act_layer
from timm.layers import GroupNormAct, BatchNormAct2d, EvoNorm2dS0, FilterResponseNormTlu2d, ClassifierHead, \
DropPath, AvgPool2dSame, create_pool2d, StdConv2d, create_conv2d, get_act_layer, get_norm_act_layer, make_divisible
from ._builder import build_model_with_cfg
from ._manipulate import checkpoint_seq, named_apply, adapt_input_conv
from ._registry import register_model
from ._registry import generate_default_cfgs, register_model, register_model_deprecations
__all__ = ['ResNetV2'] # model_registry will add each entrypoint fn to this
def _cfg(url='', **kwargs):
return {
'url': url,
'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': (7, 7),
'crop_pct': 0.875, 'interpolation': 'bilinear',
'mean': IMAGENET_INCEPTION_MEAN, 'std': IMAGENET_INCEPTION_STD,
'first_conv': 'stem.conv', 'classifier': 'head.fc',
**kwargs
}
default_cfgs = {
# pretrained on imagenet21k, finetuned on imagenet1k
'resnetv2_50x1_bitm': _cfg(
url='https://storage.googleapis.com/bit_models/BiT-M-R50x1-ILSVRC2012.npz',
input_size=(3, 448, 448), pool_size=(14, 14), crop_pct=1.0, custom_load=True),
'resnetv2_50x3_bitm': _cfg(
url='https://storage.googleapis.com/bit_models/BiT-M-R50x3-ILSVRC2012.npz',
input_size=(3, 448, 448), pool_size=(14, 14), crop_pct=1.0, custom_load=True),
'resnetv2_101x1_bitm': _cfg(
url='https://storage.googleapis.com/bit_models/BiT-M-R101x1-ILSVRC2012.npz',
input_size=(3, 448, 448), pool_size=(14, 14), crop_pct=1.0, custom_load=True),
'resnetv2_101x3_bitm': _cfg(
url='https://storage.googleapis.com/bit_models/BiT-M-R101x3-ILSVRC2012.npz',
input_size=(3, 448, 448), pool_size=(14, 14), crop_pct=1.0, custom_load=True),
'resnetv2_152x2_bitm': _cfg(
url='https://storage.googleapis.com/bit_models/BiT-M-R152x2-ILSVRC2012.npz',
input_size=(3, 448, 448), pool_size=(14, 14), crop_pct=1.0, custom_load=True),
'resnetv2_152x4_bitm': _cfg(
url='https://storage.googleapis.com/bit_models/BiT-M-R152x4-ILSVRC2012.npz',
input_size=(3, 480, 480), pool_size=(15, 15), crop_pct=1.0, custom_load=True), # only one at 480x480?
# trained on imagenet-21k
'resnetv2_50x1_bitm_in21k': _cfg(
url='https://storage.googleapis.com/bit_models/BiT-M-R50x1.npz',
num_classes=21843, custom_load=True),
'resnetv2_50x3_bitm_in21k': _cfg(
url='https://storage.googleapis.com/bit_models/BiT-M-R50x3.npz',
num_classes=21843, custom_load=True),
'resnetv2_101x1_bitm_in21k': _cfg(
url='https://storage.googleapis.com/bit_models/BiT-M-R101x1.npz',
num_classes=21843, custom_load=True),
'resnetv2_101x3_bitm_in21k': _cfg(
url='https://storage.googleapis.com/bit_models/BiT-M-R101x3.npz',
num_classes=21843, custom_load=True),
'resnetv2_152x2_bitm_in21k': _cfg(
url='https://storage.googleapis.com/bit_models/BiT-M-R152x2.npz',
num_classes=21843, custom_load=True),
'resnetv2_152x4_bitm_in21k': _cfg(
url='https://storage.googleapis.com/bit_models/BiT-M-R152x4.npz',
num_classes=21843, custom_load=True),
'resnetv2_50x1_bit_distilled': _cfg(
url='https://storage.googleapis.com/bit_models/distill/R50x1_224.npz',
interpolation='bicubic', custom_load=True),
'resnetv2_152x2_bit_teacher': _cfg(
url='https://storage.googleapis.com/bit_models/distill/R152x2_T_224.npz',
interpolation='bicubic', custom_load=True),
'resnetv2_152x2_bit_teacher_384': _cfg(
url='https://storage.googleapis.com/bit_models/distill/R152x2_T_384.npz',
input_size=(3, 384, 384), pool_size=(12, 12), crop_pct=1.0, interpolation='bicubic', custom_load=True),
'resnetv2_50': _cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-rsb-weights/resnetv2_50_a1h-000cdf49.pth',
interpolation='bicubic', crop_pct=0.95),
'resnetv2_50d': _cfg(
interpolation='bicubic', first_conv='stem.conv1'),
'resnetv2_50t': _cfg(
interpolation='bicubic', first_conv='stem.conv1'),
'resnetv2_101': _cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-rsb-weights/resnetv2_101_a1h-5d01f016.pth',
interpolation='bicubic', crop_pct=0.95),
'resnetv2_101d': _cfg(
interpolation='bicubic', first_conv='stem.conv1'),
'resnetv2_152': _cfg(
interpolation='bicubic'),
'resnetv2_152d': _cfg(
interpolation='bicubic', first_conv='stem.conv1'),
'resnetv2_50d_gn': _cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-tpu-weights/resnetv2_50d_gn_ah-c415c11a.pth',
interpolation='bicubic', first_conv='stem.conv1', test_input_size=(3, 288, 288), crop_pct=0.95),
'resnetv2_50d_evob': _cfg(
interpolation='bicubic', first_conv='stem.conv1'),
'resnetv2_50d_evos': _cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-tpu-weights/resnetv2_50d_evos_ah-7c4dd548.pth',
interpolation='bicubic', first_conv='stem.conv1', test_input_size=(3, 288, 288), crop_pct=0.95),
'resnetv2_50d_frn': _cfg(
interpolation='bicubic', first_conv='stem.conv1'),
}
def make_div(v, divisor=8):
min_value = divisor
new_v = max(min_value, int(v + divisor / 2) // divisor * divisor)
if new_v < 0.9 * v:
new_v += divisor
return new_v
class PreActBottleneck(nn.Module):
"""Pre-activation (v2) bottleneck block.
@ -174,7 +75,7 @@ class PreActBottleneck(nn.Module):
conv_layer = conv_layer or StdConv2d
norm_layer = norm_layer or partial(GroupNormAct, num_groups=32)
out_chs = out_chs or in_chs
mid_chs = make_div(out_chs * bottle_ratio)
mid_chs = make_divisible(out_chs * bottle_ratio)
if proj_layer is not None:
self.downsample = proj_layer(
@ -234,7 +135,7 @@ class Bottleneck(nn.Module):
conv_layer = conv_layer or StdConv2d
norm_layer = norm_layer or partial(GroupNormAct, num_groups=32)
out_chs = out_chs or in_chs
mid_chs = make_div(out_chs * bottle_ratio)
mid_chs = make_divisible(out_chs * bottle_ratio)
if proj_layer is not None:
self.downsample = proj_layer(
@ -472,7 +373,7 @@ class ResNetV2(nn.Module):
act_layer = get_act_layer(act_layer)
self.feature_info = []
stem_chs = make_div(stem_chs * wf)
stem_chs = make_divisible(stem_chs * wf)
self.stem = create_resnetv2_stem(
in_chans,
stem_chs,
@ -491,7 +392,7 @@ class ResNetV2(nn.Module):
block_fn = PreActBottleneck if preact else Bottleneck
self.stages = nn.Sequential()
for stage_idx, (d, c, bdpr) in enumerate(zip(layers, channels, block_dprs)):
out_chs = make_div(c * wf)
out_chs = make_divisible(c * wf)
stride = 1 if stage_idx == 0 else 2
if curr_stride >= output_stride:
dilation *= stride
@ -517,7 +418,12 @@ class ResNetV2(nn.Module):
self.num_features = prev_chs
self.norm = norm_layer(self.num_features) if preact else nn.Identity()
self.head = ClassifierHead(
self.num_features, num_classes, pool_type=global_pool, drop_rate=self.drop_rate, use_conv=True)
self.num_features,
num_classes,
pool_type=global_pool,
drop_rate=self.drop_rate,
use_conv=True,
)
self.init_weights(zero_init_last=zero_init_last)
self.grad_checkpointing = False
@ -551,8 +457,7 @@ class ResNetV2(nn.Module):
def reset_classifier(self, num_classes, global_pool='avg'):
self.num_classes = num_classes
self.head = ClassifierHead(
self.num_features, num_classes, pool_type=global_pool, drop_rate=self.drop_rate, use_conv=True)
self.head.reset(num_classes, global_pool)
def forward_features(self, x):
x = self.stem(x)
@ -636,112 +541,141 @@ def _create_resnetv2(variant, pretrained=False, **kwargs):
def _create_resnetv2_bit(variant, pretrained=False, **kwargs):
return _create_resnetv2(
variant, pretrained=pretrained, stem_type='fixed', conv_layer=partial(StdConv2d, eps=1e-8), **kwargs)
variant,
pretrained=pretrained,
stem_type='fixed',
conv_layer=partial(StdConv2d, eps=1e-8),
**kwargs,
)
def _cfg(url='', **kwargs):
return {
'url': url,
'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': (7, 7),
'crop_pct': 0.875, 'interpolation': 'bilinear',
'mean': IMAGENET_INCEPTION_MEAN, 'std': IMAGENET_INCEPTION_STD,
'first_conv': 'stem.conv', 'classifier': 'head.fc',
**kwargs
}
default_cfgs = generate_default_cfgs({
# Paper: Knowledge distillation: A good teacher is patient and consistent - https://arxiv.org/abs/2106.05237
'resnetv2_50x1_bit.goog_distilled_in1k': _cfg(
hf_hub_id='timm/',
interpolation='bicubic', custom_load=True),
'resnetv2_152x2_bit.goog_teacher_in21k_ft_in1k': _cfg(
hf_hub_id='timm/',
interpolation='bicubic', custom_load=True),
'resnetv2_152x2_bit.goog_teacher_in21k_ft_in1k_384': _cfg(
hf_hub_id='timm/',
input_size=(3, 384, 384), pool_size=(12, 12), crop_pct=1.0, interpolation='bicubic', custom_load=True),
# pretrained on imagenet21k, finetuned on imagenet1k
'resnetv2_50x1_bit.goog_in21k_ft_in1k': _cfg(
hf_hub_id='timm/',
input_size=(3, 448, 448), pool_size=(14, 14), crop_pct=1.0, custom_load=True),
'resnetv2_50x3_bit.goog_in21k_ft_in1k': _cfg(
hf_hub_id='timm/',
input_size=(3, 448, 448), pool_size=(14, 14), crop_pct=1.0, custom_load=True),
'resnetv2_101x1_bit.goog_in21k_ft_in1k': _cfg(
hf_hub_id='timm/',
input_size=(3, 448, 448), pool_size=(14, 14), crop_pct=1.0, custom_load=True),
'resnetv2_101x3_bit.goog_in21k_ft_in1k': _cfg(
hf_hub_id='timm/',
input_size=(3, 448, 448), pool_size=(14, 14), crop_pct=1.0, custom_load=True),
'resnetv2_152x2_bit.goog_in21k_ft_in1k': _cfg(
hf_hub_id='timm/',
input_size=(3, 448, 448), pool_size=(14, 14), crop_pct=1.0, custom_load=True),
'resnetv2_152x4_bit.goog_in21k_ft_in1k': _cfg(
hf_hub_id='timm/',
input_size=(3, 480, 480), pool_size=(15, 15), crop_pct=1.0, custom_load=True), # only one at 480x480?
# trained on imagenet-21k
'resnetv2_50x1_bit.goog_in21k': _cfg(
hf_hub_id='timm/',
num_classes=21843, custom_load=True),
'resnetv2_50x3_bit.goog_in21k': _cfg(
hf_hub_id='timm/',
num_classes=21843, custom_load=True),
'resnetv2_101x1_bit.goog_in21k': _cfg(
hf_hub_id='timm/',
num_classes=21843, custom_load=True),
'resnetv2_101x3_bit.goog_in21k': _cfg(
hf_hub_id='timm/',
num_classes=21843, custom_load=True),
'resnetv2_152x2_bit.goog_in21k': _cfg(
hf_hub_id='timm/',
num_classes=21843, custom_load=True),
'resnetv2_152x4_bit.goog_in21k': _cfg(
hf_hub_id='timm/',
num_classes=21843, custom_load=True),
'resnetv2_50.a1h_in1k': _cfg(
hf_hub_id='timm/',
interpolation='bicubic', crop_pct=0.95, test_input_size=(3, 288, 288), test_crop_pct=1.0),
'resnetv2_50d.untrained': _cfg(
interpolation='bicubic', first_conv='stem.conv1'),
'resnetv2_50t.untrained': _cfg(
interpolation='bicubic', first_conv='stem.conv1'),
'resnetv2_101.a1h_in1k': _cfg(
hf_hub_id='timm/',
interpolation='bicubic', crop_pct=0.95, test_input_size=(3, 288, 288), test_crop_pct=1.0),
'resnetv2_101d.untrained': _cfg(
interpolation='bicubic', first_conv='stem.conv1'),
'resnetv2_152.untrained': _cfg(
interpolation='bicubic'),
'resnetv2_152d.untrained': _cfg(
interpolation='bicubic', first_conv='stem.conv1'),
'resnetv2_50d_gn.ah_in1k': _cfg(
hf_hub_id='timm/',
interpolation='bicubic', first_conv='stem.conv1',
crop_pct=0.95, test_input_size=(3, 288, 288), test_crop_pct=1.0),
'resnetv2_50d_evos.ah_in1k': _cfg(
hf_hub_id='timm/',
interpolation='bicubic', first_conv='stem.conv1',
crop_pct=0.95, test_input_size=(3, 288, 288), test_crop_pct=1.0),
'resnetv2_50d_frn.untrained': _cfg(
interpolation='bicubic', first_conv='stem.conv1'),
})
@register_model
def resnetv2_50x1_bitm(pretrained=False, **kwargs):
def resnetv2_50x1_bit(pretrained=False, **kwargs):
return _create_resnetv2_bit(
'resnetv2_50x1_bitm', pretrained=pretrained, layers=[3, 4, 6, 3], width_factor=1, **kwargs)
'resnetv2_50x1_bit', pretrained=pretrained, layers=[3, 4, 6, 3], width_factor=1, **kwargs)
@register_model
def resnetv2_50x3_bitm(pretrained=False, **kwargs):
def resnetv2_50x3_bit(pretrained=False, **kwargs):
return _create_resnetv2_bit(
'resnetv2_50x3_bitm', pretrained=pretrained, layers=[3, 4, 6, 3], width_factor=3, **kwargs)
'resnetv2_50x3_bit', pretrained=pretrained, layers=[3, 4, 6, 3], width_factor=3, **kwargs)
@register_model
def resnetv2_101x1_bitm(pretrained=False, **kwargs):
def resnetv2_101x1_bit(pretrained=False, **kwargs):
return _create_resnetv2_bit(
'resnetv2_101x1_bitm', pretrained=pretrained, layers=[3, 4, 23, 3], width_factor=1, **kwargs)
'resnetv2_101x1_bit', pretrained=pretrained, layers=[3, 4, 23, 3], width_factor=1, **kwargs)
@register_model
def resnetv2_101x3_bitm(pretrained=False, **kwargs):
def resnetv2_101x3_bit(pretrained=False, **kwargs):
return _create_resnetv2_bit(
'resnetv2_101x3_bitm', pretrained=pretrained, layers=[3, 4, 23, 3], width_factor=3, **kwargs)
'resnetv2_101x3_bit', pretrained=pretrained, layers=[3, 4, 23, 3], width_factor=3, **kwargs)
@register_model
def resnetv2_152x2_bitm(pretrained=False, **kwargs):
def resnetv2_152x2_bit(pretrained=False, **kwargs):
return _create_resnetv2_bit(
'resnetv2_152x2_bitm', pretrained=pretrained, layers=[3, 8, 36, 3], width_factor=2, **kwargs)
'resnetv2_152x2_bit', pretrained=pretrained, layers=[3, 8, 36, 3], width_factor=2, **kwargs)
@register_model
def resnetv2_152x4_bitm(pretrained=False, **kwargs):
def resnetv2_152x4_bit(pretrained=False, **kwargs):
return _create_resnetv2_bit(
'resnetv2_152x4_bitm', pretrained=pretrained, layers=[3, 8, 36, 3], width_factor=4, **kwargs)
@register_model
def resnetv2_50x1_bitm_in21k(pretrained=False, **kwargs):
return _create_resnetv2_bit(
'resnetv2_50x1_bitm_in21k', pretrained=pretrained, num_classes=kwargs.pop('num_classes', 21843),
layers=[3, 4, 6, 3], width_factor=1, **kwargs)
@register_model
def resnetv2_50x3_bitm_in21k(pretrained=False, **kwargs):
return _create_resnetv2_bit(
'resnetv2_50x3_bitm_in21k', pretrained=pretrained, num_classes=kwargs.pop('num_classes', 21843),
layers=[3, 4, 6, 3], width_factor=3, **kwargs)
@register_model
def resnetv2_101x1_bitm_in21k(pretrained=False, **kwargs):
return _create_resnetv2(
'resnetv2_101x1_bitm_in21k', pretrained=pretrained, num_classes=kwargs.pop('num_classes', 21843),
layers=[3, 4, 23, 3], width_factor=1, **kwargs)
@register_model
def resnetv2_101x3_bitm_in21k(pretrained=False, **kwargs):
return _create_resnetv2_bit(
'resnetv2_101x3_bitm_in21k', pretrained=pretrained, num_classes=kwargs.pop('num_classes', 21843),
layers=[3, 4, 23, 3], width_factor=3, **kwargs)
@register_model
def resnetv2_152x2_bitm_in21k(pretrained=False, **kwargs):
return _create_resnetv2_bit(
'resnetv2_152x2_bitm_in21k', pretrained=pretrained, num_classes=kwargs.pop('num_classes', 21843),
layers=[3, 8, 36, 3], width_factor=2, **kwargs)
@register_model
def resnetv2_152x4_bitm_in21k(pretrained=False, **kwargs):
return _create_resnetv2_bit(
'resnetv2_152x4_bitm_in21k', pretrained=pretrained, num_classes=kwargs.pop('num_classes', 21843),
layers=[3, 8, 36, 3], width_factor=4, **kwargs)
@register_model
def resnetv2_50x1_bit_distilled(pretrained=False, **kwargs):
""" ResNetV2-50x1-BiT Distilled
Paper: Knowledge distillation: A good teacher is patient and consistent - https://arxiv.org/abs/2106.05237
"""
return _create_resnetv2_bit(
'resnetv2_50x1_bit_distilled', pretrained=pretrained, layers=[3, 4, 6, 3], width_factor=1, **kwargs)
@register_model
def resnetv2_152x2_bit_teacher(pretrained=False, **kwargs):
""" ResNetV2-152x2-BiT Teacher
Paper: Knowledge distillation: A good teacher is patient and consistent - https://arxiv.org/abs/2106.05237
"""
return _create_resnetv2_bit(
'resnetv2_152x2_bit_teacher', pretrained=pretrained, layers=[3, 8, 36, 3], width_factor=2, **kwargs)
@register_model
def resnetv2_152x2_bit_teacher_384(pretrained=False, **kwargs):
""" ResNetV2-152xx-BiT Teacher @ 384x384
Paper: Knowledge distillation: A good teacher is patient and consistent - https://arxiv.org/abs/2106.05237
"""
return _create_resnetv2_bit(
'resnetv2_152x2_bit_teacher_384', pretrained=pretrained, layers=[3, 8, 36, 3], width_factor=2, **kwargs)
'resnetv2_152x4_bit', pretrained=pretrained, layers=[3, 8, 36, 3], width_factor=4, **kwargs)
@register_model
@ -804,14 +738,6 @@ def resnetv2_50d_gn(pretrained=False, **kwargs):
return _create_resnetv2('resnetv2_50d_gn', pretrained=pretrained, **dict(model_args, **kwargs))
@register_model
def resnetv2_50d_evob(pretrained=False, **kwargs):
model_args = dict(
layers=[3, 4, 6, 3], conv_layer=create_conv2d, norm_layer=EvoNorm2dB0,
stem_type='deep', avg_down=True, zero_init_last=True)
return _create_resnetv2('resnetv2_50d_evob', pretrained=pretrained, **dict(model_args, **kwargs))
@register_model
def resnetv2_50d_evos(pretrained=False, **kwargs):
model_args = dict(
@ -826,3 +752,22 @@ def resnetv2_50d_frn(pretrained=False, **kwargs):
layers=[3, 4, 6, 3], conv_layer=create_conv2d, norm_layer=FilterResponseNormTlu2d,
stem_type='deep', avg_down=True)
return _create_resnetv2('resnetv2_50d_frn', pretrained=pretrained, **dict(model_args, **kwargs))
register_model_deprecations(__name__, {
'resnetv2_50x1_bitm': 'resnetv2_50x1_bit.goog_in21k_ft_in1k',
'resnetv2_50x3_bitm': 'resnetv2_50x3_bit.goog_in21k_ft_in1k',
'resnetv2_101x1_bitm': 'resnetv2_101x1_bit.goog_in21k_ft_in1k',
'resnetv2_101x3_bitm': 'resnetv2_101x3_bit.goog_in21k_ft_in1k',
'resnetv2_152x2_bitm': 'resnetv2_152x2_bit.goog_in21k_ft_in1k',
'resnetv2_152x4_bitm': 'resnetv2_152x4_bit.goog_in21k_ft_in1k',
'resnetv2_50x1_bitm_in21k': 'resnetv2_50x1_bit.goog_in21k',
'resnetv2_50x3_bitm_in21k': 'resnetv2_50x3_bit.goog_in21k',
'resnetv2_101x1_bitm_in21k': 'resnetv2_101x1_bit.goog_in21k',
'resnetv2_101x3_bitm_in21k': 'resnetv2_101x3_bit.goog_in21k',
'resnetv2_152x2_bitm_in21k': 'resnetv2_152x2_bit.goog_in21k',
'resnetv2_152x4_bitm_in21k': 'resnetv2_152x4_bit.goog_in21k',
'resnetv2_50x1_bit_distilled': 'resnetv2_50x1_bit.goog_distilled_in1k',
'resnetv2_152x2_bit_teacher': 'resnetv2_152x2_bit.goog_teacher_in21k_ft_in1k',
'resnetv2_152x2_bit_teacher_384': 'resnetv2_152x2_bit.goog_teacher_in21k_ft_in1k_384',
})