# Copyright (c) Alibaba, Inc. and its affiliates.
import argparse
from glob import glob

import numpy as np
import torch
import torch.nn as nn
from tqdm import tqdm

args = argparse.ArgumentParser(description='Process some integers.')
args.add_argument(
    'feature_dir',
    type=str,
    help='feature root dir',
    nargs='?',
    default='work_dirs')
args.add_argument(
    'train_path',
    type=str,
    help='train_path to match train feature npy',
    nargs='?',
    default='train_*')
args.add_argument(
    'val_path',
    type=str,
    help='train_path to match train feature npy',
    nargs='?',
    default='val_*')
args.add_argument(
    '--nb_knn',
    default=[10, 20, 100, 200],
    nargs='+',
    type=int,
    help='Number of NN to use. 20 is usually working the best.')
args.add_argument(
    '--temperature',
    default=0.07,
    type=float,
    help='Temperature used in the voting coefficient')


def knn_classifier(train_features,
                   train_labels,
                   test_features,
                   test_labels,
                   k,
                   T,
                   num_classes=1000):

    top1, top5, total = 0.0, 0.0, 0
    train_features = train_features.t()
    num_test_images, num_chunks = test_labels.shape[0], 100
    imgs_per_chunk = num_test_images // num_chunks
    retrieval_one_hot = torch.zeros(k, num_classes).cuda()
    for idx in range(0, num_test_images, imgs_per_chunk):
        # get the features for test images
        features = test_features[idx:min((idx +
                                          imgs_per_chunk), num_test_images), :]
        targets = test_labels[idx:min((idx + imgs_per_chunk), num_test_images)]
        batch_size = targets.shape[0]

        # calculate the dot product and compute top-k neighbors
        similarity = torch.mm(features, train_features)
        distances, indices = similarity.topk(k, largest=True, sorted=True)
        candidates = train_labels.view(1, -1).expand(batch_size, -1)

        retrieved_neighbors = torch.gather(candidates, 1, indices)
        retrieval_one_hot.resize_(batch_size * k, num_classes).zero_()
        retrieval_one_hot.scatter_(1, retrieved_neighbors.view(-1, 1), 1)
        distances_transform = distances.clone().div_(T).exp_()
        probs = torch.sum(
            torch.mul(
                retrieval_one_hot.view(batch_size, -1, num_classes),
                distances_transform.view(batch_size, -1, 1),
            ),
            1,
        )
        _, predictions = probs.sort(1, True)

        # find the predictions that match the target
        correct = predictions.eq(targets.data.view(-1, 1))
        top1 = top1 + correct.narrow(1, 0, 1).sum().item()
        top5 = top5 + correct.narrow(1, 0, min(
            5, k)).sum().item()  # top5 does not make sense if k < 5
        total += targets.size(0)
    top1 = top1 * 100.0 / total
    top5 = top5 * 100.0 / total
    return top1, top5


if __name__ == '__main__':
    args = args.parse_args()

    train_list = glob('%s/%sfeat1.npy' % (args.feature_dir, args.train_path))
    val_list = glob('%s/%sfeat1.npy' % (args.feature_dir, args.val_path))

    train_features = []
    train_labels = []
    val_features = []
    val_labels = []

    for i in tqdm(train_list):
        label_npy = i.replace('feat1', 'label')
        train_features.append(np.load(i))
        train_labels.append(np.load(label_npy))

    train_features = torch.tensor(np.vstack(train_features)).cuda()
    train_labels = torch.tensor(np.hstack(train_labels)).long().cuda()
    print(train_features.shape)
    print(train_labels.shape)

    for i in tqdm(val_list):
        label_npy = i.replace('feat1', 'label')
        val_features.append(np.load(i))
        val_labels.append(np.load(label_npy))

    val_features = torch.tensor(np.vstack(val_features)).cuda()
    val_labels = torch.tensor(np.hstack(val_labels)).long().cuda()

    print(val_features.shape)
    print(val_labels.shape)

    train_features = nn.functional.normalize(train_features, dim=1, p=2).cuda()
    val_features = nn.functional.normalize(val_features, dim=1, p=2).cuda()

    for k in args.nb_knn:

        print(
            k,
            knn_classifier(train_features, train_labels, val_features,
                           val_labels, k, args.temperature))