mmcv/tests/test_device/test_device_utils.py
wangjiangben-hw a5db5f6682
[Feature] Support training on NPU device (#2262)
* init npu

* add npu extension and focal loss adapter

* clean code

* clean code

* clean code

* clean code

* fix autocast bugs on npu (#2273)

fix autocast bugs on npu (#2273)

* code format

* code format

* code format

* bug fix

* pytorch_npu_helper.hpp clean code

* Npu dev (#2306)

* fix autocast bugs on npu
* using scatter_kwargs in mmcv.device.scatter_gather

* raise ImportError when compile with npu

* add npu test case (#2307)

* add npu test case

* Update focal_loss.py

* add comment

* clean lint

* update dtype assert

* update DDP forward and comment

* fix bug

Co-authored-by: Zaida Zhou <58739961+zhouzaida@users.noreply.github.com>

Co-authored-by: ckirchhoff <515629648@qq.com>
Co-authored-by: Zaida Zhou <58739961+zhouzaida@users.noreply.github.com>
2022-09-30 21:05:37 +08:00

19 lines
576 B
Python

# Copyright (c) OpenMMLab. All rights reserved.
from mmcv.device import get_device
from mmcv.utils import (IS_CUDA_AVAILABLE, IS_MLU_AVAILABLE, IS_MPS_AVAILABLE,
IS_NPU_AVAILABLE)
def test_get_device():
current_device = get_device()
if IS_NPU_AVAILABLE:
assert current_device == 'npu'
elif IS_CUDA_AVAILABLE:
assert current_device == 'cuda'
elif IS_MLU_AVAILABLE:
assert current_device == 'mlu'
elif IS_MPS_AVAILABLE:
assert current_device == 'mps'
else:
assert current_device == 'cpu'