added model validation

pull/5/head
Thalles 2020-03-10 20:41:59 -03:00
parent 5a29f3c946
commit 907a048076
2 changed files with 18 additions and 12 deletions

View File

@ -93,6 +93,7 @@ if not os.path.exists(model_checkpoints_folder):
os.makedirs(model_checkpoints_folder)
n_iter = 0
valid_n_iter = 0
best_valid_loss = np.inf
for epoch_counter in range(config['epochs']):
@ -115,18 +116,24 @@ for epoch_counter in range(config['epochs']):
# validation steps
with torch.no_grad():
model.eval()
for (xis, xjs), _ in valid_loader:
valid_loss = 0.0
for counter, ((xis, xjs), _) in enumerate(valid_loader):
if train_gpu:
xis = xis.cuda()
xjs = xjs.cuda()
loss = (step(xis, xjs))
valid_loss += loss.item()
loss = step(xis, xjs)
valid_loss /= counter
if loss < best_valid_loss:
# save the model weights
best_valid_loss = loss
torch.save(model.state_dict(), os.path.join(model_checkpoints_folder, 'model.pth'))
train_writer.add_scalar('validation_loss', loss, global_step=epoch_counter)
if valid_loss < best_valid_loss:
# save the model weights
best_valid_loss = valid_loss
torch.save(model.state_dict(), os.path.join(model_checkpoints_folder, 'model.pth'))
train_writer.add_scalar('validation_loss', valid_loss, global_step=valid_n_iter)
valid_n_iter += 1
model.train()

View File

@ -4,8 +4,7 @@ from torch.utils.data import DataLoader
from torch.utils.data.sampler import SubsetRandomSampler
np.random.seed(0)
cos1d = torch.nn.CosineSimilarity(dim=1)
cos2d = torch.nn.CosineSimilarity(dim=2)
cosine_similarity = torch.nn.CosineSimilarity(dim=-1)
def get_train_validation_data_loaders(train_dataset, config):
@ -25,7 +24,7 @@ def get_train_validation_data_loaders(train_dataset, config):
valid_loader = DataLoader(train_dataset, batch_size=config['batch_size'], sampler=valid_sampler,
num_workers=config['num_workers'],
drop_last=False)
drop_last=True)
return train_loader, valid_loader
@ -57,7 +56,7 @@ def _dot_simililarity_dim2(x, y):
def _cosine_simililarity_dim1(x, y):
v = cos1d(x, y)
v = cosine_similarity(x, y)
return v
@ -65,7 +64,7 @@ def _cosine_simililarity_dim2(x, y):
# x shape: (N, 1, C)
# y shape: (1, 2N, C)
# v shape: (N, 2N)
v = cos2d(x.unsqueeze(1), y.unsqueeze(0))
v = cosine_similarity(x.unsqueeze(1), y.unsqueeze(0))
return v