diff --git a/mmocr/models/textrecog/losses/ctc_loss.py b/mmocr/models/textrecog/losses/ctc_loss.py index 12567d25..31bc9b79 100644 --- a/mmocr/models/textrecog/losses/ctc_loss.py +++ b/mmocr/models/textrecog/losses/ctc_loss.py @@ -72,7 +72,10 @@ class CTCLoss(BaseRecogLoss): outputs = torch.log_softmax(outputs, dim=2) bsz, seq_len = outputs.size(0), outputs.size(1) outputs_for_loss = outputs.permute(1, 0, 2).contiguous() # T * N * C - targets = [data_sample.gt_text.indexes for data_sample in data_samples] + targets = [ + data_sample.gt_text.indexes[:seq_len] + for data_sample in data_samples + ] target_lengths = torch.IntTensor([len(t) for t in targets]) target_lengths = torch.clamp(target_lengths, min=1, max=seq_len).long() input_lengths = torch.full(