mirror of https://github.com/sthalles/SimCLR.git
Major refactor
parent
e8a690ae4f
commit
2c9536f731
21
config.yaml
21
config.yaml
|
@ -1,21 +0,0 @@
|
|||
batch_size: 512
|
||||
epochs: 80
|
||||
eval_every_n_epochs: 1
|
||||
fine_tune_from: None
|
||||
log_every_n_steps: 50
|
||||
weight_decay: 10e-6
|
||||
fp16_precision: False
|
||||
|
||||
model:
|
||||
out_dim: 256
|
||||
base_model: "resnet18"
|
||||
|
||||
dataset:
|
||||
s: 1
|
||||
input_shape: (96,96,3)
|
||||
num_workers: 0
|
||||
valid_size: 0.05
|
||||
|
||||
loss:
|
||||
temperature: 0.5
|
||||
use_cosine_similarity: True
|
|
@ -0,0 +1,37 @@
|
|||
from torchvision.transforms import transforms
|
||||
from data_aug.gaussian_blur import GaussianBlur
|
||||
from torchvision import transforms, datasets
|
||||
from data_aug.view_generator import ContrastiveLearningViewGenerator
|
||||
|
||||
|
||||
class ContrastiveLearningDataset:
|
||||
def __init__(self, root_folder):
|
||||
self.root_folder = root_folder
|
||||
|
||||
@staticmethod
|
||||
def get_simclr_pipeline_transform(size, s=1):
|
||||
"""Return a set of data augmentation transformations as described in the SimCLR paper."""
|
||||
color_jitter = transforms.ColorJitter(0.8 * s, 0.8 * s, 0.8 * s, 0.2 * s)
|
||||
data_transforms = transforms.Compose([transforms.RandomResizedCrop(size=size),
|
||||
transforms.RandomHorizontalFlip(),
|
||||
transforms.RandomApply([color_jitter], p=0.8),
|
||||
transforms.RandomGrayscale(p=0.2),
|
||||
GaussianBlur(kernel_size=int(0.1 * size)),
|
||||
transforms.ToTensor()])
|
||||
return data_transforms
|
||||
|
||||
def get_dataset(self, name, n_views):
|
||||
valid_datasets = {'cifar10': lambda: datasets.CIFAR10(self.root_folder, train=True,
|
||||
transform=ContrastiveLearningViewGenerator(
|
||||
self.get_simclr_pipeline_transform(32),
|
||||
n_views),
|
||||
download=True),
|
||||
|
||||
'stl10': lambda: datasets.STL10(self.root_folder, split='unlabeled',
|
||||
transform=ContrastiveLearningViewGenerator(
|
||||
self.get_simclr_pipeline_transform(96),
|
||||
n_views),
|
||||
download=True)}
|
||||
|
||||
dataset = valid_datasets.get(name, 'Invalid dataset option.')()
|
||||
return dataset
|
|
@ -1,68 +0,0 @@
|
|||
import numpy as np
|
||||
from torch.utils.data import DataLoader
|
||||
from torch.utils.data.sampler import SubsetRandomSampler
|
||||
import torchvision.transforms as transforms
|
||||
from data_aug.gaussian_blur import GaussianBlur
|
||||
from torchvision import datasets
|
||||
|
||||
np.random.seed(0)
|
||||
|
||||
|
||||
class DataSetWrapper(object):
|
||||
|
||||
def __init__(self, batch_size, num_workers, valid_size, input_shape, s):
|
||||
self.batch_size = batch_size
|
||||
self.num_workers = num_workers
|
||||
self.valid_size = valid_size
|
||||
self.s = s
|
||||
self.input_shape = eval(input_shape)
|
||||
|
||||
def get_data_loaders(self):
|
||||
data_augment = self._get_simclr_pipeline_transform()
|
||||
|
||||
train_dataset = datasets.STL10('./data', split='train+unlabeled', download=True,
|
||||
transform=SimCLRDataTransform(data_augment))
|
||||
|
||||
train_loader, valid_loader = self.get_train_validation_data_loaders(train_dataset)
|
||||
return train_loader, valid_loader
|
||||
|
||||
def _get_simclr_pipeline_transform(self):
|
||||
# get a set of data augmentation transformations as described in the SimCLR paper.
|
||||
color_jitter = transforms.ColorJitter(0.8 * self.s, 0.8 * self.s, 0.8 * self.s, 0.2 * self.s)
|
||||
data_transforms = transforms.Compose([transforms.RandomResizedCrop(size=self.input_shape[0]),
|
||||
transforms.RandomHorizontalFlip(),
|
||||
transforms.RandomApply([color_jitter], p=0.8),
|
||||
transforms.RandomGrayscale(p=0.2),
|
||||
GaussianBlur(kernel_size=int(0.1 * self.input_shape[0])),
|
||||
transforms.ToTensor()])
|
||||
return data_transforms
|
||||
|
||||
def get_train_validation_data_loaders(self, train_dataset):
|
||||
# obtain training indices that will be used for validation
|
||||
num_train = len(train_dataset)
|
||||
indices = list(range(num_train))
|
||||
np.random.shuffle(indices)
|
||||
|
||||
split = int(np.floor(self.valid_size * num_train))
|
||||
train_idx, valid_idx = indices[split:], indices[:split]
|
||||
|
||||
# define samplers for obtaining training and validation batches
|
||||
train_sampler = SubsetRandomSampler(train_idx)
|
||||
valid_sampler = SubsetRandomSampler(valid_idx)
|
||||
|
||||
train_loader = DataLoader(train_dataset, batch_size=self.batch_size, sampler=train_sampler,
|
||||
num_workers=self.num_workers, drop_last=True, shuffle=False)
|
||||
|
||||
valid_loader = DataLoader(train_dataset, batch_size=self.batch_size, sampler=valid_sampler,
|
||||
num_workers=self.num_workers, drop_last=True)
|
||||
return train_loader, valid_loader
|
||||
|
||||
|
||||
class SimCLRDataTransform(object):
|
||||
def __init__(self, transform):
|
||||
self.transform = transform
|
||||
|
||||
def __call__(self, sample):
|
||||
xi = self.transform(sample)
|
||||
xj = self.transform(sample)
|
||||
return xi, xj
|
|
@ -1,25 +1,48 @@
|
|||
import cv2
|
||||
import numpy as np
|
||||
import torch
|
||||
from torch import nn
|
||||
from torchvision.transforms import transforms
|
||||
|
||||
np.random.seed(0)
|
||||
|
||||
|
||||
class GaussianBlur(object):
|
||||
# Implements Gaussian blur as described in the SimCLR paper
|
||||
def __init__(self, kernel_size, min=0.1, max=2.0):
|
||||
self.min = min
|
||||
self.max = max
|
||||
# kernel size is set to be 10% of the image height/width
|
||||
self.kernel_size = kernel_size
|
||||
"""blur a single image on CPU"""
|
||||
def __init__(self, kernel_size):
|
||||
radias = kernel_size // 2
|
||||
kernel_size = radias * 2 + 1
|
||||
self.blur_h = nn.Conv2d(3, 3, kernel_size=(kernel_size, 1),
|
||||
stride=1, padding=0, bias=False, groups=3)
|
||||
self.blur_v = nn.Conv2d(3, 3, kernel_size=(1, kernel_size),
|
||||
stride=1, padding=0, bias=False, groups=3)
|
||||
self.k = kernel_size
|
||||
self.r = radias
|
||||
|
||||
def __call__(self, sample):
|
||||
sample = np.array(sample)
|
||||
self.blur = nn.Sequential(
|
||||
nn.ReflectionPad2d(radias),
|
||||
self.blur_h,
|
||||
self.blur_v
|
||||
)
|
||||
|
||||
# blur the image with a 50% chance
|
||||
prob = np.random.random_sample()
|
||||
self.pil_to_tensor = transforms.ToTensor()
|
||||
self.tensor_to_pil = transforms.ToPILImage()
|
||||
|
||||
if prob < 0.5:
|
||||
sigma = (self.max - self.min) * np.random.random_sample() + self.min
|
||||
sample = cv2.GaussianBlur(sample, (self.kernel_size, self.kernel_size), sigma)
|
||||
def __call__(self, img):
|
||||
img = self.pil_to_tensor(img).unsqueeze(0)
|
||||
|
||||
return sample
|
||||
sigma = np.random.uniform(0.1, 2.0)
|
||||
x = np.arange(-self.r, self.r + 1)
|
||||
x = np.exp(-np.power(x, 2) / (2 * sigma * sigma))
|
||||
x = x / x.sum()
|
||||
x = torch.from_numpy(x).view(1, -1).repeat(3, 1)
|
||||
|
||||
self.blur_h.weight.data.copy_(x.view(3, 1, self.k, 1))
|
||||
self.blur_v.weight.data.copy_(x.view(3, 1, 1, self.k))
|
||||
|
||||
with torch.no_grad():
|
||||
img = self.blur(img)
|
||||
img = img.squeeze()
|
||||
|
||||
img = self.tensor_to_pil(img)
|
||||
|
||||
return img
|
|
@ -0,0 +1,14 @@
|
|||
import numpy as np
|
||||
|
||||
np.random.seed(0)
|
||||
|
||||
|
||||
class ContrastiveLearningViewGenerator(object):
|
||||
"""Take two random crops of one image as the query and key."""
|
||||
|
||||
def __init__(self, base_transform, n_views=2):
|
||||
self.base_transform = base_transform
|
||||
self.n_views = n_views
|
||||
|
||||
def __call__(self, x):
|
||||
return [self.base_transform(x) for i in range(self.n_views)]
|
|
@ -0,0 +1,6 @@
|
|||
class BaseSimCLRException(Exception):
|
||||
"""Base exception"""
|
||||
|
||||
|
||||
class InvalidBackboneError(BaseSimCLRException):
|
||||
"""Raised when the choice of backbone Convnet is invalid."""
|
|
@ -1,65 +0,0 @@
|
|||
import torch
|
||||
import numpy as np
|
||||
|
||||
|
||||
class NTXentLoss(torch.nn.Module):
|
||||
|
||||
def __init__(self, device, batch_size, temperature, use_cosine_similarity):
|
||||
super(NTXentLoss, self).__init__()
|
||||
self.batch_size = batch_size
|
||||
self.temperature = temperature
|
||||
self.device = device
|
||||
self.softmax = torch.nn.Softmax(dim=-1)
|
||||
self.mask_samples_from_same_repr = self._get_correlated_mask().type(torch.bool)
|
||||
self.similarity_function = self._get_similarity_function(use_cosine_similarity)
|
||||
self.criterion = torch.nn.CrossEntropyLoss(reduction="sum")
|
||||
|
||||
def _get_similarity_function(self, use_cosine_similarity):
|
||||
if use_cosine_similarity:
|
||||
self._cosine_similarity = torch.nn.CosineSimilarity(dim=-1)
|
||||
return self._cosine_simililarity
|
||||
else:
|
||||
return self._dot_simililarity
|
||||
|
||||
def _get_correlated_mask(self):
|
||||
diag = np.eye(2 * self.batch_size)
|
||||
l1 = np.eye((2 * self.batch_size), 2 * self.batch_size, k=-self.batch_size)
|
||||
l2 = np.eye((2 * self.batch_size), 2 * self.batch_size, k=self.batch_size)
|
||||
mask = torch.from_numpy((diag + l1 + l2))
|
||||
mask = (1 - mask).type(torch.bool)
|
||||
return mask.to(self.device)
|
||||
|
||||
@staticmethod
|
||||
def _dot_simililarity(x, y):
|
||||
v = torch.tensordot(x.unsqueeze(1), y.T.unsqueeze(0), dims=2)
|
||||
# x shape: (N, 1, C)
|
||||
# y shape: (1, C, 2N)
|
||||
# v shape: (N, 2N)
|
||||
return v
|
||||
|
||||
def _cosine_simililarity(self, x, y):
|
||||
# x shape: (N, 1, C)
|
||||
# y shape: (1, 2N, C)
|
||||
# v shape: (N, 2N)
|
||||
v = self._cosine_similarity(x.unsqueeze(1), y.unsqueeze(0))
|
||||
return v
|
||||
|
||||
def forward(self, zis, zjs):
|
||||
representations = torch.cat([zjs, zis], dim=0)
|
||||
|
||||
similarity_matrix = self.similarity_function(representations, representations)
|
||||
|
||||
# filter out the scores from the positive samples
|
||||
l_pos = torch.diag(similarity_matrix, self.batch_size)
|
||||
r_pos = torch.diag(similarity_matrix, -self.batch_size)
|
||||
positives = torch.cat([l_pos, r_pos]).view(2 * self.batch_size, 1)
|
||||
|
||||
negatives = similarity_matrix[self.mask_samples_from_same_repr].view(2 * self.batch_size, -1)
|
||||
|
||||
logits = torch.cat((positives, negatives), dim=1)
|
||||
logits /= self.temperature
|
||||
|
||||
labels = torch.zeros(2 * self.batch_size).to(self.device).long()
|
||||
loss = self.criterion(logits, labels)
|
||||
|
||||
return loss / (2 * self.batch_size)
|
|
@ -1,43 +0,0 @@
|
|||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
import torchvision.models as models
|
||||
|
||||
|
||||
class Encoder(nn.Module):
|
||||
def __init__(self, out_dim=64):
|
||||
super(Encoder, self).__init__()
|
||||
self.conv1 = nn.Conv2d(3, 16, kernel_size=3, stride=1, padding=1)
|
||||
self.conv2 = nn.Conv2d(16, 32, kernel_size=3, stride=1, padding=1)
|
||||
self.conv3 = nn.Conv2d(32, 64, kernel_size=3, stride=1, padding=1)
|
||||
self.conv4 = nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1)
|
||||
self.pool = nn.MaxPool2d(2, 2)
|
||||
|
||||
# projection MLP
|
||||
self.l1 = nn.Linear(64, 64)
|
||||
self.l2 = nn.Linear(64, out_dim)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.conv1(x)
|
||||
x = F.relu(x)
|
||||
x = self.pool(x)
|
||||
|
||||
x = self.conv2(x)
|
||||
x = F.relu(x)
|
||||
x = self.pool(x)
|
||||
|
||||
x = self.conv3(x)
|
||||
x = F.relu(x)
|
||||
x = self.pool(x)
|
||||
|
||||
x = self.conv4(x)
|
||||
x = F.relu(x)
|
||||
x = self.pool(x)
|
||||
|
||||
h = torch.mean(x, dim=[2, 3])
|
||||
|
||||
x = self.l1(h)
|
||||
x = F.relu(x)
|
||||
x = self.l2(x)
|
||||
|
||||
return h, x
|
|
@ -1,37 +1,30 @@
|
|||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
import torchvision.models as models
|
||||
|
||||
from exceptions.exceptions import InvalidBackboneError
|
||||
|
||||
|
||||
class ResNetSimCLR(nn.Module):
|
||||
|
||||
def __init__(self, base_model, out_dim):
|
||||
super(ResNetSimCLR, self).__init__()
|
||||
self.resnet_dict = {"resnet18": models.resnet18(pretrained=False),
|
||||
"resnet50": models.resnet50(pretrained=False)}
|
||||
self.resnet_dict = {"resnet18": models.resnet18(pretrained=False, num_classes=out_dim),
|
||||
"resnet50": models.resnet50(pretrained=False, num_classes=out_dim)}
|
||||
|
||||
resnet = self._get_basemodel(base_model)
|
||||
num_ftrs = resnet.fc.in_features
|
||||
self.backbone = self._get_basemodel(base_model)
|
||||
dim_mlp = self.backbone.fc.in_features
|
||||
|
||||
self.features = nn.Sequential(*list(resnet.children())[:-1])
|
||||
|
||||
# projection MLP
|
||||
self.l1 = nn.Linear(num_ftrs, num_ftrs)
|
||||
self.l2 = nn.Linear(num_ftrs, out_dim)
|
||||
# add mlp projection head
|
||||
self.backbone.fc = nn.Sequential(nn.Linear(dim_mlp, dim_mlp), nn.ReLU(), self.backbone.fc)
|
||||
|
||||
def _get_basemodel(self, model_name):
|
||||
try:
|
||||
model = self.resnet_dict[model_name]
|
||||
print("Feature extractor:", model_name)
|
||||
except KeyError:
|
||||
raise InvalidBackboneError(
|
||||
"Invalid backbone architecture. Check the config file and pass one of: resnet18 or resnet50")
|
||||
else:
|
||||
return model
|
||||
except:
|
||||
raise ("Invalid model name. Check the config file and pass one of: resnet18 or resnet50")
|
||||
|
||||
def forward(self, x):
|
||||
h = self.features(x)
|
||||
h = h.squeeze()
|
||||
|
||||
x = self.l1(h)
|
||||
x = F.relu(x)
|
||||
x = self.l2(x)
|
||||
return h, x
|
||||
return self.backbone(x)
|
||||
|
|
88
run.py
88
run.py
|
@ -1,14 +1,90 @@
|
|||
import argparse
|
||||
import torch
|
||||
from torchvision import models
|
||||
from data_aug.contrastive_learning_dataset import ContrastiveLearningDataset
|
||||
from models.resnet_simclr import ResNetSimCLR
|
||||
from simclr import SimCLR
|
||||
import yaml
|
||||
from data_aug.dataset_wrapper import DataSetWrapper
|
||||
import torch.backends.cudnn as cudnn
|
||||
|
||||
model_names = sorted(name for name in models.__dict__
|
||||
if name.islower() and not name.startswith("__")
|
||||
and callable(models.__dict__[name]))
|
||||
|
||||
parser = argparse.ArgumentParser(description='PyTorch SimCLR')
|
||||
parser.add_argument('-data', metavar='DIR', default='./datasets',
|
||||
help='path to dataset')
|
||||
parser.add_argument('-dataset-name', default='stl10',
|
||||
help='dataset name', choices=['stl10', 'cifar10'])
|
||||
parser.add_argument('-a', '--arch', metavar='ARCH', default='resnet18',
|
||||
choices=model_names,
|
||||
help='model architecture: ' +
|
||||
' | '.join(model_names) +
|
||||
' (default: resnet50)')
|
||||
parser.add_argument('-j', '--workers', default=12, type=int, metavar='N',
|
||||
help='number of data loading workers (default: 32)')
|
||||
parser.add_argument('--epochs', default=200, type=int, metavar='N',
|
||||
help='number of total epochs to run')
|
||||
parser.add_argument('-b', '--batch-size', default=256, type=int,
|
||||
metavar='N',
|
||||
help='mini-batch size (default: 256), this is the total '
|
||||
'batch size of all GPUs on the current node when '
|
||||
'using Data Parallel or Distributed Data Parallel')
|
||||
parser.add_argument('--lr', '--learning-rate', default=0.0003, type=float,
|
||||
metavar='LR', help='initial learning rate', dest='lr')
|
||||
parser.add_argument('--wd', '--weight-decay', default=1e-4, type=float,
|
||||
metavar='W', help='weight decay (default: 1e-4)',
|
||||
dest='weight_decay')
|
||||
parser.add_argument('--resume', default='', type=str, metavar='PATH',
|
||||
help='path to latest checkpoint (default: none)')
|
||||
parser.add_argument('--seed', default=None, type=int,
|
||||
help='seed for initializing training. ')
|
||||
parser.add_argument('--disable-cuda', action='store_true',
|
||||
help='Disable CUDA')
|
||||
parser.add_argument('--fp16_precision', default=False, type=bool,
|
||||
help='Whether or not to use 16-bit precision GPU training.')
|
||||
|
||||
parser.add_argument('--out_dim', default=128, type=int,
|
||||
help='feature dimension (default: 128)')
|
||||
parser.add_argument('--log-every-n-steps', default=100, type=int,
|
||||
help='Log every n steps')
|
||||
parser.add_argument('--temperature', default=0.07, type=float,
|
||||
help='softmax temperature (default: 0.07)')
|
||||
parser.add_argument('--n-views', default=2, type=int, metavar='N',
|
||||
help='Number of views for contrastive learning training.')
|
||||
parser.add_argument('--gpu-index', default=0, type=int, help='Gpu index.')
|
||||
|
||||
|
||||
def main():
|
||||
config = yaml.load(open("config.yaml", "r"), Loader=yaml.FullLoader)
|
||||
dataset = DataSetWrapper(config['batch_size'], **config['dataset'])
|
||||
args = parser.parse_args()
|
||||
|
||||
simclr = SimCLR(dataset, config)
|
||||
simclr.train()
|
||||
# check if gpu training is available
|
||||
if not args.disable_cuda and torch.cuda.is_available():
|
||||
args.device = torch.device('cuda')
|
||||
cudnn.deterministic = True
|
||||
cudnn.benchmark = True
|
||||
else:
|
||||
args.device = torch.device('cpu')
|
||||
args.gpu_index = -1
|
||||
|
||||
dataset = ContrastiveLearningDataset(args.data)
|
||||
|
||||
train_dataset = dataset.get_dataset(args.dataset_name, args.n_views)
|
||||
|
||||
train_loader = torch.utils.data.DataLoader(
|
||||
train_dataset, batch_size=args.batch_size, shuffle=True,
|
||||
num_workers=args.workers, pin_memory=True, drop_last=True)
|
||||
|
||||
model = ResNetSimCLR(base_model=args.arch, out_dim=args.out_dim)
|
||||
|
||||
optimizer = torch.optim.Adam(model.parameters(), args.lr, weight_decay=args.weight_decay)
|
||||
|
||||
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=len(train_loader), eta_min=0,
|
||||
last_epoch=-1)
|
||||
|
||||
# It’s a no-op if the 'gpu_index' argument is a negative integer or None.
|
||||
with torch.cuda.device(args.gpu_index):
|
||||
simclr = SimCLR(model=model, optimizer=optimizer, scheduler=scheduler, args=args)
|
||||
simclr.train(train_loader)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
|
175
simclr.py
175
simclr.py
|
@ -1,11 +1,14 @@
|
|||
import torch
|
||||
from models.resnet_simclr import ResNetSimCLR
|
||||
from torch.utils.tensorboard import SummaryWriter
|
||||
import torch.nn.functional as F
|
||||
from loss.nt_xent import NTXentLoss
|
||||
import os
|
||||
import shutil
|
||||
import sys
|
||||
import yaml
|
||||
import torch
|
||||
from torch.utils.tensorboard import SummaryWriter
|
||||
import torch.nn.functional as F
|
||||
import logging
|
||||
from tqdm import tqdm
|
||||
|
||||
torch.manual_seed(0)
|
||||
|
||||
apex_support = False
|
||||
try:
|
||||
|
@ -17,135 +20,103 @@ except:
|
|||
print("Please install apex for mixed precision training from: https://github.com/NVIDIA/apex")
|
||||
apex_support = False
|
||||
|
||||
import numpy as np
|
||||
|
||||
torch.manual_seed(0)
|
||||
def save_checkpoint(state, is_best, filename='checkpoint.pth.tar'):
|
||||
torch.save(state, filename)
|
||||
if is_best:
|
||||
shutil.copyfile(filename, 'model_best.pth.tar')
|
||||
|
||||
|
||||
def _save_config_file(model_checkpoints_folder):
|
||||
def _save_config_file(model_checkpoints_folder, args):
|
||||
if not os.path.exists(model_checkpoints_folder):
|
||||
os.makedirs(model_checkpoints_folder)
|
||||
shutil.copy('./config.yaml', os.path.join(model_checkpoints_folder, 'config.yaml'))
|
||||
with open(os.path.join(model_checkpoints_folder, 'config.yml'), 'w') as outfile:
|
||||
yaml.dump(args, outfile, default_flow_style=False)
|
||||
|
||||
|
||||
class SimCLR(object):
|
||||
|
||||
def __init__(self, dataset, config):
|
||||
self.config = config
|
||||
self.device = self._get_device()
|
||||
def __init__(self, *args, **kwargs):
|
||||
self.args = kwargs['args']
|
||||
self.model = kwargs['model'].to(self.args.device)
|
||||
self.optimizer = kwargs['optimizer']
|
||||
self.scheduler = kwargs['scheduler']
|
||||
self.writer = SummaryWriter()
|
||||
self.dataset = dataset
|
||||
self.nt_xent_criterion = NTXentLoss(self.device, config['batch_size'], **config['loss'])
|
||||
logging.basicConfig(filename=os.path.join(self.writer.log_dir, 'training.log'), level=logging.DEBUG)
|
||||
self.criterion = torch.nn.CrossEntropyLoss().to(self.args.device)
|
||||
|
||||
def _get_device(self):
|
||||
device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
||||
print("Running on:", device)
|
||||
return device
|
||||
def info_nce_loss(self, features):
|
||||
batch_targets = torch.arange(self.args.batch_size, dtype=torch.long).to(self.args.device)
|
||||
batch_targets = torch.cat(self.args.n_views * [batch_targets])
|
||||
|
||||
def _step(self, model, xis, xjs, n_iter):
|
||||
features = F.normalize(features, dim=1)
|
||||
|
||||
# get the representations and the projections
|
||||
ris, zis = model(xis) # [N,C]
|
||||
similarity_matrix = torch.matmul(features, features.T)
|
||||
# assert similarity_matrix.shape == (
|
||||
# self.args.n_views * self.args.batch_size, self.args.n_views * self.args.batch_size)
|
||||
|
||||
# get the representations and the projections
|
||||
rjs, zjs = model(xjs) # [N,C]
|
||||
mask = torch.eye(len(batch_targets)).to(self.args.device)
|
||||
similarities = similarity_matrix[~mask.bool()].view(similarity_matrix.shape[0], -1)
|
||||
similarities = similarities / self.args.temperature
|
||||
return similarities, batch_targets
|
||||
|
||||
# normalize projection feature vectors
|
||||
zis = F.normalize(zis, dim=1)
|
||||
zjs = F.normalize(zjs, dim=1)
|
||||
def train(self, train_loader):
|
||||
|
||||
loss = self.nt_xent_criterion(zis, zjs)
|
||||
return loss
|
||||
|
||||
def train(self):
|
||||
|
||||
train_loader, valid_loader = self.dataset.get_data_loaders()
|
||||
|
||||
model = ResNetSimCLR(**self.config["model"]).to(self.device)
|
||||
model = self._load_pre_trained_weights(model)
|
||||
|
||||
optimizer = torch.optim.Adam(model.parameters(), 3e-4, weight_decay=eval(self.config['weight_decay']))
|
||||
|
||||
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=len(train_loader), eta_min=0,
|
||||
last_epoch=-1)
|
||||
|
||||
if apex_support and self.config['fp16_precision']:
|
||||
model, optimizer = amp.initialize(model, optimizer,
|
||||
opt_level='O2',
|
||||
keep_batchnorm_fp32=True)
|
||||
if apex_support and self.args.fp16_precision:
|
||||
logging.debug("Using apex for fp16 precision training.")
|
||||
self.model, self.optimizer = amp.initialize(self.model, self.optimizer,
|
||||
opt_level='O2',
|
||||
keep_batchnorm_fp32=True)
|
||||
|
||||
model_checkpoints_folder = os.path.join(self.writer.log_dir, 'checkpoints')
|
||||
|
||||
# save config file
|
||||
_save_config_file(model_checkpoints_folder)
|
||||
_save_config_file(model_checkpoints_folder, self.args)
|
||||
|
||||
n_iter = 0
|
||||
valid_n_iter = 0
|
||||
best_valid_loss = np.inf
|
||||
logging.info(f"Start SimCLR training for {self.args.epochs} epochs.")
|
||||
logging.info(f"Training with gpu: {self.args.disable_cuda}.")
|
||||
|
||||
for epoch_counter in range(self.config['epochs']):
|
||||
for (xis, xjs), _ in train_loader:
|
||||
optimizer.zero_grad()
|
||||
for epoch_counter in range(self.args.epochs):
|
||||
for images, _ in tqdm(train_loader):
|
||||
images = torch.cat(images, dim=0)
|
||||
|
||||
xis = xis.to(self.device)
|
||||
xjs = xjs.to(self.device)
|
||||
images = images.to(self.args.device)
|
||||
|
||||
loss = self._step(model, xis, xjs, n_iter)
|
||||
features = self.model(images)
|
||||
logits, labels = self.info_nce_loss(features)
|
||||
loss = self.criterion(logits, labels)
|
||||
|
||||
if n_iter % self.config['log_every_n_steps'] == 0:
|
||||
self.writer.add_scalar('train_loss', loss, global_step=n_iter)
|
||||
|
||||
if apex_support and self.config['fp16_precision']:
|
||||
with amp.scale_loss(loss, optimizer) as scaled_loss:
|
||||
self.optimizer.zero_grad()
|
||||
if apex_support and self.args.fp16_precision:
|
||||
with amp.scale_loss(loss, self.optimizer) as scaled_loss:
|
||||
scaled_loss.backward()
|
||||
else:
|
||||
loss.backward()
|
||||
|
||||
optimizer.step()
|
||||
self.optimizer.step()
|
||||
|
||||
if n_iter % self.args.log_every_n_steps == 0:
|
||||
predictions = torch.argmax(logits, dim=1)
|
||||
acc = 100 * (predictions == labels).float().mean()
|
||||
self.writer.add_scalar('loss', loss, global_step=n_iter)
|
||||
self.writer.add_scalar('acc/top1', acc, global_step=n_iter)
|
||||
self.writer.add_scalar('learning_rate', self.scheduler.get_lr()[0], global_step=n_iter)
|
||||
|
||||
n_iter += 1
|
||||
|
||||
# validate the model if requested
|
||||
if epoch_counter % self.config['eval_every_n_epochs'] == 0:
|
||||
valid_loss = self._validate(model, valid_loader)
|
||||
if valid_loss < best_valid_loss:
|
||||
# save the model weights
|
||||
best_valid_loss = valid_loss
|
||||
torch.save(model.state_dict(), os.path.join(model_checkpoints_folder, 'model.pth'))
|
||||
|
||||
self.writer.add_scalar('validation_loss', valid_loss, global_step=valid_n_iter)
|
||||
valid_n_iter += 1
|
||||
|
||||
# warmup for the first 10 epochs
|
||||
if epoch_counter >= 10:
|
||||
scheduler.step()
|
||||
self.writer.add_scalar('cosine_lr_decay', scheduler.get_lr()[0], global_step=n_iter)
|
||||
self.scheduler.step()
|
||||
logging.debug(f"Epoch: {epoch_counter}\tLoss: {loss}\tTop1 accuracy: {acc}")
|
||||
|
||||
def _load_pre_trained_weights(self, model):
|
||||
try:
|
||||
checkpoints_folder = os.path.join('./runs', self.config['fine_tune_from'], 'checkpoints')
|
||||
state_dict = torch.load(os.path.join(checkpoints_folder, 'model.pth'))
|
||||
model.load_state_dict(state_dict)
|
||||
print("Loaded pre-trained model with success.")
|
||||
except FileNotFoundError:
|
||||
print("Pre-trained weights not found. Training from scratch.")
|
||||
|
||||
return model
|
||||
|
||||
def _validate(self, model, valid_loader):
|
||||
|
||||
# validation steps
|
||||
with torch.no_grad():
|
||||
model.eval()
|
||||
|
||||
valid_loss = 0.0
|
||||
counter = 0
|
||||
for (xis, xjs), _ in valid_loader:
|
||||
xis = xis.to(self.device)
|
||||
xjs = xjs.to(self.device)
|
||||
|
||||
loss = self._step(model, xis, xjs, counter)
|
||||
valid_loss += loss.item()
|
||||
counter += 1
|
||||
valid_loss /= counter
|
||||
model.train()
|
||||
return valid_loss
|
||||
logging.info("Training has finished.")
|
||||
# save model checkpoints
|
||||
checkpoint_name = 'checkpoint_{:04d}.pth.tar'.format(self.args.epochs)
|
||||
save_checkpoint({
|
||||
'epoch': self.args.epochs,
|
||||
'arch': self.args.arch,
|
||||
'state_dict': self.model.state_dict(),
|
||||
'optimizer': self.optimizer.state_dict(),
|
||||
}, is_best=False, filename=os.path.join(self.writer.log_dir, checkpoint_name))
|
||||
logging.info(f"Model checkpoint and metadata has been saved at {self.writer.log_dir}.")
|
||||
|
|
Loading…
Reference in New Issue