Merge pull request #4 from cuhk-hbsun/feature/kie

[feature]: add code for kie and textsnake config
pull/2/head
Hongbin Sun 2021-04-03 00:47:04 +08:00 committed by GitHub
commit 50ab4ef23d
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
17 changed files with 1598 additions and 0 deletions

View File

@ -0,0 +1,25 @@
# Spatial Dual-Modality Graph Reasoning for Key Information Extraction
## Introduction
[ALGORITHM]
```bibtex
@misc{sun2021spatial,
title={Spatial Dual-Modality Graph Reasoning for Key Information Extraction},
author={Hongbin Sun and Zhanghui Kuang and Xiaoyu Yue and Chenhao Lin and Wayne Zhang},
year={2021},
eprint={2103.14470},
archivePrefix={arXiv},
primaryClass={cs.CV}
}
```
## Results and models
### WildReceipt
| Method | Modality | Macro F1-Score | Download |
| :--------------------------------------------------------------------: | :--------------: | :------------: | :-------------------------------------------------------------------------------------------------------------------------------------: |
| [sdmgr_unet16](/configs/kie/sdmgr/sdmgr_unet16_60e_wildreceipt.py) | Visual + Textual | 0.880 | [model](https://download.openmmlab.com/mmocr/kie/sdmgr/todo.pth) \| [log](https://download.openmmlab.com/mmocr/kie/sdmgr/todo.log.json) |
| [sdmgr_novisual](/configs/kie/sdmgr/sdmgr_novisual_60e_wildreceipt.py) | Textual | 0.871 | [model](https://download.openmmlab.com/mmocr/kie/sdmgr/todo.pth) \| [log](https://download.openmmlab.com/mmocr/kie/sdmgr/todo.log.json) |

View File

@ -0,0 +1,99 @@
dataset_type = 'KIEDataset'
data_root = 'data/wildreceipt'
img_norm_cfg = dict(
mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True)
max_scale, min_scale = 1024, 512
train_pipeline = [
dict(type='LoadImageFromFile'),
dict(type='LoadAnnotations'),
dict(type='Resize', img_scale=(max_scale, min_scale), keep_ratio=True),
dict(type='RandomFlip', flip_ratio=0.),
dict(type='Normalize', **img_norm_cfg),
dict(type='Pad', size_divisor=32),
dict(type='KIEFormatBundle'),
dict(
type='Collect',
keys=['img', 'relations', 'texts', 'gt_bboxes', 'gt_labels'])
]
test_pipeline = [
dict(type='LoadImageFromFile'),
dict(type='LoadAnnotations'),
dict(type='Resize', img_scale=(max_scale, min_scale), keep_ratio=True),
dict(type='RandomFlip', flip_ratio=0.),
dict(type='Normalize', **img_norm_cfg),
dict(type='Pad', size_divisor=32),
dict(type='KIEFormatBundle'),
dict(type='Collect', keys=['img', 'relations', 'texts', 'gt_bboxes'])
]
vocab_file = 'dict.txt'
class_file = 'class_list.txt'
data = dict(
samples_per_gpu=4,
workers_per_gpu=0,
train=dict(
type=dataset_type,
ann_file='train.txt',
pipeline=train_pipeline,
data_root=data_root,
vocab_file=vocab_file,
class_file=class_file),
val=dict(
type=dataset_type,
ann_file='test.txt',
pipeline=test_pipeline,
data_root=data_root,
vocab_file=vocab_file,
class_file=class_file),
test=dict(
type=dataset_type,
ann_file='test.txt',
pipeline=test_pipeline,
data_root=data_root,
vocab_file=vocab_file,
class_file=class_file))
evaluation = dict(
interval=1,
metric='macro_f1',
metric_options=dict(
macro_f1=dict(
ignores=[0, 2, 4, 6, 8, 10, 12, 14, 16, 18, 20, 22, 24, 25])))
model = dict(
type='SDMGR',
backbone=dict(type='UNet', base_channels=16),
bbox_head=dict(
type='SDMGRHead', visual_dim=16, num_chars=92, num_classes=26),
visual_modality=False,
train_cfg=None,
test_cfg=None)
optimizer = dict(type='Adam', weight_decay=0.0001)
optimizer_config = dict(grad_clip=None)
lr_config = dict(
policy='step',
warmup='linear',
warmup_iters=1,
warmup_ratio=1,
step=[40, 50])
total_epochs = 60
checkpoint_config = dict(interval=1)
log_config = dict(
interval=50,
hooks=[
dict(type='TextLoggerHook'),
# dict(
# type='PaviLoggerHook',
# add_last_ckpt=True,
# interval=5,
# init_kwargs=dict(project='kie')),
])
dist_params = dict(backend='nccl')
log_level = 'INFO'
load_from = None
resume_from = None
workflow = [('train', 1)]

View File

@ -0,0 +1,99 @@
dataset_type = 'KIEDataset'
data_root = 'data/wildreceipt'
img_norm_cfg = dict(
mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True)
max_scale, min_scale = 1024, 512
train_pipeline = [
dict(type='LoadImageFromFile'),
dict(type='LoadAnnotations'),
dict(type='Resize', img_scale=(max_scale, min_scale), keep_ratio=True),
dict(type='RandomFlip', flip_ratio=0.),
dict(type='Normalize', **img_norm_cfg),
dict(type='Pad', size_divisor=32),
dict(type='KIEFormatBundle'),
dict(
type='Collect',
keys=['img', 'relations', 'texts', 'gt_bboxes', 'gt_labels'])
]
test_pipeline = [
dict(type='LoadImageFromFile'),
dict(type='LoadAnnotations'),
dict(type='Resize', img_scale=(max_scale, min_scale), keep_ratio=True),
dict(type='RandomFlip', flip_ratio=0.),
dict(type='Normalize', **img_norm_cfg),
dict(type='Pad', size_divisor=32),
dict(type='KIEFormatBundle'),
dict(type='Collect', keys=['img', 'relations', 'texts', 'gt_bboxes'])
]
vocab_file = 'dict.txt'
class_file = 'class_list.txt'
data = dict(
samples_per_gpu=4,
workers_per_gpu=0,
train=dict(
type=dataset_type,
ann_file='train.txt',
pipeline=train_pipeline,
data_root=data_root,
vocab_file=vocab_file,
class_file=class_file),
val=dict(
type=dataset_type,
ann_file='test.txt',
pipeline=test_pipeline,
data_root=data_root,
vocab_file=vocab_file,
class_file=class_file),
test=dict(
type=dataset_type,
ann_file='test.txt',
pipeline=test_pipeline,
data_root=data_root,
vocab_file=vocab_file,
class_file=class_file))
evaluation = dict(
interval=1,
metric='macro_f1',
metric_options=dict(
macro_f1=dict(
ignores=[0, 2, 4, 6, 8, 10, 12, 14, 16, 18, 20, 22, 24, 25])))
model = dict(
type='SDMGR',
backbone=dict(type='UNet', base_channels=16),
bbox_head=dict(
type='SDMGRHead', visual_dim=16, num_chars=92, num_classes=26),
visual_modality=True,
train_cfg=None,
test_cfg=None)
optimizer = dict(type='Adam', weight_decay=0.0001)
optimizer_config = dict(grad_clip=None)
lr_config = dict(
policy='step',
warmup='linear',
warmup_iters=1,
warmup_ratio=1,
step=[40, 50])
total_epochs = 60
checkpoint_config = dict(interval=1)
log_config = dict(
interval=50,
hooks=[
dict(type='TextLoggerHook'),
# dict(
# type='PaviLoggerHook',
# add_last_ckpt=True,
# interval=5,
# init_kwargs=dict(project='kie')),
])
dist_params = dict(backend='nccl')
log_level = 'INFO'
load_from = None
resume_from = None
workflow = [('train', 1)]

View File

@ -0,0 +1,23 @@
# Textsnake
## Introduction
[ALGORITHM]
```bibtex
@article{long2018textsnake,
title={TextSnake: A Flexible Representation for Detecting Text of Arbitrary Shapes},
author={Long, Shangbang and Ruan, Jiaqiang and Zhang, Wenjie and He, Xin and Wu, Wenhao and Yao, Cong},
booktitle={ECCV},
pages={20-36},
year={2018}
}
```
## Results and models
### CTW1500
| Method | Pretrained Model | Training set | Test set | #epochs | Test size | Recall | Precision | Hmean | Download |
| :----------------------------------------------------------------------------: | :--------------: | :-----------: | :----------: | :-----: | :-------: | :----: | :-------: | :---: | :-------------------: |
| [TextSnake](/configs/textdet/textsnake/textsnake_r50_fpn_unet_600e_ctw1500.py) | ImageNet | CTW1500 Train | CTW1500 Test | 1200 | 736 | 0.795 | 0.840 | 0.817 | [model](https://download.openmmlab.com/mmocr/textdet/textsnake/textsnake_r50_fpn_unet_1200e_ctw1500-27f65b64.pth) | [config](https://download.openmmlab.com/mmocr/textdet/textsnake/textsnake_r50_fpn_unet_1200e_ctw1500.py) |

View File

@ -0,0 +1,113 @@
_base_ = [
'../../_base_/schedules/schedule_1200e.py',
'../../_base_/default_runtime.py'
]
model = dict(
type='TextSnake',
pretrained='torchvision://resnet50',
backbone=dict(
type='ResNet',
depth=50,
num_stages=4,
out_indices=(0, 1, 2, 3),
frozen_stages=-1,
norm_cfg=dict(type='BN', requires_grad=True),
norm_eval=True,
style='caffe'),
neck=dict(
type='FPN_UNET', in_channels=[256, 512, 1024, 2048], out_channels=32),
bbox_head=dict(
type='TextSnakeHead',
in_channels=32,
text_repr_type='poly',
loss=dict(type='TextSnakeLoss')),
train_cfg=None,
test_cfg=None)
dataset_type = 'IcdarDataset'
data_root = 'data/ctw1500/'
img_norm_cfg = dict(
mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True)
train_pipeline = [
dict(type='LoadImageFromFile'),
dict(
type='LoadTextAnnotations',
with_bbox=True,
with_mask=True,
poly2mask=False),
dict(type='ColorJitter', brightness=32.0 / 255, saturation=0.5),
dict(type='Normalize', **img_norm_cfg),
dict(
type='RandomCropPolyInstances',
instance_key='gt_masks',
crop_ratio=0.65,
min_side_ratio=0.3),
dict(
type='RandomRotatePolyInstances',
rotate_ratio=0.5,
max_angle=20,
pad_with_fixed_color=False),
dict(
type='ScaleAspectJitter',
img_scale=[(3000, 736)], # unused
ratio_range=(0.7, 1.3),
aspect_ratio_range=(0.9, 1.1),
multiscale_mode='value',
long_size_bound=800,
short_size_bound=480,
resize_type='long_short_bound',
keep_ratio=False),
dict(type='SquareResizePad', target_size=800, pad_ratio=0.6),
dict(type='RandomFlip', flip_ratio=0.5, direction='horizontal'),
dict(type='TextSnakeTargets'),
dict(type='Pad', size_divisor=32),
dict(
type='CustomFormatBundle',
keys=[
'gt_text_mask', 'gt_center_region_mask', 'gt_mask',
'gt_radius_map', 'gt_sin_map', 'gt_cos_map'
],
visualize=dict(flag=False, boundary_key='gt_text_mask')),
dict(
type='Collect',
keys=[
'img', 'gt_text_mask', 'gt_center_region_mask', 'gt_mask',
'gt_radius_map', 'gt_sin_map', 'gt_cos_map'
])
]
test_pipeline = [
dict(type='LoadImageFromFile'),
dict(
type='MultiScaleFlipAug',
img_scale=(1333, 736),
flip=False,
transforms=[
dict(type='Resize', img_scale=(1333, 736), keep_ratio=True),
dict(type='Normalize', **img_norm_cfg),
dict(type='Pad', size_divisor=32),
dict(type='ImageToTensor', keys=['img']),
dict(type='Collect', keys=['img']),
])
]
data = dict(
samples_per_gpu=4,
workers_per_gpu=4,
train=dict(
type=dataset_type,
ann_file=data_root + '/instances_training.json',
img_prefix=data_root + '/imgs',
pipeline=train_pipeline),
val=dict(
type=dataset_type,
ann_file=data_root + '/instances_test.json',
img_prefix=data_root + '/imgs',
pipeline=test_pipeline),
test=dict(
type=dataset_type,
ann_file=data_root + '/instances_test.json',
img_prefix=data_root + '/imgs',
pipeline=test_pipeline))
evaluation = dict(interval=10, metric='hmean-iou')

View File

@ -0,0 +1,27 @@
import torch
def compute_f1_score(preds, gts, ignores=[]):
"""Compute the F1-score of prediction.
Args:
preds (Tensor): The predicted probability NxC map
with N and C being the sample number and class
number respectively.
gts (Tensor): The ground truth vector of size N.
ignores (list): The index set of classes that are ignored when
reporting results.
Note: all samples are participated in computing.
Returns:
The numpy list of f1-scores of valid classes.
"""
C = preds.size(1)
classes = torch.LongTensor(sorted(set(range(C)) - set(ignores)))
hist = torch.bincount(
gts * C + preds.argmax(1), minlength=C**2).view(C, C).float()
diag = torch.diag(hist)
recalls = diag / hist.sum(1).clamp(min=1)
precisions = diag / hist.sum(0).clamp(min=1)
f1 = 2 * recalls * precisions / (recalls + precisions).clamp(min=1e-8)
return f1[classes].cpu().numpy()

View File

@ -0,0 +1,295 @@
import copy
from os import path as osp
import mmcv
import numpy as np
import torch
from matplotlib import pyplot as plt
from PIL import Image
from mmdet.datasets.builder import DATASETS
from mmdet.datasets.custom import CustomDataset
from mmocr.core import compute_f1_score
@DATASETS.register_module()
class KIEDataset(CustomDataset):
def __init__(self,
ann_file,
pipeline=None,
data_root=None,
img_prefix='',
ann_prefix='',
vocab_file=None,
class_file=None,
norm=10.,
thresholds=dict(edge=0.5),
directed=False,
**kwargs):
self.ann_prefix = ann_prefix
self.norm = norm
self.thresholds = thresholds
self.directed = directed
if data_root is not None:
if not osp.isabs(ann_file):
self.ann_file = osp.join(data_root, ann_file)
if not (ann_prefix is None or osp.isabs(ann_prefix)):
self.ann_prefix = osp.join(data_root, ann_prefix)
self.vocab = dict({'': 0})
vocab_file = osp.join(data_root, vocab_file)
if osp.exists(vocab_file):
with open(vocab_file, 'r') as fid:
for idx, char in enumerate(fid.readlines(), 1):
self.vocab[char.strip('\n')] = idx
else:
self.construct_dict(self.ann_file)
with open(vocab_file, 'w') as fid:
for key in self.vocab:
if key:
fid.write('{}\n'.format(key))
super().__init__(
ann_file,
pipeline,
data_root=data_root,
img_prefix=img_prefix,
**kwargs)
self.idx_to_cls = dict()
with open(osp.join(data_root, class_file), 'r') as fid:
for line in fid.readlines():
idx, cls = line.split()
self.idx_to_cls[int(idx)] = cls
@staticmethod
def _split_edge(line):
text = ','.join(line[8:-1])
if ';' in text and text.split(';')[0].isdecimal():
edge, text = text.split(';', 1)
edge = int(edge)
else:
edge = 0
return edge, text
def construct_dict(self, ann_file):
img_infos = mmcv.list_from_file(ann_file)
for img_info in img_infos:
_, annname = img_info.split()
if self.ann_prefix:
annname = osp.join(self.ann_prefix, annname)
with open(annname, 'r') as fid:
lines = fid.readlines()
for line in lines:
line = line.strip().split(',')
_, text = self._split_edge(line)
for c in text:
if c not in self.vocab:
self.vocab[c] = len(self.vocab)
self.vocab = dict(
{k: idx
for idx, k in enumerate(sorted(self.vocab.keys()))})
def convert_text(self, text):
return [self.vocab[c] for c in text if c in self.vocab]
def parse_lines(self, annname):
boxes, edges, texts, chars, labels = [], [], [], [], []
if self.ann_prefix:
annname = osp.join(self.ann_prefix, annname)
with open(annname, 'r') as fid:
for line in fid.readlines():
line = line.strip().split(',')
boxes.append(list(map(int, line[:8])))
edge, text = self._split_edge(line)
chars.append(text)
text = self.convert_text(text)
texts.append(text)
edges.append(edge)
labels.append(int(line[-1]))
return dict(
boxes=boxes, edges=edges, texts=texts, chars=chars, labels=labels)
def format_results(self, results):
boxes = torch.Tensor(results['boxes'])[:, [0, 1, 4, 5]].cuda()
if 'nodes' in results:
nodes, edges = results['nodes'], results['edges']
labels = nodes.argmax(-1)
num_nodes = nodes.size(0)
edges = edges[:, -1].view(num_nodes, num_nodes)
else:
labels = torch.Tensor(results['labels']).cuda()
edges = torch.Tensor(results['edges']).cuda()
boxes = torch.cat([boxes, labels[:, None].float()], -1)
return {
**{
k: v
for k, v in results.items() if k not in ['boxes', 'edges']
}, 'boxes': boxes,
'edges': edges,
'points': results['boxes']
}
def plot(self, results):
img_name = osp.join(self.img_prefix, results['filename'])
img = plt.imread(img_name)
plt.imshow(img)
boxes, texts = results['points'], results['chars']
num_nodes = len(boxes)
if 'scores' in results:
scores = results['scores']
else:
scores = np.ones(num_nodes)
for box, text, score in zip(boxes, texts, scores):
xs, ys = [], []
for idx in range(0, 10, 2):
xs.append(box[idx % 8])
ys.append(box[(idx + 1) % 8])
plt.plot(xs, ys, 'g')
plt.annotate(
'{}: {:.4f}'.format(text, score), (box[0], box[1]), color='g')
if 'nodes' in results:
nodes = results['nodes']
inds = nodes.argmax(-1)
else:
nodes = np.ones((num_nodes, 3))
inds = results['labels']
for i in range(num_nodes):
plt.annotate(
'{}: {:.4f}'.format(
self.idx_to_cls(inds[i] - 1), nodes[i, inds[i]]),
(boxes[i][6], boxes[i][7]),
color='r' if inds[i] == 1 else 'b')
edges = results['edges']
if 'nodes' not in results:
edges = (edges[:, None] == edges[None]).float()
for j in range(i + 1, num_nodes):
edge_score = max(edges[i][j], edges[j][i])
if edge_score > self.thresholds['edge']:
x1 = sum(boxes[i][:3:2]) // 2
y1 = sum(boxes[i][3:6:2]) // 2
x2 = sum(boxes[j][:3:2]) // 2
y2 = sum(boxes[j][3:6:2]) // 2
plt.plot((x1, x2), (y1, y2), 'r')
plt.annotate(
'{:.4f}'.format(edge_score),
((x1 + x2) // 2, (y1 + y2) // 2),
color='r')
def compute_relation(self, boxes):
x1s, y1s = boxes[:, 0:1], boxes[:, 1:2]
x2s, y2s = boxes[:, 4:5], boxes[:, 5:6]
ws, hs = x2s - x1s + 1, np.maximum(y2s - y1s + 1, 1)
dxs = (x1s[:, 0][None] - x1s) / self.norm
dys = (y1s[:, 0][None] - y1s) / self.norm
xhhs, xwhs = hs[:, 0][None] / hs, ws[:, 0][None] / hs
whs = ws / hs + np.zeros_like(xhhs)
relations = np.stack([dxs, dys, whs, xhhs, xwhs], -1)
bboxes = np.concatenate([x1s, y1s, x2s, y2s], -1).astype(np.float32)
return relations, bboxes
def ann_numpy(self, results):
boxes, texts = results['boxes'], results['texts']
boxes = np.array(boxes, np.int32)
if boxes[0, 1] > boxes[0, -1]:
boxes = boxes[:, [6, 7, 4, 5, 2, 3, 0, 1]]
relations, bboxes = self.compute_relation(boxes)
labels = results.get('labels', None)
if labels is not None:
labels = np.array(labels, np.int32)
edges = results.get('edges', None)
if edges is not None:
labels = labels[:, None]
edges = np.array(edges)
edges = (edges[:, None] == edges[None, :]).astype(np.int32)
if self.directed:
edges = (edges & labels == 1).astype(np.int32)
np.fill_diagonal(edges, -1)
labels = np.concatenate([labels, edges], -1)
return dict(
bboxes=bboxes,
relations=relations,
texts=self.pad_text(texts),
labels=labels)
def image_size(self, filename):
img_path = osp.join(self.img_prefix, filename)
img = Image.open(img_path)
return img.size
def load_annotations(self, ann_file):
self.anns, data_infos = [], []
self.gts = dict()
img_infos = mmcv.list_from_file(ann_file)
for img_info in img_infos:
filename, annname = img_info.split()
results = self.parse_lines(annname)
width, height = self.image_size(filename)
data_infos.append(
dict(filename=filename, width=width, height=height))
ann = self.ann_numpy(results)
self.anns.append(ann)
return data_infos
def pad_text(self, texts):
max_len = max([len(text) for text in texts])
padded_texts = -np.ones((len(texts), max_len), np.int32)
for idx, text in enumerate(texts):
padded_texts[idx, :len(text)] = np.array(text)
return padded_texts
def get_ann_info(self, idx):
return self.anns[idx]
def prepare_test_img(self, idx):
return self.prepare_train_img(idx)
def evaluate(self,
results,
metric='macro_f1',
metric_options=dict(macro_f1=dict(ignores=[])),
**kwargs):
# allow some kwargs to pass through
assert set(kwargs).issubset(['logger'])
# Protect ``metric_options`` since it uses mutable value as default
metric_options = copy.deepcopy(metric_options)
metrics = metric if isinstance(metric, list) else [metric]
allowed_metrics = ['macro_f1']
for m in metrics:
if m not in allowed_metrics:
raise KeyError(f'metric {m} is not supported')
return self.compute_macro_f1(results, **metric_options['macro_f1'])
def compute_macro_f1(self, results, ignores=[]):
node_preds = []
for result in results:
node_preds.append(result['nodes'])
node_preds = torch.cat(node_preds)
node_gts = [
torch.from_numpy(ann['labels'][:, 0]).to(node_preds.device)
for ann in self.anns
]
node_gts = torch.cat(node_gts)
node_f1s = compute_f1_score(node_preds, node_gts, ignores)
return {
'macro_f1': node_f1s.mean(),
}

View File

@ -0,0 +1,55 @@
import numpy as np
from mmcv.parallel import DataContainer as DC
from mmdet.datasets.builder import PIPELINES
from mmdet.datasets.pipelines.formating import DefaultFormatBundle, to_tensor
@PIPELINES.register_module()
class KIEFormatBundle(DefaultFormatBundle):
"""Key information extraction formatting bundle.
Based on the DefaultFormatBundle, itt simplifies the pipeline of formatting
common fields, including "img", "proposals", "gt_bboxes", "gt_labels",
"gt_masks", "gt_semantic_seg", "relations" and "texts".
These fields are formatted as follows.
- img: (1) transpose, (2) to tensor, (3) to DataContainer (stack=True)
- proposals: (1) to tensor, (2) to DataContainer
- gt_bboxes: (1) to tensor, (2) to DataContainer
- gt_bboxes_ignore: (1) to tensor, (2) to DataContainer
- gt_labels: (1) to tensor, (2) to DataContainer
- gt_masks: (1) to tensor, (2) to DataContainer (cpu_only=True)
- gt_semantic_seg: (1) unsqueeze dim-0 (2) to tensor, \
(3) to DataContainer (stack=True)
- relations: (1) scale, (2) to tensor, (3) to DataContainer
- texts: (1) to tensor, (2) to DataContainer
"""
def __call__(self, results):
"""Call function to transform and format common fields in results.
Args:
results (dict): Result dict contains the data to convert.
Returns:
dict: The result dict contains the data that is formatted with \
default bundle.
"""
super().__call__(results)
if 'ann_info' in results:
for key in ['relations', 'texts']:
value = results['ann_info'][key]
if key == 'relations' and 'scale_factor' in results:
scale_factor = results['scale_factor']
if isinstance(scale_factor, float):
sx = sy = scale_factor
else:
sx, sy = results['scale_factor'][:2]
r = sx / sy
value = value * np.array([sx, sy, r, 1, r])[None, None]
results[key] = DC(to_tensor(value))
return results
def __repr__(self):
return self.__class__.__name__

View File

@ -0,0 +1,3 @@
from .unet import UNet
__all__ = ['UNet']

View File

@ -0,0 +1,528 @@
import torch
import torch.nn as nn
import torch.utils.checkpoint as cp
from mmcv.cnn import (UPSAMPLE_LAYERS, ConvModule, build_activation_layer,
build_norm_layer, build_upsample_layer, constant_init,
kaiming_init)
from mmcv.runner import load_checkpoint
from mmcv.utils.parrots_wrapper import _BatchNorm
from mmdet.models.builder import BACKBONES
from mmdet.utils import get_root_logger
class UpConvBlock(nn.Module):
"""Upsample convolution block in decoder for UNet.
This upsample convolution block consists of one upsample module
followed by one convolution block. The upsample module expands the
high-level low-resolution feature map and the convolution block fuses
the upsampled high-level low-resolution feature map and the low-level
high-resolution feature map from encoder.
Args:
conv_block (nn.Sequential): Sequential of convolutional layers.
in_channels (int): Number of input channels of the high-level
skip_channels (int): Number of input channels of the low-level
high-resolution feature map from encoder.
out_channels (int): Number of output channels.
num_convs (int): Number of convolutional layers in the conv_block.
Default: 2.
stride (int): Stride of convolutional layer in conv_block. Default: 1.
dilation (int): Dilation rate of convolutional layer in conv_block.
Default: 1.
with_cp (bool): Use checkpoint or not. Using checkpoint will save some
memory while slowing down the training speed. Default: False.
conv_cfg (dict | None): Config dict for convolution layer.
Default: None.
norm_cfg (dict | None): Config dict for normalization layer.
Default: dict(type='BN').
act_cfg (dict | None): Config dict for activation layer in ConvModule.
Default: dict(type='ReLU').
upsample_cfg (dict): The upsample config of the upsample module in
decoder. Default: dict(type='InterpConv'). If the size of
high-level feature map is the same as that of skip feature map
(low-level feature map from encoder), it does not need upsample the
high-level feature map and the upsample_cfg is None.
dcn (bool): Use deformable convolution in convolutional layer or not.
Default: None.
plugins (dict): plugins for convolutional layers. Default: None.
"""
def __init__(self,
conv_block,
in_channels,
skip_channels,
out_channels,
num_convs=2,
stride=1,
dilation=1,
with_cp=False,
conv_cfg=None,
norm_cfg=dict(type='BN'),
act_cfg=dict(type='ReLU'),
upsample_cfg=dict(type='InterpConv'),
dcn=None,
plugins=None):
super().__init__()
assert dcn is None, 'Not implemented yet.'
assert plugins is None, 'Not implemented yet.'
self.conv_block = conv_block(
in_channels=2 * skip_channels,
out_channels=out_channels,
num_convs=num_convs,
stride=stride,
dilation=dilation,
with_cp=with_cp,
conv_cfg=conv_cfg,
norm_cfg=norm_cfg,
act_cfg=act_cfg,
dcn=None,
plugins=None)
if upsample_cfg is not None:
self.upsample = build_upsample_layer(
cfg=upsample_cfg,
in_channels=in_channels,
out_channels=skip_channels,
with_cp=with_cp,
norm_cfg=norm_cfg,
act_cfg=act_cfg)
else:
self.upsample = ConvModule(
in_channels,
skip_channels,
kernel_size=1,
stride=1,
padding=0,
conv_cfg=conv_cfg,
norm_cfg=norm_cfg,
act_cfg=act_cfg)
def forward(self, skip, x):
"""Forward function."""
x = self.upsample(x)
out = torch.cat([skip, x], dim=1)
out = self.conv_block(out)
return out
class BasicConvBlock(nn.Module):
"""Basic convolutional block for UNet.
This module consists of several plain convolutional layers.
Args:
in_channels (int): Number of input channels.
out_channels (int): Number of output channels.
num_convs (int): Number of convolutional layers. Default: 2.
stride (int): Whether use stride convolution to downsample
the input feature map. If stride=2, it only uses stride convolution
in the first convolutional layer to downsample the input feature
map. Options are 1 or 2. Default: 1.
dilation (int): Whether use dilated convolution to expand the
receptive field. Set dilation rate of each convolutional layer and
the dilation rate of the first convolutional layer is always 1.
Default: 1.
with_cp (bool): Use checkpoint or not. Using checkpoint will save some
memory while slowing down the training speed. Default: False.
conv_cfg (dict | None): Config dict for convolution layer.
Default: None.
norm_cfg (dict | None): Config dict for normalization layer.
Default: dict(type='BN').
act_cfg (dict | None): Config dict for activation layer in ConvModule.
Default: dict(type='ReLU').
dcn (bool): Use deformable convolution in convolutional layer or not.
Default: None.
plugins (dict): plugins for convolutional layers. Default: None.
"""
def __init__(self,
in_channels,
out_channels,
num_convs=2,
stride=1,
dilation=1,
with_cp=False,
conv_cfg=None,
norm_cfg=dict(type='BN'),
act_cfg=dict(type='ReLU'),
dcn=None,
plugins=None):
super().__init__()
assert dcn is None, 'Not implemented yet.'
assert plugins is None, 'Not implemented yet.'
self.with_cp = with_cp
convs = []
for i in range(num_convs):
convs.append(
ConvModule(
in_channels=in_channels if i == 0 else out_channels,
out_channels=out_channels,
kernel_size=3,
stride=stride if i == 0 else 1,
dilation=1 if i == 0 else dilation,
padding=1 if i == 0 else dilation,
conv_cfg=conv_cfg,
norm_cfg=norm_cfg,
act_cfg=act_cfg))
self.convs = nn.Sequential(*convs)
def forward(self, x):
"""Forward function."""
if self.with_cp and x.requires_grad:
out = cp.checkpoint(self.convs, x)
else:
out = self.convs(x)
return out
@UPSAMPLE_LAYERS.register_module()
class DeconvModule(nn.Module):
"""Deconvolution upsample module in decoder for UNet (2X upsample).
This module uses deconvolution to upsample feature map in the decoder
of UNet.
Args:
in_channels (int): Number of input channels.
out_channels (int): Number of output channels.
with_cp (bool): Use checkpoint or not. Using checkpoint will save some
memory while slowing down the training speed. Default: False.
norm_cfg (dict | None): Config dict for normalization layer.
Default: dict(type='BN').
act_cfg (dict | None): Config dict for activation layer in ConvModule.
Default: dict(type='ReLU').
kernel_size (int): Kernel size of the convolutional layer. Default: 4.
"""
def __init__(self,
in_channels,
out_channels,
with_cp=False,
norm_cfg=dict(type='BN'),
act_cfg=dict(type='ReLU'),
*,
kernel_size=4,
scale_factor=2):
super().__init__()
assert (kernel_size - scale_factor >= 0) and\
(kernel_size - scale_factor) % 2 == 0,\
f'kernel_size should be greater than or equal to scale_factor '\
f'and (kernel_size - scale_factor) should be even numbers, '\
f'while the kernel size is {kernel_size} and scale_factor is '\
f'{scale_factor}.'
stride = scale_factor
padding = (kernel_size - scale_factor) // 2
self.with_cp = with_cp
deconv = nn.ConvTranspose2d(
in_channels,
out_channels,
kernel_size=kernel_size,
stride=stride,
padding=padding)
_, norm = build_norm_layer(norm_cfg, out_channels)
activate = build_activation_layer(act_cfg)
self.deconv_upsamping = nn.Sequential(deconv, norm, activate)
def forward(self, x):
"""Forward function."""
if self.with_cp and x.requires_grad:
out = cp.checkpoint(self.deconv_upsamping, x)
else:
out = self.deconv_upsamping(x)
return out
@UPSAMPLE_LAYERS.register_module()
class InterpConv(nn.Module):
"""Interpolation upsample module in decoder for UNet.
This module uses interpolation to upsample feature map in the decoder
of UNet. It consists of one interpolation upsample layer and one
convolutional layer. It can be one interpolation upsample layer followed
by one convolutional layer (conv_first=False) or one convolutional layer
followed by one interpolation upsample layer (conv_first=True).
Args:
in_channels (int): Number of input channels.
out_channels (int): Number of output channels.
with_cp (bool): Use checkpoint or not. Using checkpoint will save some
memory while slowing down the training speed. Default: False.
norm_cfg (dict | None): Config dict for normalization layer.
Default: dict(type='BN').
act_cfg (dict | None): Config dict for activation layer in ConvModule.
Default: dict(type='ReLU').
conv_cfg (dict | None): Config dict for convolution layer.
Default: None.
conv_first (bool): Whether convolutional layer or interpolation
upsample layer first. Default: False. It means interpolation
upsample layer followed by one convolutional layer.
kernel_size (int): Kernel size of the convolutional layer. Default: 1.
stride (int): Stride of the convolutional layer. Default: 1.
padding (int): Padding of the convolutional layer. Default: 1.
upsample_cfg (dict): Interpolation config of the upsample layer.
Default: dict(
scale_factor=2, mode='bilinear', align_corners=False).
"""
def __init__(self,
in_channels,
out_channels,
with_cp=False,
norm_cfg=dict(type='BN'),
act_cfg=dict(type='ReLU'),
*,
conv_cfg=None,
conv_first=False,
kernel_size=1,
stride=1,
padding=0,
upsample_cfg=dict(
scale_factor=2, mode='bilinear', align_corners=False)):
super().__init__()
self.with_cp = with_cp
conv = ConvModule(
in_channels,
out_channels,
kernel_size=kernel_size,
stride=stride,
padding=padding,
conv_cfg=conv_cfg,
norm_cfg=norm_cfg,
act_cfg=act_cfg)
upsample = nn.Upsample(**upsample_cfg)
if conv_first:
self.interp_upsample = nn.Sequential(conv, upsample)
else:
self.interp_upsample = nn.Sequential(upsample, conv)
def forward(self, x):
"""Forward function."""
if self.with_cp and x.requires_grad:
out = cp.checkpoint(self.interp_upsample, x)
else:
out = self.interp_upsample(x)
return out
@BACKBONES.register_module()
class UNet(nn.Module):
"""UNet backbone.
U-Net: Convolutional Networks for Biomedical Image Segmentation.
https://arxiv.org/pdf/1505.04597.pdf
Args:
in_channels (int): Number of input image channels. Default" 3.
base_channels (int): Number of base channels of each stage.
The output channels of the first stage. Default: 64.
num_stages (int): Number of stages in encoder, normally 5. Default: 5.
strides (Sequence[int 1 | 2]): Strides of each stage in encoder.
len(strides) is equal to num_stages. Normally the stride of the
first stage in encoder is 1. If strides[i]=2, it uses stride
convolution to downsample in the correspondence encoder stage.
Default: (1, 1, 1, 1, 1).
enc_num_convs (Sequence[int]): Number of convolutional layers in the
convolution block of the correspondence encoder stage.
Default: (2, 2, 2, 2, 2).
dec_num_convs (Sequence[int]): Number of convolutional layers in the
convolution block of the correspondence decoder stage.
Default: (2, 2, 2, 2).
downsamples (Sequence[int]): Whether use MaxPool to downsample the
feature map after the first stage of encoder
(stages: [1, num_stages)). If the correspondence encoder stage use
stride convolution (strides[i]=2), it will never use MaxPool to
downsample, even downsamples[i-1]=True.
Default: (True, True, True, True).
enc_dilations (Sequence[int]): Dilation rate of each stage in encoder.
Default: (1, 1, 1, 1, 1).
dec_dilations (Sequence[int]): Dilation rate of each stage in decoder.
Default: (1, 1, 1, 1).
with_cp (bool): Use checkpoint or not. Using checkpoint will save some
memory while slowing down the training speed. Default: False.
conv_cfg (dict | None): Config dict for convolution layer.
Default: None.
norm_cfg (dict | None): Config dict for normalization layer.
Default: dict(type='BN').
act_cfg (dict | None): Config dict for activation layer in ConvModule.
Default: dict(type='ReLU').
upsample_cfg (dict): The upsample config of the upsample module in
decoder. Default: dict(type='InterpConv').
norm_eval (bool): Whether to set norm layers to eval mode, namely,
freeze running stats (mean and var). Note: Effect on Batch Norm
and its variants only. Default: False.
dcn (bool): Use deformable convolution in convolutional layer or not.
Default: None.
plugins (dict): plugins for convolutional layers. Default: None.
Notice:
The input image size should be divisible by the whole downsample rate
of the encoder. More detail of the whole downsample rate can be found
in UNet._check_input_divisible.
"""
def __init__(self,
in_channels=3,
base_channels=64,
num_stages=5,
strides=(1, 1, 1, 1, 1),
enc_num_convs=(2, 2, 2, 2, 2),
dec_num_convs=(2, 2, 2, 2),
downsamples=(True, True, True, True),
enc_dilations=(1, 1, 1, 1, 1),
dec_dilations=(1, 1, 1, 1),
with_cp=False,
conv_cfg=None,
norm_cfg=dict(type='BN'),
act_cfg=dict(type='ReLU'),
upsample_cfg=dict(type='InterpConv'),
norm_eval=False,
dcn=None,
plugins=None):
super().__init__()
assert dcn is None, 'Not implemented yet.'
assert plugins is None, 'Not implemented yet.'
assert len(strides) == num_stages, \
'The length of strides should be equal to num_stages, '\
f'while the strides is {strides}, the length of '\
f'strides is {len(strides)}, and the num_stages is '\
f'{num_stages}.'
assert len(enc_num_convs) == num_stages, \
'The length of enc_num_convs should be equal to num_stages, '\
f'while the enc_num_convs is {enc_num_convs}, the length of '\
f'enc_num_convs is {len(enc_num_convs)}, and the num_stages is '\
f'{num_stages}.'
assert len(dec_num_convs) == (num_stages-1), \
'The length of dec_num_convs should be equal to (num_stages-1), '\
f'while the dec_num_convs is {dec_num_convs}, the length of '\
f'dec_num_convs is {len(dec_num_convs)}, and the num_stages is '\
f'{num_stages}.'
assert len(downsamples) == (num_stages-1), \
'The length of downsamples should be equal to (num_stages-1), '\
f'while the downsamples is {downsamples}, the length of '\
f'downsamples is {len(downsamples)}, and the num_stages is '\
f'{num_stages}.'
assert len(enc_dilations) == num_stages, \
'The length of enc_dilations should be equal to num_stages, '\
f'while the enc_dilations is {enc_dilations}, the length of '\
f'enc_dilations is {len(enc_dilations)}, and the num_stages is '\
f'{num_stages}.'
assert len(dec_dilations) == (num_stages-1), \
'The length of dec_dilations should be equal to (num_stages-1), '\
f'while the dec_dilations is {dec_dilations}, the length of '\
f'dec_dilations is {len(dec_dilations)}, and the num_stages is '\
f'{num_stages}.'
self.num_stages = num_stages
self.strides = strides
self.downsamples = downsamples
self.norm_eval = norm_eval
self.base_channels = base_channels
self.encoder = nn.ModuleList()
self.decoder = nn.ModuleList()
for i in range(num_stages):
enc_conv_block = []
if i != 0:
if strides[i] == 1 and downsamples[i - 1]:
enc_conv_block.append(nn.MaxPool2d(kernel_size=2))
upsample = (strides[i] != 1 or downsamples[i - 1])
self.decoder.append(
UpConvBlock(
conv_block=BasicConvBlock,
in_channels=base_channels * 2**i,
skip_channels=base_channels * 2**(i - 1),
out_channels=base_channels * 2**(i - 1),
num_convs=dec_num_convs[i - 1],
stride=1,
dilation=dec_dilations[i - 1],
with_cp=with_cp,
conv_cfg=conv_cfg,
norm_cfg=norm_cfg,
act_cfg=act_cfg,
upsample_cfg=upsample_cfg if upsample else None,
dcn=None,
plugins=None))
enc_conv_block.append(
BasicConvBlock(
in_channels=in_channels,
out_channels=base_channels * 2**i,
num_convs=enc_num_convs[i],
stride=strides[i],
dilation=enc_dilations[i],
with_cp=with_cp,
conv_cfg=conv_cfg,
norm_cfg=norm_cfg,
act_cfg=act_cfg,
dcn=None,
plugins=None))
self.encoder.append((nn.Sequential(*enc_conv_block)))
in_channels = base_channels * 2**i
def forward(self, x):
self._check_input_divisible(x)
enc_outs = []
for enc in self.encoder:
x = enc(x)
enc_outs.append(x)
dec_outs = [x]
for i in reversed(range(len(self.decoder))):
x = self.decoder[i](enc_outs[i], x)
dec_outs.append(x)
return dec_outs
def train(self, mode=True):
"""Convert the model into training mode while keep normalization layer
freezed."""
super().train(mode)
if mode and self.norm_eval:
for m in self.modules():
# trick: eval have effect on BatchNorm only
if isinstance(m, _BatchNorm):
m.eval()
def _check_input_divisible(self, x):
h, w = x.shape[-2:]
whole_downsample_rate = 1
for i in range(1, self.num_stages):
if self.strides[i] == 2 or self.downsamples[i - 1]:
whole_downsample_rate *= 2
assert (h % whole_downsample_rate == 0) \
and (w % whole_downsample_rate == 0),\
f'The input image size {(h, w)} should be divisible by the whole '\
f'downsample rate {whole_downsample_rate}, when num_stages is '\
f'{self.num_stages}, strides is {self.strides}, and downsamples '\
f'is {self.downsamples}.'
def init_weights(self, pretrained=None):
"""Initialize the weights in backbone.
Args:
pretrained (str, optional): Path to pre-trained weights.
Defaults to None.
"""
if isinstance(pretrained, str):
logger = get_root_logger()
load_checkpoint(self, pretrained, strict=False, logger=logger)
elif pretrained is None:
for m in self.modules():
if isinstance(m, nn.Conv2d):
kaiming_init(m)
elif isinstance(m, (_BatchNorm, nn.GroupNorm)):
constant_init(m, 1)
else:
raise TypeError('pretrained must be a str or None')

View File

@ -0,0 +1,3 @@
from .extractors import * # noqa: F401,F403
from .heads import * # noqa: F401,F403
from .losses import * # noqa: F401,F403

View File

@ -0,0 +1,3 @@
from .sdmgr import SDMGR
__all__ = ['SDMGR']

View File

@ -0,0 +1,87 @@
from torch import nn
from torch.nn import functional as F
from mmdet.core import bbox2roi
from mmdet.models.builder import DETECTORS, build_roi_extractor
from mmdet.models.detectors import SingleStageDetector
@DETECTORS.register_module()
class SDMGR(SingleStageDetector):
"""The implementation of the paper: Spatial Dual-Modality Graph Reasoning
for Key Information Extraction. https://arxiv.org/abs/2103.14470.
Args:
visual_modality (bool): Whether use the visual modality.
"""
def __init__(self,
backbone,
neck=None,
bbox_head=None,
extractor=dict(
type='SingleRoIExtractor',
roi_layer=dict(type='RoIAlign', output_size=7),
featmap_strides=[1]),
visual_modality=False,
train_cfg=None,
test_cfg=None,
pretrained=None):
super().__init__(backbone, neck, bbox_head, train_cfg, test_cfg,
pretrained)
self.visual_modality = visual_modality
if visual_modality:
self.extractor = build_roi_extractor({
**extractor, 'out_channels':
self.backbone.base_channels
})
self.maxpool = nn.MaxPool2d(extractor['roi_layer']['output_size'])
else:
self.extractor = None
def forward_train(self, img, img_metas, relations, texts, gt_bboxes,
gt_labels):
"""
Args:
img (tensor): Input images of shape (N, C, H, W).
Typically these should be mean centered and std scaled.
img_metas (list[dict]): A list of image info dict where each dict
contains: 'img_shape', 'scale_factor', 'flip', and may also
contain 'filename', 'ori_shape', 'pad_shape', and
'img_norm_cfg'. For details of the values of these keys,
please see :class:`mmdet.datasets.pipelines.Collect`.
relations (list[tensor]): Relations between bboxes.
texts (list[tensor]): Texts in bboxes.
gt_bboxes (list[tensor]): Each item is the truth boxes for each
image in [tl_x, tl_y, br_x, br_y] format.
gt_labels (list[tensor]): Class indices corresponding to each box.
Returns:
dict[str, tensor]: A dictionary of loss components.
"""
x = self.extract_feat(img, gt_bboxes)
node_preds, edge_preds = self.bbox_head.forward(relations, texts, x)
return self.bbox_head.loss(node_preds, edge_preds, gt_labels)
def forward_test(self,
img,
img_metas,
relations,
texts,
gt_bboxes,
rescale=False):
x = self.extract_feat(img, gt_bboxes)
node_preds, edge_preds = self.bbox_head.forward(relations, texts, x)
return [
dict(
img_metas=img_metas,
nodes=F.softmax(node_preds, -1),
edges=F.softmax(edge_preds, -1))
]
def extract_feat(self, img, gt_bboxes):
if self.visual_modality:
x = super().extract_feat(img)[-1]
feats = self.maxpool(self.extractor([x], bbox2roi(gt_bboxes)))
return feats.view(feats.size(0), -1)
return None

View File

@ -0,0 +1,3 @@
from .sdmgr_head import SDMGRHead
__all__ = ['SDMGRHead']

View File

@ -0,0 +1,193 @@
import torch
from mmcv.cnn import normal_init
from torch import nn
from torch.nn import functional as F
from mmdet.models.builder import HEADS, build_loss
@HEADS.register_module()
class SDMGRHead(nn.Module):
def __init__(self,
num_chars=92,
visual_dim=64,
fusion_dim=1024,
node_input=32,
node_embed=256,
edge_input=5,
edge_embed=256,
num_gnn=2,
num_classes=26,
loss=dict(type='SDMGRLoss'),
bidirectional=False,
train_cfg=None,
test_cfg=None):
super().__init__()
self.fusion = Block([visual_dim, node_embed], node_embed, fusion_dim)
self.node_embed = nn.Embedding(num_chars, node_input, 0)
hidden = node_embed // 2 if bidirectional else node_embed
self.rnn = nn.LSTM(
input_size=node_input,
hidden_size=hidden,
num_layers=1,
batch_first=True,
bidirectional=bidirectional)
self.edge_embed = nn.Linear(edge_input, edge_embed)
self.gnn_layers = nn.ModuleList(
[GNNLayer(node_embed, edge_embed) for _ in range(num_gnn)])
self.node_cls = nn.Linear(node_embed, num_classes)
self.edge_cls = nn.Linear(edge_embed, 2)
self.loss = build_loss(loss)
def init_weights(self, pretrained=False):
normal_init(self.edge_embed, mean=0, std=0.01)
def forward(self, relations, texts, x=None):
node_nums, char_nums = [], []
for text in texts:
node_nums.append(text.size(0))
char_nums.append((text > 0).sum(-1))
max_num = max([char_num.max() for char_num in char_nums])
all_nodes = torch.cat([
torch.cat(
[text,
text.new_zeros(text.size(0), max_num - text.size(1))], -1)
for text in texts
])
embed_nodes = self.node_embed(all_nodes.clamp(min=0).long())
rnn_nodes, _ = self.rnn(embed_nodes)
nodes = rnn_nodes.new_zeros(*rnn_nodes.shape[::2])
all_nums = torch.cat(char_nums)
valid = all_nums > 0
nodes[valid] = rnn_nodes[valid].gather(
1, (all_nums[valid] - 1).unsqueeze(-1).unsqueeze(-1).expand(
-1, -1, rnn_nodes.size(-1))).squeeze(1)
if x is not None:
nodes = self.fusion([x, nodes])
all_edges = torch.cat(
[rel.view(-1, rel.size(-1)) for rel in relations])
embed_edges = self.edge_embed(all_edges.float())
embed_edges = F.normalize(embed_edges)
for gnn_layer in self.gnn_layers:
nodes, cat_nodes = gnn_layer(nodes, embed_edges, node_nums)
node_cls, edge_cls = self.node_cls(nodes), self.edge_cls(cat_nodes)
return node_cls, edge_cls
class GNNLayer(nn.Module):
def __init__(self, node_dim=256, edge_dim=256):
super().__init__()
self.in_fc = nn.Linear(node_dim * 2 + edge_dim, node_dim)
self.coef_fc = nn.Linear(node_dim, 1)
self.out_fc = nn.Linear(node_dim, node_dim)
self.relu = nn.ReLU()
def forward(self, nodes, edges, nums):
start, cat_nodes = 0, []
for num in nums:
sample_nodes = nodes[start:start + num]
cat_nodes.append(
torch.cat([
sample_nodes.unsqueeze(1).expand(-1, num, -1),
sample_nodes.unsqueeze(0).expand(num, -1, -1)
], -1).view(num**2, -1))
start += num
cat_nodes = torch.cat([torch.cat(cat_nodes), edges], -1)
cat_nodes = self.relu(self.in_fc(cat_nodes))
coefs = self.coef_fc(cat_nodes)
start, residuals = 0, []
for num in nums:
residual = F.softmax(
-torch.eye(num).to(coefs.device).unsqueeze(-1) * 1e9 +
coefs[start:start + num**2].view(num, num, -1), 1)
residuals.append(
(residual *
cat_nodes[start:start + num**2].view(num, num, -1)).sum(1))
start += num**2
nodes += self.relu(self.out_fc(torch.cat(residuals)))
return nodes, cat_nodes
class Block(nn.Module):
def __init__(self,
input_dims,
output_dim,
mm_dim=1600,
chunks=20,
rank=15,
shared=False,
dropout_input=0.,
dropout_pre_lin=0.,
dropout_output=0.,
pos_norm='before_cat'):
super().__init__()
self.rank = rank
self.dropout_input = dropout_input
self.dropout_pre_lin = dropout_pre_lin
self.dropout_output = dropout_output
assert (pos_norm in ['before_cat', 'after_cat'])
self.pos_norm = pos_norm
# Modules
self.linear0 = nn.Linear(input_dims[0], mm_dim)
self.linear1 = self.linear0 if shared \
else nn.Linear(input_dims[1], mm_dim)
self.merge_linears0, self.merge_linears1 =\
nn.ModuleList(), nn.ModuleList()
self.chunks = self.chunk_sizes(mm_dim, chunks)
for size in self.chunks:
ml0 = nn.Linear(size, size * rank)
self.merge_linears0.append(ml0)
ml1 = ml0 if shared else nn.Linear(size, size * rank)
self.merge_linears1.append(ml1)
self.linear_out = nn.Linear(mm_dim, output_dim)
def forward(self, x):
x0 = self.linear0(x[0])
x1 = self.linear1(x[1])
bs = x1.size(0)
if self.dropout_input > 0:
x0 = F.dropout(x0, p=self.dropout_input, training=self.training)
x1 = F.dropout(x1, p=self.dropout_input, training=self.training)
x0_chunks = torch.split(x0, self.chunks, -1)
x1_chunks = torch.split(x1, self.chunks, -1)
zs = []
for x0_c, x1_c, m0, m1 in zip(x0_chunks, x1_chunks,
self.merge_linears0,
self.merge_linears1):
m = m0(x0_c) * m1(x1_c) # bs x split_size*rank
m = m.view(bs, self.rank, -1)
z = torch.sum(m, 1)
if self.pos_norm == 'before_cat':
z = torch.sqrt(F.relu(z)) - torch.sqrt(F.relu(-z))
z = F.normalize(z)
zs.append(z)
z = torch.cat(zs, 1)
if self.pos_norm == 'after_cat':
z = torch.sqrt(F.relu(z)) - torch.sqrt(F.relu(-z))
z = F.normalize(z)
if self.dropout_pre_lin > 0:
z = F.dropout(z, p=self.dropout_pre_lin, training=self.training)
z = self.linear_out(z)
if self.dropout_output > 0:
z = F.dropout(z, p=self.dropout_output, training=self.training)
return z
@staticmethod
def chunk_sizes(dim, chunks):
split_size = (dim + chunks - 1) // chunks
sizes_list = [split_size] * chunks
sizes_list[-1] = sizes_list[-1] - (sum(sizes_list) - dim)
return sizes_list

View File

@ -0,0 +1,3 @@
from .sdmgr_loss import SDMGRLoss
__all__ = ['SDMGRLoss']

View File

@ -0,0 +1,39 @@
import torch
from torch import nn
from mmdet.models.builder import LOSSES
from mmdet.models.losses import accuracy
@LOSSES.register_module()
class SDMGRLoss(nn.Module):
"""The implementation the loss of key information extraction proposed in
the paper: Spatial Dual-Modality Graph Reasoning for Key Information
Extraction.
https://arxiv.org/abs/2103.14470.
"""
def __init__(self, node_weight=1.0, edge_weight=1.0, ignore=0):
super().__init__()
self.loss_node = nn.CrossEntropyLoss(ignore_index=ignore)
self.loss_edge = nn.CrossEntropyLoss(ignore_index=-1)
self.node_weight = node_weight
self.edge_weight = edge_weight
self.ignore = ignore
def forward(self, node_preds, edge_preds, gts):
node_gts, edge_gts = [], []
for gt in gts:
node_gts.append(gt[:, 0])
edge_gts.append(gt[:, 1:].contiguous().view(-1))
node_gts = torch.cat(node_gts).long()
edge_gts = torch.cat(edge_gts).long()
node_valids = torch.nonzero(node_gts != self.ignore).view(-1)
edge_valids = torch.nonzero(edge_gts != -1).view(-1)
return dict(
loss_node=self.node_weight * self.loss_node(node_preds, node_gts),
loss_edge=self.edge_weight * self.loss_edge(edge_preds, edge_gts),
acc_node=accuracy(node_preds[node_valids], node_gts[node_valids]),
acc_edge=accuracy(edge_preds[edge_valids], edge_gts[edge_valids]))