2022-07-07 16:05:49 +08:00
|
|
|
# Copyright (c) OpenMMLab. All rights reserved.
|
|
|
|
from mmcv.device import get_device
|
2022-09-30 21:05:37 +08:00
|
|
|
from mmcv.utils import (IS_CUDA_AVAILABLE, IS_MLU_AVAILABLE, IS_MPS_AVAILABLE,
|
|
|
|
IS_NPU_AVAILABLE)
|
2022-07-07 16:05:49 +08:00
|
|
|
|
|
|
|
|
|
|
|
def test_get_device():
|
|
|
|
current_device = get_device()
|
2022-09-30 21:05:37 +08:00
|
|
|
if IS_NPU_AVAILABLE:
|
|
|
|
assert current_device == 'npu'
|
|
|
|
elif IS_CUDA_AVAILABLE:
|
2022-07-07 16:05:49 +08:00
|
|
|
assert current_device == 'cuda'
|
|
|
|
elif IS_MLU_AVAILABLE:
|
|
|
|
assert current_device == 'mlu'
|
|
|
|
elif IS_MPS_AVAILABLE:
|
|
|
|
assert current_device == 'mps'
|
|
|
|
else:
|
|
|
|
assert current_device == 'cpu'
|