mmdeploy/tests/test_codebase/test_mmseg/test_mmseg_utils.py

31 lines
959 B
Python

# Copyright (c) OpenMMLab. All rights reserved.
import torch
import torch.nn as nn
from mmdeploy.codebase import import_codebase
from mmdeploy.codebase.mmseg.deploy import convert_syncbatchnorm
from mmdeploy.utils import Codebase
import_codebase(Codebase.MMSEG)
def test_convert_syncbatchnorm():
class ExampleModel(nn.Module):
def __init__(self):
super(ExampleModel, self).__init__()
self.model = nn.Sequential(
nn.Linear(2, 4), nn.SyncBatchNorm(4), nn.Sigmoid(),
nn.Linear(4, 6), nn.SyncBatchNorm(6), nn.Sigmoid())
def forward(self, x):
return self.model(x)
model = ExampleModel()
out_model = convert_syncbatchnorm(model)
assert isinstance(out_model.model[1],
torch.nn.modules.batchnorm.BatchNorm2d) and isinstance(
out_model.model[4],
torch.nn.modules.batchnorm.BatchNorm2d)