mirror of https://github.com/open-mmlab/mmocr.git
[Refactor] CTCLoss
parent
3aae157aec
commit
7582fdea41
|
@ -20,7 +20,6 @@ class BaseRecogLoss(nn.Module):
|
|||
generated from decoder. Defaults to 40.
|
||||
letter_case (str): There are three options to alter the letter cases
|
||||
of gt texts:
|
||||
|
||||
- unchanged: Do not change gt texts.
|
||||
- upper: Convert gt texts into uppercase characters.
|
||||
- lower: Convert gt texts into lowercase characters.
|
||||
|
@ -61,7 +60,7 @@ class BaseRecogLoss(nn.Module):
|
|||
|
||||
- indexes (torch.LongTensor): Character indexes representing gt
|
||||
texts.
|
||||
- padding_indexes (torch.LongTensor) Character indexes
|
||||
- padded_indexes (torch.LongTensor) Character indexes
|
||||
representing gt texts, following several padding_idxs until
|
||||
reaching the length of ``max_seq_len``.
|
||||
"""
|
||||
|
@ -88,14 +87,14 @@ class BaseRecogLoss(nn.Module):
|
|||
slice_end = src_target.size(0) - 1
|
||||
src_target = src_target[slice_start:slice_end]
|
||||
if self.dictionary.padding_idx is not None:
|
||||
padding_indexes = (torch.ones(self.max_seq_len) *
|
||||
self.dictionary.padding_idx).long()
|
||||
padded_indexes = (torch.ones(self.max_seq_len) *
|
||||
self.dictionary.padding_idx).long()
|
||||
char_num = min(src_target.size(0), self.max_seq_len)
|
||||
padding_indexes[:char_num] = src_target[:char_num]
|
||||
padded_indexes[:char_num] = src_target[:char_num]
|
||||
else:
|
||||
padding_indexes = src_target
|
||||
padded_indexes = src_target
|
||||
|
||||
# put in DataSample
|
||||
data_sample.gt_text.indexes = indexes
|
||||
data_sample.gt_text.padding_indexes = padding_indexes
|
||||
data_sample.gt_text.padded_indexes = padded_indexes
|
||||
return data_samples
|
||||
|
|
|
@ -1,19 +1,31 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import math
|
||||
from typing import Dict, Sequence, Union
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from mmocr.core.data_structures import TextRecogDataSample
|
||||
from mmocr.models.textrecog.dictionary.dictionary import Dictionary
|
||||
from mmocr.registry import MODELS
|
||||
from .base_recog_loss import BaseRecogLoss
|
||||
|
||||
|
||||
@MODELS.register_module()
|
||||
class CTCLoss(nn.Module):
|
||||
class CTCLoss(BaseRecogLoss):
|
||||
"""Implementation of loss module for CTC-loss based text recognition.
|
||||
|
||||
Args:
|
||||
dictionary (dict or :obj:`Dictionary`): The config for `Dictionary` or
|
||||
the instance of `Dictionary`.
|
||||
letter_case (str): There are three options to alter the letter cases
|
||||
of gt texts:
|
||||
- unchanged: Do not change gt texts.
|
||||
- upper: Convert gt texts into uppercase characters.
|
||||
- lower: Convert gt texts into lowercase characters.
|
||||
Usually, it only works for English characters. Defaults to
|
||||
'unchanged'.
|
||||
flatten (bool): If True, use flattened targets, else padded targets.
|
||||
blank (int): Blank label. Default 0.
|
||||
reduction (str): Specifies the reduction to apply to the output,
|
||||
should be one of the following: ('none', 'mean', 'sum').
|
||||
zero_infinity (bool): Whether to zero infinite losses and
|
||||
|
@ -23,81 +35,93 @@ class CTCLoss(nn.Module):
|
|||
"""
|
||||
|
||||
def __init__(self,
|
||||
flatten=True,
|
||||
blank=0,
|
||||
reduction='mean',
|
||||
zero_infinity=False,
|
||||
**kwargs):
|
||||
super().__init__()
|
||||
dictionary: Union[Dict, Dictionary],
|
||||
letter_case: str = 'unchanged',
|
||||
flatten: bool = True,
|
||||
reduction: str = 'mean',
|
||||
zero_infinity: bool = False,
|
||||
**kwargs) -> None:
|
||||
super().__init__(dictionary=dictionary, letter_case=letter_case)
|
||||
assert isinstance(flatten, bool)
|
||||
assert isinstance(blank, int)
|
||||
assert isinstance(reduction, str)
|
||||
assert isinstance(zero_infinity, bool)
|
||||
|
||||
self.flatten = flatten
|
||||
self.blank = blank
|
||||
self.ctc_loss = nn.CTCLoss(
|
||||
blank=blank, reduction=reduction, zero_infinity=zero_infinity)
|
||||
blank=self.dictionary.padding_idx,
|
||||
reduction=reduction,
|
||||
zero_infinity=zero_infinity)
|
||||
|
||||
def forward(self, outputs, targets_dict, img_metas=None):
|
||||
def forward(self, outputs: torch.Tensor,
|
||||
data_samples: Sequence[TextRecogDataSample]) -> Dict:
|
||||
"""
|
||||
Args:
|
||||
outputs (Tensor): A raw logit tensor of shape :math:`(N, T, C)`.
|
||||
targets_dict (dict): A dict with 3 keys ``target_lengths``,
|
||||
``flatten_targets`` and ``targets``.
|
||||
|
||||
- | ``target_lengths`` (Tensor): A tensor of shape :math:`(N)`.
|
||||
Each item is the length of a word.
|
||||
|
||||
- | ``flatten_targets`` (Tensor): Used if ``self.flatten=True``
|
||||
(default). A tensor of shape
|
||||
(sum(targets_dict['target_lengths'])). Each item is the
|
||||
index of a character.
|
||||
|
||||
- | ``targets`` (Tensor): Used if ``self.flatten=False``. A
|
||||
tensor of :math:`(N, T)`. Empty slots are padded with
|
||||
``self.blank``.
|
||||
|
||||
img_metas (dict): A dict that contains meta information of input
|
||||
images. Preferably with the key ``valid_ratio``.
|
||||
data_samples (list[TextRecogDataSample]): List of
|
||||
``TextRecogDataSample`` which are processed by ``get_target``.
|
||||
|
||||
Returns:
|
||||
dict: The loss dict with key ``loss_ctc``.
|
||||
"""
|
||||
valid_ratios = None
|
||||
if img_metas is not None:
|
||||
if data_samples is not None:
|
||||
valid_ratios = [
|
||||
img_meta.get('valid_ratio', 1.0) for img_meta in img_metas
|
||||
img_meta.get('valid_ratio', 1.0) for img_meta in data_samples
|
||||
]
|
||||
|
||||
outputs = torch.log_softmax(outputs, dim=2)
|
||||
bsz, seq_len = outputs.size(0), outputs.size(1)
|
||||
outputs_for_loss = outputs.permute(1, 0, 2).contiguous() # T * N * C
|
||||
|
||||
if self.flatten:
|
||||
targets = targets_dict['flatten_targets']
|
||||
else:
|
||||
targets = torch.full(
|
||||
size=(bsz, seq_len), fill_value=self.blank, dtype=torch.long)
|
||||
for idx, tensor in enumerate(targets_dict['targets']):
|
||||
valid_len = min(tensor.size(0), seq_len)
|
||||
targets[idx, :valid_len] = tensor[:valid_len]
|
||||
|
||||
target_lengths = targets_dict['target_lengths']
|
||||
targets = [data_sample.gt_text.indexes for data_sample in data_samples]
|
||||
target_lengths = torch.IntTensor([len(t) for t in targets])
|
||||
target_lengths = torch.clamp(target_lengths, min=1, max=seq_len).long()
|
||||
|
||||
input_lengths = torch.full(
|
||||
size=(bsz, ), fill_value=seq_len, dtype=torch.long)
|
||||
if not self.flatten and valid_ratios is not None:
|
||||
input_lengths = [
|
||||
math.ceil(valid_ratio * seq_len)
|
||||
for valid_ratio in valid_ratios
|
||||
]
|
||||
input_lengths = torch.Tensor(input_lengths).long()
|
||||
if self.flatten:
|
||||
targets = torch.cat(targets)
|
||||
else:
|
||||
padded_targets = torch.full(
|
||||
size=(bsz, seq_len),
|
||||
fill_value=self.dictionary.padding_idx,
|
||||
dtype=torch.long)
|
||||
for idx, valid_len in enumerate(target_lengths):
|
||||
padded_targets[idx, :valid_len] = targets[idx][:valid_len]
|
||||
targets = padded_targets
|
||||
|
||||
if valid_ratios is not None:
|
||||
input_lengths = [
|
||||
math.ceil(valid_ratio * seq_len)
|
||||
for valid_ratio in valid_ratios
|
||||
]
|
||||
input_lengths = torch.Tensor(input_lengths).long()
|
||||
loss_ctc = self.ctc_loss(outputs_for_loss, targets, input_lengths,
|
||||
target_lengths)
|
||||
|
||||
losses = dict(loss_ctc=loss_ctc)
|
||||
|
||||
return losses
|
||||
|
||||
def get_targets(
|
||||
self, data_samples: Sequence[TextRecogDataSample]
|
||||
) -> Sequence[TextRecogDataSample]:
|
||||
"""Target generator.
|
||||
|
||||
Args:
|
||||
data_samples (list[TextRecogDataSample]): It usually includes
|
||||
``gt_text`` information.
|
||||
|
||||
Returns:
|
||||
|
||||
list[TextRecogDataSample]: updated data_samples. It will add two
|
||||
key in data_sample:
|
||||
|
||||
- indexes (torch.LongTensor): The index corresponding to the item.
|
||||
"""
|
||||
|
||||
for data_sample in data_samples:
|
||||
text = data_sample.gt_text.item
|
||||
if self.letter_case in ['upper', 'lower']:
|
||||
text = getattr(text, self.letter_case)()
|
||||
indexes = self.dictionary.str2idx(text)
|
||||
indexes = torch.IntTensor(indexes)
|
||||
data_sample.gt_text.indexes = indexes
|
||||
return data_samples
|
||||
|
|
|
@ -76,7 +76,7 @@ class TestBaseRecogLoss(TestCase):
|
|||
torch.LongTensor([0, 1, 2, 3]))
|
||||
padding_idx = dictionary.padding_idx
|
||||
assert self._equal(
|
||||
target_data_samples[0].gt_text.padding_indexes,
|
||||
target_data_samples[0].gt_text.padded_indexes,
|
||||
torch.LongTensor([
|
||||
dictionary.start_idx, 0, 1, 2, 3, dictionary.end_idx,
|
||||
padding_idx, padding_idx, padding_idx, padding_idx
|
||||
|
@ -95,7 +95,7 @@ class TestBaseRecogLoss(TestCase):
|
|||
assert self._equal(target_data_samples[0].gt_text.indexes,
|
||||
torch.LongTensor([0, 1, 2, 3]))
|
||||
padding_idx = dictionary.padding_idx
|
||||
assert self._equal(target_data_samples[0].gt_text.padding_indexes,
|
||||
assert self._equal(target_data_samples[0].gt_text.padded_indexes,
|
||||
torch.LongTensor([0, 1, 2]))
|
||||
|
||||
dict_cfg = dict(
|
||||
|
@ -112,7 +112,7 @@ class TestBaseRecogLoss(TestCase):
|
|||
target_data_samples = base_recog_loss.get_targets([data_sample])
|
||||
assert self._equal(target_data_samples[0].gt_text.indexes,
|
||||
torch.LongTensor([0, 1, 2, 3]))
|
||||
assert self._equal(target_data_samples[0].gt_text.padding_indexes,
|
||||
assert self._equal(target_data_samples[0].gt_text.padded_indexes,
|
||||
torch.LongTensor([0, 1, 2, 3]))
|
||||
|
||||
target_data_samples = base_recog_loss.get_targets([])
|
||||
|
|
|
@ -0,0 +1,85 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import os.path as osp
|
||||
import tempfile
|
||||
from unittest import TestCase
|
||||
|
||||
import torch
|
||||
from mmengine.data import LabelData
|
||||
|
||||
from mmocr.core.data_structures import TextRecogDataSample
|
||||
from mmocr.models.textrecog.dictionary import Dictionary
|
||||
from mmocr.models.textrecog.losses import CTCLoss
|
||||
|
||||
|
||||
class TestCTCLoss(TestCase):
|
||||
|
||||
def test_ctc_loss(self):
|
||||
tmp_dir = tempfile.TemporaryDirectory()
|
||||
# create dummy data
|
||||
dict_file = osp.join(tmp_dir.name, 'fake_chars.txt')
|
||||
self._create_dummy_dict_file(dict_file)
|
||||
|
||||
dictionary = Dictionary(dict_file=dict_file, with_padding=True)
|
||||
with self.assertRaises(AssertionError):
|
||||
CTCLoss(dictionary=dictionary, flatten='flatten')
|
||||
with self.assertRaises(AssertionError):
|
||||
CTCLoss(dictionary=dictionary, reduction=1)
|
||||
with self.assertRaises(AssertionError):
|
||||
CTCLoss(dictionary=dictionary, zero_infinity='zero')
|
||||
|
||||
outputs = torch.zeros(2, 40, 37)
|
||||
datasample1 = TextRecogDataSample()
|
||||
gt_text1 = LabelData(item='hell')
|
||||
datasample1.gt_text = gt_text1
|
||||
datasample2 = datasample1.clone()
|
||||
gt_text2 = LabelData(item='owrd')
|
||||
datasample2.gt_text = gt_text2
|
||||
data_samples = [datasample1, datasample2]
|
||||
ctc_loss = CTCLoss(dictionary=dictionary)
|
||||
data_samples = ctc_loss.get_targets(data_samples)
|
||||
losses = ctc_loss(outputs, data_samples)
|
||||
assert isinstance(losses, dict)
|
||||
assert 'loss_ctc' in losses
|
||||
assert torch.allclose(losses['loss_ctc'],
|
||||
torch.tensor(losses['loss_ctc'].item()).float())
|
||||
# test flatten = False
|
||||
ctc_loss = CTCLoss(dictionary=dictionary, flatten=False)
|
||||
losses = ctc_loss(outputs, data_samples)
|
||||
assert isinstance(losses, dict)
|
||||
assert 'loss_ctc' in losses
|
||||
assert torch.allclose(losses['loss_ctc'],
|
||||
torch.tensor(losses['loss_ctc'].item()).float())
|
||||
tmp_dir.cleanup()
|
||||
|
||||
def _create_dummy_dict_file(self, dict_file):
|
||||
chars = list('helowrd')
|
||||
with open(dict_file, 'w') as fw:
|
||||
for char in chars:
|
||||
fw.write(char + '\n')
|
||||
|
||||
def test_get_targets(self):
|
||||
tmp_dir = tempfile.TemporaryDirectory()
|
||||
# create dummy data
|
||||
dict_file = osp.join(tmp_dir.name, 'fake_chars.txt')
|
||||
self._create_dummy_dict_file(dict_file)
|
||||
|
||||
dictionary = Dictionary(dict_file=dict_file, with_padding=True)
|
||||
loss = CTCLoss(dictionary=dictionary, letter_case='lower')
|
||||
# test encode str to tensor
|
||||
datasample1 = TextRecogDataSample()
|
||||
gt_text1 = LabelData(item='hell')
|
||||
datasample1.gt_text = gt_text1
|
||||
datasample2 = datasample1.clone()
|
||||
gt_text2 = LabelData(item='owrd')
|
||||
datasample2.gt_text = gt_text2
|
||||
|
||||
data_samples = [datasample1, datasample2]
|
||||
expect_tensor1 = torch.IntTensor([0, 1, 2, 2])
|
||||
expect_tensor2 = torch.IntTensor([3, 4, 5, 6])
|
||||
|
||||
data_samples = loss.get_targets(data_samples)
|
||||
self.assertTrue(
|
||||
torch.allclose(data_samples[0].gt_text.indexes, expect_tensor1))
|
||||
self.assertTrue(
|
||||
torch.allclose(data_samples[1].gt_text.indexes, expect_tensor2))
|
||||
tmp_dir.cleanup()
|
Loading…
Reference in New Issue