build new classifier; add squeezenet1_0_fc512
parent
ee5bcee63f
commit
3ddf9ce699
|
@ -40,6 +40,7 @@ __model_factory = {
|
|||
'mobilenetv2': MobileNetV2,
|
||||
'shufflenet': ShuffleNet,
|
||||
'squeezenet1_0': squeezenet1_0,
|
||||
'squeezenet1_0_fc512': squeezenet1_0_fc512,
|
||||
'squeezenet1_1': squeezenet1_1,
|
||||
# reid-specific models
|
||||
'mudeep': MuDeep,
|
||||
|
|
|
@ -13,7 +13,7 @@ import torchvision
|
|||
import torch.utils.model_zoo as model_zoo
|
||||
|
||||
|
||||
__all__ = ['squeezenet1_0', 'squeezenet1_1']
|
||||
__all__ = ['squeezenet1_0', 'squeezenet1_1', 'squeezenet1_0_fc512']
|
||||
|
||||
|
||||
model_urls = {
|
||||
|
@ -53,13 +53,15 @@ class SqueezeNet(nn.Module):
|
|||
Iandola et al. SqueezeNet: AlexNet-level accuracy with 50x fewer parameters
|
||||
and< 0.5 MB model size. arXiv:1602.07360.
|
||||
"""
|
||||
def __init__(self, num_classes, loss, version=1.0, **kwargs):
|
||||
def __init__(self, num_classes, loss, version=1.0, fc_dims=None, dropout_p=None, **kwargs):
|
||||
super(SqueezeNet, self).__init__()
|
||||
self.loss = loss
|
||||
self.feature_dim = 512
|
||||
|
||||
if version not in [1.0, 1.1]:
|
||||
raise ValueError("Unsupported SqueezeNet version {version}:"
|
||||
"1.0 or 1.1 expected".format(version=version))
|
||||
self.num_classes = num_classes
|
||||
|
||||
if version == 1.0:
|
||||
self.features = nn.Sequential(
|
||||
nn.Conv2d(3, 96, kernel_size=7, stride=2),
|
||||
|
@ -92,40 +94,74 @@ class SqueezeNet(nn.Module):
|
|||
Fire(384, 64, 256, 256),
|
||||
Fire(512, 64, 256, 256),
|
||||
)
|
||||
# Final convolution is initialized differently form the rest
|
||||
final_conv = nn.Conv2d(512, self.num_classes, kernel_size=1)
|
||||
self.classifier = nn.Sequential(
|
||||
nn.Dropout(p=0.5),
|
||||
final_conv,
|
||||
nn.ReLU(inplace=True),
|
||||
nn.AdaptiveAvgPool2d(1)
|
||||
)
|
||||
|
||||
self.global_avgpool = nn.AdaptiveAvgPool2d(1)
|
||||
self.fc = self._construct_fc_layer(fc_dims, 512, dropout_p)
|
||||
self.classifier = nn.Linear(self.feature_dim, num_classes)
|
||||
|
||||
self._init_params()
|
||||
|
||||
def _construct_fc_layer(self, fc_dims, input_dim, dropout_p=None):
|
||||
"""
|
||||
Construct fully connected layer
|
||||
|
||||
- fc_dims (list or tuple): dimensions of fc layers, if None,
|
||||
no fc layers are constructed
|
||||
- input_dim (int): input dimension
|
||||
- dropout_p (float): dropout probability, if None, dropout is unused
|
||||
"""
|
||||
if fc_dims is None:
|
||||
self.feature_dim = input_dim
|
||||
return None
|
||||
|
||||
assert isinstance(fc_dims, (list, tuple)), "fc_dims must be either list or tuple, but got {}".format(type(fc_dims))
|
||||
|
||||
layers = []
|
||||
for dim in fc_dims:
|
||||
layers.append(nn.Linear(input_dim, dim))
|
||||
layers.append(nn.BatchNorm1d(dim))
|
||||
layers.append(nn.ReLU(inplace=True))
|
||||
if dropout_p is not None:
|
||||
layers.append(nn.Dropout(p=dropout_p))
|
||||
input_dim = dim
|
||||
|
||||
self.feature_dim = fc_dims[-1]
|
||||
|
||||
return nn.Sequential(*layers)
|
||||
|
||||
def _init_params(self):
|
||||
for m in self.modules():
|
||||
if isinstance(m, nn.Conv2d):
|
||||
if m is final_conv:
|
||||
init.normal_(m.weight, mean=0.0, std=0.01)
|
||||
else:
|
||||
init.kaiming_uniform_(m.weight)
|
||||
nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
|
||||
if m.bias is not None:
|
||||
init.constant_(m.bias, 0)
|
||||
nn.init.constant_(m.bias, 0)
|
||||
elif isinstance(m, nn.BatchNorm2d):
|
||||
nn.init.constant_(m.weight, 1)
|
||||
nn.init.constant_(m.bias, 0)
|
||||
elif isinstance(m, nn.BatchNorm1d):
|
||||
nn.init.constant_(m.weight, 1)
|
||||
nn.init.constant_(m.bias, 0)
|
||||
elif isinstance(m, nn.Linear):
|
||||
nn.init.normal_(m.weight, 0, 0.01)
|
||||
if m.bias is not None:
|
||||
nn.init.constant_(m.bias, 0)
|
||||
|
||||
def forward(self, x):
|
||||
f = self.features(x)
|
||||
v = self.global_avgpool(f)
|
||||
v = v.view(v.size(0), -1)
|
||||
|
||||
if self.fc is not None:
|
||||
v = self.fc(v)
|
||||
|
||||
if not self.training:
|
||||
v = F.adaptive_avg_pool2d(f, 1)
|
||||
v = v.view(v.size(0), -1)
|
||||
return v
|
||||
|
||||
y = self.classifier(f)
|
||||
y = y.view(y.size(0), self.num_classes)
|
||||
y = self.classifier(v)
|
||||
|
||||
if self.loss == {'xent'}:
|
||||
return y
|
||||
elif self.loss == {'xent', 'htri'}:
|
||||
v = F.adaptive_avg_pool2d(f, 1)
|
||||
v = v.view(v.size(0), -1)
|
||||
return y, v
|
||||
else:
|
||||
raise KeyError("Unsupported loss: {}".format(self.loss))
|
||||
|
@ -145,14 +181,39 @@ def init_pretrained_weights(model, model_url):
|
|||
|
||||
|
||||
def squeezenet1_0(num_classes, loss, pretrained=True, **kwargs):
|
||||
model = SqueezeNet(num_classes, loss, version=1.0, **kwargs)
|
||||
model = SqueezeNet(
|
||||
num_classes, loss,
|
||||
version=1.0,
|
||||
fc_dims=None,
|
||||
dropout_p=None,
|
||||
**kwargs
|
||||
)
|
||||
if pretrained:
|
||||
init_pretrained_weights(model, model_urls['squeezenet1_0'])
|
||||
return model
|
||||
|
||||
|
||||
def squeezenet1_0_fc512(num_classes, loss, pretrained=True, **kwargs):
|
||||
model = SqueezeNet(
|
||||
num_classes, loss,
|
||||
version=1.0,
|
||||
fc_dims=[512],
|
||||
dropout_p=None,
|
||||
**kwargs
|
||||
)
|
||||
if pretrained:
|
||||
init_pretrained_weights(model, model_urls['squeezenet1_0'])
|
||||
return model
|
||||
|
||||
|
||||
def squeezenet1_1(num_classes, loss, pretrained=True, **kwargs):
|
||||
model = SqueezeNet(num_classes, loss, version=1.1, **kwargs)
|
||||
model = SqueezeNet(
|
||||
num_classes, loss,
|
||||
version=1.1,
|
||||
fc_dims=None,
|
||||
dropout_p=None,
|
||||
**kwargs
|
||||
)
|
||||
if pretrained:
|
||||
init_pretrained_weights(model, model_urls['squeezenet1_1'])
|
||||
return model
|
Loading…
Reference in New Issue