[Bug] Bug generated during kie inference visualization (#1830)

* Update kie_visualizer.py

* Update kie_visualizer.py

* Update kie_visualizer.py
This commit is contained in:
YangLy 2023-04-03 17:39:02 +08:00 committed by GitHub
parent 4842599191
commit e6174b29fe
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -70,6 +70,8 @@ class KIELocalVisualizer(BaseLocalVisualizer):
np.ndarray: The image with edge labels drawn.
"""
pairs = np.where(edge_labels > 0)
if torch.is_tensor(pairs):
pairs = pairs.cpu()
key_bboxes = bboxes[pairs[0]]
value_bboxes = bboxes[pairs[1]]
x_data = np.stack([(key_bboxes[:, 2] + key_bboxes[:, 0]) / 2,