mirror of
https://github.com/sthalles/SimCLR.git
synced 2025-06-03 15:03:00 +08:00
added weight decay for training
This commit is contained in:
parent
762a744662
commit
f369c5eae1
@ -1,13 +1,13 @@
|
|||||||
batch_size: 256
|
batch_size: 512
|
||||||
epochs: 40
|
epochs: 80
|
||||||
eval_every_n_epochs: 1
|
eval_every_n_epochs: 1
|
||||||
fine_tune_from: None
|
fine_tune_from: None
|
||||||
log_every_n_steps: 50
|
log_every_n_steps: 50
|
||||||
weight_decay: 10e-6
|
weight_decay: 10e-6
|
||||||
|
|
||||||
model:
|
model:
|
||||||
out_dim: 128
|
out_dim: 256
|
||||||
base_model: "resnet50"
|
base_model: "resnet18"
|
||||||
|
|
||||||
dataset:
|
dataset:
|
||||||
s: 1
|
s: 1
|
||||||
|
59
simclr.py
59
simclr.py
@ -49,7 +49,6 @@ class SimCLR(object):
|
|||||||
self.writer.add_histogram("xi_latent", zis, global_step=n_iter)
|
self.writer.add_histogram("xi_latent", zis, global_step=n_iter)
|
||||||
self.writer.add_histogram("xj_repr", rjs, global_step=n_iter)
|
self.writer.add_histogram("xj_repr", rjs, global_step=n_iter)
|
||||||
self.writer.add_histogram("xj_latent", zjs, global_step=n_iter)
|
self.writer.add_histogram("xj_latent", zjs, global_step=n_iter)
|
||||||
self.writer.add_scalar('train_loss', loss, global_step=n_iter)
|
|
||||||
|
|
||||||
return loss
|
return loss
|
||||||
|
|
||||||
@ -62,6 +61,9 @@ class SimCLR(object):
|
|||||||
|
|
||||||
optimizer = torch.optim.Adam(model.parameters(), 3e-4, weight_decay=eval(self.config['weight_decay']))
|
optimizer = torch.optim.Adam(model.parameters(), 3e-4, weight_decay=eval(self.config['weight_decay']))
|
||||||
|
|
||||||
|
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=len(train_loader), eta_min=0,
|
||||||
|
last_epoch=-1)
|
||||||
|
|
||||||
model_checkpoints_folder = os.path.join(self.writer.log_dir, 'checkpoints')
|
model_checkpoints_folder = os.path.join(self.writer.log_dir, 'checkpoints')
|
||||||
|
|
||||||
# save config file
|
# save config file
|
||||||
@ -80,35 +82,28 @@ class SimCLR(object):
|
|||||||
|
|
||||||
loss = self._step(model, xis, xjs, n_iter)
|
loss = self._step(model, xis, xjs, n_iter)
|
||||||
|
|
||||||
|
if n_iter % self.config['log_every_n_steps'] == 0:
|
||||||
|
self.writer.add_scalar('train_loss', loss, global_step=n_iter)
|
||||||
|
|
||||||
loss.backward()
|
loss.backward()
|
||||||
optimizer.step()
|
optimizer.step()
|
||||||
n_iter += 1
|
n_iter += 1
|
||||||
|
|
||||||
|
# validate the model if requested
|
||||||
if epoch_counter % self.config['eval_every_n_epochs'] == 0:
|
if epoch_counter % self.config['eval_every_n_epochs'] == 0:
|
||||||
|
valid_loss = self._validate(model, valid_loader)
|
||||||
|
if valid_loss < best_valid_loss:
|
||||||
|
# save the model weights
|
||||||
|
best_valid_loss = valid_loss
|
||||||
|
torch.save(model.state_dict(), os.path.join(model_checkpoints_folder, 'model.pth'))
|
||||||
|
|
||||||
# validation steps
|
self.writer.add_scalar('validation_loss', valid_loss, global_step=valid_n_iter)
|
||||||
with torch.no_grad():
|
valid_n_iter += 1
|
||||||
model.eval()
|
|
||||||
|
|
||||||
valid_loss = 0.0
|
# warmup for the first 10 epochs
|
||||||
for counter, ((xis, xjs), _) in enumerate(valid_loader):
|
if epoch_counter >= 10:
|
||||||
xis = xis.to(self.device)
|
scheduler.step()
|
||||||
xjs = xjs.to(self.device)
|
self.writer.add_scalar('cosine_lr_decay', scheduler.get_lr()[0], global_step=n_iter)
|
||||||
|
|
||||||
loss = self._step(model, xis, xjs, n_iter)
|
|
||||||
valid_loss += loss.item()
|
|
||||||
|
|
||||||
valid_loss /= counter
|
|
||||||
|
|
||||||
if valid_loss < best_valid_loss:
|
|
||||||
# save the model weights
|
|
||||||
best_valid_loss = valid_loss
|
|
||||||
torch.save(model.state_dict(), os.path.join(model_checkpoints_folder, 'model.pth'))
|
|
||||||
|
|
||||||
self.writer.add_scalar('validation_loss', valid_loss, global_step=valid_n_iter)
|
|
||||||
valid_n_iter += 1
|
|
||||||
|
|
||||||
model.train()
|
|
||||||
|
|
||||||
def _load_pre_trained_weights(self, model):
|
def _load_pre_trained_weights(self, model):
|
||||||
try:
|
try:
|
||||||
@ -120,3 +115,21 @@ class SimCLR(object):
|
|||||||
print("Pre-trained weights not found. Training from scratch.")
|
print("Pre-trained weights not found. Training from scratch.")
|
||||||
|
|
||||||
return model
|
return model
|
||||||
|
|
||||||
|
def _validate(self, model, valid_loader):
|
||||||
|
|
||||||
|
# validation steps
|
||||||
|
with torch.no_grad():
|
||||||
|
model.eval()
|
||||||
|
|
||||||
|
valid_loss = 0.0
|
||||||
|
for counter, ((xis, xjs), _) in enumerate(valid_loader):
|
||||||
|
xis = xis.to(self.device)
|
||||||
|
xjs = xjs.to(self.device)
|
||||||
|
|
||||||
|
loss = self._step(model, xis, xjs, counter)
|
||||||
|
valid_loss += loss.item()
|
||||||
|
|
||||||
|
valid_loss /= counter
|
||||||
|
model.train()
|
||||||
|
return valid_loss
|
||||||
|
Loading…
x
Reference in New Issue
Block a user