# Copyright (c) OpenMMLab. All rights reserved.
import argparse

import torch
from datasets import load_dataset
from transformers import BertForSequenceClassification, BertTokenizer

from mmengine.evaluator import BaseMetric
from mmengine.model import BaseModel
from mmengine.runner import Runner


class MMBertForClassify(BaseModel):

    def __init__(self, model):
        super().__init__()
        self.model = model

    def forward(self, label, input_ids, token_type_ids, attention_mask, mode):
        output = self.model(
            input_ids=input_ids,
            token_type_ids=token_type_ids,
            attention_mask=attention_mask,
            labels=label)
        if mode == 'loss':
            return {'loss': output.loss}
        elif mode == 'predict':
            return output.logits, label


class Accuracy(BaseMetric):

    def process(self, data_batch, data_samples):
        score, gt = data_samples
        self.results.append({
            'batch_size': len(gt),
            'correct': (score.argmax(dim=1) == gt).sum().cpu(),
        })

    def compute_metrics(self, results):
        total_correct = sum(item['correct'] for item in results)
        total_size = sum(item['batch_size'] for item in results)
        return dict(accuracy=100 * total_correct / total_size)


def parse_args():
    parser = argparse.ArgumentParser(description='Distributed Training')
    parser.add_argument(
        '--launcher',
        choices=['none', 'pytorch', 'slurm', 'mpi'],
        default='none',
        help='job launcher')
    parser.add_argument('--local_rank', type=int, default=0)

    args = parser.parse_args()
    return args


def collate_fn(data):
    labels = []
    input_ids = []
    token_type_ids = []
    attention_mask = []
    for item in data:
        labels.append(item['label'])
        input_ids.append(torch.tensor(item['input_ids']))
        token_type_ids.append(torch.tensor(item['token_type_ids']))
        attention_mask.append(torch.tensor(item['attention_mask']))

    input_ids = torch.stack(input_ids)
    token_type_ids = torch.stack(token_type_ids)
    attention_mask = torch.stack(attention_mask)
    label = torch.tensor(labels)
    return dict(
        label=label,
        input_ids=input_ids,
        token_type_ids=token_type_ids,
        attention_mask=attention_mask)


def main():
    args = parse_args()
    model = BertForSequenceClassification.from_pretrained(
        'bert-base-uncased', num_labels=2)
    tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')

    train_set = load_dataset('imdb', split='train')
    test_set = load_dataset('imdb', split='test')

    train_set = train_set.map(
        lambda x: tokenizer(
            x['text'], truncation=True, padding=True, max_length=128),
        batched=True)
    test_set = test_set.map(
        lambda x: tokenizer(
            x['text'], truncation=True, padding=True, max_length=128),
        batched=True)

    train_loader = dict(
        batch_size=32,
        dataset=train_set,
        sampler=dict(type='DefaultSampler', shuffle=True),
        collate_fn=collate_fn)
    test_loader = dict(
        batch_size=32,
        dataset=test_set,
        sampler=dict(type='DefaultSampler', shuffle=False),
        collate_fn=collate_fn)
    runner = Runner(
        model=MMBertForClassify(model),
        train_dataloader=train_loader,
        val_dataloader=test_loader,
        optim_wrapper=dict(optimizer=dict(type=torch.optim.Adam, lr=2e-5)),
        train_cfg=dict(by_epoch=True, max_epochs=2, val_interval=1),
        val_cfg=dict(),
        work_dir='bert_work_dir',
        val_evaluator=dict(type=Accuracy),
        launcher=args.launcher,
    )
    runner.train()


if __name__ == '__main__':
    main()