update model

pull/17/head
KaiyangZhou 2018-04-26 17:59:33 +01:00
parent 31a88a7675
commit ac38a1c141
2 changed files with 16 additions and 2 deletions

View File

@ -10,7 +10,19 @@ Shorthands for loss:
- TripletLoss: htri
- CenterLoss: cent
"""
__all__ = ['CrossEntropyLabelSmooth', 'TripletLoss', 'CenterLoss']
__all__ = ['DeepSupervision', 'CrossEntropyLabelSmooth', 'TripletLoss', 'CenterLoss']
def DeepSupervision(criterion, xs, y):
"""
Args:
criterion: loss function
xs: tuple of inputs
y: ground truth
"""
loss = 0.
for x in xs:
loss += criterion(x, y)
return loss
class CrossEntropyLabelSmooth(nn.Module):
"""Cross entropy loss with label smoothing regularizer.

View File

@ -188,10 +188,11 @@ class HACNN(nn.Module):
feat_dim (int): feature dimension for each branch
learn_region (bool): whether to learn region features (i.e. local branch)
"""
def __init__(self, num_classes, loss={'xent'}, nchannels=[128, 256, 384], feat_dim=512, learn_region=True, **kwargs):
def __init__(self, num_classes, loss={'xent'}, nchannels=[128, 256, 384], feat_dim=512, learn_region=True, use_gpu=True, **kwargs):
super(HACNN, self).__init__()
self.loss = loss
self.learn_region = learn_region
self.use_gpu = use_gpu
self.conv = ConvBlock(3, 32, 3, s=2, p=1)
@ -267,6 +268,7 @@ class HACNN(nn.Module):
theta = torch.zeros(theta_i.size(0), 2, 3)
theta[:,:,:2] = scale_factors
theta[:,:,-1] = theta_i
if self.use_gpu: theta = theta.cuda()
return theta
def forward(self, x):