mirror of
https://github.com/sthalles/SimCLR.git
synced 2025-06-03 15:03:00 +08:00
added tensorboard support
This commit is contained in:
parent
ec0d482fc7
commit
c138e77bdc
6
train.py
6
train.py
@ -101,7 +101,7 @@ for e in range(config['epochs']):
|
||||
l_pos = torch.bmm(zis.view(batch_size, 1, out_dim), zjs.view(batch_size, out_dim, 1)).view(batch_size, 1)
|
||||
|
||||
l_pos /= temperature
|
||||
assert l_pos.shape == (batch_size, 1) # [N,1]
|
||||
# assert l_pos.shape == (batch_size, 1) # [N,1]
|
||||
|
||||
negatives = torch.cat([zjs, zis], dim=0)
|
||||
|
||||
@ -124,8 +124,8 @@ for e in range(config['epochs']):
|
||||
l_neg = l_neg[negative_mask].view(l_neg.shape[0], -1)
|
||||
l_neg /= temperature
|
||||
|
||||
assert l_neg.shape == (batch_size, 2 * (batch_size - 1)), "Shape of negatives not expected." + str(
|
||||
l_neg.shape)
|
||||
# assert l_neg.shape == (batch_size, 2 * (batch_size - 1)), "Shape of negatives not expected." + str(
|
||||
# l_neg.shape)
|
||||
logits = torch.cat([l_pos, l_neg], dim=1) # [N,K+1]
|
||||
loss += criterion(logits, labels)
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user