mmengine/tests/test_device/test_device.py
Alex Yang 2994195be2
[Feat] Support training on MPS (#331)
* [Feat] Support mps

* fix docstring
2022-06-23 16:53:19 +08:00

16 lines
445 B
Python

# Copyright (c) OpenMMLab. All rights reserved.
from mmengine.device import (get_device, is_cuda_available, is_mlu_available,
is_mps_available)
def test_get_device():
device = get_device()
if is_cuda_available():
assert device == 'cuda'
elif is_mlu_available():
assert device == 'mlu'
elif is_mps_available():
assert device == 'mps'
else:
assert device == 'cpu'