mirror of https://github.com/open-mmlab/mmocr.git
Merge pull request #4 from cuhk-hbsun/feature/kie
[feature]: add code for kie and textsnake configpull/2/head
commit
50ab4ef23d
|
@ -0,0 +1,25 @@
|
|||
# Spatial Dual-Modality Graph Reasoning for Key Information Extraction
|
||||
|
||||
## Introduction
|
||||
|
||||
[ALGORITHM]
|
||||
|
||||
```bibtex
|
||||
@misc{sun2021spatial,
|
||||
title={Spatial Dual-Modality Graph Reasoning for Key Information Extraction},
|
||||
author={Hongbin Sun and Zhanghui Kuang and Xiaoyu Yue and Chenhao Lin and Wayne Zhang},
|
||||
year={2021},
|
||||
eprint={2103.14470},
|
||||
archivePrefix={arXiv},
|
||||
primaryClass={cs.CV}
|
||||
}
|
||||
```
|
||||
|
||||
## Results and models
|
||||
|
||||
### WildReceipt
|
||||
|
||||
| Method | Modality | Macro F1-Score | Download |
|
||||
| :--------------------------------------------------------------------: | :--------------: | :------------: | :-------------------------------------------------------------------------------------------------------------------------------------: |
|
||||
| [sdmgr_unet16](/configs/kie/sdmgr/sdmgr_unet16_60e_wildreceipt.py) | Visual + Textual | 0.880 | [model](https://download.openmmlab.com/mmocr/kie/sdmgr/todo.pth) \| [log](https://download.openmmlab.com/mmocr/kie/sdmgr/todo.log.json) |
|
||||
| [sdmgr_novisual](/configs/kie/sdmgr/sdmgr_novisual_60e_wildreceipt.py) | Textual | 0.871 | [model](https://download.openmmlab.com/mmocr/kie/sdmgr/todo.pth) \| [log](https://download.openmmlab.com/mmocr/kie/sdmgr/todo.log.json) |
|
|
@ -0,0 +1,99 @@
|
|||
dataset_type = 'KIEDataset'
|
||||
data_root = 'data/wildreceipt'
|
||||
img_norm_cfg = dict(
|
||||
mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True)
|
||||
max_scale, min_scale = 1024, 512
|
||||
|
||||
train_pipeline = [
|
||||
dict(type='LoadImageFromFile'),
|
||||
dict(type='LoadAnnotations'),
|
||||
dict(type='Resize', img_scale=(max_scale, min_scale), keep_ratio=True),
|
||||
dict(type='RandomFlip', flip_ratio=0.),
|
||||
dict(type='Normalize', **img_norm_cfg),
|
||||
dict(type='Pad', size_divisor=32),
|
||||
dict(type='KIEFormatBundle'),
|
||||
dict(
|
||||
type='Collect',
|
||||
keys=['img', 'relations', 'texts', 'gt_bboxes', 'gt_labels'])
|
||||
]
|
||||
test_pipeline = [
|
||||
dict(type='LoadImageFromFile'),
|
||||
dict(type='LoadAnnotations'),
|
||||
dict(type='Resize', img_scale=(max_scale, min_scale), keep_ratio=True),
|
||||
dict(type='RandomFlip', flip_ratio=0.),
|
||||
dict(type='Normalize', **img_norm_cfg),
|
||||
dict(type='Pad', size_divisor=32),
|
||||
dict(type='KIEFormatBundle'),
|
||||
dict(type='Collect', keys=['img', 'relations', 'texts', 'gt_bboxes'])
|
||||
]
|
||||
|
||||
vocab_file = 'dict.txt'
|
||||
class_file = 'class_list.txt'
|
||||
|
||||
data = dict(
|
||||
samples_per_gpu=4,
|
||||
workers_per_gpu=0,
|
||||
train=dict(
|
||||
type=dataset_type,
|
||||
ann_file='train.txt',
|
||||
pipeline=train_pipeline,
|
||||
data_root=data_root,
|
||||
vocab_file=vocab_file,
|
||||
class_file=class_file),
|
||||
val=dict(
|
||||
type=dataset_type,
|
||||
ann_file='test.txt',
|
||||
pipeline=test_pipeline,
|
||||
data_root=data_root,
|
||||
vocab_file=vocab_file,
|
||||
class_file=class_file),
|
||||
test=dict(
|
||||
type=dataset_type,
|
||||
ann_file='test.txt',
|
||||
pipeline=test_pipeline,
|
||||
data_root=data_root,
|
||||
vocab_file=vocab_file,
|
||||
class_file=class_file))
|
||||
|
||||
evaluation = dict(
|
||||
interval=1,
|
||||
metric='macro_f1',
|
||||
metric_options=dict(
|
||||
macro_f1=dict(
|
||||
ignores=[0, 2, 4, 6, 8, 10, 12, 14, 16, 18, 20, 22, 24, 25])))
|
||||
|
||||
model = dict(
|
||||
type='SDMGR',
|
||||
backbone=dict(type='UNet', base_channels=16),
|
||||
bbox_head=dict(
|
||||
type='SDMGRHead', visual_dim=16, num_chars=92, num_classes=26),
|
||||
visual_modality=False,
|
||||
train_cfg=None,
|
||||
test_cfg=None)
|
||||
|
||||
optimizer = dict(type='Adam', weight_decay=0.0001)
|
||||
optimizer_config = dict(grad_clip=None)
|
||||
lr_config = dict(
|
||||
policy='step',
|
||||
warmup='linear',
|
||||
warmup_iters=1,
|
||||
warmup_ratio=1,
|
||||
step=[40, 50])
|
||||
total_epochs = 60
|
||||
|
||||
checkpoint_config = dict(interval=1)
|
||||
log_config = dict(
|
||||
interval=50,
|
||||
hooks=[
|
||||
dict(type='TextLoggerHook'),
|
||||
# dict(
|
||||
# type='PaviLoggerHook',
|
||||
# add_last_ckpt=True,
|
||||
# interval=5,
|
||||
# init_kwargs=dict(project='kie')),
|
||||
])
|
||||
dist_params = dict(backend='nccl')
|
||||
log_level = 'INFO'
|
||||
load_from = None
|
||||
resume_from = None
|
||||
workflow = [('train', 1)]
|
|
@ -0,0 +1,99 @@
|
|||
dataset_type = 'KIEDataset'
|
||||
data_root = 'data/wildreceipt'
|
||||
img_norm_cfg = dict(
|
||||
mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True)
|
||||
max_scale, min_scale = 1024, 512
|
||||
|
||||
train_pipeline = [
|
||||
dict(type='LoadImageFromFile'),
|
||||
dict(type='LoadAnnotations'),
|
||||
dict(type='Resize', img_scale=(max_scale, min_scale), keep_ratio=True),
|
||||
dict(type='RandomFlip', flip_ratio=0.),
|
||||
dict(type='Normalize', **img_norm_cfg),
|
||||
dict(type='Pad', size_divisor=32),
|
||||
dict(type='KIEFormatBundle'),
|
||||
dict(
|
||||
type='Collect',
|
||||
keys=['img', 'relations', 'texts', 'gt_bboxes', 'gt_labels'])
|
||||
]
|
||||
test_pipeline = [
|
||||
dict(type='LoadImageFromFile'),
|
||||
dict(type='LoadAnnotations'),
|
||||
dict(type='Resize', img_scale=(max_scale, min_scale), keep_ratio=True),
|
||||
dict(type='RandomFlip', flip_ratio=0.),
|
||||
dict(type='Normalize', **img_norm_cfg),
|
||||
dict(type='Pad', size_divisor=32),
|
||||
dict(type='KIEFormatBundle'),
|
||||
dict(type='Collect', keys=['img', 'relations', 'texts', 'gt_bboxes'])
|
||||
]
|
||||
|
||||
vocab_file = 'dict.txt'
|
||||
class_file = 'class_list.txt'
|
||||
|
||||
data = dict(
|
||||
samples_per_gpu=4,
|
||||
workers_per_gpu=0,
|
||||
train=dict(
|
||||
type=dataset_type,
|
||||
ann_file='train.txt',
|
||||
pipeline=train_pipeline,
|
||||
data_root=data_root,
|
||||
vocab_file=vocab_file,
|
||||
class_file=class_file),
|
||||
val=dict(
|
||||
type=dataset_type,
|
||||
ann_file='test.txt',
|
||||
pipeline=test_pipeline,
|
||||
data_root=data_root,
|
||||
vocab_file=vocab_file,
|
||||
class_file=class_file),
|
||||
test=dict(
|
||||
type=dataset_type,
|
||||
ann_file='test.txt',
|
||||
pipeline=test_pipeline,
|
||||
data_root=data_root,
|
||||
vocab_file=vocab_file,
|
||||
class_file=class_file))
|
||||
|
||||
evaluation = dict(
|
||||
interval=1,
|
||||
metric='macro_f1',
|
||||
metric_options=dict(
|
||||
macro_f1=dict(
|
||||
ignores=[0, 2, 4, 6, 8, 10, 12, 14, 16, 18, 20, 22, 24, 25])))
|
||||
|
||||
model = dict(
|
||||
type='SDMGR',
|
||||
backbone=dict(type='UNet', base_channels=16),
|
||||
bbox_head=dict(
|
||||
type='SDMGRHead', visual_dim=16, num_chars=92, num_classes=26),
|
||||
visual_modality=True,
|
||||
train_cfg=None,
|
||||
test_cfg=None)
|
||||
|
||||
optimizer = dict(type='Adam', weight_decay=0.0001)
|
||||
optimizer_config = dict(grad_clip=None)
|
||||
lr_config = dict(
|
||||
policy='step',
|
||||
warmup='linear',
|
||||
warmup_iters=1,
|
||||
warmup_ratio=1,
|
||||
step=[40, 50])
|
||||
total_epochs = 60
|
||||
|
||||
checkpoint_config = dict(interval=1)
|
||||
log_config = dict(
|
||||
interval=50,
|
||||
hooks=[
|
||||
dict(type='TextLoggerHook'),
|
||||
# dict(
|
||||
# type='PaviLoggerHook',
|
||||
# add_last_ckpt=True,
|
||||
# interval=5,
|
||||
# init_kwargs=dict(project='kie')),
|
||||
])
|
||||
dist_params = dict(backend='nccl')
|
||||
log_level = 'INFO'
|
||||
load_from = None
|
||||
resume_from = None
|
||||
workflow = [('train', 1)]
|
|
@ -0,0 +1,23 @@
|
|||
# Textsnake
|
||||
|
||||
## Introduction
|
||||
|
||||
[ALGORITHM]
|
||||
|
||||
```bibtex
|
||||
@article{long2018textsnake,
|
||||
title={TextSnake: A Flexible Representation for Detecting Text of Arbitrary Shapes},
|
||||
author={Long, Shangbang and Ruan, Jiaqiang and Zhang, Wenjie and He, Xin and Wu, Wenhao and Yao, Cong},
|
||||
booktitle={ECCV},
|
||||
pages={20-36},
|
||||
year={2018}
|
||||
}
|
||||
```
|
||||
|
||||
## Results and models
|
||||
|
||||
### CTW1500
|
||||
|
||||
| Method | Pretrained Model | Training set | Test set | #epochs | Test size | Recall | Precision | Hmean | Download |
|
||||
| :----------------------------------------------------------------------------: | :--------------: | :-----------: | :----------: | :-----: | :-------: | :----: | :-------: | :---: | :-------------------: |
|
||||
| [TextSnake](/configs/textdet/textsnake/textsnake_r50_fpn_unet_600e_ctw1500.py) | ImageNet | CTW1500 Train | CTW1500 Test | 1200 | 736 | 0.795 | 0.840 | 0.817 | [model](https://download.openmmlab.com/mmocr/textdet/textsnake/textsnake_r50_fpn_unet_1200e_ctw1500-27f65b64.pth) | [config](https://download.openmmlab.com/mmocr/textdet/textsnake/textsnake_r50_fpn_unet_1200e_ctw1500.py) |
|
|
@ -0,0 +1,113 @@
|
|||
_base_ = [
|
||||
'../../_base_/schedules/schedule_1200e.py',
|
||||
'../../_base_/default_runtime.py'
|
||||
]
|
||||
model = dict(
|
||||
type='TextSnake',
|
||||
pretrained='torchvision://resnet50',
|
||||
backbone=dict(
|
||||
type='ResNet',
|
||||
depth=50,
|
||||
num_stages=4,
|
||||
out_indices=(0, 1, 2, 3),
|
||||
frozen_stages=-1,
|
||||
norm_cfg=dict(type='BN', requires_grad=True),
|
||||
norm_eval=True,
|
||||
style='caffe'),
|
||||
neck=dict(
|
||||
type='FPN_UNET', in_channels=[256, 512, 1024, 2048], out_channels=32),
|
||||
bbox_head=dict(
|
||||
type='TextSnakeHead',
|
||||
in_channels=32,
|
||||
text_repr_type='poly',
|
||||
loss=dict(type='TextSnakeLoss')),
|
||||
train_cfg=None,
|
||||
test_cfg=None)
|
||||
|
||||
dataset_type = 'IcdarDataset'
|
||||
data_root = 'data/ctw1500/'
|
||||
img_norm_cfg = dict(
|
||||
mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True)
|
||||
|
||||
train_pipeline = [
|
||||
dict(type='LoadImageFromFile'),
|
||||
dict(
|
||||
type='LoadTextAnnotations',
|
||||
with_bbox=True,
|
||||
with_mask=True,
|
||||
poly2mask=False),
|
||||
dict(type='ColorJitter', brightness=32.0 / 255, saturation=0.5),
|
||||
dict(type='Normalize', **img_norm_cfg),
|
||||
dict(
|
||||
type='RandomCropPolyInstances',
|
||||
instance_key='gt_masks',
|
||||
crop_ratio=0.65,
|
||||
min_side_ratio=0.3),
|
||||
dict(
|
||||
type='RandomRotatePolyInstances',
|
||||
rotate_ratio=0.5,
|
||||
max_angle=20,
|
||||
pad_with_fixed_color=False),
|
||||
dict(
|
||||
type='ScaleAspectJitter',
|
||||
img_scale=[(3000, 736)], # unused
|
||||
ratio_range=(0.7, 1.3),
|
||||
aspect_ratio_range=(0.9, 1.1),
|
||||
multiscale_mode='value',
|
||||
long_size_bound=800,
|
||||
short_size_bound=480,
|
||||
resize_type='long_short_bound',
|
||||
keep_ratio=False),
|
||||
dict(type='SquareResizePad', target_size=800, pad_ratio=0.6),
|
||||
dict(type='RandomFlip', flip_ratio=0.5, direction='horizontal'),
|
||||
dict(type='TextSnakeTargets'),
|
||||
dict(type='Pad', size_divisor=32),
|
||||
dict(
|
||||
type='CustomFormatBundle',
|
||||
keys=[
|
||||
'gt_text_mask', 'gt_center_region_mask', 'gt_mask',
|
||||
'gt_radius_map', 'gt_sin_map', 'gt_cos_map'
|
||||
],
|
||||
visualize=dict(flag=False, boundary_key='gt_text_mask')),
|
||||
dict(
|
||||
type='Collect',
|
||||
keys=[
|
||||
'img', 'gt_text_mask', 'gt_center_region_mask', 'gt_mask',
|
||||
'gt_radius_map', 'gt_sin_map', 'gt_cos_map'
|
||||
])
|
||||
]
|
||||
test_pipeline = [
|
||||
dict(type='LoadImageFromFile'),
|
||||
dict(
|
||||
type='MultiScaleFlipAug',
|
||||
img_scale=(1333, 736),
|
||||
flip=False,
|
||||
transforms=[
|
||||
dict(type='Resize', img_scale=(1333, 736), keep_ratio=True),
|
||||
dict(type='Normalize', **img_norm_cfg),
|
||||
dict(type='Pad', size_divisor=32),
|
||||
dict(type='ImageToTensor', keys=['img']),
|
||||
dict(type='Collect', keys=['img']),
|
||||
])
|
||||
]
|
||||
|
||||
data = dict(
|
||||
samples_per_gpu=4,
|
||||
workers_per_gpu=4,
|
||||
train=dict(
|
||||
type=dataset_type,
|
||||
ann_file=data_root + '/instances_training.json',
|
||||
img_prefix=data_root + '/imgs',
|
||||
pipeline=train_pipeline),
|
||||
val=dict(
|
||||
type=dataset_type,
|
||||
ann_file=data_root + '/instances_test.json',
|
||||
img_prefix=data_root + '/imgs',
|
||||
pipeline=test_pipeline),
|
||||
test=dict(
|
||||
type=dataset_type,
|
||||
ann_file=data_root + '/instances_test.json',
|
||||
img_prefix=data_root + '/imgs',
|
||||
pipeline=test_pipeline))
|
||||
|
||||
evaluation = dict(interval=10, metric='hmean-iou')
|
|
@ -0,0 +1,27 @@
|
|||
import torch
|
||||
|
||||
|
||||
def compute_f1_score(preds, gts, ignores=[]):
|
||||
"""Compute the F1-score of prediction.
|
||||
|
||||
Args:
|
||||
preds (Tensor): The predicted probability NxC map
|
||||
with N and C being the sample number and class
|
||||
number respectively.
|
||||
gts (Tensor): The ground truth vector of size N.
|
||||
ignores (list): The index set of classes that are ignored when
|
||||
reporting results.
|
||||
Note: all samples are participated in computing.
|
||||
|
||||
Returns:
|
||||
The numpy list of f1-scores of valid classes.
|
||||
"""
|
||||
C = preds.size(1)
|
||||
classes = torch.LongTensor(sorted(set(range(C)) - set(ignores)))
|
||||
hist = torch.bincount(
|
||||
gts * C + preds.argmax(1), minlength=C**2).view(C, C).float()
|
||||
diag = torch.diag(hist)
|
||||
recalls = diag / hist.sum(1).clamp(min=1)
|
||||
precisions = diag / hist.sum(0).clamp(min=1)
|
||||
f1 = 2 * recalls * precisions / (recalls + precisions).clamp(min=1e-8)
|
||||
return f1[classes].cpu().numpy()
|
|
@ -0,0 +1,295 @@
|
|||
import copy
|
||||
from os import path as osp
|
||||
|
||||
import mmcv
|
||||
import numpy as np
|
||||
import torch
|
||||
from matplotlib import pyplot as plt
|
||||
from PIL import Image
|
||||
|
||||
from mmdet.datasets.builder import DATASETS
|
||||
from mmdet.datasets.custom import CustomDataset
|
||||
from mmocr.core import compute_f1_score
|
||||
|
||||
|
||||
@DATASETS.register_module()
|
||||
class KIEDataset(CustomDataset):
|
||||
|
||||
def __init__(self,
|
||||
ann_file,
|
||||
pipeline=None,
|
||||
data_root=None,
|
||||
img_prefix='',
|
||||
ann_prefix='',
|
||||
vocab_file=None,
|
||||
class_file=None,
|
||||
norm=10.,
|
||||
thresholds=dict(edge=0.5),
|
||||
directed=False,
|
||||
**kwargs):
|
||||
self.ann_prefix = ann_prefix
|
||||
self.norm = norm
|
||||
self.thresholds = thresholds
|
||||
self.directed = directed
|
||||
|
||||
if data_root is not None:
|
||||
if not osp.isabs(ann_file):
|
||||
self.ann_file = osp.join(data_root, ann_file)
|
||||
if not (ann_prefix is None or osp.isabs(ann_prefix)):
|
||||
self.ann_prefix = osp.join(data_root, ann_prefix)
|
||||
|
||||
self.vocab = dict({'': 0})
|
||||
vocab_file = osp.join(data_root, vocab_file)
|
||||
if osp.exists(vocab_file):
|
||||
with open(vocab_file, 'r') as fid:
|
||||
for idx, char in enumerate(fid.readlines(), 1):
|
||||
self.vocab[char.strip('\n')] = idx
|
||||
else:
|
||||
self.construct_dict(self.ann_file)
|
||||
with open(vocab_file, 'w') as fid:
|
||||
for key in self.vocab:
|
||||
if key:
|
||||
fid.write('{}\n'.format(key))
|
||||
|
||||
super().__init__(
|
||||
ann_file,
|
||||
pipeline,
|
||||
data_root=data_root,
|
||||
img_prefix=img_prefix,
|
||||
**kwargs)
|
||||
|
||||
self.idx_to_cls = dict()
|
||||
with open(osp.join(data_root, class_file), 'r') as fid:
|
||||
for line in fid.readlines():
|
||||
idx, cls = line.split()
|
||||
self.idx_to_cls[int(idx)] = cls
|
||||
|
||||
@staticmethod
|
||||
def _split_edge(line):
|
||||
text = ','.join(line[8:-1])
|
||||
if ';' in text and text.split(';')[0].isdecimal():
|
||||
edge, text = text.split(';', 1)
|
||||
edge = int(edge)
|
||||
else:
|
||||
edge = 0
|
||||
return edge, text
|
||||
|
||||
def construct_dict(self, ann_file):
|
||||
img_infos = mmcv.list_from_file(ann_file)
|
||||
for img_info in img_infos:
|
||||
_, annname = img_info.split()
|
||||
if self.ann_prefix:
|
||||
annname = osp.join(self.ann_prefix, annname)
|
||||
with open(annname, 'r') as fid:
|
||||
lines = fid.readlines()
|
||||
|
||||
for line in lines:
|
||||
line = line.strip().split(',')
|
||||
_, text = self._split_edge(line)
|
||||
for c in text:
|
||||
if c not in self.vocab:
|
||||
self.vocab[c] = len(self.vocab)
|
||||
self.vocab = dict(
|
||||
{k: idx
|
||||
for idx, k in enumerate(sorted(self.vocab.keys()))})
|
||||
|
||||
def convert_text(self, text):
|
||||
return [self.vocab[c] for c in text if c in self.vocab]
|
||||
|
||||
def parse_lines(self, annname):
|
||||
boxes, edges, texts, chars, labels = [], [], [], [], []
|
||||
|
||||
if self.ann_prefix:
|
||||
annname = osp.join(self.ann_prefix, annname)
|
||||
|
||||
with open(annname, 'r') as fid:
|
||||
for line in fid.readlines():
|
||||
line = line.strip().split(',')
|
||||
boxes.append(list(map(int, line[:8])))
|
||||
edge, text = self._split_edge(line)
|
||||
chars.append(text)
|
||||
text = self.convert_text(text)
|
||||
texts.append(text)
|
||||
edges.append(edge)
|
||||
labels.append(int(line[-1]))
|
||||
return dict(
|
||||
boxes=boxes, edges=edges, texts=texts, chars=chars, labels=labels)
|
||||
|
||||
def format_results(self, results):
|
||||
boxes = torch.Tensor(results['boxes'])[:, [0, 1, 4, 5]].cuda()
|
||||
|
||||
if 'nodes' in results:
|
||||
nodes, edges = results['nodes'], results['edges']
|
||||
labels = nodes.argmax(-1)
|
||||
num_nodes = nodes.size(0)
|
||||
edges = edges[:, -1].view(num_nodes, num_nodes)
|
||||
else:
|
||||
labels = torch.Tensor(results['labels']).cuda()
|
||||
edges = torch.Tensor(results['edges']).cuda()
|
||||
boxes = torch.cat([boxes, labels[:, None].float()], -1)
|
||||
|
||||
return {
|
||||
**{
|
||||
k: v
|
||||
for k, v in results.items() if k not in ['boxes', 'edges']
|
||||
}, 'boxes': boxes,
|
||||
'edges': edges,
|
||||
'points': results['boxes']
|
||||
}
|
||||
|
||||
def plot(self, results):
|
||||
img_name = osp.join(self.img_prefix, results['filename'])
|
||||
img = plt.imread(img_name)
|
||||
plt.imshow(img)
|
||||
|
||||
boxes, texts = results['points'], results['chars']
|
||||
num_nodes = len(boxes)
|
||||
if 'scores' in results:
|
||||
scores = results['scores']
|
||||
else:
|
||||
scores = np.ones(num_nodes)
|
||||
for box, text, score in zip(boxes, texts, scores):
|
||||
xs, ys = [], []
|
||||
for idx in range(0, 10, 2):
|
||||
xs.append(box[idx % 8])
|
||||
ys.append(box[(idx + 1) % 8])
|
||||
plt.plot(xs, ys, 'g')
|
||||
plt.annotate(
|
||||
'{}: {:.4f}'.format(text, score), (box[0], box[1]), color='g')
|
||||
|
||||
if 'nodes' in results:
|
||||
nodes = results['nodes']
|
||||
inds = nodes.argmax(-1)
|
||||
else:
|
||||
nodes = np.ones((num_nodes, 3))
|
||||
inds = results['labels']
|
||||
for i in range(num_nodes):
|
||||
plt.annotate(
|
||||
'{}: {:.4f}'.format(
|
||||
self.idx_to_cls(inds[i] - 1), nodes[i, inds[i]]),
|
||||
(boxes[i][6], boxes[i][7]),
|
||||
color='r' if inds[i] == 1 else 'b')
|
||||
edges = results['edges']
|
||||
if 'nodes' not in results:
|
||||
edges = (edges[:, None] == edges[None]).float()
|
||||
for j in range(i + 1, num_nodes):
|
||||
edge_score = max(edges[i][j], edges[j][i])
|
||||
if edge_score > self.thresholds['edge']:
|
||||
x1 = sum(boxes[i][:3:2]) // 2
|
||||
y1 = sum(boxes[i][3:6:2]) // 2
|
||||
x2 = sum(boxes[j][:3:2]) // 2
|
||||
y2 = sum(boxes[j][3:6:2]) // 2
|
||||
plt.plot((x1, x2), (y1, y2), 'r')
|
||||
plt.annotate(
|
||||
'{:.4f}'.format(edge_score),
|
||||
((x1 + x2) // 2, (y1 + y2) // 2),
|
||||
color='r')
|
||||
|
||||
def compute_relation(self, boxes):
|
||||
x1s, y1s = boxes[:, 0:1], boxes[:, 1:2]
|
||||
x2s, y2s = boxes[:, 4:5], boxes[:, 5:6]
|
||||
ws, hs = x2s - x1s + 1, np.maximum(y2s - y1s + 1, 1)
|
||||
dxs = (x1s[:, 0][None] - x1s) / self.norm
|
||||
dys = (y1s[:, 0][None] - y1s) / self.norm
|
||||
xhhs, xwhs = hs[:, 0][None] / hs, ws[:, 0][None] / hs
|
||||
whs = ws / hs + np.zeros_like(xhhs)
|
||||
relations = np.stack([dxs, dys, whs, xhhs, xwhs], -1)
|
||||
bboxes = np.concatenate([x1s, y1s, x2s, y2s], -1).astype(np.float32)
|
||||
return relations, bboxes
|
||||
|
||||
def ann_numpy(self, results):
|
||||
boxes, texts = results['boxes'], results['texts']
|
||||
boxes = np.array(boxes, np.int32)
|
||||
if boxes[0, 1] > boxes[0, -1]:
|
||||
boxes = boxes[:, [6, 7, 4, 5, 2, 3, 0, 1]]
|
||||
relations, bboxes = self.compute_relation(boxes)
|
||||
|
||||
labels = results.get('labels', None)
|
||||
if labels is not None:
|
||||
labels = np.array(labels, np.int32)
|
||||
edges = results.get('edges', None)
|
||||
if edges is not None:
|
||||
labels = labels[:, None]
|
||||
edges = np.array(edges)
|
||||
edges = (edges[:, None] == edges[None, :]).astype(np.int32)
|
||||
if self.directed:
|
||||
edges = (edges & labels == 1).astype(np.int32)
|
||||
np.fill_diagonal(edges, -1)
|
||||
labels = np.concatenate([labels, edges], -1)
|
||||
return dict(
|
||||
bboxes=bboxes,
|
||||
relations=relations,
|
||||
texts=self.pad_text(texts),
|
||||
labels=labels)
|
||||
|
||||
def image_size(self, filename):
|
||||
img_path = osp.join(self.img_prefix, filename)
|
||||
img = Image.open(img_path)
|
||||
return img.size
|
||||
|
||||
def load_annotations(self, ann_file):
|
||||
self.anns, data_infos = [], []
|
||||
|
||||
self.gts = dict()
|
||||
img_infos = mmcv.list_from_file(ann_file)
|
||||
for img_info in img_infos:
|
||||
filename, annname = img_info.split()
|
||||
results = self.parse_lines(annname)
|
||||
width, height = self.image_size(filename)
|
||||
|
||||
data_infos.append(
|
||||
dict(filename=filename, width=width, height=height))
|
||||
ann = self.ann_numpy(results)
|
||||
self.anns.append(ann)
|
||||
|
||||
return data_infos
|
||||
|
||||
def pad_text(self, texts):
|
||||
max_len = max([len(text) for text in texts])
|
||||
padded_texts = -np.ones((len(texts), max_len), np.int32)
|
||||
for idx, text in enumerate(texts):
|
||||
padded_texts[idx, :len(text)] = np.array(text)
|
||||
return padded_texts
|
||||
|
||||
def get_ann_info(self, idx):
|
||||
return self.anns[idx]
|
||||
|
||||
def prepare_test_img(self, idx):
|
||||
return self.prepare_train_img(idx)
|
||||
|
||||
def evaluate(self,
|
||||
results,
|
||||
metric='macro_f1',
|
||||
metric_options=dict(macro_f1=dict(ignores=[])),
|
||||
**kwargs):
|
||||
# allow some kwargs to pass through
|
||||
assert set(kwargs).issubset(['logger'])
|
||||
|
||||
# Protect ``metric_options`` since it uses mutable value as default
|
||||
metric_options = copy.deepcopy(metric_options)
|
||||
|
||||
metrics = metric if isinstance(metric, list) else [metric]
|
||||
allowed_metrics = ['macro_f1']
|
||||
for m in metrics:
|
||||
if m not in allowed_metrics:
|
||||
raise KeyError(f'metric {m} is not supported')
|
||||
|
||||
return self.compute_macro_f1(results, **metric_options['macro_f1'])
|
||||
|
||||
def compute_macro_f1(self, results, ignores=[]):
|
||||
node_preds = []
|
||||
for result in results:
|
||||
node_preds.append(result['nodes'])
|
||||
node_preds = torch.cat(node_preds)
|
||||
|
||||
node_gts = [
|
||||
torch.from_numpy(ann['labels'][:, 0]).to(node_preds.device)
|
||||
for ann in self.anns
|
||||
]
|
||||
node_gts = torch.cat(node_gts)
|
||||
|
||||
node_f1s = compute_f1_score(node_preds, node_gts, ignores)
|
||||
|
||||
return {
|
||||
'macro_f1': node_f1s.mean(),
|
||||
}
|
|
@ -0,0 +1,55 @@
|
|||
import numpy as np
|
||||
from mmcv.parallel import DataContainer as DC
|
||||
|
||||
from mmdet.datasets.builder import PIPELINES
|
||||
from mmdet.datasets.pipelines.formating import DefaultFormatBundle, to_tensor
|
||||
|
||||
|
||||
@PIPELINES.register_module()
|
||||
class KIEFormatBundle(DefaultFormatBundle):
|
||||
"""Key information extraction formatting bundle.
|
||||
|
||||
Based on the DefaultFormatBundle, itt simplifies the pipeline of formatting
|
||||
common fields, including "img", "proposals", "gt_bboxes", "gt_labels",
|
||||
"gt_masks", "gt_semantic_seg", "relations" and "texts".
|
||||
These fields are formatted as follows.
|
||||
|
||||
- img: (1) transpose, (2) to tensor, (3) to DataContainer (stack=True)
|
||||
- proposals: (1) to tensor, (2) to DataContainer
|
||||
- gt_bboxes: (1) to tensor, (2) to DataContainer
|
||||
- gt_bboxes_ignore: (1) to tensor, (2) to DataContainer
|
||||
- gt_labels: (1) to tensor, (2) to DataContainer
|
||||
- gt_masks: (1) to tensor, (2) to DataContainer (cpu_only=True)
|
||||
- gt_semantic_seg: (1) unsqueeze dim-0 (2) to tensor, \
|
||||
(3) to DataContainer (stack=True)
|
||||
- relations: (1) scale, (2) to tensor, (3) to DataContainer
|
||||
- texts: (1) to tensor, (2) to DataContainer
|
||||
"""
|
||||
|
||||
def __call__(self, results):
|
||||
"""Call function to transform and format common fields in results.
|
||||
|
||||
Args:
|
||||
results (dict): Result dict contains the data to convert.
|
||||
|
||||
Returns:
|
||||
dict: The result dict contains the data that is formatted with \
|
||||
default bundle.
|
||||
"""
|
||||
super().__call__(results)
|
||||
if 'ann_info' in results:
|
||||
for key in ['relations', 'texts']:
|
||||
value = results['ann_info'][key]
|
||||
if key == 'relations' and 'scale_factor' in results:
|
||||
scale_factor = results['scale_factor']
|
||||
if isinstance(scale_factor, float):
|
||||
sx = sy = scale_factor
|
||||
else:
|
||||
sx, sy = results['scale_factor'][:2]
|
||||
r = sx / sy
|
||||
value = value * np.array([sx, sy, r, 1, r])[None, None]
|
||||
results[key] = DC(to_tensor(value))
|
||||
return results
|
||||
|
||||
def __repr__(self):
|
||||
return self.__class__.__name__
|
|
@ -0,0 +1,3 @@
|
|||
from .unet import UNet
|
||||
|
||||
__all__ = ['UNet']
|
|
@ -0,0 +1,528 @@
|
|||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.utils.checkpoint as cp
|
||||
from mmcv.cnn import (UPSAMPLE_LAYERS, ConvModule, build_activation_layer,
|
||||
build_norm_layer, build_upsample_layer, constant_init,
|
||||
kaiming_init)
|
||||
from mmcv.runner import load_checkpoint
|
||||
from mmcv.utils.parrots_wrapper import _BatchNorm
|
||||
|
||||
from mmdet.models.builder import BACKBONES
|
||||
from mmdet.utils import get_root_logger
|
||||
|
||||
|
||||
class UpConvBlock(nn.Module):
|
||||
"""Upsample convolution block in decoder for UNet.
|
||||
|
||||
This upsample convolution block consists of one upsample module
|
||||
followed by one convolution block. The upsample module expands the
|
||||
high-level low-resolution feature map and the convolution block fuses
|
||||
the upsampled high-level low-resolution feature map and the low-level
|
||||
high-resolution feature map from encoder.
|
||||
|
||||
Args:
|
||||
conv_block (nn.Sequential): Sequential of convolutional layers.
|
||||
in_channels (int): Number of input channels of the high-level
|
||||
skip_channels (int): Number of input channels of the low-level
|
||||
high-resolution feature map from encoder.
|
||||
out_channels (int): Number of output channels.
|
||||
num_convs (int): Number of convolutional layers in the conv_block.
|
||||
Default: 2.
|
||||
stride (int): Stride of convolutional layer in conv_block. Default: 1.
|
||||
dilation (int): Dilation rate of convolutional layer in conv_block.
|
||||
Default: 1.
|
||||
with_cp (bool): Use checkpoint or not. Using checkpoint will save some
|
||||
memory while slowing down the training speed. Default: False.
|
||||
conv_cfg (dict | None): Config dict for convolution layer.
|
||||
Default: None.
|
||||
norm_cfg (dict | None): Config dict for normalization layer.
|
||||
Default: dict(type='BN').
|
||||
act_cfg (dict | None): Config dict for activation layer in ConvModule.
|
||||
Default: dict(type='ReLU').
|
||||
upsample_cfg (dict): The upsample config of the upsample module in
|
||||
decoder. Default: dict(type='InterpConv'). If the size of
|
||||
high-level feature map is the same as that of skip feature map
|
||||
(low-level feature map from encoder), it does not need upsample the
|
||||
high-level feature map and the upsample_cfg is None.
|
||||
dcn (bool): Use deformable convolution in convolutional layer or not.
|
||||
Default: None.
|
||||
plugins (dict): plugins for convolutional layers. Default: None.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
conv_block,
|
||||
in_channels,
|
||||
skip_channels,
|
||||
out_channels,
|
||||
num_convs=2,
|
||||
stride=1,
|
||||
dilation=1,
|
||||
with_cp=False,
|
||||
conv_cfg=None,
|
||||
norm_cfg=dict(type='BN'),
|
||||
act_cfg=dict(type='ReLU'),
|
||||
upsample_cfg=dict(type='InterpConv'),
|
||||
dcn=None,
|
||||
plugins=None):
|
||||
super().__init__()
|
||||
assert dcn is None, 'Not implemented yet.'
|
||||
assert plugins is None, 'Not implemented yet.'
|
||||
|
||||
self.conv_block = conv_block(
|
||||
in_channels=2 * skip_channels,
|
||||
out_channels=out_channels,
|
||||
num_convs=num_convs,
|
||||
stride=stride,
|
||||
dilation=dilation,
|
||||
with_cp=with_cp,
|
||||
conv_cfg=conv_cfg,
|
||||
norm_cfg=norm_cfg,
|
||||
act_cfg=act_cfg,
|
||||
dcn=None,
|
||||
plugins=None)
|
||||
if upsample_cfg is not None:
|
||||
self.upsample = build_upsample_layer(
|
||||
cfg=upsample_cfg,
|
||||
in_channels=in_channels,
|
||||
out_channels=skip_channels,
|
||||
with_cp=with_cp,
|
||||
norm_cfg=norm_cfg,
|
||||
act_cfg=act_cfg)
|
||||
else:
|
||||
self.upsample = ConvModule(
|
||||
in_channels,
|
||||
skip_channels,
|
||||
kernel_size=1,
|
||||
stride=1,
|
||||
padding=0,
|
||||
conv_cfg=conv_cfg,
|
||||
norm_cfg=norm_cfg,
|
||||
act_cfg=act_cfg)
|
||||
|
||||
def forward(self, skip, x):
|
||||
"""Forward function."""
|
||||
|
||||
x = self.upsample(x)
|
||||
out = torch.cat([skip, x], dim=1)
|
||||
out = self.conv_block(out)
|
||||
|
||||
return out
|
||||
|
||||
|
||||
class BasicConvBlock(nn.Module):
|
||||
"""Basic convolutional block for UNet.
|
||||
|
||||
This module consists of several plain convolutional layers.
|
||||
|
||||
Args:
|
||||
in_channels (int): Number of input channels.
|
||||
out_channels (int): Number of output channels.
|
||||
num_convs (int): Number of convolutional layers. Default: 2.
|
||||
stride (int): Whether use stride convolution to downsample
|
||||
the input feature map. If stride=2, it only uses stride convolution
|
||||
in the first convolutional layer to downsample the input feature
|
||||
map. Options are 1 or 2. Default: 1.
|
||||
dilation (int): Whether use dilated convolution to expand the
|
||||
receptive field. Set dilation rate of each convolutional layer and
|
||||
the dilation rate of the first convolutional layer is always 1.
|
||||
Default: 1.
|
||||
with_cp (bool): Use checkpoint or not. Using checkpoint will save some
|
||||
memory while slowing down the training speed. Default: False.
|
||||
conv_cfg (dict | None): Config dict for convolution layer.
|
||||
Default: None.
|
||||
norm_cfg (dict | None): Config dict for normalization layer.
|
||||
Default: dict(type='BN').
|
||||
act_cfg (dict | None): Config dict for activation layer in ConvModule.
|
||||
Default: dict(type='ReLU').
|
||||
dcn (bool): Use deformable convolution in convolutional layer or not.
|
||||
Default: None.
|
||||
plugins (dict): plugins for convolutional layers. Default: None.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
in_channels,
|
||||
out_channels,
|
||||
num_convs=2,
|
||||
stride=1,
|
||||
dilation=1,
|
||||
with_cp=False,
|
||||
conv_cfg=None,
|
||||
norm_cfg=dict(type='BN'),
|
||||
act_cfg=dict(type='ReLU'),
|
||||
dcn=None,
|
||||
plugins=None):
|
||||
super().__init__()
|
||||
assert dcn is None, 'Not implemented yet.'
|
||||
assert plugins is None, 'Not implemented yet.'
|
||||
|
||||
self.with_cp = with_cp
|
||||
convs = []
|
||||
for i in range(num_convs):
|
||||
convs.append(
|
||||
ConvModule(
|
||||
in_channels=in_channels if i == 0 else out_channels,
|
||||
out_channels=out_channels,
|
||||
kernel_size=3,
|
||||
stride=stride if i == 0 else 1,
|
||||
dilation=1 if i == 0 else dilation,
|
||||
padding=1 if i == 0 else dilation,
|
||||
conv_cfg=conv_cfg,
|
||||
norm_cfg=norm_cfg,
|
||||
act_cfg=act_cfg))
|
||||
|
||||
self.convs = nn.Sequential(*convs)
|
||||
|
||||
def forward(self, x):
|
||||
"""Forward function."""
|
||||
|
||||
if self.with_cp and x.requires_grad:
|
||||
out = cp.checkpoint(self.convs, x)
|
||||
else:
|
||||
out = self.convs(x)
|
||||
return out
|
||||
|
||||
|
||||
@UPSAMPLE_LAYERS.register_module()
|
||||
class DeconvModule(nn.Module):
|
||||
"""Deconvolution upsample module in decoder for UNet (2X upsample).
|
||||
|
||||
This module uses deconvolution to upsample feature map in the decoder
|
||||
of UNet.
|
||||
|
||||
Args:
|
||||
in_channels (int): Number of input channels.
|
||||
out_channels (int): Number of output channels.
|
||||
with_cp (bool): Use checkpoint or not. Using checkpoint will save some
|
||||
memory while slowing down the training speed. Default: False.
|
||||
norm_cfg (dict | None): Config dict for normalization layer.
|
||||
Default: dict(type='BN').
|
||||
act_cfg (dict | None): Config dict for activation layer in ConvModule.
|
||||
Default: dict(type='ReLU').
|
||||
kernel_size (int): Kernel size of the convolutional layer. Default: 4.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
in_channels,
|
||||
out_channels,
|
||||
with_cp=False,
|
||||
norm_cfg=dict(type='BN'),
|
||||
act_cfg=dict(type='ReLU'),
|
||||
*,
|
||||
kernel_size=4,
|
||||
scale_factor=2):
|
||||
super().__init__()
|
||||
|
||||
assert (kernel_size - scale_factor >= 0) and\
|
||||
(kernel_size - scale_factor) % 2 == 0,\
|
||||
f'kernel_size should be greater than or equal to scale_factor '\
|
||||
f'and (kernel_size - scale_factor) should be even numbers, '\
|
||||
f'while the kernel size is {kernel_size} and scale_factor is '\
|
||||
f'{scale_factor}.'
|
||||
|
||||
stride = scale_factor
|
||||
padding = (kernel_size - scale_factor) // 2
|
||||
self.with_cp = with_cp
|
||||
deconv = nn.ConvTranspose2d(
|
||||
in_channels,
|
||||
out_channels,
|
||||
kernel_size=kernel_size,
|
||||
stride=stride,
|
||||
padding=padding)
|
||||
|
||||
_, norm = build_norm_layer(norm_cfg, out_channels)
|
||||
activate = build_activation_layer(act_cfg)
|
||||
self.deconv_upsamping = nn.Sequential(deconv, norm, activate)
|
||||
|
||||
def forward(self, x):
|
||||
"""Forward function."""
|
||||
|
||||
if self.with_cp and x.requires_grad:
|
||||
out = cp.checkpoint(self.deconv_upsamping, x)
|
||||
else:
|
||||
out = self.deconv_upsamping(x)
|
||||
return out
|
||||
|
||||
|
||||
@UPSAMPLE_LAYERS.register_module()
|
||||
class InterpConv(nn.Module):
|
||||
"""Interpolation upsample module in decoder for UNet.
|
||||
|
||||
This module uses interpolation to upsample feature map in the decoder
|
||||
of UNet. It consists of one interpolation upsample layer and one
|
||||
convolutional layer. It can be one interpolation upsample layer followed
|
||||
by one convolutional layer (conv_first=False) or one convolutional layer
|
||||
followed by one interpolation upsample layer (conv_first=True).
|
||||
|
||||
Args:
|
||||
in_channels (int): Number of input channels.
|
||||
out_channels (int): Number of output channels.
|
||||
with_cp (bool): Use checkpoint or not. Using checkpoint will save some
|
||||
memory while slowing down the training speed. Default: False.
|
||||
norm_cfg (dict | None): Config dict for normalization layer.
|
||||
Default: dict(type='BN').
|
||||
act_cfg (dict | None): Config dict for activation layer in ConvModule.
|
||||
Default: dict(type='ReLU').
|
||||
conv_cfg (dict | None): Config dict for convolution layer.
|
||||
Default: None.
|
||||
conv_first (bool): Whether convolutional layer or interpolation
|
||||
upsample layer first. Default: False. It means interpolation
|
||||
upsample layer followed by one convolutional layer.
|
||||
kernel_size (int): Kernel size of the convolutional layer. Default: 1.
|
||||
stride (int): Stride of the convolutional layer. Default: 1.
|
||||
padding (int): Padding of the convolutional layer. Default: 1.
|
||||
upsample_cfg (dict): Interpolation config of the upsample layer.
|
||||
Default: dict(
|
||||
scale_factor=2, mode='bilinear', align_corners=False).
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
in_channels,
|
||||
out_channels,
|
||||
with_cp=False,
|
||||
norm_cfg=dict(type='BN'),
|
||||
act_cfg=dict(type='ReLU'),
|
||||
*,
|
||||
conv_cfg=None,
|
||||
conv_first=False,
|
||||
kernel_size=1,
|
||||
stride=1,
|
||||
padding=0,
|
||||
upsample_cfg=dict(
|
||||
scale_factor=2, mode='bilinear', align_corners=False)):
|
||||
super().__init__()
|
||||
|
||||
self.with_cp = with_cp
|
||||
conv = ConvModule(
|
||||
in_channels,
|
||||
out_channels,
|
||||
kernel_size=kernel_size,
|
||||
stride=stride,
|
||||
padding=padding,
|
||||
conv_cfg=conv_cfg,
|
||||
norm_cfg=norm_cfg,
|
||||
act_cfg=act_cfg)
|
||||
upsample = nn.Upsample(**upsample_cfg)
|
||||
if conv_first:
|
||||
self.interp_upsample = nn.Sequential(conv, upsample)
|
||||
else:
|
||||
self.interp_upsample = nn.Sequential(upsample, conv)
|
||||
|
||||
def forward(self, x):
|
||||
"""Forward function."""
|
||||
|
||||
if self.with_cp and x.requires_grad:
|
||||
out = cp.checkpoint(self.interp_upsample, x)
|
||||
else:
|
||||
out = self.interp_upsample(x)
|
||||
return out
|
||||
|
||||
|
||||
@BACKBONES.register_module()
|
||||
class UNet(nn.Module):
|
||||
"""UNet backbone.
|
||||
U-Net: Convolutional Networks for Biomedical Image Segmentation.
|
||||
https://arxiv.org/pdf/1505.04597.pdf
|
||||
|
||||
Args:
|
||||
in_channels (int): Number of input image channels. Default" 3.
|
||||
base_channels (int): Number of base channels of each stage.
|
||||
The output channels of the first stage. Default: 64.
|
||||
num_stages (int): Number of stages in encoder, normally 5. Default: 5.
|
||||
strides (Sequence[int 1 | 2]): Strides of each stage in encoder.
|
||||
len(strides) is equal to num_stages. Normally the stride of the
|
||||
first stage in encoder is 1. If strides[i]=2, it uses stride
|
||||
convolution to downsample in the correspondence encoder stage.
|
||||
Default: (1, 1, 1, 1, 1).
|
||||
enc_num_convs (Sequence[int]): Number of convolutional layers in the
|
||||
convolution block of the correspondence encoder stage.
|
||||
Default: (2, 2, 2, 2, 2).
|
||||
dec_num_convs (Sequence[int]): Number of convolutional layers in the
|
||||
convolution block of the correspondence decoder stage.
|
||||
Default: (2, 2, 2, 2).
|
||||
downsamples (Sequence[int]): Whether use MaxPool to downsample the
|
||||
feature map after the first stage of encoder
|
||||
(stages: [1, num_stages)). If the correspondence encoder stage use
|
||||
stride convolution (strides[i]=2), it will never use MaxPool to
|
||||
downsample, even downsamples[i-1]=True.
|
||||
Default: (True, True, True, True).
|
||||
enc_dilations (Sequence[int]): Dilation rate of each stage in encoder.
|
||||
Default: (1, 1, 1, 1, 1).
|
||||
dec_dilations (Sequence[int]): Dilation rate of each stage in decoder.
|
||||
Default: (1, 1, 1, 1).
|
||||
with_cp (bool): Use checkpoint or not. Using checkpoint will save some
|
||||
memory while slowing down the training speed. Default: False.
|
||||
conv_cfg (dict | None): Config dict for convolution layer.
|
||||
Default: None.
|
||||
norm_cfg (dict | None): Config dict for normalization layer.
|
||||
Default: dict(type='BN').
|
||||
act_cfg (dict | None): Config dict for activation layer in ConvModule.
|
||||
Default: dict(type='ReLU').
|
||||
upsample_cfg (dict): The upsample config of the upsample module in
|
||||
decoder. Default: dict(type='InterpConv').
|
||||
norm_eval (bool): Whether to set norm layers to eval mode, namely,
|
||||
freeze running stats (mean and var). Note: Effect on Batch Norm
|
||||
and its variants only. Default: False.
|
||||
dcn (bool): Use deformable convolution in convolutional layer or not.
|
||||
Default: None.
|
||||
plugins (dict): plugins for convolutional layers. Default: None.
|
||||
|
||||
Notice:
|
||||
The input image size should be divisible by the whole downsample rate
|
||||
of the encoder. More detail of the whole downsample rate can be found
|
||||
in UNet._check_input_divisible.
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
in_channels=3,
|
||||
base_channels=64,
|
||||
num_stages=5,
|
||||
strides=(1, 1, 1, 1, 1),
|
||||
enc_num_convs=(2, 2, 2, 2, 2),
|
||||
dec_num_convs=(2, 2, 2, 2),
|
||||
downsamples=(True, True, True, True),
|
||||
enc_dilations=(1, 1, 1, 1, 1),
|
||||
dec_dilations=(1, 1, 1, 1),
|
||||
with_cp=False,
|
||||
conv_cfg=None,
|
||||
norm_cfg=dict(type='BN'),
|
||||
act_cfg=dict(type='ReLU'),
|
||||
upsample_cfg=dict(type='InterpConv'),
|
||||
norm_eval=False,
|
||||
dcn=None,
|
||||
plugins=None):
|
||||
super().__init__()
|
||||
assert dcn is None, 'Not implemented yet.'
|
||||
assert plugins is None, 'Not implemented yet.'
|
||||
assert len(strides) == num_stages, \
|
||||
'The length of strides should be equal to num_stages, '\
|
||||
f'while the strides is {strides}, the length of '\
|
||||
f'strides is {len(strides)}, and the num_stages is '\
|
||||
f'{num_stages}.'
|
||||
assert len(enc_num_convs) == num_stages, \
|
||||
'The length of enc_num_convs should be equal to num_stages, '\
|
||||
f'while the enc_num_convs is {enc_num_convs}, the length of '\
|
||||
f'enc_num_convs is {len(enc_num_convs)}, and the num_stages is '\
|
||||
f'{num_stages}.'
|
||||
assert len(dec_num_convs) == (num_stages-1), \
|
||||
'The length of dec_num_convs should be equal to (num_stages-1), '\
|
||||
f'while the dec_num_convs is {dec_num_convs}, the length of '\
|
||||
f'dec_num_convs is {len(dec_num_convs)}, and the num_stages is '\
|
||||
f'{num_stages}.'
|
||||
assert len(downsamples) == (num_stages-1), \
|
||||
'The length of downsamples should be equal to (num_stages-1), '\
|
||||
f'while the downsamples is {downsamples}, the length of '\
|
||||
f'downsamples is {len(downsamples)}, and the num_stages is '\
|
||||
f'{num_stages}.'
|
||||
assert len(enc_dilations) == num_stages, \
|
||||
'The length of enc_dilations should be equal to num_stages, '\
|
||||
f'while the enc_dilations is {enc_dilations}, the length of '\
|
||||
f'enc_dilations is {len(enc_dilations)}, and the num_stages is '\
|
||||
f'{num_stages}.'
|
||||
assert len(dec_dilations) == (num_stages-1), \
|
||||
'The length of dec_dilations should be equal to (num_stages-1), '\
|
||||
f'while the dec_dilations is {dec_dilations}, the length of '\
|
||||
f'dec_dilations is {len(dec_dilations)}, and the num_stages is '\
|
||||
f'{num_stages}.'
|
||||
self.num_stages = num_stages
|
||||
self.strides = strides
|
||||
self.downsamples = downsamples
|
||||
self.norm_eval = norm_eval
|
||||
self.base_channels = base_channels
|
||||
|
||||
self.encoder = nn.ModuleList()
|
||||
self.decoder = nn.ModuleList()
|
||||
|
||||
for i in range(num_stages):
|
||||
enc_conv_block = []
|
||||
if i != 0:
|
||||
if strides[i] == 1 and downsamples[i - 1]:
|
||||
enc_conv_block.append(nn.MaxPool2d(kernel_size=2))
|
||||
upsample = (strides[i] != 1 or downsamples[i - 1])
|
||||
self.decoder.append(
|
||||
UpConvBlock(
|
||||
conv_block=BasicConvBlock,
|
||||
in_channels=base_channels * 2**i,
|
||||
skip_channels=base_channels * 2**(i - 1),
|
||||
out_channels=base_channels * 2**(i - 1),
|
||||
num_convs=dec_num_convs[i - 1],
|
||||
stride=1,
|
||||
dilation=dec_dilations[i - 1],
|
||||
with_cp=with_cp,
|
||||
conv_cfg=conv_cfg,
|
||||
norm_cfg=norm_cfg,
|
||||
act_cfg=act_cfg,
|
||||
upsample_cfg=upsample_cfg if upsample else None,
|
||||
dcn=None,
|
||||
plugins=None))
|
||||
|
||||
enc_conv_block.append(
|
||||
BasicConvBlock(
|
||||
in_channels=in_channels,
|
||||
out_channels=base_channels * 2**i,
|
||||
num_convs=enc_num_convs[i],
|
||||
stride=strides[i],
|
||||
dilation=enc_dilations[i],
|
||||
with_cp=with_cp,
|
||||
conv_cfg=conv_cfg,
|
||||
norm_cfg=norm_cfg,
|
||||
act_cfg=act_cfg,
|
||||
dcn=None,
|
||||
plugins=None))
|
||||
self.encoder.append((nn.Sequential(*enc_conv_block)))
|
||||
in_channels = base_channels * 2**i
|
||||
|
||||
def forward(self, x):
|
||||
self._check_input_divisible(x)
|
||||
enc_outs = []
|
||||
for enc in self.encoder:
|
||||
x = enc(x)
|
||||
enc_outs.append(x)
|
||||
dec_outs = [x]
|
||||
for i in reversed(range(len(self.decoder))):
|
||||
x = self.decoder[i](enc_outs[i], x)
|
||||
dec_outs.append(x)
|
||||
|
||||
return dec_outs
|
||||
|
||||
def train(self, mode=True):
|
||||
"""Convert the model into training mode while keep normalization layer
|
||||
freezed."""
|
||||
super().train(mode)
|
||||
if mode and self.norm_eval:
|
||||
for m in self.modules():
|
||||
# trick: eval have effect on BatchNorm only
|
||||
if isinstance(m, _BatchNorm):
|
||||
m.eval()
|
||||
|
||||
def _check_input_divisible(self, x):
|
||||
h, w = x.shape[-2:]
|
||||
whole_downsample_rate = 1
|
||||
for i in range(1, self.num_stages):
|
||||
if self.strides[i] == 2 or self.downsamples[i - 1]:
|
||||
whole_downsample_rate *= 2
|
||||
assert (h % whole_downsample_rate == 0) \
|
||||
and (w % whole_downsample_rate == 0),\
|
||||
f'The input image size {(h, w)} should be divisible by the whole '\
|
||||
f'downsample rate {whole_downsample_rate}, when num_stages is '\
|
||||
f'{self.num_stages}, strides is {self.strides}, and downsamples '\
|
||||
f'is {self.downsamples}.'
|
||||
|
||||
def init_weights(self, pretrained=None):
|
||||
"""Initialize the weights in backbone.
|
||||
|
||||
Args:
|
||||
pretrained (str, optional): Path to pre-trained weights.
|
||||
Defaults to None.
|
||||
"""
|
||||
if isinstance(pretrained, str):
|
||||
logger = get_root_logger()
|
||||
load_checkpoint(self, pretrained, strict=False, logger=logger)
|
||||
elif pretrained is None:
|
||||
for m in self.modules():
|
||||
if isinstance(m, nn.Conv2d):
|
||||
kaiming_init(m)
|
||||
elif isinstance(m, (_BatchNorm, nn.GroupNorm)):
|
||||
constant_init(m, 1)
|
||||
else:
|
||||
raise TypeError('pretrained must be a str or None')
|
|
@ -0,0 +1,3 @@
|
|||
from .extractors import * # noqa: F401,F403
|
||||
from .heads import * # noqa: F401,F403
|
||||
from .losses import * # noqa: F401,F403
|
|
@ -0,0 +1,3 @@
|
|||
from .sdmgr import SDMGR
|
||||
|
||||
__all__ = ['SDMGR']
|
|
@ -0,0 +1,87 @@
|
|||
from torch import nn
|
||||
from torch.nn import functional as F
|
||||
|
||||
from mmdet.core import bbox2roi
|
||||
from mmdet.models.builder import DETECTORS, build_roi_extractor
|
||||
from mmdet.models.detectors import SingleStageDetector
|
||||
|
||||
|
||||
@DETECTORS.register_module()
|
||||
class SDMGR(SingleStageDetector):
|
||||
"""The implementation of the paper: Spatial Dual-Modality Graph Reasoning
|
||||
for Key Information Extraction. https://arxiv.org/abs/2103.14470.
|
||||
|
||||
Args:
|
||||
visual_modality (bool): Whether use the visual modality.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
backbone,
|
||||
neck=None,
|
||||
bbox_head=None,
|
||||
extractor=dict(
|
||||
type='SingleRoIExtractor',
|
||||
roi_layer=dict(type='RoIAlign', output_size=7),
|
||||
featmap_strides=[1]),
|
||||
visual_modality=False,
|
||||
train_cfg=None,
|
||||
test_cfg=None,
|
||||
pretrained=None):
|
||||
super().__init__(backbone, neck, bbox_head, train_cfg, test_cfg,
|
||||
pretrained)
|
||||
self.visual_modality = visual_modality
|
||||
if visual_modality:
|
||||
self.extractor = build_roi_extractor({
|
||||
**extractor, 'out_channels':
|
||||
self.backbone.base_channels
|
||||
})
|
||||
self.maxpool = nn.MaxPool2d(extractor['roi_layer']['output_size'])
|
||||
else:
|
||||
self.extractor = None
|
||||
|
||||
def forward_train(self, img, img_metas, relations, texts, gt_bboxes,
|
||||
gt_labels):
|
||||
"""
|
||||
Args:
|
||||
img (tensor): Input images of shape (N, C, H, W).
|
||||
Typically these should be mean centered and std scaled.
|
||||
img_metas (list[dict]): A list of image info dict where each dict
|
||||
contains: 'img_shape', 'scale_factor', 'flip', and may also
|
||||
contain 'filename', 'ori_shape', 'pad_shape', and
|
||||
'img_norm_cfg'. For details of the values of these keys,
|
||||
please see :class:`mmdet.datasets.pipelines.Collect`.
|
||||
relations (list[tensor]): Relations between bboxes.
|
||||
texts (list[tensor]): Texts in bboxes.
|
||||
gt_bboxes (list[tensor]): Each item is the truth boxes for each
|
||||
image in [tl_x, tl_y, br_x, br_y] format.
|
||||
gt_labels (list[tensor]): Class indices corresponding to each box.
|
||||
|
||||
Returns:
|
||||
dict[str, tensor]: A dictionary of loss components.
|
||||
"""
|
||||
x = self.extract_feat(img, gt_bboxes)
|
||||
node_preds, edge_preds = self.bbox_head.forward(relations, texts, x)
|
||||
return self.bbox_head.loss(node_preds, edge_preds, gt_labels)
|
||||
|
||||
def forward_test(self,
|
||||
img,
|
||||
img_metas,
|
||||
relations,
|
||||
texts,
|
||||
gt_bboxes,
|
||||
rescale=False):
|
||||
x = self.extract_feat(img, gt_bboxes)
|
||||
node_preds, edge_preds = self.bbox_head.forward(relations, texts, x)
|
||||
return [
|
||||
dict(
|
||||
img_metas=img_metas,
|
||||
nodes=F.softmax(node_preds, -1),
|
||||
edges=F.softmax(edge_preds, -1))
|
||||
]
|
||||
|
||||
def extract_feat(self, img, gt_bboxes):
|
||||
if self.visual_modality:
|
||||
x = super().extract_feat(img)[-1]
|
||||
feats = self.maxpool(self.extractor([x], bbox2roi(gt_bboxes)))
|
||||
return feats.view(feats.size(0), -1)
|
||||
return None
|
|
@ -0,0 +1,3 @@
|
|||
from .sdmgr_head import SDMGRHead
|
||||
|
||||
__all__ = ['SDMGRHead']
|
|
@ -0,0 +1,193 @@
|
|||
import torch
|
||||
from mmcv.cnn import normal_init
|
||||
from torch import nn
|
||||
from torch.nn import functional as F
|
||||
|
||||
from mmdet.models.builder import HEADS, build_loss
|
||||
|
||||
|
||||
@HEADS.register_module()
|
||||
class SDMGRHead(nn.Module):
|
||||
|
||||
def __init__(self,
|
||||
num_chars=92,
|
||||
visual_dim=64,
|
||||
fusion_dim=1024,
|
||||
node_input=32,
|
||||
node_embed=256,
|
||||
edge_input=5,
|
||||
edge_embed=256,
|
||||
num_gnn=2,
|
||||
num_classes=26,
|
||||
loss=dict(type='SDMGRLoss'),
|
||||
bidirectional=False,
|
||||
train_cfg=None,
|
||||
test_cfg=None):
|
||||
super().__init__()
|
||||
|
||||
self.fusion = Block([visual_dim, node_embed], node_embed, fusion_dim)
|
||||
self.node_embed = nn.Embedding(num_chars, node_input, 0)
|
||||
hidden = node_embed // 2 if bidirectional else node_embed
|
||||
self.rnn = nn.LSTM(
|
||||
input_size=node_input,
|
||||
hidden_size=hidden,
|
||||
num_layers=1,
|
||||
batch_first=True,
|
||||
bidirectional=bidirectional)
|
||||
self.edge_embed = nn.Linear(edge_input, edge_embed)
|
||||
self.gnn_layers = nn.ModuleList(
|
||||
[GNNLayer(node_embed, edge_embed) for _ in range(num_gnn)])
|
||||
self.node_cls = nn.Linear(node_embed, num_classes)
|
||||
self.edge_cls = nn.Linear(edge_embed, 2)
|
||||
self.loss = build_loss(loss)
|
||||
|
||||
def init_weights(self, pretrained=False):
|
||||
normal_init(self.edge_embed, mean=0, std=0.01)
|
||||
|
||||
def forward(self, relations, texts, x=None):
|
||||
node_nums, char_nums = [], []
|
||||
for text in texts:
|
||||
node_nums.append(text.size(0))
|
||||
char_nums.append((text > 0).sum(-1))
|
||||
|
||||
max_num = max([char_num.max() for char_num in char_nums])
|
||||
all_nodes = torch.cat([
|
||||
torch.cat(
|
||||
[text,
|
||||
text.new_zeros(text.size(0), max_num - text.size(1))], -1)
|
||||
for text in texts
|
||||
])
|
||||
embed_nodes = self.node_embed(all_nodes.clamp(min=0).long())
|
||||
rnn_nodes, _ = self.rnn(embed_nodes)
|
||||
|
||||
nodes = rnn_nodes.new_zeros(*rnn_nodes.shape[::2])
|
||||
all_nums = torch.cat(char_nums)
|
||||
valid = all_nums > 0
|
||||
nodes[valid] = rnn_nodes[valid].gather(
|
||||
1, (all_nums[valid] - 1).unsqueeze(-1).unsqueeze(-1).expand(
|
||||
-1, -1, rnn_nodes.size(-1))).squeeze(1)
|
||||
|
||||
if x is not None:
|
||||
nodes = self.fusion([x, nodes])
|
||||
|
||||
all_edges = torch.cat(
|
||||
[rel.view(-1, rel.size(-1)) for rel in relations])
|
||||
embed_edges = self.edge_embed(all_edges.float())
|
||||
embed_edges = F.normalize(embed_edges)
|
||||
|
||||
for gnn_layer in self.gnn_layers:
|
||||
nodes, cat_nodes = gnn_layer(nodes, embed_edges, node_nums)
|
||||
|
||||
node_cls, edge_cls = self.node_cls(nodes), self.edge_cls(cat_nodes)
|
||||
return node_cls, edge_cls
|
||||
|
||||
|
||||
class GNNLayer(nn.Module):
|
||||
|
||||
def __init__(self, node_dim=256, edge_dim=256):
|
||||
super().__init__()
|
||||
self.in_fc = nn.Linear(node_dim * 2 + edge_dim, node_dim)
|
||||
self.coef_fc = nn.Linear(node_dim, 1)
|
||||
self.out_fc = nn.Linear(node_dim, node_dim)
|
||||
self.relu = nn.ReLU()
|
||||
|
||||
def forward(self, nodes, edges, nums):
|
||||
start, cat_nodes = 0, []
|
||||
for num in nums:
|
||||
sample_nodes = nodes[start:start + num]
|
||||
cat_nodes.append(
|
||||
torch.cat([
|
||||
sample_nodes.unsqueeze(1).expand(-1, num, -1),
|
||||
sample_nodes.unsqueeze(0).expand(num, -1, -1)
|
||||
], -1).view(num**2, -1))
|
||||
start += num
|
||||
cat_nodes = torch.cat([torch.cat(cat_nodes), edges], -1)
|
||||
cat_nodes = self.relu(self.in_fc(cat_nodes))
|
||||
coefs = self.coef_fc(cat_nodes)
|
||||
|
||||
start, residuals = 0, []
|
||||
for num in nums:
|
||||
residual = F.softmax(
|
||||
-torch.eye(num).to(coefs.device).unsqueeze(-1) * 1e9 +
|
||||
coefs[start:start + num**2].view(num, num, -1), 1)
|
||||
residuals.append(
|
||||
(residual *
|
||||
cat_nodes[start:start + num**2].view(num, num, -1)).sum(1))
|
||||
start += num**2
|
||||
|
||||
nodes += self.relu(self.out_fc(torch.cat(residuals)))
|
||||
return nodes, cat_nodes
|
||||
|
||||
|
||||
class Block(nn.Module):
|
||||
|
||||
def __init__(self,
|
||||
input_dims,
|
||||
output_dim,
|
||||
mm_dim=1600,
|
||||
chunks=20,
|
||||
rank=15,
|
||||
shared=False,
|
||||
dropout_input=0.,
|
||||
dropout_pre_lin=0.,
|
||||
dropout_output=0.,
|
||||
pos_norm='before_cat'):
|
||||
super().__init__()
|
||||
self.rank = rank
|
||||
self.dropout_input = dropout_input
|
||||
self.dropout_pre_lin = dropout_pre_lin
|
||||
self.dropout_output = dropout_output
|
||||
assert (pos_norm in ['before_cat', 'after_cat'])
|
||||
self.pos_norm = pos_norm
|
||||
# Modules
|
||||
self.linear0 = nn.Linear(input_dims[0], mm_dim)
|
||||
self.linear1 = self.linear0 if shared \
|
||||
else nn.Linear(input_dims[1], mm_dim)
|
||||
self.merge_linears0, self.merge_linears1 =\
|
||||
nn.ModuleList(), nn.ModuleList()
|
||||
self.chunks = self.chunk_sizes(mm_dim, chunks)
|
||||
for size in self.chunks:
|
||||
ml0 = nn.Linear(size, size * rank)
|
||||
self.merge_linears0.append(ml0)
|
||||
ml1 = ml0 if shared else nn.Linear(size, size * rank)
|
||||
self.merge_linears1.append(ml1)
|
||||
self.linear_out = nn.Linear(mm_dim, output_dim)
|
||||
|
||||
def forward(self, x):
|
||||
x0 = self.linear0(x[0])
|
||||
x1 = self.linear1(x[1])
|
||||
bs = x1.size(0)
|
||||
if self.dropout_input > 0:
|
||||
x0 = F.dropout(x0, p=self.dropout_input, training=self.training)
|
||||
x1 = F.dropout(x1, p=self.dropout_input, training=self.training)
|
||||
x0_chunks = torch.split(x0, self.chunks, -1)
|
||||
x1_chunks = torch.split(x1, self.chunks, -1)
|
||||
zs = []
|
||||
for x0_c, x1_c, m0, m1 in zip(x0_chunks, x1_chunks,
|
||||
self.merge_linears0,
|
||||
self.merge_linears1):
|
||||
m = m0(x0_c) * m1(x1_c) # bs x split_size*rank
|
||||
m = m.view(bs, self.rank, -1)
|
||||
z = torch.sum(m, 1)
|
||||
if self.pos_norm == 'before_cat':
|
||||
z = torch.sqrt(F.relu(z)) - torch.sqrt(F.relu(-z))
|
||||
z = F.normalize(z)
|
||||
zs.append(z)
|
||||
z = torch.cat(zs, 1)
|
||||
if self.pos_norm == 'after_cat':
|
||||
z = torch.sqrt(F.relu(z)) - torch.sqrt(F.relu(-z))
|
||||
z = F.normalize(z)
|
||||
|
||||
if self.dropout_pre_lin > 0:
|
||||
z = F.dropout(z, p=self.dropout_pre_lin, training=self.training)
|
||||
z = self.linear_out(z)
|
||||
if self.dropout_output > 0:
|
||||
z = F.dropout(z, p=self.dropout_output, training=self.training)
|
||||
return z
|
||||
|
||||
@staticmethod
|
||||
def chunk_sizes(dim, chunks):
|
||||
split_size = (dim + chunks - 1) // chunks
|
||||
sizes_list = [split_size] * chunks
|
||||
sizes_list[-1] = sizes_list[-1] - (sum(sizes_list) - dim)
|
||||
return sizes_list
|
|
@ -0,0 +1,3 @@
|
|||
from .sdmgr_loss import SDMGRLoss
|
||||
|
||||
__all__ = ['SDMGRLoss']
|
|
@ -0,0 +1,39 @@
|
|||
import torch
|
||||
from torch import nn
|
||||
|
||||
from mmdet.models.builder import LOSSES
|
||||
from mmdet.models.losses import accuracy
|
||||
|
||||
|
||||
@LOSSES.register_module()
|
||||
class SDMGRLoss(nn.Module):
|
||||
"""The implementation the loss of key information extraction proposed in
|
||||
the paper: Spatial Dual-Modality Graph Reasoning for Key Information
|
||||
Extraction.
|
||||
|
||||
https://arxiv.org/abs/2103.14470.
|
||||
"""
|
||||
|
||||
def __init__(self, node_weight=1.0, edge_weight=1.0, ignore=0):
|
||||
super().__init__()
|
||||
self.loss_node = nn.CrossEntropyLoss(ignore_index=ignore)
|
||||
self.loss_edge = nn.CrossEntropyLoss(ignore_index=-1)
|
||||
self.node_weight = node_weight
|
||||
self.edge_weight = edge_weight
|
||||
self.ignore = ignore
|
||||
|
||||
def forward(self, node_preds, edge_preds, gts):
|
||||
node_gts, edge_gts = [], []
|
||||
for gt in gts:
|
||||
node_gts.append(gt[:, 0])
|
||||
edge_gts.append(gt[:, 1:].contiguous().view(-1))
|
||||
node_gts = torch.cat(node_gts).long()
|
||||
edge_gts = torch.cat(edge_gts).long()
|
||||
|
||||
node_valids = torch.nonzero(node_gts != self.ignore).view(-1)
|
||||
edge_valids = torch.nonzero(edge_gts != -1).view(-1)
|
||||
return dict(
|
||||
loss_node=self.node_weight * self.loss_node(node_preds, node_gts),
|
||||
loss_edge=self.edge_weight * self.loss_edge(edge_preds, edge_gts),
|
||||
acc_node=accuracy(node_preds[node_valids], node_gts[node_valids]),
|
||||
acc_edge=accuracy(edge_preds[edge_valids], edge_gts[edge_valids]))
|
Loading…
Reference in New Issue