import numpy as np
import torch
from datasets import load_dataset
from torchtext.data.metrics import bleu_score
from transformers import AutoTokenizer, T5ForConditionalGeneration

from mmengine.evaluator import BaseMetric
from mmengine.model import BaseModel
from mmengine.runner import Runner

tokenizer = AutoTokenizer.from_pretrained('t5-small')


class MMT5ForTranslation(BaseModel):

    def __init__(self, model):
        super().__init__()
        self.model = model

    def forward(self, label, input_ids, attention_mask, mode):
        if mode == 'loss':
            output = self.model(
                input_ids=input_ids,
                attention_mask=attention_mask,
                labels=label)
            return {'loss': output.loss}
        elif mode == 'predict':
            output = self.model.generate(input_ids)
            return output, label


def post_process(preds, labels):
    preds = tokenizer.batch_decode(preds, skip_special_tokens=True)
    labels = torch.where(labels != -100, labels, tokenizer.pad_token_id)
    labels = tokenizer.batch_decode(labels, skip_special_tokens=True)
    decoded_preds = [pred.split() for pred in preds]
    decoded_labels = [[label.split()] for label in labels]
    return decoded_preds, decoded_labels


class Accuracy(BaseMetric):

    def process(self, data_batch, data_samples):
        outputs, labels = data_samples
        decoded_preds, decoded_labels = post_process(outputs, labels)
        score = bleu_score(decoded_preds, decoded_labels)
        prediction_lens = torch.tensor([
            torch.count_nonzero(pred != tokenizer.pad_token_id)
            for pred in outputs
        ],
                                       dtype=torch.float64)

        gen_len = torch.mean(prediction_lens).item()
        self.results.append({
            'gen_len': gen_len,
            'bleu': score,
        })

    def compute_metrics(self, results):
        return dict(
            gen_len=np.mean([item['gen_len'] for item in results]),
            bleu_score=np.mean([item['bleu'] for item in results]),
        )


def collate_fn(data):
    prefix = 'translate English to French: '
    input_sequences = [prefix + item['translation']['en'] for item in data]
    target_sequences = [item['translation']['fr'] for item in data]
    input_dict = tokenizer(
        input_sequences,
        padding='longest',
        return_tensors='pt',
    )

    label = tokenizer(
        target_sequences,
        padding='longest',
        return_tensors='pt',
    ).input_ids
    label[label ==
          tokenizer.pad_token_id] = -100  # ignore contribution to loss
    return dict(
        label=label,
        input_ids=input_dict.input_ids,
        attention_mask=input_dict.attention_mask)


def main():
    model = T5ForConditionalGeneration.from_pretrained('t5-small')

    books = load_dataset('opus_books', 'en-fr')
    books = books['train'].train_test_split(test_size=0.2)
    train_set, test_set = books['train'], books['test']

    train_loader = dict(
        batch_size=16,
        dataset=train_set,
        sampler=dict(type='DefaultSampler', shuffle=True),
        collate_fn=collate_fn)
    test_loader = dict(
        batch_size=32,
        dataset=test_set,
        sampler=dict(type='DefaultSampler', shuffle=False),
        collate_fn=collate_fn)
    runner = Runner(
        model=MMT5ForTranslation(model),
        train_dataloader=train_loader,
        val_dataloader=test_loader,
        optim_wrapper=dict(optimizer=dict(type=torch.optim.Adam, lr=2e-5)),
        train_cfg=dict(by_epoch=True, max_epochs=2, val_interval=1),
        val_cfg=dict(),
        work_dir='t5_work_dir',
        val_evaluator=dict(type=Accuracy))

    runner.train()


if __name__ == '__main__':
    main()