diff --git a/config.yaml b/config.yaml index d8d9365..df64b41 100644 --- a/config.yaml +++ b/config.yaml @@ -8,4 +8,5 @@ epochs: 50 num_workers: 0 valid_size: 0.05 eval_every_n_epochs: 2 -continue_training: None \ No newline at end of file +continue_training: Mar10_21-50-05_thallessilva +log_every_n_steps: 50 diff --git a/data_aug/data_transform.py b/data_aug/data_transform.py index 7a49d09..aded7ef 100644 --- a/data_aug/data_transform.py +++ b/data_aug/data_transform.py @@ -34,7 +34,7 @@ class GaussianBlur(object): return sample -def get_data_transform_opes(s, crop_size): +def get_simclr_data_transform(s, crop_size): # get a set of data augmentation transformations as described in the SimCLR paper. color_jitter = transforms.ColorJitter(0.8 * s, 0.8 * s, 0.8 * s, 0.2 * s) data_transforms = transforms.Compose([transforms.RandomResizedCrop(size=crop_size), diff --git a/train.py b/train.py index c11f252..c9d0894 100644 --- a/train.py +++ b/train.py @@ -13,7 +13,7 @@ import torch.nn.functional as F import numpy as np from models.resnet_simclr import ResNetSimCLR from utils import get_similarity_function, get_train_validation_data_loaders -from data_aug.data_transform import DataTransform, get_data_transform_opes +from data_aug.data_transform import DataTransform, get_simclr_data_transform torch.manual_seed(0) np.random.seed(0) @@ -25,34 +25,34 @@ out_dim = config['out_dim'] temperature = config['temperature'] use_cosine_similarity = config['use_cosine_similarity'] -data_augment = get_data_transform_opes(s=config['s'], crop_size=96) +data_augment = get_simclr_data_transform(s=config['s'], crop_size=96) train_dataset = datasets.STL10('./data', split='train+unlabeled', download=True, transform=DataTransform(data_augment)) -train_loader, valid_loader = get_train_validation_data_loaders(train_dataset, config) +train_loader, valid_loader = get_train_validation_data_loaders(train_dataset, **config) # model = Encoder(out_dim=out_dim) model = ResNetSimCLR(base_model=config["base_convnet"], out_dim=out_dim) -if eval(config['continue_training']): - model_id = eval(config['continue_training']) - checkpoints_folder = os.path.join('./runs', model_id, 'checkpoints') +if config['continue_training']: + checkpoints_folder = os.path.join('./runs', config['continue_training'], 'checkpoints') state_dict = torch.load(os.path.join(checkpoints_folder, 'model.pth')) model.load_state_dict(state_dict) + print("Loaded pre-trained model with success.") train_gpu = torch.cuda.is_available() print("Is gpu available:", train_gpu) # moves the model parameters to gpu if train_gpu: - model.cuda() + model = model.cuda() criterion = torch.nn.CrossEntropyLoss(reduction='sum') optimizer = optim.Adam(model.parameters(), 3e-4) train_writer = SummaryWriter() -_, similarity_func = get_similarity_function(use_cosine_similarity) +similarity_func = get_similarity_function(use_cosine_similarity) megative_mask = (1 - torch.eye(2 * batch_size)).type(torch.bool) labels = (np.eye((2 * batch_size), 2 * batch_size - 1, k=-batch_size) + np.eye((2 * batch_size), 2 * batch_size - 1, @@ -61,19 +61,21 @@ labels = torch.from_numpy(labels) softmax = torch.nn.Softmax(dim=-1) if train_gpu: - labels.cuda() + labels = labels.cuda() def step(xis, xjs): # get the representations and the projections ris, zis = model(xis) # [N,C] - train_writer.add_histogram("xi_repr", ris, global_step=n_iter) - train_writer.add_histogram("xi_latent", zis, global_step=n_iter) # get the representations and the projections rjs, zjs = model(xjs) # [N,C] - train_writer.add_histogram("xj_repr", rjs, global_step=n_iter) - train_writer.add_histogram("xj_latent", zjs, global_step=n_iter) + + if n_iter % config['log_every_n_steps'] == 0: + train_writer.add_histogram("xi_repr", ris, global_step=n_iter) + train_writer.add_histogram("xi_latent", zis, global_step=n_iter) + train_writer.add_histogram("xj_repr", rjs, global_step=n_iter) + train_writer.add_histogram("xj_latent", zjs, global_step=n_iter) # normalize projection feature vectors zis = F.normalize(zis, dim=1) diff --git a/utils.py b/utils.py index 36d4f89..7759e1d 100644 --- a/utils.py +++ b/utils.py @@ -7,24 +7,24 @@ np.random.seed(0) cosine_similarity = torch.nn.CosineSimilarity(dim=-1) -def get_train_validation_data_loaders(train_dataset, config): +def get_train_validation_data_loaders(train_dataset, batch_size, num_workers, valid_size, **ignored): # obtain training indices that will be used for validation num_train = len(train_dataset) indices = list(range(num_train)) np.random.shuffle(indices) - split = int(np.floor(config['valid_size'] * num_train)) + + split = int(np.floor(valid_size * num_train)) train_idx, valid_idx = indices[split:], indices[:split] # define samplers for obtaining training and validation batches train_sampler = SubsetRandomSampler(train_idx) valid_sampler = SubsetRandomSampler(valid_idx) - train_loader = DataLoader(train_dataset, batch_size=config['batch_size'], sampler=train_sampler, - num_workers=config['num_workers'], drop_last=True, shuffle=False) + train_loader = DataLoader(train_dataset, batch_size=batch_size, sampler=train_sampler, + num_workers=num_workers, drop_last=True, shuffle=False) - valid_loader = DataLoader(train_dataset, batch_size=config['batch_size'], sampler=valid_sampler, - num_workers=config['num_workers'], - drop_last=True) + valid_loader = DataLoader(train_dataset, batch_size=batch_size, sampler=valid_sampler, + num_workers=num_workers, drop_last=True) return train_loader, valid_loader @@ -39,14 +39,6 @@ def get_negative_mask(batch_size): return negative_mask -def _dot_simililarity_dim1(x, y): - # x shape: (N, 1, C) - # y shape: (N, C, 1) - # v shape: (N, 1, 1) - v = torch.bmm(x.unsqueeze(1), y.unsqueeze(2)) # - return v - - def _dot_simililarity_dim2(x, y): v = torch.tensordot(x.unsqueeze(1), y.T.unsqueeze(0), dims=2) # x shape: (N, 1, C) @@ -55,11 +47,6 @@ def _dot_simililarity_dim2(x, y): return v -def _cosine_simililarity_dim1(x, y): - v = cosine_similarity(x, y) - return v - - def _cosine_simililarity_dim2(x, y): # x shape: (N, 1, C) # y shape: (1, 2N, C) @@ -70,6 +57,6 @@ def _cosine_simililarity_dim2(x, y): def get_similarity_function(use_cosine_similarity): if use_cosine_similarity: - return _cosine_simililarity_dim1, _cosine_simililarity_dim2 + return _cosine_simililarity_dim2 else: - return _dot_simililarity_dim1, _dot_simililarity_dim2 + return _dot_simililarity_dim2