import numpy as np
import torch
import yaml

print(torch.__version__)
import torch.optim as optim
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
from torchvision import datasets
from torch.utils.tensorboard import SummaryWriter
import torch.nn.functional as F

from model import Encoder, ResNet18
from utils import GaussianBlur

config = yaml.load(open("config.yaml", "r"), Loader=yaml.FullLoader)

batch_size = config['batch_size']
out_dim = config['out_dim']
s = config['s']
temperature = config['temperature']
use_cosine_similarity = config['use_cosine_similarity']

color_jitter = transforms.ColorJitter(0.8 * s, 0.8 * s, 0.8 * s, 0.2 * s)

data_augment = transforms.Compose([transforms.ToPILImage(),
                                   transforms.RandomResizedCrop(96),
                                   transforms.RandomHorizontalFlip(),
                                   transforms.RandomApply([color_jitter], p=0.8),
                                   transforms.RandomGrayscale(p=0.2),
                                   GaussianBlur(),
                                   transforms.ToTensor()])

train_dataset = datasets.STL10('data', split='train+unlabeled', download=True, transform=transforms.ToTensor())
train_loader = DataLoader(train_dataset, batch_size=batch_size, num_workers=1, drop_last=True, shuffle=True)

# model = Encoder(out_dim=out_dim)
model = ResNet18()
print(model)

train_gpu = torch.cuda.is_available()
print("Is gpu available:", train_gpu)
# moves the model paramemeters to gpu
if train_gpu:
    model.cuda()

criterion = torch.nn.CrossEntropyLoss(reduction='sum')
optimizer = optim.Adam(model.parameters(), 3e-4)

train_writer = SummaryWriter()

if use_cosine_similarity:
    similarity_dim1 = torch.nn.CosineSimilarity(dim=1)
    similarity_dim2 = torch.nn.CosineSimilarity(dim=2)

# Mask to remove positive examples from the batch of negative samples
negative_mask = torch.ones((batch_size, 2 * batch_size), dtype=bool)
for i in range(batch_size):
    negative_mask[i, i] = 0
    negative_mask[i, i + batch_size] = 0

n_iter = 0
for e in range(config['epochs']):
    for step, (batch_x, _) in enumerate(train_loader):

        optimizer.zero_grad()

        xis = []
        xjs = []
        for k in range(len(batch_x)):
            xis.append(data_augment(batch_x[k]))
            xjs.append(data_augment(batch_x[k]))

        xis = torch.stack(xis)
        xjs = torch.stack(xjs)

        if train_gpu:
            xis = xis.cuda()
            xjs = xjs.cuda()

        ris, zis = model(xis)  # [N,C]
        train_writer.add_histogram("xi_repr", ris, global_step=n_iter)
        train_writer.add_histogram("xi_latent", zis, global_step=n_iter)

        rjs, zjs = model(xjs)  # [N,C]
        train_writer.add_histogram("xj_repr", rjs, global_step=n_iter)
        train_writer.add_histogram("xj_latent", zjs, global_step=n_iter)

        # normalize projection feature vectors
        zis = F.normalize(zis, dim=1)
        zjs = F.normalize(zjs, dim=1)
        # assert zis.shape == (batch_size, out_dim), "Shape not expected: " + str(zis.shape)
        # assert zjs.shape == (batch_size, out_dim), "Shape not expected: " + str(zjs.shape)

        # positive pairs
        if use_cosine_similarity:
            l_pos = similarity_dim1(zis.view(batch_size, out_dim), zjs.view(batch_size, out_dim)).view(batch_size, 1)
        else:
            l_pos = torch.bmm(zis.view(batch_size, 1, out_dim), zjs.view(batch_size, out_dim, 1)).view(batch_size, 1)

        l_pos /= temperature
        # assert l_pos.shape == (batch_size, 1)  # [N,1]

        negatives = torch.cat([zjs, zis], dim=0)

        if use_cosine_similarity:
            negatives = negatives.view(1, (2 * batch_size), out_dim)
            l_neg_1 = similarity_dim2(zis.view(batch_size, 1, out_dim), negatives)
            l_neg_2 = similarity_dim2(zjs.view(batch_size, 1, out_dim), negatives)
        else:
            l_neg_1 = torch.tensordot(zis.view(batch_size, 1, out_dim), negatives.T.view(1, out_dim, (2 * batch_size)),
                                      dims=2)
            l_neg_2 = torch.tensordot(zjs.view(batch_size, 1, out_dim), negatives.T.view(1, out_dim, (2 * batch_size)),
                                      dims=2)

        labels = torch.zeros(batch_size, dtype=torch.long)
        if train_gpu:
            labels = labels.cuda()

        loss = 0
        for l_neg in [l_neg_1, l_neg_2]:
            l_neg = l_neg[negative_mask].view(l_neg.shape[0], -1)
            l_neg /= temperature

            # assert l_neg.shape == (batch_size, 2 * (batch_size - 1)), "Shape of negatives not expected." + str(l_neg.shape)
            logits = torch.cat([l_pos, l_neg], dim=1)  # [N,K+1]
            loss += criterion(logits, labels)

        loss = loss / (2 * batch_size)
        train_writer.add_scalar('loss', loss, global_step=n_iter)

        loss.backward()
        optimizer.step()
        n_iter += 1
        # print("Step {}, Loss {}".format(step, loss))

torch.save(model.state_dict(), './model/checkpoint.pth')