[Fix] Add missing softmax in ASTER forward_test ()

* add missing softmax

* update
pull/1701/head
Qing Jiang 2023-02-13 10:32:55 +08:00 committed by GitHub
parent b3be8cfbb3
commit 332089ca11
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 2 additions and 1 deletions
mmocr/models/textrecog/decoders

View File

@ -71,6 +71,7 @@ class ASTERDecoder(BaseDecoder):
# Prediction layer
self.fc = nn.Linear(hidden_size, self.dictionary.num_classes)
self.softmax = nn.Softmax(dim=-1)
def _attention(self, feat: torch.Tensor, prev_hidden: torch.Tensor,
prev_char: torch.Tensor
@ -177,4 +178,4 @@ class ASTERDecoder(BaseDecoder):
outputs.append(output)
_, predicted = output.max(-1)
outputs = torch.cat([_.unsqueeze(1) for _ in outputs], 1)
return outputs
return self.softmax(outputs)