diff --git a/mmocr/models/textrecog/losses/base_recog_loss.py b/mmocr/models/textrecog/losses/base_recog_loss.py index b38246dd..2cc81d89 100644 --- a/mmocr/models/textrecog/losses/base_recog_loss.py +++ b/mmocr/models/textrecog/losses/base_recog_loss.py @@ -20,7 +20,6 @@ class BaseRecogLoss(nn.Module): generated from decoder. Defaults to 40. letter_case (str): There are three options to alter the letter cases of gt texts: - - unchanged: Do not change gt texts. - upper: Convert gt texts into uppercase characters. - lower: Convert gt texts into lowercase characters. @@ -61,7 +60,7 @@ class BaseRecogLoss(nn.Module): - indexes (torch.LongTensor): Character indexes representing gt texts. - - padding_indexes (torch.LongTensor) Character indexes + - padded_indexes (torch.LongTensor) Character indexes representing gt texts, following several padding_idxs until reaching the length of ``max_seq_len``. """ @@ -88,14 +87,14 @@ class BaseRecogLoss(nn.Module): slice_end = src_target.size(0) - 1 src_target = src_target[slice_start:slice_end] if self.dictionary.padding_idx is not None: - padding_indexes = (torch.ones(self.max_seq_len) * - self.dictionary.padding_idx).long() + padded_indexes = (torch.ones(self.max_seq_len) * + self.dictionary.padding_idx).long() char_num = min(src_target.size(0), self.max_seq_len) - padding_indexes[:char_num] = src_target[:char_num] + padded_indexes[:char_num] = src_target[:char_num] else: - padding_indexes = src_target + padded_indexes = src_target # put in DataSample data_sample.gt_text.indexes = indexes - data_sample.gt_text.padding_indexes = padding_indexes + data_sample.gt_text.padded_indexes = padded_indexes return data_samples diff --git a/mmocr/models/textrecog/losses/ctc_loss.py b/mmocr/models/textrecog/losses/ctc_loss.py index dbf8a4ab..12567d25 100644 --- a/mmocr/models/textrecog/losses/ctc_loss.py +++ b/mmocr/models/textrecog/losses/ctc_loss.py @@ -1,19 +1,31 @@ # Copyright (c) OpenMMLab. All rights reserved. import math +from typing import Dict, Sequence, Union import torch import torch.nn as nn +from mmocr.core.data_structures import TextRecogDataSample +from mmocr.models.textrecog.dictionary.dictionary import Dictionary from mmocr.registry import MODELS +from .base_recog_loss import BaseRecogLoss @MODELS.register_module() -class CTCLoss(nn.Module): +class CTCLoss(BaseRecogLoss): """Implementation of loss module for CTC-loss based text recognition. Args: + dictionary (dict or :obj:`Dictionary`): The config for `Dictionary` or + the instance of `Dictionary`. + letter_case (str): There are three options to alter the letter cases + of gt texts: + - unchanged: Do not change gt texts. + - upper: Convert gt texts into uppercase characters. + - lower: Convert gt texts into lowercase characters. + Usually, it only works for English characters. Defaults to + 'unchanged'. flatten (bool): If True, use flattened targets, else padded targets. - blank (int): Blank label. Default 0. reduction (str): Specifies the reduction to apply to the output, should be one of the following: ('none', 'mean', 'sum'). zero_infinity (bool): Whether to zero infinite losses and @@ -23,81 +35,93 @@ class CTCLoss(nn.Module): """ def __init__(self, - flatten=True, - blank=0, - reduction='mean', - zero_infinity=False, - **kwargs): - super().__init__() + dictionary: Union[Dict, Dictionary], + letter_case: str = 'unchanged', + flatten: bool = True, + reduction: str = 'mean', + zero_infinity: bool = False, + **kwargs) -> None: + super().__init__(dictionary=dictionary, letter_case=letter_case) assert isinstance(flatten, bool) - assert isinstance(blank, int) assert isinstance(reduction, str) assert isinstance(zero_infinity, bool) self.flatten = flatten - self.blank = blank self.ctc_loss = nn.CTCLoss( - blank=blank, reduction=reduction, zero_infinity=zero_infinity) + blank=self.dictionary.padding_idx, + reduction=reduction, + zero_infinity=zero_infinity) - def forward(self, outputs, targets_dict, img_metas=None): + def forward(self, outputs: torch.Tensor, + data_samples: Sequence[TextRecogDataSample]) -> Dict: """ Args: outputs (Tensor): A raw logit tensor of shape :math:`(N, T, C)`. - targets_dict (dict): A dict with 3 keys ``target_lengths``, - ``flatten_targets`` and ``targets``. - - - | ``target_lengths`` (Tensor): A tensor of shape :math:`(N)`. - Each item is the length of a word. - - - | ``flatten_targets`` (Tensor): Used if ``self.flatten=True`` - (default). A tensor of shape - (sum(targets_dict['target_lengths'])). Each item is the - index of a character. - - - | ``targets`` (Tensor): Used if ``self.flatten=False``. A - tensor of :math:`(N, T)`. Empty slots are padded with - ``self.blank``. - - img_metas (dict): A dict that contains meta information of input - images. Preferably with the key ``valid_ratio``. + data_samples (list[TextRecogDataSample]): List of + ``TextRecogDataSample`` which are processed by ``get_target``. Returns: dict: The loss dict with key ``loss_ctc``. """ valid_ratios = None - if img_metas is not None: + if data_samples is not None: valid_ratios = [ - img_meta.get('valid_ratio', 1.0) for img_meta in img_metas + img_meta.get('valid_ratio', 1.0) for img_meta in data_samples ] outputs = torch.log_softmax(outputs, dim=2) bsz, seq_len = outputs.size(0), outputs.size(1) outputs_for_loss = outputs.permute(1, 0, 2).contiguous() # T * N * C - - if self.flatten: - targets = targets_dict['flatten_targets'] - else: - targets = torch.full( - size=(bsz, seq_len), fill_value=self.blank, dtype=torch.long) - for idx, tensor in enumerate(targets_dict['targets']): - valid_len = min(tensor.size(0), seq_len) - targets[idx, :valid_len] = tensor[:valid_len] - - target_lengths = targets_dict['target_lengths'] + targets = [data_sample.gt_text.indexes for data_sample in data_samples] + target_lengths = torch.IntTensor([len(t) for t in targets]) target_lengths = torch.clamp(target_lengths, min=1, max=seq_len).long() - input_lengths = torch.full( size=(bsz, ), fill_value=seq_len, dtype=torch.long) - if not self.flatten and valid_ratios is not None: - input_lengths = [ - math.ceil(valid_ratio * seq_len) - for valid_ratio in valid_ratios - ] - input_lengths = torch.Tensor(input_lengths).long() + if self.flatten: + targets = torch.cat(targets) + else: + padded_targets = torch.full( + size=(bsz, seq_len), + fill_value=self.dictionary.padding_idx, + dtype=torch.long) + for idx, valid_len in enumerate(target_lengths): + padded_targets[idx, :valid_len] = targets[idx][:valid_len] + targets = padded_targets + if valid_ratios is not None: + input_lengths = [ + math.ceil(valid_ratio * seq_len) + for valid_ratio in valid_ratios + ] + input_lengths = torch.Tensor(input_lengths).long() loss_ctc = self.ctc_loss(outputs_for_loss, targets, input_lengths, target_lengths) - losses = dict(loss_ctc=loss_ctc) return losses + + def get_targets( + self, data_samples: Sequence[TextRecogDataSample] + ) -> Sequence[TextRecogDataSample]: + """Target generator. + + Args: + data_samples (list[TextRecogDataSample]): It usually includes + ``gt_text`` information. + + Returns: + + list[TextRecogDataSample]: updated data_samples. It will add two + key in data_sample: + + - indexes (torch.LongTensor): The index corresponding to the item. + """ + + for data_sample in data_samples: + text = data_sample.gt_text.item + if self.letter_case in ['upper', 'lower']: + text = getattr(text, self.letter_case)() + indexes = self.dictionary.str2idx(text) + indexes = torch.IntTensor(indexes) + data_sample.gt_text.indexes = indexes + return data_samples diff --git a/tests/test_models/test_textrecog/test_loss/test_base_recog_loss.py b/tests/test_models/test_textrecog/test_loss/test_base_recog_loss.py index 28d23008..18f3343f 100644 --- a/tests/test_models/test_textrecog/test_loss/test_base_recog_loss.py +++ b/tests/test_models/test_textrecog/test_loss/test_base_recog_loss.py @@ -76,7 +76,7 @@ class TestBaseRecogLoss(TestCase): torch.LongTensor([0, 1, 2, 3])) padding_idx = dictionary.padding_idx assert self._equal( - target_data_samples[0].gt_text.padding_indexes, + target_data_samples[0].gt_text.padded_indexes, torch.LongTensor([ dictionary.start_idx, 0, 1, 2, 3, dictionary.end_idx, padding_idx, padding_idx, padding_idx, padding_idx @@ -95,7 +95,7 @@ class TestBaseRecogLoss(TestCase): assert self._equal(target_data_samples[0].gt_text.indexes, torch.LongTensor([0, 1, 2, 3])) padding_idx = dictionary.padding_idx - assert self._equal(target_data_samples[0].gt_text.padding_indexes, + assert self._equal(target_data_samples[0].gt_text.padded_indexes, torch.LongTensor([0, 1, 2])) dict_cfg = dict( @@ -112,7 +112,7 @@ class TestBaseRecogLoss(TestCase): target_data_samples = base_recog_loss.get_targets([data_sample]) assert self._equal(target_data_samples[0].gt_text.indexes, torch.LongTensor([0, 1, 2, 3])) - assert self._equal(target_data_samples[0].gt_text.padding_indexes, + assert self._equal(target_data_samples[0].gt_text.padded_indexes, torch.LongTensor([0, 1, 2, 3])) target_data_samples = base_recog_loss.get_targets([]) diff --git a/tests/test_models/test_textrecog/test_loss/test_ctc_loss.py b/tests/test_models/test_textrecog/test_loss/test_ctc_loss.py new file mode 100644 index 00000000..fa2db098 --- /dev/null +++ b/tests/test_models/test_textrecog/test_loss/test_ctc_loss.py @@ -0,0 +1,85 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import os.path as osp +import tempfile +from unittest import TestCase + +import torch +from mmengine.data import LabelData + +from mmocr.core.data_structures import TextRecogDataSample +from mmocr.models.textrecog.dictionary import Dictionary +from mmocr.models.textrecog.losses import CTCLoss + + +class TestCTCLoss(TestCase): + + def test_ctc_loss(self): + tmp_dir = tempfile.TemporaryDirectory() + # create dummy data + dict_file = osp.join(tmp_dir.name, 'fake_chars.txt') + self._create_dummy_dict_file(dict_file) + + dictionary = Dictionary(dict_file=dict_file, with_padding=True) + with self.assertRaises(AssertionError): + CTCLoss(dictionary=dictionary, flatten='flatten') + with self.assertRaises(AssertionError): + CTCLoss(dictionary=dictionary, reduction=1) + with self.assertRaises(AssertionError): + CTCLoss(dictionary=dictionary, zero_infinity='zero') + + outputs = torch.zeros(2, 40, 37) + datasample1 = TextRecogDataSample() + gt_text1 = LabelData(item='hell') + datasample1.gt_text = gt_text1 + datasample2 = datasample1.clone() + gt_text2 = LabelData(item='owrd') + datasample2.gt_text = gt_text2 + data_samples = [datasample1, datasample2] + ctc_loss = CTCLoss(dictionary=dictionary) + data_samples = ctc_loss.get_targets(data_samples) + losses = ctc_loss(outputs, data_samples) + assert isinstance(losses, dict) + assert 'loss_ctc' in losses + assert torch.allclose(losses['loss_ctc'], + torch.tensor(losses['loss_ctc'].item()).float()) + # test flatten = False + ctc_loss = CTCLoss(dictionary=dictionary, flatten=False) + losses = ctc_loss(outputs, data_samples) + assert isinstance(losses, dict) + assert 'loss_ctc' in losses + assert torch.allclose(losses['loss_ctc'], + torch.tensor(losses['loss_ctc'].item()).float()) + tmp_dir.cleanup() + + def _create_dummy_dict_file(self, dict_file): + chars = list('helowrd') + with open(dict_file, 'w') as fw: + for char in chars: + fw.write(char + '\n') + + def test_get_targets(self): + tmp_dir = tempfile.TemporaryDirectory() + # create dummy data + dict_file = osp.join(tmp_dir.name, 'fake_chars.txt') + self._create_dummy_dict_file(dict_file) + + dictionary = Dictionary(dict_file=dict_file, with_padding=True) + loss = CTCLoss(dictionary=dictionary, letter_case='lower') + # test encode str to tensor + datasample1 = TextRecogDataSample() + gt_text1 = LabelData(item='hell') + datasample1.gt_text = gt_text1 + datasample2 = datasample1.clone() + gt_text2 = LabelData(item='owrd') + datasample2.gt_text = gt_text2 + + data_samples = [datasample1, datasample2] + expect_tensor1 = torch.IntTensor([0, 1, 2, 2]) + expect_tensor2 = torch.IntTensor([3, 4, 5, 6]) + + data_samples = loss.get_targets(data_samples) + self.assertTrue( + torch.allclose(data_samples[0].gt_text.indexes, expect_tensor1)) + self.assertTrue( + torch.allclose(data_samples[1].gt_text.indexes, expect_tensor2)) + tmp_dir.cleanup()