fix shape unmatch error

pull/6781/head
WenmuZhou 2022-07-04 04:26:54 +00:00
parent 8d84144029
commit b541cc12be
1 changed files with 2 additions and 0 deletions

View File

@ -58,6 +58,8 @@ class PSEPostProcess(object):
kernels = (pred > self.thresh).astype('float32')
text_mask = kernels[:, 0, :, :]
text_mask = paddle.unsqueeze(text_mask, axis=1)
kernels[:, 0:, :, :] = kernels[:, 0:, :, :] * text_mask
score = score.numpy()