mirror of
https://github.com/open-mmlab/mmocr.git
synced 2025-06-03 21:54:47 +08:00
Add SDMGR Loss
This commit is contained in:
parent
622e65926e
commit
e23a2ef089
@ -1,4 +1,4 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from .extractors import * # NOQA
|
||||
from .heads import * # NOQA
|
||||
from .losses import * # NOQA
|
||||
from .module_losses import * # NOQA
|
||||
|
@ -1,4 +0,0 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from .sdmgr_loss import SDMGRLoss
|
||||
|
||||
__all__ = ['SDMGRLoss']
|
@ -1,41 +0,0 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import torch
|
||||
from mmdet.models.losses import accuracy
|
||||
from torch import nn
|
||||
|
||||
from mmocr.registry import MODELS
|
||||
|
||||
|
||||
@MODELS.register_module()
|
||||
class SDMGRLoss(nn.Module):
|
||||
"""The implementation the loss of key information extraction proposed in
|
||||
the paper: Spatial Dual-Modality Graph Reasoning for Key Information
|
||||
Extraction.
|
||||
|
||||
https://arxiv.org/abs/2103.14470.
|
||||
"""
|
||||
|
||||
def __init__(self, node_weight=1.0, edge_weight=1.0, ignore=-100):
|
||||
super().__init__()
|
||||
self.loss_node = nn.CrossEntropyLoss(ignore_index=ignore)
|
||||
self.loss_edge = nn.CrossEntropyLoss(ignore_index=-1)
|
||||
self.node_weight = node_weight
|
||||
self.edge_weight = edge_weight
|
||||
self.ignore = ignore
|
||||
|
||||
def forward(self, node_preds, edge_preds, gts):
|
||||
node_gts, edge_gts = [], []
|
||||
for gt in gts:
|
||||
node_gts.append(gt[:, 0])
|
||||
edge_gts.append(gt[:, 1:].contiguous().view(-1))
|
||||
node_gts = torch.cat(node_gts).long()
|
||||
edge_gts = torch.cat(edge_gts).long()
|
||||
|
||||
node_valids = torch.nonzero(
|
||||
node_gts != self.ignore, as_tuple=False).view(-1)
|
||||
edge_valids = torch.nonzero(edge_gts != -1, as_tuple=False).view(-1)
|
||||
return dict(
|
||||
loss_node=self.node_weight * self.loss_node(node_preds, node_gts),
|
||||
loss_edge=self.edge_weight * self.loss_edge(edge_preds, edge_gts),
|
||||
acc_node=accuracy(node_preds[node_valids], node_gts[node_valids]),
|
||||
acc_edge=accuracy(edge_preds[edge_valids], edge_gts[edge_valids]))
|
4
mmocr/models/kie/module_losses/__init__.py
Normal file
4
mmocr/models/kie/module_losses/__init__.py
Normal file
@ -0,0 +1,4 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from .sdmgr_module_loss import SDMGRModuleLoss
|
||||
|
||||
__all__ = ['SDMGRModuleLoss']
|
65
mmocr/models/kie/module_losses/sdmgr_module_loss.py
Normal file
65
mmocr/models/kie/module_losses/sdmgr_module_loss.py
Normal file
@ -0,0 +1,65 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from typing import Dict, List, Tuple
|
||||
|
||||
import torch
|
||||
from mmdet.models.losses import accuracy
|
||||
from torch import Tensor, nn
|
||||
|
||||
from mmocr.core.data_structures.kie_data_sample import KIEDataSample
|
||||
from mmocr.registry import MODELS
|
||||
|
||||
|
||||
@MODELS.register_module()
|
||||
class SDMGRModuleLoss(nn.Module):
|
||||
"""The implementation the loss of key information extraction proposed in
|
||||
the paper: `Spatial Dual-Modality Graph Reasoning for Key Information
|
||||
Extraction <https://arxiv.org/abs/2103.14470>`_.
|
||||
|
||||
Args:
|
||||
weight_node (float): Weight of node loss. Defaults to 1.0.
|
||||
weight_edge (float): Weight of edge loss. Defaults to 1.0.
|
||||
ignore_idx (int): Node label to ignore. Defaults to -100.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
weight_node: float = 1.0,
|
||||
weight_edge: float = 1.0,
|
||||
ignore_idx: int = -100) -> None:
|
||||
super().__init__()
|
||||
# TODO: Use MODELS.build after DRRG loss has been merged
|
||||
self.loss_node = nn.CrossEntropyLoss(ignore_index=ignore_idx)
|
||||
self.loss_edge = nn.CrossEntropyLoss(ignore_index=-1)
|
||||
self.weight_node = weight_node
|
||||
self.weight_edge = weight_edge
|
||||
self.ignore_idx = ignore_idx
|
||||
|
||||
def forward(self, preds: Tuple[Tensor, Tensor],
|
||||
data_samples: List[KIEDataSample]) -> Dict:
|
||||
"""Forward function.
|
||||
|
||||
Args:
|
||||
preds (tuple(Tensor, Tensor)):
|
||||
data_samples (list[KIEDataSample]): A list of datasamples
|
||||
containing ``gt_instances.labels`` and
|
||||
``gt_instances.edge_labels``.
|
||||
|
||||
Returns:
|
||||
dict(str, Tensor): Loss dict, containing ``loss_node``,
|
||||
``loss_edge``, ``acc_node`` and ``acc_edge``.
|
||||
"""
|
||||
node_preds, edge_preds = preds
|
||||
node_gts, edge_gts = [], []
|
||||
for data_sample in data_samples:
|
||||
node_gts.append(data_sample.gt_instances.labels)
|
||||
edge_gts.append(data_sample.gt_instances.edge_labels.reshape(-1))
|
||||
node_gts = torch.cat(node_gts).long()
|
||||
edge_gts = torch.cat(edge_gts).long()
|
||||
|
||||
node_valids = torch.nonzero(
|
||||
node_gts != self.ignore_idx, as_tuple=False).reshape(-1)
|
||||
edge_valids = torch.nonzero(edge_gts != -1, as_tuple=False).reshape(-1)
|
||||
return dict(
|
||||
loss_node=self.weight_node * self.loss_node(node_preds, node_gts),
|
||||
loss_edge=self.weight_edge * self.loss_edge(edge_preds, edge_gts),
|
||||
acc_node=accuracy(node_preds[node_valids], node_gts[node_valids]),
|
||||
acc_edge=accuracy(edge_preds[edge_valids], edge_gts[edge_valids]))
|
@ -0,0 +1,32 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from unittest import TestCase
|
||||
|
||||
import torch
|
||||
from mmengine import InstanceData
|
||||
|
||||
from mmocr.core.data_structures.kie_data_sample import KIEDataSample
|
||||
from mmocr.models.kie.module_losses import SDMGRModuleLoss
|
||||
|
||||
|
||||
class TestSDMGRModuleLoss(TestCase):
|
||||
|
||||
def test_forward(self):
|
||||
loss = SDMGRModuleLoss()
|
||||
|
||||
node_preds = torch.rand((3, 26))
|
||||
edge_preds = torch.rand((9, 2))
|
||||
data_sample = KIEDataSample()
|
||||
data_sample.gt_instances = InstanceData(
|
||||
labels=torch.randint(0, 26, (3, )).long(),
|
||||
edge_labels=torch.randint(0, 2, (3, 3)).long())
|
||||
|
||||
losses = loss((node_preds, edge_preds), [data_sample])
|
||||
self.assertIn('loss_node', losses)
|
||||
self.assertIn('loss_edge', losses)
|
||||
self.assertIn('acc_node', losses)
|
||||
self.assertIn('acc_edge', losses)
|
||||
|
||||
loss = SDMGRModuleLoss(weight_edge=2, weight_node=3)
|
||||
new_losses = loss((node_preds, edge_preds), [data_sample])
|
||||
self.assertEqual(losses['loss_node'] * 3, new_losses['loss_node'])
|
||||
self.assertEqual(losses['loss_edge'] * 2, new_losses['loss_edge'])
|
Loading…
x
Reference in New Issue
Block a user