add resnext50_32x4d_fc512

pull/119/head
KaiyangZhou 2018-11-22 12:41:22 +00:00
parent b065f130d0
commit d7755cfd16
2 changed files with 20 additions and 1 deletions

View File

@ -25,6 +25,7 @@ __model_factory = {
'resnet50': resnet50,
'resnet50_fc512': resnet50_fc512,
'resnext50_32x4d': resnext50_32x4d,
'resnext50_32x4d_fc512': resnext50_32x4d_fc512,
'resnext101_32x4d': resnext101_32x4d,
'se_resnet50': se_resnet50,
'se_resnet50_fc512': se_resnet50_fc512,

View File

@ -10,7 +10,7 @@ import torchvision
import torch.utils.model_zoo as model_zoo
__all__ = ['resnext50_32x4d', 'resnext101_32x4d']
__all__ = ['resnext50_32x4d', 'resnext50_32x4d_fc512', 'resnext101_32x4d']
model_urls = {
@ -219,6 +219,24 @@ def resnext50_32x4d(num_classes, loss, pretrained='imagenet', **kwargs):
return model
def resnext50_32x4d_fc512(num_classes, loss, pretrained='imagenet', **kwargs):
model = ResNeXt(
num_classes=num_classes,
loss=loss,
block=ResNeXtBottleneck,
layers=[3, 4, 6, 3],
groups=32,
base_width=4,
last_stride=1,
fc_dims=[512],
dropout_p=None,
**kwargs
)
if pretrained == 'imagenet':
init_pretrained_weights(model, model_urls['resnext50_32x4d'])
return model
def resnext101_32x4d(num_classes, loss, pretrained='imagenet', **kwargs):
model = ResNeXt(
num_classes=num_classes,