Add torch_meshgrid wrapper due to PyTorch change (#2044)

* Add torch_meshgrid_ij wrapper due to PyTorch change

* Update torch_meshgrid name/doc/version implementation

* Make imports local

* add ut

* ignore ut when torch is not available

Co-authored-by: zhouzaida <zhouzaida@163.com>
pull/2061/head
Philipp Allgeuer 2022-06-15 14:36:48 +02:00 committed by GitHub
parent b062468015
commit f5425ab761
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 63 additions and 2 deletions

View File

@ -45,7 +45,22 @@ jobs:
- name: Run unittests and generate coverage report
run: |
pip install -r requirements/test.txt
pytest tests/ --ignore=tests/test_runner --ignore=tests/test_device/test_ipu --ignore=tests/test_optimizer.py --ignore=tests/test_cnn --ignore=tests/test_parallel.py --ignore=tests/test_ops --ignore=tests/test_load_model_zoo.py --ignore=tests/test_utils/test_logging.py --ignore=tests/test_image/test_io.py --ignore=tests/test_utils/test_registry.py --ignore=tests/test_utils/test_parrots_jit.py --ignore=tests/test_utils/test_trace.py --ignore=tests/test_utils/test_hub.py --ignore=tests/test_device/test_mlu/test_mlu_parallel.py
pytest tests/ \
--ignore=tests/test_runner \
--ignore=tests/test_device/test_ipu \
--ignore=tests/test_optimizer.py \
--ignore=tests/test_cnn \
--ignore=tests/test_parallel.py \
--ignore=tests/test_ops \
--ignore=tests/test_load_model_zoo.py \
--ignore=tests/test_utils/test_logging.py \
--ignore=tests/test_image/test_io.py \
--ignore=tests/test_utils/test_registry.py \
--ignore=tests/test_utils/test_parrots_jit.py \
--ignore=tests/test_utils/test_trace.py \
--ignore=tests/test_utils/test_hub.py \
--ignore=tests/test_device/test_mlu/test_mlu_parallel.py \
--ignore=tests/test_utils/test_torch_ops.py
build_without_ops:
runs-on: ubuntu-18.04

View File

@ -53,6 +53,7 @@ else:
# yapf: enable
from .registry import Registry, build_from_cfg
from .seed import worker_init_fn
from .torch_ops import torch_meshgrid
from .trace import is_jit_tracing
__all__ = [
'Config', 'ConfigDict', 'DictAction', 'collect_env', 'get_logger',
@ -74,5 +75,6 @@ else:
'assert_params_all_zeros', 'check_python_script',
'is_method_overridden', 'is_jit_tracing', 'is_rocm_pytorch',
'_get_cuda_home', 'load_url', 'has_method', 'IS_CUDA_AVAILABLE',
'worker_init_fn', 'IS_MLU_AVAILABLE', 'IS_IPU_AVAILABLE'
'worker_init_fn', 'IS_MLU_AVAILABLE', 'IS_IPU_AVAILABLE',
'torch_meshgrid'
]

View File

@ -0,0 +1,29 @@
# Copyright (c) OpenMMLab. All rights reserved.
import torch
from .parrots_wrapper import TORCH_VERSION
from .version_utils import digit_version
_torch_version_meshgrid_indexing = (
'parrots' not in TORCH_VERSION
and digit_version(TORCH_VERSION) >= digit_version('1.10.0a0'))
def torch_meshgrid(*tensors):
"""A wrapper of torch.meshgrid to compat different PyTorch versions.
Since PyTorch 1.10.0a0, torch.meshgrid supports the arguments ``indexing``.
So we implement a wrapper here to avoid warning when using high-version
PyTorch and avoid compatibility issues when using previous versions of
PyTorch.
Args:
tensors (List[Tensor]): List of scalars or 1 dimensional tensors.
Returns:
Sequence[Tensor]: Sequence of meshgrid tensors.
"""
if _torch_version_meshgrid_indexing:
return torch.meshgrid(*tensors, indexing='ij')
else:
return torch.meshgrid(*tensors) # Uses indexing='ij' by default

View File

@ -0,0 +1,15 @@
# Copyright (c) OpenMMLab. All rights reserved.
import pytest
import torch
from mmcv.utils import torch_meshgrid
def test_torch_meshgrid():
# torch_meshgrid should not throw warning
with pytest.warns(None) as record:
x = torch.tensor([1, 2, 3])
y = torch.tensor([4, 5, 6])
grid_x, grid_y = torch_meshgrid(x, y)
assert len(record) == 0