diff --git a/mmocr/models/textrecog/losses/__init__.py b/mmocr/models/textrecog/losses/__init__.py index 3d2a363e..1f6e2631 100755 --- a/mmocr/models/textrecog/losses/__init__.py +++ b/mmocr/models/textrecog/losses/__init__.py @@ -1,8 +1,8 @@ # Copyright (c) OpenMMLab. All rights reserved. +from .abi_loss import ABILoss from .base_recog_loss import BaseRecogLoss from .ce_loss import CELoss from .ctc_loss import CTCLoss -from .mix_loss import ABILoss from .seg_loss import SegLoss __all__ = ['BaseRecogLoss', 'CELoss', 'CTCLoss', 'SegLoss', 'ABILoss'] diff --git a/mmocr/models/textrecog/losses/abi_loss.py b/mmocr/models/textrecog/losses/abi_loss.py new file mode 100644 index 00000000..840e2c07 --- /dev/null +++ b/mmocr/models/textrecog/losses/abi_loss.py @@ -0,0 +1,99 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import Dict, Sequence, Union + +import torch + +from mmocr.data import TextRecogDataSample +from mmocr.models.textrecog.dictionary.dictionary import Dictionary +from mmocr.registry import MODELS +from .base_recog_loss import BaseRecogLoss +from .ce_loss import CELoss + + +@MODELS.register_module() +class ABILoss(BaseRecogLoss): + """Implementation of ABINet multiloss that allows mixing different types of + losses with weights. + + Args: + dictionary (dict or :obj:`Dictionary`): The config for `Dictionary` or + the instance of `Dictionary`. + max_seq_len (int): Maximum sequence length. The sequence is usually + generated from decoder. Defaults to 40. + letter_case (str): There are three options to alter the letter cases + of gt texts: + - unchanged: Do not change gt texts. + - upper: Convert gt texts into uppercase characters. + - lower: Convert gt texts into lowercase characters. + Usually, it only works for English characters. Defaults to + 'unchanged'. + weight_vis (float or int): The weight of vision decoder loss. Defaults + to 1.0. + weight_dec (float or int): The weight of language decoder loss. + Defaults to 1.0. + weight_fusion (float or int): The weight of fuser (aligner) loss. + Defaults to 1.0. + """ + + def __init__(self, + dictionary: Union[Dict, Dictionary], + max_seq_len: int = 40, + letter_case: str = 'unchanged', + weight_vis: Union[float, int] = 1.0, + weight_lang: Union[float, int] = 1.0, + weight_fusion: Union[float, int] = 1.0, + **kwargs) -> None: + assert isinstance(weight_vis, (float, int)) + assert isinstance(weight_lang, (float, int)) + assert isinstance(weight_fusion, (float, int)) + super().__init__( + dictionary=dictionary, + max_seq_len=max_seq_len, + letter_case=letter_case) + self.weight_vis = weight_vis + self.weight_lang = weight_lang + self.weight_fusion = weight_fusion + self._ce_loss = CELoss( + self.dictionary, + max_seq_len, + letter_case, + reduction='mean', + ignore_first_char=True) + + def forward(self, outputs: Dict, + data_samples: Sequence[TextRecogDataSample]) -> Dict: + """ + Args: + outputs (dict): The output dictionary with at least one of + ``out_vis``, ``out_langs`` and ``out_fusers`` specified. + data_samples (list[TextRecogDataSample]): List of + ``TextRecogDataSample`` which are processed by ``get_target``. + + Returns: + dict: A loss dictionary with ``loss_visual``, ``loss_lang`` and + ``loss_fusion``. Each should either be the loss tensor or None if + the output of its corresponding module is not given. + """ + assert 'out_vis' in outputs or \ + 'out_langs' in outputs or 'out_fusers' in outputs + losses = {} + + if outputs.get('out_vis', None): + losses['loss_visual'] = self.weight_vis * self._ce_loss( + outputs['out_vis']['logits'], data_samples)['loss_ce'] + if outputs.get('out_langs', None): + lang_losses = [] + for out_lang in outputs['out_langs']: + lang_losses.append( + self._ce_loss(out_lang['logits'], data_samples)['loss_ce']) + losses['loss_lang'] = self.weight_lang * torch.mean( + torch.stack(lang_losses)) + if outputs.get('out_fusers', None): + fuser_losses = [] + for out_fuser in outputs['out_fusers']: + fuser_losses.append( + self._ce_loss(out_fuser['logits'], + data_samples)['loss_ce']) + losses['loss_fusion'] = self.weight_fusion * torch.mean( + torch.stack(fuser_losses)) + return losses diff --git a/mmocr/models/textrecog/losses/base_recog_loss.py b/mmocr/models/textrecog/losses/base_recog_loss.py index 201a3586..31220d0f 100644 --- a/mmocr/models/textrecog/losses/base_recog_loss.py +++ b/mmocr/models/textrecog/losses/base_recog_loss.py @@ -25,12 +25,21 @@ class BaseRecogLoss(nn.Module): - lower: Convert gt texts into lowercase characters. Usually, it only works for English characters. Defaults to 'unchanged'. + pad_with (str): The padding strategy for ``gt_text.padded_indexes``. + Defaults to 'auto'. Options are: + - 'auto': Use dictionary.padding_idx to pad gt texts, or + dictionary.end_idx if dictionary.padding_idx + is None. + - 'padding': Always use dictionary.padding_idx to pad gt texts. + - 'end': Always use dictionary.end_idx to pad gt texts. + - 'none': Do not pad gt texts. """ def __init__(self, dictionary: Union[Dict, Dictionary], max_seq_len: int = 40, letter_case: str = 'unchanged', + pad_with: str = 'auto', **kwargs) -> None: super().__init__() if isinstance(dictionary, dict): @@ -45,6 +54,25 @@ class BaseRecogLoss(nn.Module): assert letter_case in ['unchanged', 'upper', 'lower'] self.letter_case = letter_case + assert pad_with in ['auto', 'padding', 'end', 'none'] + if pad_with == 'auto': + self.pad_idx = self.dictionary.padding_idx or \ + self.dictionary.end_idx + elif pad_with == 'padding': + self.pad_idx = self.dictionary.padding_idx + elif pad_with == 'end': + self.pad_idx = self.dictionary.end_idx + else: + self.pad_idx = None + if self.pad_idx is None and pad_with != 'none': + if pad_with == 'auto': + raise ValueError('pad_with="auto", but dictionary.end_idx' + ' and dictionary.padding_idx are both None') + else: + raise ValueError( + f'pad_with="{pad_with}", but dictionary.{pad_with}_idx is' + ' None') + def get_targets( self, data_samples: Sequence[TextRecogDataSample] ) -> Sequence[TextRecogDataSample]: @@ -59,10 +87,12 @@ class BaseRecogLoss(nn.Module): added to data_sample: - indexes (torch.LongTensor): Character indexes representing gt - texts. - - padded_indexes (torch.LongTensor) Character indexes - representing gt texts, following several padding_idxs until - reaching the length of ``max_seq_len``. + texts. All special tokens are excluded, except for UKN. + - padded_indexes (torch.LongTensor): Character indexes + representing gt texts with BOS and EOS if applicable, following + several padding indexes until the length reaches ``max_seq_len``. + In particular, if ``pad_with='none'``, no padding will be + applied. """ for data_sample in data_samples: @@ -88,14 +118,13 @@ class BaseRecogLoss(nn.Module): else: slice_end = src_target.size(0) - 1 src_target = src_target[slice_start:slice_end] - if self.dictionary.padding_idx is not None: + if self.pad_idx is not None: padded_indexes = (torch.ones(self.max_seq_len) * - self.dictionary.padding_idx).long() + self.pad_idx).long() char_num = min(src_target.size(0), self.max_seq_len) padded_indexes[:char_num] = src_target[:char_num] else: padded_indexes = src_target - # put in DataSample data_sample.gt_text.indexes = indexes data_sample.gt_text.padded_indexes = padded_indexes diff --git a/mmocr/models/textrecog/losses/ce_loss.py b/mmocr/models/textrecog/losses/ce_loss.py index eff4545a..164f42ff 100644 --- a/mmocr/models/textrecog/losses/ce_loss.py +++ b/mmocr/models/textrecog/losses/ce_loss.py @@ -1,4 +1,5 @@ # Copyright (c) OpenMMLab. All rights reserved. +import warnings from typing import Dict, Sequence, Union import torch @@ -36,6 +37,8 @@ class CELoss(BaseRecogLoss): and 'unknown', which refer to their corresponding special tokens in the dictionary. It will not ignore any special tokens when ignore_char == -1 or 'none'. Defaults to 'padding'. + flatten (bool): Whether to flatten the output and target before + computing CE loss. Defaults to False. reduction (str): Specifies the reduction to apply to the output, should be one of the following: ('none', 'mean', 'sum'). Defaults to 'none'. @@ -82,8 +85,11 @@ class CELoss(BaseRecogLoss): ignore_index = mapping_table.get( ignore_char, self.dictionary._char2idx.get(ignore_char, None)) if ignore_index is None: - raise ValueError( - f'{ignore_char} does not exist in the dictionary') + warnings.warn( + f'{ignore_char} does not exist in the dictionary', + UserWarning) + ignore_index = -1 + self.ignore_char = ignore_char self.ignore_index = ignore_index self.loss_ce = nn.CrossEntropyLoss( diff --git a/mmocr/models/textrecog/losses/mix_loss.py b/mmocr/models/textrecog/losses/mix_loss.py deleted file mode 100644 index 0d5f92e8..00000000 --- a/mmocr/models/textrecog/losses/mix_loss.py +++ /dev/null @@ -1,109 +0,0 @@ -# Copyright (c) OpenMMLab. All rights reserved. -import torch -import torch.nn as nn -import torch.nn.functional as F - -from mmocr.registry import MODELS - - -@MODELS.register_module() -class ABILoss(nn.Module): - """Implementation of ABINet multiloss that allows mixing different types of - losses with weights. - - Args: - enc_weight (float): The weight of encoder loss. Defaults to 1.0. - dec_weight (float): The weight of decoder loss. Defaults to 1.0. - fusion_weight (float): The weight of fuser (aligner) loss. - Defaults to 1.0. - num_classes (int): Number of unique output language tokens. - - Returns: - A dictionary whose key/value pairs are the losses of three modules. - """ - - def __init__(self, - enc_weight=1.0, - dec_weight=1.0, - fusion_weight=1.0, - num_classes=37, - **kwargs): - assert isinstance(enc_weight, float) or isinstance(enc_weight, int) - assert isinstance(dec_weight, float) or isinstance(dec_weight, int) - assert isinstance(fusion_weight, float) or \ - isinstance(fusion_weight, int) - super().__init__() - self.enc_weight = enc_weight - self.dec_weight = dec_weight - self.fusion_weight = fusion_weight - self.num_classes = num_classes - - def _flatten(self, logits, target_lens): - flatten_logits = torch.cat( - [s[:target_lens[i]] for i, s in enumerate(logits)]) - return flatten_logits - - def _ce_loss(self, logits, targets): - targets_one_hot = F.one_hot(targets, self.num_classes) - log_prob = F.log_softmax(logits, dim=-1) - loss = -(targets_one_hot.to(log_prob.device) * log_prob).sum(dim=-1) - return loss.mean() - - def _loss_over_iters(self, outputs, targets): - """ - Args: - outputs (list[Tensor]): Each tensor has shape (N, T, C) where N is - the batch size, T is the sequence length and C is the number of - classes. - targets_dicts (dict): The dictionary with at least `padded_targets` - defined. - """ - iter_num = len(outputs) - dec_outputs = torch.cat(outputs, dim=0) - flatten_targets_iternum = targets.repeat(iter_num) - return self._ce_loss(dec_outputs, flatten_targets_iternum) - - def forward(self, outputs, targets_dict, img_metas=None): - """ - Args: - outputs (dict): The output dictionary with at least one of - ``out_enc``, ``out_dec`` and ``out_fusers`` specified. - targets_dict (dict): The target dictionary containing the key - ``padded_targets``, which represents target sequences in - shape (batch_size, sequence_length). - - Returns: - A loss dictionary with ``loss_visual``, ``loss_lang`` and - ``loss_fusion``. Each should either be the loss tensor or ``0`` if - the output of its corresponding module is not given. - """ - assert 'out_enc' in outputs or \ - 'out_dec' in outputs or 'out_fusers' in outputs - losses = {} - - target_lens = [len(t) for t in targets_dict['targets']] - flatten_targets = torch.cat([t for t in targets_dict['targets']]) - - if outputs.get('out_enc', None): - enc_input = self._flatten(outputs['out_enc']['logits'], - target_lens) - enc_loss = self._ce_loss(enc_input, - flatten_targets) * self.enc_weight - losses['loss_visual'] = enc_loss - if outputs.get('out_decs', None): - dec_logits = [ - self._flatten(o['logits'], target_lens) - for o in outputs['out_decs'] - ] - dec_loss = self._loss_over_iters(dec_logits, - flatten_targets) * self.dec_weight - losses['loss_lang'] = dec_loss - if outputs.get('out_fusers', None): - fusion_logits = [ - self._flatten(o['logits'], target_lens) - for o in outputs['out_fusers'] - ] - fusion_loss = self._loss_over_iters( - fusion_logits, flatten_targets) * self.fusion_weight - losses['loss_fusion'] = fusion_loss - return losses diff --git a/tests/test_models/test_textrecog/test_losses/test_abi_loss.py b/tests/test_models/test_textrecog/test_losses/test_abi_loss.py new file mode 100644 index 00000000..e7857e88 --- /dev/null +++ b/tests/test_models/test_textrecog/test_losses/test_abi_loss.py @@ -0,0 +1,65 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from unittest import TestCase + +import numpy as np +import torch +from mmengine.data import LabelData + +from mmocr.data import TextRecogDataSample +from mmocr.models.textrecog.losses import ABILoss + + +class TestABILoss(TestCase): + + def setUp(self) -> None: + + data_sample1 = TextRecogDataSample() + data_sample1.gt_text = LabelData(item='hello') + data_sample2 = TextRecogDataSample() + data_sample2.gt_text = LabelData(item='123') + self.gt = [data_sample1, data_sample2] + + def _equal(self, a, b): + if isinstance(a, (torch.Tensor, np.ndarray)): + return (a == b).all() + else: + return a == b + + def test_forward(self): + dict_cfg = dict( + type='Dictionary', + dict_file='dicts/lower_english_digits.txt', + with_start=True, + with_end=True, + same_start_end=True, + with_padding=True, + with_unknown=False) + abi_loss = ABILoss(dict_cfg, max_seq_len=10) + abi_loss.get_targets(self.gt) + outputs = dict( + out_vis=dict(logits=torch.randn(2, 10, 38)), + out_langs=[ + dict(logits=torch.randn(2, 10, 38)), + dict(logits=torch.randn(2, 10, 38)) + ], + out_fusers=[ + dict(logits=torch.randn(2, 10, 38)), + dict(logits=torch.randn(2, 10, 38)) + ]) + losses = abi_loss(outputs, self.gt) + self.assertIsInstance(losses, dict) + self.assertIn('loss_visual', losses) + self.assertIn('loss_lang', losses) + self.assertIn('loss_fusion', losses) + print(losses['loss_lang']) + print(losses['loss_fusion']) + + outputs.pop('out_vis') + abi_loss(outputs, self.gt) + out_langs = outputs.pop('out_langs') + abi_loss(outputs, self.gt) + outputs.pop('out_fusers') + with self.assertRaises(AssertionError): + abi_loss(outputs, self.gt) + outputs['out_langs'] = out_langs + abi_loss(outputs, self.gt) diff --git a/tests/test_models/test_textrecog/test_losses/test_base_recog_loss.py b/tests/test_models/test_textrecog/test_losses/test_base_recog_loss.py index 3a1110d3..d8a3dc2b 100644 --- a/tests/test_models/test_textrecog/test_losses/test_base_recog_loss.py +++ b/tests/test_models/test_textrecog/test_losses/test_base_recog_loss.py @@ -45,6 +45,22 @@ class TestBaseRecogLoss(TestCase): # test case mode with self.assertRaises(AssertionError): base_recog_loss = BaseRecogLoss(dict_cfg, letter_case='no') + # test invalid pad_with + with self.assertRaises(AssertionError): + base_recog_loss = BaseRecogLoss(dict_cfg, pad_with='test') + # test invalid combination of dictionary and pad_with + dict_cfg = dict(type='Dictionary', dict_file=dict_file, with_end=False) + for pad_with in ['end', 'padding']: + with self.assertRaisesRegex( + ValueError, f'pad_with="{pad_with}", but' + f' dictionary.{pad_with}_idx is None'): + base_recog_loss = BaseRecogLoss(dict_cfg, pad_with=pad_with) + with self.assertRaisesRegex( + ValueError, 'pad_with="auto", but' + ' dictionary.end_idx and dictionary.padding_idx are both' + ' None'): + base_recog_loss = BaseRecogLoss(dict_cfg, pad_with='auto') + # test dictionary is invalid type dict_cfg = ['tmp'] with self.assertRaisesRegex( @@ -82,8 +98,10 @@ class TestBaseRecogLoss(TestCase): padding_idx, padding_idx, padding_idx, padding_idx ])) self.assertTrue(target_data_samples[0].have_target) + target_data_samples = base_recog_loss.get_targets(target_data_samples) data_sample.set_metainfo(dict(have_target=False)) + dictionary = Dictionary( dict_file=dict_file, with_start=False, @@ -100,23 +118,25 @@ class TestBaseRecogLoss(TestCase): assert self._equal(target_data_samples[0].gt_text.padded_indexes, torch.LongTensor([0, 1, 2])) data_sample.set_metainfo(dict(have_target=False)) + dict_cfg = dict( type='Dictionary', - dict_file=dict_file, + dict_file='dicts/lower_english_digits.txt', with_start=False, - with_end=False, + with_end=True, same_start_end=False, - with_padding=False, + with_padding=True, with_unknown=True) base_recog_loss = BaseRecogLoss( - dict_cfg, max_seq_len=10, letter_case='lower') + dict_cfg, max_seq_len=10, letter_case='lower', pad_with='none') data_sample.gt_text.item = '0123' target_data_samples = base_recog_loss.get_targets([data_sample]) assert self._equal(target_data_samples[0].gt_text.indexes, torch.LongTensor([0, 1, 2, 3])) assert self._equal(target_data_samples[0].gt_text.padded_indexes, - torch.LongTensor([0, 1, 2, 3])) + torch.LongTensor([0, 1, 2, 3, 36])) target_data_samples = base_recog_loss.get_targets([]) self.assertListEqual(target_data_samples, []) + tmp_dir.cleanup() diff --git a/tests/test_models/test_textrecog/test_losses/test_ce_loss.py b/tests/test_models/test_textrecog/test_losses/test_ce_loss.py index 2beb46dd..0a6d374e 100644 --- a/tests/test_models/test_textrecog/test_losses/test_ce_loss.py +++ b/tests/test_models/test_textrecog/test_losses/test_ce_loss.py @@ -44,12 +44,13 @@ class TestCELoss(TestCase): self.assertEqual(ce_loss.ignore_index, 37) ce_loss = CELoss(dict_cfg, ignore_char=-1) self.assertEqual(ce_loss.ignore_index, -1) - with self.assertRaises(ValueError): + # with self.assertRaises(ValueError): + with self.assertWarns(UserWarning): ce_loss = CELoss(dict_cfg, ignore_char='ignore') ce_loss = CELoss(dict_cfg, ignore_char='1') self.assertEqual(ce_loss.ignore_index, 1) - def test_ce_loss(self): + def test_forward(self): dict_cfg = dict( type='Dictionary', dict_file='dicts/lower_english_digits.txt',