mirror of
https://github.com/sthalles/SimCLR.git
synced 2025-06-03 15:03:00 +08:00
added tensorboard support
This commit is contained in:
parent
3344165c75
commit
0d7ff950cb
@ -420,8 +420,8 @@
|
|||||||
"print(\"Train score:\", clf.score(X_train_feature, y_train))\n",
|
"print(\"Train score:\", clf.score(X_train_feature, y_train))\n",
|
||||||
"print(\"Test score:\", clf.score(X_test_feature, y_test))\n",
|
"print(\"Test score:\", clf.score(X_test_feature, y_test))\n",
|
||||||
"# SimCLR feature evaluation\n",
|
"# SimCLR feature evaluation\n",
|
||||||
"# Train score: 0.5298\n",
|
"# Train score: 0.7444\n",
|
||||||
"# Test score: 0.52075"
|
"# Test score: 0.62625"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
|
@ -3,4 +3,4 @@ out_dim: 64
|
|||||||
s: 1
|
s: 1
|
||||||
temperature: 0.5
|
temperature: 0.5
|
||||||
use_cosine_similarity: True
|
use_cosine_similarity: True
|
||||||
epochs: 20
|
epochs: 30
|
Binary file not shown.
25
train.py
25
train.py
@ -1,6 +1,6 @@
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
from yaml import load
|
import yaml
|
||||||
|
|
||||||
print(torch.__version__)
|
print(torch.__version__)
|
||||||
import torch.optim as optim
|
import torch.optim as optim
|
||||||
@ -13,7 +13,7 @@ import torch.nn.functional as F
|
|||||||
from model import Encoder, ResNet18
|
from model import Encoder, ResNet18
|
||||||
from utils import GaussianBlur
|
from utils import GaussianBlur
|
||||||
|
|
||||||
config = load(open("config.ymal", "r"))
|
config = yaml.load(open("config.yaml", "r"), Loader=yaml.FullLoader)
|
||||||
|
|
||||||
batch_size = config['batch_size']
|
batch_size = config['batch_size']
|
||||||
out_dim = config['out_dim']
|
out_dim = config['out_dim']
|
||||||
@ -44,7 +44,7 @@ print("Is gpu available:", train_gpu)
|
|||||||
if train_gpu:
|
if train_gpu:
|
||||||
model.cuda()
|
model.cuda()
|
||||||
|
|
||||||
criterion = torch.nn.CrossEntropyLoss()
|
criterion = torch.nn.CrossEntropyLoss(reduction='sum')
|
||||||
optimizer = optim.Adam(model.parameters(), 3e-4)
|
optimizer = optim.Adam(model.parameters(), 3e-4)
|
||||||
|
|
||||||
train_writer = SummaryWriter()
|
train_writer = SummaryWriter()
|
||||||
@ -106,22 +106,27 @@ for e in range(config['epochs']):
|
|||||||
if use_cosine_similarity:
|
if use_cosine_similarity:
|
||||||
negatives = negatives.view(1, (2 * batch_size), out_dim)
|
negatives = negatives.view(1, (2 * batch_size), out_dim)
|
||||||
l_neg_1 = similarity_dim2(zis.view(batch_size, 1, out_dim), negatives)
|
l_neg_1 = similarity_dim2(zis.view(batch_size, 1, out_dim), negatives)
|
||||||
|
l_neg_2 = similarity_dim2(zjs.view(batch_size, 1, out_dim), negatives)
|
||||||
else:
|
else:
|
||||||
l_neg = torch.tensordot(zis.view(batch_size, 1, out_dim), negatives.T.view(1, out_dim, (2 * batch_size)),
|
l_neg_1 = torch.tensordot(zis.view(batch_size, 1, out_dim), negatives.T.view(1, out_dim, (2 * batch_size)),
|
||||||
|
dims=2)
|
||||||
|
l_neg_2 = torch.tensordot(zjs.view(batch_size, 1, out_dim), negatives.T.view(1, out_dim, (2 * batch_size)),
|
||||||
dims=2)
|
dims=2)
|
||||||
|
|
||||||
|
labels = torch.zeros(batch_size, dtype=torch.long)
|
||||||
|
if train_gpu:
|
||||||
|
labels = labels.cuda()
|
||||||
|
|
||||||
|
loss = 0
|
||||||
|
for l_neg in [l_neg_1, l_neg_2]:
|
||||||
l_neg = l_neg[negative_mask].view(l_neg.shape[0], -1)
|
l_neg = l_neg[negative_mask].view(l_neg.shape[0], -1)
|
||||||
l_neg /= temperature
|
l_neg /= temperature
|
||||||
|
|
||||||
# assert l_neg.shape == (batch_size, 2 * (batch_size - 1)), "Shape of negatives not expected." + str(l_neg.shape)
|
# assert l_neg.shape == (batch_size, 2 * (batch_size - 1)), "Shape of negatives not expected." + str(l_neg.shape)
|
||||||
|
|
||||||
logits = torch.cat([l_pos, l_neg], dim=1) # [N,K+1]
|
logits = torch.cat([l_pos, l_neg], dim=1) # [N,K+1]
|
||||||
labels = torch.zeros(batch_size, dtype=torch.long)
|
loss += criterion(logits, labels)
|
||||||
|
|
||||||
if train_gpu:
|
loss = loss / (2 * batch_size)
|
||||||
labels = labels.cuda()
|
|
||||||
|
|
||||||
loss = criterion(logits, labels)
|
|
||||||
train_writer.add_scalar('loss', loss, global_step=n_iter)
|
train_writer.add_scalar('loss', loss, global_step=n_iter)
|
||||||
|
|
||||||
loss.backward()
|
loss.backward()
|
||||||
|
Loading…
x
Reference in New Issue
Block a user