fix small bugs

pull/5/head
Thalles 2020-03-12 22:34:21 -03:00
parent f78ee5d069
commit ca32c990ec
4 changed files with 27 additions and 37 deletions

View File

@ -8,4 +8,5 @@ epochs: 50
num_workers: 0 num_workers: 0
valid_size: 0.05 valid_size: 0.05
eval_every_n_epochs: 2 eval_every_n_epochs: 2
continue_training: None continue_training: Mar10_21-50-05_thallessilva
log_every_n_steps: 50

View File

@ -34,7 +34,7 @@ class GaussianBlur(object):
return sample return sample
def get_data_transform_opes(s, crop_size): def get_simclr_data_transform(s, crop_size):
# get a set of data augmentation transformations as described in the SimCLR paper. # get a set of data augmentation transformations as described in the SimCLR paper.
color_jitter = transforms.ColorJitter(0.8 * s, 0.8 * s, 0.8 * s, 0.2 * s) color_jitter = transforms.ColorJitter(0.8 * s, 0.8 * s, 0.8 * s, 0.2 * s)
data_transforms = transforms.Compose([transforms.RandomResizedCrop(size=crop_size), data_transforms = transforms.Compose([transforms.RandomResizedCrop(size=crop_size),

View File

@ -13,7 +13,7 @@ import torch.nn.functional as F
import numpy as np import numpy as np
from models.resnet_simclr import ResNetSimCLR from models.resnet_simclr import ResNetSimCLR
from utils import get_similarity_function, get_train_validation_data_loaders from utils import get_similarity_function, get_train_validation_data_loaders
from data_aug.data_transform import DataTransform, get_data_transform_opes from data_aug.data_transform import DataTransform, get_simclr_data_transform
torch.manual_seed(0) torch.manual_seed(0)
np.random.seed(0) np.random.seed(0)
@ -25,34 +25,34 @@ out_dim = config['out_dim']
temperature = config['temperature'] temperature = config['temperature']
use_cosine_similarity = config['use_cosine_similarity'] use_cosine_similarity = config['use_cosine_similarity']
data_augment = get_data_transform_opes(s=config['s'], crop_size=96) data_augment = get_simclr_data_transform(s=config['s'], crop_size=96)
train_dataset = datasets.STL10('./data', split='train+unlabeled', download=True, transform=DataTransform(data_augment)) train_dataset = datasets.STL10('./data', split='train+unlabeled', download=True, transform=DataTransform(data_augment))
train_loader, valid_loader = get_train_validation_data_loaders(train_dataset, config) train_loader, valid_loader = get_train_validation_data_loaders(train_dataset, **config)
# model = Encoder(out_dim=out_dim) # model = Encoder(out_dim=out_dim)
model = ResNetSimCLR(base_model=config["base_convnet"], out_dim=out_dim) model = ResNetSimCLR(base_model=config["base_convnet"], out_dim=out_dim)
if eval(config['continue_training']): if config['continue_training']:
model_id = eval(config['continue_training']) checkpoints_folder = os.path.join('./runs', config['continue_training'], 'checkpoints')
checkpoints_folder = os.path.join('./runs', model_id, 'checkpoints')
state_dict = torch.load(os.path.join(checkpoints_folder, 'model.pth')) state_dict = torch.load(os.path.join(checkpoints_folder, 'model.pth'))
model.load_state_dict(state_dict) model.load_state_dict(state_dict)
print("Loaded pre-trained model with success.")
train_gpu = torch.cuda.is_available() train_gpu = torch.cuda.is_available()
print("Is gpu available:", train_gpu) print("Is gpu available:", train_gpu)
# moves the model parameters to gpu # moves the model parameters to gpu
if train_gpu: if train_gpu:
model.cuda() model = model.cuda()
criterion = torch.nn.CrossEntropyLoss(reduction='sum') criterion = torch.nn.CrossEntropyLoss(reduction='sum')
optimizer = optim.Adam(model.parameters(), 3e-4) optimizer = optim.Adam(model.parameters(), 3e-4)
train_writer = SummaryWriter() train_writer = SummaryWriter()
_, similarity_func = get_similarity_function(use_cosine_similarity) similarity_func = get_similarity_function(use_cosine_similarity)
megative_mask = (1 - torch.eye(2 * batch_size)).type(torch.bool) megative_mask = (1 - torch.eye(2 * batch_size)).type(torch.bool)
labels = (np.eye((2 * batch_size), 2 * batch_size - 1, k=-batch_size) + np.eye((2 * batch_size), 2 * batch_size - 1, labels = (np.eye((2 * batch_size), 2 * batch_size - 1, k=-batch_size) + np.eye((2 * batch_size), 2 * batch_size - 1,
@ -61,19 +61,21 @@ labels = torch.from_numpy(labels)
softmax = torch.nn.Softmax(dim=-1) softmax = torch.nn.Softmax(dim=-1)
if train_gpu: if train_gpu:
labels.cuda() labels = labels.cuda()
def step(xis, xjs): def step(xis, xjs):
# get the representations and the projections # get the representations and the projections
ris, zis = model(xis) # [N,C] ris, zis = model(xis) # [N,C]
train_writer.add_histogram("xi_repr", ris, global_step=n_iter)
train_writer.add_histogram("xi_latent", zis, global_step=n_iter)
# get the representations and the projections # get the representations and the projections
rjs, zjs = model(xjs) # [N,C] rjs, zjs = model(xjs) # [N,C]
train_writer.add_histogram("xj_repr", rjs, global_step=n_iter)
train_writer.add_histogram("xj_latent", zjs, global_step=n_iter) if n_iter % config['log_every_n_steps'] == 0:
train_writer.add_histogram("xi_repr", ris, global_step=n_iter)
train_writer.add_histogram("xi_latent", zis, global_step=n_iter)
train_writer.add_histogram("xj_repr", rjs, global_step=n_iter)
train_writer.add_histogram("xj_latent", zjs, global_step=n_iter)
# normalize projection feature vectors # normalize projection feature vectors
zis = F.normalize(zis, dim=1) zis = F.normalize(zis, dim=1)

View File

@ -7,24 +7,24 @@ np.random.seed(0)
cosine_similarity = torch.nn.CosineSimilarity(dim=-1) cosine_similarity = torch.nn.CosineSimilarity(dim=-1)
def get_train_validation_data_loaders(train_dataset, config): def get_train_validation_data_loaders(train_dataset, batch_size, num_workers, valid_size, **ignored):
# obtain training indices that will be used for validation # obtain training indices that will be used for validation
num_train = len(train_dataset) num_train = len(train_dataset)
indices = list(range(num_train)) indices = list(range(num_train))
np.random.shuffle(indices) np.random.shuffle(indices)
split = int(np.floor(config['valid_size'] * num_train))
split = int(np.floor(valid_size * num_train))
train_idx, valid_idx = indices[split:], indices[:split] train_idx, valid_idx = indices[split:], indices[:split]
# define samplers for obtaining training and validation batches # define samplers for obtaining training and validation batches
train_sampler = SubsetRandomSampler(train_idx) train_sampler = SubsetRandomSampler(train_idx)
valid_sampler = SubsetRandomSampler(valid_idx) valid_sampler = SubsetRandomSampler(valid_idx)
train_loader = DataLoader(train_dataset, batch_size=config['batch_size'], sampler=train_sampler, train_loader = DataLoader(train_dataset, batch_size=batch_size, sampler=train_sampler,
num_workers=config['num_workers'], drop_last=True, shuffle=False) num_workers=num_workers, drop_last=True, shuffle=False)
valid_loader = DataLoader(train_dataset, batch_size=config['batch_size'], sampler=valid_sampler, valid_loader = DataLoader(train_dataset, batch_size=batch_size, sampler=valid_sampler,
num_workers=config['num_workers'], num_workers=num_workers, drop_last=True)
drop_last=True)
return train_loader, valid_loader return train_loader, valid_loader
@ -39,14 +39,6 @@ def get_negative_mask(batch_size):
return negative_mask return negative_mask
def _dot_simililarity_dim1(x, y):
# x shape: (N, 1, C)
# y shape: (N, C, 1)
# v shape: (N, 1, 1)
v = torch.bmm(x.unsqueeze(1), y.unsqueeze(2)) #
return v
def _dot_simililarity_dim2(x, y): def _dot_simililarity_dim2(x, y):
v = torch.tensordot(x.unsqueeze(1), y.T.unsqueeze(0), dims=2) v = torch.tensordot(x.unsqueeze(1), y.T.unsqueeze(0), dims=2)
# x shape: (N, 1, C) # x shape: (N, 1, C)
@ -55,11 +47,6 @@ def _dot_simililarity_dim2(x, y):
return v return v
def _cosine_simililarity_dim1(x, y):
v = cosine_similarity(x, y)
return v
def _cosine_simililarity_dim2(x, y): def _cosine_simililarity_dim2(x, y):
# x shape: (N, 1, C) # x shape: (N, 1, C)
# y shape: (1, 2N, C) # y shape: (1, 2N, C)
@ -70,6 +57,6 @@ def _cosine_simililarity_dim2(x, y):
def get_similarity_function(use_cosine_similarity): def get_similarity_function(use_cosine_similarity):
if use_cosine_similarity: if use_cosine_similarity:
return _cosine_simililarity_dim1, _cosine_simililarity_dim2 return _cosine_simililarity_dim2
else: else:
return _dot_simililarity_dim1, _dot_simililarity_dim2 return _dot_simililarity_dim2