complete the loss function description from the paper

pull/5/head
Thalles 2020-02-20 11:04:09 -03:00 committed by Thalles Silva
parent 289f9e54b0
commit a7af53f845
1 changed files with 4 additions and 3 deletions

View File

@ -9,7 +9,7 @@ from torchvision import datasets
from torch.utils.tensorboard import SummaryWriter
import torch.nn.functional as F
from model import Encoder
from model import ResNet18
from utils import GaussianBlur
torch.manual_seed(0)
@ -35,12 +35,13 @@ data_augment = transforms.Compose([transforms.ToPILImage(),
train_dataset = datasets.STL10('data', split='train', download=True, transform=transforms.ToTensor())
train_loader = DataLoader(train_dataset, batch_size=batch_size, num_workers=1, drop_last=True, shuffle=True)
model = Encoder(out_dim=out_dim)
# model = ResNet18(out_dim=out_dim)
# model = Encoder(out_dim=out_dim)
model = ResNet18(out_dim=out_dim)
print(model)
train_gpu = torch.cuda.is_available()
print("Is gpu available:", train_gpu)
# moves the model paramemeters to gpu
if train_gpu:
model.cuda()