Major refactor

pull/24/head
Thalles Silva 2021-01-17 14:12:17 -03:00
parent e8a690ae4f
commit 2c9536f731
11 changed files with 263 additions and 340 deletions

View File

@ -1,21 +0,0 @@
batch_size: 512
epochs: 80
eval_every_n_epochs: 1
fine_tune_from: None
log_every_n_steps: 50
weight_decay: 10e-6
fp16_precision: False
model:
out_dim: 256
base_model: "resnet18"
dataset:
s: 1
input_shape: (96,96,3)
num_workers: 0
valid_size: 0.05
loss:
temperature: 0.5
use_cosine_similarity: True

View File

@ -0,0 +1,37 @@
from torchvision.transforms import transforms
from data_aug.gaussian_blur import GaussianBlur
from torchvision import transforms, datasets
from data_aug.view_generator import ContrastiveLearningViewGenerator
class ContrastiveLearningDataset:
def __init__(self, root_folder):
self.root_folder = root_folder
@staticmethod
def get_simclr_pipeline_transform(size, s=1):
"""Return a set of data augmentation transformations as described in the SimCLR paper."""
color_jitter = transforms.ColorJitter(0.8 * s, 0.8 * s, 0.8 * s, 0.2 * s)
data_transforms = transforms.Compose([transforms.RandomResizedCrop(size=size),
transforms.RandomHorizontalFlip(),
transforms.RandomApply([color_jitter], p=0.8),
transforms.RandomGrayscale(p=0.2),
GaussianBlur(kernel_size=int(0.1 * size)),
transforms.ToTensor()])
return data_transforms
def get_dataset(self, name, n_views):
valid_datasets = {'cifar10': lambda: datasets.CIFAR10(self.root_folder, train=True,
transform=ContrastiveLearningViewGenerator(
self.get_simclr_pipeline_transform(32),
n_views),
download=True),
'stl10': lambda: datasets.STL10(self.root_folder, split='unlabeled',
transform=ContrastiveLearningViewGenerator(
self.get_simclr_pipeline_transform(96),
n_views),
download=True)}
dataset = valid_datasets.get(name, 'Invalid dataset option.')()
return dataset

View File

@ -1,68 +0,0 @@
import numpy as np
from torch.utils.data import DataLoader
from torch.utils.data.sampler import SubsetRandomSampler
import torchvision.transforms as transforms
from data_aug.gaussian_blur import GaussianBlur
from torchvision import datasets
np.random.seed(0)
class DataSetWrapper(object):
def __init__(self, batch_size, num_workers, valid_size, input_shape, s):
self.batch_size = batch_size
self.num_workers = num_workers
self.valid_size = valid_size
self.s = s
self.input_shape = eval(input_shape)
def get_data_loaders(self):
data_augment = self._get_simclr_pipeline_transform()
train_dataset = datasets.STL10('./data', split='train+unlabeled', download=True,
transform=SimCLRDataTransform(data_augment))
train_loader, valid_loader = self.get_train_validation_data_loaders(train_dataset)
return train_loader, valid_loader
def _get_simclr_pipeline_transform(self):
# get a set of data augmentation transformations as described in the SimCLR paper.
color_jitter = transforms.ColorJitter(0.8 * self.s, 0.8 * self.s, 0.8 * self.s, 0.2 * self.s)
data_transforms = transforms.Compose([transforms.RandomResizedCrop(size=self.input_shape[0]),
transforms.RandomHorizontalFlip(),
transforms.RandomApply([color_jitter], p=0.8),
transforms.RandomGrayscale(p=0.2),
GaussianBlur(kernel_size=int(0.1 * self.input_shape[0])),
transforms.ToTensor()])
return data_transforms
def get_train_validation_data_loaders(self, train_dataset):
# obtain training indices that will be used for validation
num_train = len(train_dataset)
indices = list(range(num_train))
np.random.shuffle(indices)
split = int(np.floor(self.valid_size * num_train))
train_idx, valid_idx = indices[split:], indices[:split]
# define samplers for obtaining training and validation batches
train_sampler = SubsetRandomSampler(train_idx)
valid_sampler = SubsetRandomSampler(valid_idx)
train_loader = DataLoader(train_dataset, batch_size=self.batch_size, sampler=train_sampler,
num_workers=self.num_workers, drop_last=True, shuffle=False)
valid_loader = DataLoader(train_dataset, batch_size=self.batch_size, sampler=valid_sampler,
num_workers=self.num_workers, drop_last=True)
return train_loader, valid_loader
class SimCLRDataTransform(object):
def __init__(self, transform):
self.transform = transform
def __call__(self, sample):
xi = self.transform(sample)
xj = self.transform(sample)
return xi, xj

View File

@ -1,25 +1,48 @@
import cv2
import numpy as np
import torch
from torch import nn
from torchvision.transforms import transforms
np.random.seed(0)
class GaussianBlur(object):
# Implements Gaussian blur as described in the SimCLR paper
def __init__(self, kernel_size, min=0.1, max=2.0):
self.min = min
self.max = max
# kernel size is set to be 10% of the image height/width
self.kernel_size = kernel_size
"""blur a single image on CPU"""
def __init__(self, kernel_size):
radias = kernel_size // 2
kernel_size = radias * 2 + 1
self.blur_h = nn.Conv2d(3, 3, kernel_size=(kernel_size, 1),
stride=1, padding=0, bias=False, groups=3)
self.blur_v = nn.Conv2d(3, 3, kernel_size=(1, kernel_size),
stride=1, padding=0, bias=False, groups=3)
self.k = kernel_size
self.r = radias
def __call__(self, sample):
sample = np.array(sample)
self.blur = nn.Sequential(
nn.ReflectionPad2d(radias),
self.blur_h,
self.blur_v
)
# blur the image with a 50% chance
prob = np.random.random_sample()
self.pil_to_tensor = transforms.ToTensor()
self.tensor_to_pil = transforms.ToPILImage()
if prob < 0.5:
sigma = (self.max - self.min) * np.random.random_sample() + self.min
sample = cv2.GaussianBlur(sample, (self.kernel_size, self.kernel_size), sigma)
def __call__(self, img):
img = self.pil_to_tensor(img).unsqueeze(0)
return sample
sigma = np.random.uniform(0.1, 2.0)
x = np.arange(-self.r, self.r + 1)
x = np.exp(-np.power(x, 2) / (2 * sigma * sigma))
x = x / x.sum()
x = torch.from_numpy(x).view(1, -1).repeat(3, 1)
self.blur_h.weight.data.copy_(x.view(3, 1, self.k, 1))
self.blur_v.weight.data.copy_(x.view(3, 1, 1, self.k))
with torch.no_grad():
img = self.blur(img)
img = img.squeeze()
img = self.tensor_to_pil(img)
return img

View File

@ -0,0 +1,14 @@
import numpy as np
np.random.seed(0)
class ContrastiveLearningViewGenerator(object):
"""Take two random crops of one image as the query and key."""
def __init__(self, base_transform, n_views=2):
self.base_transform = base_transform
self.n_views = n_views
def __call__(self, x):
return [self.base_transform(x) for i in range(self.n_views)]

View File

@ -0,0 +1,6 @@
class BaseSimCLRException(Exception):
"""Base exception"""
class InvalidBackboneError(BaseSimCLRException):
"""Raised when the choice of backbone Convnet is invalid."""

View File

@ -1,65 +0,0 @@
import torch
import numpy as np
class NTXentLoss(torch.nn.Module):
def __init__(self, device, batch_size, temperature, use_cosine_similarity):
super(NTXentLoss, self).__init__()
self.batch_size = batch_size
self.temperature = temperature
self.device = device
self.softmax = torch.nn.Softmax(dim=-1)
self.mask_samples_from_same_repr = self._get_correlated_mask().type(torch.bool)
self.similarity_function = self._get_similarity_function(use_cosine_similarity)
self.criterion = torch.nn.CrossEntropyLoss(reduction="sum")
def _get_similarity_function(self, use_cosine_similarity):
if use_cosine_similarity:
self._cosine_similarity = torch.nn.CosineSimilarity(dim=-1)
return self._cosine_simililarity
else:
return self._dot_simililarity
def _get_correlated_mask(self):
diag = np.eye(2 * self.batch_size)
l1 = np.eye((2 * self.batch_size), 2 * self.batch_size, k=-self.batch_size)
l2 = np.eye((2 * self.batch_size), 2 * self.batch_size, k=self.batch_size)
mask = torch.from_numpy((diag + l1 + l2))
mask = (1 - mask).type(torch.bool)
return mask.to(self.device)
@staticmethod
def _dot_simililarity(x, y):
v = torch.tensordot(x.unsqueeze(1), y.T.unsqueeze(0), dims=2)
# x shape: (N, 1, C)
# y shape: (1, C, 2N)
# v shape: (N, 2N)
return v
def _cosine_simililarity(self, x, y):
# x shape: (N, 1, C)
# y shape: (1, 2N, C)
# v shape: (N, 2N)
v = self._cosine_similarity(x.unsqueeze(1), y.unsqueeze(0))
return v
def forward(self, zis, zjs):
representations = torch.cat([zjs, zis], dim=0)
similarity_matrix = self.similarity_function(representations, representations)
# filter out the scores from the positive samples
l_pos = torch.diag(similarity_matrix, self.batch_size)
r_pos = torch.diag(similarity_matrix, -self.batch_size)
positives = torch.cat([l_pos, r_pos]).view(2 * self.batch_size, 1)
negatives = similarity_matrix[self.mask_samples_from_same_repr].view(2 * self.batch_size, -1)
logits = torch.cat((positives, negatives), dim=1)
logits /= self.temperature
labels = torch.zeros(2 * self.batch_size).to(self.device).long()
loss = self.criterion(logits, labels)
return loss / (2 * self.batch_size)

View File

@ -1,43 +0,0 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.models as models
class Encoder(nn.Module):
def __init__(self, out_dim=64):
super(Encoder, self).__init__()
self.conv1 = nn.Conv2d(3, 16, kernel_size=3, stride=1, padding=1)
self.conv2 = nn.Conv2d(16, 32, kernel_size=3, stride=1, padding=1)
self.conv3 = nn.Conv2d(32, 64, kernel_size=3, stride=1, padding=1)
self.conv4 = nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1)
self.pool = nn.MaxPool2d(2, 2)
# projection MLP
self.l1 = nn.Linear(64, 64)
self.l2 = nn.Linear(64, out_dim)
def forward(self, x):
x = self.conv1(x)
x = F.relu(x)
x = self.pool(x)
x = self.conv2(x)
x = F.relu(x)
x = self.pool(x)
x = self.conv3(x)
x = F.relu(x)
x = self.pool(x)
x = self.conv4(x)
x = F.relu(x)
x = self.pool(x)
h = torch.mean(x, dim=[2, 3])
x = self.l1(h)
x = F.relu(x)
x = self.l2(x)
return h, x

View File

@ -1,37 +1,30 @@
import torch.nn as nn
import torch.nn.functional as F
import torchvision.models as models
from exceptions.exceptions import InvalidBackboneError
class ResNetSimCLR(nn.Module):
def __init__(self, base_model, out_dim):
super(ResNetSimCLR, self).__init__()
self.resnet_dict = {"resnet18": models.resnet18(pretrained=False),
"resnet50": models.resnet50(pretrained=False)}
self.resnet_dict = {"resnet18": models.resnet18(pretrained=False, num_classes=out_dim),
"resnet50": models.resnet50(pretrained=False, num_classes=out_dim)}
resnet = self._get_basemodel(base_model)
num_ftrs = resnet.fc.in_features
self.backbone = self._get_basemodel(base_model)
dim_mlp = self.backbone.fc.in_features
self.features = nn.Sequential(*list(resnet.children())[:-1])
# projection MLP
self.l1 = nn.Linear(num_ftrs, num_ftrs)
self.l2 = nn.Linear(num_ftrs, out_dim)
# add mlp projection head
self.backbone.fc = nn.Sequential(nn.Linear(dim_mlp, dim_mlp), nn.ReLU(), self.backbone.fc)
def _get_basemodel(self, model_name):
try:
model = self.resnet_dict[model_name]
print("Feature extractor:", model_name)
except KeyError:
raise InvalidBackboneError(
"Invalid backbone architecture. Check the config file and pass one of: resnet18 or resnet50")
else:
return model
except:
raise ("Invalid model name. Check the config file and pass one of: resnet18 or resnet50")
def forward(self, x):
h = self.features(x)
h = h.squeeze()
x = self.l1(h)
x = F.relu(x)
x = self.l2(x)
return h, x
return self.backbone(x)

88
run.py
View File

@ -1,14 +1,90 @@
import argparse
import torch
from torchvision import models
from data_aug.contrastive_learning_dataset import ContrastiveLearningDataset
from models.resnet_simclr import ResNetSimCLR
from simclr import SimCLR
import yaml
from data_aug.dataset_wrapper import DataSetWrapper
import torch.backends.cudnn as cudnn
model_names = sorted(name for name in models.__dict__
if name.islower() and not name.startswith("__")
and callable(models.__dict__[name]))
parser = argparse.ArgumentParser(description='PyTorch SimCLR')
parser.add_argument('-data', metavar='DIR', default='./datasets',
help='path to dataset')
parser.add_argument('-dataset-name', default='stl10',
help='dataset name', choices=['stl10', 'cifar10'])
parser.add_argument('-a', '--arch', metavar='ARCH', default='resnet18',
choices=model_names,
help='model architecture: ' +
' | '.join(model_names) +
' (default: resnet50)')
parser.add_argument('-j', '--workers', default=12, type=int, metavar='N',
help='number of data loading workers (default: 32)')
parser.add_argument('--epochs', default=200, type=int, metavar='N',
help='number of total epochs to run')
parser.add_argument('-b', '--batch-size', default=256, type=int,
metavar='N',
help='mini-batch size (default: 256), this is the total '
'batch size of all GPUs on the current node when '
'using Data Parallel or Distributed Data Parallel')
parser.add_argument('--lr', '--learning-rate', default=0.0003, type=float,
metavar='LR', help='initial learning rate', dest='lr')
parser.add_argument('--wd', '--weight-decay', default=1e-4, type=float,
metavar='W', help='weight decay (default: 1e-4)',
dest='weight_decay')
parser.add_argument('--resume', default='', type=str, metavar='PATH',
help='path to latest checkpoint (default: none)')
parser.add_argument('--seed', default=None, type=int,
help='seed for initializing training. ')
parser.add_argument('--disable-cuda', action='store_true',
help='Disable CUDA')
parser.add_argument('--fp16_precision', default=False, type=bool,
help='Whether or not to use 16-bit precision GPU training.')
parser.add_argument('--out_dim', default=128, type=int,
help='feature dimension (default: 128)')
parser.add_argument('--log-every-n-steps', default=100, type=int,
help='Log every n steps')
parser.add_argument('--temperature', default=0.07, type=float,
help='softmax temperature (default: 0.07)')
parser.add_argument('--n-views', default=2, type=int, metavar='N',
help='Number of views for contrastive learning training.')
parser.add_argument('--gpu-index', default=0, type=int, help='Gpu index.')
def main():
config = yaml.load(open("config.yaml", "r"), Loader=yaml.FullLoader)
dataset = DataSetWrapper(config['batch_size'], **config['dataset'])
args = parser.parse_args()
simclr = SimCLR(dataset, config)
simclr.train()
# check if gpu training is available
if not args.disable_cuda and torch.cuda.is_available():
args.device = torch.device('cuda')
cudnn.deterministic = True
cudnn.benchmark = True
else:
args.device = torch.device('cpu')
args.gpu_index = -1
dataset = ContrastiveLearningDataset(args.data)
train_dataset = dataset.get_dataset(args.dataset_name, args.n_views)
train_loader = torch.utils.data.DataLoader(
train_dataset, batch_size=args.batch_size, shuffle=True,
num_workers=args.workers, pin_memory=True, drop_last=True)
model = ResNetSimCLR(base_model=args.arch, out_dim=args.out_dim)
optimizer = torch.optim.Adam(model.parameters(), args.lr, weight_decay=args.weight_decay)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=len(train_loader), eta_min=0,
last_epoch=-1)
# Its a no-op if the 'gpu_index' argument is a negative integer or None.
with torch.cuda.device(args.gpu_index):
simclr = SimCLR(model=model, optimizer=optimizer, scheduler=scheduler, args=args)
simclr.train(train_loader)
if __name__ == "__main__":

175
simclr.py
View File

@ -1,11 +1,14 @@
import torch
from models.resnet_simclr import ResNetSimCLR
from torch.utils.tensorboard import SummaryWriter
import torch.nn.functional as F
from loss.nt_xent import NTXentLoss
import os
import shutil
import sys
import yaml
import torch
from torch.utils.tensorboard import SummaryWriter
import torch.nn.functional as F
import logging
from tqdm import tqdm
torch.manual_seed(0)
apex_support = False
try:
@ -17,135 +20,103 @@ except:
print("Please install apex for mixed precision training from: https://github.com/NVIDIA/apex")
apex_support = False
import numpy as np
torch.manual_seed(0)
def save_checkpoint(state, is_best, filename='checkpoint.pth.tar'):
torch.save(state, filename)
if is_best:
shutil.copyfile(filename, 'model_best.pth.tar')
def _save_config_file(model_checkpoints_folder):
def _save_config_file(model_checkpoints_folder, args):
if not os.path.exists(model_checkpoints_folder):
os.makedirs(model_checkpoints_folder)
shutil.copy('./config.yaml', os.path.join(model_checkpoints_folder, 'config.yaml'))
with open(os.path.join(model_checkpoints_folder, 'config.yml'), 'w') as outfile:
yaml.dump(args, outfile, default_flow_style=False)
class SimCLR(object):
def __init__(self, dataset, config):
self.config = config
self.device = self._get_device()
def __init__(self, *args, **kwargs):
self.args = kwargs['args']
self.model = kwargs['model'].to(self.args.device)
self.optimizer = kwargs['optimizer']
self.scheduler = kwargs['scheduler']
self.writer = SummaryWriter()
self.dataset = dataset
self.nt_xent_criterion = NTXentLoss(self.device, config['batch_size'], **config['loss'])
logging.basicConfig(filename=os.path.join(self.writer.log_dir, 'training.log'), level=logging.DEBUG)
self.criterion = torch.nn.CrossEntropyLoss().to(self.args.device)
def _get_device(self):
device = 'cuda' if torch.cuda.is_available() else 'cpu'
print("Running on:", device)
return device
def info_nce_loss(self, features):
batch_targets = torch.arange(self.args.batch_size, dtype=torch.long).to(self.args.device)
batch_targets = torch.cat(self.args.n_views * [batch_targets])
def _step(self, model, xis, xjs, n_iter):
features = F.normalize(features, dim=1)
# get the representations and the projections
ris, zis = model(xis) # [N,C]
similarity_matrix = torch.matmul(features, features.T)
# assert similarity_matrix.shape == (
# self.args.n_views * self.args.batch_size, self.args.n_views * self.args.batch_size)
# get the representations and the projections
rjs, zjs = model(xjs) # [N,C]
mask = torch.eye(len(batch_targets)).to(self.args.device)
similarities = similarity_matrix[~mask.bool()].view(similarity_matrix.shape[0], -1)
similarities = similarities / self.args.temperature
return similarities, batch_targets
# normalize projection feature vectors
zis = F.normalize(zis, dim=1)
zjs = F.normalize(zjs, dim=1)
def train(self, train_loader):
loss = self.nt_xent_criterion(zis, zjs)
return loss
def train(self):
train_loader, valid_loader = self.dataset.get_data_loaders()
model = ResNetSimCLR(**self.config["model"]).to(self.device)
model = self._load_pre_trained_weights(model)
optimizer = torch.optim.Adam(model.parameters(), 3e-4, weight_decay=eval(self.config['weight_decay']))
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=len(train_loader), eta_min=0,
last_epoch=-1)
if apex_support and self.config['fp16_precision']:
model, optimizer = amp.initialize(model, optimizer,
opt_level='O2',
keep_batchnorm_fp32=True)
if apex_support and self.args.fp16_precision:
logging.debug("Using apex for fp16 precision training.")
self.model, self.optimizer = amp.initialize(self.model, self.optimizer,
opt_level='O2',
keep_batchnorm_fp32=True)
model_checkpoints_folder = os.path.join(self.writer.log_dir, 'checkpoints')
# save config file
_save_config_file(model_checkpoints_folder)
_save_config_file(model_checkpoints_folder, self.args)
n_iter = 0
valid_n_iter = 0
best_valid_loss = np.inf
logging.info(f"Start SimCLR training for {self.args.epochs} epochs.")
logging.info(f"Training with gpu: {self.args.disable_cuda}.")
for epoch_counter in range(self.config['epochs']):
for (xis, xjs), _ in train_loader:
optimizer.zero_grad()
for epoch_counter in range(self.args.epochs):
for images, _ in tqdm(train_loader):
images = torch.cat(images, dim=0)
xis = xis.to(self.device)
xjs = xjs.to(self.device)
images = images.to(self.args.device)
loss = self._step(model, xis, xjs, n_iter)
features = self.model(images)
logits, labels = self.info_nce_loss(features)
loss = self.criterion(logits, labels)
if n_iter % self.config['log_every_n_steps'] == 0:
self.writer.add_scalar('train_loss', loss, global_step=n_iter)
if apex_support and self.config['fp16_precision']:
with amp.scale_loss(loss, optimizer) as scaled_loss:
self.optimizer.zero_grad()
if apex_support and self.args.fp16_precision:
with amp.scale_loss(loss, self.optimizer) as scaled_loss:
scaled_loss.backward()
else:
loss.backward()
optimizer.step()
self.optimizer.step()
if n_iter % self.args.log_every_n_steps == 0:
predictions = torch.argmax(logits, dim=1)
acc = 100 * (predictions == labels).float().mean()
self.writer.add_scalar('loss', loss, global_step=n_iter)
self.writer.add_scalar('acc/top1', acc, global_step=n_iter)
self.writer.add_scalar('learning_rate', self.scheduler.get_lr()[0], global_step=n_iter)
n_iter += 1
# validate the model if requested
if epoch_counter % self.config['eval_every_n_epochs'] == 0:
valid_loss = self._validate(model, valid_loader)
if valid_loss < best_valid_loss:
# save the model weights
best_valid_loss = valid_loss
torch.save(model.state_dict(), os.path.join(model_checkpoints_folder, 'model.pth'))
self.writer.add_scalar('validation_loss', valid_loss, global_step=valid_n_iter)
valid_n_iter += 1
# warmup for the first 10 epochs
if epoch_counter >= 10:
scheduler.step()
self.writer.add_scalar('cosine_lr_decay', scheduler.get_lr()[0], global_step=n_iter)
self.scheduler.step()
logging.debug(f"Epoch: {epoch_counter}\tLoss: {loss}\tTop1 accuracy: {acc}")
def _load_pre_trained_weights(self, model):
try:
checkpoints_folder = os.path.join('./runs', self.config['fine_tune_from'], 'checkpoints')
state_dict = torch.load(os.path.join(checkpoints_folder, 'model.pth'))
model.load_state_dict(state_dict)
print("Loaded pre-trained model with success.")
except FileNotFoundError:
print("Pre-trained weights not found. Training from scratch.")
return model
def _validate(self, model, valid_loader):
# validation steps
with torch.no_grad():
model.eval()
valid_loss = 0.0
counter = 0
for (xis, xjs), _ in valid_loader:
xis = xis.to(self.device)
xjs = xjs.to(self.device)
loss = self._step(model, xis, xjs, counter)
valid_loss += loss.item()
counter += 1
valid_loss /= counter
model.train()
return valid_loss
logging.info("Training has finished.")
# save model checkpoints
checkpoint_name = 'checkpoint_{:04d}.pth.tar'.format(self.args.epochs)
save_checkpoint({
'epoch': self.args.epochs,
'arch': self.args.arch,
'state_dict': self.model.state_dict(),
'optimizer': self.optimizer.state_dict(),
}, is_best=False, filename=os.path.join(self.writer.log_dir, checkpoint_name))
logging.info(f"Model checkpoint and metadata has been saved at {self.writer.log_dir}.")