mirror of
https://github.com/sthalles/SimCLR.git
synced 2025-06-03 06:42:40 +08:00
fixed bug missing zero_grad()
This commit is contained in:
parent
88dcdf6d06
commit
af77a03f0b
4
train.py
4
train.py
@ -34,7 +34,7 @@ train_loader = DataLoader(train_dataset, batch_size=batch_size, num_workers=conf
|
|||||||
# model = Encoder(out_dim=out_dim)
|
# model = Encoder(out_dim=out_dim)
|
||||||
model = ResNetSimCLR(base_model=config["base_convnet"], out_dim=out_dim)
|
model = ResNetSimCLR(base_model=config["base_convnet"], out_dim=out_dim)
|
||||||
|
|
||||||
train_gpu = torch.cuda.is_available()
|
train_gpu = False # torch.cuda.is_available()
|
||||||
print("Is gpu available:", train_gpu)
|
print("Is gpu available:", train_gpu)
|
||||||
|
|
||||||
# moves the model parameters to gpu
|
# moves the model parameters to gpu
|
||||||
@ -55,6 +55,8 @@ n_iter = 0
|
|||||||
for e in range(config['epochs']):
|
for e in range(config['epochs']):
|
||||||
for step, ((xis, xjs), _) in enumerate(train_loader):
|
for step, ((xis, xjs), _) in enumerate(train_loader):
|
||||||
|
|
||||||
|
optimizer.zero_grad()
|
||||||
|
|
||||||
if train_gpu:
|
if train_gpu:
|
||||||
xis = xis.cuda()
|
xis = xis.cuda()
|
||||||
xjs = xjs.cuda()
|
xjs = xjs.cuda()
|
||||||
|
3
utils.py
3
utils.py
@ -1,8 +1,5 @@
|
|||||||
import cv2
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
import torchvision.transforms as transforms
|
|
||||||
import matplotlib.pyplot as plt
|
|
||||||
|
|
||||||
np.random.seed(0)
|
np.random.seed(0)
|
||||||
cos1d = torch.nn.CosineSimilarity(dim=1)
|
cos1d = torch.nn.CosineSimilarity(dim=1)
|
||||||
|
Loading…
x
Reference in New Issue
Block a user