added tensorboard support

This commit is contained in:
Thalles 2020-02-24 18:34:11 -03:00
parent ec0d482fc7
commit c138e77bdc

View File

@ -101,7 +101,7 @@ for e in range(config['epochs']):
l_pos = torch.bmm(zis.view(batch_size, 1, out_dim), zjs.view(batch_size, out_dim, 1)).view(batch_size, 1)
l_pos /= temperature
assert l_pos.shape == (batch_size, 1) # [N,1]
# assert l_pos.shape == (batch_size, 1) # [N,1]
negatives = torch.cat([zjs, zis], dim=0)
@ -124,8 +124,8 @@ for e in range(config['epochs']):
l_neg = l_neg[negative_mask].view(l_neg.shape[0], -1)
l_neg /= temperature
assert l_neg.shape == (batch_size, 2 * (batch_size - 1)), "Shape of negatives not expected." + str(
l_neg.shape)
# assert l_neg.shape == (batch_size, 2 * (batch_size - 1)), "Shape of negatives not expected." + str(
# l_neg.shape)
logits = torch.cat([l_pos, l_neg], dim=1) # [N,K+1]
loss += criterion(logits, labels)