mirror of https://github.com/sthalles/SimCLR.git
complete the loss function description from the paper
parent
289f9e54b0
commit
a7af53f845
7
train.py
7
train.py
|
@ -9,7 +9,7 @@ from torchvision import datasets
|
|||
from torch.utils.tensorboard import SummaryWriter
|
||||
import torch.nn.functional as F
|
||||
|
||||
from model import Encoder
|
||||
from model import ResNet18
|
||||
from utils import GaussianBlur
|
||||
|
||||
torch.manual_seed(0)
|
||||
|
@ -35,12 +35,13 @@ data_augment = transforms.Compose([transforms.ToPILImage(),
|
|||
train_dataset = datasets.STL10('data', split='train', download=True, transform=transforms.ToTensor())
|
||||
train_loader = DataLoader(train_dataset, batch_size=batch_size, num_workers=1, drop_last=True, shuffle=True)
|
||||
|
||||
model = Encoder(out_dim=out_dim)
|
||||
# model = ResNet18(out_dim=out_dim)
|
||||
# model = Encoder(out_dim=out_dim)
|
||||
model = ResNet18(out_dim=out_dim)
|
||||
print(model)
|
||||
|
||||
train_gpu = torch.cuda.is_available()
|
||||
print("Is gpu available:", train_gpu)
|
||||
|
||||
# moves the model paramemeters to gpu
|
||||
if train_gpu:
|
||||
model.cuda()
|
||||
|
|
Loading…
Reference in New Issue