mirror of
https://github.com/open-mmlab/mmocr.git
synced 2025-06-03 21:54:47 +08:00
[Fix] Standardize the type of torch.device in ocr.py (#800)
This commit is contained in:
parent
ac4462f374
commit
fb77352eb2
@ -329,8 +329,8 @@ class MMOCR:
|
||||
self.kie = kie
|
||||
self.device = device
|
||||
if self.device is None:
|
||||
self.device = torch.cuda.current_device() \
|
||||
if torch.cuda.is_available() else 'cpu'
|
||||
self.device = torch.device(
|
||||
'cuda' if torch.cuda.is_available() else 'cpu')
|
||||
|
||||
# Check if the det/recog model choice is valid
|
||||
if self.td and self.td not in textdet_models:
|
||||
|
@ -82,6 +82,8 @@ def test_ocr_init(mock_loading, mock_config, mock_build_detector,
|
||||
|
||||
def loadcheckpoint_assert(*args, **kwargs):
|
||||
assert args[1] == gt_ckpt[-1]
|
||||
assert kwargs['map_location'] == torch.device(
|
||||
'cuda' if torch.cuda.is_available() else 'cpu')
|
||||
|
||||
mock_loading.side_effect = loadcheckpoint_assert
|
||||
with mock.patch('mmocr.utils.ocr.revert_sync_batchnorm'):
|
||||
@ -106,8 +108,7 @@ def test_ocr_init(mock_loading, mock_config, mock_build_detector,
|
||||
mock_config.assert_called_with(gt_cfg[-1])
|
||||
mock_build_detector.assert_called_once()
|
||||
mock_loading.assert_called_once()
|
||||
device = torch.cuda.current_device() if \
|
||||
torch.cuda.is_available() else 'cpu'
|
||||
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
||||
calls = [
|
||||
mock.call(gt_cfg[i], gt_ckpt[i], device=device) for i in i_range
|
||||
]
|
||||
@ -168,8 +169,7 @@ def test_ocr_init_customize_config(mock_loading, mock_config,
|
||||
mock_config.assert_called_with(gt_cfg[-1])
|
||||
mock_build_detector.assert_called_once()
|
||||
mock_loading.assert_called_once()
|
||||
device = torch.cuda.current_device() if \
|
||||
torch.cuda.is_available() else 'cpu'
|
||||
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
||||
calls = [
|
||||
mock.call(gt_cfg[i], gt_ckpt[i], device=device) for i in i_range
|
||||
]
|
||||
|
Loading…
x
Reference in New Issue
Block a user