diff --git a/.github/workflows/build.yml b/.github/workflows/build.yml index 0b90cbf21..754757ece 100644 --- a/.github/workflows/build.yml +++ b/.github/workflows/build.yml @@ -48,7 +48,7 @@ jobs: - name: Run unittests and generate coverage report run: | pip install -r requirements/test.txt - pytest tests/ --ignore=tests/test_runner --ignore=tests/test_optimizer.py --ignore=tests/test_cnn --ignore=tests/test_parallel.py --ignore=tests/test_ops --ignore=tests/test_load_model_zoo.py --ignore=tests/test_utils/test_logging.py --ignore=tests/test_image/test_io.py --ignore=tests/test_utils/test_registry.py --ignore=tests/test_utils/test_parrots_jit.py + pytest tests/ --ignore=tests/test_runner --ignore=tests/test_optimizer.py --ignore=tests/test_cnn --ignore=tests/test_parallel.py --ignore=tests/test_ops --ignore=tests/test_load_model_zoo.py --ignore=tests/test_utils/test_logging.py --ignore=tests/test_image/test_io.py --ignore=tests/test_utils/test_registry.py --ignore=tests/test_utils/test_parrots_jit.py --ignore=tests/test_utils/test_trace.py build_without_ops: runs-on: ubuntu-18.04 diff --git a/mmcv/utils/__init__.py b/mmcv/utils/__init__.py index 6ca345240..72eaf8cde 100644 --- a/mmcv/utils/__init__.py +++ b/mmcv/utils/__init__.py @@ -45,6 +45,7 @@ else: _AdaptiveMaxPoolNd, _AvgPoolNd, _BatchNorm, _ConvNd, _ConvTransposeMixin, _InstanceNorm, _MaxPoolNd, get_build_config) from .registry import Registry, build_from_cfg + from .trace import is_jit_tracing __all__ = [ 'Config', 'ConfigDict', 'DictAction', 'collect_env', 'get_logger', 'print_log', 'is_str', 'iter_cast', 'list_cast', 'tuple_cast', @@ -63,5 +64,5 @@ else: 'assert_dict_contains_subset', 'assert_attrs_equal', 'assert_dict_has_keys', 'assert_keys_equal', 'assert_is_norm_layer', 'assert_params_all_zeros', 'check_python_script', - 'is_method_overridden' + 'is_method_overridden', 'is_jit_tracing' ] diff --git a/mmcv/utils/trace.py b/mmcv/utils/trace.py new file mode 100644 index 000000000..f2a6fceaf --- /dev/null +++ b/mmcv/utils/trace.py @@ -0,0 +1,21 @@ +import warnings +from distutils.version import LooseVersion + +import torch + + +def is_jit_tracing() -> bool: + if LooseVersion(torch.__version__) >= LooseVersion('1.6.0'): + on_trace = torch.jit.is_tracing() + # In PyTorch 1.6, torch.jit.is_tracing has a bug. + # Refers to https://github.com/pytorch/pytorch/issues/42448 + if isinstance(on_trace, bool): + return on_trace + else: + return torch._C._is_tracing() + else: + warnings.warn( + 'torch.jit.is_tracing is only supported after v1.6.0. ' + 'Therefore is_tracing returns False automatically. Please ' + 'set on_trace manually if you are using trace.', UserWarning) + return False diff --git a/tests/test_utils/test_trace.py b/tests/test_utils/test_trace.py new file mode 100644 index 000000000..c39b2c3b2 --- /dev/null +++ b/tests/test_utils/test_trace.py @@ -0,0 +1,26 @@ +from distutils.version import LooseVersion + +import pytest +import torch + +from mmcv.utils import is_jit_tracing + + +@pytest.mark.skipif( + LooseVersion(torch.__version__) < LooseVersion('1.6.0'), + reason='torch.jit.is_tracing is not available before 1.6.0') +def test_is_jit_tracing(): + + def foo(x): + if is_jit_tracing(): + return x + else: + return x.tolist() + + x = torch.rand(3) + # test without trace + assert isinstance(foo(x), list) + + # test with trace + traced_foo = torch.jit.trace(foo, (torch.rand(1), )) + assert isinstance(traced_foo(x), torch.Tensor)