[Fix] Standardize the type of torch.device in ocr.py (#800)

This commit is contained in:
Tong Gao 2022-03-03 14:18:33 +08:00 committed by GitHub
parent ac4462f374
commit fb77352eb2
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 6 additions and 6 deletions

View File

@ -329,8 +329,8 @@ class MMOCR:
self.kie = kie
self.device = device
if self.device is None:
self.device = torch.cuda.current_device() \
if torch.cuda.is_available() else 'cpu'
self.device = torch.device(
'cuda' if torch.cuda.is_available() else 'cpu')
# Check if the det/recog model choice is valid
if self.td and self.td not in textdet_models:

View File

@ -82,6 +82,8 @@ def test_ocr_init(mock_loading, mock_config, mock_build_detector,
def loadcheckpoint_assert(*args, **kwargs):
assert args[1] == gt_ckpt[-1]
assert kwargs['map_location'] == torch.device(
'cuda' if torch.cuda.is_available() else 'cpu')
mock_loading.side_effect = loadcheckpoint_assert
with mock.patch('mmocr.utils.ocr.revert_sync_batchnorm'):
@ -106,8 +108,7 @@ def test_ocr_init(mock_loading, mock_config, mock_build_detector,
mock_config.assert_called_with(gt_cfg[-1])
mock_build_detector.assert_called_once()
mock_loading.assert_called_once()
device = torch.cuda.current_device() if \
torch.cuda.is_available() else 'cpu'
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
calls = [
mock.call(gt_cfg[i], gt_ckpt[i], device=device) for i in i_range
]
@ -168,8 +169,7 @@ def test_ocr_init_customize_config(mock_loading, mock_config,
mock_config.assert_called_with(gt_cfg[-1])
mock_build_detector.assert_called_once()
mock_loading.assert_called_once()
device = torch.cuda.current_device() if \
torch.cuda.is_available() else 'cpu'
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
calls = [
mock.call(gt_cfg[i], gt_ckpt[i], device=device) for i in i_range
]