[Feature] Add is_tracing to wrap torch.jit.is_tracing in different versions. (#1187)

* Add `is_tracing` to wrap `torch.jit.is_tracing` in different versions.

* Remame `is_tracing` to `is_jit_tracing`

* Ignore `is_jit_tracing` tests in CI.
This commit is contained in:
Ma Zerun 2021-07-13 14:42:50 +08:00 committed by GitHub
parent c3ddcf9d38
commit 6659c38dd5
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 50 additions and 2 deletions

View File

@ -48,7 +48,7 @@ jobs:
- name: Run unittests and generate coverage report
run: |
pip install -r requirements/test.txt
pytest tests/ --ignore=tests/test_runner --ignore=tests/test_optimizer.py --ignore=tests/test_cnn --ignore=tests/test_parallel.py --ignore=tests/test_ops --ignore=tests/test_load_model_zoo.py --ignore=tests/test_utils/test_logging.py --ignore=tests/test_image/test_io.py --ignore=tests/test_utils/test_registry.py --ignore=tests/test_utils/test_parrots_jit.py
pytest tests/ --ignore=tests/test_runner --ignore=tests/test_optimizer.py --ignore=tests/test_cnn --ignore=tests/test_parallel.py --ignore=tests/test_ops --ignore=tests/test_load_model_zoo.py --ignore=tests/test_utils/test_logging.py --ignore=tests/test_image/test_io.py --ignore=tests/test_utils/test_registry.py --ignore=tests/test_utils/test_parrots_jit.py --ignore=tests/test_utils/test_trace.py
build_without_ops:
runs-on: ubuntu-18.04

View File

@ -45,6 +45,7 @@ else:
_AdaptiveMaxPoolNd, _AvgPoolNd, _BatchNorm, _ConvNd,
_ConvTransposeMixin, _InstanceNorm, _MaxPoolNd, get_build_config)
from .registry import Registry, build_from_cfg
from .trace import is_jit_tracing
__all__ = [
'Config', 'ConfigDict', 'DictAction', 'collect_env', 'get_logger',
'print_log', 'is_str', 'iter_cast', 'list_cast', 'tuple_cast',
@ -63,5 +64,5 @@ else:
'assert_dict_contains_subset', 'assert_attrs_equal',
'assert_dict_has_keys', 'assert_keys_equal', 'assert_is_norm_layer',
'assert_params_all_zeros', 'check_python_script',
'is_method_overridden'
'is_method_overridden', 'is_jit_tracing'
]

21
mmcv/utils/trace.py Normal file
View File

@ -0,0 +1,21 @@
import warnings
from distutils.version import LooseVersion
import torch
def is_jit_tracing() -> bool:
if LooseVersion(torch.__version__) >= LooseVersion('1.6.0'):
on_trace = torch.jit.is_tracing()
# In PyTorch 1.6, torch.jit.is_tracing has a bug.
# Refers to https://github.com/pytorch/pytorch/issues/42448
if isinstance(on_trace, bool):
return on_trace
else:
return torch._C._is_tracing()
else:
warnings.warn(
'torch.jit.is_tracing is only supported after v1.6.0. '
'Therefore is_tracing returns False automatically. Please '
'set on_trace manually if you are using trace.', UserWarning)
return False

View File

@ -0,0 +1,26 @@
from distutils.version import LooseVersion
import pytest
import torch
from mmcv.utils import is_jit_tracing
@pytest.mark.skipif(
LooseVersion(torch.__version__) < LooseVersion('1.6.0'),
reason='torch.jit.is_tracing is not available before 1.6.0')
def test_is_jit_tracing():
def foo(x):
if is_jit_tracing():
return x
else:
return x.tolist()
x = torch.rand(3)
# test without trace
assert isinstance(foo(x), list)
# test with trace
traced_foo = torch.jit.trace(foo, (torch.rand(1), ))
assert isinstance(traced_foo(x), torch.Tensor)