mirror of
https://github.com/open-mmlab/mmocr.git
synced 2025-06-03 21:54:47 +08:00
[TODO] Fix score
This commit is contained in:
parent
e73665029b
commit
7813e18a6c
@ -77,6 +77,7 @@ class ABIFuser(BaseDecoder):
|
||||
warnings.warn(f"Using max_seq_len {cfg['max_seq_len']} "
|
||||
"in decoder's config.")
|
||||
setattr(self, cfg_name, MODELS.build(cfg))
|
||||
self.softmax = nn.Softmax(dim=-1)
|
||||
|
||||
def forward_train(
|
||||
self,
|
||||
@ -138,7 +139,9 @@ class ABIFuser(BaseDecoder):
|
||||
DataSample placeholder. Defaults to None.
|
||||
|
||||
Returns:
|
||||
torch.Tensor: Raw logits.
|
||||
Tensor: Character probabilities. of shape
|
||||
:math:`(N, self.max_seq_len, C)` where :math:`C` is
|
||||
``num_classes``.
|
||||
"""
|
||||
raw_result = self.forward_train(feat, logits, data_samples)
|
||||
|
||||
@ -149,7 +152,7 @@ class ABIFuser(BaseDecoder):
|
||||
else:
|
||||
ret = raw_result['out_vis']['logits']
|
||||
|
||||
return ret
|
||||
return self.softmax(ret)
|
||||
|
||||
def fuse(self, l_feature: torch.Tensor, v_feature: torch.Tensor) -> Dict:
|
||||
"""Mix and align visual feature and linguistic feature.
|
||||
|
@ -54,6 +54,7 @@ class CRNNDecoder(BaseDecoder):
|
||||
self.dictionary.num_classes,
|
||||
kernel_size=1,
|
||||
stride=1)
|
||||
self.softmax = nn.Softmax(dim=-1)
|
||||
|
||||
def forward_train(
|
||||
self,
|
||||
@ -101,7 +102,8 @@ class CRNNDecoder(BaseDecoder):
|
||||
Defaults to None.
|
||||
|
||||
Returns:
|
||||
Tensor: The raw logit tensor. Shape :math:`(N, W, C)` where
|
||||
:math:`C` is ``num_classes``.
|
||||
Tensor: Character probabilities. of shape
|
||||
:math:`(N, self.max_seq_len, C)` where :math:`C` is
|
||||
``num_classes``.
|
||||
"""
|
||||
return self.forward_train(feat, out_enc, data_samples)
|
||||
return self.softmax(self.forward_train(feat, out_enc, data_samples))
|
||||
|
@ -5,7 +5,6 @@ from typing import Dict, Optional, Sequence, Union
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from mmcv.cnn.bricks.transformer import BaseTransformerLayer
|
||||
from mmcv.runner import ModuleList
|
||||
|
||||
@ -148,6 +147,7 @@ class MasterDecoder(BaseDecoder):
|
||||
self.feat_positional_encoding = PositionalEncoding(
|
||||
d_hid=d_model, n_position=self.feat_size, dropout=feat_pe_drop)
|
||||
self.norm = nn.LayerNorm(d_model)
|
||||
self.softmax = nn.Softmax(dim=-1)
|
||||
|
||||
def make_target_mask(self, tgt: torch.Tensor,
|
||||
device: torch.device) -> torch.Tensor:
|
||||
@ -248,8 +248,8 @@ class MasterDecoder(BaseDecoder):
|
||||
data_samples (list[TextRecogDataSample]): Unused.
|
||||
|
||||
Returns:
|
||||
Tensor: The raw logit tensor.
|
||||
Shape :math:`(N, self.max_seq_len, C)` where :math:`C` is
|
||||
Tensor: Character probabilities. of shape
|
||||
:math:`(N, self.max_seq_len, C)` where :math:`C` is
|
||||
``num_classes``.
|
||||
"""
|
||||
|
||||
@ -270,7 +270,6 @@ class MasterDecoder(BaseDecoder):
|
||||
target_mask = self.make_target_mask(input, device=feat.device)
|
||||
out = self.decode(input, feat, None, target_mask)
|
||||
output = out
|
||||
prob = F.softmax(out, dim=-1)
|
||||
_, next_word = torch.max(prob, dim=-1)
|
||||
_, next_word = torch.max(out, dim=-1)
|
||||
input = torch.cat([input, next_word[:, -1].unsqueeze(-1)], dim=1)
|
||||
return output
|
||||
return self.softmax(output)
|
||||
|
@ -4,7 +4,6 @@ from typing import Dict, List, Optional, Sequence, Union
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from mmcv.runner import ModuleList
|
||||
|
||||
from mmocr.data import TextRecogDataSample
|
||||
@ -86,6 +85,7 @@ class NRTRDecoder(BaseDecoder):
|
||||
|
||||
pred_num_class = self.dictionary.num_classes
|
||||
self.classifier = nn.Linear(d_model, pred_num_class)
|
||||
self.softmax = nn.Softmax(dim=-1)
|
||||
|
||||
def _get_target_mask(self, trg_seq: torch.Tensor) -> torch.Tensor:
|
||||
"""Generate mask for target sequence.
|
||||
@ -225,8 +225,8 @@ class NRTRDecoder(BaseDecoder):
|
||||
information. Defaults to None.
|
||||
|
||||
Returns:
|
||||
Tensor: The raw logit tensor.
|
||||
Shape :math:`(N, self.max_seq_len, C)` where :math:`C` is
|
||||
Tensor: Character probabilities. of shape
|
||||
:math:`(N, self.max_seq_len, C)` where :math:`C` is
|
||||
``num_classes``.
|
||||
"""
|
||||
valid_ratios = []
|
||||
@ -246,8 +246,7 @@ class NRTRDecoder(BaseDecoder):
|
||||
decoder_output = self._attention(
|
||||
init_target_seq, out_enc, src_mask=src_mask)
|
||||
# bsz * seq_len * C
|
||||
step_result = F.softmax(
|
||||
self.classifier(decoder_output[:, step, :]), dim=-1)
|
||||
step_result = self.classifier(decoder_output[:, step, :])
|
||||
# bsz * num_classes
|
||||
outputs.append(step_result)
|
||||
_, step_max_index = torch.max(step_result, dim=-1)
|
||||
@ -255,4 +254,4 @@ class NRTRDecoder(BaseDecoder):
|
||||
|
||||
outputs = torch.stack(outputs, dim=1)
|
||||
|
||||
return outputs
|
||||
return self.softmax(outputs)
|
||||
|
@ -83,6 +83,7 @@ class PositionAttentionDecoder(BaseDecoder):
|
||||
self.prediction = nn.Linear(
|
||||
dim_model if encode_value else dim_input,
|
||||
self.dictionary.num_classes)
|
||||
self.softmax = nn.Softmax(dim=-1)
|
||||
|
||||
def _get_position_index(self,
|
||||
length: int,
|
||||
@ -174,7 +175,7 @@ class PositionAttentionDecoder(BaseDecoder):
|
||||
to None.
|
||||
|
||||
Returns:
|
||||
Tensor: A raw logit tensor of shape :math:`(N, T, C)` if
|
||||
Tensor: Character probabilities of shape :math:`(N, T, C)` if
|
||||
``return_feature=False``. Otherwise it would be the hidden feature
|
||||
before the prediction projection layer, whose shape is
|
||||
:math:`(N, T, D_m)`.
|
||||
@ -216,4 +217,4 @@ class PositionAttentionDecoder(BaseDecoder):
|
||||
if self.return_feature:
|
||||
return attn_out
|
||||
|
||||
return self.prediction(attn_out)
|
||||
return self.softmax(self.prediction(attn_out))
|
||||
|
@ -4,7 +4,6 @@ from typing import Dict, List, Optional, Sequence, Union
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
from mmocr.data import TextRecogDataSample
|
||||
from mmocr.models.textrecog.dictionary import Dictionary
|
||||
@ -78,6 +77,7 @@ class RobustScannerFuser(BaseDecoder):
|
||||
self.glu_layer = nn.GLU(dim=dim)
|
||||
self.prediction = nn.Linear(
|
||||
int(in_channels / 2), self.dictionary.num_classes)
|
||||
self.softmax = nn.Softmax(dim=-1)
|
||||
|
||||
def forward_train(
|
||||
self,
|
||||
@ -117,6 +117,11 @@ class RobustScannerFuser(BaseDecoder):
|
||||
data_samples (Sequence[TextRecogDataSample]): Batch of
|
||||
TextRecogDataSample, containing vaild_ratio information.
|
||||
Defaults to None.
|
||||
|
||||
Returns:
|
||||
Tensor: Character probabilities. of shape
|
||||
:math:`(N, self.max_seq_len, C)` where :math:`C` is
|
||||
``num_classes``.
|
||||
"""
|
||||
position_glimpse = self.position_decoder(feat, out_enc, data_samples)
|
||||
|
||||
@ -133,10 +138,9 @@ class RobustScannerFuser(BaseDecoder):
|
||||
output = self.linear_layer(fusion_input)
|
||||
output = self.glu_layer(output)
|
||||
output = self.prediction(output)
|
||||
output = F.softmax(output, -1)
|
||||
_, max_idx = torch.max(output, dim=1, keepdim=False)
|
||||
if step < self.max_seq_len - 1:
|
||||
decode_sequence[:, step + 1] = max_idx
|
||||
outputs.append(output)
|
||||
outputs = torch.stack(outputs, 1)
|
||||
return outputs
|
||||
return self.softmax(outputs)
|
||||
|
@ -115,6 +115,7 @@ class ParallelSARDecoder(BaseDecoder):
|
||||
else:
|
||||
fc_in_channel = d_model
|
||||
self.prediction = nn.Linear(fc_in_channel, self.num_classes)
|
||||
self.softmax = nn.Softmax(dim=-1)
|
||||
|
||||
def _2d_attention(self,
|
||||
decoder_input: torch.Tensor,
|
||||
@ -239,7 +240,9 @@ class ParallelSARDecoder(BaseDecoder):
|
||||
information. Defaults to None.
|
||||
|
||||
Returns:
|
||||
Tensor: A raw logit tensor of shape :math:`(N, T, C)`.
|
||||
Tensor: Character probabilities. of shape
|
||||
:math:`(N, self.max_seq_len, C)` where :math:`C` is
|
||||
``num_classes``.
|
||||
"""
|
||||
if data_samples is not None:
|
||||
assert len(data_samples) == feat.size(0)
|
||||
@ -273,7 +276,6 @@ class ParallelSARDecoder(BaseDecoder):
|
||||
decoder_output = self._2d_attention(
|
||||
decoder_input, feat, out_enc, valid_ratios=valid_ratios)
|
||||
char_output = decoder_output[:, i, :] # bsz * num_classes
|
||||
char_output = F.softmax(char_output, -1)
|
||||
outputs.append(char_output)
|
||||
_, max_idx = torch.max(char_output, dim=1, keepdim=False)
|
||||
char_embedding = self.embedding(max_idx) # bsz * emb_dim
|
||||
@ -282,7 +284,7 @@ class ParallelSARDecoder(BaseDecoder):
|
||||
|
||||
outputs = torch.stack(outputs, 1) # bsz * seq_len * num_classes
|
||||
|
||||
return outputs
|
||||
return self.softmax(outputs)
|
||||
|
||||
|
||||
@MODELS.register_module()
|
||||
@ -386,6 +388,7 @@ class SequentialSARDecoder(BaseDecoder):
|
||||
else:
|
||||
fc_in_channel = d_model
|
||||
self.prediction = nn.Linear(fc_in_channel, self.num_classes)
|
||||
self.softmax = nn.Softmax(dim=-1)
|
||||
|
||||
def _2d_attention(self,
|
||||
y_prev: torch.Tensor,
|
||||
@ -525,7 +528,9 @@ class SequentialSARDecoder(BaseDecoder):
|
||||
information.
|
||||
|
||||
Returns:
|
||||
Tensor: A raw logit tensor of shape :math:`(N, T, C)`.
|
||||
Tensor: Character probabilities. of shape
|
||||
:math:`(N, self.max_seq_len, C)` where :math:`C` is
|
||||
``num_classes``.
|
||||
"""
|
||||
valid_ratios = None
|
||||
if data_samples is not None:
|
||||
@ -559,8 +564,6 @@ class SequentialSARDecoder(BaseDecoder):
|
||||
hx2,
|
||||
cx2,
|
||||
valid_ratios=valid_ratios)
|
||||
|
||||
y = F.softmax(y, -1)
|
||||
_, max_idx = torch.max(y, dim=1, keepdim=False)
|
||||
char_embedding = self.embedding(max_idx)
|
||||
y_prev = char_embedding
|
||||
@ -568,4 +571,4 @@ class SequentialSARDecoder(BaseDecoder):
|
||||
|
||||
outputs = torch.stack(outputs, 1)
|
||||
|
||||
return outputs
|
||||
return self.softmax(outputs)
|
||||
|
@ -4,7 +4,6 @@ from typing import Dict, Optional, Sequence, Union
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
from mmocr.data import TextRecogDataSample
|
||||
from mmocr.models.textrecog.dictionary import Dictionary
|
||||
@ -92,6 +91,7 @@ class SequenceAttentionDecoder(BaseDecoder):
|
||||
self.prediction = nn.Linear(
|
||||
dim_model if encode_value else dim_input,
|
||||
self.dictionary.num_classes)
|
||||
self.softmax = nn.Softmax(dim=-1)
|
||||
|
||||
def forward_train(
|
||||
self,
|
||||
@ -176,8 +176,9 @@ class SequenceAttentionDecoder(BaseDecoder):
|
||||
to None.
|
||||
|
||||
Returns:
|
||||
Tensor: The output logit sequence tensor of shape
|
||||
:math:`(N, T, C)`.
|
||||
Tensor: Character probabilities. of shape
|
||||
:math:`(N, self.max_seq_len, C)` where :math:`C` is
|
||||
``num_classes``.
|
||||
"""
|
||||
seq_len = self.max_seq_len
|
||||
batch_size = feat.size(0)
|
||||
@ -196,7 +197,7 @@ class SequenceAttentionDecoder(BaseDecoder):
|
||||
|
||||
outputs = torch.stack(outputs, 1)
|
||||
|
||||
return outputs
|
||||
return self.softmax(outputs)
|
||||
|
||||
def forward_test_step(self, feat: torch.Tensor, out_enc: torch.Tensor,
|
||||
decode_sequence: torch.Tensor, current_step: int,
|
||||
@ -255,6 +256,5 @@ class SequenceAttentionDecoder(BaseDecoder):
|
||||
|
||||
if not self.return_feature:
|
||||
out = self.prediction(out)
|
||||
out = F.softmax(out, dim=-1)
|
||||
|
||||
return out
|
||||
|
@ -14,20 +14,22 @@ class AttentionPostprocessor(BaseTextRecogPostprocessor):
|
||||
|
||||
def get_single_prediction(
|
||||
self,
|
||||
output: torch.Tensor,
|
||||
probs: torch.Tensor,
|
||||
data_sample: Optional[TextRecogDataSample] = None,
|
||||
) -> Tuple[Sequence[int], Sequence[float]]:
|
||||
"""Convert the output of a single image to index and score.
|
||||
"""Convert the output probabilities of a single image to index and
|
||||
score.
|
||||
|
||||
Args:
|
||||
output (torch.Tensor): Single image output.
|
||||
probs (torch.Tensor): Character probabilities with shape
|
||||
:math:`(T, C)`.
|
||||
data_sample (TextRecogDataSample, optional): Datasample of an
|
||||
image. Defaults to None.
|
||||
|
||||
Returns:
|
||||
tuple(list[int], list[float]): index and score.
|
||||
"""
|
||||
max_value, max_idx = torch.max(output, -1)
|
||||
max_value, max_idx = torch.max(probs, -1)
|
||||
index, score = [], []
|
||||
output_index = max_idx.cpu().detach().numpy().tolist()
|
||||
output_score = max_value.cpu().detach().numpy().tolist()
|
||||
|
@ -65,13 +65,15 @@ class BaseTextRecogPostprocessor:
|
||||
|
||||
def get_single_prediction(
|
||||
self,
|
||||
output: torch.Tensor,
|
||||
probs: torch.Tensor,
|
||||
data_sample: Optional[TextRecogDataSample] = None,
|
||||
) -> Tuple[Sequence[int], Sequence[float]]:
|
||||
"""Convert the output of a single image to index and score.
|
||||
"""Convert the output probabilities of a single image to index and
|
||||
score.
|
||||
|
||||
Args:
|
||||
output (torch.Tensor): Single image output.
|
||||
probs (torch.Tensor): Character probabilities with shape
|
||||
:math:`(T, C)`.
|
||||
data_sample (TextRecogDataSample): Datasample of an image.
|
||||
|
||||
Returns:
|
||||
@ -80,13 +82,13 @@ class BaseTextRecogPostprocessor:
|
||||
raise NotImplementedError
|
||||
|
||||
def __call__(
|
||||
self, outputs: torch.Tensor,
|
||||
data_samples: Sequence[TextRecogDataSample]
|
||||
self, probs: torch.Tensor, data_samples: Sequence[TextRecogDataSample]
|
||||
) -> Sequence[TextRecogDataSample]:
|
||||
"""Convert outputs to strings and scores.
|
||||
|
||||
Args:
|
||||
outputs (torch.Tensor): The model outputs in size: N * T * C
|
||||
probs (torch.Tensor): Batched character probabilities, the model's
|
||||
softmaxed output in size: :math:`(N, T, C)`.
|
||||
data_samples (list[TextRecogDataSample]): The list of
|
||||
TextRecogDataSample.
|
||||
|
||||
@ -94,10 +96,10 @@ class BaseTextRecogPostprocessor:
|
||||
list(TextRecogDataSample): The list of TextRecogDataSample. It
|
||||
usually contain ``pred_text`` information.
|
||||
"""
|
||||
batch_size = outputs.size(0)
|
||||
batch_size = probs.size(0)
|
||||
|
||||
for idx in range(batch_size):
|
||||
index, score = self.get_single_prediction(outputs[idx, :, :],
|
||||
index, score = self.get_single_prediction(probs[idx, :, :],
|
||||
data_samples[idx])
|
||||
text = self.dictionary.idx2str(index)
|
||||
pred_text = LabelData()
|
||||
|
@ -3,7 +3,6 @@ import math
|
||||
from typing import Sequence, Tuple
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
|
||||
from mmocr.data import TextRecogDataSample
|
||||
from mmocr.registry import MODELS
|
||||
@ -15,20 +14,22 @@ from .base_textrecog_postprocessor import BaseTextRecogPostprocessor
|
||||
class CTCPostProcessor(BaseTextRecogPostprocessor):
|
||||
"""PostProcessor for CTC."""
|
||||
|
||||
def get_single_prediction(self, output: torch.Tensor,
|
||||
def get_single_prediction(self, probs: torch.Tensor,
|
||||
data_sample: TextRecogDataSample
|
||||
) -> Tuple[Sequence[int], Sequence[float]]:
|
||||
"""Convert the output of a single image to index and score.
|
||||
"""Convert the output probabilities of a single image to index and
|
||||
score.
|
||||
|
||||
Args:
|
||||
output (torch.Tensor): Single image output.
|
||||
probs (torch.Tensor): Character probabilities with shape
|
||||
:math:`(T, C)`.
|
||||
data_sample (TextRecogDataSample): Datasample of an image.
|
||||
|
||||
Returns:
|
||||
tuple(list[int], list[float]): index and score.
|
||||
"""
|
||||
feat_len = output.size(0)
|
||||
max_value, max_idx = torch.max(output, -1)
|
||||
feat_len = probs.size(0)
|
||||
max_value, max_idx = torch.max(probs, -1)
|
||||
valid_ratio = data_sample.get('valid_ratio', 1)
|
||||
decode_len = min(feat_len, math.ceil(feat_len * valid_ratio))
|
||||
index = []
|
||||
@ -47,7 +48,5 @@ class CTCPostProcessor(BaseTextRecogPostprocessor):
|
||||
self, outputs: torch.Tensor,
|
||||
data_samples: Sequence[TextRecogDataSample]
|
||||
) -> Sequence[TextRecogDataSample]:
|
||||
# TODO move to decoder
|
||||
outputs = F.softmax(outputs, dim=2)
|
||||
outputs = outputs.cpu().detach()
|
||||
return super().__call__(outputs, data_samples)
|
||||
|
Loading…
x
Reference in New Issue
Block a user