mmcv/tests/test_utils/test_trace.py
Haodong Duan ef48a47389
[Improvement] Improve digit_version & use it for version_checking (#1185)
* improve digit_version & use it for version_checking

* more testing for digit_version

* setuptools >= 50 is needed

* fix CI

* add debuging log

* >= to ==

* fix lint

* remove

* add failure case

* replace

* fix

* consider TORCH_VERSION == 'parrots'

* add unittest

* digit_version do not deal with the case if 'parrots' in version name.
2021-07-23 21:03:33 +08:00

25 lines
585 B
Python

import pytest
import torch
from mmcv.utils import digit_version, is_jit_tracing
@pytest.mark.skipif(
digit_version(torch.__version__) < digit_version('1.6.0'),
reason='torch.jit.is_tracing is not available before 1.6.0')
def test_is_jit_tracing():
def foo(x):
if is_jit_tracing():
return x
else:
return x.tolist()
x = torch.rand(3)
# test without trace
assert isinstance(foo(x), list)
# test with trace
traced_foo = torch.jit.trace(foo, (torch.rand(1), ))
assert isinstance(traced_foo(x), torch.Tensor)