[ABINet] Refactor ABILoss

This commit is contained in:
gaotongxiao 2022-07-13 02:53:59 +00:00
parent ee1212a5cd
commit a844b497db
8 changed files with 237 additions and 126 deletions

View File

@ -1,8 +1,8 @@
# Copyright (c) OpenMMLab. All rights reserved. # Copyright (c) OpenMMLab. All rights reserved.
from .abi_loss import ABILoss
from .base_recog_loss import BaseRecogLoss from .base_recog_loss import BaseRecogLoss
from .ce_loss import CELoss from .ce_loss import CELoss
from .ctc_loss import CTCLoss from .ctc_loss import CTCLoss
from .mix_loss import ABILoss
from .seg_loss import SegLoss from .seg_loss import SegLoss
__all__ = ['BaseRecogLoss', 'CELoss', 'CTCLoss', 'SegLoss', 'ABILoss'] __all__ = ['BaseRecogLoss', 'CELoss', 'CTCLoss', 'SegLoss', 'ABILoss']

View File

@ -0,0 +1,99 @@
# Copyright (c) OpenMMLab. All rights reserved.
from typing import Dict, Sequence, Union
import torch
from mmocr.data import TextRecogDataSample
from mmocr.models.textrecog.dictionary.dictionary import Dictionary
from mmocr.registry import MODELS
from .base_recog_loss import BaseRecogLoss
from .ce_loss import CELoss
@MODELS.register_module()
class ABILoss(BaseRecogLoss):
"""Implementation of ABINet multiloss that allows mixing different types of
losses with weights.
Args:
dictionary (dict or :obj:`Dictionary`): The config for `Dictionary` or
the instance of `Dictionary`.
max_seq_len (int): Maximum sequence length. The sequence is usually
generated from decoder. Defaults to 40.
letter_case (str): There are three options to alter the letter cases
of gt texts:
- unchanged: Do not change gt texts.
- upper: Convert gt texts into uppercase characters.
- lower: Convert gt texts into lowercase characters.
Usually, it only works for English characters. Defaults to
'unchanged'.
weight_vis (float or int): The weight of vision decoder loss. Defaults
to 1.0.
weight_dec (float or int): The weight of language decoder loss.
Defaults to 1.0.
weight_fusion (float or int): The weight of fuser (aligner) loss.
Defaults to 1.0.
"""
def __init__(self,
dictionary: Union[Dict, Dictionary],
max_seq_len: int = 40,
letter_case: str = 'unchanged',
weight_vis: Union[float, int] = 1.0,
weight_lang: Union[float, int] = 1.0,
weight_fusion: Union[float, int] = 1.0,
**kwargs) -> None:
assert isinstance(weight_vis, (float, int))
assert isinstance(weight_lang, (float, int))
assert isinstance(weight_fusion, (float, int))
super().__init__(
dictionary=dictionary,
max_seq_len=max_seq_len,
letter_case=letter_case)
self.weight_vis = weight_vis
self.weight_lang = weight_lang
self.weight_fusion = weight_fusion
self._ce_loss = CELoss(
self.dictionary,
max_seq_len,
letter_case,
reduction='mean',
ignore_first_char=True)
def forward(self, outputs: Dict,
data_samples: Sequence[TextRecogDataSample]) -> Dict:
"""
Args:
outputs (dict): The output dictionary with at least one of
``out_vis``, ``out_langs`` and ``out_fusers`` specified.
data_samples (list[TextRecogDataSample]): List of
``TextRecogDataSample`` which are processed by ``get_target``.
Returns:
dict: A loss dictionary with ``loss_visual``, ``loss_lang`` and
``loss_fusion``. Each should either be the loss tensor or None if
the output of its corresponding module is not given.
"""
assert 'out_vis' in outputs or \
'out_langs' in outputs or 'out_fusers' in outputs
losses = {}
if outputs.get('out_vis', None):
losses['loss_visual'] = self.weight_vis * self._ce_loss(
outputs['out_vis']['logits'], data_samples)['loss_ce']
if outputs.get('out_langs', None):
lang_losses = []
for out_lang in outputs['out_langs']:
lang_losses.append(
self._ce_loss(out_lang['logits'], data_samples)['loss_ce'])
losses['loss_lang'] = self.weight_lang * torch.mean(
torch.stack(lang_losses))
if outputs.get('out_fusers', None):
fuser_losses = []
for out_fuser in outputs['out_fusers']:
fuser_losses.append(
self._ce_loss(out_fuser['logits'],
data_samples)['loss_ce'])
losses['loss_fusion'] = self.weight_fusion * torch.mean(
torch.stack(fuser_losses))
return losses

View File

@ -25,12 +25,21 @@ class BaseRecogLoss(nn.Module):
- lower: Convert gt texts into lowercase characters. - lower: Convert gt texts into lowercase characters.
Usually, it only works for English characters. Defaults to Usually, it only works for English characters. Defaults to
'unchanged'. 'unchanged'.
pad_with (str): The padding strategy for ``gt_text.padded_indexes``.
Defaults to 'auto'. Options are:
- 'auto': Use dictionary.padding_idx to pad gt texts, or
dictionary.end_idx if dictionary.padding_idx
is None.
- 'padding': Always use dictionary.padding_idx to pad gt texts.
- 'end': Always use dictionary.end_idx to pad gt texts.
- 'none': Do not pad gt texts.
""" """
def __init__(self, def __init__(self,
dictionary: Union[Dict, Dictionary], dictionary: Union[Dict, Dictionary],
max_seq_len: int = 40, max_seq_len: int = 40,
letter_case: str = 'unchanged', letter_case: str = 'unchanged',
pad_with: str = 'auto',
**kwargs) -> None: **kwargs) -> None:
super().__init__() super().__init__()
if isinstance(dictionary, dict): if isinstance(dictionary, dict):
@ -45,6 +54,25 @@ class BaseRecogLoss(nn.Module):
assert letter_case in ['unchanged', 'upper', 'lower'] assert letter_case in ['unchanged', 'upper', 'lower']
self.letter_case = letter_case self.letter_case = letter_case
assert pad_with in ['auto', 'padding', 'end', 'none']
if pad_with == 'auto':
self.pad_idx = self.dictionary.padding_idx or \
self.dictionary.end_idx
elif pad_with == 'padding':
self.pad_idx = self.dictionary.padding_idx
elif pad_with == 'end':
self.pad_idx = self.dictionary.end_idx
else:
self.pad_idx = None
if self.pad_idx is None and pad_with != 'none':
if pad_with == 'auto':
raise ValueError('pad_with="auto", but dictionary.end_idx'
' and dictionary.padding_idx are both None')
else:
raise ValueError(
f'pad_with="{pad_with}", but dictionary.{pad_with}_idx is'
' None')
def get_targets( def get_targets(
self, data_samples: Sequence[TextRecogDataSample] self, data_samples: Sequence[TextRecogDataSample]
) -> Sequence[TextRecogDataSample]: ) -> Sequence[TextRecogDataSample]:
@ -59,10 +87,12 @@ class BaseRecogLoss(nn.Module):
added to data_sample: added to data_sample:
- indexes (torch.LongTensor): Character indexes representing gt - indexes (torch.LongTensor): Character indexes representing gt
texts. texts. All special tokens are excluded, except for UKN.
- padded_indexes (torch.LongTensor) Character indexes - padded_indexes (torch.LongTensor): Character indexes
representing gt texts, following several padding_idxs until representing gt texts with BOS and EOS if applicable, following
reaching the length of ``max_seq_len``. several padding indexes until the length reaches ``max_seq_len``.
In particular, if ``pad_with='none'``, no padding will be
applied.
""" """
for data_sample in data_samples: for data_sample in data_samples:
@ -88,14 +118,13 @@ class BaseRecogLoss(nn.Module):
else: else:
slice_end = src_target.size(0) - 1 slice_end = src_target.size(0) - 1
src_target = src_target[slice_start:slice_end] src_target = src_target[slice_start:slice_end]
if self.dictionary.padding_idx is not None: if self.pad_idx is not None:
padded_indexes = (torch.ones(self.max_seq_len) * padded_indexes = (torch.ones(self.max_seq_len) *
self.dictionary.padding_idx).long() self.pad_idx).long()
char_num = min(src_target.size(0), self.max_seq_len) char_num = min(src_target.size(0), self.max_seq_len)
padded_indexes[:char_num] = src_target[:char_num] padded_indexes[:char_num] = src_target[:char_num]
else: else:
padded_indexes = src_target padded_indexes = src_target
# put in DataSample # put in DataSample
data_sample.gt_text.indexes = indexes data_sample.gt_text.indexes = indexes
data_sample.gt_text.padded_indexes = padded_indexes data_sample.gt_text.padded_indexes = padded_indexes

View File

@ -1,4 +1,5 @@
# Copyright (c) OpenMMLab. All rights reserved. # Copyright (c) OpenMMLab. All rights reserved.
import warnings
from typing import Dict, Sequence, Union from typing import Dict, Sequence, Union
import torch import torch
@ -36,6 +37,8 @@ class CELoss(BaseRecogLoss):
and 'unknown', which refer to their corresponding special and 'unknown', which refer to their corresponding special
tokens in the dictionary. It will not ignore any special tokens in the dictionary. It will not ignore any special
tokens when ignore_char == -1 or 'none'. Defaults to 'padding'. tokens when ignore_char == -1 or 'none'. Defaults to 'padding'.
flatten (bool): Whether to flatten the output and target before
computing CE loss. Defaults to False.
reduction (str): Specifies the reduction to apply to the output, reduction (str): Specifies the reduction to apply to the output,
should be one of the following: ('none', 'mean', 'sum'). Defaults should be one of the following: ('none', 'mean', 'sum'). Defaults
to 'none'. to 'none'.
@ -82,8 +85,11 @@ class CELoss(BaseRecogLoss):
ignore_index = mapping_table.get( ignore_index = mapping_table.get(
ignore_char, self.dictionary._char2idx.get(ignore_char, None)) ignore_char, self.dictionary._char2idx.get(ignore_char, None))
if ignore_index is None: if ignore_index is None:
raise ValueError( warnings.warn(
f'{ignore_char} does not exist in the dictionary') f'{ignore_char} does not exist in the dictionary',
UserWarning)
ignore_index = -1
self.ignore_char = ignore_char self.ignore_char = ignore_char
self.ignore_index = ignore_index self.ignore_index = ignore_index
self.loss_ce = nn.CrossEntropyLoss( self.loss_ce = nn.CrossEntropyLoss(

View File

@ -1,109 +0,0 @@
# Copyright (c) OpenMMLab. All rights reserved.
import torch
import torch.nn as nn
import torch.nn.functional as F
from mmocr.registry import MODELS
@MODELS.register_module()
class ABILoss(nn.Module):
"""Implementation of ABINet multiloss that allows mixing different types of
losses with weights.
Args:
enc_weight (float): The weight of encoder loss. Defaults to 1.0.
dec_weight (float): The weight of decoder loss. Defaults to 1.0.
fusion_weight (float): The weight of fuser (aligner) loss.
Defaults to 1.0.
num_classes (int): Number of unique output language tokens.
Returns:
A dictionary whose key/value pairs are the losses of three modules.
"""
def __init__(self,
enc_weight=1.0,
dec_weight=1.0,
fusion_weight=1.0,
num_classes=37,
**kwargs):
assert isinstance(enc_weight, float) or isinstance(enc_weight, int)
assert isinstance(dec_weight, float) or isinstance(dec_weight, int)
assert isinstance(fusion_weight, float) or \
isinstance(fusion_weight, int)
super().__init__()
self.enc_weight = enc_weight
self.dec_weight = dec_weight
self.fusion_weight = fusion_weight
self.num_classes = num_classes
def _flatten(self, logits, target_lens):
flatten_logits = torch.cat(
[s[:target_lens[i]] for i, s in enumerate(logits)])
return flatten_logits
def _ce_loss(self, logits, targets):
targets_one_hot = F.one_hot(targets, self.num_classes)
log_prob = F.log_softmax(logits, dim=-1)
loss = -(targets_one_hot.to(log_prob.device) * log_prob).sum(dim=-1)
return loss.mean()
def _loss_over_iters(self, outputs, targets):
"""
Args:
outputs (list[Tensor]): Each tensor has shape (N, T, C) where N is
the batch size, T is the sequence length and C is the number of
classes.
targets_dicts (dict): The dictionary with at least `padded_targets`
defined.
"""
iter_num = len(outputs)
dec_outputs = torch.cat(outputs, dim=0)
flatten_targets_iternum = targets.repeat(iter_num)
return self._ce_loss(dec_outputs, flatten_targets_iternum)
def forward(self, outputs, targets_dict, img_metas=None):
"""
Args:
outputs (dict): The output dictionary with at least one of
``out_enc``, ``out_dec`` and ``out_fusers`` specified.
targets_dict (dict): The target dictionary containing the key
``padded_targets``, which represents target sequences in
shape (batch_size, sequence_length).
Returns:
A loss dictionary with ``loss_visual``, ``loss_lang`` and
``loss_fusion``. Each should either be the loss tensor or ``0`` if
the output of its corresponding module is not given.
"""
assert 'out_enc' in outputs or \
'out_dec' in outputs or 'out_fusers' in outputs
losses = {}
target_lens = [len(t) for t in targets_dict['targets']]
flatten_targets = torch.cat([t for t in targets_dict['targets']])
if outputs.get('out_enc', None):
enc_input = self._flatten(outputs['out_enc']['logits'],
target_lens)
enc_loss = self._ce_loss(enc_input,
flatten_targets) * self.enc_weight
losses['loss_visual'] = enc_loss
if outputs.get('out_decs', None):
dec_logits = [
self._flatten(o['logits'], target_lens)
for o in outputs['out_decs']
]
dec_loss = self._loss_over_iters(dec_logits,
flatten_targets) * self.dec_weight
losses['loss_lang'] = dec_loss
if outputs.get('out_fusers', None):
fusion_logits = [
self._flatten(o['logits'], target_lens)
for o in outputs['out_fusers']
]
fusion_loss = self._loss_over_iters(
fusion_logits, flatten_targets) * self.fusion_weight
losses['loss_fusion'] = fusion_loss
return losses

View File

@ -0,0 +1,65 @@
# Copyright (c) OpenMMLab. All rights reserved.
from unittest import TestCase
import numpy as np
import torch
from mmengine.data import LabelData
from mmocr.data import TextRecogDataSample
from mmocr.models.textrecog.losses import ABILoss
class TestABILoss(TestCase):
def setUp(self) -> None:
data_sample1 = TextRecogDataSample()
data_sample1.gt_text = LabelData(item='hello')
data_sample2 = TextRecogDataSample()
data_sample2.gt_text = LabelData(item='123')
self.gt = [data_sample1, data_sample2]
def _equal(self, a, b):
if isinstance(a, (torch.Tensor, np.ndarray)):
return (a == b).all()
else:
return a == b
def test_forward(self):
dict_cfg = dict(
type='Dictionary',
dict_file='dicts/lower_english_digits.txt',
with_start=True,
with_end=True,
same_start_end=True,
with_padding=True,
with_unknown=False)
abi_loss = ABILoss(dict_cfg, max_seq_len=10)
abi_loss.get_targets(self.gt)
outputs = dict(
out_vis=dict(logits=torch.randn(2, 10, 38)),
out_langs=[
dict(logits=torch.randn(2, 10, 38)),
dict(logits=torch.randn(2, 10, 38))
],
out_fusers=[
dict(logits=torch.randn(2, 10, 38)),
dict(logits=torch.randn(2, 10, 38))
])
losses = abi_loss(outputs, self.gt)
self.assertIsInstance(losses, dict)
self.assertIn('loss_visual', losses)
self.assertIn('loss_lang', losses)
self.assertIn('loss_fusion', losses)
print(losses['loss_lang'])
print(losses['loss_fusion'])
outputs.pop('out_vis')
abi_loss(outputs, self.gt)
out_langs = outputs.pop('out_langs')
abi_loss(outputs, self.gt)
outputs.pop('out_fusers')
with self.assertRaises(AssertionError):
abi_loss(outputs, self.gt)
outputs['out_langs'] = out_langs
abi_loss(outputs, self.gt)

View File

@ -45,6 +45,22 @@ class TestBaseRecogLoss(TestCase):
# test case mode # test case mode
with self.assertRaises(AssertionError): with self.assertRaises(AssertionError):
base_recog_loss = BaseRecogLoss(dict_cfg, letter_case='no') base_recog_loss = BaseRecogLoss(dict_cfg, letter_case='no')
# test invalid pad_with
with self.assertRaises(AssertionError):
base_recog_loss = BaseRecogLoss(dict_cfg, pad_with='test')
# test invalid combination of dictionary and pad_with
dict_cfg = dict(type='Dictionary', dict_file=dict_file, with_end=False)
for pad_with in ['end', 'padding']:
with self.assertRaisesRegex(
ValueError, f'pad_with="{pad_with}", but'
f' dictionary.{pad_with}_idx is None'):
base_recog_loss = BaseRecogLoss(dict_cfg, pad_with=pad_with)
with self.assertRaisesRegex(
ValueError, 'pad_with="auto", but'
' dictionary.end_idx and dictionary.padding_idx are both'
' None'):
base_recog_loss = BaseRecogLoss(dict_cfg, pad_with='auto')
# test dictionary is invalid type # test dictionary is invalid type
dict_cfg = ['tmp'] dict_cfg = ['tmp']
with self.assertRaisesRegex( with self.assertRaisesRegex(
@ -82,8 +98,10 @@ class TestBaseRecogLoss(TestCase):
padding_idx, padding_idx, padding_idx, padding_idx padding_idx, padding_idx, padding_idx, padding_idx
])) ]))
self.assertTrue(target_data_samples[0].have_target) self.assertTrue(target_data_samples[0].have_target)
target_data_samples = base_recog_loss.get_targets(target_data_samples) target_data_samples = base_recog_loss.get_targets(target_data_samples)
data_sample.set_metainfo(dict(have_target=False)) data_sample.set_metainfo(dict(have_target=False))
dictionary = Dictionary( dictionary = Dictionary(
dict_file=dict_file, dict_file=dict_file,
with_start=False, with_start=False,
@ -100,23 +118,25 @@ class TestBaseRecogLoss(TestCase):
assert self._equal(target_data_samples[0].gt_text.padded_indexes, assert self._equal(target_data_samples[0].gt_text.padded_indexes,
torch.LongTensor([0, 1, 2])) torch.LongTensor([0, 1, 2]))
data_sample.set_metainfo(dict(have_target=False)) data_sample.set_metainfo(dict(have_target=False))
dict_cfg = dict( dict_cfg = dict(
type='Dictionary', type='Dictionary',
dict_file=dict_file, dict_file='dicts/lower_english_digits.txt',
with_start=False, with_start=False,
with_end=False, with_end=True,
same_start_end=False, same_start_end=False,
with_padding=False, with_padding=True,
with_unknown=True) with_unknown=True)
base_recog_loss = BaseRecogLoss( base_recog_loss = BaseRecogLoss(
dict_cfg, max_seq_len=10, letter_case='lower') dict_cfg, max_seq_len=10, letter_case='lower', pad_with='none')
data_sample.gt_text.item = '0123' data_sample.gt_text.item = '0123'
target_data_samples = base_recog_loss.get_targets([data_sample]) target_data_samples = base_recog_loss.get_targets([data_sample])
assert self._equal(target_data_samples[0].gt_text.indexes, assert self._equal(target_data_samples[0].gt_text.indexes,
torch.LongTensor([0, 1, 2, 3])) torch.LongTensor([0, 1, 2, 3]))
assert self._equal(target_data_samples[0].gt_text.padded_indexes, assert self._equal(target_data_samples[0].gt_text.padded_indexes,
torch.LongTensor([0, 1, 2, 3])) torch.LongTensor([0, 1, 2, 3, 36]))
target_data_samples = base_recog_loss.get_targets([]) target_data_samples = base_recog_loss.get_targets([])
self.assertListEqual(target_data_samples, []) self.assertListEqual(target_data_samples, [])
tmp_dir.cleanup() tmp_dir.cleanup()

View File

@ -44,12 +44,13 @@ class TestCELoss(TestCase):
self.assertEqual(ce_loss.ignore_index, 37) self.assertEqual(ce_loss.ignore_index, 37)
ce_loss = CELoss(dict_cfg, ignore_char=-1) ce_loss = CELoss(dict_cfg, ignore_char=-1)
self.assertEqual(ce_loss.ignore_index, -1) self.assertEqual(ce_loss.ignore_index, -1)
with self.assertRaises(ValueError): # with self.assertRaises(ValueError):
with self.assertWarns(UserWarning):
ce_loss = CELoss(dict_cfg, ignore_char='ignore') ce_loss = CELoss(dict_cfg, ignore_char='ignore')
ce_loss = CELoss(dict_cfg, ignore_char='1') ce_loss = CELoss(dict_cfg, ignore_char='1')
self.assertEqual(ce_loss.ignore_index, 1) self.assertEqual(ce_loss.ignore_index, 1)
def test_ce_loss(self): def test_forward(self):
dict_cfg = dict( dict_cfg = dict(
type='Dictionary', type='Dictionary',
dict_file='dicts/lower_english_digits.txt', dict_file='dicts/lower_english_digits.txt',