mmcv/tests/test_device/test_device_utils.py
Zaida Zhou 6a03918f55
[Feature] Add support for mps (#2092)
* [Feature] Add support for MPS

* fix import error

* update ut

* fix error

* trigger CI

* use a unique basename for test file modules

* avoid bc-breaking
2022-07-07 16:05:49 +08:00

16 lines
466 B
Python

# Copyright (c) OpenMMLab. All rights reserved.
from mmcv.device import get_device
from mmcv.utils import IS_CUDA_AVAILABLE, IS_MLU_AVAILABLE, IS_MPS_AVAILABLE
def test_get_device():
current_device = get_device()
if IS_CUDA_AVAILABLE:
assert current_device == 'cuda'
elif IS_MLU_AVAILABLE:
assert current_device == 'mlu'
elif IS_MPS_AVAILABLE:
assert current_device == 'mps'
else:
assert current_device == 'cpu'