mirror of https://github.com/open-mmlab/mmocr.git
[ABINet] Refactor ABILoss
parent
ee1212a5cd
commit
a844b497db
|
@ -1,8 +1,8 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from .abi_loss import ABILoss
|
||||
from .base_recog_loss import BaseRecogLoss
|
||||
from .ce_loss import CELoss
|
||||
from .ctc_loss import CTCLoss
|
||||
from .mix_loss import ABILoss
|
||||
from .seg_loss import SegLoss
|
||||
|
||||
__all__ = ['BaseRecogLoss', 'CELoss', 'CTCLoss', 'SegLoss', 'ABILoss']
|
||||
|
|
|
@ -0,0 +1,99 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from typing import Dict, Sequence, Union
|
||||
|
||||
import torch
|
||||
|
||||
from mmocr.data import TextRecogDataSample
|
||||
from mmocr.models.textrecog.dictionary.dictionary import Dictionary
|
||||
from mmocr.registry import MODELS
|
||||
from .base_recog_loss import BaseRecogLoss
|
||||
from .ce_loss import CELoss
|
||||
|
||||
|
||||
@MODELS.register_module()
|
||||
class ABILoss(BaseRecogLoss):
|
||||
"""Implementation of ABINet multiloss that allows mixing different types of
|
||||
losses with weights.
|
||||
|
||||
Args:
|
||||
dictionary (dict or :obj:`Dictionary`): The config for `Dictionary` or
|
||||
the instance of `Dictionary`.
|
||||
max_seq_len (int): Maximum sequence length. The sequence is usually
|
||||
generated from decoder. Defaults to 40.
|
||||
letter_case (str): There are three options to alter the letter cases
|
||||
of gt texts:
|
||||
- unchanged: Do not change gt texts.
|
||||
- upper: Convert gt texts into uppercase characters.
|
||||
- lower: Convert gt texts into lowercase characters.
|
||||
Usually, it only works for English characters. Defaults to
|
||||
'unchanged'.
|
||||
weight_vis (float or int): The weight of vision decoder loss. Defaults
|
||||
to 1.0.
|
||||
weight_dec (float or int): The weight of language decoder loss.
|
||||
Defaults to 1.0.
|
||||
weight_fusion (float or int): The weight of fuser (aligner) loss.
|
||||
Defaults to 1.0.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
dictionary: Union[Dict, Dictionary],
|
||||
max_seq_len: int = 40,
|
||||
letter_case: str = 'unchanged',
|
||||
weight_vis: Union[float, int] = 1.0,
|
||||
weight_lang: Union[float, int] = 1.0,
|
||||
weight_fusion: Union[float, int] = 1.0,
|
||||
**kwargs) -> None:
|
||||
assert isinstance(weight_vis, (float, int))
|
||||
assert isinstance(weight_lang, (float, int))
|
||||
assert isinstance(weight_fusion, (float, int))
|
||||
super().__init__(
|
||||
dictionary=dictionary,
|
||||
max_seq_len=max_seq_len,
|
||||
letter_case=letter_case)
|
||||
self.weight_vis = weight_vis
|
||||
self.weight_lang = weight_lang
|
||||
self.weight_fusion = weight_fusion
|
||||
self._ce_loss = CELoss(
|
||||
self.dictionary,
|
||||
max_seq_len,
|
||||
letter_case,
|
||||
reduction='mean',
|
||||
ignore_first_char=True)
|
||||
|
||||
def forward(self, outputs: Dict,
|
||||
data_samples: Sequence[TextRecogDataSample]) -> Dict:
|
||||
"""
|
||||
Args:
|
||||
outputs (dict): The output dictionary with at least one of
|
||||
``out_vis``, ``out_langs`` and ``out_fusers`` specified.
|
||||
data_samples (list[TextRecogDataSample]): List of
|
||||
``TextRecogDataSample`` which are processed by ``get_target``.
|
||||
|
||||
Returns:
|
||||
dict: A loss dictionary with ``loss_visual``, ``loss_lang`` and
|
||||
``loss_fusion``. Each should either be the loss tensor or None if
|
||||
the output of its corresponding module is not given.
|
||||
"""
|
||||
assert 'out_vis' in outputs or \
|
||||
'out_langs' in outputs or 'out_fusers' in outputs
|
||||
losses = {}
|
||||
|
||||
if outputs.get('out_vis', None):
|
||||
losses['loss_visual'] = self.weight_vis * self._ce_loss(
|
||||
outputs['out_vis']['logits'], data_samples)['loss_ce']
|
||||
if outputs.get('out_langs', None):
|
||||
lang_losses = []
|
||||
for out_lang in outputs['out_langs']:
|
||||
lang_losses.append(
|
||||
self._ce_loss(out_lang['logits'], data_samples)['loss_ce'])
|
||||
losses['loss_lang'] = self.weight_lang * torch.mean(
|
||||
torch.stack(lang_losses))
|
||||
if outputs.get('out_fusers', None):
|
||||
fuser_losses = []
|
||||
for out_fuser in outputs['out_fusers']:
|
||||
fuser_losses.append(
|
||||
self._ce_loss(out_fuser['logits'],
|
||||
data_samples)['loss_ce'])
|
||||
losses['loss_fusion'] = self.weight_fusion * torch.mean(
|
||||
torch.stack(fuser_losses))
|
||||
return losses
|
|
@ -25,12 +25,21 @@ class BaseRecogLoss(nn.Module):
|
|||
- lower: Convert gt texts into lowercase characters.
|
||||
Usually, it only works for English characters. Defaults to
|
||||
'unchanged'.
|
||||
pad_with (str): The padding strategy for ``gt_text.padded_indexes``.
|
||||
Defaults to 'auto'. Options are:
|
||||
- 'auto': Use dictionary.padding_idx to pad gt texts, or
|
||||
dictionary.end_idx if dictionary.padding_idx
|
||||
is None.
|
||||
- 'padding': Always use dictionary.padding_idx to pad gt texts.
|
||||
- 'end': Always use dictionary.end_idx to pad gt texts.
|
||||
- 'none': Do not pad gt texts.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
dictionary: Union[Dict, Dictionary],
|
||||
max_seq_len: int = 40,
|
||||
letter_case: str = 'unchanged',
|
||||
pad_with: str = 'auto',
|
||||
**kwargs) -> None:
|
||||
super().__init__()
|
||||
if isinstance(dictionary, dict):
|
||||
|
@ -45,6 +54,25 @@ class BaseRecogLoss(nn.Module):
|
|||
assert letter_case in ['unchanged', 'upper', 'lower']
|
||||
self.letter_case = letter_case
|
||||
|
||||
assert pad_with in ['auto', 'padding', 'end', 'none']
|
||||
if pad_with == 'auto':
|
||||
self.pad_idx = self.dictionary.padding_idx or \
|
||||
self.dictionary.end_idx
|
||||
elif pad_with == 'padding':
|
||||
self.pad_idx = self.dictionary.padding_idx
|
||||
elif pad_with == 'end':
|
||||
self.pad_idx = self.dictionary.end_idx
|
||||
else:
|
||||
self.pad_idx = None
|
||||
if self.pad_idx is None and pad_with != 'none':
|
||||
if pad_with == 'auto':
|
||||
raise ValueError('pad_with="auto", but dictionary.end_idx'
|
||||
' and dictionary.padding_idx are both None')
|
||||
else:
|
||||
raise ValueError(
|
||||
f'pad_with="{pad_with}", but dictionary.{pad_with}_idx is'
|
||||
' None')
|
||||
|
||||
def get_targets(
|
||||
self, data_samples: Sequence[TextRecogDataSample]
|
||||
) -> Sequence[TextRecogDataSample]:
|
||||
|
@ -59,10 +87,12 @@ class BaseRecogLoss(nn.Module):
|
|||
added to data_sample:
|
||||
|
||||
- indexes (torch.LongTensor): Character indexes representing gt
|
||||
texts.
|
||||
- padded_indexes (torch.LongTensor) Character indexes
|
||||
representing gt texts, following several padding_idxs until
|
||||
reaching the length of ``max_seq_len``.
|
||||
texts. All special tokens are excluded, except for UKN.
|
||||
- padded_indexes (torch.LongTensor): Character indexes
|
||||
representing gt texts with BOS and EOS if applicable, following
|
||||
several padding indexes until the length reaches ``max_seq_len``.
|
||||
In particular, if ``pad_with='none'``, no padding will be
|
||||
applied.
|
||||
"""
|
||||
|
||||
for data_sample in data_samples:
|
||||
|
@ -88,14 +118,13 @@ class BaseRecogLoss(nn.Module):
|
|||
else:
|
||||
slice_end = src_target.size(0) - 1
|
||||
src_target = src_target[slice_start:slice_end]
|
||||
if self.dictionary.padding_idx is not None:
|
||||
if self.pad_idx is not None:
|
||||
padded_indexes = (torch.ones(self.max_seq_len) *
|
||||
self.dictionary.padding_idx).long()
|
||||
self.pad_idx).long()
|
||||
char_num = min(src_target.size(0), self.max_seq_len)
|
||||
padded_indexes[:char_num] = src_target[:char_num]
|
||||
else:
|
||||
padded_indexes = src_target
|
||||
|
||||
# put in DataSample
|
||||
data_sample.gt_text.indexes = indexes
|
||||
data_sample.gt_text.padded_indexes = padded_indexes
|
||||
|
|
|
@ -1,4 +1,5 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import warnings
|
||||
from typing import Dict, Sequence, Union
|
||||
|
||||
import torch
|
||||
|
@ -36,6 +37,8 @@ class CELoss(BaseRecogLoss):
|
|||
and 'unknown', which refer to their corresponding special
|
||||
tokens in the dictionary. It will not ignore any special
|
||||
tokens when ignore_char == -1 or 'none'. Defaults to 'padding'.
|
||||
flatten (bool): Whether to flatten the output and target before
|
||||
computing CE loss. Defaults to False.
|
||||
reduction (str): Specifies the reduction to apply to the output,
|
||||
should be one of the following: ('none', 'mean', 'sum'). Defaults
|
||||
to 'none'.
|
||||
|
@ -82,8 +85,11 @@ class CELoss(BaseRecogLoss):
|
|||
ignore_index = mapping_table.get(
|
||||
ignore_char, self.dictionary._char2idx.get(ignore_char, None))
|
||||
if ignore_index is None:
|
||||
raise ValueError(
|
||||
f'{ignore_char} does not exist in the dictionary')
|
||||
warnings.warn(
|
||||
f'{ignore_char} does not exist in the dictionary',
|
||||
UserWarning)
|
||||
ignore_index = -1
|
||||
|
||||
self.ignore_char = ignore_char
|
||||
self.ignore_index = ignore_index
|
||||
self.loss_ce = nn.CrossEntropyLoss(
|
||||
|
|
|
@ -1,109 +0,0 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
from mmocr.registry import MODELS
|
||||
|
||||
|
||||
@MODELS.register_module()
|
||||
class ABILoss(nn.Module):
|
||||
"""Implementation of ABINet multiloss that allows mixing different types of
|
||||
losses with weights.
|
||||
|
||||
Args:
|
||||
enc_weight (float): The weight of encoder loss. Defaults to 1.0.
|
||||
dec_weight (float): The weight of decoder loss. Defaults to 1.0.
|
||||
fusion_weight (float): The weight of fuser (aligner) loss.
|
||||
Defaults to 1.0.
|
||||
num_classes (int): Number of unique output language tokens.
|
||||
|
||||
Returns:
|
||||
A dictionary whose key/value pairs are the losses of three modules.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
enc_weight=1.0,
|
||||
dec_weight=1.0,
|
||||
fusion_weight=1.0,
|
||||
num_classes=37,
|
||||
**kwargs):
|
||||
assert isinstance(enc_weight, float) or isinstance(enc_weight, int)
|
||||
assert isinstance(dec_weight, float) or isinstance(dec_weight, int)
|
||||
assert isinstance(fusion_weight, float) or \
|
||||
isinstance(fusion_weight, int)
|
||||
super().__init__()
|
||||
self.enc_weight = enc_weight
|
||||
self.dec_weight = dec_weight
|
||||
self.fusion_weight = fusion_weight
|
||||
self.num_classes = num_classes
|
||||
|
||||
def _flatten(self, logits, target_lens):
|
||||
flatten_logits = torch.cat(
|
||||
[s[:target_lens[i]] for i, s in enumerate(logits)])
|
||||
return flatten_logits
|
||||
|
||||
def _ce_loss(self, logits, targets):
|
||||
targets_one_hot = F.one_hot(targets, self.num_classes)
|
||||
log_prob = F.log_softmax(logits, dim=-1)
|
||||
loss = -(targets_one_hot.to(log_prob.device) * log_prob).sum(dim=-1)
|
||||
return loss.mean()
|
||||
|
||||
def _loss_over_iters(self, outputs, targets):
|
||||
"""
|
||||
Args:
|
||||
outputs (list[Tensor]): Each tensor has shape (N, T, C) where N is
|
||||
the batch size, T is the sequence length and C is the number of
|
||||
classes.
|
||||
targets_dicts (dict): The dictionary with at least `padded_targets`
|
||||
defined.
|
||||
"""
|
||||
iter_num = len(outputs)
|
||||
dec_outputs = torch.cat(outputs, dim=0)
|
||||
flatten_targets_iternum = targets.repeat(iter_num)
|
||||
return self._ce_loss(dec_outputs, flatten_targets_iternum)
|
||||
|
||||
def forward(self, outputs, targets_dict, img_metas=None):
|
||||
"""
|
||||
Args:
|
||||
outputs (dict): The output dictionary with at least one of
|
||||
``out_enc``, ``out_dec`` and ``out_fusers`` specified.
|
||||
targets_dict (dict): The target dictionary containing the key
|
||||
``padded_targets``, which represents target sequences in
|
||||
shape (batch_size, sequence_length).
|
||||
|
||||
Returns:
|
||||
A loss dictionary with ``loss_visual``, ``loss_lang`` and
|
||||
``loss_fusion``. Each should either be the loss tensor or ``0`` if
|
||||
the output of its corresponding module is not given.
|
||||
"""
|
||||
assert 'out_enc' in outputs or \
|
||||
'out_dec' in outputs or 'out_fusers' in outputs
|
||||
losses = {}
|
||||
|
||||
target_lens = [len(t) for t in targets_dict['targets']]
|
||||
flatten_targets = torch.cat([t for t in targets_dict['targets']])
|
||||
|
||||
if outputs.get('out_enc', None):
|
||||
enc_input = self._flatten(outputs['out_enc']['logits'],
|
||||
target_lens)
|
||||
enc_loss = self._ce_loss(enc_input,
|
||||
flatten_targets) * self.enc_weight
|
||||
losses['loss_visual'] = enc_loss
|
||||
if outputs.get('out_decs', None):
|
||||
dec_logits = [
|
||||
self._flatten(o['logits'], target_lens)
|
||||
for o in outputs['out_decs']
|
||||
]
|
||||
dec_loss = self._loss_over_iters(dec_logits,
|
||||
flatten_targets) * self.dec_weight
|
||||
losses['loss_lang'] = dec_loss
|
||||
if outputs.get('out_fusers', None):
|
||||
fusion_logits = [
|
||||
self._flatten(o['logits'], target_lens)
|
||||
for o in outputs['out_fusers']
|
||||
]
|
||||
fusion_loss = self._loss_over_iters(
|
||||
fusion_logits, flatten_targets) * self.fusion_weight
|
||||
losses['loss_fusion'] = fusion_loss
|
||||
return losses
|
|
@ -0,0 +1,65 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from unittest import TestCase
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from mmengine.data import LabelData
|
||||
|
||||
from mmocr.data import TextRecogDataSample
|
||||
from mmocr.models.textrecog.losses import ABILoss
|
||||
|
||||
|
||||
class TestABILoss(TestCase):
|
||||
|
||||
def setUp(self) -> None:
|
||||
|
||||
data_sample1 = TextRecogDataSample()
|
||||
data_sample1.gt_text = LabelData(item='hello')
|
||||
data_sample2 = TextRecogDataSample()
|
||||
data_sample2.gt_text = LabelData(item='123')
|
||||
self.gt = [data_sample1, data_sample2]
|
||||
|
||||
def _equal(self, a, b):
|
||||
if isinstance(a, (torch.Tensor, np.ndarray)):
|
||||
return (a == b).all()
|
||||
else:
|
||||
return a == b
|
||||
|
||||
def test_forward(self):
|
||||
dict_cfg = dict(
|
||||
type='Dictionary',
|
||||
dict_file='dicts/lower_english_digits.txt',
|
||||
with_start=True,
|
||||
with_end=True,
|
||||
same_start_end=True,
|
||||
with_padding=True,
|
||||
with_unknown=False)
|
||||
abi_loss = ABILoss(dict_cfg, max_seq_len=10)
|
||||
abi_loss.get_targets(self.gt)
|
||||
outputs = dict(
|
||||
out_vis=dict(logits=torch.randn(2, 10, 38)),
|
||||
out_langs=[
|
||||
dict(logits=torch.randn(2, 10, 38)),
|
||||
dict(logits=torch.randn(2, 10, 38))
|
||||
],
|
||||
out_fusers=[
|
||||
dict(logits=torch.randn(2, 10, 38)),
|
||||
dict(logits=torch.randn(2, 10, 38))
|
||||
])
|
||||
losses = abi_loss(outputs, self.gt)
|
||||
self.assertIsInstance(losses, dict)
|
||||
self.assertIn('loss_visual', losses)
|
||||
self.assertIn('loss_lang', losses)
|
||||
self.assertIn('loss_fusion', losses)
|
||||
print(losses['loss_lang'])
|
||||
print(losses['loss_fusion'])
|
||||
|
||||
outputs.pop('out_vis')
|
||||
abi_loss(outputs, self.gt)
|
||||
out_langs = outputs.pop('out_langs')
|
||||
abi_loss(outputs, self.gt)
|
||||
outputs.pop('out_fusers')
|
||||
with self.assertRaises(AssertionError):
|
||||
abi_loss(outputs, self.gt)
|
||||
outputs['out_langs'] = out_langs
|
||||
abi_loss(outputs, self.gt)
|
|
@ -45,6 +45,22 @@ class TestBaseRecogLoss(TestCase):
|
|||
# test case mode
|
||||
with self.assertRaises(AssertionError):
|
||||
base_recog_loss = BaseRecogLoss(dict_cfg, letter_case='no')
|
||||
# test invalid pad_with
|
||||
with self.assertRaises(AssertionError):
|
||||
base_recog_loss = BaseRecogLoss(dict_cfg, pad_with='test')
|
||||
# test invalid combination of dictionary and pad_with
|
||||
dict_cfg = dict(type='Dictionary', dict_file=dict_file, with_end=False)
|
||||
for pad_with in ['end', 'padding']:
|
||||
with self.assertRaisesRegex(
|
||||
ValueError, f'pad_with="{pad_with}", but'
|
||||
f' dictionary.{pad_with}_idx is None'):
|
||||
base_recog_loss = BaseRecogLoss(dict_cfg, pad_with=pad_with)
|
||||
with self.assertRaisesRegex(
|
||||
ValueError, 'pad_with="auto", but'
|
||||
' dictionary.end_idx and dictionary.padding_idx are both'
|
||||
' None'):
|
||||
base_recog_loss = BaseRecogLoss(dict_cfg, pad_with='auto')
|
||||
|
||||
# test dictionary is invalid type
|
||||
dict_cfg = ['tmp']
|
||||
with self.assertRaisesRegex(
|
||||
|
@ -82,8 +98,10 @@ class TestBaseRecogLoss(TestCase):
|
|||
padding_idx, padding_idx, padding_idx, padding_idx
|
||||
]))
|
||||
self.assertTrue(target_data_samples[0].have_target)
|
||||
|
||||
target_data_samples = base_recog_loss.get_targets(target_data_samples)
|
||||
data_sample.set_metainfo(dict(have_target=False))
|
||||
|
||||
dictionary = Dictionary(
|
||||
dict_file=dict_file,
|
||||
with_start=False,
|
||||
|
@ -100,23 +118,25 @@ class TestBaseRecogLoss(TestCase):
|
|||
assert self._equal(target_data_samples[0].gt_text.padded_indexes,
|
||||
torch.LongTensor([0, 1, 2]))
|
||||
data_sample.set_metainfo(dict(have_target=False))
|
||||
|
||||
dict_cfg = dict(
|
||||
type='Dictionary',
|
||||
dict_file=dict_file,
|
||||
dict_file='dicts/lower_english_digits.txt',
|
||||
with_start=False,
|
||||
with_end=False,
|
||||
with_end=True,
|
||||
same_start_end=False,
|
||||
with_padding=False,
|
||||
with_padding=True,
|
||||
with_unknown=True)
|
||||
base_recog_loss = BaseRecogLoss(
|
||||
dict_cfg, max_seq_len=10, letter_case='lower')
|
||||
dict_cfg, max_seq_len=10, letter_case='lower', pad_with='none')
|
||||
data_sample.gt_text.item = '0123'
|
||||
target_data_samples = base_recog_loss.get_targets([data_sample])
|
||||
assert self._equal(target_data_samples[0].gt_text.indexes,
|
||||
torch.LongTensor([0, 1, 2, 3]))
|
||||
assert self._equal(target_data_samples[0].gt_text.padded_indexes,
|
||||
torch.LongTensor([0, 1, 2, 3]))
|
||||
torch.LongTensor([0, 1, 2, 3, 36]))
|
||||
|
||||
target_data_samples = base_recog_loss.get_targets([])
|
||||
self.assertListEqual(target_data_samples, [])
|
||||
|
||||
tmp_dir.cleanup()
|
||||
|
|
|
@ -44,12 +44,13 @@ class TestCELoss(TestCase):
|
|||
self.assertEqual(ce_loss.ignore_index, 37)
|
||||
ce_loss = CELoss(dict_cfg, ignore_char=-1)
|
||||
self.assertEqual(ce_loss.ignore_index, -1)
|
||||
with self.assertRaises(ValueError):
|
||||
# with self.assertRaises(ValueError):
|
||||
with self.assertWarns(UserWarning):
|
||||
ce_loss = CELoss(dict_cfg, ignore_char='ignore')
|
||||
ce_loss = CELoss(dict_cfg, ignore_char='1')
|
||||
self.assertEqual(ce_loss.ignore_index, 1)
|
||||
|
||||
def test_ce_loss(self):
|
||||
def test_forward(self):
|
||||
dict_cfg = dict(
|
||||
type='Dictionary',
|
||||
dict_file='dicts/lower_english_digits.txt',
|
||||
|
|
Loading…
Reference in New Issue