This commit is contained in:
xinke-wang 2022-06-23 17:28:48 +08:00 committed by gaotongxiao
parent 7be4dc1bca
commit fe43b4e767

View File

@ -72,7 +72,10 @@ class CTCLoss(BaseRecogLoss):
outputs = torch.log_softmax(outputs, dim=2)
bsz, seq_len = outputs.size(0), outputs.size(1)
outputs_for_loss = outputs.permute(1, 0, 2).contiguous() # T * N * C
targets = [data_sample.gt_text.indexes for data_sample in data_samples]
targets = [
data_sample.gt_text.indexes[:seq_len]
for data_sample in data_samples
]
target_lengths = torch.IntTensor([len(t) for t in targets])
target_lengths = torch.clamp(target_lengths, min=1, max=seq_len).long()
input_lengths = torch.full(