mirror of https://github.com/sthalles/SimCLR.git
added tensorboard support
parent
f8ade33008
commit
3344165c75
|
@ -156,7 +156,7 @@
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": 8,
|
"execution_count": 16,
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"outputs": [
|
"outputs": [
|
||||||
{
|
{
|
||||||
|
@ -177,7 +177,7 @@
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": 9,
|
"execution_count": 17,
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"outputs": [
|
"outputs": [
|
||||||
{
|
{
|
||||||
|
@ -195,7 +195,7 @@
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": 11,
|
"execution_count": 18,
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"outputs": [
|
"outputs": [
|
||||||
{
|
{
|
||||||
|
@ -298,7 +298,7 @@
|
||||||
"<All keys matched successfully>"
|
"<All keys matched successfully>"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
"execution_count": 11,
|
"execution_count": 18,
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"output_type": "execute_result"
|
"output_type": "execute_result"
|
||||||
}
|
}
|
||||||
|
@ -323,7 +323,7 @@
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": 12,
|
"execution_count": 19,
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"outputs": [
|
"outputs": [
|
||||||
{
|
{
|
||||||
|
@ -350,7 +350,7 @@
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": 13,
|
"execution_count": 20,
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"outputs": [
|
"outputs": [
|
||||||
{
|
{
|
||||||
|
@ -377,7 +377,7 @@
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": 14,
|
"execution_count": 21,
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"outputs": [
|
"outputs": [
|
||||||
{
|
{
|
||||||
|
@ -402,7 +402,7 @@
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": 15,
|
"execution_count": 22,
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"outputs": [
|
"outputs": [
|
||||||
{
|
{
|
||||||
|
@ -410,8 +410,8 @@
|
||||||
"output_type": "stream",
|
"output_type": "stream",
|
||||||
"text": [
|
"text": [
|
||||||
"SimCLR feature evaluation\n",
|
"SimCLR feature evaluation\n",
|
||||||
"Train score: 0.4858\n",
|
"Train score: 0.7444\n",
|
||||||
"Test score: 0.37825\n"
|
"Test score: 0.62625\n"
|
||||||
]
|
]
|
||||||
}
|
}
|
||||||
],
|
],
|
||||||
|
|
Binary file not shown.
5
train.py
5
train.py
|
@ -62,7 +62,7 @@ for i in range(batch_size):
|
||||||
n_iter = 0
|
n_iter = 0
|
||||||
for e in range(config['epochs']):
|
for e in range(config['epochs']):
|
||||||
for step, (batch_x, _) in enumerate(train_loader):
|
for step, (batch_x, _) in enumerate(train_loader):
|
||||||
# print("Input batch:", batch_x.shape, torch.min(batch_x), torch.max(batch_x))
|
|
||||||
optimizer.zero_grad()
|
optimizer.zero_grad()
|
||||||
|
|
||||||
xis = []
|
xis = []
|
||||||
|
@ -104,7 +104,8 @@ for e in range(config['epochs']):
|
||||||
negatives = torch.cat([zjs, zis], dim=0)
|
negatives = torch.cat([zjs, zis], dim=0)
|
||||||
|
|
||||||
if use_cosine_similarity:
|
if use_cosine_similarity:
|
||||||
l_neg = similarity_dim2(zis.view(batch_size, 1, out_dim), negatives.view(1, (2 * batch_size), out_dim))
|
negatives = negatives.view(1, (2 * batch_size), out_dim)
|
||||||
|
l_neg_1 = similarity_dim2(zis.view(batch_size, 1, out_dim), negatives)
|
||||||
else:
|
else:
|
||||||
l_neg = torch.tensordot(zis.view(batch_size, 1, out_dim), negatives.T.view(1, out_dim, (2 * batch_size)),
|
l_neg = torch.tensordot(zis.view(batch_size, 1, out_dim), negatives.T.view(1, out_dim, (2 * batch_size)),
|
||||||
dims=2)
|
dims=2)
|
||||||
|
|
Loading…
Reference in New Issue