mirror of
https://github.com/open-mmlab/mmcv.git
synced 2025-06-03 21:54:52 +08:00
[Feature] Add is_tracing
to wrap torch.jit.is_tracing
in different versions. (#1187)
* Add `is_tracing` to wrap `torch.jit.is_tracing` in different versions. * Remame `is_tracing` to `is_jit_tracing` * Ignore `is_jit_tracing` tests in CI.
This commit is contained in:
parent
c3ddcf9d38
commit
6659c38dd5
2
.github/workflows/build.yml
vendored
2
.github/workflows/build.yml
vendored
@ -48,7 +48,7 @@ jobs:
|
|||||||
- name: Run unittests and generate coverage report
|
- name: Run unittests and generate coverage report
|
||||||
run: |
|
run: |
|
||||||
pip install -r requirements/test.txt
|
pip install -r requirements/test.txt
|
||||||
pytest tests/ --ignore=tests/test_runner --ignore=tests/test_optimizer.py --ignore=tests/test_cnn --ignore=tests/test_parallel.py --ignore=tests/test_ops --ignore=tests/test_load_model_zoo.py --ignore=tests/test_utils/test_logging.py --ignore=tests/test_image/test_io.py --ignore=tests/test_utils/test_registry.py --ignore=tests/test_utils/test_parrots_jit.py
|
pytest tests/ --ignore=tests/test_runner --ignore=tests/test_optimizer.py --ignore=tests/test_cnn --ignore=tests/test_parallel.py --ignore=tests/test_ops --ignore=tests/test_load_model_zoo.py --ignore=tests/test_utils/test_logging.py --ignore=tests/test_image/test_io.py --ignore=tests/test_utils/test_registry.py --ignore=tests/test_utils/test_parrots_jit.py --ignore=tests/test_utils/test_trace.py
|
||||||
|
|
||||||
build_without_ops:
|
build_without_ops:
|
||||||
runs-on: ubuntu-18.04
|
runs-on: ubuntu-18.04
|
||||||
|
@ -45,6 +45,7 @@ else:
|
|||||||
_AdaptiveMaxPoolNd, _AvgPoolNd, _BatchNorm, _ConvNd,
|
_AdaptiveMaxPoolNd, _AvgPoolNd, _BatchNorm, _ConvNd,
|
||||||
_ConvTransposeMixin, _InstanceNorm, _MaxPoolNd, get_build_config)
|
_ConvTransposeMixin, _InstanceNorm, _MaxPoolNd, get_build_config)
|
||||||
from .registry import Registry, build_from_cfg
|
from .registry import Registry, build_from_cfg
|
||||||
|
from .trace import is_jit_tracing
|
||||||
__all__ = [
|
__all__ = [
|
||||||
'Config', 'ConfigDict', 'DictAction', 'collect_env', 'get_logger',
|
'Config', 'ConfigDict', 'DictAction', 'collect_env', 'get_logger',
|
||||||
'print_log', 'is_str', 'iter_cast', 'list_cast', 'tuple_cast',
|
'print_log', 'is_str', 'iter_cast', 'list_cast', 'tuple_cast',
|
||||||
@ -63,5 +64,5 @@ else:
|
|||||||
'assert_dict_contains_subset', 'assert_attrs_equal',
|
'assert_dict_contains_subset', 'assert_attrs_equal',
|
||||||
'assert_dict_has_keys', 'assert_keys_equal', 'assert_is_norm_layer',
|
'assert_dict_has_keys', 'assert_keys_equal', 'assert_is_norm_layer',
|
||||||
'assert_params_all_zeros', 'check_python_script',
|
'assert_params_all_zeros', 'check_python_script',
|
||||||
'is_method_overridden'
|
'is_method_overridden', 'is_jit_tracing'
|
||||||
]
|
]
|
||||||
|
21
mmcv/utils/trace.py
Normal file
21
mmcv/utils/trace.py
Normal file
@ -0,0 +1,21 @@
|
|||||||
|
import warnings
|
||||||
|
from distutils.version import LooseVersion
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
|
|
||||||
|
def is_jit_tracing() -> bool:
|
||||||
|
if LooseVersion(torch.__version__) >= LooseVersion('1.6.0'):
|
||||||
|
on_trace = torch.jit.is_tracing()
|
||||||
|
# In PyTorch 1.6, torch.jit.is_tracing has a bug.
|
||||||
|
# Refers to https://github.com/pytorch/pytorch/issues/42448
|
||||||
|
if isinstance(on_trace, bool):
|
||||||
|
return on_trace
|
||||||
|
else:
|
||||||
|
return torch._C._is_tracing()
|
||||||
|
else:
|
||||||
|
warnings.warn(
|
||||||
|
'torch.jit.is_tracing is only supported after v1.6.0. '
|
||||||
|
'Therefore is_tracing returns False automatically. Please '
|
||||||
|
'set on_trace manually if you are using trace.', UserWarning)
|
||||||
|
return False
|
26
tests/test_utils/test_trace.py
Normal file
26
tests/test_utils/test_trace.py
Normal file
@ -0,0 +1,26 @@
|
|||||||
|
from distutils.version import LooseVersion
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from mmcv.utils import is_jit_tracing
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.skipif(
|
||||||
|
LooseVersion(torch.__version__) < LooseVersion('1.6.0'),
|
||||||
|
reason='torch.jit.is_tracing is not available before 1.6.0')
|
||||||
|
def test_is_jit_tracing():
|
||||||
|
|
||||||
|
def foo(x):
|
||||||
|
if is_jit_tracing():
|
||||||
|
return x
|
||||||
|
else:
|
||||||
|
return x.tolist()
|
||||||
|
|
||||||
|
x = torch.rand(3)
|
||||||
|
# test without trace
|
||||||
|
assert isinstance(foo(x), list)
|
||||||
|
|
||||||
|
# test with trace
|
||||||
|
traced_foo = torch.jit.trace(foo, (torch.rand(1), ))
|
||||||
|
assert isinstance(traced_foo(x), torch.Tensor)
|
Loading…
x
Reference in New Issue
Block a user