diff --git a/mmocr/utils/ocr.py b/mmocr/utils/ocr.py index 93e36579..d99dbe69 100755 --- a/mmocr/utils/ocr.py +++ b/mmocr/utils/ocr.py @@ -329,8 +329,8 @@ class MMOCR: self.kie = kie self.device = device if self.device is None: - self.device = torch.cuda.current_device() \ - if torch.cuda.is_available() else 'cpu' + self.device = torch.device( + 'cuda' if torch.cuda.is_available() else 'cpu') # Check if the det/recog model choice is valid if self.td and self.td not in textdet_models: diff --git a/tests/test_utils/test_ocr.py b/tests/test_utils/test_ocr.py index 9aedf95a..c2332abe 100644 --- a/tests/test_utils/test_ocr.py +++ b/tests/test_utils/test_ocr.py @@ -82,6 +82,8 @@ def test_ocr_init(mock_loading, mock_config, mock_build_detector, def loadcheckpoint_assert(*args, **kwargs): assert args[1] == gt_ckpt[-1] + assert kwargs['map_location'] == torch.device( + 'cuda' if torch.cuda.is_available() else 'cpu') mock_loading.side_effect = loadcheckpoint_assert with mock.patch('mmocr.utils.ocr.revert_sync_batchnorm'): @@ -106,8 +108,7 @@ def test_ocr_init(mock_loading, mock_config, mock_build_detector, mock_config.assert_called_with(gt_cfg[-1]) mock_build_detector.assert_called_once() mock_loading.assert_called_once() - device = torch.cuda.current_device() if \ - torch.cuda.is_available() else 'cpu' + device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') calls = [ mock.call(gt_cfg[i], gt_ckpt[i], device=device) for i in i_range ] @@ -168,8 +169,7 @@ def test_ocr_init_customize_config(mock_loading, mock_config, mock_config.assert_called_with(gt_cfg[-1]) mock_build_detector.assert_called_once() mock_loading.assert_called_once() - device = torch.cuda.current_device() if \ - torch.cuda.is_available() else 'cpu' + device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') calls = [ mock.call(gt_cfg[i], gt_ckpt[i], device=device) for i in i_range ]