Add SDMGR Loss

This commit is contained in:
gaotongxiao 2022-07-08 09:55:06 +00:00
parent 622e65926e
commit e23a2ef089
6 changed files with 102 additions and 46 deletions

View File

@ -1,4 +1,4 @@
# Copyright (c) OpenMMLab. All rights reserved.
from .extractors import * # NOQA
from .heads import * # NOQA
from .losses import * # NOQA
from .module_losses import * # NOQA

View File

@ -1,4 +0,0 @@
# Copyright (c) OpenMMLab. All rights reserved.
from .sdmgr_loss import SDMGRLoss
__all__ = ['SDMGRLoss']

View File

@ -1,41 +0,0 @@
# Copyright (c) OpenMMLab. All rights reserved.
import torch
from mmdet.models.losses import accuracy
from torch import nn
from mmocr.registry import MODELS
@MODELS.register_module()
class SDMGRLoss(nn.Module):
"""The implementation the loss of key information extraction proposed in
the paper: Spatial Dual-Modality Graph Reasoning for Key Information
Extraction.
https://arxiv.org/abs/2103.14470.
"""
def __init__(self, node_weight=1.0, edge_weight=1.0, ignore=-100):
super().__init__()
self.loss_node = nn.CrossEntropyLoss(ignore_index=ignore)
self.loss_edge = nn.CrossEntropyLoss(ignore_index=-1)
self.node_weight = node_weight
self.edge_weight = edge_weight
self.ignore = ignore
def forward(self, node_preds, edge_preds, gts):
node_gts, edge_gts = [], []
for gt in gts:
node_gts.append(gt[:, 0])
edge_gts.append(gt[:, 1:].contiguous().view(-1))
node_gts = torch.cat(node_gts).long()
edge_gts = torch.cat(edge_gts).long()
node_valids = torch.nonzero(
node_gts != self.ignore, as_tuple=False).view(-1)
edge_valids = torch.nonzero(edge_gts != -1, as_tuple=False).view(-1)
return dict(
loss_node=self.node_weight * self.loss_node(node_preds, node_gts),
loss_edge=self.edge_weight * self.loss_edge(edge_preds, edge_gts),
acc_node=accuracy(node_preds[node_valids], node_gts[node_valids]),
acc_edge=accuracy(edge_preds[edge_valids], edge_gts[edge_valids]))

View File

@ -0,0 +1,4 @@
# Copyright (c) OpenMMLab. All rights reserved.
from .sdmgr_module_loss import SDMGRModuleLoss
__all__ = ['SDMGRModuleLoss']

View File

@ -0,0 +1,65 @@
# Copyright (c) OpenMMLab. All rights reserved.
from typing import Dict, List, Tuple
import torch
from mmdet.models.losses import accuracy
from torch import Tensor, nn
from mmocr.core.data_structures.kie_data_sample import KIEDataSample
from mmocr.registry import MODELS
@MODELS.register_module()
class SDMGRModuleLoss(nn.Module):
"""The implementation the loss of key information extraction proposed in
the paper: `Spatial Dual-Modality Graph Reasoning for Key Information
Extraction <https://arxiv.org/abs/2103.14470>`_.
Args:
weight_node (float): Weight of node loss. Defaults to 1.0.
weight_edge (float): Weight of edge loss. Defaults to 1.0.
ignore_idx (int): Node label to ignore. Defaults to -100.
"""
def __init__(self,
weight_node: float = 1.0,
weight_edge: float = 1.0,
ignore_idx: int = -100) -> None:
super().__init__()
# TODO: Use MODELS.build after DRRG loss has been merged
self.loss_node = nn.CrossEntropyLoss(ignore_index=ignore_idx)
self.loss_edge = nn.CrossEntropyLoss(ignore_index=-1)
self.weight_node = weight_node
self.weight_edge = weight_edge
self.ignore_idx = ignore_idx
def forward(self, preds: Tuple[Tensor, Tensor],
data_samples: List[KIEDataSample]) -> Dict:
"""Forward function.
Args:
preds (tuple(Tensor, Tensor)):
data_samples (list[KIEDataSample]): A list of datasamples
containing ``gt_instances.labels`` and
``gt_instances.edge_labels``.
Returns:
dict(str, Tensor): Loss dict, containing ``loss_node``,
``loss_edge``, ``acc_node`` and ``acc_edge``.
"""
node_preds, edge_preds = preds
node_gts, edge_gts = [], []
for data_sample in data_samples:
node_gts.append(data_sample.gt_instances.labels)
edge_gts.append(data_sample.gt_instances.edge_labels.reshape(-1))
node_gts = torch.cat(node_gts).long()
edge_gts = torch.cat(edge_gts).long()
node_valids = torch.nonzero(
node_gts != self.ignore_idx, as_tuple=False).reshape(-1)
edge_valids = torch.nonzero(edge_gts != -1, as_tuple=False).reshape(-1)
return dict(
loss_node=self.weight_node * self.loss_node(node_preds, node_gts),
loss_edge=self.weight_edge * self.loss_edge(edge_preds, edge_gts),
acc_node=accuracy(node_preds[node_valids], node_gts[node_valids]),
acc_edge=accuracy(edge_preds[edge_valids], edge_gts[edge_valids]))

View File

@ -0,0 +1,32 @@
# Copyright (c) OpenMMLab. All rights reserved.
from unittest import TestCase
import torch
from mmengine import InstanceData
from mmocr.core.data_structures.kie_data_sample import KIEDataSample
from mmocr.models.kie.module_losses import SDMGRModuleLoss
class TestSDMGRModuleLoss(TestCase):
def test_forward(self):
loss = SDMGRModuleLoss()
node_preds = torch.rand((3, 26))
edge_preds = torch.rand((9, 2))
data_sample = KIEDataSample()
data_sample.gt_instances = InstanceData(
labels=torch.randint(0, 26, (3, )).long(),
edge_labels=torch.randint(0, 2, (3, 3)).long())
losses = loss((node_preds, edge_preds), [data_sample])
self.assertIn('loss_node', losses)
self.assertIn('loss_edge', losses)
self.assertIn('acc_node', losses)
self.assertIn('acc_edge', losses)
loss = SDMGRModuleLoss(weight_edge=2, weight_node=3)
new_losses = loss((node_preds, edge_preds), [data_sample])
self.assertEqual(losses['loss_node'] * 3, new_losses['loss_node'])
self.assertEqual(losses['loss_edge'] * 2, new_losses['loss_edge'])