diff --git a/train.py b/train.py index d9f36de..1161d9e 100644 --- a/train.py +++ b/train.py @@ -101,7 +101,7 @@ for e in range(config['epochs']): l_pos = torch.bmm(zis.view(batch_size, 1, out_dim), zjs.view(batch_size, out_dim, 1)).view(batch_size, 1) l_pos /= temperature - assert l_pos.shape == (batch_size, 1) # [N,1] + # assert l_pos.shape == (batch_size, 1) # [N,1] negatives = torch.cat([zjs, zis], dim=0) @@ -124,8 +124,8 @@ for e in range(config['epochs']): l_neg = l_neg[negative_mask].view(l_neg.shape[0], -1) l_neg /= temperature - assert l_neg.shape == (batch_size, 2 * (batch_size - 1)), "Shape of negatives not expected." + str( - l_neg.shape) + # assert l_neg.shape == (batch_size, 2 * (batch_size - 1)), "Shape of negatives not expected." + str( + # l_neg.shape) logits = torch.cat([l_pos, l_neg], dim=1) # [N,K+1] loss += criterion(logits, labels)