mirror of
https://github.com/sthalles/SimCLR.git
synced 2025-06-03 15:03:00 +08:00
added model validation
This commit is contained in:
parent
5a29f3c946
commit
907a048076
21
train.py
21
train.py
@ -93,6 +93,7 @@ if not os.path.exists(model_checkpoints_folder):
|
|||||||
os.makedirs(model_checkpoints_folder)
|
os.makedirs(model_checkpoints_folder)
|
||||||
|
|
||||||
n_iter = 0
|
n_iter = 0
|
||||||
|
valid_n_iter = 0
|
||||||
best_valid_loss = np.inf
|
best_valid_loss = np.inf
|
||||||
|
|
||||||
for epoch_counter in range(config['epochs']):
|
for epoch_counter in range(config['epochs']):
|
||||||
@ -115,18 +116,24 @@ for epoch_counter in range(config['epochs']):
|
|||||||
# validation steps
|
# validation steps
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
model.eval()
|
model.eval()
|
||||||
for (xis, xjs), _ in valid_loader:
|
|
||||||
|
valid_loss = 0.0
|
||||||
|
for counter, ((xis, xjs), _) in enumerate(valid_loader):
|
||||||
|
|
||||||
if train_gpu:
|
if train_gpu:
|
||||||
xis = xis.cuda()
|
xis = xis.cuda()
|
||||||
xjs = xjs.cuda()
|
xjs = xjs.cuda()
|
||||||
|
loss = (step(xis, xjs))
|
||||||
|
valid_loss += loss.item()
|
||||||
|
|
||||||
loss = step(xis, xjs)
|
valid_loss /= counter
|
||||||
|
|
||||||
if loss < best_valid_loss:
|
if valid_loss < best_valid_loss:
|
||||||
# save the model weights
|
# save the model weights
|
||||||
best_valid_loss = loss
|
best_valid_loss = valid_loss
|
||||||
torch.save(model.state_dict(), os.path.join(model_checkpoints_folder, 'model.pth'))
|
torch.save(model.state_dict(), os.path.join(model_checkpoints_folder, 'model.pth'))
|
||||||
train_writer.add_scalar('validation_loss', loss, global_step=epoch_counter)
|
|
||||||
|
train_writer.add_scalar('validation_loss', valid_loss, global_step=valid_n_iter)
|
||||||
|
valid_n_iter += 1
|
||||||
|
|
||||||
model.train()
|
model.train()
|
||||||
|
9
utils.py
9
utils.py
@ -4,8 +4,7 @@ from torch.utils.data import DataLoader
|
|||||||
from torch.utils.data.sampler import SubsetRandomSampler
|
from torch.utils.data.sampler import SubsetRandomSampler
|
||||||
|
|
||||||
np.random.seed(0)
|
np.random.seed(0)
|
||||||
cos1d = torch.nn.CosineSimilarity(dim=1)
|
cosine_similarity = torch.nn.CosineSimilarity(dim=-1)
|
||||||
cos2d = torch.nn.CosineSimilarity(dim=2)
|
|
||||||
|
|
||||||
|
|
||||||
def get_train_validation_data_loaders(train_dataset, config):
|
def get_train_validation_data_loaders(train_dataset, config):
|
||||||
@ -25,7 +24,7 @@ def get_train_validation_data_loaders(train_dataset, config):
|
|||||||
|
|
||||||
valid_loader = DataLoader(train_dataset, batch_size=config['batch_size'], sampler=valid_sampler,
|
valid_loader = DataLoader(train_dataset, batch_size=config['batch_size'], sampler=valid_sampler,
|
||||||
num_workers=config['num_workers'],
|
num_workers=config['num_workers'],
|
||||||
drop_last=False)
|
drop_last=True)
|
||||||
return train_loader, valid_loader
|
return train_loader, valid_loader
|
||||||
|
|
||||||
|
|
||||||
@ -57,7 +56,7 @@ def _dot_simililarity_dim2(x, y):
|
|||||||
|
|
||||||
|
|
||||||
def _cosine_simililarity_dim1(x, y):
|
def _cosine_simililarity_dim1(x, y):
|
||||||
v = cos1d(x, y)
|
v = cosine_similarity(x, y)
|
||||||
return v
|
return v
|
||||||
|
|
||||||
|
|
||||||
@ -65,7 +64,7 @@ def _cosine_simililarity_dim2(x, y):
|
|||||||
# x shape: (N, 1, C)
|
# x shape: (N, 1, C)
|
||||||
# y shape: (1, 2N, C)
|
# y shape: (1, 2N, C)
|
||||||
# v shape: (N, 2N)
|
# v shape: (N, 2N)
|
||||||
v = cos2d(x.unsqueeze(1), y.unsqueeze(0))
|
v = cosine_similarity(x.unsqueeze(1), y.unsqueeze(0))
|
||||||
return v
|
return v
|
||||||
|
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user