mmdeploy/tests/test_codebase/test_mmseg/test_mmseg_utils.py
Yifan Zhou 4149228716
[Enhancement]: Import codebase only when it is required (#266)
* Add import codebase

* lint

* Fix import order

* typo

* Fix partition

* docstring

* lint
2021-12-10 11:34:22 +08:00

31 lines
959 B
Python

# Copyright (c) OpenMMLab. All rights reserved.
import torch
import torch.nn as nn
from mmdeploy.codebase import import_codebase
from mmdeploy.codebase.mmseg.deploy import convert_syncbatchnorm
from mmdeploy.utils import Codebase
import_codebase(Codebase.MMSEG)
def test_convert_syncbatchnorm():
class ExampleModel(nn.Module):
def __init__(self):
super(ExampleModel, self).__init__()
self.model = nn.Sequential(
nn.Linear(2, 4), nn.SyncBatchNorm(4), nn.Sigmoid(),
nn.Linear(4, 6), nn.SyncBatchNorm(6), nn.Sigmoid())
def forward(self, x):
return self.model(x)
model = ExampleModel()
out_model = convert_syncbatchnorm(model)
assert isinstance(out_model.model[1],
torch.nn.modules.batchnorm.BatchNorm2d) and isinstance(
out_model.model[4],
torch.nn.modules.batchnorm.BatchNorm2d)