Update resnetv2.py for multi-weight and HF hub weights
parent
b3e816d6d7
commit
da6bdd4560
|
@ -36,114 +36,15 @@ import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
|
|
||||||
from timm.data import IMAGENET_INCEPTION_MEAN, IMAGENET_INCEPTION_STD
|
from timm.data import IMAGENET_INCEPTION_MEAN, IMAGENET_INCEPTION_STD
|
||||||
from timm.layers import GroupNormAct, BatchNormAct2d, EvoNorm2dB0, EvoNorm2dS0, FilterResponseNormTlu2d, \
|
from timm.layers import GroupNormAct, BatchNormAct2d, EvoNorm2dS0, FilterResponseNormTlu2d, ClassifierHead, \
|
||||||
ClassifierHead, DropPath, AvgPool2dSame, create_pool2d, StdConv2d, create_conv2d, get_act_layer, get_norm_act_layer
|
DropPath, AvgPool2dSame, create_pool2d, StdConv2d, create_conv2d, get_act_layer, get_norm_act_layer, make_divisible
|
||||||
from ._builder import build_model_with_cfg
|
from ._builder import build_model_with_cfg
|
||||||
from ._manipulate import checkpoint_seq, named_apply, adapt_input_conv
|
from ._manipulate import checkpoint_seq, named_apply, adapt_input_conv
|
||||||
from ._registry import register_model
|
from ._registry import generate_default_cfgs, register_model, register_model_deprecations
|
||||||
|
|
||||||
__all__ = ['ResNetV2'] # model_registry will add each entrypoint fn to this
|
__all__ = ['ResNetV2'] # model_registry will add each entrypoint fn to this
|
||||||
|
|
||||||
|
|
||||||
def _cfg(url='', **kwargs):
|
|
||||||
return {
|
|
||||||
'url': url,
|
|
||||||
'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': (7, 7),
|
|
||||||
'crop_pct': 0.875, 'interpolation': 'bilinear',
|
|
||||||
'mean': IMAGENET_INCEPTION_MEAN, 'std': IMAGENET_INCEPTION_STD,
|
|
||||||
'first_conv': 'stem.conv', 'classifier': 'head.fc',
|
|
||||||
**kwargs
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
default_cfgs = {
|
|
||||||
# pretrained on imagenet21k, finetuned on imagenet1k
|
|
||||||
'resnetv2_50x1_bitm': _cfg(
|
|
||||||
url='https://storage.googleapis.com/bit_models/BiT-M-R50x1-ILSVRC2012.npz',
|
|
||||||
input_size=(3, 448, 448), pool_size=(14, 14), crop_pct=1.0, custom_load=True),
|
|
||||||
'resnetv2_50x3_bitm': _cfg(
|
|
||||||
url='https://storage.googleapis.com/bit_models/BiT-M-R50x3-ILSVRC2012.npz',
|
|
||||||
input_size=(3, 448, 448), pool_size=(14, 14), crop_pct=1.0, custom_load=True),
|
|
||||||
'resnetv2_101x1_bitm': _cfg(
|
|
||||||
url='https://storage.googleapis.com/bit_models/BiT-M-R101x1-ILSVRC2012.npz',
|
|
||||||
input_size=(3, 448, 448), pool_size=(14, 14), crop_pct=1.0, custom_load=True),
|
|
||||||
'resnetv2_101x3_bitm': _cfg(
|
|
||||||
url='https://storage.googleapis.com/bit_models/BiT-M-R101x3-ILSVRC2012.npz',
|
|
||||||
input_size=(3, 448, 448), pool_size=(14, 14), crop_pct=1.0, custom_load=True),
|
|
||||||
'resnetv2_152x2_bitm': _cfg(
|
|
||||||
url='https://storage.googleapis.com/bit_models/BiT-M-R152x2-ILSVRC2012.npz',
|
|
||||||
input_size=(3, 448, 448), pool_size=(14, 14), crop_pct=1.0, custom_load=True),
|
|
||||||
'resnetv2_152x4_bitm': _cfg(
|
|
||||||
url='https://storage.googleapis.com/bit_models/BiT-M-R152x4-ILSVRC2012.npz',
|
|
||||||
input_size=(3, 480, 480), pool_size=(15, 15), crop_pct=1.0, custom_load=True), # only one at 480x480?
|
|
||||||
|
|
||||||
# trained on imagenet-21k
|
|
||||||
'resnetv2_50x1_bitm_in21k': _cfg(
|
|
||||||
url='https://storage.googleapis.com/bit_models/BiT-M-R50x1.npz',
|
|
||||||
num_classes=21843, custom_load=True),
|
|
||||||
'resnetv2_50x3_bitm_in21k': _cfg(
|
|
||||||
url='https://storage.googleapis.com/bit_models/BiT-M-R50x3.npz',
|
|
||||||
num_classes=21843, custom_load=True),
|
|
||||||
'resnetv2_101x1_bitm_in21k': _cfg(
|
|
||||||
url='https://storage.googleapis.com/bit_models/BiT-M-R101x1.npz',
|
|
||||||
num_classes=21843, custom_load=True),
|
|
||||||
'resnetv2_101x3_bitm_in21k': _cfg(
|
|
||||||
url='https://storage.googleapis.com/bit_models/BiT-M-R101x3.npz',
|
|
||||||
num_classes=21843, custom_load=True),
|
|
||||||
'resnetv2_152x2_bitm_in21k': _cfg(
|
|
||||||
url='https://storage.googleapis.com/bit_models/BiT-M-R152x2.npz',
|
|
||||||
num_classes=21843, custom_load=True),
|
|
||||||
'resnetv2_152x4_bitm_in21k': _cfg(
|
|
||||||
url='https://storage.googleapis.com/bit_models/BiT-M-R152x4.npz',
|
|
||||||
num_classes=21843, custom_load=True),
|
|
||||||
|
|
||||||
'resnetv2_50x1_bit_distilled': _cfg(
|
|
||||||
url='https://storage.googleapis.com/bit_models/distill/R50x1_224.npz',
|
|
||||||
interpolation='bicubic', custom_load=True),
|
|
||||||
'resnetv2_152x2_bit_teacher': _cfg(
|
|
||||||
url='https://storage.googleapis.com/bit_models/distill/R152x2_T_224.npz',
|
|
||||||
interpolation='bicubic', custom_load=True),
|
|
||||||
'resnetv2_152x2_bit_teacher_384': _cfg(
|
|
||||||
url='https://storage.googleapis.com/bit_models/distill/R152x2_T_384.npz',
|
|
||||||
input_size=(3, 384, 384), pool_size=(12, 12), crop_pct=1.0, interpolation='bicubic', custom_load=True),
|
|
||||||
|
|
||||||
'resnetv2_50': _cfg(
|
|
||||||
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-rsb-weights/resnetv2_50_a1h-000cdf49.pth',
|
|
||||||
interpolation='bicubic', crop_pct=0.95),
|
|
||||||
'resnetv2_50d': _cfg(
|
|
||||||
interpolation='bicubic', first_conv='stem.conv1'),
|
|
||||||
'resnetv2_50t': _cfg(
|
|
||||||
interpolation='bicubic', first_conv='stem.conv1'),
|
|
||||||
'resnetv2_101': _cfg(
|
|
||||||
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-rsb-weights/resnetv2_101_a1h-5d01f016.pth',
|
|
||||||
interpolation='bicubic', crop_pct=0.95),
|
|
||||||
'resnetv2_101d': _cfg(
|
|
||||||
interpolation='bicubic', first_conv='stem.conv1'),
|
|
||||||
'resnetv2_152': _cfg(
|
|
||||||
interpolation='bicubic'),
|
|
||||||
'resnetv2_152d': _cfg(
|
|
||||||
interpolation='bicubic', first_conv='stem.conv1'),
|
|
||||||
|
|
||||||
'resnetv2_50d_gn': _cfg(
|
|
||||||
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-tpu-weights/resnetv2_50d_gn_ah-c415c11a.pth',
|
|
||||||
interpolation='bicubic', first_conv='stem.conv1', test_input_size=(3, 288, 288), crop_pct=0.95),
|
|
||||||
'resnetv2_50d_evob': _cfg(
|
|
||||||
interpolation='bicubic', first_conv='stem.conv1'),
|
|
||||||
'resnetv2_50d_evos': _cfg(
|
|
||||||
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-tpu-weights/resnetv2_50d_evos_ah-7c4dd548.pth',
|
|
||||||
interpolation='bicubic', first_conv='stem.conv1', test_input_size=(3, 288, 288), crop_pct=0.95),
|
|
||||||
'resnetv2_50d_frn': _cfg(
|
|
||||||
interpolation='bicubic', first_conv='stem.conv1'),
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
def make_div(v, divisor=8):
|
|
||||||
min_value = divisor
|
|
||||||
new_v = max(min_value, int(v + divisor / 2) // divisor * divisor)
|
|
||||||
if new_v < 0.9 * v:
|
|
||||||
new_v += divisor
|
|
||||||
return new_v
|
|
||||||
|
|
||||||
|
|
||||||
class PreActBottleneck(nn.Module):
|
class PreActBottleneck(nn.Module):
|
||||||
"""Pre-activation (v2) bottleneck block.
|
"""Pre-activation (v2) bottleneck block.
|
||||||
|
@ -174,7 +75,7 @@ class PreActBottleneck(nn.Module):
|
||||||
conv_layer = conv_layer or StdConv2d
|
conv_layer = conv_layer or StdConv2d
|
||||||
norm_layer = norm_layer or partial(GroupNormAct, num_groups=32)
|
norm_layer = norm_layer or partial(GroupNormAct, num_groups=32)
|
||||||
out_chs = out_chs or in_chs
|
out_chs = out_chs or in_chs
|
||||||
mid_chs = make_div(out_chs * bottle_ratio)
|
mid_chs = make_divisible(out_chs * bottle_ratio)
|
||||||
|
|
||||||
if proj_layer is not None:
|
if proj_layer is not None:
|
||||||
self.downsample = proj_layer(
|
self.downsample = proj_layer(
|
||||||
|
@ -234,7 +135,7 @@ class Bottleneck(nn.Module):
|
||||||
conv_layer = conv_layer or StdConv2d
|
conv_layer = conv_layer or StdConv2d
|
||||||
norm_layer = norm_layer or partial(GroupNormAct, num_groups=32)
|
norm_layer = norm_layer or partial(GroupNormAct, num_groups=32)
|
||||||
out_chs = out_chs or in_chs
|
out_chs = out_chs or in_chs
|
||||||
mid_chs = make_div(out_chs * bottle_ratio)
|
mid_chs = make_divisible(out_chs * bottle_ratio)
|
||||||
|
|
||||||
if proj_layer is not None:
|
if proj_layer is not None:
|
||||||
self.downsample = proj_layer(
|
self.downsample = proj_layer(
|
||||||
|
@ -472,7 +373,7 @@ class ResNetV2(nn.Module):
|
||||||
act_layer = get_act_layer(act_layer)
|
act_layer = get_act_layer(act_layer)
|
||||||
|
|
||||||
self.feature_info = []
|
self.feature_info = []
|
||||||
stem_chs = make_div(stem_chs * wf)
|
stem_chs = make_divisible(stem_chs * wf)
|
||||||
self.stem = create_resnetv2_stem(
|
self.stem = create_resnetv2_stem(
|
||||||
in_chans,
|
in_chans,
|
||||||
stem_chs,
|
stem_chs,
|
||||||
|
@ -491,7 +392,7 @@ class ResNetV2(nn.Module):
|
||||||
block_fn = PreActBottleneck if preact else Bottleneck
|
block_fn = PreActBottleneck if preact else Bottleneck
|
||||||
self.stages = nn.Sequential()
|
self.stages = nn.Sequential()
|
||||||
for stage_idx, (d, c, bdpr) in enumerate(zip(layers, channels, block_dprs)):
|
for stage_idx, (d, c, bdpr) in enumerate(zip(layers, channels, block_dprs)):
|
||||||
out_chs = make_div(c * wf)
|
out_chs = make_divisible(c * wf)
|
||||||
stride = 1 if stage_idx == 0 else 2
|
stride = 1 if stage_idx == 0 else 2
|
||||||
if curr_stride >= output_stride:
|
if curr_stride >= output_stride:
|
||||||
dilation *= stride
|
dilation *= stride
|
||||||
|
@ -517,7 +418,12 @@ class ResNetV2(nn.Module):
|
||||||
self.num_features = prev_chs
|
self.num_features = prev_chs
|
||||||
self.norm = norm_layer(self.num_features) if preact else nn.Identity()
|
self.norm = norm_layer(self.num_features) if preact else nn.Identity()
|
||||||
self.head = ClassifierHead(
|
self.head = ClassifierHead(
|
||||||
self.num_features, num_classes, pool_type=global_pool, drop_rate=self.drop_rate, use_conv=True)
|
self.num_features,
|
||||||
|
num_classes,
|
||||||
|
pool_type=global_pool,
|
||||||
|
drop_rate=self.drop_rate,
|
||||||
|
use_conv=True,
|
||||||
|
)
|
||||||
|
|
||||||
self.init_weights(zero_init_last=zero_init_last)
|
self.init_weights(zero_init_last=zero_init_last)
|
||||||
self.grad_checkpointing = False
|
self.grad_checkpointing = False
|
||||||
|
@ -551,8 +457,7 @@ class ResNetV2(nn.Module):
|
||||||
|
|
||||||
def reset_classifier(self, num_classes, global_pool='avg'):
|
def reset_classifier(self, num_classes, global_pool='avg'):
|
||||||
self.num_classes = num_classes
|
self.num_classes = num_classes
|
||||||
self.head = ClassifierHead(
|
self.head.reset(num_classes, global_pool)
|
||||||
self.num_features, num_classes, pool_type=global_pool, drop_rate=self.drop_rate, use_conv=True)
|
|
||||||
|
|
||||||
def forward_features(self, x):
|
def forward_features(self, x):
|
||||||
x = self.stem(x)
|
x = self.stem(x)
|
||||||
|
@ -636,112 +541,141 @@ def _create_resnetv2(variant, pretrained=False, **kwargs):
|
||||||
|
|
||||||
def _create_resnetv2_bit(variant, pretrained=False, **kwargs):
|
def _create_resnetv2_bit(variant, pretrained=False, **kwargs):
|
||||||
return _create_resnetv2(
|
return _create_resnetv2(
|
||||||
variant, pretrained=pretrained, stem_type='fixed', conv_layer=partial(StdConv2d, eps=1e-8), **kwargs)
|
variant,
|
||||||
|
pretrained=pretrained,
|
||||||
|
stem_type='fixed',
|
||||||
|
conv_layer=partial(StdConv2d, eps=1e-8),
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _cfg(url='', **kwargs):
|
||||||
|
return {
|
||||||
|
'url': url,
|
||||||
|
'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': (7, 7),
|
||||||
|
'crop_pct': 0.875, 'interpolation': 'bilinear',
|
||||||
|
'mean': IMAGENET_INCEPTION_MEAN, 'std': IMAGENET_INCEPTION_STD,
|
||||||
|
'first_conv': 'stem.conv', 'classifier': 'head.fc',
|
||||||
|
**kwargs
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
default_cfgs = generate_default_cfgs({
|
||||||
|
# Paper: Knowledge distillation: A good teacher is patient and consistent - https://arxiv.org/abs/2106.05237
|
||||||
|
'resnetv2_50x1_bit.goog_distilled_in1k': _cfg(
|
||||||
|
hf_hub_id='timm/',
|
||||||
|
interpolation='bicubic', custom_load=True),
|
||||||
|
'resnetv2_152x2_bit.goog_teacher_in21k_ft_in1k': _cfg(
|
||||||
|
hf_hub_id='timm/',
|
||||||
|
interpolation='bicubic', custom_load=True),
|
||||||
|
'resnetv2_152x2_bit.goog_teacher_in21k_ft_in1k_384': _cfg(
|
||||||
|
hf_hub_id='timm/',
|
||||||
|
input_size=(3, 384, 384), pool_size=(12, 12), crop_pct=1.0, interpolation='bicubic', custom_load=True),
|
||||||
|
|
||||||
|
# pretrained on imagenet21k, finetuned on imagenet1k
|
||||||
|
'resnetv2_50x1_bit.goog_in21k_ft_in1k': _cfg(
|
||||||
|
hf_hub_id='timm/',
|
||||||
|
input_size=(3, 448, 448), pool_size=(14, 14), crop_pct=1.0, custom_load=True),
|
||||||
|
'resnetv2_50x3_bit.goog_in21k_ft_in1k': _cfg(
|
||||||
|
hf_hub_id='timm/',
|
||||||
|
input_size=(3, 448, 448), pool_size=(14, 14), crop_pct=1.0, custom_load=True),
|
||||||
|
'resnetv2_101x1_bit.goog_in21k_ft_in1k': _cfg(
|
||||||
|
hf_hub_id='timm/',
|
||||||
|
input_size=(3, 448, 448), pool_size=(14, 14), crop_pct=1.0, custom_load=True),
|
||||||
|
'resnetv2_101x3_bit.goog_in21k_ft_in1k': _cfg(
|
||||||
|
hf_hub_id='timm/',
|
||||||
|
input_size=(3, 448, 448), pool_size=(14, 14), crop_pct=1.0, custom_load=True),
|
||||||
|
'resnetv2_152x2_bit.goog_in21k_ft_in1k': _cfg(
|
||||||
|
hf_hub_id='timm/',
|
||||||
|
input_size=(3, 448, 448), pool_size=(14, 14), crop_pct=1.0, custom_load=True),
|
||||||
|
'resnetv2_152x4_bit.goog_in21k_ft_in1k': _cfg(
|
||||||
|
hf_hub_id='timm/',
|
||||||
|
input_size=(3, 480, 480), pool_size=(15, 15), crop_pct=1.0, custom_load=True), # only one at 480x480?
|
||||||
|
|
||||||
|
# trained on imagenet-21k
|
||||||
|
'resnetv2_50x1_bit.goog_in21k': _cfg(
|
||||||
|
hf_hub_id='timm/',
|
||||||
|
num_classes=21843, custom_load=True),
|
||||||
|
'resnetv2_50x3_bit.goog_in21k': _cfg(
|
||||||
|
hf_hub_id='timm/',
|
||||||
|
num_classes=21843, custom_load=True),
|
||||||
|
'resnetv2_101x1_bit.goog_in21k': _cfg(
|
||||||
|
hf_hub_id='timm/',
|
||||||
|
num_classes=21843, custom_load=True),
|
||||||
|
'resnetv2_101x3_bit.goog_in21k': _cfg(
|
||||||
|
hf_hub_id='timm/',
|
||||||
|
num_classes=21843, custom_load=True),
|
||||||
|
'resnetv2_152x2_bit.goog_in21k': _cfg(
|
||||||
|
hf_hub_id='timm/',
|
||||||
|
num_classes=21843, custom_load=True),
|
||||||
|
'resnetv2_152x4_bit.goog_in21k': _cfg(
|
||||||
|
hf_hub_id='timm/',
|
||||||
|
num_classes=21843, custom_load=True),
|
||||||
|
|
||||||
|
'resnetv2_50.a1h_in1k': _cfg(
|
||||||
|
hf_hub_id='timm/',
|
||||||
|
interpolation='bicubic', crop_pct=0.95, test_input_size=(3, 288, 288), test_crop_pct=1.0),
|
||||||
|
'resnetv2_50d.untrained': _cfg(
|
||||||
|
interpolation='bicubic', first_conv='stem.conv1'),
|
||||||
|
'resnetv2_50t.untrained': _cfg(
|
||||||
|
interpolation='bicubic', first_conv='stem.conv1'),
|
||||||
|
'resnetv2_101.a1h_in1k': _cfg(
|
||||||
|
hf_hub_id='timm/',
|
||||||
|
interpolation='bicubic', crop_pct=0.95, test_input_size=(3, 288, 288), test_crop_pct=1.0),
|
||||||
|
'resnetv2_101d.untrained': _cfg(
|
||||||
|
interpolation='bicubic', first_conv='stem.conv1'),
|
||||||
|
'resnetv2_152.untrained': _cfg(
|
||||||
|
interpolation='bicubic'),
|
||||||
|
'resnetv2_152d.untrained': _cfg(
|
||||||
|
interpolation='bicubic', first_conv='stem.conv1'),
|
||||||
|
|
||||||
|
'resnetv2_50d_gn.ah_in1k': _cfg(
|
||||||
|
hf_hub_id='timm/',
|
||||||
|
interpolation='bicubic', first_conv='stem.conv1',
|
||||||
|
crop_pct=0.95, test_input_size=(3, 288, 288), test_crop_pct=1.0),
|
||||||
|
'resnetv2_50d_evos.ah_in1k': _cfg(
|
||||||
|
hf_hub_id='timm/',
|
||||||
|
interpolation='bicubic', first_conv='stem.conv1',
|
||||||
|
crop_pct=0.95, test_input_size=(3, 288, 288), test_crop_pct=1.0),
|
||||||
|
'resnetv2_50d_frn.untrained': _cfg(
|
||||||
|
interpolation='bicubic', first_conv='stem.conv1'),
|
||||||
|
})
|
||||||
|
|
||||||
|
|
||||||
@register_model
|
@register_model
|
||||||
def resnetv2_50x1_bitm(pretrained=False, **kwargs):
|
def resnetv2_50x1_bit(pretrained=False, **kwargs):
|
||||||
return _create_resnetv2_bit(
|
return _create_resnetv2_bit(
|
||||||
'resnetv2_50x1_bitm', pretrained=pretrained, layers=[3, 4, 6, 3], width_factor=1, **kwargs)
|
'resnetv2_50x1_bit', pretrained=pretrained, layers=[3, 4, 6, 3], width_factor=1, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
@register_model
|
@register_model
|
||||||
def resnetv2_50x3_bitm(pretrained=False, **kwargs):
|
def resnetv2_50x3_bit(pretrained=False, **kwargs):
|
||||||
return _create_resnetv2_bit(
|
return _create_resnetv2_bit(
|
||||||
'resnetv2_50x3_bitm', pretrained=pretrained, layers=[3, 4, 6, 3], width_factor=3, **kwargs)
|
'resnetv2_50x3_bit', pretrained=pretrained, layers=[3, 4, 6, 3], width_factor=3, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
@register_model
|
@register_model
|
||||||
def resnetv2_101x1_bitm(pretrained=False, **kwargs):
|
def resnetv2_101x1_bit(pretrained=False, **kwargs):
|
||||||
return _create_resnetv2_bit(
|
return _create_resnetv2_bit(
|
||||||
'resnetv2_101x1_bitm', pretrained=pretrained, layers=[3, 4, 23, 3], width_factor=1, **kwargs)
|
'resnetv2_101x1_bit', pretrained=pretrained, layers=[3, 4, 23, 3], width_factor=1, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
@register_model
|
@register_model
|
||||||
def resnetv2_101x3_bitm(pretrained=False, **kwargs):
|
def resnetv2_101x3_bit(pretrained=False, **kwargs):
|
||||||
return _create_resnetv2_bit(
|
return _create_resnetv2_bit(
|
||||||
'resnetv2_101x3_bitm', pretrained=pretrained, layers=[3, 4, 23, 3], width_factor=3, **kwargs)
|
'resnetv2_101x3_bit', pretrained=pretrained, layers=[3, 4, 23, 3], width_factor=3, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
@register_model
|
@register_model
|
||||||
def resnetv2_152x2_bitm(pretrained=False, **kwargs):
|
def resnetv2_152x2_bit(pretrained=False, **kwargs):
|
||||||
return _create_resnetv2_bit(
|
return _create_resnetv2_bit(
|
||||||
'resnetv2_152x2_bitm', pretrained=pretrained, layers=[3, 8, 36, 3], width_factor=2, **kwargs)
|
'resnetv2_152x2_bit', pretrained=pretrained, layers=[3, 8, 36, 3], width_factor=2, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
@register_model
|
@register_model
|
||||||
def resnetv2_152x4_bitm(pretrained=False, **kwargs):
|
def resnetv2_152x4_bit(pretrained=False, **kwargs):
|
||||||
return _create_resnetv2_bit(
|
return _create_resnetv2_bit(
|
||||||
'resnetv2_152x4_bitm', pretrained=pretrained, layers=[3, 8, 36, 3], width_factor=4, **kwargs)
|
'resnetv2_152x4_bit', pretrained=pretrained, layers=[3, 8, 36, 3], width_factor=4, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
@register_model
|
|
||||||
def resnetv2_50x1_bitm_in21k(pretrained=False, **kwargs):
|
|
||||||
return _create_resnetv2_bit(
|
|
||||||
'resnetv2_50x1_bitm_in21k', pretrained=pretrained, num_classes=kwargs.pop('num_classes', 21843),
|
|
||||||
layers=[3, 4, 6, 3], width_factor=1, **kwargs)
|
|
||||||
|
|
||||||
|
|
||||||
@register_model
|
|
||||||
def resnetv2_50x3_bitm_in21k(pretrained=False, **kwargs):
|
|
||||||
return _create_resnetv2_bit(
|
|
||||||
'resnetv2_50x3_bitm_in21k', pretrained=pretrained, num_classes=kwargs.pop('num_classes', 21843),
|
|
||||||
layers=[3, 4, 6, 3], width_factor=3, **kwargs)
|
|
||||||
|
|
||||||
|
|
||||||
@register_model
|
|
||||||
def resnetv2_101x1_bitm_in21k(pretrained=False, **kwargs):
|
|
||||||
return _create_resnetv2(
|
|
||||||
'resnetv2_101x1_bitm_in21k', pretrained=pretrained, num_classes=kwargs.pop('num_classes', 21843),
|
|
||||||
layers=[3, 4, 23, 3], width_factor=1, **kwargs)
|
|
||||||
|
|
||||||
|
|
||||||
@register_model
|
|
||||||
def resnetv2_101x3_bitm_in21k(pretrained=False, **kwargs):
|
|
||||||
return _create_resnetv2_bit(
|
|
||||||
'resnetv2_101x3_bitm_in21k', pretrained=pretrained, num_classes=kwargs.pop('num_classes', 21843),
|
|
||||||
layers=[3, 4, 23, 3], width_factor=3, **kwargs)
|
|
||||||
|
|
||||||
|
|
||||||
@register_model
|
|
||||||
def resnetv2_152x2_bitm_in21k(pretrained=False, **kwargs):
|
|
||||||
return _create_resnetv2_bit(
|
|
||||||
'resnetv2_152x2_bitm_in21k', pretrained=pretrained, num_classes=kwargs.pop('num_classes', 21843),
|
|
||||||
layers=[3, 8, 36, 3], width_factor=2, **kwargs)
|
|
||||||
|
|
||||||
|
|
||||||
@register_model
|
|
||||||
def resnetv2_152x4_bitm_in21k(pretrained=False, **kwargs):
|
|
||||||
return _create_resnetv2_bit(
|
|
||||||
'resnetv2_152x4_bitm_in21k', pretrained=pretrained, num_classes=kwargs.pop('num_classes', 21843),
|
|
||||||
layers=[3, 8, 36, 3], width_factor=4, **kwargs)
|
|
||||||
|
|
||||||
|
|
||||||
@register_model
|
|
||||||
def resnetv2_50x1_bit_distilled(pretrained=False, **kwargs):
|
|
||||||
""" ResNetV2-50x1-BiT Distilled
|
|
||||||
Paper: Knowledge distillation: A good teacher is patient and consistent - https://arxiv.org/abs/2106.05237
|
|
||||||
"""
|
|
||||||
return _create_resnetv2_bit(
|
|
||||||
'resnetv2_50x1_bit_distilled', pretrained=pretrained, layers=[3, 4, 6, 3], width_factor=1, **kwargs)
|
|
||||||
|
|
||||||
|
|
||||||
@register_model
|
|
||||||
def resnetv2_152x2_bit_teacher(pretrained=False, **kwargs):
|
|
||||||
""" ResNetV2-152x2-BiT Teacher
|
|
||||||
Paper: Knowledge distillation: A good teacher is patient and consistent - https://arxiv.org/abs/2106.05237
|
|
||||||
"""
|
|
||||||
return _create_resnetv2_bit(
|
|
||||||
'resnetv2_152x2_bit_teacher', pretrained=pretrained, layers=[3, 8, 36, 3], width_factor=2, **kwargs)
|
|
||||||
|
|
||||||
|
|
||||||
@register_model
|
|
||||||
def resnetv2_152x2_bit_teacher_384(pretrained=False, **kwargs):
|
|
||||||
""" ResNetV2-152xx-BiT Teacher @ 384x384
|
|
||||||
Paper: Knowledge distillation: A good teacher is patient and consistent - https://arxiv.org/abs/2106.05237
|
|
||||||
"""
|
|
||||||
return _create_resnetv2_bit(
|
|
||||||
'resnetv2_152x2_bit_teacher_384', pretrained=pretrained, layers=[3, 8, 36, 3], width_factor=2, **kwargs)
|
|
||||||
|
|
||||||
|
|
||||||
@register_model
|
@register_model
|
||||||
|
@ -804,14 +738,6 @@ def resnetv2_50d_gn(pretrained=False, **kwargs):
|
||||||
return _create_resnetv2('resnetv2_50d_gn', pretrained=pretrained, **dict(model_args, **kwargs))
|
return _create_resnetv2('resnetv2_50d_gn', pretrained=pretrained, **dict(model_args, **kwargs))
|
||||||
|
|
||||||
|
|
||||||
@register_model
|
|
||||||
def resnetv2_50d_evob(pretrained=False, **kwargs):
|
|
||||||
model_args = dict(
|
|
||||||
layers=[3, 4, 6, 3], conv_layer=create_conv2d, norm_layer=EvoNorm2dB0,
|
|
||||||
stem_type='deep', avg_down=True, zero_init_last=True)
|
|
||||||
return _create_resnetv2('resnetv2_50d_evob', pretrained=pretrained, **dict(model_args, **kwargs))
|
|
||||||
|
|
||||||
|
|
||||||
@register_model
|
@register_model
|
||||||
def resnetv2_50d_evos(pretrained=False, **kwargs):
|
def resnetv2_50d_evos(pretrained=False, **kwargs):
|
||||||
model_args = dict(
|
model_args = dict(
|
||||||
|
@ -826,3 +752,22 @@ def resnetv2_50d_frn(pretrained=False, **kwargs):
|
||||||
layers=[3, 4, 6, 3], conv_layer=create_conv2d, norm_layer=FilterResponseNormTlu2d,
|
layers=[3, 4, 6, 3], conv_layer=create_conv2d, norm_layer=FilterResponseNormTlu2d,
|
||||||
stem_type='deep', avg_down=True)
|
stem_type='deep', avg_down=True)
|
||||||
return _create_resnetv2('resnetv2_50d_frn', pretrained=pretrained, **dict(model_args, **kwargs))
|
return _create_resnetv2('resnetv2_50d_frn', pretrained=pretrained, **dict(model_args, **kwargs))
|
||||||
|
|
||||||
|
|
||||||
|
register_model_deprecations(__name__, {
|
||||||
|
'resnetv2_50x1_bitm': 'resnetv2_50x1_bit.goog_in21k_ft_in1k',
|
||||||
|
'resnetv2_50x3_bitm': 'resnetv2_50x3_bit.goog_in21k_ft_in1k',
|
||||||
|
'resnetv2_101x1_bitm': 'resnetv2_101x1_bit.goog_in21k_ft_in1k',
|
||||||
|
'resnetv2_101x3_bitm': 'resnetv2_101x3_bit.goog_in21k_ft_in1k',
|
||||||
|
'resnetv2_152x2_bitm': 'resnetv2_152x2_bit.goog_in21k_ft_in1k',
|
||||||
|
'resnetv2_152x4_bitm': 'resnetv2_152x4_bit.goog_in21k_ft_in1k',
|
||||||
|
'resnetv2_50x1_bitm_in21k': 'resnetv2_50x1_bit.goog_in21k',
|
||||||
|
'resnetv2_50x3_bitm_in21k': 'resnetv2_50x3_bit.goog_in21k',
|
||||||
|
'resnetv2_101x1_bitm_in21k': 'resnetv2_101x1_bit.goog_in21k',
|
||||||
|
'resnetv2_101x3_bitm_in21k': 'resnetv2_101x3_bit.goog_in21k',
|
||||||
|
'resnetv2_152x2_bitm_in21k': 'resnetv2_152x2_bit.goog_in21k',
|
||||||
|
'resnetv2_152x4_bitm_in21k': 'resnetv2_152x4_bit.goog_in21k',
|
||||||
|
'resnetv2_50x1_bit_distilled': 'resnetv2_50x1_bit.goog_distilled_in1k',
|
||||||
|
'resnetv2_152x2_bit_teacher': 'resnetv2_152x2_bit.goog_teacher_in21k_ft_in1k',
|
||||||
|
'resnetv2_152x2_bit_teacher_384': 'resnetv2_152x2_bit.goog_teacher_in21k_ft_in1k_384',
|
||||||
|
})
|
||||||
|
|
Loading…
Reference in New Issue