# Copyright (c) Facebook, Inc. and its affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
#
################################################################################
"""
SVM training using 3-fold cross-validation.

Relevant transfer tasks: Image Classification VOC07 and COCO2014.
"""

from __future__ import division
from __future__ import absolute_import
from __future__ import unicode_literals
from __future__ import print_function

import multiprocessing as mp
import tqdm
import argparse
import logging
import numpy as np
import os
import pickle
import sys
from sklearn.svm import LinearSVC
from sklearn.model_selection import cross_val_score

import svm_helper

import pdb


def task(cls, cost, opts, features, targets):
    out_file, ap_out_file = svm_helper.get_svm_train_output_files(
        cls, cost, opts.output_path)
    if not (os.path.exists(out_file) and os.path.exists(ap_out_file)):
        clf = LinearSVC(
            C=cost,
            class_weight={
                1: 2,
                -1: 1
            },
            intercept_scaling=1.0,
            verbose=0,
            penalty='l2',
            loss='squared_hinge',
            tol=0.0001,
            dual=True,
            max_iter=2000,
        )
        cls_labels = targets[:, cls].astype(dtype=np.int32, copy=True)
        cls_labels[np.where(cls_labels == 0)] = -1
        ap_scores = cross_val_score(
            clf, features, cls_labels, cv=3, scoring='average_precision')
        clf.fit(features, cls_labels)
        np.save(ap_out_file, np.array([ap_scores.mean()]))
        with open(out_file, 'wb') as fwrite:
            pickle.dump(clf, fwrite)
    return 0


def mp_helper(args):
    return task(*args)


def train_svm(opts):
    assert os.path.exists(opts.data_file), "Data file not found. Abort!"
    if not os.path.exists(opts.output_path):
        os.makedirs(opts.output_path)

    features, targets = svm_helper.load_input_data(opts.data_file,
                                                   opts.targets_data_file)
    # normalize the features: N x 9216 (example shape)
    features = svm_helper.normalize_features(features)

    # parse the cost values for training the SVM on
    costs_list = svm_helper.parse_cost_list(opts.costs_list)

    # classes for which SVM training should be done
    if opts.cls_list:
        cls_list = [int(cls) for cls in opts.cls_list.split(",")]
    else:
        num_classes = targets.shape[1]
        cls_list = range(num_classes)

    num_task = len(cls_list) * len(costs_list)
    args_cls = []
    args_cost = []
    for cls in cls_list:
        for cost in costs_list:
            args_cls.append(cls)
            args_cost.append(cost)
    args_opts = [opts] * num_task
    args_features = [features] * num_task
    args_targets = [targets] * num_task

    pool = mp.Pool(mp.cpu_count())
    for _ in tqdm.tqdm(
            pool.imap_unordered(
                mp_helper,
                zip(args_cls, args_cost, args_opts, args_features,
                    args_targets)),
            total=num_task):
        pass


def main():
    parser = argparse.ArgumentParser(description='SVM model training')
    parser.add_argument(
        '--data_file',
        type=str,
        default=None,
        help="Numpy file containing image features")
    parser.add_argument(
        '--targets_data_file',
        type=str,
        default=None,
        help="Numpy file containing image labels")
    parser.add_argument(
        '--output_path',
        type=str,
        default=None,
        help="path where to save the trained SVM models")
    parser.add_argument(
        '--costs_list',
        type=str,
        default="0.01,0.1",
        help="comma separated string containing list of costs")
    parser.add_argument(
        '--random_seed',
        type=int,
        default=100,
        help="random seed for SVM classifier training")

    parser.add_argument(
        '--cls_list',
        type=str,
        default=None,
        help="comma separated string list of classes to train")
    if len(sys.argv) == 1:
        parser.print_help()
        sys.exit(1)

    opts = parser.parse_args()
    train_svm(opts)


if __name__ == '__main__':
    main()