mirror of https://github.com/open-mmlab/mmocr.git
[Fix] Add missing softmax in ASTER forward_test (#1718)
* add missing softmax * updatepull/1701/head
parent
b3be8cfbb3
commit
332089ca11
|
@ -71,6 +71,7 @@ class ASTERDecoder(BaseDecoder):
|
|||
|
||||
# Prediction layer
|
||||
self.fc = nn.Linear(hidden_size, self.dictionary.num_classes)
|
||||
self.softmax = nn.Softmax(dim=-1)
|
||||
|
||||
def _attention(self, feat: torch.Tensor, prev_hidden: torch.Tensor,
|
||||
prev_char: torch.Tensor
|
||||
|
@ -177,4 +178,4 @@ class ASTERDecoder(BaseDecoder):
|
|||
outputs.append(output)
|
||||
_, predicted = output.max(-1)
|
||||
outputs = torch.cat([_.unsqueeze(1) for _ in outputs], 1)
|
||||
return outputs
|
||||
return self.softmax(outputs)
|
||||
|
|
Loading…
Reference in New Issue