# Copyright (c) OpenMMLab. All rights reserved.
import torch
import torch.nn as nn

from mmdeploy.codebase import import_codebase
from mmdeploy.codebase.mmseg.deploy import convert_syncbatchnorm
from mmdeploy.utils import Codebase

import_codebase(Codebase.MMSEG)


def test_convert_syncbatchnorm():

    class ExampleModel(nn.Module):

        def __init__(self):
            super(ExampleModel, self).__init__()
            self.model = nn.Sequential(
                nn.Linear(2, 4), nn.SyncBatchNorm(4), nn.Sigmoid(),
                nn.Linear(4, 6), nn.SyncBatchNorm(6), nn.Sigmoid())

        def forward(self, x):
            return self.model(x)

    model = ExampleModel()
    out_model = convert_syncbatchnorm(model)
    assert isinstance(out_model.model[1],
                      torch.nn.modules.batchnorm.BatchNorm2d) and isinstance(
                          out_model.model[4],
                          torch.nn.modules.batchnorm.BatchNorm2d)