update squeezenet

pull/17/head
KaiyangZhou 2018-04-28 11:11:43 +01:00
parent bf806a987f
commit b28375e870
1 changed files with 56 additions and 25 deletions

View File

@ -39,7 +39,15 @@ class ExpandLayer(nn.Module):
return x
class FireModule(nn.Module):
"""Fire Module"""
"""
Args:
in_channels (int): number of input channels.
s1_channels (int): number of 1-by-1 filters for squeeze layer.
e1_channels (int): number of 1-by-1 filters for expand layer.
e3_channels (int): number of 3-by-3 filters for expand layer.
Number of output channels from FireModule is e1_channels + e3_channels.
"""
def __init__(self, in_channels, s1_channels, e1_channels, e3_channels):
super(FireModule, self).__init__()
self.squeeze = ConvBlock(in_channels, s1_channels, 1)
@ -51,9 +59,16 @@ class FireModule(nn.Module):
return x
class SqueezeNet(nn.Module):
def __init__(self, num_classes, loss={'xent'}, **kwargs):
"""SqueezeNet
Reference:
Iandola et al. SqueezeNet: AlexNet-level accuracy with 50x fewer parameters
and< 0.5 MB model size. arXiv:1602.07360.
"""
def __init__(self, num_classes, loss={'xent'}, bypass=True, **kwargs):
super(SqueezeNet, self).__init__()
self.loss = loss
self.bypass = bypass
self.conv1 = ConvBlock(3, 96, 7, s=2, p=2)
self.fire2 = FireModule(96, 16, 64, 64)
@ -66,27 +81,43 @@ class SqueezeNet(nn.Module):
self.fire9 = FireModule(512, 64, 256, 256)
self.conv10 = ConvBlock(512, 1000, 1)
def forward(self, x):
x = self.conv1(x)
x = F.max_pool2d(x, 3, stride=2)
x = self.fire2(x)
x = self.fire3(x)
x = self.fire4(x)
x = F.max_pool2d(x, 3, stride=2)
x = self.fire5(x)
x = self.fire6(x)
x = self.fire7(x)
x = self.fire8(x)
x = F.max_pool2d(x, 3, stride=2)
x = self.fire9(x)
x = self.conv10(x)
x = F.avg_pool2d(x, x.size()[2:])
return x
self.classifier = nn.Linear(1000, num_classes)
self.feat_dim = 1000
if __name__ == '__main__':
model = SqueezeNet(10)
model.eval()
x = torch.rand(1, 3, 256, 128)
with torch.no_grad():
y = model(x)
print "output size {}".format(y.size())
def forward(self, x):
x1 = self.conv1(x)
x1 = F.max_pool2d(x1, 3, stride=2)
x2 = self.fire2(x1)
x3 = self.fire3(x2)
if self.bypass:
x3 = x3 + x2
x4 = self.fire4(x3)
x4 = F.max_pool2d(x4, 3, stride=2)
x5 = self.fire5(x4)
if self.bypass:
x5 = x5 + x4
x6 = self.fire6(x5)
x7 = self.fire7(x6)
if self.bypass:
x7 = x7 + x6
x8 = self.fire8(x7)
x8 = F.max_pool2d(x8, 3, stride=2)
x9 = self.fire9(x8)
if self.bypass:
x9 = x9 + x8
x10 = self.conv10(x9)
f = F.avg_pool2d(x10, x10.size()[2:]).view(x10.size(0), -1)
if not self.training:
return f
y = self.classifier(f)
if self.loss == {'xent'}:
return y
elif self.loss == {'xent', 'htri'}:
return y, f
elif self.loss == {'cent'}:
return y, f
else:
raise KeyError("Unsupported loss: {}".format(self.loss))