[TODO] Fix score

This commit is contained in:
jiangqing.vendor 2022-07-13 11:42:29 +00:00 committed by gaotongxiao
parent e73665029b
commit 7813e18a6c
11 changed files with 68 additions and 54 deletions

View File

@ -77,6 +77,7 @@ class ABIFuser(BaseDecoder):
warnings.warn(f"Using max_seq_len {cfg['max_seq_len']} "
"in decoder's config.")
setattr(self, cfg_name, MODELS.build(cfg))
self.softmax = nn.Softmax(dim=-1)
def forward_train(
self,
@ -138,7 +139,9 @@ class ABIFuser(BaseDecoder):
DataSample placeholder. Defaults to None.
Returns:
torch.Tensor: Raw logits.
Tensor: Character probabilities. of shape
:math:`(N, self.max_seq_len, C)` where :math:`C` is
``num_classes``.
"""
raw_result = self.forward_train(feat, logits, data_samples)
@ -149,7 +152,7 @@ class ABIFuser(BaseDecoder):
else:
ret = raw_result['out_vis']['logits']
return ret
return self.softmax(ret)
def fuse(self, l_feature: torch.Tensor, v_feature: torch.Tensor) -> Dict:
"""Mix and align visual feature and linguistic feature.

View File

@ -54,6 +54,7 @@ class CRNNDecoder(BaseDecoder):
self.dictionary.num_classes,
kernel_size=1,
stride=1)
self.softmax = nn.Softmax(dim=-1)
def forward_train(
self,
@ -101,7 +102,8 @@ class CRNNDecoder(BaseDecoder):
Defaults to None.
Returns:
Tensor: The raw logit tensor. Shape :math:`(N, W, C)` where
:math:`C` is ``num_classes``.
Tensor: Character probabilities. of shape
:math:`(N, self.max_seq_len, C)` where :math:`C` is
``num_classes``.
"""
return self.forward_train(feat, out_enc, data_samples)
return self.softmax(self.forward_train(feat, out_enc, data_samples))

View File

@ -5,7 +5,6 @@ from typing import Dict, Optional, Sequence, Union
import torch
import torch.nn as nn
import torch.nn.functional as F
from mmcv.cnn.bricks.transformer import BaseTransformerLayer
from mmcv.runner import ModuleList
@ -148,6 +147,7 @@ class MasterDecoder(BaseDecoder):
self.feat_positional_encoding = PositionalEncoding(
d_hid=d_model, n_position=self.feat_size, dropout=feat_pe_drop)
self.norm = nn.LayerNorm(d_model)
self.softmax = nn.Softmax(dim=-1)
def make_target_mask(self, tgt: torch.Tensor,
device: torch.device) -> torch.Tensor:
@ -248,8 +248,8 @@ class MasterDecoder(BaseDecoder):
data_samples (list[TextRecogDataSample]): Unused.
Returns:
Tensor: The raw logit tensor.
Shape :math:`(N, self.max_seq_len, C)` where :math:`C` is
Tensor: Character probabilities. of shape
:math:`(N, self.max_seq_len, C)` where :math:`C` is
``num_classes``.
"""
@ -270,7 +270,6 @@ class MasterDecoder(BaseDecoder):
target_mask = self.make_target_mask(input, device=feat.device)
out = self.decode(input, feat, None, target_mask)
output = out
prob = F.softmax(out, dim=-1)
_, next_word = torch.max(prob, dim=-1)
_, next_word = torch.max(out, dim=-1)
input = torch.cat([input, next_word[:, -1].unsqueeze(-1)], dim=1)
return output
return self.softmax(output)

View File

@ -4,7 +4,6 @@ from typing import Dict, List, Optional, Sequence, Union
import torch
import torch.nn as nn
import torch.nn.functional as F
from mmcv.runner import ModuleList
from mmocr.data import TextRecogDataSample
@ -86,6 +85,7 @@ class NRTRDecoder(BaseDecoder):
pred_num_class = self.dictionary.num_classes
self.classifier = nn.Linear(d_model, pred_num_class)
self.softmax = nn.Softmax(dim=-1)
def _get_target_mask(self, trg_seq: torch.Tensor) -> torch.Tensor:
"""Generate mask for target sequence.
@ -225,8 +225,8 @@ class NRTRDecoder(BaseDecoder):
information. Defaults to None.
Returns:
Tensor: The raw logit tensor.
Shape :math:`(N, self.max_seq_len, C)` where :math:`C` is
Tensor: Character probabilities. of shape
:math:`(N, self.max_seq_len, C)` where :math:`C` is
``num_classes``.
"""
valid_ratios = []
@ -246,8 +246,7 @@ class NRTRDecoder(BaseDecoder):
decoder_output = self._attention(
init_target_seq, out_enc, src_mask=src_mask)
# bsz * seq_len * C
step_result = F.softmax(
self.classifier(decoder_output[:, step, :]), dim=-1)
step_result = self.classifier(decoder_output[:, step, :])
# bsz * num_classes
outputs.append(step_result)
_, step_max_index = torch.max(step_result, dim=-1)
@ -255,4 +254,4 @@ class NRTRDecoder(BaseDecoder):
outputs = torch.stack(outputs, dim=1)
return outputs
return self.softmax(outputs)

View File

@ -83,6 +83,7 @@ class PositionAttentionDecoder(BaseDecoder):
self.prediction = nn.Linear(
dim_model if encode_value else dim_input,
self.dictionary.num_classes)
self.softmax = nn.Softmax(dim=-1)
def _get_position_index(self,
length: int,
@ -174,7 +175,7 @@ class PositionAttentionDecoder(BaseDecoder):
to None.
Returns:
Tensor: A raw logit tensor of shape :math:`(N, T, C)` if
Tensor: Character probabilities of shape :math:`(N, T, C)` if
``return_feature=False``. Otherwise it would be the hidden feature
before the prediction projection layer, whose shape is
:math:`(N, T, D_m)`.
@ -216,4 +217,4 @@ class PositionAttentionDecoder(BaseDecoder):
if self.return_feature:
return attn_out
return self.prediction(attn_out)
return self.softmax(self.prediction(attn_out))

View File

@ -4,7 +4,6 @@ from typing import Dict, List, Optional, Sequence, Union
import torch
import torch.nn as nn
import torch.nn.functional as F
from mmocr.data import TextRecogDataSample
from mmocr.models.textrecog.dictionary import Dictionary
@ -78,6 +77,7 @@ class RobustScannerFuser(BaseDecoder):
self.glu_layer = nn.GLU(dim=dim)
self.prediction = nn.Linear(
int(in_channels / 2), self.dictionary.num_classes)
self.softmax = nn.Softmax(dim=-1)
def forward_train(
self,
@ -117,6 +117,11 @@ class RobustScannerFuser(BaseDecoder):
data_samples (Sequence[TextRecogDataSample]): Batch of
TextRecogDataSample, containing vaild_ratio information.
Defaults to None.
Returns:
Tensor: Character probabilities. of shape
:math:`(N, self.max_seq_len, C)` where :math:`C` is
``num_classes``.
"""
position_glimpse = self.position_decoder(feat, out_enc, data_samples)
@ -133,10 +138,9 @@ class RobustScannerFuser(BaseDecoder):
output = self.linear_layer(fusion_input)
output = self.glu_layer(output)
output = self.prediction(output)
output = F.softmax(output, -1)
_, max_idx = torch.max(output, dim=1, keepdim=False)
if step < self.max_seq_len - 1:
decode_sequence[:, step + 1] = max_idx
outputs.append(output)
outputs = torch.stack(outputs, 1)
return outputs
return self.softmax(outputs)

View File

@ -115,6 +115,7 @@ class ParallelSARDecoder(BaseDecoder):
else:
fc_in_channel = d_model
self.prediction = nn.Linear(fc_in_channel, self.num_classes)
self.softmax = nn.Softmax(dim=-1)
def _2d_attention(self,
decoder_input: torch.Tensor,
@ -239,7 +240,9 @@ class ParallelSARDecoder(BaseDecoder):
information. Defaults to None.
Returns:
Tensor: A raw logit tensor of shape :math:`(N, T, C)`.
Tensor: Character probabilities. of shape
:math:`(N, self.max_seq_len, C)` where :math:`C` is
``num_classes``.
"""
if data_samples is not None:
assert len(data_samples) == feat.size(0)
@ -273,7 +276,6 @@ class ParallelSARDecoder(BaseDecoder):
decoder_output = self._2d_attention(
decoder_input, feat, out_enc, valid_ratios=valid_ratios)
char_output = decoder_output[:, i, :] # bsz * num_classes
char_output = F.softmax(char_output, -1)
outputs.append(char_output)
_, max_idx = torch.max(char_output, dim=1, keepdim=False)
char_embedding = self.embedding(max_idx) # bsz * emb_dim
@ -282,7 +284,7 @@ class ParallelSARDecoder(BaseDecoder):
outputs = torch.stack(outputs, 1) # bsz * seq_len * num_classes
return outputs
return self.softmax(outputs)
@MODELS.register_module()
@ -386,6 +388,7 @@ class SequentialSARDecoder(BaseDecoder):
else:
fc_in_channel = d_model
self.prediction = nn.Linear(fc_in_channel, self.num_classes)
self.softmax = nn.Softmax(dim=-1)
def _2d_attention(self,
y_prev: torch.Tensor,
@ -525,7 +528,9 @@ class SequentialSARDecoder(BaseDecoder):
information.
Returns:
Tensor: A raw logit tensor of shape :math:`(N, T, C)`.
Tensor: Character probabilities. of shape
:math:`(N, self.max_seq_len, C)` where :math:`C` is
``num_classes``.
"""
valid_ratios = None
if data_samples is not None:
@ -559,8 +564,6 @@ class SequentialSARDecoder(BaseDecoder):
hx2,
cx2,
valid_ratios=valid_ratios)
y = F.softmax(y, -1)
_, max_idx = torch.max(y, dim=1, keepdim=False)
char_embedding = self.embedding(max_idx)
y_prev = char_embedding
@ -568,4 +571,4 @@ class SequentialSARDecoder(BaseDecoder):
outputs = torch.stack(outputs, 1)
return outputs
return self.softmax(outputs)

View File

@ -4,7 +4,6 @@ from typing import Dict, Optional, Sequence, Union
import torch
import torch.nn as nn
import torch.nn.functional as F
from mmocr.data import TextRecogDataSample
from mmocr.models.textrecog.dictionary import Dictionary
@ -92,6 +91,7 @@ class SequenceAttentionDecoder(BaseDecoder):
self.prediction = nn.Linear(
dim_model if encode_value else dim_input,
self.dictionary.num_classes)
self.softmax = nn.Softmax(dim=-1)
def forward_train(
self,
@ -176,8 +176,9 @@ class SequenceAttentionDecoder(BaseDecoder):
to None.
Returns:
Tensor: The output logit sequence tensor of shape
:math:`(N, T, C)`.
Tensor: Character probabilities. of shape
:math:`(N, self.max_seq_len, C)` where :math:`C` is
``num_classes``.
"""
seq_len = self.max_seq_len
batch_size = feat.size(0)
@ -196,7 +197,7 @@ class SequenceAttentionDecoder(BaseDecoder):
outputs = torch.stack(outputs, 1)
return outputs
return self.softmax(outputs)
def forward_test_step(self, feat: torch.Tensor, out_enc: torch.Tensor,
decode_sequence: torch.Tensor, current_step: int,
@ -255,6 +256,5 @@ class SequenceAttentionDecoder(BaseDecoder):
if not self.return_feature:
out = self.prediction(out)
out = F.softmax(out, dim=-1)
return out

View File

@ -14,20 +14,22 @@ class AttentionPostprocessor(BaseTextRecogPostprocessor):
def get_single_prediction(
self,
output: torch.Tensor,
probs: torch.Tensor,
data_sample: Optional[TextRecogDataSample] = None,
) -> Tuple[Sequence[int], Sequence[float]]:
"""Convert the output of a single image to index and score.
"""Convert the output probabilities of a single image to index and
score.
Args:
output (torch.Tensor): Single image output.
probs (torch.Tensor): Character probabilities with shape
:math:`(T, C)`.
data_sample (TextRecogDataSample, optional): Datasample of an
image. Defaults to None.
Returns:
tuple(list[int], list[float]): index and score.
"""
max_value, max_idx = torch.max(output, -1)
max_value, max_idx = torch.max(probs, -1)
index, score = [], []
output_index = max_idx.cpu().detach().numpy().tolist()
output_score = max_value.cpu().detach().numpy().tolist()

View File

@ -65,13 +65,15 @@ class BaseTextRecogPostprocessor:
def get_single_prediction(
self,
output: torch.Tensor,
probs: torch.Tensor,
data_sample: Optional[TextRecogDataSample] = None,
) -> Tuple[Sequence[int], Sequence[float]]:
"""Convert the output of a single image to index and score.
"""Convert the output probabilities of a single image to index and
score.
Args:
output (torch.Tensor): Single image output.
probs (torch.Tensor): Character probabilities with shape
:math:`(T, C)`.
data_sample (TextRecogDataSample): Datasample of an image.
Returns:
@ -80,13 +82,13 @@ class BaseTextRecogPostprocessor:
raise NotImplementedError
def __call__(
self, outputs: torch.Tensor,
data_samples: Sequence[TextRecogDataSample]
self, probs: torch.Tensor, data_samples: Sequence[TextRecogDataSample]
) -> Sequence[TextRecogDataSample]:
"""Convert outputs to strings and scores.
Args:
outputs (torch.Tensor): The model outputs in size: N * T * C
probs (torch.Tensor): Batched character probabilities, the model's
softmaxed output in size: :math:`(N, T, C)`.
data_samples (list[TextRecogDataSample]): The list of
TextRecogDataSample.
@ -94,10 +96,10 @@ class BaseTextRecogPostprocessor:
list(TextRecogDataSample): The list of TextRecogDataSample. It
usually contain ``pred_text`` information.
"""
batch_size = outputs.size(0)
batch_size = probs.size(0)
for idx in range(batch_size):
index, score = self.get_single_prediction(outputs[idx, :, :],
index, score = self.get_single_prediction(probs[idx, :, :],
data_samples[idx])
text = self.dictionary.idx2str(index)
pred_text = LabelData()

View File

@ -3,7 +3,6 @@ import math
from typing import Sequence, Tuple
import torch
import torch.nn.functional as F
from mmocr.data import TextRecogDataSample
from mmocr.registry import MODELS
@ -15,20 +14,22 @@ from .base_textrecog_postprocessor import BaseTextRecogPostprocessor
class CTCPostProcessor(BaseTextRecogPostprocessor):
"""PostProcessor for CTC."""
def get_single_prediction(self, output: torch.Tensor,
def get_single_prediction(self, probs: torch.Tensor,
data_sample: TextRecogDataSample
) -> Tuple[Sequence[int], Sequence[float]]:
"""Convert the output of a single image to index and score.
"""Convert the output probabilities of a single image to index and
score.
Args:
output (torch.Tensor): Single image output.
probs (torch.Tensor): Character probabilities with shape
:math:`(T, C)`.
data_sample (TextRecogDataSample): Datasample of an image.
Returns:
tuple(list[int], list[float]): index and score.
"""
feat_len = output.size(0)
max_value, max_idx = torch.max(output, -1)
feat_len = probs.size(0)
max_value, max_idx = torch.max(probs, -1)
valid_ratio = data_sample.get('valid_ratio', 1)
decode_len = min(feat_len, math.ceil(feat_len * valid_ratio))
index = []
@ -47,7 +48,5 @@ class CTCPostProcessor(BaseTextRecogPostprocessor):
self, outputs: torch.Tensor,
data_samples: Sequence[TextRecogDataSample]
) -> Sequence[TextRecogDataSample]:
# TODO move to decoder
outputs = F.softmax(outputs, dim=2)
outputs = outputs.cpu().detach()
return super().__call__(outputs, data_samples)