mirror of
https://github.com/sthalles/SimCLR.git
synced 2025-06-03 15:03:00 +08:00
fixed bug missing zero_grad()
This commit is contained in:
parent
88dcdf6d06
commit
af77a03f0b
4
train.py
4
train.py
@ -34,7 +34,7 @@ train_loader = DataLoader(train_dataset, batch_size=batch_size, num_workers=conf
|
||||
# model = Encoder(out_dim=out_dim)
|
||||
model = ResNetSimCLR(base_model=config["base_convnet"], out_dim=out_dim)
|
||||
|
||||
train_gpu = torch.cuda.is_available()
|
||||
train_gpu = False # torch.cuda.is_available()
|
||||
print("Is gpu available:", train_gpu)
|
||||
|
||||
# moves the model parameters to gpu
|
||||
@ -55,6 +55,8 @@ n_iter = 0
|
||||
for e in range(config['epochs']):
|
||||
for step, ((xis, xjs), _) in enumerate(train_loader):
|
||||
|
||||
optimizer.zero_grad()
|
||||
|
||||
if train_gpu:
|
||||
xis = xis.cuda()
|
||||
xjs = xjs.cuda()
|
||||
|
Loading…
x
Reference in New Issue
Block a user