Add Facebook Research Semi-Supervised and Semi-Weakly Supervised ResNet model weights.
parent
a9eb484835
commit
b93fcf0708
59
sotabench.py
59
sotabench.py
|
@ -167,6 +167,65 @@ model_list = [
|
|||
_entry('ig_resnext101_32x48d', 'ResNeXt-101 32x48d (288x288 Mean-Max Pooling)', '1805.00932',
|
||||
ttp=True, args=dict(img_size=288), batch_size=BATCH_SIZE // 8),
|
||||
|
||||
## Facebook SSL weights
|
||||
_entry('ssl_resnet18', 'ResNet-18', '1905.00546',
|
||||
model_desc='Weights from https://github.com/facebookresearch/semi-supervised-ImageNet1K-models/'),
|
||||
_entry('ssl_resnet50', 'ResNet-50', '1905.00546',
|
||||
model_desc='Weights from https://github.com/facebookresearch/semi-supervised-ImageNet1K-models/'),
|
||||
_entry('ssl_resnext50_32x4d', 'ResNeXt-50 32x4d', '1905.00546',
|
||||
model_desc='Weights from https://github.com/facebookresearch/semi-supervised-ImageNet1K-models/'),
|
||||
_entry('ssl_resnext101_32x4d', 'ResNeXt-101 32x4d', '1905.00546',
|
||||
model_desc='Weights from https://github.com/facebookresearch/semi-supervised-ImageNet1K-models/'),
|
||||
_entry('ssl_resnext101_32x8d', 'ResNeXt-101 32x8d', '1905.00546',
|
||||
model_desc='Weights from https://github.com/facebookresearch/semi-supervised-ImageNet1K-models/'),
|
||||
_entry('ssl_resnext101_32x16d', 'ResNeXt-101 32x16d', '1905.00546',
|
||||
model_desc='Weights from https://github.com/facebookresearch/semi-supervised-ImageNet1K-models/'),
|
||||
|
||||
_entry('ssl_resnet50', 'ResNet-50 (288x288 Mean-Max Pooling)', '1905.00546',
|
||||
ttp=True, args=dict(img_size=288),
|
||||
model_desc='Weights from https://github.com/facebookresearch/semi-supervised-ImageNet1K-models/'),
|
||||
_entry('ssl_resnext50_32x4d', 'ResNeXt-50 32x4d (288x288 Mean-Max Pooling)', '1905.00546',
|
||||
ttp=True, args=dict(img_size=288),
|
||||
model_desc='Weights from https://github.com/facebookresearch/semi-supervised-ImageNet1K-models/'),
|
||||
_entry('ssl_resnext101_32x4d', 'ResNeXt-101 32x4d (288x288 Mean-Max Pooling)', '1905.00546',
|
||||
ttp=True, args=dict(img_size=288),
|
||||
model_desc='Weights from https://github.com/facebookresearch/semi-supervised-ImageNet1K-models/'),
|
||||
_entry('ssl_resnext101_32x8d', 'ResNeXt-101 32x8d (288x288 Mean-Max Pooling)', '1905.00546',
|
||||
ttp=True, args=dict(img_size=288),
|
||||
model_desc='Weights from https://github.com/facebookresearch/semi-supervised-ImageNet1K-models/'),
|
||||
_entry('ssl_resnext101_32x16d', 'ResNeXt-101 32x16d (288x288 Mean-Max Pooling)', '1905.00546',
|
||||
ttp=True, args=dict(img_size=288), batch_size=BATCH_SIZE // 2,
|
||||
model_desc='Weights from https://github.com/facebookresearch/semi-supervised-ImageNet1K-models/'),
|
||||
|
||||
## Facebook SWSL weights
|
||||
_entry('swsl_resnet18', 'ResNet-18', '1905.00546',
|
||||
model_desc='Weights from https://github.com/facebookresearch/semi-supervised-ImageNet1K-models/'),
|
||||
_entry('swsl_resnet50', 'ResNet-50', '1905.00546',
|
||||
model_desc='Weights from https://github.com/facebookresearch/semi-supervised-ImageNet1K-models/'),
|
||||
_entry('swsl_resnext50_32x4d', 'ResNeXt-50 32x4d', '1905.00546',
|
||||
model_desc='Weights from https://github.com/facebookresearch/semi-supervised-ImageNet1K-models/'),
|
||||
_entry('swsl_resnext101_32x4d', 'ResNeXt-101 32x4d', '1905.00546',
|
||||
model_desc='Weights from https://github.com/facebookresearch/semi-supervised-ImageNet1K-models/'),
|
||||
_entry('swsl_resnext101_32x8d', 'ResNeXt-101 32x8d', '1905.00546',
|
||||
model_desc='Weights from https://github.com/facebookresearch/semi-supervised-ImageNet1K-models/'),
|
||||
_entry('swsl_resnext101_32x16d', 'ResNeXt-101 32x16d', '1905.00546'),
|
||||
|
||||
_entry('swsl_resnet50', 'ResNet-50 (288x288 Mean-Max Pooling)', '1905.00546',
|
||||
ttp=True, args=dict(img_size=288),
|
||||
model_desc='Weights from https://github.com/facebookresearch/semi-supervised-ImageNet1K-models/'),
|
||||
_entry('swsl_resnext50_32x4d', 'ResNeXt-50 32x4d (288x288 Mean-Max Pooling)', '1905.00546',
|
||||
ttp=True, args=dict(img_size=288),
|
||||
model_desc='Weights from https://github.com/facebookresearch/semi-supervised-ImageNet1K-models/'),
|
||||
_entry('swsl_resnext101_32x4d', 'ResNeXt-101 32x4d (288x288 Mean-Max Pooling)', '1905.00546',
|
||||
ttp=True, args=dict(img_size=288),
|
||||
model_desc='Weights from https://github.com/facebookresearch/semi-supervised-ImageNet1K-models/'),
|
||||
_entry('swsl_resnext101_32x8d', 'ResNeXt-101 32x8d (288x288 Mean-Max Pooling)', '1905.00546',
|
||||
ttp=True, args=dict(img_size=288),
|
||||
model_desc='Weights from https://github.com/facebookresearch/semi-supervised-ImageNet1K-models/'),
|
||||
_entry('swsl_resnext101_32x16d', 'ResNeXt-101 32x16d (288x288 Mean-Max Pooling)', '1905.00546',
|
||||
ttp=True, args=dict(img_size=288), batch_size=BATCH_SIZE // 2,
|
||||
model_desc='Weights from https://github.com/facebookresearch/semi-supervised-ImageNet1K-models/'),
|
||||
|
||||
## DLA official impl weights (to remove if sotabench added to source)
|
||||
_entry('dla34', 'DLA-34', '1707.06484'),
|
||||
_entry('dla46_c', 'DLA-46-C', '1707.06484'),
|
||||
|
|
|
@ -57,15 +57,17 @@ def resume_checkpoint(model, checkpoint_path):
|
|||
raise FileNotFoundError()
|
||||
|
||||
|
||||
def load_pretrained(model, default_cfg, num_classes=1000, in_chans=3, filter_fn=None):
|
||||
if 'url' not in default_cfg or not default_cfg['url']:
|
||||
def load_pretrained(model, cfg=None, num_classes=1000, in_chans=3, filter_fn=None):
|
||||
if cfg is None:
|
||||
cfg = getattr(model, 'default_cfg')
|
||||
if cfg is None or 'url' not in cfg or not cfg['url']:
|
||||
logging.warning("Pretrained model URL is invalid, using random initialization.")
|
||||
return
|
||||
|
||||
state_dict = model_zoo.load_url(default_cfg['url'], progress=False)
|
||||
state_dict = model_zoo.load_url(cfg['url'], progress=False)
|
||||
|
||||
if in_chans == 1:
|
||||
conv1_name = default_cfg['first_conv']
|
||||
conv1_name = cfg['first_conv']
|
||||
logging.info('Converting first conv (%s) from 3 to 1 channel' % conv1_name)
|
||||
conv1_weight = state_dict[conv1_name + '.weight']
|
||||
state_dict[conv1_name + '.weight'] = conv1_weight.sum(dim=1, keepdim=True)
|
||||
|
@ -73,14 +75,14 @@ def load_pretrained(model, default_cfg, num_classes=1000, in_chans=3, filter_fn=
|
|||
assert False, "Invalid in_chans for pretrained weights"
|
||||
|
||||
strict = True
|
||||
classifier_name = default_cfg['classifier']
|
||||
if num_classes == 1000 and default_cfg['num_classes'] == 1001:
|
||||
classifier_name = cfg['classifier']
|
||||
if num_classes == 1000 and cfg['num_classes'] == 1001:
|
||||
# special case for imagenet trained models with extra background class in pretrained weights
|
||||
classifier_weight = state_dict[classifier_name + '.weight']
|
||||
state_dict[classifier_name + '.weight'] = classifier_weight[1:]
|
||||
classifier_bias = state_dict[classifier_name + '.bias']
|
||||
state_dict[classifier_name + '.bias'] = classifier_bias[1:]
|
||||
elif num_classes != default_cfg['num_classes']:
|
||||
elif num_classes != cfg['num_classes']:
|
||||
# completely discard fully connected for all other differences between pretrained and created model
|
||||
del state_dict[classifier_name + '.weight']
|
||||
del state_dict[classifier_name + '.bias']
|
||||
|
|
|
@ -67,6 +67,30 @@ default_cfgs = {
|
|||
'ig_resnext101_32x16d': _cfg(url='https://download.pytorch.org/models/ig_resnext101_32x16-c6f796b0.pth'),
|
||||
'ig_resnext101_32x32d': _cfg(url='https://download.pytorch.org/models/ig_resnext101_32x32-e4b90b00.pth'),
|
||||
'ig_resnext101_32x48d': _cfg(url='https://download.pytorch.org/models/ig_resnext101_32x48-3e41cc8a.pth'),
|
||||
'ssl_resnet18': _cfg(
|
||||
url='https://dl.fbaipublicfiles.com/semiweaksupervision/model_files/semi_supervised_resnet18-d92f0530.pth'),
|
||||
'ssl_resnet50': _cfg(
|
||||
url='https://dl.fbaipublicfiles.com/semiweaksupervision/model_files/semi_supervised_resnet50-08389792.pth'),
|
||||
'ssl_resnext50_32x4d': _cfg(
|
||||
url='https://dl.fbaipublicfiles.com/semiweaksupervision/model_files/semi_supervised_resnext50_32x4-ddb3e555.pth'),
|
||||
'ssl_resnext101_32x4d': _cfg(
|
||||
url='https://dl.fbaipublicfiles.com/semiweaksupervision/model_files/semi_supervised_resnext101_32x4-dc43570a.pth'),
|
||||
'ssl_resnext101_32x8d': _cfg(
|
||||
url='https://dl.fbaipublicfiles.com/semiweaksupervision/model_files/semi_supervised_resnext101_32x8-2cfe2f8b.pth'),
|
||||
'ssl_resnext101_32x16d': _cfg(
|
||||
url='https://dl.fbaipublicfiles.com/semiweaksupervision/model_files/semi_supervised_resnext101_32x16-15fffa57.pth'),
|
||||
'swsl_resnet18': _cfg(
|
||||
url='https://dl.fbaipublicfiles.com/semiweaksupervision/model_files/semi_weakly_supervised_resnet18-118f1556.pth'),
|
||||
'swsl_resnet50': _cfg(
|
||||
url='https://dl.fbaipublicfiles.com/semiweaksupervision/model_files/semi_weakly_supervised_resnet50-16a12f1b.pth'),
|
||||
'swsl_resnext50_32x4d': _cfg(
|
||||
url='https://dl.fbaipublicfiles.com/semiweaksupervision/model_files/semi_weakly_supervised_resnext50_32x4-72679e44.pth'),
|
||||
'swsl_resnext101_32x4d': _cfg(
|
||||
url='https://dl.fbaipublicfiles.com/semiweaksupervision/model_files/semi_weakly_supervised_resnext101_32x4-3f87e46b.pth'),
|
||||
'swsl_resnext101_32x8d': _cfg(
|
||||
url='https://dl.fbaipublicfiles.com/semiweaksupervision/model_files/semi_weakly_supervised_resnext101_32x8-b4712904.pth'),
|
||||
'swsl_resnext101_32x16d': _cfg(
|
||||
url='https://dl.fbaipublicfiles.com/semiweaksupervision/model_files/semi_weakly_supervised_resnext101_32x16-f3559a9c.pth'),
|
||||
}
|
||||
|
||||
|
||||
|
@ -621,80 +645,218 @@ def tv_resnext50_32x4d(pretrained=False, num_classes=1000, in_chans=3, **kwargs)
|
|||
|
||||
|
||||
@register_model
|
||||
def ig_resnext101_32x8d(pretrained=True, num_classes=1000, in_chans=3, **kwargs):
|
||||
def ig_resnext101_32x8d(pretrained=True, **kwargs):
|
||||
"""Constructs a ResNeXt-101 32x8 model pre-trained on weakly-supervised data
|
||||
and finetuned on ImageNet from Figure 5 in
|
||||
`"Exploring the Limits of Weakly Supervised Pretraining" <https://arxiv.org/abs/1805.00932>`_
|
||||
Weights from https://pytorch.org/hub/facebookresearch_WSL-Images_resnext/
|
||||
Args:
|
||||
pretrained (bool): load pretrained weights
|
||||
num_classes (int): number of classes for classifier (default: 1000 for pretrained)
|
||||
in_chans (int): number of input planes (default: 3 for pretrained / color)
|
||||
"""
|
||||
default_cfg = default_cfgs['ig_resnext101_32x8d']
|
||||
model = ResNet(Bottleneck, [3, 4, 23, 3], cardinality=32, base_width=8,
|
||||
num_classes=1000, in_chans=3, **kwargs)
|
||||
model.default_cfg = default_cfg
|
||||
model = ResNet(Bottleneck, [3, 4, 23, 3], cardinality=32, base_width=8, **kwargs)
|
||||
model.default_cfg = default_cfgs['ig_resnext101_32x8d']
|
||||
if pretrained:
|
||||
load_pretrained(model, default_cfg, num_classes, in_chans)
|
||||
load_pretrained(model, num_classes=kwargs.get('num_classes', 0), in_chans=kwargs.get('in_chans', 3))
|
||||
return model
|
||||
|
||||
|
||||
@register_model
|
||||
def ig_resnext101_32x16d(pretrained=True, num_classes=1000, in_chans=3, **kwargs):
|
||||
def ig_resnext101_32x16d(pretrained=True, **kwargs):
|
||||
"""Constructs a ResNeXt-101 32x16 model pre-trained on weakly-supervised data
|
||||
and finetuned on ImageNet from Figure 5 in
|
||||
`"Exploring the Limits of Weakly Supervised Pretraining" <https://arxiv.org/abs/1805.00932>`_
|
||||
Weights from https://pytorch.org/hub/facebookresearch_WSL-Images_resnext/
|
||||
Args:
|
||||
pretrained (bool): load pretrained weights
|
||||
num_classes (int): number of classes for classifier (default: 1000 for pretrained)
|
||||
in_chans (int): number of input planes (default: 3 for pretrained / color)
|
||||
"""
|
||||
default_cfg = default_cfgs['ig_resnext101_32x16d']
|
||||
model = ResNet(Bottleneck, [3, 4, 23, 3], cardinality=32, base_width=16,
|
||||
num_classes=1000, in_chans=3, **kwargs)
|
||||
model.default_cfg = default_cfg
|
||||
model = ResNet(Bottleneck, [3, 4, 23, 3], cardinality=32, base_width=16, **kwargs)
|
||||
model.default_cfg = default_cfgs['ig_resnext101_32x16d']
|
||||
if pretrained:
|
||||
load_pretrained(model, default_cfg, num_classes, in_chans)
|
||||
load_pretrained(model, num_classes=kwargs.get('num_classes', 0), in_chans=kwargs.get('in_chans', 3))
|
||||
return model
|
||||
|
||||
|
||||
@register_model
|
||||
def ig_resnext101_32x32d(pretrained=True, num_classes=1000, in_chans=3, **kwargs):
|
||||
def ig_resnext101_32x32d(pretrained=True, **kwargs):
|
||||
"""Constructs a ResNeXt-101 32x32 model pre-trained on weakly-supervised data
|
||||
and finetuned on ImageNet from Figure 5 in
|
||||
`"Exploring the Limits of Weakly Supervised Pretraining" <https://arxiv.org/abs/1805.00932>`_
|
||||
Weights from https://pytorch.org/hub/facebookresearch_WSL-Images_resnext/
|
||||
Args:
|
||||
pretrained (bool): load pretrained weights
|
||||
num_classes (int): number of classes for classifier (default: 1000 for pretrained)
|
||||
in_chans (int): number of input planes (default: 3 for pretrained / color)
|
||||
"""
|
||||
default_cfg = default_cfgs['ig_resnext101_32x32d']
|
||||
model = ResNet(Bottleneck, [3, 4, 23, 3], cardinality=32, base_width=32,
|
||||
num_classes=1000, in_chans=3, **kwargs)
|
||||
model.default_cfg = default_cfg
|
||||
model = ResNet(Bottleneck, [3, 4, 23, 3], cardinality=32, base_width=32, **kwargs)
|
||||
model.default_cfg = default_cfgs['ig_resnext101_32x32d']
|
||||
if pretrained:
|
||||
load_pretrained(model, default_cfg, num_classes, in_chans)
|
||||
load_pretrained(model, num_classes=kwargs.get('num_classes', 0), in_chans=kwargs.get('in_chans', 3))
|
||||
return model
|
||||
|
||||
|
||||
@register_model
|
||||
def ig_resnext101_32x48d(pretrained=True, num_classes=1000, in_chans=3, **kwargs):
|
||||
def ig_resnext101_32x48d(pretrained=True, **kwargs):
|
||||
"""Constructs a ResNeXt-101 32x48 model pre-trained on weakly-supervised data
|
||||
and finetuned on ImageNet from Figure 5 in
|
||||
`"Exploring the Limits of Weakly Supervised Pretraining" <https://arxiv.org/abs/1805.00932>`_
|
||||
Weights from https://pytorch.org/hub/facebookresearch_WSL-Images_resnext/
|
||||
Args:
|
||||
pretrained (bool): load pretrained weights
|
||||
num_classes (int): number of classes for classifier (default: 1000 for pretrained)
|
||||
in_chans (int): number of input planes (default: 3 for pretrained / color)
|
||||
"""
|
||||
default_cfg = default_cfgs['ig_resnext101_32x48d']
|
||||
model = ResNet(Bottleneck, [3, 4, 23, 3], cardinality=32, base_width=48,
|
||||
num_classes=1000, in_chans=3, **kwargs)
|
||||
model.default_cfg = default_cfg
|
||||
model = ResNet(Bottleneck, [3, 4, 23, 3], cardinality=32, base_width=48, **kwargs)
|
||||
model.default_cfg = default_cfgs['ig_resnext101_32x48d']
|
||||
if pretrained:
|
||||
load_pretrained(model, default_cfg, num_classes, in_chans)
|
||||
load_pretrained(model, num_classes=kwargs.get('num_classes', 0), in_chans=kwargs.get('in_chans', 3))
|
||||
return model
|
||||
|
||||
|
||||
@register_model
|
||||
def ssl_resnet18(pretrained=True, **kwargs):
|
||||
"""Constructs a semi-supervised ResNet-18 model pre-trained on YFCC100M dataset and finetuned on ImageNet
|
||||
`"Billion-scale Semi-Supervised Learning for Image Classification" <https://arxiv.org/abs/1905.00546>`_
|
||||
Weights from https://github.com/facebookresearch/semi-supervised-ImageNet1K-models/
|
||||
"""
|
||||
model = ResNet(BasicBlock, [2, 2, 2, 2], **kwargs)
|
||||
model.default_cfg = default_cfgs['ssl_resnet18']
|
||||
if pretrained:
|
||||
load_pretrained(model, num_classes=kwargs.get('num_classes', 0), in_chans=kwargs.get('in_chans', 3))
|
||||
return model
|
||||
|
||||
|
||||
@register_model
|
||||
def ssl_resnet50(pretrained=True, **kwargs):
|
||||
"""Constructs a semi-supervised ResNet-50 model pre-trained on YFCC100M dataset and finetuned on ImageNet
|
||||
`"Billion-scale Semi-Supervised Learning for Image Classification" <https://arxiv.org/abs/1905.00546>`_
|
||||
Weights from https://github.com/facebookresearch/semi-supervised-ImageNet1K-models/
|
||||
"""
|
||||
model = ResNet(Bottleneck, [3, 4, 6, 3], **kwargs)
|
||||
model.default_cfg = default_cfgs['ssl_resnet50']
|
||||
if pretrained:
|
||||
load_pretrained(model, num_classes=kwargs.get('num_classes', 0), in_chans=kwargs.get('in_chans', 3))
|
||||
return model
|
||||
|
||||
|
||||
@register_model
|
||||
def ssl_resnext50_32x4d(pretrained=True, **kwargs):
|
||||
"""Constructs a semi-supervised ResNeXt-50 32x4 model pre-trained on YFCC100M dataset and finetuned on ImageNet
|
||||
`"Billion-scale Semi-Supervised Learning for Image Classification" <https://arxiv.org/abs/1905.00546>`_
|
||||
Weights from https://github.com/facebookresearch/semi-supervised-ImageNet1K-models/
|
||||
"""
|
||||
model = ResNet(Bottleneck, [3, 4, 6, 3], cardinality=32, base_width=4, **kwargs)
|
||||
model.default_cfg = default_cfgs['ssl_resnext50_32x4d']
|
||||
if pretrained:
|
||||
load_pretrained(model, num_classes=kwargs.get('num_classes', 0), in_chans=kwargs.get('in_chans', 3))
|
||||
return model
|
||||
|
||||
|
||||
@register_model
|
||||
def ssl_resnext101_32x4d(pretrained=True, **kwargs):
|
||||
"""Constructs a semi-supervised ResNeXt-101 32x4 model pre-trained on YFCC100M dataset and finetuned on ImageNet
|
||||
`"Billion-scale Semi-Supervised Learning for Image Classification" <https://arxiv.org/abs/1905.00546>`_
|
||||
Weights from https://github.com/facebookresearch/semi-supervised-ImageNet1K-models/
|
||||
"""
|
||||
model = ResNet(Bottleneck, [3, 4, 23, 3], cardinality=32, base_width=4, **kwargs)
|
||||
model.default_cfg = default_cfgs['ssl_resnext101_32x4d']
|
||||
if pretrained:
|
||||
load_pretrained(model, num_classes=kwargs.get('num_classes', 0), in_chans=kwargs.get('in_chans', 3))
|
||||
return model
|
||||
|
||||
|
||||
@register_model
|
||||
def ssl_resnext101_32x8d(pretrained=True, **kwargs):
|
||||
"""Constructs a semi-supervised ResNeXt-101 32x8 model pre-trained on YFCC100M dataset and finetuned on ImageNet
|
||||
`"Billion-scale Semi-Supervised Learning for Image Classification" <https://arxiv.org/abs/1905.00546>`_
|
||||
Weights from https://github.com/facebookresearch/semi-supervised-ImageNet1K-models/
|
||||
"""
|
||||
model = ResNet(Bottleneck, [3, 4, 23, 3], cardinality=32, base_width=8, **kwargs)
|
||||
model.default_cfg = default_cfgs['ssl_resnext101_32x8d']
|
||||
if pretrained:
|
||||
load_pretrained(model, num_classes=kwargs.get('num_classes', 0), in_chans=kwargs.get('in_chans', 3))
|
||||
return model
|
||||
|
||||
|
||||
@register_model
|
||||
def ssl_resnext101_32x16d(pretrained=True, **kwargs):
|
||||
"""Constructs a semi-supervised ResNeXt-101 32x16 model pre-trained on YFCC100M dataset and finetuned on ImageNet
|
||||
`"Billion-scale Semi-Supervised Learning for Image Classification" <https://arxiv.org/abs/1905.00546>`_
|
||||
Weights from https://github.com/facebookresearch/semi-supervised-ImageNet1K-models/
|
||||
"""
|
||||
model = ResNet(Bottleneck, [3, 4, 23, 3], cardinality=32, base_width=16, **kwargs)
|
||||
model.default_cfg = default_cfgs['ssl_resnext101_32x16d']
|
||||
if pretrained:
|
||||
load_pretrained(model, num_classes=kwargs.get('num_classes', 0), in_chans=kwargs.get('in_chans', 3))
|
||||
return model
|
||||
|
||||
|
||||
@register_model
|
||||
def swsl_resnet18(pretrained=True, **kwargs):
|
||||
"""Constructs a semi-weakly supervised Resnet-18 model pre-trained on 1B weakly supervised
|
||||
image dataset and finetuned on ImageNet.
|
||||
`"Billion-scale Semi-Supervised Learning for Image Classification" <https://arxiv.org/abs/1905.00546>`_
|
||||
Weights from https://github.com/facebookresearch/semi-supervised-ImageNet1K-models/
|
||||
"""
|
||||
model = ResNet(BasicBlock, [2, 2, 2, 2], **kwargs)
|
||||
model.default_cfg = default_cfgs['swsl_resnet18']
|
||||
if pretrained:
|
||||
load_pretrained(model, num_classes=kwargs.get('num_classes', 0), in_chans=kwargs.get('in_chans', 3))
|
||||
return model
|
||||
|
||||
|
||||
@register_model
|
||||
def swsl_resnet50(pretrained=True, **kwargs):
|
||||
"""Constructs a semi-weakly supervised ResNet-50 model pre-trained on 1B weakly supervised
|
||||
image dataset and finetuned on ImageNet.
|
||||
`"Billion-scale Semi-Supervised Learning for Image Classification" <https://arxiv.org/abs/1905.00546>`_
|
||||
Weights from https://github.com/facebookresearch/semi-supervised-ImageNet1K-models/
|
||||
"""
|
||||
model = ResNet(Bottleneck, [3, 4, 6, 3], **kwargs)
|
||||
model.default_cfg = default_cfgs['swsl_resnet50']
|
||||
if pretrained:
|
||||
load_pretrained(model, num_classes=kwargs.get('num_classes', 0), in_chans=kwargs.get('in_chans', 3))
|
||||
return model
|
||||
|
||||
|
||||
@register_model
|
||||
def swsl_resnext50_32x4d(pretrained=True, **kwargs):
|
||||
"""Constructs a semi-weakly supervised ResNeXt-50 32x4 model pre-trained on 1B weakly supervised
|
||||
image dataset and finetuned on ImageNet.
|
||||
`"Billion-scale Semi-Supervised Learning for Image Classification" <https://arxiv.org/abs/1905.00546>`_
|
||||
Weights from https://github.com/facebookresearch/semi-supervised-ImageNet1K-models/
|
||||
"""
|
||||
model = ResNet(Bottleneck, [3, 4, 6, 3], cardinality=32, base_width=4, **kwargs)
|
||||
model.default_cfg = default_cfgs['swsl_resnext50_32x4d']
|
||||
if pretrained:
|
||||
load_pretrained(model, num_classes=kwargs.get('num_classes', 0), in_chans=kwargs.get('in_chans', 3))
|
||||
return model
|
||||
|
||||
|
||||
@register_model
|
||||
def swsl_resnext101_32x4d(pretrained=True, **kwargs):
|
||||
"""Constructs a semi-weakly supervised ResNeXt-101 32x4 model pre-trained on 1B weakly supervised
|
||||
image dataset and finetuned on ImageNet.
|
||||
`"Billion-scale Semi-Supervised Learning for Image Classification" <https://arxiv.org/abs/1905.00546>`_
|
||||
Weights from https://github.com/facebookresearch/semi-supervised-ImageNet1K-models/
|
||||
"""
|
||||
model = ResNet(Bottleneck, [3, 4, 23, 3], cardinality=32, base_width=4, **kwargs)
|
||||
model.default_cfg = default_cfgs['swsl_resnext101_32x4d']
|
||||
if pretrained:
|
||||
load_pretrained(model, num_classes=kwargs.get('num_classes', 0), in_chans=kwargs.get('in_chans', 3))
|
||||
return model
|
||||
|
||||
|
||||
@register_model
|
||||
def swsl_resnext101_32x8d(pretrained=True, **kwargs):
|
||||
"""Constructs a semi-weakly supervised ResNeXt-101 32x8 model pre-trained on 1B weakly supervised
|
||||
image dataset and finetuned on ImageNet.
|
||||
`"Billion-scale Semi-Supervised Learning for Image Classification" <https://arxiv.org/abs/1905.00546>`_
|
||||
Weights from https://github.com/facebookresearch/semi-supervised-ImageNet1K-models/
|
||||
"""
|
||||
model = ResNet(Bottleneck, [3, 4, 23, 3], cardinality=32, base_width=8, **kwargs)
|
||||
model.default_cfg = default_cfgs['swsl_resnext101_32x8d']
|
||||
if pretrained:
|
||||
load_pretrained(model, num_classes=kwargs.get('num_classes', 0), in_chans=kwargs.get('in_chans', 3))
|
||||
return model
|
||||
|
||||
|
||||
@register_model
|
||||
def swsl_resnext101_32x16d(pretrained=True, **kwargs):
|
||||
"""Constructs a semi-weakly supervised ResNeXt-101 32x16 model pre-trained on 1B weakly supervised
|
||||
image dataset and finetuned on ImageNet.
|
||||
`"Billion-scale Semi-Supervised Learning for Image Classification" <https://arxiv.org/abs/1905.00546>`_
|
||||
Weights from https://github.com/facebookresearch/semi-supervised-ImageNet1K-models/
|
||||
"""
|
||||
model = ResNet(Bottleneck, [3, 4, 23, 3], cardinality=32, base_width=16, **kwargs)
|
||||
model.default_cfg = default_cfgs['swsl_resnext101_32x16d']
|
||||
if pretrained:
|
||||
load_pretrained(model, num_classes=kwargs.get('num_classes', 0), in_chans=kwargs.get('in_chans', 3))
|
||||
return model
|
||||
|
|
Loading…
Reference in New Issue