mirror of
https://github.com/open-mmlab/mmengine.git
synced 2025-06-03 21:54:44 +08:00
* Rename data to structure * adjust the way to import module * adjust the way to import module * rename Structure to Data Structures in docs api * rename structure to structures * support using some modules of mmengine without torch * fix circleci config * fix circleci config * fix registry ut * minor fix * move init method from model/utils to model/weight_init.py * move init method from model/utils to model/weight_init.py * move sync_bn to model * move functions depending on torch to dl_utils * format import * fix logging ut * add weight init in model/__init__.py * move get_config and get_model to mmengine/hub * move log_processor.py to mmengine/runner * fix ut * Add TimeCounter in dl_utils/__init__.py
21 lines
652 B
Python
21 lines
652 B
Python
# Copyright (c) OpenMMLab. All rights reserved.
|
|
import pytest
|
|
import torch
|
|
import torch.nn as nn
|
|
|
|
from mmengine.model import revert_sync_batchnorm
|
|
|
|
|
|
@pytest.mark.skipif(
|
|
torch.__version__ == 'parrots', reason='not supported in parrots now')
|
|
def test_revert_syncbn():
|
|
# conv = ConvModule(3, 8, 2, norm_cfg=dict(type='SyncBN'))
|
|
conv = nn.Sequential(nn.Conv2d(3, 8, 2), nn.SyncBatchNorm(8))
|
|
x = torch.randn(1, 3, 10, 10)
|
|
# Expect a ValueError prompting that SyncBN is not supported on CPU
|
|
with pytest.raises(ValueError):
|
|
y = conv(x)
|
|
conv = revert_sync_batchnorm(conv)
|
|
y = conv(x)
|
|
assert y.shape == (1, 8, 9, 9)
|