mirror of https://github.com/open-mmlab/mmocr.git
[feature]: add code for kie and textsnake config
parent
3ed6aaa4e4
commit
b8156a3a77
|
@ -0,0 +1,25 @@
|
||||||
|
# Spatial Dual-Modality Graph Reasoning for Key Information Extraction
|
||||||
|
|
||||||
|
## Introduction
|
||||||
|
|
||||||
|
[ALGORITHM]
|
||||||
|
|
||||||
|
```bibtex
|
||||||
|
@misc{sun2021spatial,
|
||||||
|
title={Spatial Dual-Modality Graph Reasoning for Key Information Extraction},
|
||||||
|
author={Hongbin Sun and Zhanghui Kuang and Xiaoyu Yue and Chenhao Lin and Wayne Zhang},
|
||||||
|
year={2021},
|
||||||
|
eprint={2103.14470},
|
||||||
|
archivePrefix={arXiv},
|
||||||
|
primaryClass={cs.CV}
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
## Results and models
|
||||||
|
|
||||||
|
### WildReceipt
|
||||||
|
|
||||||
|
| Method | Modality | Macro F1-Score | Download |
|
||||||
|
| :--------------------------------------------------------------------: | :--------------: | :------------: | :-------------------------------------------------------------------------------------------------------------------------------------: |
|
||||||
|
| [sdmgr_unet16](/configs/kie/sdmgr/sdmgr_unet16_60e_wildreceipt.py) | Visual + Textual | 0.880 | [model](https://download.openmmlab.com/mmocr/kie/sdmgr/todo.pth) \| [log](https://download.openmmlab.com/mmocr/kie/sdmgr/todo.log.json) |
|
||||||
|
| [sdmgr_novisual](/configs/kie/sdmgr/sdmgr_novisual_60e_wildreceipt.py) | Textual | 0.871 | [model](https://download.openmmlab.com/mmocr/kie/sdmgr/todo.pth) \| [log](https://download.openmmlab.com/mmocr/kie/sdmgr/todo.log.json) |
|
|
@ -0,0 +1,99 @@
|
||||||
|
dataset_type = 'KIEDataset'
|
||||||
|
data_root = 'data/wildreceipt'
|
||||||
|
img_norm_cfg = dict(
|
||||||
|
mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True)
|
||||||
|
max_scale, min_scale = 1024, 512
|
||||||
|
|
||||||
|
train_pipeline = [
|
||||||
|
dict(type='LoadImageFromFile'),
|
||||||
|
dict(type='LoadAnnotations'),
|
||||||
|
dict(type='Resize', img_scale=(max_scale, min_scale), keep_ratio=True),
|
||||||
|
dict(type='RandomFlip', flip_ratio=0.),
|
||||||
|
dict(type='Normalize', **img_norm_cfg),
|
||||||
|
dict(type='Pad', size_divisor=32),
|
||||||
|
dict(type='KIEFormatBundle'),
|
||||||
|
dict(
|
||||||
|
type='Collect',
|
||||||
|
keys=['img', 'relations', 'texts', 'gt_bboxes', 'gt_labels'])
|
||||||
|
]
|
||||||
|
test_pipeline = [
|
||||||
|
dict(type='LoadImageFromFile'),
|
||||||
|
dict(type='LoadAnnotations'),
|
||||||
|
dict(type='Resize', img_scale=(max_scale, min_scale), keep_ratio=True),
|
||||||
|
dict(type='RandomFlip', flip_ratio=0.),
|
||||||
|
dict(type='Normalize', **img_norm_cfg),
|
||||||
|
dict(type='Pad', size_divisor=32),
|
||||||
|
dict(type='KIEFormatBundle'),
|
||||||
|
dict(type='Collect', keys=['img', 'relations', 'texts', 'gt_bboxes'])
|
||||||
|
]
|
||||||
|
|
||||||
|
vocab_file = 'dict.txt'
|
||||||
|
class_file = 'class_list.txt'
|
||||||
|
|
||||||
|
data = dict(
|
||||||
|
samples_per_gpu=4,
|
||||||
|
workers_per_gpu=0,
|
||||||
|
train=dict(
|
||||||
|
type=dataset_type,
|
||||||
|
ann_file='train.txt',
|
||||||
|
pipeline=train_pipeline,
|
||||||
|
data_root=data_root,
|
||||||
|
vocab_file=vocab_file,
|
||||||
|
class_file=class_file),
|
||||||
|
val=dict(
|
||||||
|
type=dataset_type,
|
||||||
|
ann_file='test.txt',
|
||||||
|
pipeline=test_pipeline,
|
||||||
|
data_root=data_root,
|
||||||
|
vocab_file=vocab_file,
|
||||||
|
class_file=class_file),
|
||||||
|
test=dict(
|
||||||
|
type=dataset_type,
|
||||||
|
ann_file='test.txt',
|
||||||
|
pipeline=test_pipeline,
|
||||||
|
data_root=data_root,
|
||||||
|
vocab_file=vocab_file,
|
||||||
|
class_file=class_file))
|
||||||
|
|
||||||
|
evaluation = dict(
|
||||||
|
interval=1,
|
||||||
|
metric='macro_f1',
|
||||||
|
metric_options=dict(
|
||||||
|
macro_f1=dict(
|
||||||
|
ignores=[0, 2, 4, 6, 8, 10, 12, 14, 16, 18, 20, 22, 24, 25])))
|
||||||
|
|
||||||
|
model = dict(
|
||||||
|
type='SDMGR',
|
||||||
|
backbone=dict(type='UNet', base_channels=16),
|
||||||
|
bbox_head=dict(
|
||||||
|
type='SDMGRHead', visual_dim=16, num_chars=92, num_classes=26),
|
||||||
|
visual_modality=False,
|
||||||
|
train_cfg=None,
|
||||||
|
test_cfg=None)
|
||||||
|
|
||||||
|
optimizer = dict(type='Adam', weight_decay=0.0001)
|
||||||
|
optimizer_config = dict(grad_clip=None)
|
||||||
|
lr_config = dict(
|
||||||
|
policy='step',
|
||||||
|
warmup='linear',
|
||||||
|
warmup_iters=1,
|
||||||
|
warmup_ratio=1,
|
||||||
|
step=[40, 50])
|
||||||
|
total_epochs = 60
|
||||||
|
|
||||||
|
checkpoint_config = dict(interval=1)
|
||||||
|
log_config = dict(
|
||||||
|
interval=50,
|
||||||
|
hooks=[
|
||||||
|
dict(type='TextLoggerHook'),
|
||||||
|
# dict(
|
||||||
|
# type='PaviLoggerHook',
|
||||||
|
# add_last_ckpt=True,
|
||||||
|
# interval=5,
|
||||||
|
# init_kwargs=dict(project='kie')),
|
||||||
|
])
|
||||||
|
dist_params = dict(backend='nccl')
|
||||||
|
log_level = 'INFO'
|
||||||
|
load_from = None
|
||||||
|
resume_from = None
|
||||||
|
workflow = [('train', 1)]
|
|
@ -0,0 +1,99 @@
|
||||||
|
dataset_type = 'KIEDataset'
|
||||||
|
data_root = 'data/wildreceipt'
|
||||||
|
img_norm_cfg = dict(
|
||||||
|
mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True)
|
||||||
|
max_scale, min_scale = 1024, 512
|
||||||
|
|
||||||
|
train_pipeline = [
|
||||||
|
dict(type='LoadImageFromFile'),
|
||||||
|
dict(type='LoadAnnotations'),
|
||||||
|
dict(type='Resize', img_scale=(max_scale, min_scale), keep_ratio=True),
|
||||||
|
dict(type='RandomFlip', flip_ratio=0.),
|
||||||
|
dict(type='Normalize', **img_norm_cfg),
|
||||||
|
dict(type='Pad', size_divisor=32),
|
||||||
|
dict(type='KIEFormatBundle'),
|
||||||
|
dict(
|
||||||
|
type='Collect',
|
||||||
|
keys=['img', 'relations', 'texts', 'gt_bboxes', 'gt_labels'])
|
||||||
|
]
|
||||||
|
test_pipeline = [
|
||||||
|
dict(type='LoadImageFromFile'),
|
||||||
|
dict(type='LoadAnnotations'),
|
||||||
|
dict(type='Resize', img_scale=(max_scale, min_scale), keep_ratio=True),
|
||||||
|
dict(type='RandomFlip', flip_ratio=0.),
|
||||||
|
dict(type='Normalize', **img_norm_cfg),
|
||||||
|
dict(type='Pad', size_divisor=32),
|
||||||
|
dict(type='KIEFormatBundle'),
|
||||||
|
dict(type='Collect', keys=['img', 'relations', 'texts', 'gt_bboxes'])
|
||||||
|
]
|
||||||
|
|
||||||
|
vocab_file = 'dict.txt'
|
||||||
|
class_file = 'class_list.txt'
|
||||||
|
|
||||||
|
data = dict(
|
||||||
|
samples_per_gpu=4,
|
||||||
|
workers_per_gpu=0,
|
||||||
|
train=dict(
|
||||||
|
type=dataset_type,
|
||||||
|
ann_file='train.txt',
|
||||||
|
pipeline=train_pipeline,
|
||||||
|
data_root=data_root,
|
||||||
|
vocab_file=vocab_file,
|
||||||
|
class_file=class_file),
|
||||||
|
val=dict(
|
||||||
|
type=dataset_type,
|
||||||
|
ann_file='test.txt',
|
||||||
|
pipeline=test_pipeline,
|
||||||
|
data_root=data_root,
|
||||||
|
vocab_file=vocab_file,
|
||||||
|
class_file=class_file),
|
||||||
|
test=dict(
|
||||||
|
type=dataset_type,
|
||||||
|
ann_file='test.txt',
|
||||||
|
pipeline=test_pipeline,
|
||||||
|
data_root=data_root,
|
||||||
|
vocab_file=vocab_file,
|
||||||
|
class_file=class_file))
|
||||||
|
|
||||||
|
evaluation = dict(
|
||||||
|
interval=1,
|
||||||
|
metric='macro_f1',
|
||||||
|
metric_options=dict(
|
||||||
|
macro_f1=dict(
|
||||||
|
ignores=[0, 2, 4, 6, 8, 10, 12, 14, 16, 18, 20, 22, 24, 25])))
|
||||||
|
|
||||||
|
model = dict(
|
||||||
|
type='SDMGR',
|
||||||
|
backbone=dict(type='UNet', base_channels=16),
|
||||||
|
bbox_head=dict(
|
||||||
|
type='SDMGRHead', visual_dim=16, num_chars=92, num_classes=26),
|
||||||
|
visual_modality=True,
|
||||||
|
train_cfg=None,
|
||||||
|
test_cfg=None)
|
||||||
|
|
||||||
|
optimizer = dict(type='Adam', weight_decay=0.0001)
|
||||||
|
optimizer_config = dict(grad_clip=None)
|
||||||
|
lr_config = dict(
|
||||||
|
policy='step',
|
||||||
|
warmup='linear',
|
||||||
|
warmup_iters=1,
|
||||||
|
warmup_ratio=1,
|
||||||
|
step=[40, 50])
|
||||||
|
total_epochs = 60
|
||||||
|
|
||||||
|
checkpoint_config = dict(interval=1)
|
||||||
|
log_config = dict(
|
||||||
|
interval=50,
|
||||||
|
hooks=[
|
||||||
|
dict(type='TextLoggerHook'),
|
||||||
|
# dict(
|
||||||
|
# type='PaviLoggerHook',
|
||||||
|
# add_last_ckpt=True,
|
||||||
|
# interval=5,
|
||||||
|
# init_kwargs=dict(project='kie')),
|
||||||
|
])
|
||||||
|
dist_params = dict(backend='nccl')
|
||||||
|
log_level = 'INFO'
|
||||||
|
load_from = None
|
||||||
|
resume_from = None
|
||||||
|
workflow = [('train', 1)]
|
|
@ -0,0 +1,23 @@
|
||||||
|
# Textsnake
|
||||||
|
|
||||||
|
## Introduction
|
||||||
|
|
||||||
|
[ALGORITHM]
|
||||||
|
|
||||||
|
```bibtex
|
||||||
|
@article{long2018textsnake,
|
||||||
|
title={TextSnake: A Flexible Representation for Detecting Text of Arbitrary Shapes},
|
||||||
|
author={Long, Shangbang and Ruan, Jiaqiang and Zhang, Wenjie and He, Xin and Wu, Wenhao and Yao, Cong},
|
||||||
|
booktitle={ECCV},
|
||||||
|
pages={20-36},
|
||||||
|
year={2018}
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
## Results and models
|
||||||
|
|
||||||
|
### CTW1500
|
||||||
|
|
||||||
|
| Method | Pretrained Model | Training set | Test set | #epochs | Test size | Recall | Precision | Hmean | Download |
|
||||||
|
| :----------------------------------------------------------------------------: | :--------------: | :-----------: | :----------: | :-----: | :-------: | :----: | :-------: | :---: | :-------------------: |
|
||||||
|
| [TextSnake](/configs/textdet/textsnake/textsnake_r50_fpn_unet_600e_ctw1500.py) | ImageNet | CTW1500 Train | CTW1500 Test | 1200 | 736 | 0.795 | 0.840 | 0.817 | [model](https://download.openmmlab.com/mmocr/textdet/textsnake/textsnake_r50_fpn_unet_1200e_ctw1500-27f65b64.pth) | [config](https://download.openmmlab.com/mmocr/textdet/textsnake/textsnake_r50_fpn_unet_1200e_ctw1500.py) |
|
|
@ -0,0 +1,113 @@
|
||||||
|
_base_ = [
|
||||||
|
'../../_base_/schedules/schedule_1200e.py',
|
||||||
|
'../../_base_/default_runtime.py'
|
||||||
|
]
|
||||||
|
model = dict(
|
||||||
|
type='TextSnake',
|
||||||
|
pretrained='torchvision://resnet50',
|
||||||
|
backbone=dict(
|
||||||
|
type='ResNet',
|
||||||
|
depth=50,
|
||||||
|
num_stages=4,
|
||||||
|
out_indices=(0, 1, 2, 3),
|
||||||
|
frozen_stages=-1,
|
||||||
|
norm_cfg=dict(type='BN', requires_grad=True),
|
||||||
|
norm_eval=True,
|
||||||
|
style='caffe'),
|
||||||
|
neck=dict(
|
||||||
|
type='FPN_UNET', in_channels=[256, 512, 1024, 2048], out_channels=32),
|
||||||
|
bbox_head=dict(
|
||||||
|
type='TextSnakeHead',
|
||||||
|
in_channels=32,
|
||||||
|
text_repr_type='poly',
|
||||||
|
loss=dict(type='TextSnakeLoss')),
|
||||||
|
train_cfg=None,
|
||||||
|
test_cfg=None)
|
||||||
|
|
||||||
|
dataset_type = 'IcdarDataset'
|
||||||
|
data_root = 'data/ctw1500/'
|
||||||
|
img_norm_cfg = dict(
|
||||||
|
mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True)
|
||||||
|
|
||||||
|
train_pipeline = [
|
||||||
|
dict(type='LoadImageFromFile'),
|
||||||
|
dict(
|
||||||
|
type='LoadTextAnnotations',
|
||||||
|
with_bbox=True,
|
||||||
|
with_mask=True,
|
||||||
|
poly2mask=False),
|
||||||
|
dict(type='ColorJitter', brightness=32.0 / 255, saturation=0.5),
|
||||||
|
dict(type='Normalize', **img_norm_cfg),
|
||||||
|
dict(
|
||||||
|
type='RandomCropPolyInstances',
|
||||||
|
instance_key='gt_masks',
|
||||||
|
crop_ratio=0.65,
|
||||||
|
min_side_ratio=0.3),
|
||||||
|
dict(
|
||||||
|
type='RandomRotatePolyInstances',
|
||||||
|
rotate_ratio=0.5,
|
||||||
|
max_angle=20,
|
||||||
|
pad_with_fixed_color=False),
|
||||||
|
dict(
|
||||||
|
type='ScaleAspectJitter',
|
||||||
|
img_scale=[(3000, 736)], # unused
|
||||||
|
ratio_range=(0.7, 1.3),
|
||||||
|
aspect_ratio_range=(0.9, 1.1),
|
||||||
|
multiscale_mode='value',
|
||||||
|
long_size_bound=800,
|
||||||
|
short_size_bound=480,
|
||||||
|
resize_type='long_short_bound',
|
||||||
|
keep_ratio=False),
|
||||||
|
dict(type='SquareResizePad', target_size=800, pad_ratio=0.6),
|
||||||
|
dict(type='RandomFlip', flip_ratio=0.5, direction='horizontal'),
|
||||||
|
dict(type='TextSnakeTargets'),
|
||||||
|
dict(type='Pad', size_divisor=32),
|
||||||
|
dict(
|
||||||
|
type='CustomFormatBundle',
|
||||||
|
keys=[
|
||||||
|
'gt_text_mask', 'gt_center_region_mask', 'gt_mask',
|
||||||
|
'gt_radius_map', 'gt_sin_map', 'gt_cos_map'
|
||||||
|
],
|
||||||
|
visualize=dict(flag=False, boundary_key='gt_text_mask')),
|
||||||
|
dict(
|
||||||
|
type='Collect',
|
||||||
|
keys=[
|
||||||
|
'img', 'gt_text_mask', 'gt_center_region_mask', 'gt_mask',
|
||||||
|
'gt_radius_map', 'gt_sin_map', 'gt_cos_map'
|
||||||
|
])
|
||||||
|
]
|
||||||
|
test_pipeline = [
|
||||||
|
dict(type='LoadImageFromFile'),
|
||||||
|
dict(
|
||||||
|
type='MultiScaleFlipAug',
|
||||||
|
img_scale=(1333, 736),
|
||||||
|
flip=False,
|
||||||
|
transforms=[
|
||||||
|
dict(type='Resize', img_scale=(1333, 736), keep_ratio=True),
|
||||||
|
dict(type='Normalize', **img_norm_cfg),
|
||||||
|
dict(type='Pad', size_divisor=32),
|
||||||
|
dict(type='ImageToTensor', keys=['img']),
|
||||||
|
dict(type='Collect', keys=['img']),
|
||||||
|
])
|
||||||
|
]
|
||||||
|
|
||||||
|
data = dict(
|
||||||
|
samples_per_gpu=4,
|
||||||
|
workers_per_gpu=4,
|
||||||
|
train=dict(
|
||||||
|
type=dataset_type,
|
||||||
|
ann_file=data_root + '/instances_training.json',
|
||||||
|
img_prefix=data_root + '/imgs',
|
||||||
|
pipeline=train_pipeline),
|
||||||
|
val=dict(
|
||||||
|
type=dataset_type,
|
||||||
|
ann_file=data_root + '/instances_test.json',
|
||||||
|
img_prefix=data_root + '/imgs',
|
||||||
|
pipeline=test_pipeline),
|
||||||
|
test=dict(
|
||||||
|
type=dataset_type,
|
||||||
|
ann_file=data_root + '/instances_test.json',
|
||||||
|
img_prefix=data_root + '/imgs',
|
||||||
|
pipeline=test_pipeline))
|
||||||
|
|
||||||
|
evaluation = dict(interval=10, metric='hmean-iou')
|
|
@ -0,0 +1,27 @@
|
||||||
|
import torch
|
||||||
|
|
||||||
|
|
||||||
|
def compute_f1_score(preds, gts, ignores=[]):
|
||||||
|
"""Compute the F1-score of prediction.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
preds (Tensor): The predicted probability NxC map
|
||||||
|
with N and C being the sample number and class
|
||||||
|
number respectively.
|
||||||
|
gts (Tensor): The ground truth vector of size N.
|
||||||
|
ignores (list): The index set of classes that are ignored when
|
||||||
|
reporting results.
|
||||||
|
Note: all samples are participated in computing.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The numpy list of f1-scores of valid classes.
|
||||||
|
"""
|
||||||
|
C = preds.size(1)
|
||||||
|
classes = torch.LongTensor(sorted(set(range(C)) - set(ignores)))
|
||||||
|
hist = torch.bincount(
|
||||||
|
gts * C + preds.argmax(1), minlength=C**2).view(C, C).float()
|
||||||
|
diag = torch.diag(hist)
|
||||||
|
recalls = diag / hist.sum(1).clamp(min=1)
|
||||||
|
precisions = diag / hist.sum(0).clamp(min=1)
|
||||||
|
f1 = 2 * recalls * precisions / (recalls + precisions).clamp(min=1e-8)
|
||||||
|
return f1[classes].cpu().numpy()
|
|
@ -0,0 +1,295 @@
|
||||||
|
import copy
|
||||||
|
from os import path as osp
|
||||||
|
|
||||||
|
import mmcv
|
||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
from matplotlib import pyplot as plt
|
||||||
|
from PIL import Image
|
||||||
|
|
||||||
|
from mmdet.datasets.builder import DATASETS
|
||||||
|
from mmdet.datasets.custom import CustomDataset
|
||||||
|
from mmocr.core import compute_f1_score
|
||||||
|
|
||||||
|
|
||||||
|
@DATASETS.register_module()
|
||||||
|
class KIEDataset(CustomDataset):
|
||||||
|
|
||||||
|
def __init__(self,
|
||||||
|
ann_file,
|
||||||
|
pipeline=None,
|
||||||
|
data_root=None,
|
||||||
|
img_prefix='',
|
||||||
|
ann_prefix='',
|
||||||
|
vocab_file=None,
|
||||||
|
class_file=None,
|
||||||
|
norm=10.,
|
||||||
|
thresholds=dict(edge=0.5),
|
||||||
|
directed=False,
|
||||||
|
**kwargs):
|
||||||
|
self.ann_prefix = ann_prefix
|
||||||
|
self.norm = norm
|
||||||
|
self.thresholds = thresholds
|
||||||
|
self.directed = directed
|
||||||
|
|
||||||
|
if data_root is not None:
|
||||||
|
if not osp.isabs(ann_file):
|
||||||
|
self.ann_file = osp.join(data_root, ann_file)
|
||||||
|
if not (ann_prefix is None or osp.isabs(ann_prefix)):
|
||||||
|
self.ann_prefix = osp.join(data_root, ann_prefix)
|
||||||
|
|
||||||
|
self.vocab = dict({'': 0})
|
||||||
|
vocab_file = osp.join(data_root, vocab_file)
|
||||||
|
if osp.exists(vocab_file):
|
||||||
|
with open(vocab_file, 'r') as fid:
|
||||||
|
for idx, char in enumerate(fid.readlines(), 1):
|
||||||
|
self.vocab[char.strip('\n')] = idx
|
||||||
|
else:
|
||||||
|
self.construct_dict(self.ann_file)
|
||||||
|
with open(vocab_file, 'w') as fid:
|
||||||
|
for key in self.vocab:
|
||||||
|
if key:
|
||||||
|
fid.write('{}\n'.format(key))
|
||||||
|
|
||||||
|
super().__init__(
|
||||||
|
ann_file,
|
||||||
|
pipeline,
|
||||||
|
data_root=data_root,
|
||||||
|
img_prefix=img_prefix,
|
||||||
|
**kwargs)
|
||||||
|
|
||||||
|
self.idx_to_cls = dict()
|
||||||
|
with open(osp.join(data_root, class_file), 'r') as fid:
|
||||||
|
for line in fid.readlines():
|
||||||
|
idx, cls = line.split()
|
||||||
|
self.idx_to_cls[int(idx)] = cls
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _split_edge(line):
|
||||||
|
text = ','.join(line[8:-1])
|
||||||
|
if ';' in text and text.split(';')[0].isdecimal():
|
||||||
|
edge, text = text.split(';', 1)
|
||||||
|
edge = int(edge)
|
||||||
|
else:
|
||||||
|
edge = 0
|
||||||
|
return edge, text
|
||||||
|
|
||||||
|
def construct_dict(self, ann_file):
|
||||||
|
img_infos = mmcv.list_from_file(ann_file)
|
||||||
|
for img_info in img_infos:
|
||||||
|
_, annname = img_info.split()
|
||||||
|
if self.ann_prefix:
|
||||||
|
annname = osp.join(self.ann_prefix, annname)
|
||||||
|
with open(annname, 'r') as fid:
|
||||||
|
lines = fid.readlines()
|
||||||
|
|
||||||
|
for line in lines:
|
||||||
|
line = line.strip().split(',')
|
||||||
|
_, text = self._split_edge(line)
|
||||||
|
for c in text:
|
||||||
|
if c not in self.vocab:
|
||||||
|
self.vocab[c] = len(self.vocab)
|
||||||
|
self.vocab = dict(
|
||||||
|
{k: idx
|
||||||
|
for idx, k in enumerate(sorted(self.vocab.keys()))})
|
||||||
|
|
||||||
|
def convert_text(self, text):
|
||||||
|
return [self.vocab[c] for c in text if c in self.vocab]
|
||||||
|
|
||||||
|
def parse_lines(self, annname):
|
||||||
|
boxes, edges, texts, chars, labels = [], [], [], [], []
|
||||||
|
|
||||||
|
if self.ann_prefix:
|
||||||
|
annname = osp.join(self.ann_prefix, annname)
|
||||||
|
|
||||||
|
with open(annname, 'r') as fid:
|
||||||
|
for line in fid.readlines():
|
||||||
|
line = line.strip().split(',')
|
||||||
|
boxes.append(list(map(int, line[:8])))
|
||||||
|
edge, text = self._split_edge(line)
|
||||||
|
chars.append(text)
|
||||||
|
text = self.convert_text(text)
|
||||||
|
texts.append(text)
|
||||||
|
edges.append(edge)
|
||||||
|
labels.append(int(line[-1]))
|
||||||
|
return dict(
|
||||||
|
boxes=boxes, edges=edges, texts=texts, chars=chars, labels=labels)
|
||||||
|
|
||||||
|
def format_results(self, results):
|
||||||
|
boxes = torch.Tensor(results['boxes'])[:, [0, 1, 4, 5]].cuda()
|
||||||
|
|
||||||
|
if 'nodes' in results:
|
||||||
|
nodes, edges = results['nodes'], results['edges']
|
||||||
|
labels = nodes.argmax(-1)
|
||||||
|
num_nodes = nodes.size(0)
|
||||||
|
edges = edges[:, -1].view(num_nodes, num_nodes)
|
||||||
|
else:
|
||||||
|
labels = torch.Tensor(results['labels']).cuda()
|
||||||
|
edges = torch.Tensor(results['edges']).cuda()
|
||||||
|
boxes = torch.cat([boxes, labels[:, None].float()], -1)
|
||||||
|
|
||||||
|
return {
|
||||||
|
**{
|
||||||
|
k: v
|
||||||
|
for k, v in results.items() if k not in ['boxes', 'edges']
|
||||||
|
}, 'boxes': boxes,
|
||||||
|
'edges': edges,
|
||||||
|
'points': results['boxes']
|
||||||
|
}
|
||||||
|
|
||||||
|
def plot(self, results):
|
||||||
|
img_name = osp.join(self.img_prefix, results['filename'])
|
||||||
|
img = plt.imread(img_name)
|
||||||
|
plt.imshow(img)
|
||||||
|
|
||||||
|
boxes, texts = results['points'], results['chars']
|
||||||
|
num_nodes = len(boxes)
|
||||||
|
if 'scores' in results:
|
||||||
|
scores = results['scores']
|
||||||
|
else:
|
||||||
|
scores = np.ones(num_nodes)
|
||||||
|
for box, text, score in zip(boxes, texts, scores):
|
||||||
|
xs, ys = [], []
|
||||||
|
for idx in range(0, 10, 2):
|
||||||
|
xs.append(box[idx % 8])
|
||||||
|
ys.append(box[(idx + 1) % 8])
|
||||||
|
plt.plot(xs, ys, 'g')
|
||||||
|
plt.annotate(
|
||||||
|
'{}: {:.4f}'.format(text, score), (box[0], box[1]), color='g')
|
||||||
|
|
||||||
|
if 'nodes' in results:
|
||||||
|
nodes = results['nodes']
|
||||||
|
inds = nodes.argmax(-1)
|
||||||
|
else:
|
||||||
|
nodes = np.ones((num_nodes, 3))
|
||||||
|
inds = results['labels']
|
||||||
|
for i in range(num_nodes):
|
||||||
|
plt.annotate(
|
||||||
|
'{}: {:.4f}'.format(
|
||||||
|
self.idx_to_cls(inds[i] - 1), nodes[i, inds[i]]),
|
||||||
|
(boxes[i][6], boxes[i][7]),
|
||||||
|
color='r' if inds[i] == 1 else 'b')
|
||||||
|
edges = results['edges']
|
||||||
|
if 'nodes' not in results:
|
||||||
|
edges = (edges[:, None] == edges[None]).float()
|
||||||
|
for j in range(i + 1, num_nodes):
|
||||||
|
edge_score = max(edges[i][j], edges[j][i])
|
||||||
|
if edge_score > self.thresholds['edge']:
|
||||||
|
x1 = sum(boxes[i][:3:2]) // 2
|
||||||
|
y1 = sum(boxes[i][3:6:2]) // 2
|
||||||
|
x2 = sum(boxes[j][:3:2]) // 2
|
||||||
|
y2 = sum(boxes[j][3:6:2]) // 2
|
||||||
|
plt.plot((x1, x2), (y1, y2), 'r')
|
||||||
|
plt.annotate(
|
||||||
|
'{:.4f}'.format(edge_score),
|
||||||
|
((x1 + x2) // 2, (y1 + y2) // 2),
|
||||||
|
color='r')
|
||||||
|
|
||||||
|
def compute_relation(self, boxes):
|
||||||
|
x1s, y1s = boxes[:, 0:1], boxes[:, 1:2]
|
||||||
|
x2s, y2s = boxes[:, 4:5], boxes[:, 5:6]
|
||||||
|
ws, hs = x2s - x1s + 1, np.maximum(y2s - y1s + 1, 1)
|
||||||
|
dxs = (x1s[:, 0][None] - x1s) / self.norm
|
||||||
|
dys = (y1s[:, 0][None] - y1s) / self.norm
|
||||||
|
xhhs, xwhs = hs[:, 0][None] / hs, ws[:, 0][None] / hs
|
||||||
|
whs = ws / hs + np.zeros_like(xhhs)
|
||||||
|
relations = np.stack([dxs, dys, whs, xhhs, xwhs], -1)
|
||||||
|
bboxes = np.concatenate([x1s, y1s, x2s, y2s], -1).astype(np.float32)
|
||||||
|
return relations, bboxes
|
||||||
|
|
||||||
|
def ann_numpy(self, results):
|
||||||
|
boxes, texts = results['boxes'], results['texts']
|
||||||
|
boxes = np.array(boxes, np.int32)
|
||||||
|
if boxes[0, 1] > boxes[0, -1]:
|
||||||
|
boxes = boxes[:, [6, 7, 4, 5, 2, 3, 0, 1]]
|
||||||
|
relations, bboxes = self.compute_relation(boxes)
|
||||||
|
|
||||||
|
labels = results.get('labels', None)
|
||||||
|
if labels is not None:
|
||||||
|
labels = np.array(labels, np.int32)
|
||||||
|
edges = results.get('edges', None)
|
||||||
|
if edges is not None:
|
||||||
|
labels = labels[:, None]
|
||||||
|
edges = np.array(edges)
|
||||||
|
edges = (edges[:, None] == edges[None, :]).astype(np.int32)
|
||||||
|
if self.directed:
|
||||||
|
edges = (edges & labels == 1).astype(np.int32)
|
||||||
|
np.fill_diagonal(edges, -1)
|
||||||
|
labels = np.concatenate([labels, edges], -1)
|
||||||
|
return dict(
|
||||||
|
bboxes=bboxes,
|
||||||
|
relations=relations,
|
||||||
|
texts=self.pad_text(texts),
|
||||||
|
labels=labels)
|
||||||
|
|
||||||
|
def image_size(self, filename):
|
||||||
|
img_path = osp.join(self.img_prefix, filename)
|
||||||
|
img = Image.open(img_path)
|
||||||
|
return img.size
|
||||||
|
|
||||||
|
def load_annotations(self, ann_file):
|
||||||
|
self.anns, data_infos = [], []
|
||||||
|
|
||||||
|
self.gts = dict()
|
||||||
|
img_infos = mmcv.list_from_file(ann_file)
|
||||||
|
for img_info in img_infos:
|
||||||
|
filename, annname = img_info.split()
|
||||||
|
results = self.parse_lines(annname)
|
||||||
|
width, height = self.image_size(filename)
|
||||||
|
|
||||||
|
data_infos.append(
|
||||||
|
dict(filename=filename, width=width, height=height))
|
||||||
|
ann = self.ann_numpy(results)
|
||||||
|
self.anns.append(ann)
|
||||||
|
|
||||||
|
return data_infos
|
||||||
|
|
||||||
|
def pad_text(self, texts):
|
||||||
|
max_len = max([len(text) for text in texts])
|
||||||
|
padded_texts = -np.ones((len(texts), max_len), np.int32)
|
||||||
|
for idx, text in enumerate(texts):
|
||||||
|
padded_texts[idx, :len(text)] = np.array(text)
|
||||||
|
return padded_texts
|
||||||
|
|
||||||
|
def get_ann_info(self, idx):
|
||||||
|
return self.anns[idx]
|
||||||
|
|
||||||
|
def prepare_test_img(self, idx):
|
||||||
|
return self.prepare_train_img(idx)
|
||||||
|
|
||||||
|
def evaluate(self,
|
||||||
|
results,
|
||||||
|
metric='macro_f1',
|
||||||
|
metric_options=dict(macro_f1=dict(ignores=[])),
|
||||||
|
**kwargs):
|
||||||
|
# allow some kwargs to pass through
|
||||||
|
assert set(kwargs).issubset(['logger'])
|
||||||
|
|
||||||
|
# Protect ``metric_options`` since it uses mutable value as default
|
||||||
|
metric_options = copy.deepcopy(metric_options)
|
||||||
|
|
||||||
|
metrics = metric if isinstance(metric, list) else [metric]
|
||||||
|
allowed_metrics = ['macro_f1']
|
||||||
|
for m in metrics:
|
||||||
|
if m not in allowed_metrics:
|
||||||
|
raise KeyError(f'metric {m} is not supported')
|
||||||
|
|
||||||
|
return self.compute_macro_f1(results, **metric_options['macro_f1'])
|
||||||
|
|
||||||
|
def compute_macro_f1(self, results, ignores=[]):
|
||||||
|
node_preds = []
|
||||||
|
for result in results:
|
||||||
|
node_preds.append(result['nodes'])
|
||||||
|
node_preds = torch.cat(node_preds)
|
||||||
|
|
||||||
|
node_gts = [
|
||||||
|
torch.from_numpy(ann['labels'][:, 0]).to(node_preds.device)
|
||||||
|
for ann in self.anns
|
||||||
|
]
|
||||||
|
node_gts = torch.cat(node_gts)
|
||||||
|
|
||||||
|
node_f1s = compute_f1_score(node_preds, node_gts, ignores)
|
||||||
|
|
||||||
|
return {
|
||||||
|
'macro_f1': node_f1s.mean(),
|
||||||
|
}
|
|
@ -0,0 +1,55 @@
|
||||||
|
import numpy as np
|
||||||
|
from mmcv.parallel import DataContainer as DC
|
||||||
|
|
||||||
|
from mmdet.datasets.builder import PIPELINES
|
||||||
|
from mmdet.datasets.pipelines.formating import DefaultFormatBundle, to_tensor
|
||||||
|
|
||||||
|
|
||||||
|
@PIPELINES.register_module()
|
||||||
|
class KIEFormatBundle(DefaultFormatBundle):
|
||||||
|
"""Key information extraction formatting bundle.
|
||||||
|
|
||||||
|
Based on the DefaultFormatBundle, itt simplifies the pipeline of formatting
|
||||||
|
common fields, including "img", "proposals", "gt_bboxes", "gt_labels",
|
||||||
|
"gt_masks", "gt_semantic_seg", "relations" and "texts".
|
||||||
|
These fields are formatted as follows.
|
||||||
|
|
||||||
|
- img: (1) transpose, (2) to tensor, (3) to DataContainer (stack=True)
|
||||||
|
- proposals: (1) to tensor, (2) to DataContainer
|
||||||
|
- gt_bboxes: (1) to tensor, (2) to DataContainer
|
||||||
|
- gt_bboxes_ignore: (1) to tensor, (2) to DataContainer
|
||||||
|
- gt_labels: (1) to tensor, (2) to DataContainer
|
||||||
|
- gt_masks: (1) to tensor, (2) to DataContainer (cpu_only=True)
|
||||||
|
- gt_semantic_seg: (1) unsqueeze dim-0 (2) to tensor, \
|
||||||
|
(3) to DataContainer (stack=True)
|
||||||
|
- relations: (1) scale, (2) to tensor, (3) to DataContainer
|
||||||
|
- texts: (1) to tensor, (2) to DataContainer
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __call__(self, results):
|
||||||
|
"""Call function to transform and format common fields in results.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
results (dict): Result dict contains the data to convert.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
dict: The result dict contains the data that is formatted with \
|
||||||
|
default bundle.
|
||||||
|
"""
|
||||||
|
super().__call__(results)
|
||||||
|
if 'ann_info' in results:
|
||||||
|
for key in ['relations', 'texts']:
|
||||||
|
value = results['ann_info'][key]
|
||||||
|
if key == 'relations' and 'scale_factor' in results:
|
||||||
|
scale_factor = results['scale_factor']
|
||||||
|
if isinstance(scale_factor, float):
|
||||||
|
sx = sy = scale_factor
|
||||||
|
else:
|
||||||
|
sx, sy = results['scale_factor'][:2]
|
||||||
|
r = sx / sy
|
||||||
|
value = value * np.array([sx, sy, r, 1, r])[None, None]
|
||||||
|
results[key] = DC(to_tensor(value))
|
||||||
|
return results
|
||||||
|
|
||||||
|
def __repr__(self):
|
||||||
|
return self.__class__.__name__
|
|
@ -0,0 +1,3 @@
|
||||||
|
from .unet import UNet
|
||||||
|
|
||||||
|
__all__ = ['UNet']
|
|
@ -0,0 +1,528 @@
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
import torch.utils.checkpoint as cp
|
||||||
|
from mmcv.cnn import (UPSAMPLE_LAYERS, ConvModule, build_activation_layer,
|
||||||
|
build_norm_layer, build_upsample_layer, constant_init,
|
||||||
|
kaiming_init)
|
||||||
|
from mmcv.runner import load_checkpoint
|
||||||
|
from mmcv.utils.parrots_wrapper import _BatchNorm
|
||||||
|
|
||||||
|
from mmdet.models.builder import BACKBONES
|
||||||
|
from mmdet.utils import get_root_logger
|
||||||
|
|
||||||
|
|
||||||
|
class UpConvBlock(nn.Module):
|
||||||
|
"""Upsample convolution block in decoder for UNet.
|
||||||
|
|
||||||
|
This upsample convolution block consists of one upsample module
|
||||||
|
followed by one convolution block. The upsample module expands the
|
||||||
|
high-level low-resolution feature map and the convolution block fuses
|
||||||
|
the upsampled high-level low-resolution feature map and the low-level
|
||||||
|
high-resolution feature map from encoder.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
conv_block (nn.Sequential): Sequential of convolutional layers.
|
||||||
|
in_channels (int): Number of input channels of the high-level
|
||||||
|
skip_channels (int): Number of input channels of the low-level
|
||||||
|
high-resolution feature map from encoder.
|
||||||
|
out_channels (int): Number of output channels.
|
||||||
|
num_convs (int): Number of convolutional layers in the conv_block.
|
||||||
|
Default: 2.
|
||||||
|
stride (int): Stride of convolutional layer in conv_block. Default: 1.
|
||||||
|
dilation (int): Dilation rate of convolutional layer in conv_block.
|
||||||
|
Default: 1.
|
||||||
|
with_cp (bool): Use checkpoint or not. Using checkpoint will save some
|
||||||
|
memory while slowing down the training speed. Default: False.
|
||||||
|
conv_cfg (dict | None): Config dict for convolution layer.
|
||||||
|
Default: None.
|
||||||
|
norm_cfg (dict | None): Config dict for normalization layer.
|
||||||
|
Default: dict(type='BN').
|
||||||
|
act_cfg (dict | None): Config dict for activation layer in ConvModule.
|
||||||
|
Default: dict(type='ReLU').
|
||||||
|
upsample_cfg (dict): The upsample config of the upsample module in
|
||||||
|
decoder. Default: dict(type='InterpConv'). If the size of
|
||||||
|
high-level feature map is the same as that of skip feature map
|
||||||
|
(low-level feature map from encoder), it does not need upsample the
|
||||||
|
high-level feature map and the upsample_cfg is None.
|
||||||
|
dcn (bool): Use deformable convolution in convolutional layer or not.
|
||||||
|
Default: None.
|
||||||
|
plugins (dict): plugins for convolutional layers. Default: None.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self,
|
||||||
|
conv_block,
|
||||||
|
in_channels,
|
||||||
|
skip_channels,
|
||||||
|
out_channels,
|
||||||
|
num_convs=2,
|
||||||
|
stride=1,
|
||||||
|
dilation=1,
|
||||||
|
with_cp=False,
|
||||||
|
conv_cfg=None,
|
||||||
|
norm_cfg=dict(type='BN'),
|
||||||
|
act_cfg=dict(type='ReLU'),
|
||||||
|
upsample_cfg=dict(type='InterpConv'),
|
||||||
|
dcn=None,
|
||||||
|
plugins=None):
|
||||||
|
super().__init__()
|
||||||
|
assert dcn is None, 'Not implemented yet.'
|
||||||
|
assert plugins is None, 'Not implemented yet.'
|
||||||
|
|
||||||
|
self.conv_block = conv_block(
|
||||||
|
in_channels=2 * skip_channels,
|
||||||
|
out_channels=out_channels,
|
||||||
|
num_convs=num_convs,
|
||||||
|
stride=stride,
|
||||||
|
dilation=dilation,
|
||||||
|
with_cp=with_cp,
|
||||||
|
conv_cfg=conv_cfg,
|
||||||
|
norm_cfg=norm_cfg,
|
||||||
|
act_cfg=act_cfg,
|
||||||
|
dcn=None,
|
||||||
|
plugins=None)
|
||||||
|
if upsample_cfg is not None:
|
||||||
|
self.upsample = build_upsample_layer(
|
||||||
|
cfg=upsample_cfg,
|
||||||
|
in_channels=in_channels,
|
||||||
|
out_channels=skip_channels,
|
||||||
|
with_cp=with_cp,
|
||||||
|
norm_cfg=norm_cfg,
|
||||||
|
act_cfg=act_cfg)
|
||||||
|
else:
|
||||||
|
self.upsample = ConvModule(
|
||||||
|
in_channels,
|
||||||
|
skip_channels,
|
||||||
|
kernel_size=1,
|
||||||
|
stride=1,
|
||||||
|
padding=0,
|
||||||
|
conv_cfg=conv_cfg,
|
||||||
|
norm_cfg=norm_cfg,
|
||||||
|
act_cfg=act_cfg)
|
||||||
|
|
||||||
|
def forward(self, skip, x):
|
||||||
|
"""Forward function."""
|
||||||
|
|
||||||
|
x = self.upsample(x)
|
||||||
|
out = torch.cat([skip, x], dim=1)
|
||||||
|
out = self.conv_block(out)
|
||||||
|
|
||||||
|
return out
|
||||||
|
|
||||||
|
|
||||||
|
class BasicConvBlock(nn.Module):
|
||||||
|
"""Basic convolutional block for UNet.
|
||||||
|
|
||||||
|
This module consists of several plain convolutional layers.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
in_channels (int): Number of input channels.
|
||||||
|
out_channels (int): Number of output channels.
|
||||||
|
num_convs (int): Number of convolutional layers. Default: 2.
|
||||||
|
stride (int): Whether use stride convolution to downsample
|
||||||
|
the input feature map. If stride=2, it only uses stride convolution
|
||||||
|
in the first convolutional layer to downsample the input feature
|
||||||
|
map. Options are 1 or 2. Default: 1.
|
||||||
|
dilation (int): Whether use dilated convolution to expand the
|
||||||
|
receptive field. Set dilation rate of each convolutional layer and
|
||||||
|
the dilation rate of the first convolutional layer is always 1.
|
||||||
|
Default: 1.
|
||||||
|
with_cp (bool): Use checkpoint or not. Using checkpoint will save some
|
||||||
|
memory while slowing down the training speed. Default: False.
|
||||||
|
conv_cfg (dict | None): Config dict for convolution layer.
|
||||||
|
Default: None.
|
||||||
|
norm_cfg (dict | None): Config dict for normalization layer.
|
||||||
|
Default: dict(type='BN').
|
||||||
|
act_cfg (dict | None): Config dict for activation layer in ConvModule.
|
||||||
|
Default: dict(type='ReLU').
|
||||||
|
dcn (bool): Use deformable convolution in convolutional layer or not.
|
||||||
|
Default: None.
|
||||||
|
plugins (dict): plugins for convolutional layers. Default: None.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self,
|
||||||
|
in_channels,
|
||||||
|
out_channels,
|
||||||
|
num_convs=2,
|
||||||
|
stride=1,
|
||||||
|
dilation=1,
|
||||||
|
with_cp=False,
|
||||||
|
conv_cfg=None,
|
||||||
|
norm_cfg=dict(type='BN'),
|
||||||
|
act_cfg=dict(type='ReLU'),
|
||||||
|
dcn=None,
|
||||||
|
plugins=None):
|
||||||
|
super().__init__()
|
||||||
|
assert dcn is None, 'Not implemented yet.'
|
||||||
|
assert plugins is None, 'Not implemented yet.'
|
||||||
|
|
||||||
|
self.with_cp = with_cp
|
||||||
|
convs = []
|
||||||
|
for i in range(num_convs):
|
||||||
|
convs.append(
|
||||||
|
ConvModule(
|
||||||
|
in_channels=in_channels if i == 0 else out_channels,
|
||||||
|
out_channels=out_channels,
|
||||||
|
kernel_size=3,
|
||||||
|
stride=stride if i == 0 else 1,
|
||||||
|
dilation=1 if i == 0 else dilation,
|
||||||
|
padding=1 if i == 0 else dilation,
|
||||||
|
conv_cfg=conv_cfg,
|
||||||
|
norm_cfg=norm_cfg,
|
||||||
|
act_cfg=act_cfg))
|
||||||
|
|
||||||
|
self.convs = nn.Sequential(*convs)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
"""Forward function."""
|
||||||
|
|
||||||
|
if self.with_cp and x.requires_grad:
|
||||||
|
out = cp.checkpoint(self.convs, x)
|
||||||
|
else:
|
||||||
|
out = self.convs(x)
|
||||||
|
return out
|
||||||
|
|
||||||
|
|
||||||
|
@UPSAMPLE_LAYERS.register_module()
|
||||||
|
class DeconvModule(nn.Module):
|
||||||
|
"""Deconvolution upsample module in decoder for UNet (2X upsample).
|
||||||
|
|
||||||
|
This module uses deconvolution to upsample feature map in the decoder
|
||||||
|
of UNet.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
in_channels (int): Number of input channels.
|
||||||
|
out_channels (int): Number of output channels.
|
||||||
|
with_cp (bool): Use checkpoint or not. Using checkpoint will save some
|
||||||
|
memory while slowing down the training speed. Default: False.
|
||||||
|
norm_cfg (dict | None): Config dict for normalization layer.
|
||||||
|
Default: dict(type='BN').
|
||||||
|
act_cfg (dict | None): Config dict for activation layer in ConvModule.
|
||||||
|
Default: dict(type='ReLU').
|
||||||
|
kernel_size (int): Kernel size of the convolutional layer. Default: 4.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self,
|
||||||
|
in_channels,
|
||||||
|
out_channels,
|
||||||
|
with_cp=False,
|
||||||
|
norm_cfg=dict(type='BN'),
|
||||||
|
act_cfg=dict(type='ReLU'),
|
||||||
|
*,
|
||||||
|
kernel_size=4,
|
||||||
|
scale_factor=2):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
assert (kernel_size - scale_factor >= 0) and\
|
||||||
|
(kernel_size - scale_factor) % 2 == 0,\
|
||||||
|
f'kernel_size should be greater than or equal to scale_factor '\
|
||||||
|
f'and (kernel_size - scale_factor) should be even numbers, '\
|
||||||
|
f'while the kernel size is {kernel_size} and scale_factor is '\
|
||||||
|
f'{scale_factor}.'
|
||||||
|
|
||||||
|
stride = scale_factor
|
||||||
|
padding = (kernel_size - scale_factor) // 2
|
||||||
|
self.with_cp = with_cp
|
||||||
|
deconv = nn.ConvTranspose2d(
|
||||||
|
in_channels,
|
||||||
|
out_channels,
|
||||||
|
kernel_size=kernel_size,
|
||||||
|
stride=stride,
|
||||||
|
padding=padding)
|
||||||
|
|
||||||
|
_, norm = build_norm_layer(norm_cfg, out_channels)
|
||||||
|
activate = build_activation_layer(act_cfg)
|
||||||
|
self.deconv_upsamping = nn.Sequential(deconv, norm, activate)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
"""Forward function."""
|
||||||
|
|
||||||
|
if self.with_cp and x.requires_grad:
|
||||||
|
out = cp.checkpoint(self.deconv_upsamping, x)
|
||||||
|
else:
|
||||||
|
out = self.deconv_upsamping(x)
|
||||||
|
return out
|
||||||
|
|
||||||
|
|
||||||
|
@UPSAMPLE_LAYERS.register_module()
|
||||||
|
class InterpConv(nn.Module):
|
||||||
|
"""Interpolation upsample module in decoder for UNet.
|
||||||
|
|
||||||
|
This module uses interpolation to upsample feature map in the decoder
|
||||||
|
of UNet. It consists of one interpolation upsample layer and one
|
||||||
|
convolutional layer. It can be one interpolation upsample layer followed
|
||||||
|
by one convolutional layer (conv_first=False) or one convolutional layer
|
||||||
|
followed by one interpolation upsample layer (conv_first=True).
|
||||||
|
|
||||||
|
Args:
|
||||||
|
in_channels (int): Number of input channels.
|
||||||
|
out_channels (int): Number of output channels.
|
||||||
|
with_cp (bool): Use checkpoint or not. Using checkpoint will save some
|
||||||
|
memory while slowing down the training speed. Default: False.
|
||||||
|
norm_cfg (dict | None): Config dict for normalization layer.
|
||||||
|
Default: dict(type='BN').
|
||||||
|
act_cfg (dict | None): Config dict for activation layer in ConvModule.
|
||||||
|
Default: dict(type='ReLU').
|
||||||
|
conv_cfg (dict | None): Config dict for convolution layer.
|
||||||
|
Default: None.
|
||||||
|
conv_first (bool): Whether convolutional layer or interpolation
|
||||||
|
upsample layer first. Default: False. It means interpolation
|
||||||
|
upsample layer followed by one convolutional layer.
|
||||||
|
kernel_size (int): Kernel size of the convolutional layer. Default: 1.
|
||||||
|
stride (int): Stride of the convolutional layer. Default: 1.
|
||||||
|
padding (int): Padding of the convolutional layer. Default: 1.
|
||||||
|
upsample_cfg (dict): Interpolation config of the upsample layer.
|
||||||
|
Default: dict(
|
||||||
|
scale_factor=2, mode='bilinear', align_corners=False).
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self,
|
||||||
|
in_channels,
|
||||||
|
out_channels,
|
||||||
|
with_cp=False,
|
||||||
|
norm_cfg=dict(type='BN'),
|
||||||
|
act_cfg=dict(type='ReLU'),
|
||||||
|
*,
|
||||||
|
conv_cfg=None,
|
||||||
|
conv_first=False,
|
||||||
|
kernel_size=1,
|
||||||
|
stride=1,
|
||||||
|
padding=0,
|
||||||
|
upsample_cfg=dict(
|
||||||
|
scale_factor=2, mode='bilinear', align_corners=False)):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
self.with_cp = with_cp
|
||||||
|
conv = ConvModule(
|
||||||
|
in_channels,
|
||||||
|
out_channels,
|
||||||
|
kernel_size=kernel_size,
|
||||||
|
stride=stride,
|
||||||
|
padding=padding,
|
||||||
|
conv_cfg=conv_cfg,
|
||||||
|
norm_cfg=norm_cfg,
|
||||||
|
act_cfg=act_cfg)
|
||||||
|
upsample = nn.Upsample(**upsample_cfg)
|
||||||
|
if conv_first:
|
||||||
|
self.interp_upsample = nn.Sequential(conv, upsample)
|
||||||
|
else:
|
||||||
|
self.interp_upsample = nn.Sequential(upsample, conv)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
"""Forward function."""
|
||||||
|
|
||||||
|
if self.with_cp and x.requires_grad:
|
||||||
|
out = cp.checkpoint(self.interp_upsample, x)
|
||||||
|
else:
|
||||||
|
out = self.interp_upsample(x)
|
||||||
|
return out
|
||||||
|
|
||||||
|
|
||||||
|
@BACKBONES.register_module()
|
||||||
|
class UNet(nn.Module):
|
||||||
|
"""UNet backbone.
|
||||||
|
U-Net: Convolutional Networks for Biomedical Image Segmentation.
|
||||||
|
https://arxiv.org/pdf/1505.04597.pdf
|
||||||
|
|
||||||
|
Args:
|
||||||
|
in_channels (int): Number of input image channels. Default" 3.
|
||||||
|
base_channels (int): Number of base channels of each stage.
|
||||||
|
The output channels of the first stage. Default: 64.
|
||||||
|
num_stages (int): Number of stages in encoder, normally 5. Default: 5.
|
||||||
|
strides (Sequence[int 1 | 2]): Strides of each stage in encoder.
|
||||||
|
len(strides) is equal to num_stages. Normally the stride of the
|
||||||
|
first stage in encoder is 1. If strides[i]=2, it uses stride
|
||||||
|
convolution to downsample in the correspondence encoder stage.
|
||||||
|
Default: (1, 1, 1, 1, 1).
|
||||||
|
enc_num_convs (Sequence[int]): Number of convolutional layers in the
|
||||||
|
convolution block of the correspondence encoder stage.
|
||||||
|
Default: (2, 2, 2, 2, 2).
|
||||||
|
dec_num_convs (Sequence[int]): Number of convolutional layers in the
|
||||||
|
convolution block of the correspondence decoder stage.
|
||||||
|
Default: (2, 2, 2, 2).
|
||||||
|
downsamples (Sequence[int]): Whether use MaxPool to downsample the
|
||||||
|
feature map after the first stage of encoder
|
||||||
|
(stages: [1, num_stages)). If the correspondence encoder stage use
|
||||||
|
stride convolution (strides[i]=2), it will never use MaxPool to
|
||||||
|
downsample, even downsamples[i-1]=True.
|
||||||
|
Default: (True, True, True, True).
|
||||||
|
enc_dilations (Sequence[int]): Dilation rate of each stage in encoder.
|
||||||
|
Default: (1, 1, 1, 1, 1).
|
||||||
|
dec_dilations (Sequence[int]): Dilation rate of each stage in decoder.
|
||||||
|
Default: (1, 1, 1, 1).
|
||||||
|
with_cp (bool): Use checkpoint or not. Using checkpoint will save some
|
||||||
|
memory while slowing down the training speed. Default: False.
|
||||||
|
conv_cfg (dict | None): Config dict for convolution layer.
|
||||||
|
Default: None.
|
||||||
|
norm_cfg (dict | None): Config dict for normalization layer.
|
||||||
|
Default: dict(type='BN').
|
||||||
|
act_cfg (dict | None): Config dict for activation layer in ConvModule.
|
||||||
|
Default: dict(type='ReLU').
|
||||||
|
upsample_cfg (dict): The upsample config of the upsample module in
|
||||||
|
decoder. Default: dict(type='InterpConv').
|
||||||
|
norm_eval (bool): Whether to set norm layers to eval mode, namely,
|
||||||
|
freeze running stats (mean and var). Note: Effect on Batch Norm
|
||||||
|
and its variants only. Default: False.
|
||||||
|
dcn (bool): Use deformable convolution in convolutional layer or not.
|
||||||
|
Default: None.
|
||||||
|
plugins (dict): plugins for convolutional layers. Default: None.
|
||||||
|
|
||||||
|
Notice:
|
||||||
|
The input image size should be divisible by the whole downsample rate
|
||||||
|
of the encoder. More detail of the whole downsample rate can be found
|
||||||
|
in UNet._check_input_divisible.
|
||||||
|
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self,
|
||||||
|
in_channels=3,
|
||||||
|
base_channels=64,
|
||||||
|
num_stages=5,
|
||||||
|
strides=(1, 1, 1, 1, 1),
|
||||||
|
enc_num_convs=(2, 2, 2, 2, 2),
|
||||||
|
dec_num_convs=(2, 2, 2, 2),
|
||||||
|
downsamples=(True, True, True, True),
|
||||||
|
enc_dilations=(1, 1, 1, 1, 1),
|
||||||
|
dec_dilations=(1, 1, 1, 1),
|
||||||
|
with_cp=False,
|
||||||
|
conv_cfg=None,
|
||||||
|
norm_cfg=dict(type='BN'),
|
||||||
|
act_cfg=dict(type='ReLU'),
|
||||||
|
upsample_cfg=dict(type='InterpConv'),
|
||||||
|
norm_eval=False,
|
||||||
|
dcn=None,
|
||||||
|
plugins=None):
|
||||||
|
super().__init__()
|
||||||
|
assert dcn is None, 'Not implemented yet.'
|
||||||
|
assert plugins is None, 'Not implemented yet.'
|
||||||
|
assert len(strides) == num_stages, \
|
||||||
|
'The length of strides should be equal to num_stages, '\
|
||||||
|
f'while the strides is {strides}, the length of '\
|
||||||
|
f'strides is {len(strides)}, and the num_stages is '\
|
||||||
|
f'{num_stages}.'
|
||||||
|
assert len(enc_num_convs) == num_stages, \
|
||||||
|
'The length of enc_num_convs should be equal to num_stages, '\
|
||||||
|
f'while the enc_num_convs is {enc_num_convs}, the length of '\
|
||||||
|
f'enc_num_convs is {len(enc_num_convs)}, and the num_stages is '\
|
||||||
|
f'{num_stages}.'
|
||||||
|
assert len(dec_num_convs) == (num_stages-1), \
|
||||||
|
'The length of dec_num_convs should be equal to (num_stages-1), '\
|
||||||
|
f'while the dec_num_convs is {dec_num_convs}, the length of '\
|
||||||
|
f'dec_num_convs is {len(dec_num_convs)}, and the num_stages is '\
|
||||||
|
f'{num_stages}.'
|
||||||
|
assert len(downsamples) == (num_stages-1), \
|
||||||
|
'The length of downsamples should be equal to (num_stages-1), '\
|
||||||
|
f'while the downsamples is {downsamples}, the length of '\
|
||||||
|
f'downsamples is {len(downsamples)}, and the num_stages is '\
|
||||||
|
f'{num_stages}.'
|
||||||
|
assert len(enc_dilations) == num_stages, \
|
||||||
|
'The length of enc_dilations should be equal to num_stages, '\
|
||||||
|
f'while the enc_dilations is {enc_dilations}, the length of '\
|
||||||
|
f'enc_dilations is {len(enc_dilations)}, and the num_stages is '\
|
||||||
|
f'{num_stages}.'
|
||||||
|
assert len(dec_dilations) == (num_stages-1), \
|
||||||
|
'The length of dec_dilations should be equal to (num_stages-1), '\
|
||||||
|
f'while the dec_dilations is {dec_dilations}, the length of '\
|
||||||
|
f'dec_dilations is {len(dec_dilations)}, and the num_stages is '\
|
||||||
|
f'{num_stages}.'
|
||||||
|
self.num_stages = num_stages
|
||||||
|
self.strides = strides
|
||||||
|
self.downsamples = downsamples
|
||||||
|
self.norm_eval = norm_eval
|
||||||
|
self.base_channels = base_channels
|
||||||
|
|
||||||
|
self.encoder = nn.ModuleList()
|
||||||
|
self.decoder = nn.ModuleList()
|
||||||
|
|
||||||
|
for i in range(num_stages):
|
||||||
|
enc_conv_block = []
|
||||||
|
if i != 0:
|
||||||
|
if strides[i] == 1 and downsamples[i - 1]:
|
||||||
|
enc_conv_block.append(nn.MaxPool2d(kernel_size=2))
|
||||||
|
upsample = (strides[i] != 1 or downsamples[i - 1])
|
||||||
|
self.decoder.append(
|
||||||
|
UpConvBlock(
|
||||||
|
conv_block=BasicConvBlock,
|
||||||
|
in_channels=base_channels * 2**i,
|
||||||
|
skip_channels=base_channels * 2**(i - 1),
|
||||||
|
out_channels=base_channels * 2**(i - 1),
|
||||||
|
num_convs=dec_num_convs[i - 1],
|
||||||
|
stride=1,
|
||||||
|
dilation=dec_dilations[i - 1],
|
||||||
|
with_cp=with_cp,
|
||||||
|
conv_cfg=conv_cfg,
|
||||||
|
norm_cfg=norm_cfg,
|
||||||
|
act_cfg=act_cfg,
|
||||||
|
upsample_cfg=upsample_cfg if upsample else None,
|
||||||
|
dcn=None,
|
||||||
|
plugins=None))
|
||||||
|
|
||||||
|
enc_conv_block.append(
|
||||||
|
BasicConvBlock(
|
||||||
|
in_channels=in_channels,
|
||||||
|
out_channels=base_channels * 2**i,
|
||||||
|
num_convs=enc_num_convs[i],
|
||||||
|
stride=strides[i],
|
||||||
|
dilation=enc_dilations[i],
|
||||||
|
with_cp=with_cp,
|
||||||
|
conv_cfg=conv_cfg,
|
||||||
|
norm_cfg=norm_cfg,
|
||||||
|
act_cfg=act_cfg,
|
||||||
|
dcn=None,
|
||||||
|
plugins=None))
|
||||||
|
self.encoder.append((nn.Sequential(*enc_conv_block)))
|
||||||
|
in_channels = base_channels * 2**i
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
self._check_input_divisible(x)
|
||||||
|
enc_outs = []
|
||||||
|
for enc in self.encoder:
|
||||||
|
x = enc(x)
|
||||||
|
enc_outs.append(x)
|
||||||
|
dec_outs = [x]
|
||||||
|
for i in reversed(range(len(self.decoder))):
|
||||||
|
x = self.decoder[i](enc_outs[i], x)
|
||||||
|
dec_outs.append(x)
|
||||||
|
|
||||||
|
return dec_outs
|
||||||
|
|
||||||
|
def train(self, mode=True):
|
||||||
|
"""Convert the model into training mode while keep normalization layer
|
||||||
|
freezed."""
|
||||||
|
super().train(mode)
|
||||||
|
if mode and self.norm_eval:
|
||||||
|
for m in self.modules():
|
||||||
|
# trick: eval have effect on BatchNorm only
|
||||||
|
if isinstance(m, _BatchNorm):
|
||||||
|
m.eval()
|
||||||
|
|
||||||
|
def _check_input_divisible(self, x):
|
||||||
|
h, w = x.shape[-2:]
|
||||||
|
whole_downsample_rate = 1
|
||||||
|
for i in range(1, self.num_stages):
|
||||||
|
if self.strides[i] == 2 or self.downsamples[i - 1]:
|
||||||
|
whole_downsample_rate *= 2
|
||||||
|
assert (h % whole_downsample_rate == 0) \
|
||||||
|
and (w % whole_downsample_rate == 0),\
|
||||||
|
f'The input image size {(h, w)} should be divisible by the whole '\
|
||||||
|
f'downsample rate {whole_downsample_rate}, when num_stages is '\
|
||||||
|
f'{self.num_stages}, strides is {self.strides}, and downsamples '\
|
||||||
|
f'is {self.downsamples}.'
|
||||||
|
|
||||||
|
def init_weights(self, pretrained=None):
|
||||||
|
"""Initialize the weights in backbone.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
pretrained (str, optional): Path to pre-trained weights.
|
||||||
|
Defaults to None.
|
||||||
|
"""
|
||||||
|
if isinstance(pretrained, str):
|
||||||
|
logger = get_root_logger()
|
||||||
|
load_checkpoint(self, pretrained, strict=False, logger=logger)
|
||||||
|
elif pretrained is None:
|
||||||
|
for m in self.modules():
|
||||||
|
if isinstance(m, nn.Conv2d):
|
||||||
|
kaiming_init(m)
|
||||||
|
elif isinstance(m, (_BatchNorm, nn.GroupNorm)):
|
||||||
|
constant_init(m, 1)
|
||||||
|
else:
|
||||||
|
raise TypeError('pretrained must be a str or None')
|
|
@ -0,0 +1,3 @@
|
||||||
|
from .extractors import * # noqa: F401,F403
|
||||||
|
from .heads import * # noqa: F401,F403
|
||||||
|
from .losses import * # noqa: F401,F403
|
|
@ -0,0 +1,3 @@
|
||||||
|
from .sdmgr import SDMGR
|
||||||
|
|
||||||
|
__all__ = ['SDMGR']
|
|
@ -0,0 +1,87 @@
|
||||||
|
from torch import nn
|
||||||
|
from torch.nn import functional as F
|
||||||
|
|
||||||
|
from mmdet.core import bbox2roi
|
||||||
|
from mmdet.models.builder import DETECTORS, build_roi_extractor
|
||||||
|
from mmdet.models.detectors import SingleStageDetector
|
||||||
|
|
||||||
|
|
||||||
|
@DETECTORS.register_module()
|
||||||
|
class SDMGR(SingleStageDetector):
|
||||||
|
"""The implementation of the paper: Spatial Dual-Modality Graph Reasoning
|
||||||
|
for Key Information Extraction. https://arxiv.org/abs/2103.14470.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
visual_modality (bool): Whether use the visual modality.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self,
|
||||||
|
backbone,
|
||||||
|
neck=None,
|
||||||
|
bbox_head=None,
|
||||||
|
extractor=dict(
|
||||||
|
type='SingleRoIExtractor',
|
||||||
|
roi_layer=dict(type='RoIAlign', output_size=7),
|
||||||
|
featmap_strides=[1]),
|
||||||
|
visual_modality=False,
|
||||||
|
train_cfg=None,
|
||||||
|
test_cfg=None,
|
||||||
|
pretrained=None):
|
||||||
|
super().__init__(backbone, neck, bbox_head, train_cfg, test_cfg,
|
||||||
|
pretrained)
|
||||||
|
self.visual_modality = visual_modality
|
||||||
|
if visual_modality:
|
||||||
|
self.extractor = build_roi_extractor({
|
||||||
|
**extractor, 'out_channels':
|
||||||
|
self.backbone.base_channels
|
||||||
|
})
|
||||||
|
self.maxpool = nn.MaxPool2d(extractor['roi_layer']['output_size'])
|
||||||
|
else:
|
||||||
|
self.extractor = None
|
||||||
|
|
||||||
|
def forward_train(self, img, img_metas, relations, texts, gt_bboxes,
|
||||||
|
gt_labels):
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
img (tensor): Input images of shape (N, C, H, W).
|
||||||
|
Typically these should be mean centered and std scaled.
|
||||||
|
img_metas (list[dict]): A list of image info dict where each dict
|
||||||
|
contains: 'img_shape', 'scale_factor', 'flip', and may also
|
||||||
|
contain 'filename', 'ori_shape', 'pad_shape', and
|
||||||
|
'img_norm_cfg'. For details of the values of these keys,
|
||||||
|
please see :class:`mmdet.datasets.pipelines.Collect`.
|
||||||
|
relations (list[tensor]): Relations between bboxes.
|
||||||
|
texts (list[tensor]): Texts in bboxes.
|
||||||
|
gt_bboxes (list[tensor]): Each item is the truth boxes for each
|
||||||
|
image in [tl_x, tl_y, br_x, br_y] format.
|
||||||
|
gt_labels (list[tensor]): Class indices corresponding to each box.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
dict[str, tensor]: A dictionary of loss components.
|
||||||
|
"""
|
||||||
|
x = self.extract_feat(img, gt_bboxes)
|
||||||
|
node_preds, edge_preds = self.bbox_head.forward(relations, texts, x)
|
||||||
|
return self.bbox_head.loss(node_preds, edge_preds, gt_labels)
|
||||||
|
|
||||||
|
def forward_test(self,
|
||||||
|
img,
|
||||||
|
img_metas,
|
||||||
|
relations,
|
||||||
|
texts,
|
||||||
|
gt_bboxes,
|
||||||
|
rescale=False):
|
||||||
|
x = self.extract_feat(img, gt_bboxes)
|
||||||
|
node_preds, edge_preds = self.bbox_head.forward(relations, texts, x)
|
||||||
|
return [
|
||||||
|
dict(
|
||||||
|
img_metas=img_metas,
|
||||||
|
nodes=F.softmax(node_preds, -1),
|
||||||
|
edges=F.softmax(edge_preds, -1))
|
||||||
|
]
|
||||||
|
|
||||||
|
def extract_feat(self, img, gt_bboxes):
|
||||||
|
if self.visual_modality:
|
||||||
|
x = super().extract_feat(img)[-1]
|
||||||
|
feats = self.maxpool(self.extractor([x], bbox2roi(gt_bboxes)))
|
||||||
|
return feats.view(feats.size(0), -1)
|
||||||
|
return None
|
|
@ -0,0 +1,3 @@
|
||||||
|
from .sdmgr_head import SDMGRHead
|
||||||
|
|
||||||
|
__all__ = ['SDMGRHead']
|
|
@ -0,0 +1,193 @@
|
||||||
|
import torch
|
||||||
|
from mmcv.cnn import normal_init
|
||||||
|
from torch import nn
|
||||||
|
from torch.nn import functional as F
|
||||||
|
|
||||||
|
from mmdet.models.builder import HEADS, build_loss
|
||||||
|
|
||||||
|
|
||||||
|
@HEADS.register_module()
|
||||||
|
class SDMGRHead(nn.Module):
|
||||||
|
|
||||||
|
def __init__(self,
|
||||||
|
num_chars=92,
|
||||||
|
visual_dim=64,
|
||||||
|
fusion_dim=1024,
|
||||||
|
node_input=32,
|
||||||
|
node_embed=256,
|
||||||
|
edge_input=5,
|
||||||
|
edge_embed=256,
|
||||||
|
num_gnn=2,
|
||||||
|
num_classes=26,
|
||||||
|
loss=dict(type='SDMGRLoss'),
|
||||||
|
bidirectional=False,
|
||||||
|
train_cfg=None,
|
||||||
|
test_cfg=None):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
self.fusion = Block([visual_dim, node_embed], node_embed, fusion_dim)
|
||||||
|
self.node_embed = nn.Embedding(num_chars, node_input, 0)
|
||||||
|
hidden = node_embed // 2 if bidirectional else node_embed
|
||||||
|
self.rnn = nn.LSTM(
|
||||||
|
input_size=node_input,
|
||||||
|
hidden_size=hidden,
|
||||||
|
num_layers=1,
|
||||||
|
batch_first=True,
|
||||||
|
bidirectional=bidirectional)
|
||||||
|
self.edge_embed = nn.Linear(edge_input, edge_embed)
|
||||||
|
self.gnn_layers = nn.ModuleList(
|
||||||
|
[GNNLayer(node_embed, edge_embed) for _ in range(num_gnn)])
|
||||||
|
self.node_cls = nn.Linear(node_embed, num_classes)
|
||||||
|
self.edge_cls = nn.Linear(edge_embed, 2)
|
||||||
|
self.loss = build_loss(loss)
|
||||||
|
|
||||||
|
def init_weights(self, pretrained=False):
|
||||||
|
normal_init(self.edge_embed, mean=0, std=0.01)
|
||||||
|
|
||||||
|
def forward(self, relations, texts, x=None):
|
||||||
|
node_nums, char_nums = [], []
|
||||||
|
for text in texts:
|
||||||
|
node_nums.append(text.size(0))
|
||||||
|
char_nums.append((text > 0).sum(-1))
|
||||||
|
|
||||||
|
max_num = max([char_num.max() for char_num in char_nums])
|
||||||
|
all_nodes = torch.cat([
|
||||||
|
torch.cat(
|
||||||
|
[text,
|
||||||
|
text.new_zeros(text.size(0), max_num - text.size(1))], -1)
|
||||||
|
for text in texts
|
||||||
|
])
|
||||||
|
embed_nodes = self.node_embed(all_nodes.clamp(min=0).long())
|
||||||
|
rnn_nodes, _ = self.rnn(embed_nodes)
|
||||||
|
|
||||||
|
nodes = rnn_nodes.new_zeros(*rnn_nodes.shape[::2])
|
||||||
|
all_nums = torch.cat(char_nums)
|
||||||
|
valid = all_nums > 0
|
||||||
|
nodes[valid] = rnn_nodes[valid].gather(
|
||||||
|
1, (all_nums[valid] - 1).unsqueeze(-1).unsqueeze(-1).expand(
|
||||||
|
-1, -1, rnn_nodes.size(-1))).squeeze(1)
|
||||||
|
|
||||||
|
if x is not None:
|
||||||
|
nodes = self.fusion([x, nodes])
|
||||||
|
|
||||||
|
all_edges = torch.cat(
|
||||||
|
[rel.view(-1, rel.size(-1)) for rel in relations])
|
||||||
|
embed_edges = self.edge_embed(all_edges.float())
|
||||||
|
embed_edges = F.normalize(embed_edges)
|
||||||
|
|
||||||
|
for gnn_layer in self.gnn_layers:
|
||||||
|
nodes, cat_nodes = gnn_layer(nodes, embed_edges, node_nums)
|
||||||
|
|
||||||
|
node_cls, edge_cls = self.node_cls(nodes), self.edge_cls(cat_nodes)
|
||||||
|
return node_cls, edge_cls
|
||||||
|
|
||||||
|
|
||||||
|
class GNNLayer(nn.Module):
|
||||||
|
|
||||||
|
def __init__(self, node_dim=256, edge_dim=256):
|
||||||
|
super().__init__()
|
||||||
|
self.in_fc = nn.Linear(node_dim * 2 + edge_dim, node_dim)
|
||||||
|
self.coef_fc = nn.Linear(node_dim, 1)
|
||||||
|
self.out_fc = nn.Linear(node_dim, node_dim)
|
||||||
|
self.relu = nn.ReLU()
|
||||||
|
|
||||||
|
def forward(self, nodes, edges, nums):
|
||||||
|
start, cat_nodes = 0, []
|
||||||
|
for num in nums:
|
||||||
|
sample_nodes = nodes[start:start + num]
|
||||||
|
cat_nodes.append(
|
||||||
|
torch.cat([
|
||||||
|
sample_nodes.unsqueeze(1).expand(-1, num, -1),
|
||||||
|
sample_nodes.unsqueeze(0).expand(num, -1, -1)
|
||||||
|
], -1).view(num**2, -1))
|
||||||
|
start += num
|
||||||
|
cat_nodes = torch.cat([torch.cat(cat_nodes), edges], -1)
|
||||||
|
cat_nodes = self.relu(self.in_fc(cat_nodes))
|
||||||
|
coefs = self.coef_fc(cat_nodes)
|
||||||
|
|
||||||
|
start, residuals = 0, []
|
||||||
|
for num in nums:
|
||||||
|
residual = F.softmax(
|
||||||
|
-torch.eye(num).to(coefs.device).unsqueeze(-1) * 1e9 +
|
||||||
|
coefs[start:start + num**2].view(num, num, -1), 1)
|
||||||
|
residuals.append(
|
||||||
|
(residual *
|
||||||
|
cat_nodes[start:start + num**2].view(num, num, -1)).sum(1))
|
||||||
|
start += num**2
|
||||||
|
|
||||||
|
nodes += self.relu(self.out_fc(torch.cat(residuals)))
|
||||||
|
return nodes, cat_nodes
|
||||||
|
|
||||||
|
|
||||||
|
class Block(nn.Module):
|
||||||
|
|
||||||
|
def __init__(self,
|
||||||
|
input_dims,
|
||||||
|
output_dim,
|
||||||
|
mm_dim=1600,
|
||||||
|
chunks=20,
|
||||||
|
rank=15,
|
||||||
|
shared=False,
|
||||||
|
dropout_input=0.,
|
||||||
|
dropout_pre_lin=0.,
|
||||||
|
dropout_output=0.,
|
||||||
|
pos_norm='before_cat'):
|
||||||
|
super().__init__()
|
||||||
|
self.rank = rank
|
||||||
|
self.dropout_input = dropout_input
|
||||||
|
self.dropout_pre_lin = dropout_pre_lin
|
||||||
|
self.dropout_output = dropout_output
|
||||||
|
assert (pos_norm in ['before_cat', 'after_cat'])
|
||||||
|
self.pos_norm = pos_norm
|
||||||
|
# Modules
|
||||||
|
self.linear0 = nn.Linear(input_dims[0], mm_dim)
|
||||||
|
self.linear1 = self.linear0 if shared \
|
||||||
|
else nn.Linear(input_dims[1], mm_dim)
|
||||||
|
self.merge_linears0, self.merge_linears1 =\
|
||||||
|
nn.ModuleList(), nn.ModuleList()
|
||||||
|
self.chunks = self.chunk_sizes(mm_dim, chunks)
|
||||||
|
for size in self.chunks:
|
||||||
|
ml0 = nn.Linear(size, size * rank)
|
||||||
|
self.merge_linears0.append(ml0)
|
||||||
|
ml1 = ml0 if shared else nn.Linear(size, size * rank)
|
||||||
|
self.merge_linears1.append(ml1)
|
||||||
|
self.linear_out = nn.Linear(mm_dim, output_dim)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
x0 = self.linear0(x[0])
|
||||||
|
x1 = self.linear1(x[1])
|
||||||
|
bs = x1.size(0)
|
||||||
|
if self.dropout_input > 0:
|
||||||
|
x0 = F.dropout(x0, p=self.dropout_input, training=self.training)
|
||||||
|
x1 = F.dropout(x1, p=self.dropout_input, training=self.training)
|
||||||
|
x0_chunks = torch.split(x0, self.chunks, -1)
|
||||||
|
x1_chunks = torch.split(x1, self.chunks, -1)
|
||||||
|
zs = []
|
||||||
|
for x0_c, x1_c, m0, m1 in zip(x0_chunks, x1_chunks,
|
||||||
|
self.merge_linears0,
|
||||||
|
self.merge_linears1):
|
||||||
|
m = m0(x0_c) * m1(x1_c) # bs x split_size*rank
|
||||||
|
m = m.view(bs, self.rank, -1)
|
||||||
|
z = torch.sum(m, 1)
|
||||||
|
if self.pos_norm == 'before_cat':
|
||||||
|
z = torch.sqrt(F.relu(z)) - torch.sqrt(F.relu(-z))
|
||||||
|
z = F.normalize(z)
|
||||||
|
zs.append(z)
|
||||||
|
z = torch.cat(zs, 1)
|
||||||
|
if self.pos_norm == 'after_cat':
|
||||||
|
z = torch.sqrt(F.relu(z)) - torch.sqrt(F.relu(-z))
|
||||||
|
z = F.normalize(z)
|
||||||
|
|
||||||
|
if self.dropout_pre_lin > 0:
|
||||||
|
z = F.dropout(z, p=self.dropout_pre_lin, training=self.training)
|
||||||
|
z = self.linear_out(z)
|
||||||
|
if self.dropout_output > 0:
|
||||||
|
z = F.dropout(z, p=self.dropout_output, training=self.training)
|
||||||
|
return z
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def chunk_sizes(dim, chunks):
|
||||||
|
split_size = (dim + chunks - 1) // chunks
|
||||||
|
sizes_list = [split_size] * chunks
|
||||||
|
sizes_list[-1] = sizes_list[-1] - (sum(sizes_list) - dim)
|
||||||
|
return sizes_list
|
|
@ -0,0 +1,3 @@
|
||||||
|
from .sdmgr_loss import SDMGRLoss
|
||||||
|
|
||||||
|
__all__ = ['SDMGRLoss']
|
|
@ -0,0 +1,39 @@
|
||||||
|
import torch
|
||||||
|
from torch import nn
|
||||||
|
|
||||||
|
from mmdet.models.builder import LOSSES
|
||||||
|
from mmdet.models.losses import accuracy
|
||||||
|
|
||||||
|
|
||||||
|
@LOSSES.register_module()
|
||||||
|
class SDMGRLoss(nn.Module):
|
||||||
|
"""The implementation the loss of key information extraction proposed in
|
||||||
|
the paper: Spatial Dual-Modality Graph Reasoning for Key Information
|
||||||
|
Extraction.
|
||||||
|
|
||||||
|
https://arxiv.org/abs/2103.14470.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, node_weight=1.0, edge_weight=1.0, ignore=0):
|
||||||
|
super().__init__()
|
||||||
|
self.loss_node = nn.CrossEntropyLoss(ignore_index=ignore)
|
||||||
|
self.loss_edge = nn.CrossEntropyLoss(ignore_index=-1)
|
||||||
|
self.node_weight = node_weight
|
||||||
|
self.edge_weight = edge_weight
|
||||||
|
self.ignore = ignore
|
||||||
|
|
||||||
|
def forward(self, node_preds, edge_preds, gts):
|
||||||
|
node_gts, edge_gts = [], []
|
||||||
|
for gt in gts:
|
||||||
|
node_gts.append(gt[:, 0])
|
||||||
|
edge_gts.append(gt[:, 1:].contiguous().view(-1))
|
||||||
|
node_gts = torch.cat(node_gts).long()
|
||||||
|
edge_gts = torch.cat(edge_gts).long()
|
||||||
|
|
||||||
|
node_valids = torch.nonzero(node_gts != self.ignore).view(-1)
|
||||||
|
edge_valids = torch.nonzero(edge_gts != -1).view(-1)
|
||||||
|
return dict(
|
||||||
|
loss_node=self.node_weight * self.loss_node(node_preds, node_gts),
|
||||||
|
loss_edge=self.edge_weight * self.loss_edge(edge_preds, edge_gts),
|
||||||
|
acc_node=accuracy(node_preds[node_valids], node_gts[node_valids]),
|
||||||
|
acc_edge=accuracy(edge_preds[edge_valids], edge_gts[edge_valids]))
|
Loading…
Reference in New Issue