29 lines
831 B
Python
29 lines
831 B
Python
# Copyright (c) OpenMMLab. All rights reserved.
|
|
from unittest import TestCase
|
|
from unittest.mock import patch
|
|
|
|
import mmcv
|
|
|
|
from mmcls.utils import auto_select_device
|
|
|
|
|
|
class TestAutoSelectDevice(TestCase):
|
|
|
|
@patch.object(mmcv, '__version__', '1.6.0')
|
|
@patch('mmcv.device.get_device', create=True)
|
|
def test_mmcv(self, mock):
|
|
auto_select_device()
|
|
mock.assert_called_once()
|
|
|
|
@patch.object(mmcv, '__version__', '1.5.0')
|
|
@patch('torch.cuda.is_available', return_value=True)
|
|
def test_cuda(self, mock):
|
|
device = auto_select_device()
|
|
self.assertEqual(device, 'cuda')
|
|
|
|
@patch.object(mmcv, '__version__', '1.5.0')
|
|
@patch('torch.cuda.is_available', return_value=False)
|
|
def test_cpu(self, mock):
|
|
device = auto_select_device()
|
|
self.assertEqual(device, 'cpu')
|