mirror of
https://github.com/open-mmlab/mmengine.git
synced 2025-06-03 21:54:44 +08:00
* Rename data to structure * adjust the way to import module * adjust the way to import module * rename Structure to Data Structures in docs api * rename structure to structures * support using some modules of mmengine without torch * fix circleci config * fix circleci config * fix registry ut * minor fix * move init method from model/utils to model/weight_init.py * move init method from model/utils to model/weight_init.py * move sync_bn to model * move functions depending on torch to dl_utils * format import * fix logging ut * add weight init in model/__init__.py * move get_config and get_model to mmengine/hub * move log_processor.py to mmengine/runner * fix ut * Add TimeCounter in dl_utils/__init__.py
27 lines
672 B
Python
27 lines
672 B
Python
# Copyright (c) OpenMMLab. All rights reserved.
|
|
import pytest
|
|
import torch
|
|
|
|
from mmengine.utils import digit_version
|
|
from mmengine.utils.dl_utils import is_jit_tracing
|
|
|
|
|
|
@pytest.mark.skipif(
|
|
digit_version(torch.__version__) < digit_version('1.6.0'),
|
|
reason='torch.jit.is_tracing is not available before 1.6.0')
|
|
def test_is_jit_tracing():
|
|
|
|
def foo(x):
|
|
if is_jit_tracing():
|
|
return x
|
|
else:
|
|
return x.tolist()
|
|
|
|
x = torch.rand(3)
|
|
# test without trace
|
|
assert isinstance(foo(x), list)
|
|
|
|
# test with trace
|
|
traced_foo = torch.jit.trace(foo, (torch.rand(1), ))
|
|
assert isinstance(traced_foo(x), torch.Tensor)
|