update squeezenet
parent
bf806a987f
commit
b28375e870
|
@ -39,7 +39,15 @@ class ExpandLayer(nn.Module):
|
|||
return x
|
||||
|
||||
class FireModule(nn.Module):
|
||||
"""Fire Module"""
|
||||
"""
|
||||
Args:
|
||||
in_channels (int): number of input channels.
|
||||
s1_channels (int): number of 1-by-1 filters for squeeze layer.
|
||||
e1_channels (int): number of 1-by-1 filters for expand layer.
|
||||
e3_channels (int): number of 3-by-3 filters for expand layer.
|
||||
|
||||
Number of output channels from FireModule is e1_channels + e3_channels.
|
||||
"""
|
||||
def __init__(self, in_channels, s1_channels, e1_channels, e3_channels):
|
||||
super(FireModule, self).__init__()
|
||||
self.squeeze = ConvBlock(in_channels, s1_channels, 1)
|
||||
|
@ -51,9 +59,16 @@ class FireModule(nn.Module):
|
|||
return x
|
||||
|
||||
class SqueezeNet(nn.Module):
|
||||
def __init__(self, num_classes, loss={'xent'}, **kwargs):
|
||||
"""SqueezeNet
|
||||
|
||||
Reference:
|
||||
Iandola et al. SqueezeNet: AlexNet-level accuracy with 50x fewer parameters
|
||||
and< 0.5 MB model size. arXiv:1602.07360.
|
||||
"""
|
||||
def __init__(self, num_classes, loss={'xent'}, bypass=True, **kwargs):
|
||||
super(SqueezeNet, self).__init__()
|
||||
self.loss = loss
|
||||
self.bypass = bypass
|
||||
|
||||
self.conv1 = ConvBlock(3, 96, 7, s=2, p=2)
|
||||
self.fire2 = FireModule(96, 16, 64, 64)
|
||||
|
@ -66,27 +81,43 @@ class SqueezeNet(nn.Module):
|
|||
self.fire9 = FireModule(512, 64, 256, 256)
|
||||
self.conv10 = ConvBlock(512, 1000, 1)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.conv1(x)
|
||||
x = F.max_pool2d(x, 3, stride=2)
|
||||
x = self.fire2(x)
|
||||
x = self.fire3(x)
|
||||
x = self.fire4(x)
|
||||
x = F.max_pool2d(x, 3, stride=2)
|
||||
x = self.fire5(x)
|
||||
x = self.fire6(x)
|
||||
x = self.fire7(x)
|
||||
x = self.fire8(x)
|
||||
x = F.max_pool2d(x, 3, stride=2)
|
||||
x = self.fire9(x)
|
||||
x = self.conv10(x)
|
||||
x = F.avg_pool2d(x, x.size()[2:])
|
||||
return x
|
||||
self.classifier = nn.Linear(1000, num_classes)
|
||||
self.feat_dim = 1000
|
||||
|
||||
if __name__ == '__main__':
|
||||
model = SqueezeNet(10)
|
||||
model.eval()
|
||||
x = torch.rand(1, 3, 256, 128)
|
||||
with torch.no_grad():
|
||||
y = model(x)
|
||||
print "output size {}".format(y.size())
|
||||
def forward(self, x):
|
||||
x1 = self.conv1(x)
|
||||
x1 = F.max_pool2d(x1, 3, stride=2)
|
||||
x2 = self.fire2(x1)
|
||||
x3 = self.fire3(x2)
|
||||
if self.bypass:
|
||||
x3 = x3 + x2
|
||||
x4 = self.fire4(x3)
|
||||
x4 = F.max_pool2d(x4, 3, stride=2)
|
||||
x5 = self.fire5(x4)
|
||||
if self.bypass:
|
||||
x5 = x5 + x4
|
||||
x6 = self.fire6(x5)
|
||||
x7 = self.fire7(x6)
|
||||
if self.bypass:
|
||||
x7 = x7 + x6
|
||||
x8 = self.fire8(x7)
|
||||
x8 = F.max_pool2d(x8, 3, stride=2)
|
||||
x9 = self.fire9(x8)
|
||||
if self.bypass:
|
||||
x9 = x9 + x8
|
||||
x10 = self.conv10(x9)
|
||||
f = F.avg_pool2d(x10, x10.size()[2:]).view(x10.size(0), -1)
|
||||
|
||||
if not self.training:
|
||||
return f
|
||||
|
||||
y = self.classifier(f)
|
||||
|
||||
if self.loss == {'xent'}:
|
||||
return y
|
||||
elif self.loss == {'xent', 'htri'}:
|
||||
return y, f
|
||||
elif self.loss == {'cent'}:
|
||||
return y, f
|
||||
else:
|
||||
raise KeyError("Unsupported loss: {}".format(self.loss))
|
Loading…
Reference in New Issue