[ABINet] Refactor ABILoss

pull/1178/head
gaotongxiao 2022-07-13 02:53:59 +00:00
parent ee1212a5cd
commit a844b497db
8 changed files with 237 additions and 126 deletions

View File

@ -1,8 +1,8 @@
# Copyright (c) OpenMMLab. All rights reserved.
from .abi_loss import ABILoss
from .base_recog_loss import BaseRecogLoss
from .ce_loss import CELoss
from .ctc_loss import CTCLoss
from .mix_loss import ABILoss
from .seg_loss import SegLoss
__all__ = ['BaseRecogLoss', 'CELoss', 'CTCLoss', 'SegLoss', 'ABILoss']

View File

@ -0,0 +1,99 @@
# Copyright (c) OpenMMLab. All rights reserved.
from typing import Dict, Sequence, Union
import torch
from mmocr.data import TextRecogDataSample
from mmocr.models.textrecog.dictionary.dictionary import Dictionary
from mmocr.registry import MODELS
from .base_recog_loss import BaseRecogLoss
from .ce_loss import CELoss
@MODELS.register_module()
class ABILoss(BaseRecogLoss):
"""Implementation of ABINet multiloss that allows mixing different types of
losses with weights.
Args:
dictionary (dict or :obj:`Dictionary`): The config for `Dictionary` or
the instance of `Dictionary`.
max_seq_len (int): Maximum sequence length. The sequence is usually
generated from decoder. Defaults to 40.
letter_case (str): There are three options to alter the letter cases
of gt texts:
- unchanged: Do not change gt texts.
- upper: Convert gt texts into uppercase characters.
- lower: Convert gt texts into lowercase characters.
Usually, it only works for English characters. Defaults to
'unchanged'.
weight_vis (float or int): The weight of vision decoder loss. Defaults
to 1.0.
weight_dec (float or int): The weight of language decoder loss.
Defaults to 1.0.
weight_fusion (float or int): The weight of fuser (aligner) loss.
Defaults to 1.0.
"""
def __init__(self,
dictionary: Union[Dict, Dictionary],
max_seq_len: int = 40,
letter_case: str = 'unchanged',
weight_vis: Union[float, int] = 1.0,
weight_lang: Union[float, int] = 1.0,
weight_fusion: Union[float, int] = 1.0,
**kwargs) -> None:
assert isinstance(weight_vis, (float, int))
assert isinstance(weight_lang, (float, int))
assert isinstance(weight_fusion, (float, int))
super().__init__(
dictionary=dictionary,
max_seq_len=max_seq_len,
letter_case=letter_case)
self.weight_vis = weight_vis
self.weight_lang = weight_lang
self.weight_fusion = weight_fusion
self._ce_loss = CELoss(
self.dictionary,
max_seq_len,
letter_case,
reduction='mean',
ignore_first_char=True)
def forward(self, outputs: Dict,
data_samples: Sequence[TextRecogDataSample]) -> Dict:
"""
Args:
outputs (dict): The output dictionary with at least one of
``out_vis``, ``out_langs`` and ``out_fusers`` specified.
data_samples (list[TextRecogDataSample]): List of
``TextRecogDataSample`` which are processed by ``get_target``.
Returns:
dict: A loss dictionary with ``loss_visual``, ``loss_lang`` and
``loss_fusion``. Each should either be the loss tensor or None if
the output of its corresponding module is not given.
"""
assert 'out_vis' in outputs or \
'out_langs' in outputs or 'out_fusers' in outputs
losses = {}
if outputs.get('out_vis', None):
losses['loss_visual'] = self.weight_vis * self._ce_loss(
outputs['out_vis']['logits'], data_samples)['loss_ce']
if outputs.get('out_langs', None):
lang_losses = []
for out_lang in outputs['out_langs']:
lang_losses.append(
self._ce_loss(out_lang['logits'], data_samples)['loss_ce'])
losses['loss_lang'] = self.weight_lang * torch.mean(
torch.stack(lang_losses))
if outputs.get('out_fusers', None):
fuser_losses = []
for out_fuser in outputs['out_fusers']:
fuser_losses.append(
self._ce_loss(out_fuser['logits'],
data_samples)['loss_ce'])
losses['loss_fusion'] = self.weight_fusion * torch.mean(
torch.stack(fuser_losses))
return losses

View File

@ -25,12 +25,21 @@ class BaseRecogLoss(nn.Module):
- lower: Convert gt texts into lowercase characters.
Usually, it only works for English characters. Defaults to
'unchanged'.
pad_with (str): The padding strategy for ``gt_text.padded_indexes``.
Defaults to 'auto'. Options are:
- 'auto': Use dictionary.padding_idx to pad gt texts, or
dictionary.end_idx if dictionary.padding_idx
is None.
- 'padding': Always use dictionary.padding_idx to pad gt texts.
- 'end': Always use dictionary.end_idx to pad gt texts.
- 'none': Do not pad gt texts.
"""
def __init__(self,
dictionary: Union[Dict, Dictionary],
max_seq_len: int = 40,
letter_case: str = 'unchanged',
pad_with: str = 'auto',
**kwargs) -> None:
super().__init__()
if isinstance(dictionary, dict):
@ -45,6 +54,25 @@ class BaseRecogLoss(nn.Module):
assert letter_case in ['unchanged', 'upper', 'lower']
self.letter_case = letter_case
assert pad_with in ['auto', 'padding', 'end', 'none']
if pad_with == 'auto':
self.pad_idx = self.dictionary.padding_idx or \
self.dictionary.end_idx
elif pad_with == 'padding':
self.pad_idx = self.dictionary.padding_idx
elif pad_with == 'end':
self.pad_idx = self.dictionary.end_idx
else:
self.pad_idx = None
if self.pad_idx is None and pad_with != 'none':
if pad_with == 'auto':
raise ValueError('pad_with="auto", but dictionary.end_idx'
' and dictionary.padding_idx are both None')
else:
raise ValueError(
f'pad_with="{pad_with}", but dictionary.{pad_with}_idx is'
' None')
def get_targets(
self, data_samples: Sequence[TextRecogDataSample]
) -> Sequence[TextRecogDataSample]:
@ -59,10 +87,12 @@ class BaseRecogLoss(nn.Module):
added to data_sample:
- indexes (torch.LongTensor): Character indexes representing gt
texts.
- padded_indexes (torch.LongTensor) Character indexes
representing gt texts, following several padding_idxs until
reaching the length of ``max_seq_len``.
texts. All special tokens are excluded, except for UKN.
- padded_indexes (torch.LongTensor): Character indexes
representing gt texts with BOS and EOS if applicable, following
several padding indexes until the length reaches ``max_seq_len``.
In particular, if ``pad_with='none'``, no padding will be
applied.
"""
for data_sample in data_samples:
@ -88,14 +118,13 @@ class BaseRecogLoss(nn.Module):
else:
slice_end = src_target.size(0) - 1
src_target = src_target[slice_start:slice_end]
if self.dictionary.padding_idx is not None:
if self.pad_idx is not None:
padded_indexes = (torch.ones(self.max_seq_len) *
self.dictionary.padding_idx).long()
self.pad_idx).long()
char_num = min(src_target.size(0), self.max_seq_len)
padded_indexes[:char_num] = src_target[:char_num]
else:
padded_indexes = src_target
# put in DataSample
data_sample.gt_text.indexes = indexes
data_sample.gt_text.padded_indexes = padded_indexes

View File

@ -1,4 +1,5 @@
# Copyright (c) OpenMMLab. All rights reserved.
import warnings
from typing import Dict, Sequence, Union
import torch
@ -36,6 +37,8 @@ class CELoss(BaseRecogLoss):
and 'unknown', which refer to their corresponding special
tokens in the dictionary. It will not ignore any special
tokens when ignore_char == -1 or 'none'. Defaults to 'padding'.
flatten (bool): Whether to flatten the output and target before
computing CE loss. Defaults to False.
reduction (str): Specifies the reduction to apply to the output,
should be one of the following: ('none', 'mean', 'sum'). Defaults
to 'none'.
@ -82,8 +85,11 @@ class CELoss(BaseRecogLoss):
ignore_index = mapping_table.get(
ignore_char, self.dictionary._char2idx.get(ignore_char, None))
if ignore_index is None:
raise ValueError(
f'{ignore_char} does not exist in the dictionary')
warnings.warn(
f'{ignore_char} does not exist in the dictionary',
UserWarning)
ignore_index = -1
self.ignore_char = ignore_char
self.ignore_index = ignore_index
self.loss_ce = nn.CrossEntropyLoss(

View File

@ -1,109 +0,0 @@
# Copyright (c) OpenMMLab. All rights reserved.
import torch
import torch.nn as nn
import torch.nn.functional as F
from mmocr.registry import MODELS
@MODELS.register_module()
class ABILoss(nn.Module):
"""Implementation of ABINet multiloss that allows mixing different types of
losses with weights.
Args:
enc_weight (float): The weight of encoder loss. Defaults to 1.0.
dec_weight (float): The weight of decoder loss. Defaults to 1.0.
fusion_weight (float): The weight of fuser (aligner) loss.
Defaults to 1.0.
num_classes (int): Number of unique output language tokens.
Returns:
A dictionary whose key/value pairs are the losses of three modules.
"""
def __init__(self,
enc_weight=1.0,
dec_weight=1.0,
fusion_weight=1.0,
num_classes=37,
**kwargs):
assert isinstance(enc_weight, float) or isinstance(enc_weight, int)
assert isinstance(dec_weight, float) or isinstance(dec_weight, int)
assert isinstance(fusion_weight, float) or \
isinstance(fusion_weight, int)
super().__init__()
self.enc_weight = enc_weight
self.dec_weight = dec_weight
self.fusion_weight = fusion_weight
self.num_classes = num_classes
def _flatten(self, logits, target_lens):
flatten_logits = torch.cat(
[s[:target_lens[i]] for i, s in enumerate(logits)])
return flatten_logits
def _ce_loss(self, logits, targets):
targets_one_hot = F.one_hot(targets, self.num_classes)
log_prob = F.log_softmax(logits, dim=-1)
loss = -(targets_one_hot.to(log_prob.device) * log_prob).sum(dim=-1)
return loss.mean()
def _loss_over_iters(self, outputs, targets):
"""
Args:
outputs (list[Tensor]): Each tensor has shape (N, T, C) where N is
the batch size, T is the sequence length and C is the number of
classes.
targets_dicts (dict): The dictionary with at least `padded_targets`
defined.
"""
iter_num = len(outputs)
dec_outputs = torch.cat(outputs, dim=0)
flatten_targets_iternum = targets.repeat(iter_num)
return self._ce_loss(dec_outputs, flatten_targets_iternum)
def forward(self, outputs, targets_dict, img_metas=None):
"""
Args:
outputs (dict): The output dictionary with at least one of
``out_enc``, ``out_dec`` and ``out_fusers`` specified.
targets_dict (dict): The target dictionary containing the key
``padded_targets``, which represents target sequences in
shape (batch_size, sequence_length).
Returns:
A loss dictionary with ``loss_visual``, ``loss_lang`` and
``loss_fusion``. Each should either be the loss tensor or ``0`` if
the output of its corresponding module is not given.
"""
assert 'out_enc' in outputs or \
'out_dec' in outputs or 'out_fusers' in outputs
losses = {}
target_lens = [len(t) for t in targets_dict['targets']]
flatten_targets = torch.cat([t for t in targets_dict['targets']])
if outputs.get('out_enc', None):
enc_input = self._flatten(outputs['out_enc']['logits'],
target_lens)
enc_loss = self._ce_loss(enc_input,
flatten_targets) * self.enc_weight
losses['loss_visual'] = enc_loss
if outputs.get('out_decs', None):
dec_logits = [
self._flatten(o['logits'], target_lens)
for o in outputs['out_decs']
]
dec_loss = self._loss_over_iters(dec_logits,
flatten_targets) * self.dec_weight
losses['loss_lang'] = dec_loss
if outputs.get('out_fusers', None):
fusion_logits = [
self._flatten(o['logits'], target_lens)
for o in outputs['out_fusers']
]
fusion_loss = self._loss_over_iters(
fusion_logits, flatten_targets) * self.fusion_weight
losses['loss_fusion'] = fusion_loss
return losses

View File

@ -0,0 +1,65 @@
# Copyright (c) OpenMMLab. All rights reserved.
from unittest import TestCase
import numpy as np
import torch
from mmengine.data import LabelData
from mmocr.data import TextRecogDataSample
from mmocr.models.textrecog.losses import ABILoss
class TestABILoss(TestCase):
def setUp(self) -> None:
data_sample1 = TextRecogDataSample()
data_sample1.gt_text = LabelData(item='hello')
data_sample2 = TextRecogDataSample()
data_sample2.gt_text = LabelData(item='123')
self.gt = [data_sample1, data_sample2]
def _equal(self, a, b):
if isinstance(a, (torch.Tensor, np.ndarray)):
return (a == b).all()
else:
return a == b
def test_forward(self):
dict_cfg = dict(
type='Dictionary',
dict_file='dicts/lower_english_digits.txt',
with_start=True,
with_end=True,
same_start_end=True,
with_padding=True,
with_unknown=False)
abi_loss = ABILoss(dict_cfg, max_seq_len=10)
abi_loss.get_targets(self.gt)
outputs = dict(
out_vis=dict(logits=torch.randn(2, 10, 38)),
out_langs=[
dict(logits=torch.randn(2, 10, 38)),
dict(logits=torch.randn(2, 10, 38))
],
out_fusers=[
dict(logits=torch.randn(2, 10, 38)),
dict(logits=torch.randn(2, 10, 38))
])
losses = abi_loss(outputs, self.gt)
self.assertIsInstance(losses, dict)
self.assertIn('loss_visual', losses)
self.assertIn('loss_lang', losses)
self.assertIn('loss_fusion', losses)
print(losses['loss_lang'])
print(losses['loss_fusion'])
outputs.pop('out_vis')
abi_loss(outputs, self.gt)
out_langs = outputs.pop('out_langs')
abi_loss(outputs, self.gt)
outputs.pop('out_fusers')
with self.assertRaises(AssertionError):
abi_loss(outputs, self.gt)
outputs['out_langs'] = out_langs
abi_loss(outputs, self.gt)

View File

@ -45,6 +45,22 @@ class TestBaseRecogLoss(TestCase):
# test case mode
with self.assertRaises(AssertionError):
base_recog_loss = BaseRecogLoss(dict_cfg, letter_case='no')
# test invalid pad_with
with self.assertRaises(AssertionError):
base_recog_loss = BaseRecogLoss(dict_cfg, pad_with='test')
# test invalid combination of dictionary and pad_with
dict_cfg = dict(type='Dictionary', dict_file=dict_file, with_end=False)
for pad_with in ['end', 'padding']:
with self.assertRaisesRegex(
ValueError, f'pad_with="{pad_with}", but'
f' dictionary.{pad_with}_idx is None'):
base_recog_loss = BaseRecogLoss(dict_cfg, pad_with=pad_with)
with self.assertRaisesRegex(
ValueError, 'pad_with="auto", but'
' dictionary.end_idx and dictionary.padding_idx are both'
' None'):
base_recog_loss = BaseRecogLoss(dict_cfg, pad_with='auto')
# test dictionary is invalid type
dict_cfg = ['tmp']
with self.assertRaisesRegex(
@ -82,8 +98,10 @@ class TestBaseRecogLoss(TestCase):
padding_idx, padding_idx, padding_idx, padding_idx
]))
self.assertTrue(target_data_samples[0].have_target)
target_data_samples = base_recog_loss.get_targets(target_data_samples)
data_sample.set_metainfo(dict(have_target=False))
dictionary = Dictionary(
dict_file=dict_file,
with_start=False,
@ -100,23 +118,25 @@ class TestBaseRecogLoss(TestCase):
assert self._equal(target_data_samples[0].gt_text.padded_indexes,
torch.LongTensor([0, 1, 2]))
data_sample.set_metainfo(dict(have_target=False))
dict_cfg = dict(
type='Dictionary',
dict_file=dict_file,
dict_file='dicts/lower_english_digits.txt',
with_start=False,
with_end=False,
with_end=True,
same_start_end=False,
with_padding=False,
with_padding=True,
with_unknown=True)
base_recog_loss = BaseRecogLoss(
dict_cfg, max_seq_len=10, letter_case='lower')
dict_cfg, max_seq_len=10, letter_case='lower', pad_with='none')
data_sample.gt_text.item = '0123'
target_data_samples = base_recog_loss.get_targets([data_sample])
assert self._equal(target_data_samples[0].gt_text.indexes,
torch.LongTensor([0, 1, 2, 3]))
assert self._equal(target_data_samples[0].gt_text.padded_indexes,
torch.LongTensor([0, 1, 2, 3]))
torch.LongTensor([0, 1, 2, 3, 36]))
target_data_samples = base_recog_loss.get_targets([])
self.assertListEqual(target_data_samples, [])
tmp_dir.cleanup()

View File

@ -44,12 +44,13 @@ class TestCELoss(TestCase):
self.assertEqual(ce_loss.ignore_index, 37)
ce_loss = CELoss(dict_cfg, ignore_char=-1)
self.assertEqual(ce_loss.ignore_index, -1)
with self.assertRaises(ValueError):
# with self.assertRaises(ValueError):
with self.assertWarns(UserWarning):
ce_loss = CELoss(dict_cfg, ignore_char='ignore')
ce_loss = CELoss(dict_cfg, ignore_char='1')
self.assertEqual(ce_loss.ignore_index, 1)
def test_ce_loss(self):
def test_forward(self):
dict_cfg = dict(
type='Dictionary',
dict_file='dicts/lower_english_digits.txt',