added weight decay for training

This commit is contained in:
Thalles 2020-03-15 13:14:14 -03:00
parent 762a744662
commit f369c5eae1
2 changed files with 40 additions and 27 deletions

View File

@ -1,13 +1,13 @@
batch_size: 256 batch_size: 512
epochs: 40 epochs: 80
eval_every_n_epochs: 1 eval_every_n_epochs: 1
fine_tune_from: None fine_tune_from: None
log_every_n_steps: 50 log_every_n_steps: 50
weight_decay: 10e-6 weight_decay: 10e-6
model: model:
out_dim: 128 out_dim: 256
base_model: "resnet50" base_model: "resnet18"
dataset: dataset:
s: 1 s: 1

View File

@ -49,7 +49,6 @@ class SimCLR(object):
self.writer.add_histogram("xi_latent", zis, global_step=n_iter) self.writer.add_histogram("xi_latent", zis, global_step=n_iter)
self.writer.add_histogram("xj_repr", rjs, global_step=n_iter) self.writer.add_histogram("xj_repr", rjs, global_step=n_iter)
self.writer.add_histogram("xj_latent", zjs, global_step=n_iter) self.writer.add_histogram("xj_latent", zjs, global_step=n_iter)
self.writer.add_scalar('train_loss', loss, global_step=n_iter)
return loss return loss
@ -62,6 +61,9 @@ class SimCLR(object):
optimizer = torch.optim.Adam(model.parameters(), 3e-4, weight_decay=eval(self.config['weight_decay'])) optimizer = torch.optim.Adam(model.parameters(), 3e-4, weight_decay=eval(self.config['weight_decay']))
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=len(train_loader), eta_min=0,
last_epoch=-1)
model_checkpoints_folder = os.path.join(self.writer.log_dir, 'checkpoints') model_checkpoints_folder = os.path.join(self.writer.log_dir, 'checkpoints')
# save config file # save config file
@ -80,35 +82,28 @@ class SimCLR(object):
loss = self._step(model, xis, xjs, n_iter) loss = self._step(model, xis, xjs, n_iter)
if n_iter % self.config['log_every_n_steps'] == 0:
self.writer.add_scalar('train_loss', loss, global_step=n_iter)
loss.backward() loss.backward()
optimizer.step() optimizer.step()
n_iter += 1 n_iter += 1
# validate the model if requested
if epoch_counter % self.config['eval_every_n_epochs'] == 0: if epoch_counter % self.config['eval_every_n_epochs'] == 0:
valid_loss = self._validate(model, valid_loader)
if valid_loss < best_valid_loss:
# save the model weights
best_valid_loss = valid_loss
torch.save(model.state_dict(), os.path.join(model_checkpoints_folder, 'model.pth'))
# validation steps self.writer.add_scalar('validation_loss', valid_loss, global_step=valid_n_iter)
with torch.no_grad(): valid_n_iter += 1
model.eval()
valid_loss = 0.0 # warmup for the first 10 epochs
for counter, ((xis, xjs), _) in enumerate(valid_loader): if epoch_counter >= 10:
xis = xis.to(self.device) scheduler.step()
xjs = xjs.to(self.device) self.writer.add_scalar('cosine_lr_decay', scheduler.get_lr()[0], global_step=n_iter)
loss = self._step(model, xis, xjs, n_iter)
valid_loss += loss.item()
valid_loss /= counter
if valid_loss < best_valid_loss:
# save the model weights
best_valid_loss = valid_loss
torch.save(model.state_dict(), os.path.join(model_checkpoints_folder, 'model.pth'))
self.writer.add_scalar('validation_loss', valid_loss, global_step=valid_n_iter)
valid_n_iter += 1
model.train()
def _load_pre_trained_weights(self, model): def _load_pre_trained_weights(self, model):
try: try:
@ -120,3 +115,21 @@ class SimCLR(object):
print("Pre-trained weights not found. Training from scratch.") print("Pre-trained weights not found. Training from scratch.")
return model return model
def _validate(self, model, valid_loader):
# validation steps
with torch.no_grad():
model.eval()
valid_loss = 0.0
for counter, ((xis, xjs), _) in enumerate(valid_loader):
xis = xis.to(self.device)
xjs = xjs.to(self.device)
loss = self._step(model, xis, xjs, counter)
valid_loss += loss.item()
valid_loss /= counter
model.train()
return valid_loss