diff --git a/mmocr/models/textrecog/decoders/aster_decoder.py b/mmocr/models/textrecog/decoders/aster_decoder.py index acd43d41..83e249b0 100644 --- a/mmocr/models/textrecog/decoders/aster_decoder.py +++ b/mmocr/models/textrecog/decoders/aster_decoder.py @@ -71,6 +71,7 @@ class ASTERDecoder(BaseDecoder): # Prediction layer self.fc = nn.Linear(hidden_size, self.dictionary.num_classes) + self.softmax = nn.Softmax(dim=-1) def _attention(self, feat: torch.Tensor, prev_hidden: torch.Tensor, prev_char: torch.Tensor @@ -177,4 +178,4 @@ class ASTERDecoder(BaseDecoder): outputs.append(output) _, predicted = output.max(-1) outputs = torch.cat([_.unsqueeze(1) for _ in outputs], 1) - return outputs + return self.softmax(outputs)