From c138e77bdc6acc217b3cc273163bc7bee5b3486e Mon Sep 17 00:00:00 2001 From: Thalles Date: Mon, 24 Feb 2020 18:34:11 -0300 Subject: [PATCH] added tensorboard support --- train.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/train.py b/train.py index d9f36de..1161d9e 100644 --- a/train.py +++ b/train.py @@ -101,7 +101,7 @@ for e in range(config['epochs']): l_pos = torch.bmm(zis.view(batch_size, 1, out_dim), zjs.view(batch_size, out_dim, 1)).view(batch_size, 1) l_pos /= temperature - assert l_pos.shape == (batch_size, 1) # [N,1] + # assert l_pos.shape == (batch_size, 1) # [N,1] negatives = torch.cat([zjs, zis], dim=0) @@ -124,8 +124,8 @@ for e in range(config['epochs']): l_neg = l_neg[negative_mask].view(l_neg.shape[0], -1) l_neg /= temperature - assert l_neg.shape == (batch_size, 2 * (batch_size - 1)), "Shape of negatives not expected." + str( - l_neg.shape) + # assert l_neg.shape == (batch_size, 2 * (batch_size - 1)), "Shape of negatives not expected." + str( + # l_neg.shape) logits = torch.cat([l_pos, l_neg], dim=1) # [N,K+1] loss += criterion(logits, labels)