mirror of https://github.com/sthalles/SimCLR.git
added tensorboard support
parent
f8ade33008
commit
3344165c75
|
@ -156,7 +156,7 @@
|
|||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 8,
|
||||
"execution_count": 16,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
|
@ -177,7 +177,7 @@
|
|||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 9,
|
||||
"execution_count": 17,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
|
@ -195,7 +195,7 @@
|
|||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 11,
|
||||
"execution_count": 18,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
|
@ -298,7 +298,7 @@
|
|||
"<All keys matched successfully>"
|
||||
]
|
||||
},
|
||||
"execution_count": 11,
|
||||
"execution_count": 18,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
|
@ -323,7 +323,7 @@
|
|||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 12,
|
||||
"execution_count": 19,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
|
@ -350,7 +350,7 @@
|
|||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 13,
|
||||
"execution_count": 20,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
|
@ -377,7 +377,7 @@
|
|||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 14,
|
||||
"execution_count": 21,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
|
@ -402,7 +402,7 @@
|
|||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 15,
|
||||
"execution_count": 22,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
|
@ -410,8 +410,8 @@
|
|||
"output_type": "stream",
|
||||
"text": [
|
||||
"SimCLR feature evaluation\n",
|
||||
"Train score: 0.4858\n",
|
||||
"Test score: 0.37825\n"
|
||||
"Train score: 0.7444\n",
|
||||
"Test score: 0.62625\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
|
|
Binary file not shown.
5
train.py
5
train.py
|
@ -62,7 +62,7 @@ for i in range(batch_size):
|
|||
n_iter = 0
|
||||
for e in range(config['epochs']):
|
||||
for step, (batch_x, _) in enumerate(train_loader):
|
||||
# print("Input batch:", batch_x.shape, torch.min(batch_x), torch.max(batch_x))
|
||||
|
||||
optimizer.zero_grad()
|
||||
|
||||
xis = []
|
||||
|
@ -104,7 +104,8 @@ for e in range(config['epochs']):
|
|||
negatives = torch.cat([zjs, zis], dim=0)
|
||||
|
||||
if use_cosine_similarity:
|
||||
l_neg = similarity_dim2(zis.view(batch_size, 1, out_dim), negatives.view(1, (2 * batch_size), out_dim))
|
||||
negatives = negatives.view(1, (2 * batch_size), out_dim)
|
||||
l_neg_1 = similarity_dim2(zis.view(batch_size, 1, out_dim), negatives)
|
||||
else:
|
||||
l_neg = torch.tensordot(zis.view(batch_size, 1, out_dim), negatives.T.view(1, out_dim, (2 * batch_size)),
|
||||
dims=2)
|
||||
|
|
Loading…
Reference in New Issue