mirror of
https://github.com/open-mmlab/mmcv.git
synced 2025-06-03 21:54:52 +08:00
[Feature] Add is_tracing
to wrap torch.jit.is_tracing
in different versions. (#1187)
* Add `is_tracing` to wrap `torch.jit.is_tracing` in different versions. * Remame `is_tracing` to `is_jit_tracing` * Ignore `is_jit_tracing` tests in CI.
This commit is contained in:
parent
c3ddcf9d38
commit
6659c38dd5
2
.github/workflows/build.yml
vendored
2
.github/workflows/build.yml
vendored
@ -48,7 +48,7 @@ jobs:
|
||||
- name: Run unittests and generate coverage report
|
||||
run: |
|
||||
pip install -r requirements/test.txt
|
||||
pytest tests/ --ignore=tests/test_runner --ignore=tests/test_optimizer.py --ignore=tests/test_cnn --ignore=tests/test_parallel.py --ignore=tests/test_ops --ignore=tests/test_load_model_zoo.py --ignore=tests/test_utils/test_logging.py --ignore=tests/test_image/test_io.py --ignore=tests/test_utils/test_registry.py --ignore=tests/test_utils/test_parrots_jit.py
|
||||
pytest tests/ --ignore=tests/test_runner --ignore=tests/test_optimizer.py --ignore=tests/test_cnn --ignore=tests/test_parallel.py --ignore=tests/test_ops --ignore=tests/test_load_model_zoo.py --ignore=tests/test_utils/test_logging.py --ignore=tests/test_image/test_io.py --ignore=tests/test_utils/test_registry.py --ignore=tests/test_utils/test_parrots_jit.py --ignore=tests/test_utils/test_trace.py
|
||||
|
||||
build_without_ops:
|
||||
runs-on: ubuntu-18.04
|
||||
|
@ -45,6 +45,7 @@ else:
|
||||
_AdaptiveMaxPoolNd, _AvgPoolNd, _BatchNorm, _ConvNd,
|
||||
_ConvTransposeMixin, _InstanceNorm, _MaxPoolNd, get_build_config)
|
||||
from .registry import Registry, build_from_cfg
|
||||
from .trace import is_jit_tracing
|
||||
__all__ = [
|
||||
'Config', 'ConfigDict', 'DictAction', 'collect_env', 'get_logger',
|
||||
'print_log', 'is_str', 'iter_cast', 'list_cast', 'tuple_cast',
|
||||
@ -63,5 +64,5 @@ else:
|
||||
'assert_dict_contains_subset', 'assert_attrs_equal',
|
||||
'assert_dict_has_keys', 'assert_keys_equal', 'assert_is_norm_layer',
|
||||
'assert_params_all_zeros', 'check_python_script',
|
||||
'is_method_overridden'
|
||||
'is_method_overridden', 'is_jit_tracing'
|
||||
]
|
||||
|
21
mmcv/utils/trace.py
Normal file
21
mmcv/utils/trace.py
Normal file
@ -0,0 +1,21 @@
|
||||
import warnings
|
||||
from distutils.version import LooseVersion
|
||||
|
||||
import torch
|
||||
|
||||
|
||||
def is_jit_tracing() -> bool:
|
||||
if LooseVersion(torch.__version__) >= LooseVersion('1.6.0'):
|
||||
on_trace = torch.jit.is_tracing()
|
||||
# In PyTorch 1.6, torch.jit.is_tracing has a bug.
|
||||
# Refers to https://github.com/pytorch/pytorch/issues/42448
|
||||
if isinstance(on_trace, bool):
|
||||
return on_trace
|
||||
else:
|
||||
return torch._C._is_tracing()
|
||||
else:
|
||||
warnings.warn(
|
||||
'torch.jit.is_tracing is only supported after v1.6.0. '
|
||||
'Therefore is_tracing returns False automatically. Please '
|
||||
'set on_trace manually if you are using trace.', UserWarning)
|
||||
return False
|
26
tests/test_utils/test_trace.py
Normal file
26
tests/test_utils/test_trace.py
Normal file
@ -0,0 +1,26 @@
|
||||
from distutils.version import LooseVersion
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from mmcv.utils import is_jit_tracing
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
LooseVersion(torch.__version__) < LooseVersion('1.6.0'),
|
||||
reason='torch.jit.is_tracing is not available before 1.6.0')
|
||||
def test_is_jit_tracing():
|
||||
|
||||
def foo(x):
|
||||
if is_jit_tracing():
|
||||
return x
|
||||
else:
|
||||
return x.tolist()
|
||||
|
||||
x = torch.rand(3)
|
||||
# test without trace
|
||||
assert isinstance(foo(x), list)
|
||||
|
||||
# test with trace
|
||||
traced_foo = torch.jit.trace(foo, (torch.rand(1), ))
|
||||
assert isinstance(traced_foo(x), torch.Tensor)
|
Loading…
x
Reference in New Issue
Block a user