# Copyright (c) OpenMMLab. All rights reserved.
import argparse

import torch
import torch.nn.functional as F
import torchvision
import torchvision.transforms as transforms

from mmengine.evaluator import BaseMetric
from mmengine.model import BaseModel
from mmengine.runner._flexible_runner import FlexibleRunner


class MMResNet50(BaseModel):

    def __init__(self):
        super().__init__()
        self.resnet = torchvision.models.resnet50()

    def forward(self, imgs, labels, mode):
        x = self.resnet(imgs)
        if mode == 'loss':
            return {'loss': F.cross_entropy(x, labels)}
        elif mode == 'predict':
            return x, labels


class Accuracy(BaseMetric):

    def process(self, data_batch, data_samples):
        score, gt = data_samples
        self.results.append({
            'batch_size': len(gt),
            'correct': (score.argmax(dim=1) == gt).sum().cpu(),
        })

    def compute_metrics(self, results):
        total_correct = sum(item['correct'] for item in results)
        total_size = sum(item['batch_size'] for item in results)
        return dict(accuracy=100 * total_correct / total_size)


def parse_args():
    parser = argparse.ArgumentParser(description='Distributed Training')
    parser.add_argument('--local_rank', '--local-rank', type=int, default=0)
    parser.add_argument('--use-fsdp', action='store_true')
    parser.add_argument('--use-deepspeed', action='store_true')
    parser.add_argument('--use-colossalai', action='store_true')
    args = parser.parse_args()
    return args


def main():
    args = parse_args()
    norm_cfg = dict(mean=[0.491, 0.482, 0.447], std=[0.202, 0.199, 0.201])
    train_set = torchvision.datasets.CIFAR10(
        'data/cifar10',
        train=True,
        download=True,
        transform=transforms.Compose([
            transforms.RandomCrop(32, padding=4),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            transforms.Normalize(**norm_cfg)
        ]))
    valid_set = torchvision.datasets.CIFAR10(
        'data/cifar10',
        train=False,
        download=True,
        transform=transforms.Compose(
            [transforms.ToTensor(),
             transforms.Normalize(**norm_cfg)]))
    train_dataloader = dict(
        batch_size=128,
        dataset=train_set,
        sampler=dict(type='DefaultSampler', shuffle=True),
        collate_fn=dict(type='default_collate'))
    val_dataloader = dict(
        batch_size=128,
        dataset=valid_set,
        sampler=dict(type='DefaultSampler', shuffle=False),
        collate_fn=dict(type='default_collate'))

    if args.use_deepspeed:
        strategy = dict(
            type='DeepSpeedStrategy',
            fp16=dict(
                enabled=True,
                fp16_master_weights_and_grads=False,
                loss_scale=0,
                loss_scale_window=500,
                hysteresis=2,
                min_loss_scale=1,
                initial_scale_power=15,
            ),
            inputs_to_half=[0],
            # bf16=dict(
            #     enabled=True,
            # ),
            zero_optimization=dict(
                stage=3,
                allgather_partitions=True,
                reduce_scatter=True,
                allgather_bucket_size=50000000,
                reduce_bucket_size=50000000,
                overlap_comm=True,
                contiguous_gradients=True,
                cpu_offload=False),
        )
        optim_wrapper = dict(
            type='DeepSpeedOptimWrapper',
            optimizer=dict(type='AdamW', lr=1e-3))
    elif args.use_fsdp:
        from functools import partial

        from torch.distributed.fsdp.wrap import size_based_auto_wrap_policy
        size_based_auto_wrap_policy = partial(
            size_based_auto_wrap_policy, min_num_params=1e7)
        strategy = dict(
            type='FSDPStrategy',
            model_wrapper=dict(auto_wrap_policy=size_based_auto_wrap_policy))
        optim_wrapper = dict(
            type='AmpOptimWrapper', optimizer=dict(type='AdamW', lr=1e-3))
    elif args.use_colossalai:
        from colossalai.tensor.op_wrapper import colo_op_impl

        # ColossalAI overwrite some torch ops with their custom op to
        # make it compatible with `ColoTensor`. However, a backward error
        # is more likely to happen if there are inplace operation in the
        # model.
        # For example, layers like `conv` + `bn` + `relu` is OK when `relu` is
        # inplace since PyTorch builtin ops `batch_norm` could handle it.
        # However, if `relu` is an `inplaced` op while `batch_norm` is an
        # custom op, an error will be raised since PyTorch thinks the custom op
        # could not handle the backward graph modification caused by inplace
        # op.
        # In this example, the inplace op `add_` in resnet could raise an error
        # since PyTorch consider the custom op before it could not handle the
        # backward graph modification
        colo_op_impl(torch.Tensor.add_)(torch.add)
        strategy = dict(type='ColossalAIStrategy')
        optim_wrapper = dict(optimizer=dict(type='HybridAdam', lr=1e-3))
    else:
        strategy = None
        optim_wrapper = dict(
            type='AmpOptimWrapper', optimizer=dict(type='AdamW', lr=1e-3))

    runner = FlexibleRunner(
        model=MMResNet50(),
        work_dir='./work_dirs',
        strategy=strategy,
        train_dataloader=train_dataloader,
        optim_wrapper=optim_wrapper,
        param_scheduler=dict(type='LinearLR'),
        train_cfg=dict(by_epoch=True, max_epochs=10, val_interval=1),
        val_dataloader=val_dataloader,
        val_cfg=dict(),
        val_evaluator=dict(type=Accuracy))
    runner.train()


if __name__ == '__main__':
    # torchrun --nproc-per-node 2 distributed_training_with_flexible_runner.py --use-fsdp  # noqa: 501
    # torchrun --nproc-per-node 2 distributed_training_with_flexible_runner.py --use-deepspeed  # noqa: 501
    # torchrun --nproc-per-node 2 distributed_training_with_flexible_runner.py
    # python distributed_training_with_flexible_runner.py
    main()