fixed bug missing zero_grad()

pull/5/head
Thalles 2020-02-29 19:51:00 -03:00
parent 88dcdf6d06
commit af77a03f0b
2 changed files with 3 additions and 4 deletions

View File

@ -34,7 +34,7 @@ train_loader = DataLoader(train_dataset, batch_size=batch_size, num_workers=conf
# model = Encoder(out_dim=out_dim)
model = ResNetSimCLR(base_model=config["base_convnet"], out_dim=out_dim)
train_gpu = torch.cuda.is_available()
train_gpu = False # torch.cuda.is_available()
print("Is gpu available:", train_gpu)
# moves the model parameters to gpu
@ -55,6 +55,8 @@ n_iter = 0
for e in range(config['epochs']):
for step, ((xis, xjs), _) in enumerate(train_loader):
optimizer.zero_grad()
if train_gpu:
xis = xis.cuda()
xjs = xjs.cuda()

View File

@ -1,8 +1,5 @@
import cv2
import numpy as np
import torch
import torchvision.transforms as transforms
import matplotlib.pyplot as plt
np.random.seed(0)
cos1d = torch.nn.CosineSimilarity(dim=1)