Add some models, remove a model, tweak some models
parent
31055466fc
commit
e0cfeb7d8e
|
@ -14,21 +14,17 @@ import torch.nn as nn
|
|||
import torch.nn.functional as F
|
||||
|
||||
|
||||
def adaptive_avgmax_pool2d(x, pool_type='avg', padding=0, count_include_pad=False):
|
||||
def adaptive_avgmax_pool2d(x, pool_type='avg', output_size=1):
|
||||
"""Selectable global pooling function with dynamic input kernel size
|
||||
"""
|
||||
if pool_type == 'avgmax':
|
||||
x_avg = F.avg_pool2d(
|
||||
x, kernel_size=(x.size(2), x.size(3)), padding=padding, count_include_pad=count_include_pad)
|
||||
x_max = F.max_pool2d(x, kernel_size=(x.size(2), x.size(3)), padding=padding)
|
||||
x_avg = F.adaptive_avg_pool2d(x, output_size)
|
||||
x_max = F.adaptive_max_pool2d(x, output_size)
|
||||
x = 0.5 * (x_avg + x_max)
|
||||
elif pool_type == 'max':
|
||||
x = F.max_pool2d(x, kernel_size=(x.size(2), x.size(3)), padding=padding)
|
||||
x = F.adaptive_max_pool2d(x, output_size)
|
||||
else:
|
||||
if pool_type != 'avg':
|
||||
print('Invalid pool type %s specified. Defaulting to average pooling.' % pool_type)
|
||||
x = F.avg_pool2d(
|
||||
x, kernel_size=(x.size(2), x.size(3)), padding=padding, count_include_pad=count_include_pad)
|
||||
x = F.adaptive_avg_pool2d(x, output_size)
|
||||
return x
|
||||
|
||||
|
||||
|
|
|
@ -8,6 +8,7 @@ import torch.nn.functional as F
|
|||
import torch.utils.model_zoo as model_zoo
|
||||
from collections import OrderedDict
|
||||
from .adaptive_avgmax_pool import *
|
||||
import re
|
||||
|
||||
__all__ = ['DenseNet', 'densenet121', 'densenet169', 'densenet201', 'densenet161']
|
||||
|
||||
|
@ -20,6 +21,19 @@ model_urls = {
|
|||
}
|
||||
|
||||
|
||||
def _filter_pretrained(state_dict):
|
||||
pattern = re.compile(
|
||||
r'^(.*denselayer\d+\.(?:norm|relu|conv))\.((?:[12])\.(?:weight|bias|running_mean|running_var))$')
|
||||
|
||||
for key in list(state_dict.keys()):
|
||||
res = pattern.match(key)
|
||||
if res:
|
||||
new_key = res.group(1) + res.group(2)
|
||||
state_dict[new_key] = state_dict[key]
|
||||
del state_dict[key]
|
||||
return state_dict
|
||||
|
||||
|
||||
def densenet121(pretrained=False, **kwargs):
|
||||
r"""Densenet-121 model from
|
||||
`"Densely Connected Convolutional Networks" <https://arxiv.org/pdf/1608.06993.pdf>`
|
||||
|
@ -29,7 +43,8 @@ def densenet121(pretrained=False, **kwargs):
|
|||
"""
|
||||
model = DenseNet(num_init_features=64, growth_rate=32, block_config=(6, 12, 24, 16), **kwargs)
|
||||
if pretrained:
|
||||
model.load_state_dict(model_zoo.load_url(model_urls['densenet121']))
|
||||
state_dict = model_zoo.load_url(model_urls['densenet121'])
|
||||
model.load_state_dict(_filter_pretrained(state_dict))
|
||||
return model
|
||||
|
||||
|
||||
|
@ -42,7 +57,8 @@ def densenet169(pretrained=False, **kwargs):
|
|||
"""
|
||||
model = DenseNet(num_init_features=64, growth_rate=32, block_config=(6, 12, 32, 32), **kwargs)
|
||||
if pretrained:
|
||||
model.load_state_dict(model_zoo.load_url(model_urls['densenet169']))
|
||||
state_dict = model_zoo.load_url(model_urls['densenet169'])
|
||||
model.load_state_dict(_filter_pretrained(state_dict))
|
||||
return model
|
||||
|
||||
|
||||
|
@ -55,7 +71,8 @@ def densenet201(pretrained=False, **kwargs):
|
|||
"""
|
||||
model = DenseNet(num_init_features=64, growth_rate=32, block_config=(6, 12, 48, 32), **kwargs)
|
||||
if pretrained:
|
||||
model.load_state_dict(model_zoo.load_url(model_urls['densenet201']))
|
||||
state_dict = model_zoo.load_url(model_urls['densenet201'])
|
||||
model.load_state_dict(_filter_pretrained(state_dict))
|
||||
return model
|
||||
|
||||
|
||||
|
@ -69,20 +86,21 @@ def densenet161(pretrained=False, **kwargs):
|
|||
print(kwargs)
|
||||
model = DenseNet(num_init_features=96, growth_rate=48, block_config=(6, 12, 36, 24), **kwargs)
|
||||
if pretrained:
|
||||
model.load_state_dict(model_zoo.load_url(model_urls['densenet161']))
|
||||
state_dict = model_zoo.load_url(model_urls['densenet161'])
|
||||
model.load_state_dict(_filter_pretrained(state_dict))
|
||||
return model
|
||||
|
||||
|
||||
class _DenseLayer(nn.Sequential):
|
||||
def __init__(self, num_input_features, growth_rate, bn_size, drop_rate):
|
||||
super(_DenseLayer, self).__init__()
|
||||
self.add_module('norm.1', nn.BatchNorm2d(num_input_features)),
|
||||
self.add_module('relu.1', nn.ReLU(inplace=True)),
|
||||
self.add_module('conv.1', nn.Conv2d(num_input_features, bn_size *
|
||||
self.add_module('norm1', nn.BatchNorm2d(num_input_features)),
|
||||
self.add_module('relu1', nn.ReLU(inplace=True)),
|
||||
self.add_module('conv1', nn.Conv2d(num_input_features, bn_size *
|
||||
growth_rate, kernel_size=1, stride=1, bias=False)),
|
||||
self.add_module('norm.2', nn.BatchNorm2d(bn_size * growth_rate)),
|
||||
self.add_module('relu.2', nn.ReLU(inplace=True)),
|
||||
self.add_module('conv.2', nn.Conv2d(bn_size * growth_rate, growth_rate,
|
||||
self.add_module('norm2', nn.BatchNorm2d(bn_size * growth_rate)),
|
||||
self.add_module('relu2', nn.ReLU(inplace=True)),
|
||||
self.add_module('conv2', nn.Conv2d(bn_size * growth_rate, growth_rate,
|
||||
kernel_size=3, stride=1, padding=1, bias=False)),
|
||||
self.drop_rate = drop_rate
|
||||
|
||||
|
@ -172,12 +190,12 @@ class DenseNet(nn.Module):
|
|||
self.classifier = None
|
||||
|
||||
def forward_features(self, x, pool=True):
|
||||
features = self.features(x)
|
||||
out = F.relu(features, inplace=True)
|
||||
x = self.features(x)
|
||||
x = F.relu(x, inplace=True)
|
||||
if pool:
|
||||
out = adaptive_avgmax_pool2d(out, self.global_pool)
|
||||
out = x.view(out.size(0), -1)
|
||||
return out
|
||||
x = adaptive_avgmax_pool2d(x, self.global_pool)
|
||||
x = x.view(x.size(0), -1)
|
||||
return x
|
||||
|
||||
def forward(self, x):
|
||||
return self.classifier(self.forward_features(x, pool=True))
|
|
@ -6,13 +6,13 @@ import os
|
|||
|
||||
from .inception_v4 import inception_v4
|
||||
from .inception_resnet_v2 import inception_resnet_v2
|
||||
from .wrn50_2 import wrn50_2
|
||||
from .my_densenet import densenet161, densenet121, densenet169, densenet201
|
||||
from .my_resnet import resnet18, resnet34, resnet50, resnet101, resnet152
|
||||
from .densenet import densenet161, densenet121, densenet169, densenet201
|
||||
from .resnet import resnet18, resnet34, resnet50, resnet101, resnet152
|
||||
from .fbresnet200 import fbresnet200
|
||||
from .dpn import dpn68, dpn68b, dpn92, dpn98, dpn131, dpn107
|
||||
from .senet import se_resnet18, se_resnet34, se_resnet50, se_resnet101, se_resnet152,\
|
||||
se_resnext50_32x4d, se_resnext101_32x4d
|
||||
from .senet import seresnet18, seresnet34, seresnet50, seresnet101, seresnet152,\
|
||||
seresnext50_32x4d, seresnext101_32x4d
|
||||
from .resnext import resnext50, resnext101, resnext152
|
||||
|
||||
|
||||
model_config_dict = {
|
||||
|
@ -99,15 +99,29 @@ def create_model(
|
|||
elif model_name == 'inception_resnet_v2':
|
||||
model = inception_resnet_v2(num_classes=num_classes, pretrained=pretrained, **kwargs)
|
||||
elif model_name == 'inception_v4':
|
||||
model = inception_v4(num_classes=num_classes, pretrained=pretrained, **kwargs)
|
||||
elif model_name == 'wrn50':
|
||||
model = wrn50_2(num_classes=num_classes, pretrained=pretrained, **kwargs)
|
||||
model = inception_v4(num_classes=num_classes, pretrained=pretrained, **kwargs)
|
||||
elif model_name == 'fbresnet200':
|
||||
model = fbresnet200(num_classes=num_classes, pretrained=pretrained, **kwargs)
|
||||
model = fbresnet200(num_classes=num_classes, pretrained=pretrained, **kwargs)
|
||||
elif model_name == 'seresnet18':
|
||||
model = se_resnet18(num_classes=num_classes, pretrained=pretrained)
|
||||
model = seresnet18(num_classes=num_classes, pretrained=pretrained, **kwargs)
|
||||
elif model_name == 'seresnet34':
|
||||
model = se_resnet34(num_classes=num_classes, pretrained=pretrained)
|
||||
model = seresnet34(num_classes=num_classes, pretrained=pretrained, **kwargs)
|
||||
elif model_name == 'seresnet50':
|
||||
model = seresnet50(num_classes=num_classes, pretrained=pretrained, **kwargs)
|
||||
elif model_name == 'seresnet101':
|
||||
model = seresnet101(num_classes=num_classes, pretrained=pretrained, **kwargs)
|
||||
elif model_name == 'seresnet152':
|
||||
model = seresnet152(num_classes=num_classes, pretrained=pretrained, **kwargs)
|
||||
elif model_name == 'seresnext50_32x4d':
|
||||
model = seresnext50_32x4d(num_classes=num_classes, pretrained=pretrained, **kwargs)
|
||||
elif model_name == 'seresnext101_32x4d':
|
||||
model = seresnext101_32x4d(num_classes=num_classes, pretrained=pretrained, **kwargs)
|
||||
elif model_name == 'resnext50':
|
||||
model = resnext50(num_classes=num_classes, pretrained=pretrained, **kwargs)
|
||||
elif model_name == 'resnext101':
|
||||
model = resnext101(num_classes=num_classes, pretrained=pretrained, **kwargs)
|
||||
elif model_name == 'resnext152':
|
||||
model = resnext152(num_classes=num_classes, pretrained=pretrained, **kwargs)
|
||||
else:
|
||||
assert False and "Invalid model"
|
||||
|
||||
|
|
195
models/senet.py
195
models/senet.py
|
@ -9,102 +9,22 @@ import math
|
|||
import torch.nn as nn
|
||||
from torch.utils import model_zoo
|
||||
|
||||
__all__ = ['SENet', 'senet154', 'se_resnet50', 'se_resnet101', 'se_resnet152',
|
||||
'se_resnext50_32x4d', 'se_resnext101_32x4d']
|
||||
__all__ = ['SENet', 'senet154', 'seresnet50', 'seresnet101', 'seresnet152',
|
||||
'seresnext50_32x4d', 'seresnext101_32x4d']
|
||||
|
||||
pretrained_config = {
|
||||
'senet154': {
|
||||
'imagenet': {
|
||||
'url': 'http://data.lip6.fr/cadene/pretrainedmodels/senet154-c7b49a05.pth',
|
||||
'input_space': 'RGB',
|
||||
'input_size': [3, 224, 224],
|
||||
'input_range': [0, 1],
|
||||
'mean': [0.485, 0.456, 0.406],
|
||||
'std': [0.229, 0.224, 0.225],
|
||||
'num_classes': 1000
|
||||
}
|
||||
},
|
||||
'se_resnet18': {
|
||||
'imagenet': {
|
||||
'url': 'http://data.lip6.fr/cadene/pretrainedmodels/se_resnet50-ce0d4300.pth',
|
||||
'input_space': 'RGB',
|
||||
'input_size': [3, 224, 224],
|
||||
'input_range': [0, 1],
|
||||
'mean': [0.485, 0.456, 0.406],
|
||||
'std': [0.229, 0.224, 0.225],
|
||||
'num_classes': 1000
|
||||
}
|
||||
},
|
||||
'se_resnet34': {
|
||||
'imagenet': {
|
||||
'url': 'http://data.lip6.fr/cadene/pretrainedmodels/se_resnet50-ce0d4300.pth',
|
||||
'input_space': 'RGB',
|
||||
'input_size': [3, 224, 224],
|
||||
'input_range': [0, 1],
|
||||
'mean': [0.485, 0.456, 0.406],
|
||||
'std': [0.229, 0.224, 0.225],
|
||||
'num_classes': 1000
|
||||
}
|
||||
},
|
||||
'se_resnet50': {
|
||||
'imagenet': {
|
||||
'url': 'http://data.lip6.fr/cadene/pretrainedmodels/se_resnet50-ce0d4300.pth',
|
||||
'input_space': 'RGB',
|
||||
'input_size': [3, 224, 224],
|
||||
'input_range': [0, 1],
|
||||
'mean': [0.485, 0.456, 0.406],
|
||||
'std': [0.229, 0.224, 0.225],
|
||||
'num_classes': 1000
|
||||
}
|
||||
},
|
||||
'se_resnet101': {
|
||||
'imagenet': {
|
||||
'url': 'http://data.lip6.fr/cadene/pretrainedmodels/se_resnet101-7e38fcc6.pth',
|
||||
'input_space': 'RGB',
|
||||
'input_size': [3, 224, 224],
|
||||
'input_range': [0, 1],
|
||||
'mean': [0.485, 0.456, 0.406],
|
||||
'std': [0.229, 0.224, 0.225],
|
||||
'num_classes': 1000
|
||||
}
|
||||
},
|
||||
'se_resnet152': {
|
||||
'imagenet': {
|
||||
'url': 'http://data.lip6.fr/cadene/pretrainedmodels/se_resnet152-d17c99b7.pth',
|
||||
'input_space': 'RGB',
|
||||
'input_size': [3, 224, 224],
|
||||
'input_range': [0, 1],
|
||||
'mean': [0.485, 0.456, 0.406],
|
||||
'std': [0.229, 0.224, 0.225],
|
||||
'num_classes': 1000
|
||||
}
|
||||
},
|
||||
'se_resnext50_32x4d': {
|
||||
'imagenet': {
|
||||
'url': 'http://data.lip6.fr/cadene/pretrainedmodels/se_resnext50_32x4d-a260b3a4.pth',
|
||||
'input_space': 'RGB',
|
||||
'input_size': [3, 224, 224],
|
||||
'input_range': [0, 1],
|
||||
'mean': [0.485, 0.456, 0.406],
|
||||
'std': [0.229, 0.224, 0.225],
|
||||
'num_classes': 1000
|
||||
}
|
||||
},
|
||||
'se_resnext101_32x4d': {
|
||||
'imagenet': {
|
||||
'url': 'http://data.lip6.fr/cadene/pretrainedmodels/se_resnext101_32x4d-3b2fe3d8.pth',
|
||||
'input_space': 'RGB',
|
||||
'input_size': [3, 224, 224],
|
||||
'input_range': [0, 1],
|
||||
'mean': [0.485, 0.456, 0.406],
|
||||
'std': [0.229, 0.224, 0.225],
|
||||
'num_classes': 1000
|
||||
}
|
||||
},
|
||||
model_urls = {
|
||||
'senet154': 'http://data.lip6.fr/cadene/pretrainedmodels/senet154-c7b49a05.pth',
|
||||
'seresnet18': 'http://data.lip6.fr/cadene/pretrainedmodels/se_resnet50-ce0d4300.pth',
|
||||
'seresnet34': 'https://www.dropbox.com/s/q31ccy22aq0fju7/seresnet34-a4004e63.pth?dl=1',
|
||||
'seresnet50': 'http://data.lip6.fr/cadene/pretrainedmodels/se_resnet50-ce0d4300.pth',
|
||||
'seresnet101': 'http://data.lip6.fr/cadene/pretrainedmodels/se_resnet101-7e38fcc6.pth',
|
||||
'seresnet152': 'http://data.lip6.fr/cadene/pretrainedmodels/se_resnet152-d17c99b7.pth',
|
||||
'seresnext50_32x4d': 'http://data.lip6.fr/cadene/pretrainedmodels/se_resnext50_32x4d-a260b3a4.pth',
|
||||
'seresnext101_32x4d': 'http://data.lip6.fr/cadene/pretrainedmodels/se_resnext101_32x4d-3b2fe3d8.pth',
|
||||
}
|
||||
|
||||
|
||||
def _weight_init(m, n='', ll=''):
|
||||
def _weight_init(m):
|
||||
if isinstance(m, nn.Conv2d):
|
||||
nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
|
||||
elif isinstance(m, nn.BatchNorm2d):
|
||||
|
@ -138,6 +58,7 @@ class Bottleneck(nn.Module):
|
|||
"""
|
||||
Base class for bottlenecks that implements `forward()` method.
|
||||
"""
|
||||
|
||||
def forward(self, x):
|
||||
residual = x
|
||||
|
||||
|
@ -236,7 +157,7 @@ class SEResNeXtBottleneck(Bottleneck):
|
|||
|
||||
class SEResNetBlock(nn.Module):
|
||||
expansion = 1
|
||||
|
||||
|
||||
def __init__(self, inplanes, planes, groups, reduction, stride=1, downsample=None):
|
||||
super(SEResNetBlock, self).__init__()
|
||||
self.conv1 = nn.Conv2d(
|
||||
|
@ -273,7 +194,7 @@ class SEResNetBlock(nn.Module):
|
|||
class SENet(nn.Module):
|
||||
|
||||
def __init__(self, block, layers, groups, reduction, dropout_p=0.2,
|
||||
inch=3, inplanes=128, input_3x3=True, downsample_kernel_size=3,
|
||||
inchans=3, inplanes=128, input_3x3=True, downsample_kernel_size=3,
|
||||
downsample_padding=1, num_classes=1000):
|
||||
"""
|
||||
Parameters
|
||||
|
@ -320,9 +241,10 @@ class SENet(nn.Module):
|
|||
"""
|
||||
super(SENet, self).__init__()
|
||||
self.inplanes = inplanes
|
||||
self.num_classes = num_classes
|
||||
if input_3x3:
|
||||
layer0_modules = [
|
||||
('conv1', nn.Conv2d(inch, 64, 3, stride=2, padding=1, bias=False)),
|
||||
('conv1', nn.Conv2d(inchans, 64, 3, stride=2, padding=1, bias=False)),
|
||||
('bn1', nn.BatchNorm2d(64)),
|
||||
('relu1', nn.ReLU(inplace=True)),
|
||||
('conv2', nn.Conv2d(64, 64, 3, stride=1, padding=1, bias=False)),
|
||||
|
@ -335,7 +257,7 @@ class SENet(nn.Module):
|
|||
else:
|
||||
layer0_modules = [
|
||||
('conv1', nn.Conv2d(
|
||||
inch, inplanes, kernel_size=7, stride=2, padding=3, bias=False)),
|
||||
inchans, inplanes, kernel_size=7, stride=2, padding=3, bias=False)),
|
||||
('bn1', nn.BatchNorm2d(inplanes)),
|
||||
('relu1', nn.ReLU(inplace=True)),
|
||||
]
|
||||
|
@ -384,7 +306,8 @@ class SENet(nn.Module):
|
|||
)
|
||||
self.avg_pool = nn.AdaptiveAvgPool2d(1)
|
||||
self.dropout = nn.Dropout(dropout_p) if dropout_p is not None else None
|
||||
self.last_linear = nn.Linear(512 * block.expansion, num_classes)
|
||||
self.num_features = 512 * block.expansion
|
||||
self.last_linear = nn.Linear(self.num_features, num_classes)
|
||||
|
||||
for m in self.modules():
|
||||
_weight_init(m)
|
||||
|
@ -408,19 +331,31 @@ class SENet(nn.Module):
|
|||
|
||||
return nn.Sequential(*layers)
|
||||
|
||||
def forward_features(self, x):
|
||||
def get_classifier(self):
|
||||
return self.last_linear
|
||||
|
||||
def reset_classifier(self, num_classes):
|
||||
self.num_classes = num_classes
|
||||
del self.last_linear
|
||||
if num_classes:
|
||||
self.last_linear = nn.Linear(self.num_features, num_classes)
|
||||
else:
|
||||
self.last_linear = None
|
||||
|
||||
def forward_features(self, x, pool=True):
|
||||
x = self.layer0(x)
|
||||
x = self.layer1(x)
|
||||
x = self.layer2(x)
|
||||
x = self.layer3(x)
|
||||
x = self.layer4(x)
|
||||
if pool:
|
||||
x = self.avg_pool(x)
|
||||
x = x.view(x.size(0), -1)
|
||||
return x
|
||||
|
||||
def logits(self, x):
|
||||
x = self.avg_pool(x)
|
||||
if self.dropout is not None:
|
||||
x = self.dropout(x)
|
||||
x = x.view(x.size(0), -1)
|
||||
x = self.last_linear(x)
|
||||
return x
|
||||
|
||||
|
@ -430,99 +365,89 @@ class SENet(nn.Module):
|
|||
return x
|
||||
|
||||
|
||||
def initialize_pretrained_model(model, num_classes, config):
|
||||
assert num_classes == config['num_classes'], \
|
||||
'num_classes should be {}, but is {}'.format(
|
||||
config['num_classes'], num_classes)
|
||||
model.load_state_dict(model_zoo.load_url(config['url']))
|
||||
model.input_space = config['input_space']
|
||||
model.input_size = config['input_size']
|
||||
model.input_range = config['input_range']
|
||||
model.mean = config['mean']
|
||||
model.std = config['std']
|
||||
def _load_pretrained(model, url, inchans=3):
|
||||
state_dict = model_zoo.load_url(url)
|
||||
if inchans == 1:
|
||||
conv1_weight = state_dict['conv1.weight']
|
||||
state_dict['conv1.weight'] = conv1_weight.sum(dim=1, keepdim=True)
|
||||
elif inchans != 3:
|
||||
assert False, "Invalid inchans for pretrained weights"
|
||||
model.load_state_dict(state_dict)
|
||||
|
||||
|
||||
|
||||
def senet154(num_classes=1000, pretrained='imagenet'):
|
||||
def senet154(num_classes=1000, inchans=3, pretrained='imagenet'):
|
||||
model = SENet(SEBottleneck, [3, 8, 36, 3], groups=64, reduction=16,
|
||||
dropout_p=0.2, num_classes=num_classes)
|
||||
if pretrained:
|
||||
config = pretrained_config['senet154'][pretrained]
|
||||
initialize_pretrained_model(model, num_classes, config)
|
||||
_load_pretrained(model, model_urls['senet154'], inchans)
|
||||
return model
|
||||
|
||||
|
||||
def se_resnet18(num_classes=1000, pretrained='imagenet'):
|
||||
def seresnet18(num_classes=1000, inchans=3, pretrained='imagenet'):
|
||||
model = SENet(SEResNetBlock, [2, 2, 2, 2], groups=1, reduction=16,
|
||||
dropout_p=None, inplanes=64, input_3x3=False,
|
||||
downsample_kernel_size=1, downsample_padding=0,
|
||||
num_classes=num_classes)
|
||||
if pretrained:
|
||||
config = pretrained_config['se_resnet18'][pretrained]
|
||||
initialize_pretrained_model(model, num_classes, config)
|
||||
_load_pretrained(model, model_urls['seresnet18'], inchans)
|
||||
return model
|
||||
|
||||
|
||||
def se_resnet34(num_classes=1000, pretrained='imagenet'):
|
||||
def seresnet34(num_classes=1000, inchans=3, pretrained='imagenet'):
|
||||
model = SENet(SEResNetBlock, [3, 4, 6, 3], groups=1, reduction=16,
|
||||
dropout_p=None, inplanes=64, input_3x3=False,
|
||||
downsample_kernel_size=1, downsample_padding=0,
|
||||
num_classes=num_classes)
|
||||
if pretrained:
|
||||
config = pretrained_config['se_resnet34'][pretrained]
|
||||
initialize_pretrained_model(model, num_classes, config)
|
||||
_load_pretrained(model, model_urls['seresnet34'], inchans)
|
||||
return model
|
||||
|
||||
|
||||
def se_resnet50(num_classes=1000, pretrained='imagenet'):
|
||||
def seresnet50(num_classes=1000, inchans=3, pretrained='imagenet'):
|
||||
model = SENet(SEResNetBottleneck, [3, 4, 6, 3], groups=1, reduction=16,
|
||||
dropout_p=None, inplanes=64, input_3x3=False,
|
||||
downsample_kernel_size=1, downsample_padding=0,
|
||||
num_classes=num_classes)
|
||||
if pretrained:
|
||||
config = pretrained_config['se_resnet50'][pretrained]
|
||||
initialize_pretrained_model(model, num_classes, config)
|
||||
_load_pretrained(model, model_urls['seresnet50'], inchans)
|
||||
return model
|
||||
|
||||
|
||||
def se_resnet101(num_classes=1000, pretrained='imagenet'):
|
||||
def seresnet101(num_classes=1000, inchans=3, pretrained='imagenet'):
|
||||
model = SENet(SEResNetBottleneck, [3, 4, 23, 3], groups=1, reduction=16,
|
||||
dropout_p=None, inplanes=64, input_3x3=False,
|
||||
downsample_kernel_size=1, downsample_padding=0,
|
||||
num_classes=num_classes)
|
||||
if pretrained:
|
||||
config = pretrained_config['se_resnet101'][pretrained]
|
||||
initialize_pretrained_model(model, num_classes, config)
|
||||
_load_pretrained(model, model_urls['seresnet101'], inchans)
|
||||
return model
|
||||
|
||||
|
||||
def se_resnet152(num_classes=1000, pretrained='imagenet'):
|
||||
def seresnet152(num_classes=1000, inchans=3, pretrained='imagenet'):
|
||||
model = SENet(SEResNetBottleneck, [3, 8, 36, 3], groups=1, reduction=16,
|
||||
dropout_p=None, inplanes=64, input_3x3=False,
|
||||
downsample_kernel_size=1, downsample_padding=0,
|
||||
num_classes=num_classes)
|
||||
if pretrained:
|
||||
config = pretrained_config['se_resnet152'][pretrained]
|
||||
initialize_pretrained_model(model, num_classes, config)
|
||||
_load_pretrained(model, model_urls['seresnet152'], inchans)
|
||||
return model
|
||||
|
||||
|
||||
def se_resnext50_32x4d(num_classes=1000, pretrained='imagenet'):
|
||||
def seresnext50_32x4d(num_classes=1000, inchans=3, pretrained='imagenet'):
|
||||
model = SENet(SEResNeXtBottleneck, [3, 4, 6, 3], groups=32, reduction=16,
|
||||
dropout_p=None, inplanes=64, input_3x3=False,
|
||||
downsample_kernel_size=1, downsample_padding=0,
|
||||
num_classes=num_classes)
|
||||
if pretrained:
|
||||
config = pretrained_config['se_resnext50_32x4d'][pretrained]
|
||||
initialize_pretrained_model(model, num_classes, config)
|
||||
_load_pretrained(model, model_urls['seresnext50_32x4d'], inchans)
|
||||
return model
|
||||
|
||||
|
||||
def se_resnext101_32x4d(num_classes=1000, pretrained='imagenet'):
|
||||
def seresnext101_32x4d(num_classes=1000, inchans=3, pretrained='imagenet'):
|
||||
model = SENet(SEResNeXtBottleneck, [3, 4, 23, 3], groups=32, reduction=16,
|
||||
dropout_p=None, inplanes=64, input_3x3=False,
|
||||
downsample_kernel_size=1, downsample_padding=0,
|
||||
num_classes=num_classes)
|
||||
if pretrained:
|
||||
config = pretrained_config['se_resnext101_32x4d'][pretrained]
|
||||
initialize_pretrained_model(model, num_classes, config)
|
||||
_load_pretrained(model, model_urls['seresnext101_32x4d'], inchans)
|
||||
return model
|
||||
|
|
|
@ -1,393 +0,0 @@
|
|||
""" Pytorch Wide-Resnet-50-2
|
||||
Sourced by running https://github.com/clcarwin/convert_torch_to_pytorch (MIT) on
|
||||
https://github.com/szagoruyko/wide-residual-networks/blob/master/pretrained/README.md
|
||||
License of above is, as of yet, unclear.
|
||||
"""
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
import torch.utils.model_zoo as model_zoo
|
||||
from functools import reduce
|
||||
from collections import OrderedDict
|
||||
from .adaptive_avgmax_pool import *
|
||||
|
||||
model_urls = {
|
||||
'wrn50_2': 'https://www.dropbox.com/s/fe7rj3okz9rctn0/wrn50_2-d98ded61.pth?dl=1',
|
||||
}
|
||||
|
||||
|
||||
class LambdaBase(nn.Sequential):
|
||||
def __init__(self, fn, *args):
|
||||
super(LambdaBase, self).__init__(*args)
|
||||
self.lambda_func = fn
|
||||
|
||||
def forward_prepare(self, input):
|
||||
output = []
|
||||
for module in self._modules.values():
|
||||
output.append(module(input))
|
||||
return output if output else input
|
||||
|
||||
|
||||
class Lambda(LambdaBase):
|
||||
def forward(self, input):
|
||||
return self.lambda_func(self.forward_prepare(input))
|
||||
|
||||
|
||||
class LambdaMap(LambdaBase):
|
||||
def forward(self, input):
|
||||
return list(map(self.lambda_func, self.forward_prepare(input)))
|
||||
|
||||
|
||||
class LambdaReduce(LambdaBase):
|
||||
def forward(self, input):
|
||||
return reduce(self.lambda_func, self.forward_prepare(input))
|
||||
|
||||
|
||||
def wrn_50_2_features(activation_fn=nn.ReLU()):
|
||||
features = nn.Sequential( # Sequential,
|
||||
nn.Conv2d(3, 64, (7, 7), (2, 2), (3, 3), 1, 1, bias=False),
|
||||
nn.BatchNorm2d(64),
|
||||
activation_fn,
|
||||
nn.MaxPool2d((3, 3), (2, 2), (1, 1)),
|
||||
nn.Sequential( # Sequential,
|
||||
nn.Sequential( # Sequential,
|
||||
LambdaMap(lambda x: x, # ConcatTable,
|
||||
nn.Sequential( # Sequential,
|
||||
nn.Conv2d(64, 128, (1, 1), (1, 1), (0, 0), 1, 1, bias=False),
|
||||
nn.BatchNorm2d(128),
|
||||
activation_fn,
|
||||
nn.Conv2d(128, 128, (3, 3), (1, 1), (1, 1), 1, 1, bias=False),
|
||||
nn.BatchNorm2d(128),
|
||||
activation_fn,
|
||||
nn.Conv2d(128, 256, (1, 1), (1, 1), (0, 0), 1, 1, bias=False),
|
||||
nn.BatchNorm2d(256),
|
||||
),
|
||||
nn.Sequential( # Sequential,
|
||||
nn.Conv2d(64, 256, (1, 1), (1, 1), (0, 0), 1, 1, bias=False),
|
||||
nn.BatchNorm2d(256),
|
||||
),
|
||||
),
|
||||
LambdaReduce(lambda x, y: x + y), # CAddTable,
|
||||
activation_fn,
|
||||
),
|
||||
nn.Sequential( # Sequential,
|
||||
LambdaMap(lambda x: x, # ConcatTable,
|
||||
nn.Sequential( # Sequential,
|
||||
nn.Conv2d(256, 128, (1, 1), (1, 1), (0, 0), 1, 1, bias=False),
|
||||
nn.BatchNorm2d(128),
|
||||
activation_fn,
|
||||
nn.Conv2d(128, 128, (3, 3), (1, 1), (1, 1), 1, 1, bias=False),
|
||||
nn.BatchNorm2d(128),
|
||||
activation_fn,
|
||||
nn.Conv2d(128, 256, (1, 1), (1, 1), (0, 0), 1, 1, bias=False),
|
||||
nn.BatchNorm2d(256),
|
||||
),
|
||||
Lambda(lambda x: x), # Identity,
|
||||
),
|
||||
LambdaReduce(lambda x, y: x + y), # CAddTable,
|
||||
activation_fn,
|
||||
),
|
||||
nn.Sequential( # Sequential,
|
||||
LambdaMap(lambda x: x, # ConcatTable,
|
||||
nn.Sequential( # Sequential,
|
||||
nn.Conv2d(256, 128, (1, 1), (1, 1), (0, 0), 1, 1, bias=False),
|
||||
nn.BatchNorm2d(128),
|
||||
activation_fn,
|
||||
nn.Conv2d(128, 128, (3, 3), (1, 1), (1, 1), 1, 1, bias=False),
|
||||
nn.BatchNorm2d(128),
|
||||
activation_fn,
|
||||
nn.Conv2d(128, 256, (1, 1), (1, 1), (0, 0), 1, 1, bias=False),
|
||||
nn.BatchNorm2d(256),
|
||||
),
|
||||
Lambda(lambda x: x), # Identity,
|
||||
),
|
||||
LambdaReduce(lambda x, y: x + y), # CAddTable,
|
||||
activation_fn,
|
||||
),
|
||||
),
|
||||
nn.Sequential( # Sequential,
|
||||
nn.Sequential( # Sequential,
|
||||
LambdaMap(lambda x: x, # ConcatTable,
|
||||
nn.Sequential( # Sequential,
|
||||
nn.Conv2d(256, 256, (1, 1), (1, 1), (0, 0), 1, 1, bias=False),
|
||||
nn.BatchNorm2d(256),
|
||||
activation_fn,
|
||||
nn.Conv2d(256, 256, (3, 3), (2, 2), (1, 1), 1, 1, bias=False),
|
||||
nn.BatchNorm2d(256),
|
||||
activation_fn,
|
||||
nn.Conv2d(256, 512, (1, 1), (1, 1), (0, 0), 1, 1, bias=False),
|
||||
nn.BatchNorm2d(512),
|
||||
),
|
||||
nn.Sequential( # Sequential,
|
||||
nn.Conv2d(256, 512, (1, 1), (2, 2), (0, 0), 1, 1, bias=False),
|
||||
nn.BatchNorm2d(512),
|
||||
),
|
||||
),
|
||||
LambdaReduce(lambda x, y: x + y), # CAddTable,
|
||||
activation_fn,
|
||||
),
|
||||
nn.Sequential( # Sequential,
|
||||
LambdaMap(lambda x: x, # ConcatTable,
|
||||
nn.Sequential( # Sequential,
|
||||
nn.Conv2d(512, 256, (1, 1), (1, 1), (0, 0), 1, 1, bias=False),
|
||||
nn.BatchNorm2d(256),
|
||||
activation_fn,
|
||||
nn.Conv2d(256, 256, (3, 3), (1, 1), (1, 1), 1, 1, bias=False),
|
||||
nn.BatchNorm2d(256),
|
||||
activation_fn,
|
||||
nn.Conv2d(256, 512, (1, 1), (1, 1), (0, 0), 1, 1, bias=False),
|
||||
nn.BatchNorm2d(512),
|
||||
),
|
||||
Lambda(lambda x: x), # Identity,
|
||||
),
|
||||
LambdaReduce(lambda x, y: x + y), # CAddTable,
|
||||
activation_fn,
|
||||
),
|
||||
nn.Sequential( # Sequential,
|
||||
LambdaMap(lambda x: x, # ConcatTable,
|
||||
nn.Sequential( # Sequential,
|
||||
nn.Conv2d(512, 256, (1, 1), (1, 1), (0, 0), 1, 1, bias=False),
|
||||
nn.BatchNorm2d(256),
|
||||
activation_fn,
|
||||
nn.Conv2d(256, 256, (3, 3), (1, 1), (1, 1), 1, 1, bias=False),
|
||||
nn.BatchNorm2d(256),
|
||||
activation_fn,
|
||||
nn.Conv2d(256, 512, (1, 1), (1, 1), (0, 0), 1, 1, bias=False),
|
||||
nn.BatchNorm2d(512),
|
||||
),
|
||||
Lambda(lambda x: x), # Identity,
|
||||
),
|
||||
LambdaReduce(lambda x, y: x + y), # CAddTable,
|
||||
activation_fn,
|
||||
),
|
||||
nn.Sequential( # Sequential,
|
||||
LambdaMap(lambda x: x, # ConcatTable,
|
||||
nn.Sequential( # Sequential,
|
||||
nn.Conv2d(512, 256, (1, 1), (1, 1), (0, 0), 1, 1, bias=False),
|
||||
nn.BatchNorm2d(256),
|
||||
activation_fn,
|
||||
nn.Conv2d(256, 256, (3, 3), (1, 1), (1, 1), 1, 1, bias=False),
|
||||
nn.BatchNorm2d(256),
|
||||
activation_fn,
|
||||
nn.Conv2d(256, 512, (1, 1), (1, 1), (0, 0), 1, 1, bias=False),
|
||||
nn.BatchNorm2d(512),
|
||||
),
|
||||
Lambda(lambda x: x), # Identity,
|
||||
),
|
||||
LambdaReduce(lambda x, y: x + y), # CAddTable,
|
||||
activation_fn,
|
||||
),
|
||||
),
|
||||
nn.Sequential( # Sequential,
|
||||
nn.Sequential( # Sequential,
|
||||
LambdaMap(lambda x: x, # ConcatTable,
|
||||
nn.Sequential( # Sequential,
|
||||
nn.Conv2d(512, 512, (1, 1), (1, 1), (0, 0), 1, 1, bias=False),
|
||||
nn.BatchNorm2d(512),
|
||||
activation_fn,
|
||||
nn.Conv2d(512, 512, (3, 3), (2, 2), (1, 1), 1, 1, bias=False),
|
||||
nn.BatchNorm2d(512),
|
||||
activation_fn,
|
||||
nn.Conv2d(512, 1024, (1, 1), (1, 1), (0, 0), 1, 1, bias=False),
|
||||
nn.BatchNorm2d(1024),
|
||||
),
|
||||
nn.Sequential( # Sequential,
|
||||
nn.Conv2d(512, 1024, (1, 1), (2, 2), (0, 0), 1, 1, bias=False),
|
||||
nn.BatchNorm2d(1024),
|
||||
),
|
||||
),
|
||||
LambdaReduce(lambda x, y: x + y), # CAddTable,
|
||||
activation_fn,
|
||||
),
|
||||
nn.Sequential( # Sequential,
|
||||
LambdaMap(lambda x: x, # ConcatTable,
|
||||
nn.Sequential( # Sequential,
|
||||
nn.Conv2d(1024, 512, (1, 1), (1, 1), (0, 0), 1, 1, bias=False),
|
||||
nn.BatchNorm2d(512),
|
||||
activation_fn,
|
||||
nn.Conv2d(512, 512, (3, 3), (1, 1), (1, 1), 1, 1, bias=False),
|
||||
nn.BatchNorm2d(512),
|
||||
activation_fn,
|
||||
nn.Conv2d(512, 1024, (1, 1), (1, 1), (0, 0), 1, 1, bias=False),
|
||||
nn.BatchNorm2d(1024),
|
||||
),
|
||||
Lambda(lambda x: x), # Identity,
|
||||
),
|
||||
LambdaReduce(lambda x, y: x + y), # CAddTable,
|
||||
activation_fn,
|
||||
),
|
||||
nn.Sequential( # Sequential,
|
||||
LambdaMap(lambda x: x, # ConcatTable,
|
||||
nn.Sequential( # Sequential,
|
||||
nn.Conv2d(1024, 512, (1, 1), (1, 1), (0, 0), 1, 1, bias=False),
|
||||
nn.BatchNorm2d(512),
|
||||
activation_fn,
|
||||
nn.Conv2d(512, 512, (3, 3), (1, 1), (1, 1), 1, 1, bias=False),
|
||||
nn.BatchNorm2d(512),
|
||||
activation_fn,
|
||||
nn.Conv2d(512, 1024, (1, 1), (1, 1), (0, 0), 1, 1, bias=False),
|
||||
nn.BatchNorm2d(1024),
|
||||
),
|
||||
Lambda(lambda x: x), # Identity,
|
||||
),
|
||||
LambdaReduce(lambda x, y: x + y), # CAddTable,
|
||||
activation_fn,
|
||||
),
|
||||
nn.Sequential( # Sequential,
|
||||
LambdaMap(lambda x: x, # ConcatTable,
|
||||
nn.Sequential( # Sequential,
|
||||
nn.Conv2d(1024, 512, (1, 1), (1, 1), (0, 0), 1, 1, bias=False),
|
||||
nn.BatchNorm2d(512),
|
||||
activation_fn,
|
||||
nn.Conv2d(512, 512, (3, 3), (1, 1), (1, 1), 1, 1, bias=False),
|
||||
nn.BatchNorm2d(512),
|
||||
activation_fn,
|
||||
nn.Conv2d(512, 1024, (1, 1), (1, 1), (0, 0), 1, 1, bias=False),
|
||||
nn.BatchNorm2d(1024),
|
||||
),
|
||||
Lambda(lambda x: x), # Identity,
|
||||
),
|
||||
LambdaReduce(lambda x, y: x + y), # CAddTable,
|
||||
activation_fn,
|
||||
),
|
||||
nn.Sequential( # Sequential,
|
||||
LambdaMap(lambda x: x, # ConcatTable,
|
||||
nn.Sequential( # Sequential,
|
||||
nn.Conv2d(1024, 512, (1, 1), (1, 1), (0, 0), 1, 1, bias=False),
|
||||
nn.BatchNorm2d(512),
|
||||
activation_fn,
|
||||
nn.Conv2d(512, 512, (3, 3), (1, 1), (1, 1), 1, 1, bias=False),
|
||||
nn.BatchNorm2d(512),
|
||||
activation_fn,
|
||||
nn.Conv2d(512, 1024, (1, 1), (1, 1), (0, 0), 1, 1, bias=False),
|
||||
nn.BatchNorm2d(1024),
|
||||
),
|
||||
Lambda(lambda x: x), # Identity,
|
||||
),
|
||||
LambdaReduce(lambda x, y: x + y), # CAddTable,
|
||||
activation_fn,
|
||||
),
|
||||
nn.Sequential( # Sequential,
|
||||
LambdaMap(lambda x: x, # ConcatTable,
|
||||
nn.Sequential( # Sequential,
|
||||
nn.Conv2d(1024, 512, (1, 1), (1, 1), (0, 0), 1, 1, bias=False),
|
||||
nn.BatchNorm2d(512),
|
||||
activation_fn,
|
||||
nn.Conv2d(512, 512, (3, 3), (1, 1), (1, 1), 1, 1, bias=False),
|
||||
nn.BatchNorm2d(512),
|
||||
activation_fn,
|
||||
nn.Conv2d(512, 1024, (1, 1), (1, 1), (0, 0), 1, 1, bias=False),
|
||||
nn.BatchNorm2d(1024),
|
||||
),
|
||||
Lambda(lambda x: x), # Identity,
|
||||
),
|
||||
LambdaReduce(lambda x, y: x + y), # CAddTable,
|
||||
activation_fn,
|
||||
),
|
||||
),
|
||||
nn.Sequential( # Sequential,
|
||||
nn.Sequential( # Sequential,
|
||||
LambdaMap(lambda x: x, # ConcatTable,
|
||||
nn.Sequential( # Sequential,
|
||||
nn.Conv2d(1024, 1024, (1, 1), (1, 1), (0, 0), 1, 1, bias=False),
|
||||
nn.BatchNorm2d(1024),
|
||||
activation_fn,
|
||||
nn.Conv2d(1024, 1024, (3, 3), (2, 2), (1, 1), 1, 1, bias=False),
|
||||
nn.BatchNorm2d(1024),
|
||||
activation_fn,
|
||||
nn.Conv2d(1024, 2048, (1, 1), (1, 1), (0, 0), 1, 1, bias=False),
|
||||
nn.BatchNorm2d(2048),
|
||||
),
|
||||
nn.Sequential( # Sequential,
|
||||
nn.Conv2d(1024, 2048, (1, 1), (2, 2), (0, 0), 1, 1, bias=False),
|
||||
nn.BatchNorm2d(2048),
|
||||
),
|
||||
),
|
||||
LambdaReduce(lambda x, y: x + y), # CAddTable,
|
||||
activation_fn,
|
||||
),
|
||||
nn.Sequential( # Sequential,
|
||||
LambdaMap(lambda x: x, # ConcatTable,
|
||||
nn.Sequential( # Sequential,
|
||||
nn.Conv2d(2048, 1024, (1, 1), (1, 1), (0, 0), 1, 1, bias=False),
|
||||
nn.BatchNorm2d(1024),
|
||||
activation_fn,
|
||||
nn.Conv2d(1024, 1024, (3, 3), (1, 1), (1, 1), 1, 1, bias=False),
|
||||
nn.BatchNorm2d(1024),
|
||||
activation_fn,
|
||||
nn.Conv2d(1024, 2048, (1, 1), (1, 1), (0, 0), 1, 1, bias=False),
|
||||
nn.BatchNorm2d(2048),
|
||||
),
|
||||
Lambda(lambda x: x), # Identity,
|
||||
),
|
||||
LambdaReduce(lambda x, y: x + y), # CAddTable,
|
||||
activation_fn,
|
||||
),
|
||||
nn.Sequential( # Sequential,
|
||||
LambdaMap(lambda x: x, # ConcatTable,
|
||||
nn.Sequential( # Sequential,
|
||||
nn.Conv2d(2048, 1024, (1, 1), (1, 1), (0, 0), 1, 1, bias=False),
|
||||
nn.BatchNorm2d(1024),
|
||||
activation_fn,
|
||||
nn.Conv2d(1024, 1024, (3, 3), (1, 1), (1, 1), 1, 1, bias=False),
|
||||
nn.BatchNorm2d(1024),
|
||||
activation_fn,
|
||||
nn.Conv2d(1024, 2048, (1, 1), (1, 1), (0, 0), 1, 1, bias=False),
|
||||
nn.BatchNorm2d(2048),
|
||||
),
|
||||
Lambda(lambda x: x), # Identity,
|
||||
),
|
||||
LambdaReduce(lambda x, y: x + y), # CAddTable,
|
||||
activation_fn,
|
||||
),
|
||||
),
|
||||
)
|
||||
return features
|
||||
|
||||
|
||||
class Wrn50_2(nn.Module):
|
||||
def __init__(self, num_classes=1000, activation_fn=nn.ReLU(), drop_rate=0., global_pool='avg'):
|
||||
super(Wrn50_2, self).__init__()
|
||||
self.drop_rate = drop_rate
|
||||
self.num_classes = num_classes
|
||||
self.num_features = 2048
|
||||
self.global_pool = global_pool
|
||||
self.features = wrn_50_2_features(activation_fn=activation_fn)
|
||||
self.fc = nn.Linear(2048, num_classes)
|
||||
|
||||
def get_classifier(self):
|
||||
return self.fc
|
||||
|
||||
def reset_classifier(self, num_classes, global_pool='avg'):
|
||||
self.num_classes = num_classes
|
||||
self.global_pool = global_pool
|
||||
self.fc = nn.Linear(2048, num_classes)
|
||||
|
||||
def forward_features(self, x, pool=True):
|
||||
x = self.features(x)
|
||||
if pool:
|
||||
x = adaptive_avgmax_pool2d(x, self.global_pool)
|
||||
x = x.view(x.size(0), -1)
|
||||
return x
|
||||
|
||||
def forward(self, x):
|
||||
x = self.forward_features(x, pool=True)
|
||||
if self.drop_rate > 0:
|
||||
x = F.dropout(x, p=self.drop_rate, training=self.training)
|
||||
x = self.fc(x)
|
||||
return x
|
||||
|
||||
|
||||
def wrn50_2(pretrained=False, num_classes=1000, **kwargs):
|
||||
model = Wrn50_2(num_classes=num_classes, **kwargs)
|
||||
if pretrained:
|
||||
# Remap pretrained weights to match our class module with features + fc
|
||||
pretrained_weights = model_zoo.load_url(model_urls['wrn50_2'])
|
||||
feature_keys = filter(lambda k: '10.1.' not in k, pretrained_weights.keys())
|
||||
remapped_weights = OrderedDict()
|
||||
for k in feature_keys:
|
||||
remapped_weights['features.' + k] = pretrained_weights[k]
|
||||
remapped_weights['fc.weight'] = pretrained_weights['10.1.weight']
|
||||
remapped_weights['fc.bias'] = pretrained_weights['10.1.bias']
|
||||
model.load_state_dict(remapped_weights)
|
||||
return model
|
Loading…
Reference in New Issue