Update resnetv2.py for multi-weight and HF hub weights
parent
b3e816d6d7
commit
da6bdd4560
|
@ -36,114 +36,15 @@ import torch
|
|||
import torch.nn as nn
|
||||
|
||||
from timm.data import IMAGENET_INCEPTION_MEAN, IMAGENET_INCEPTION_STD
|
||||
from timm.layers import GroupNormAct, BatchNormAct2d, EvoNorm2dB0, EvoNorm2dS0, FilterResponseNormTlu2d, \
|
||||
ClassifierHead, DropPath, AvgPool2dSame, create_pool2d, StdConv2d, create_conv2d, get_act_layer, get_norm_act_layer
|
||||
from timm.layers import GroupNormAct, BatchNormAct2d, EvoNorm2dS0, FilterResponseNormTlu2d, ClassifierHead, \
|
||||
DropPath, AvgPool2dSame, create_pool2d, StdConv2d, create_conv2d, get_act_layer, get_norm_act_layer, make_divisible
|
||||
from ._builder import build_model_with_cfg
|
||||
from ._manipulate import checkpoint_seq, named_apply, adapt_input_conv
|
||||
from ._registry import register_model
|
||||
from ._registry import generate_default_cfgs, register_model, register_model_deprecations
|
||||
|
||||
__all__ = ['ResNetV2'] # model_registry will add each entrypoint fn to this
|
||||
|
||||
|
||||
def _cfg(url='', **kwargs):
|
||||
return {
|
||||
'url': url,
|
||||
'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': (7, 7),
|
||||
'crop_pct': 0.875, 'interpolation': 'bilinear',
|
||||
'mean': IMAGENET_INCEPTION_MEAN, 'std': IMAGENET_INCEPTION_STD,
|
||||
'first_conv': 'stem.conv', 'classifier': 'head.fc',
|
||||
**kwargs
|
||||
}
|
||||
|
||||
|
||||
default_cfgs = {
|
||||
# pretrained on imagenet21k, finetuned on imagenet1k
|
||||
'resnetv2_50x1_bitm': _cfg(
|
||||
url='https://storage.googleapis.com/bit_models/BiT-M-R50x1-ILSVRC2012.npz',
|
||||
input_size=(3, 448, 448), pool_size=(14, 14), crop_pct=1.0, custom_load=True),
|
||||
'resnetv2_50x3_bitm': _cfg(
|
||||
url='https://storage.googleapis.com/bit_models/BiT-M-R50x3-ILSVRC2012.npz',
|
||||
input_size=(3, 448, 448), pool_size=(14, 14), crop_pct=1.0, custom_load=True),
|
||||
'resnetv2_101x1_bitm': _cfg(
|
||||
url='https://storage.googleapis.com/bit_models/BiT-M-R101x1-ILSVRC2012.npz',
|
||||
input_size=(3, 448, 448), pool_size=(14, 14), crop_pct=1.0, custom_load=True),
|
||||
'resnetv2_101x3_bitm': _cfg(
|
||||
url='https://storage.googleapis.com/bit_models/BiT-M-R101x3-ILSVRC2012.npz',
|
||||
input_size=(3, 448, 448), pool_size=(14, 14), crop_pct=1.0, custom_load=True),
|
||||
'resnetv2_152x2_bitm': _cfg(
|
||||
url='https://storage.googleapis.com/bit_models/BiT-M-R152x2-ILSVRC2012.npz',
|
||||
input_size=(3, 448, 448), pool_size=(14, 14), crop_pct=1.0, custom_load=True),
|
||||
'resnetv2_152x4_bitm': _cfg(
|
||||
url='https://storage.googleapis.com/bit_models/BiT-M-R152x4-ILSVRC2012.npz',
|
||||
input_size=(3, 480, 480), pool_size=(15, 15), crop_pct=1.0, custom_load=True), # only one at 480x480?
|
||||
|
||||
# trained on imagenet-21k
|
||||
'resnetv2_50x1_bitm_in21k': _cfg(
|
||||
url='https://storage.googleapis.com/bit_models/BiT-M-R50x1.npz',
|
||||
num_classes=21843, custom_load=True),
|
||||
'resnetv2_50x3_bitm_in21k': _cfg(
|
||||
url='https://storage.googleapis.com/bit_models/BiT-M-R50x3.npz',
|
||||
num_classes=21843, custom_load=True),
|
||||
'resnetv2_101x1_bitm_in21k': _cfg(
|
||||
url='https://storage.googleapis.com/bit_models/BiT-M-R101x1.npz',
|
||||
num_classes=21843, custom_load=True),
|
||||
'resnetv2_101x3_bitm_in21k': _cfg(
|
||||
url='https://storage.googleapis.com/bit_models/BiT-M-R101x3.npz',
|
||||
num_classes=21843, custom_load=True),
|
||||
'resnetv2_152x2_bitm_in21k': _cfg(
|
||||
url='https://storage.googleapis.com/bit_models/BiT-M-R152x2.npz',
|
||||
num_classes=21843, custom_load=True),
|
||||
'resnetv2_152x4_bitm_in21k': _cfg(
|
||||
url='https://storage.googleapis.com/bit_models/BiT-M-R152x4.npz',
|
||||
num_classes=21843, custom_load=True),
|
||||
|
||||
'resnetv2_50x1_bit_distilled': _cfg(
|
||||
url='https://storage.googleapis.com/bit_models/distill/R50x1_224.npz',
|
||||
interpolation='bicubic', custom_load=True),
|
||||
'resnetv2_152x2_bit_teacher': _cfg(
|
||||
url='https://storage.googleapis.com/bit_models/distill/R152x2_T_224.npz',
|
||||
interpolation='bicubic', custom_load=True),
|
||||
'resnetv2_152x2_bit_teacher_384': _cfg(
|
||||
url='https://storage.googleapis.com/bit_models/distill/R152x2_T_384.npz',
|
||||
input_size=(3, 384, 384), pool_size=(12, 12), crop_pct=1.0, interpolation='bicubic', custom_load=True),
|
||||
|
||||
'resnetv2_50': _cfg(
|
||||
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-rsb-weights/resnetv2_50_a1h-000cdf49.pth',
|
||||
interpolation='bicubic', crop_pct=0.95),
|
||||
'resnetv2_50d': _cfg(
|
||||
interpolation='bicubic', first_conv='stem.conv1'),
|
||||
'resnetv2_50t': _cfg(
|
||||
interpolation='bicubic', first_conv='stem.conv1'),
|
||||
'resnetv2_101': _cfg(
|
||||
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-rsb-weights/resnetv2_101_a1h-5d01f016.pth',
|
||||
interpolation='bicubic', crop_pct=0.95),
|
||||
'resnetv2_101d': _cfg(
|
||||
interpolation='bicubic', first_conv='stem.conv1'),
|
||||
'resnetv2_152': _cfg(
|
||||
interpolation='bicubic'),
|
||||
'resnetv2_152d': _cfg(
|
||||
interpolation='bicubic', first_conv='stem.conv1'),
|
||||
|
||||
'resnetv2_50d_gn': _cfg(
|
||||
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-tpu-weights/resnetv2_50d_gn_ah-c415c11a.pth',
|
||||
interpolation='bicubic', first_conv='stem.conv1', test_input_size=(3, 288, 288), crop_pct=0.95),
|
||||
'resnetv2_50d_evob': _cfg(
|
||||
interpolation='bicubic', first_conv='stem.conv1'),
|
||||
'resnetv2_50d_evos': _cfg(
|
||||
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-tpu-weights/resnetv2_50d_evos_ah-7c4dd548.pth',
|
||||
interpolation='bicubic', first_conv='stem.conv1', test_input_size=(3, 288, 288), crop_pct=0.95),
|
||||
'resnetv2_50d_frn': _cfg(
|
||||
interpolation='bicubic', first_conv='stem.conv1'),
|
||||
}
|
||||
|
||||
|
||||
def make_div(v, divisor=8):
|
||||
min_value = divisor
|
||||
new_v = max(min_value, int(v + divisor / 2) // divisor * divisor)
|
||||
if new_v < 0.9 * v:
|
||||
new_v += divisor
|
||||
return new_v
|
||||
|
||||
|
||||
class PreActBottleneck(nn.Module):
|
||||
"""Pre-activation (v2) bottleneck block.
|
||||
|
@ -174,7 +75,7 @@ class PreActBottleneck(nn.Module):
|
|||
conv_layer = conv_layer or StdConv2d
|
||||
norm_layer = norm_layer or partial(GroupNormAct, num_groups=32)
|
||||
out_chs = out_chs or in_chs
|
||||
mid_chs = make_div(out_chs * bottle_ratio)
|
||||
mid_chs = make_divisible(out_chs * bottle_ratio)
|
||||
|
||||
if proj_layer is not None:
|
||||
self.downsample = proj_layer(
|
||||
|
@ -234,7 +135,7 @@ class Bottleneck(nn.Module):
|
|||
conv_layer = conv_layer or StdConv2d
|
||||
norm_layer = norm_layer or partial(GroupNormAct, num_groups=32)
|
||||
out_chs = out_chs or in_chs
|
||||
mid_chs = make_div(out_chs * bottle_ratio)
|
||||
mid_chs = make_divisible(out_chs * bottle_ratio)
|
||||
|
||||
if proj_layer is not None:
|
||||
self.downsample = proj_layer(
|
||||
|
@ -472,7 +373,7 @@ class ResNetV2(nn.Module):
|
|||
act_layer = get_act_layer(act_layer)
|
||||
|
||||
self.feature_info = []
|
||||
stem_chs = make_div(stem_chs * wf)
|
||||
stem_chs = make_divisible(stem_chs * wf)
|
||||
self.stem = create_resnetv2_stem(
|
||||
in_chans,
|
||||
stem_chs,
|
||||
|
@ -491,7 +392,7 @@ class ResNetV2(nn.Module):
|
|||
block_fn = PreActBottleneck if preact else Bottleneck
|
||||
self.stages = nn.Sequential()
|
||||
for stage_idx, (d, c, bdpr) in enumerate(zip(layers, channels, block_dprs)):
|
||||
out_chs = make_div(c * wf)
|
||||
out_chs = make_divisible(c * wf)
|
||||
stride = 1 if stage_idx == 0 else 2
|
||||
if curr_stride >= output_stride:
|
||||
dilation *= stride
|
||||
|
@ -517,7 +418,12 @@ class ResNetV2(nn.Module):
|
|||
self.num_features = prev_chs
|
||||
self.norm = norm_layer(self.num_features) if preact else nn.Identity()
|
||||
self.head = ClassifierHead(
|
||||
self.num_features, num_classes, pool_type=global_pool, drop_rate=self.drop_rate, use_conv=True)
|
||||
self.num_features,
|
||||
num_classes,
|
||||
pool_type=global_pool,
|
||||
drop_rate=self.drop_rate,
|
||||
use_conv=True,
|
||||
)
|
||||
|
||||
self.init_weights(zero_init_last=zero_init_last)
|
||||
self.grad_checkpointing = False
|
||||
|
@ -551,8 +457,7 @@ class ResNetV2(nn.Module):
|
|||
|
||||
def reset_classifier(self, num_classes, global_pool='avg'):
|
||||
self.num_classes = num_classes
|
||||
self.head = ClassifierHead(
|
||||
self.num_features, num_classes, pool_type=global_pool, drop_rate=self.drop_rate, use_conv=True)
|
||||
self.head.reset(num_classes, global_pool)
|
||||
|
||||
def forward_features(self, x):
|
||||
x = self.stem(x)
|
||||
|
@ -636,112 +541,141 @@ def _create_resnetv2(variant, pretrained=False, **kwargs):
|
|||
|
||||
def _create_resnetv2_bit(variant, pretrained=False, **kwargs):
|
||||
return _create_resnetv2(
|
||||
variant, pretrained=pretrained, stem_type='fixed', conv_layer=partial(StdConv2d, eps=1e-8), **kwargs)
|
||||
variant,
|
||||
pretrained=pretrained,
|
||||
stem_type='fixed',
|
||||
conv_layer=partial(StdConv2d, eps=1e-8),
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
|
||||
def _cfg(url='', **kwargs):
|
||||
return {
|
||||
'url': url,
|
||||
'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': (7, 7),
|
||||
'crop_pct': 0.875, 'interpolation': 'bilinear',
|
||||
'mean': IMAGENET_INCEPTION_MEAN, 'std': IMAGENET_INCEPTION_STD,
|
||||
'first_conv': 'stem.conv', 'classifier': 'head.fc',
|
||||
**kwargs
|
||||
}
|
||||
|
||||
|
||||
default_cfgs = generate_default_cfgs({
|
||||
# Paper: Knowledge distillation: A good teacher is patient and consistent - https://arxiv.org/abs/2106.05237
|
||||
'resnetv2_50x1_bit.goog_distilled_in1k': _cfg(
|
||||
hf_hub_id='timm/',
|
||||
interpolation='bicubic', custom_load=True),
|
||||
'resnetv2_152x2_bit.goog_teacher_in21k_ft_in1k': _cfg(
|
||||
hf_hub_id='timm/',
|
||||
interpolation='bicubic', custom_load=True),
|
||||
'resnetv2_152x2_bit.goog_teacher_in21k_ft_in1k_384': _cfg(
|
||||
hf_hub_id='timm/',
|
||||
input_size=(3, 384, 384), pool_size=(12, 12), crop_pct=1.0, interpolation='bicubic', custom_load=True),
|
||||
|
||||
# pretrained on imagenet21k, finetuned on imagenet1k
|
||||
'resnetv2_50x1_bit.goog_in21k_ft_in1k': _cfg(
|
||||
hf_hub_id='timm/',
|
||||
input_size=(3, 448, 448), pool_size=(14, 14), crop_pct=1.0, custom_load=True),
|
||||
'resnetv2_50x3_bit.goog_in21k_ft_in1k': _cfg(
|
||||
hf_hub_id='timm/',
|
||||
input_size=(3, 448, 448), pool_size=(14, 14), crop_pct=1.0, custom_load=True),
|
||||
'resnetv2_101x1_bit.goog_in21k_ft_in1k': _cfg(
|
||||
hf_hub_id='timm/',
|
||||
input_size=(3, 448, 448), pool_size=(14, 14), crop_pct=1.0, custom_load=True),
|
||||
'resnetv2_101x3_bit.goog_in21k_ft_in1k': _cfg(
|
||||
hf_hub_id='timm/',
|
||||
input_size=(3, 448, 448), pool_size=(14, 14), crop_pct=1.0, custom_load=True),
|
||||
'resnetv2_152x2_bit.goog_in21k_ft_in1k': _cfg(
|
||||
hf_hub_id='timm/',
|
||||
input_size=(3, 448, 448), pool_size=(14, 14), crop_pct=1.0, custom_load=True),
|
||||
'resnetv2_152x4_bit.goog_in21k_ft_in1k': _cfg(
|
||||
hf_hub_id='timm/',
|
||||
input_size=(3, 480, 480), pool_size=(15, 15), crop_pct=1.0, custom_load=True), # only one at 480x480?
|
||||
|
||||
# trained on imagenet-21k
|
||||
'resnetv2_50x1_bit.goog_in21k': _cfg(
|
||||
hf_hub_id='timm/',
|
||||
num_classes=21843, custom_load=True),
|
||||
'resnetv2_50x3_bit.goog_in21k': _cfg(
|
||||
hf_hub_id='timm/',
|
||||
num_classes=21843, custom_load=True),
|
||||
'resnetv2_101x1_bit.goog_in21k': _cfg(
|
||||
hf_hub_id='timm/',
|
||||
num_classes=21843, custom_load=True),
|
||||
'resnetv2_101x3_bit.goog_in21k': _cfg(
|
||||
hf_hub_id='timm/',
|
||||
num_classes=21843, custom_load=True),
|
||||
'resnetv2_152x2_bit.goog_in21k': _cfg(
|
||||
hf_hub_id='timm/',
|
||||
num_classes=21843, custom_load=True),
|
||||
'resnetv2_152x4_bit.goog_in21k': _cfg(
|
||||
hf_hub_id='timm/',
|
||||
num_classes=21843, custom_load=True),
|
||||
|
||||
'resnetv2_50.a1h_in1k': _cfg(
|
||||
hf_hub_id='timm/',
|
||||
interpolation='bicubic', crop_pct=0.95, test_input_size=(3, 288, 288), test_crop_pct=1.0),
|
||||
'resnetv2_50d.untrained': _cfg(
|
||||
interpolation='bicubic', first_conv='stem.conv1'),
|
||||
'resnetv2_50t.untrained': _cfg(
|
||||
interpolation='bicubic', first_conv='stem.conv1'),
|
||||
'resnetv2_101.a1h_in1k': _cfg(
|
||||
hf_hub_id='timm/',
|
||||
interpolation='bicubic', crop_pct=0.95, test_input_size=(3, 288, 288), test_crop_pct=1.0),
|
||||
'resnetv2_101d.untrained': _cfg(
|
||||
interpolation='bicubic', first_conv='stem.conv1'),
|
||||
'resnetv2_152.untrained': _cfg(
|
||||
interpolation='bicubic'),
|
||||
'resnetv2_152d.untrained': _cfg(
|
||||
interpolation='bicubic', first_conv='stem.conv1'),
|
||||
|
||||
'resnetv2_50d_gn.ah_in1k': _cfg(
|
||||
hf_hub_id='timm/',
|
||||
interpolation='bicubic', first_conv='stem.conv1',
|
||||
crop_pct=0.95, test_input_size=(3, 288, 288), test_crop_pct=1.0),
|
||||
'resnetv2_50d_evos.ah_in1k': _cfg(
|
||||
hf_hub_id='timm/',
|
||||
interpolation='bicubic', first_conv='stem.conv1',
|
||||
crop_pct=0.95, test_input_size=(3, 288, 288), test_crop_pct=1.0),
|
||||
'resnetv2_50d_frn.untrained': _cfg(
|
||||
interpolation='bicubic', first_conv='stem.conv1'),
|
||||
})
|
||||
|
||||
|
||||
@register_model
|
||||
def resnetv2_50x1_bitm(pretrained=False, **kwargs):
|
||||
def resnetv2_50x1_bit(pretrained=False, **kwargs):
|
||||
return _create_resnetv2_bit(
|
||||
'resnetv2_50x1_bitm', pretrained=pretrained, layers=[3, 4, 6, 3], width_factor=1, **kwargs)
|
||||
'resnetv2_50x1_bit', pretrained=pretrained, layers=[3, 4, 6, 3], width_factor=1, **kwargs)
|
||||
|
||||
|
||||
@register_model
|
||||
def resnetv2_50x3_bitm(pretrained=False, **kwargs):
|
||||
def resnetv2_50x3_bit(pretrained=False, **kwargs):
|
||||
return _create_resnetv2_bit(
|
||||
'resnetv2_50x3_bitm', pretrained=pretrained, layers=[3, 4, 6, 3], width_factor=3, **kwargs)
|
||||
'resnetv2_50x3_bit', pretrained=pretrained, layers=[3, 4, 6, 3], width_factor=3, **kwargs)
|
||||
|
||||
|
||||
@register_model
|
||||
def resnetv2_101x1_bitm(pretrained=False, **kwargs):
|
||||
def resnetv2_101x1_bit(pretrained=False, **kwargs):
|
||||
return _create_resnetv2_bit(
|
||||
'resnetv2_101x1_bitm', pretrained=pretrained, layers=[3, 4, 23, 3], width_factor=1, **kwargs)
|
||||
'resnetv2_101x1_bit', pretrained=pretrained, layers=[3, 4, 23, 3], width_factor=1, **kwargs)
|
||||
|
||||
|
||||
@register_model
|
||||
def resnetv2_101x3_bitm(pretrained=False, **kwargs):
|
||||
def resnetv2_101x3_bit(pretrained=False, **kwargs):
|
||||
return _create_resnetv2_bit(
|
||||
'resnetv2_101x3_bitm', pretrained=pretrained, layers=[3, 4, 23, 3], width_factor=3, **kwargs)
|
||||
'resnetv2_101x3_bit', pretrained=pretrained, layers=[3, 4, 23, 3], width_factor=3, **kwargs)
|
||||
|
||||
|
||||
@register_model
|
||||
def resnetv2_152x2_bitm(pretrained=False, **kwargs):
|
||||
def resnetv2_152x2_bit(pretrained=False, **kwargs):
|
||||
return _create_resnetv2_bit(
|
||||
'resnetv2_152x2_bitm', pretrained=pretrained, layers=[3, 8, 36, 3], width_factor=2, **kwargs)
|
||||
'resnetv2_152x2_bit', pretrained=pretrained, layers=[3, 8, 36, 3], width_factor=2, **kwargs)
|
||||
|
||||
|
||||
@register_model
|
||||
def resnetv2_152x4_bitm(pretrained=False, **kwargs):
|
||||
def resnetv2_152x4_bit(pretrained=False, **kwargs):
|
||||
return _create_resnetv2_bit(
|
||||
'resnetv2_152x4_bitm', pretrained=pretrained, layers=[3, 8, 36, 3], width_factor=4, **kwargs)
|
||||
|
||||
|
||||
@register_model
|
||||
def resnetv2_50x1_bitm_in21k(pretrained=False, **kwargs):
|
||||
return _create_resnetv2_bit(
|
||||
'resnetv2_50x1_bitm_in21k', pretrained=pretrained, num_classes=kwargs.pop('num_classes', 21843),
|
||||
layers=[3, 4, 6, 3], width_factor=1, **kwargs)
|
||||
|
||||
|
||||
@register_model
|
||||
def resnetv2_50x3_bitm_in21k(pretrained=False, **kwargs):
|
||||
return _create_resnetv2_bit(
|
||||
'resnetv2_50x3_bitm_in21k', pretrained=pretrained, num_classes=kwargs.pop('num_classes', 21843),
|
||||
layers=[3, 4, 6, 3], width_factor=3, **kwargs)
|
||||
|
||||
|
||||
@register_model
|
||||
def resnetv2_101x1_bitm_in21k(pretrained=False, **kwargs):
|
||||
return _create_resnetv2(
|
||||
'resnetv2_101x1_bitm_in21k', pretrained=pretrained, num_classes=kwargs.pop('num_classes', 21843),
|
||||
layers=[3, 4, 23, 3], width_factor=1, **kwargs)
|
||||
|
||||
|
||||
@register_model
|
||||
def resnetv2_101x3_bitm_in21k(pretrained=False, **kwargs):
|
||||
return _create_resnetv2_bit(
|
||||
'resnetv2_101x3_bitm_in21k', pretrained=pretrained, num_classes=kwargs.pop('num_classes', 21843),
|
||||
layers=[3, 4, 23, 3], width_factor=3, **kwargs)
|
||||
|
||||
|
||||
@register_model
|
||||
def resnetv2_152x2_bitm_in21k(pretrained=False, **kwargs):
|
||||
return _create_resnetv2_bit(
|
||||
'resnetv2_152x2_bitm_in21k', pretrained=pretrained, num_classes=kwargs.pop('num_classes', 21843),
|
||||
layers=[3, 8, 36, 3], width_factor=2, **kwargs)
|
||||
|
||||
|
||||
@register_model
|
||||
def resnetv2_152x4_bitm_in21k(pretrained=False, **kwargs):
|
||||
return _create_resnetv2_bit(
|
||||
'resnetv2_152x4_bitm_in21k', pretrained=pretrained, num_classes=kwargs.pop('num_classes', 21843),
|
||||
layers=[3, 8, 36, 3], width_factor=4, **kwargs)
|
||||
|
||||
|
||||
@register_model
|
||||
def resnetv2_50x1_bit_distilled(pretrained=False, **kwargs):
|
||||
""" ResNetV2-50x1-BiT Distilled
|
||||
Paper: Knowledge distillation: A good teacher is patient and consistent - https://arxiv.org/abs/2106.05237
|
||||
"""
|
||||
return _create_resnetv2_bit(
|
||||
'resnetv2_50x1_bit_distilled', pretrained=pretrained, layers=[3, 4, 6, 3], width_factor=1, **kwargs)
|
||||
|
||||
|
||||
@register_model
|
||||
def resnetv2_152x2_bit_teacher(pretrained=False, **kwargs):
|
||||
""" ResNetV2-152x2-BiT Teacher
|
||||
Paper: Knowledge distillation: A good teacher is patient and consistent - https://arxiv.org/abs/2106.05237
|
||||
"""
|
||||
return _create_resnetv2_bit(
|
||||
'resnetv2_152x2_bit_teacher', pretrained=pretrained, layers=[3, 8, 36, 3], width_factor=2, **kwargs)
|
||||
|
||||
|
||||
@register_model
|
||||
def resnetv2_152x2_bit_teacher_384(pretrained=False, **kwargs):
|
||||
""" ResNetV2-152xx-BiT Teacher @ 384x384
|
||||
Paper: Knowledge distillation: A good teacher is patient and consistent - https://arxiv.org/abs/2106.05237
|
||||
"""
|
||||
return _create_resnetv2_bit(
|
||||
'resnetv2_152x2_bit_teacher_384', pretrained=pretrained, layers=[3, 8, 36, 3], width_factor=2, **kwargs)
|
||||
'resnetv2_152x4_bit', pretrained=pretrained, layers=[3, 8, 36, 3], width_factor=4, **kwargs)
|
||||
|
||||
|
||||
@register_model
|
||||
|
@ -804,14 +738,6 @@ def resnetv2_50d_gn(pretrained=False, **kwargs):
|
|||
return _create_resnetv2('resnetv2_50d_gn', pretrained=pretrained, **dict(model_args, **kwargs))
|
||||
|
||||
|
||||
@register_model
|
||||
def resnetv2_50d_evob(pretrained=False, **kwargs):
|
||||
model_args = dict(
|
||||
layers=[3, 4, 6, 3], conv_layer=create_conv2d, norm_layer=EvoNorm2dB0,
|
||||
stem_type='deep', avg_down=True, zero_init_last=True)
|
||||
return _create_resnetv2('resnetv2_50d_evob', pretrained=pretrained, **dict(model_args, **kwargs))
|
||||
|
||||
|
||||
@register_model
|
||||
def resnetv2_50d_evos(pretrained=False, **kwargs):
|
||||
model_args = dict(
|
||||
|
@ -826,3 +752,22 @@ def resnetv2_50d_frn(pretrained=False, **kwargs):
|
|||
layers=[3, 4, 6, 3], conv_layer=create_conv2d, norm_layer=FilterResponseNormTlu2d,
|
||||
stem_type='deep', avg_down=True)
|
||||
return _create_resnetv2('resnetv2_50d_frn', pretrained=pretrained, **dict(model_args, **kwargs))
|
||||
|
||||
|
||||
register_model_deprecations(__name__, {
|
||||
'resnetv2_50x1_bitm': 'resnetv2_50x1_bit.goog_in21k_ft_in1k',
|
||||
'resnetv2_50x3_bitm': 'resnetv2_50x3_bit.goog_in21k_ft_in1k',
|
||||
'resnetv2_101x1_bitm': 'resnetv2_101x1_bit.goog_in21k_ft_in1k',
|
||||
'resnetv2_101x3_bitm': 'resnetv2_101x3_bit.goog_in21k_ft_in1k',
|
||||
'resnetv2_152x2_bitm': 'resnetv2_152x2_bit.goog_in21k_ft_in1k',
|
||||
'resnetv2_152x4_bitm': 'resnetv2_152x4_bit.goog_in21k_ft_in1k',
|
||||
'resnetv2_50x1_bitm_in21k': 'resnetv2_50x1_bit.goog_in21k',
|
||||
'resnetv2_50x3_bitm_in21k': 'resnetv2_50x3_bit.goog_in21k',
|
||||
'resnetv2_101x1_bitm_in21k': 'resnetv2_101x1_bit.goog_in21k',
|
||||
'resnetv2_101x3_bitm_in21k': 'resnetv2_101x3_bit.goog_in21k',
|
||||
'resnetv2_152x2_bitm_in21k': 'resnetv2_152x2_bit.goog_in21k',
|
||||
'resnetv2_152x4_bitm_in21k': 'resnetv2_152x4_bit.goog_in21k',
|
||||
'resnetv2_50x1_bit_distilled': 'resnetv2_50x1_bit.goog_distilled_in1k',
|
||||
'resnetv2_152x2_bit_teacher': 'resnetv2_152x2_bit.goog_teacher_in21k_ft_in1k',
|
||||
'resnetv2_152x2_bit_teacher_384': 'resnetv2_152x2_bit.goog_teacher_in21k_ft_in1k_384',
|
||||
})
|
||||
|
|
Loading…
Reference in New Issue