mirror of https://github.com/open-mmlab/mmocr.git
110 lines
4.3 KiB
Python
110 lines
4.3 KiB
Python
# Copyright (c) OpenMMLab. All rights reserved.
|
|
import torch
|
|
import torch.nn as nn
|
|
import torch.nn.functional as F
|
|
|
|
from mmocr.registry import MODELS
|
|
|
|
|
|
@MODELS.register_module()
|
|
class ABILoss(nn.Module):
|
|
"""Implementation of ABINet multiloss that allows mixing different types of
|
|
losses with weights.
|
|
|
|
Args:
|
|
enc_weight (float): The weight of encoder loss. Defaults to 1.0.
|
|
dec_weight (float): The weight of decoder loss. Defaults to 1.0.
|
|
fusion_weight (float): The weight of fuser (aligner) loss.
|
|
Defaults to 1.0.
|
|
num_classes (int): Number of unique output language tokens.
|
|
|
|
Returns:
|
|
A dictionary whose key/value pairs are the losses of three modules.
|
|
"""
|
|
|
|
def __init__(self,
|
|
enc_weight=1.0,
|
|
dec_weight=1.0,
|
|
fusion_weight=1.0,
|
|
num_classes=37,
|
|
**kwargs):
|
|
assert isinstance(enc_weight, float) or isinstance(enc_weight, int)
|
|
assert isinstance(dec_weight, float) or isinstance(dec_weight, int)
|
|
assert isinstance(fusion_weight, float) or \
|
|
isinstance(fusion_weight, int)
|
|
super().__init__()
|
|
self.enc_weight = enc_weight
|
|
self.dec_weight = dec_weight
|
|
self.fusion_weight = fusion_weight
|
|
self.num_classes = num_classes
|
|
|
|
def _flatten(self, logits, target_lens):
|
|
flatten_logits = torch.cat(
|
|
[s[:target_lens[i]] for i, s in enumerate((logits))])
|
|
return flatten_logits
|
|
|
|
def _ce_loss(self, logits, targets):
|
|
targets_one_hot = F.one_hot(targets, self.num_classes)
|
|
log_prob = F.log_softmax(logits, dim=-1)
|
|
loss = -(targets_one_hot.to(log_prob.device) * log_prob).sum(dim=-1)
|
|
return loss.mean()
|
|
|
|
def _loss_over_iters(self, outputs, targets):
|
|
"""
|
|
Args:
|
|
outputs (list[Tensor]): Each tensor has shape (N, T, C) where N is
|
|
the batch size, T is the sequence length and C is the number of
|
|
classes.
|
|
targets_dicts (dict): The dictionary with at least `padded_targets`
|
|
defined.
|
|
"""
|
|
iter_num = len(outputs)
|
|
dec_outputs = torch.cat(outputs, dim=0)
|
|
flatten_targets_iternum = targets.repeat(iter_num)
|
|
return self._ce_loss(dec_outputs, flatten_targets_iternum)
|
|
|
|
def forward(self, outputs, targets_dict, img_metas=None):
|
|
"""
|
|
Args:
|
|
outputs (dict): The output dictionary with at least one of
|
|
``out_enc``, ``out_dec`` and ``out_fusers`` specified.
|
|
targets_dict (dict): The target dictionary containing the key
|
|
``padded_targets``, which represents target sequences in
|
|
shape (batch_size, sequence_length).
|
|
|
|
Returns:
|
|
A loss dictionary with ``loss_visual``, ``loss_lang`` and
|
|
``loss_fusion``. Each should either be the loss tensor or ``0`` if
|
|
the output of its corresponding module is not given.
|
|
"""
|
|
assert 'out_enc' in outputs or \
|
|
'out_dec' in outputs or 'out_fusers' in outputs
|
|
losses = {}
|
|
|
|
target_lens = [len(t) for t in targets_dict['targets']]
|
|
flatten_targets = torch.cat([t for t in targets_dict['targets']])
|
|
|
|
if outputs.get('out_enc', None):
|
|
enc_input = self._flatten(outputs['out_enc']['logits'],
|
|
target_lens)
|
|
enc_loss = self._ce_loss(enc_input,
|
|
flatten_targets) * self.enc_weight
|
|
losses['loss_visual'] = enc_loss
|
|
if outputs.get('out_decs', None):
|
|
dec_logits = [
|
|
self._flatten(o['logits'], target_lens)
|
|
for o in outputs['out_decs']
|
|
]
|
|
dec_loss = self._loss_over_iters(dec_logits,
|
|
flatten_targets) * self.dec_weight
|
|
losses['loss_lang'] = dec_loss
|
|
if outputs.get('out_fusers', None):
|
|
fusion_logits = [
|
|
self._flatten(o['logits'], target_lens)
|
|
for o in outputs['out_fusers']
|
|
]
|
|
fusion_loss = self._loss_over_iters(
|
|
fusion_logits, flatten_targets) * self.fusion_weight
|
|
losses['loss_fusion'] = fusion_loss
|
|
return losses
|