mirror of https://github.com/sthalles/SimCLR.git
fixed bug missing zero_grad()
parent
88dcdf6d06
commit
af77a03f0b
4
train.py
4
train.py
|
@ -34,7 +34,7 @@ train_loader = DataLoader(train_dataset, batch_size=batch_size, num_workers=conf
|
|||
# model = Encoder(out_dim=out_dim)
|
||||
model = ResNetSimCLR(base_model=config["base_convnet"], out_dim=out_dim)
|
||||
|
||||
train_gpu = torch.cuda.is_available()
|
||||
train_gpu = False # torch.cuda.is_available()
|
||||
print("Is gpu available:", train_gpu)
|
||||
|
||||
# moves the model parameters to gpu
|
||||
|
@ -55,6 +55,8 @@ n_iter = 0
|
|||
for e in range(config['epochs']):
|
||||
for step, ((xis, xjs), _) in enumerate(train_loader):
|
||||
|
||||
optimizer.zero_grad()
|
||||
|
||||
if train_gpu:
|
||||
xis = xis.cuda()
|
||||
xjs = xjs.cuda()
|
||||
|
|
Loading…
Reference in New Issue