mirror of https://github.com/sthalles/SimCLR.git
added model validation
parent
5a29f3c946
commit
907a048076
21
train.py
21
train.py
|
@ -93,6 +93,7 @@ if not os.path.exists(model_checkpoints_folder):
|
|||
os.makedirs(model_checkpoints_folder)
|
||||
|
||||
n_iter = 0
|
||||
valid_n_iter = 0
|
||||
best_valid_loss = np.inf
|
||||
|
||||
for epoch_counter in range(config['epochs']):
|
||||
|
@ -115,18 +116,24 @@ for epoch_counter in range(config['epochs']):
|
|||
# validation steps
|
||||
with torch.no_grad():
|
||||
model.eval()
|
||||
for (xis, xjs), _ in valid_loader:
|
||||
|
||||
valid_loss = 0.0
|
||||
for counter, ((xis, xjs), _) in enumerate(valid_loader):
|
||||
|
||||
if train_gpu:
|
||||
xis = xis.cuda()
|
||||
xjs = xjs.cuda()
|
||||
loss = (step(xis, xjs))
|
||||
valid_loss += loss.item()
|
||||
|
||||
loss = step(xis, xjs)
|
||||
valid_loss /= counter
|
||||
|
||||
if loss < best_valid_loss:
|
||||
# save the model weights
|
||||
best_valid_loss = loss
|
||||
torch.save(model.state_dict(), os.path.join(model_checkpoints_folder, 'model.pth'))
|
||||
train_writer.add_scalar('validation_loss', loss, global_step=epoch_counter)
|
||||
if valid_loss < best_valid_loss:
|
||||
# save the model weights
|
||||
best_valid_loss = valid_loss
|
||||
torch.save(model.state_dict(), os.path.join(model_checkpoints_folder, 'model.pth'))
|
||||
|
||||
train_writer.add_scalar('validation_loss', valid_loss, global_step=valid_n_iter)
|
||||
valid_n_iter += 1
|
||||
|
||||
model.train()
|
||||
|
|
9
utils.py
9
utils.py
|
@ -4,8 +4,7 @@ from torch.utils.data import DataLoader
|
|||
from torch.utils.data.sampler import SubsetRandomSampler
|
||||
|
||||
np.random.seed(0)
|
||||
cos1d = torch.nn.CosineSimilarity(dim=1)
|
||||
cos2d = torch.nn.CosineSimilarity(dim=2)
|
||||
cosine_similarity = torch.nn.CosineSimilarity(dim=-1)
|
||||
|
||||
|
||||
def get_train_validation_data_loaders(train_dataset, config):
|
||||
|
@ -25,7 +24,7 @@ def get_train_validation_data_loaders(train_dataset, config):
|
|||
|
||||
valid_loader = DataLoader(train_dataset, batch_size=config['batch_size'], sampler=valid_sampler,
|
||||
num_workers=config['num_workers'],
|
||||
drop_last=False)
|
||||
drop_last=True)
|
||||
return train_loader, valid_loader
|
||||
|
||||
|
||||
|
@ -57,7 +56,7 @@ def _dot_simililarity_dim2(x, y):
|
|||
|
||||
|
||||
def _cosine_simililarity_dim1(x, y):
|
||||
v = cos1d(x, y)
|
||||
v = cosine_similarity(x, y)
|
||||
return v
|
||||
|
||||
|
||||
|
@ -65,7 +64,7 @@ def _cosine_simililarity_dim2(x, y):
|
|||
# x shape: (N, 1, C)
|
||||
# y shape: (1, 2N, C)
|
||||
# v shape: (N, 2N)
|
||||
v = cos2d(x.unsqueeze(1), y.unsqueeze(0))
|
||||
v = cosine_similarity(x.unsqueeze(1), y.unsqueeze(0))
|
||||
return v
|
||||
|
||||
|
||||
|
|
Loading…
Reference in New Issue