mirror of https://github.com/open-mmlab/mmcv.git
Add torch_meshgrid wrapper due to PyTorch change (#2044)
* Add torch_meshgrid_ij wrapper due to PyTorch change * Update torch_meshgrid name/doc/version implementation * Make imports local * add ut * ignore ut when torch is not available Co-authored-by: zhouzaida <zhouzaida@163.com>pull/2061/head
parent
b062468015
commit
f5425ab761
|
@ -45,7 +45,22 @@ jobs:
|
|||
- name: Run unittests and generate coverage report
|
||||
run: |
|
||||
pip install -r requirements/test.txt
|
||||
pytest tests/ --ignore=tests/test_runner --ignore=tests/test_device/test_ipu --ignore=tests/test_optimizer.py --ignore=tests/test_cnn --ignore=tests/test_parallel.py --ignore=tests/test_ops --ignore=tests/test_load_model_zoo.py --ignore=tests/test_utils/test_logging.py --ignore=tests/test_image/test_io.py --ignore=tests/test_utils/test_registry.py --ignore=tests/test_utils/test_parrots_jit.py --ignore=tests/test_utils/test_trace.py --ignore=tests/test_utils/test_hub.py --ignore=tests/test_device/test_mlu/test_mlu_parallel.py
|
||||
pytest tests/ \
|
||||
--ignore=tests/test_runner \
|
||||
--ignore=tests/test_device/test_ipu \
|
||||
--ignore=tests/test_optimizer.py \
|
||||
--ignore=tests/test_cnn \
|
||||
--ignore=tests/test_parallel.py \
|
||||
--ignore=tests/test_ops \
|
||||
--ignore=tests/test_load_model_zoo.py \
|
||||
--ignore=tests/test_utils/test_logging.py \
|
||||
--ignore=tests/test_image/test_io.py \
|
||||
--ignore=tests/test_utils/test_registry.py \
|
||||
--ignore=tests/test_utils/test_parrots_jit.py \
|
||||
--ignore=tests/test_utils/test_trace.py \
|
||||
--ignore=tests/test_utils/test_hub.py \
|
||||
--ignore=tests/test_device/test_mlu/test_mlu_parallel.py \
|
||||
--ignore=tests/test_utils/test_torch_ops.py
|
||||
|
||||
build_without_ops:
|
||||
runs-on: ubuntu-18.04
|
||||
|
|
|
@ -53,6 +53,7 @@ else:
|
|||
# yapf: enable
|
||||
from .registry import Registry, build_from_cfg
|
||||
from .seed import worker_init_fn
|
||||
from .torch_ops import torch_meshgrid
|
||||
from .trace import is_jit_tracing
|
||||
__all__ = [
|
||||
'Config', 'ConfigDict', 'DictAction', 'collect_env', 'get_logger',
|
||||
|
@ -74,5 +75,6 @@ else:
|
|||
'assert_params_all_zeros', 'check_python_script',
|
||||
'is_method_overridden', 'is_jit_tracing', 'is_rocm_pytorch',
|
||||
'_get_cuda_home', 'load_url', 'has_method', 'IS_CUDA_AVAILABLE',
|
||||
'worker_init_fn', 'IS_MLU_AVAILABLE', 'IS_IPU_AVAILABLE'
|
||||
'worker_init_fn', 'IS_MLU_AVAILABLE', 'IS_IPU_AVAILABLE',
|
||||
'torch_meshgrid'
|
||||
]
|
||||
|
|
|
@ -0,0 +1,29 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import torch
|
||||
|
||||
from .parrots_wrapper import TORCH_VERSION
|
||||
from .version_utils import digit_version
|
||||
|
||||
_torch_version_meshgrid_indexing = (
|
||||
'parrots' not in TORCH_VERSION
|
||||
and digit_version(TORCH_VERSION) >= digit_version('1.10.0a0'))
|
||||
|
||||
|
||||
def torch_meshgrid(*tensors):
|
||||
"""A wrapper of torch.meshgrid to compat different PyTorch versions.
|
||||
|
||||
Since PyTorch 1.10.0a0, torch.meshgrid supports the arguments ``indexing``.
|
||||
So we implement a wrapper here to avoid warning when using high-version
|
||||
PyTorch and avoid compatibility issues when using previous versions of
|
||||
PyTorch.
|
||||
|
||||
Args:
|
||||
tensors (List[Tensor]): List of scalars or 1 dimensional tensors.
|
||||
|
||||
Returns:
|
||||
Sequence[Tensor]: Sequence of meshgrid tensors.
|
||||
"""
|
||||
if _torch_version_meshgrid_indexing:
|
||||
return torch.meshgrid(*tensors, indexing='ij')
|
||||
else:
|
||||
return torch.meshgrid(*tensors) # Uses indexing='ij' by default
|
|
@ -0,0 +1,15 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from mmcv.utils import torch_meshgrid
|
||||
|
||||
|
||||
def test_torch_meshgrid():
|
||||
# torch_meshgrid should not throw warning
|
||||
with pytest.warns(None) as record:
|
||||
x = torch.tensor([1, 2, 3])
|
||||
y = torch.tensor([4, 5, 6])
|
||||
grid_x, grid_y = torch_meshgrid(x, y)
|
||||
|
||||
assert len(record) == 0
|
Loading…
Reference in New Issue