#coding=utf8
from __future__ import division
import torch
import os,time,datetime
from torch.autograd import Variable
import logging
import numpy as np
from math import ceil
from torch.nn import L1Loss
from torch import nn

def dt():
    return datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S')

def trainlog(logfilepath, head='%(message)s'):
    logger = logging.getLogger('mylogger')
    logging.basicConfig(filename=logfilepath, level=logging.INFO, format=head)
    console = logging.StreamHandler()
    console.setLevel(logging.INFO)
    formatter = logging.Formatter(head)
    console.setFormatter(formatter)
    logging.getLogger('').addHandler(console)

def train(cfg,
          model,
          epoch_num,
          start_epoch,
          optimizer,
          criterion,
          exp_lr_scheduler,
          data_set,
          data_loader,
          save_dir,
          print_inter=200,
          val_inter=3500
          ):

    step = 0
    add_loss = L1Loss()
    for epoch in range(start_epoch,epoch_num-1):
        # train phase
        exp_lr_scheduler.step(epoch)
        model.train(True)  # Set model to training mode

        for batch_cnt, data in enumerate(data_loader['train']):

            step+=1
            model.train(True)
            inputs, labels, labels_swap, swap_law = data
            inputs = Variable(inputs.cuda())
            labels = Variable(torch.from_numpy(np.array(labels)).cuda())
            labels_swap = Variable(torch.from_numpy(np.array(labels_swap)).cuda())
            swap_law = Variable(torch.from_numpy(np.array(swap_law)).float().cuda())

            # zero the parameter gradients
            optimizer.zero_grad()

            outputs = model(inputs)
            if isinstance(outputs, list):
                loss = criterion(outputs[0], labels)
                loss += criterion(outputs[1], labels_swap)
                loss += add_loss(outputs[2], swap_law)
            loss.backward()
            optimizer.step()

            if step % val_inter == 0:
                logging.info('current lr:%s' % exp_lr_scheduler.get_lr())
                # val phase
                model.train(False)  # Set model to evaluate mode

                val_loss = 0
                val_corrects1 = 0
                val_corrects2 = 0
                val_corrects3 = 0
                val_size = ceil(len(data_set['val']) / data_loader['val'].batch_size)

                t0 = time.time()

                for batch_cnt_val, data_val in enumerate(data_loader['val']):
                    # print data
                    inputs,  labels, labels_swap, swap_law = data_val

                    inputs = Variable(inputs.cuda())
                    labels = Variable(torch.from_numpy(np.array(labels)).long().cuda())
                    labels_swap = Variable(torch.from_numpy(np.array(labels_swap)).long().cuda())
                    # forward
                    if len(inputs)==1:
                        inputs = torch.cat((inputs,inputs))
                        labels = torch.cat((labels,labels))
                        labels_swap = torch.cat((labels_swap,labels_swap))
                    outputs = model(inputs)

                    if isinstance(outputs, list):
                        outputs1 = outputs[0] + outputs[1][:,0:cfg['numcls']] + outputs[1][:,cfg['numcls']:2*cfg['numcls']]
                        outputs2 = outputs[0]
                        outputs3 = outputs[1][:,0:cfg['numcls']] + outputs[1][:,cfg['numcls']:2*cfg['numcls']]
                    _, preds1 = torch.max(outputs1, 1)
                    _, preds2 = torch.max(outputs2, 1)
                    _, preds3 = torch.max(outputs3, 1)


                    batch_corrects1 = torch.sum((preds1 == labels)).data.item()
                    val_corrects1 += batch_corrects1
                    batch_corrects2 = torch.sum((preds2 == labels)).data.item()
                    val_corrects2 += batch_corrects2
                    batch_corrects3 = torch.sum((preds3 == labels)).data.item()
                    val_corrects3 += batch_corrects3


                # val_acc = 0.5 * val_corrects / len(data_set['val'])
                val_acc1 = 0.5 * val_corrects1 / len(data_set['val'])
                val_acc2 = 0.5 * val_corrects2 / len(data_set['val'])
                val_acc3 = 0.5 * val_corrects3 / len(data_set['val'])

                t1 = time.time()
                since = t1-t0
                logging.info('--'*30)
                logging.info('current lr:%s' % exp_lr_scheduler.get_lr())
                logging.info('%s epoch[%d]-val-loss: %.4f ||val-acc@1: c&a: %.4f c: %.4f a: %.4f||time: %d'
                             % (dt(), epoch, val_loss, val_acc1, val_acc2, val_acc3, since))

                # save model
                save_path = os.path.join(save_dir,
                        'weights-%d-%d-[%.4f].pth'%(epoch,batch_cnt,val_acc1))
                torch.save(model.state_dict(), save_path)
                logging.info('saved model to %s' % (save_path))
                logging.info('--' * 30)