mmengine/tests/test_device/test_device.py

16 lines
445 B
Python

# Copyright (c) OpenMMLab. All rights reserved.
from mmengine.device import (get_device, is_cuda_available, is_mlu_available,
is_mps_available)
def test_get_device():
device = get_device()
if is_cuda_available():
assert device == 'cuda'
elif is_mlu_available():
assert device == 'mlu'
elif is_mps_available():
assert device == 'mps'
else:
assert device == 'cpu'