diff --git a/mmocr/visualization/kie_visualizer.py b/mmocr/visualization/kie_visualizer.py index 34aa4eca..753bac2e 100644 --- a/mmocr/visualization/kie_visualizer.py +++ b/mmocr/visualization/kie_visualizer.py @@ -70,6 +70,8 @@ class KIELocalVisualizer(BaseLocalVisualizer): np.ndarray: The image with edge labels drawn. """ pairs = np.where(edge_labels > 0) + if torch.is_tensor(pairs): + pairs = pairs.cpu() key_bboxes = bboxes[pairs[0]] value_bboxes = bboxes[pairs[1]] x_data = np.stack([(key_bboxes[:, 2] + key_bboxes[:, 0]) / 2,