mirror of
https://github.com/sthalles/SimCLR.git
synced 2025-06-03 15:03:00 +08:00
added tensorboard support
This commit is contained in:
parent
3344165c75
commit
0d7ff950cb
@ -420,8 +420,8 @@
|
||||
"print(\"Train score:\", clf.score(X_train_feature, y_train))\n",
|
||||
"print(\"Test score:\", clf.score(X_test_feature, y_test))\n",
|
||||
"# SimCLR feature evaluation\n",
|
||||
"# Train score: 0.5298\n",
|
||||
"# Test score: 0.52075"
|
||||
"# Train score: 0.7444\n",
|
||||
"# Test score: 0.62625"
|
||||
]
|
||||
},
|
||||
{
|
||||
|
@ -3,4 +3,4 @@ out_dim: 64
|
||||
s: 1
|
||||
temperature: 0.5
|
||||
use_cosine_similarity: True
|
||||
epochs: 20
|
||||
epochs: 30
|
Binary file not shown.
31
train.py
31
train.py
@ -1,6 +1,6 @@
|
||||
import numpy as np
|
||||
import torch
|
||||
from yaml import load
|
||||
import yaml
|
||||
|
||||
print(torch.__version__)
|
||||
import torch.optim as optim
|
||||
@ -13,7 +13,7 @@ import torch.nn.functional as F
|
||||
from model import Encoder, ResNet18
|
||||
from utils import GaussianBlur
|
||||
|
||||
config = load(open("config.ymal", "r"))
|
||||
config = yaml.load(open("config.yaml", "r"), Loader=yaml.FullLoader)
|
||||
|
||||
batch_size = config['batch_size']
|
||||
out_dim = config['out_dim']
|
||||
@ -44,7 +44,7 @@ print("Is gpu available:", train_gpu)
|
||||
if train_gpu:
|
||||
model.cuda()
|
||||
|
||||
criterion = torch.nn.CrossEntropyLoss()
|
||||
criterion = torch.nn.CrossEntropyLoss(reduction='sum')
|
||||
optimizer = optim.Adam(model.parameters(), 3e-4)
|
||||
|
||||
train_writer = SummaryWriter()
|
||||
@ -106,22 +106,27 @@ for e in range(config['epochs']):
|
||||
if use_cosine_similarity:
|
||||
negatives = negatives.view(1, (2 * batch_size), out_dim)
|
||||
l_neg_1 = similarity_dim2(zis.view(batch_size, 1, out_dim), negatives)
|
||||
l_neg_2 = similarity_dim2(zjs.view(batch_size, 1, out_dim), negatives)
|
||||
else:
|
||||
l_neg = torch.tensordot(zis.view(batch_size, 1, out_dim), negatives.T.view(1, out_dim, (2 * batch_size)),
|
||||
dims=2)
|
||||
l_neg_1 = torch.tensordot(zis.view(batch_size, 1, out_dim), negatives.T.view(1, out_dim, (2 * batch_size)),
|
||||
dims=2)
|
||||
l_neg_2 = torch.tensordot(zjs.view(batch_size, 1, out_dim), negatives.T.view(1, out_dim, (2 * batch_size)),
|
||||
dims=2)
|
||||
|
||||
l_neg = l_neg[negative_mask].view(l_neg.shape[0], -1)
|
||||
l_neg /= temperature
|
||||
|
||||
# assert l_neg.shape == (batch_size, 2 * (batch_size - 1)), "Shape of negatives not expected." + str(l_neg.shape)
|
||||
|
||||
logits = torch.cat([l_pos, l_neg], dim=1) # [N,K+1]
|
||||
labels = torch.zeros(batch_size, dtype=torch.long)
|
||||
|
||||
if train_gpu:
|
||||
labels = labels.cuda()
|
||||
|
||||
loss = criterion(logits, labels)
|
||||
loss = 0
|
||||
for l_neg in [l_neg_1, l_neg_2]:
|
||||
l_neg = l_neg[negative_mask].view(l_neg.shape[0], -1)
|
||||
l_neg /= temperature
|
||||
|
||||
# assert l_neg.shape == (batch_size, 2 * (batch_size - 1)), "Shape of negatives not expected." + str(l_neg.shape)
|
||||
logits = torch.cat([l_pos, l_neg], dim=1) # [N,K+1]
|
||||
loss += criterion(logits, labels)
|
||||
|
||||
loss = loss / (2 * batch_size)
|
||||
train_writer.add_scalar('loss', loss, global_step=n_iter)
|
||||
|
||||
loss.backward()
|
||||
|
Loading…
x
Reference in New Issue
Block a user