# Copyright (c) OpenMMLab. All rights reserved.
from unittest.mock import MagicMock, patch

import mmcv
import torch
import torch.nn as nn
from mmcv.parallel import (MMDataParallel, MMDistributedDataParallel,
                           is_module_wrapper)

from mmseg import digit_version
from mmseg.utils import build_ddp, build_dp


def mock(*args, **kwargs):
    pass


class Model(nn.Module):

    def __init__(self):
        super().__init__()
        self.conv = nn.Conv2d(2, 2, 1)

    def forward(self, x):
        return self.conv(x)


@patch('torch.distributed._broadcast_coalesced', mock)
@patch('torch.distributed.broadcast', mock)
@patch('torch.nn.parallel.DistributedDataParallel._ddp_init_helper', mock)
def test_build_dp():
    model = Model()
    assert not is_module_wrapper(model)

    mmdp = build_dp(model, 'cpu')
    assert isinstance(mmdp, MMDataParallel)

    if torch.cuda.is_available():
        mmdp = build_dp(model, 'cuda')
        assert isinstance(mmdp, MMDataParallel)

    if digit_version(mmcv.__version__) >= digit_version('1.5.0'):
        from mmcv.device.mlu import MLUDataParallel
        from mmcv.utils import IS_MLU_AVAILABLE
        if IS_MLU_AVAILABLE:
            mludp = build_dp(model, 'mlu')
            assert isinstance(mludp, MLUDataParallel)


@patch('torch.distributed._broadcast_coalesced', mock)
@patch('torch.distributed.broadcast', mock)
@patch('torch.nn.parallel.DistributedDataParallel._ddp_init_helper', mock)
def test_build_ddp():
    model = Model()
    assert not is_module_wrapper(model)

    if torch.cuda.is_available():
        mmddp = build_ddp(
            model, 'cuda', device_ids=[0], process_group=MagicMock())
        assert isinstance(mmddp, MMDistributedDataParallel)

    if digit_version(mmcv.__version__) >= digit_version('1.5.0'):
        from mmcv.device.mlu import MLUDistributedDataParallel
        from mmcv.utils import IS_MLU_AVAILABLE
        if IS_MLU_AVAILABLE:
            mluddp = build_ddp(
                model, 'mlu', device_ids=[0], process_group=MagicMock())
            assert isinstance(mluddp, MLUDistributedDataParallel)