From b8156a3a771486095595aba8398ffc3f6fe27a8f Mon Sep 17 00:00:00 2001 From: Hongbin Sun Date: Sat, 3 Apr 2021 00:44:12 +0800 Subject: [PATCH] [feature]: add code for kie and textsnake config --- configs/kie/sdmgr/README.md | 25 + .../sdmgr/sdmgr_novisual_60e_wildreceipt.py | 99 ++++ .../kie/sdmgr/sdmgr_unet16_60e_wildreceipt.py | 99 ++++ configs/textdet/textsnake/README.md | 23 + .../textsnake_r50_fpn_unet_1200e_ctw1500.py | 113 ++++ mmocr/core/evaluation/kie_metric.py | 27 + mmocr/datasets/kie_dataset.py | 295 ++++++++++ mmocr/datasets/pipelines/kie_transforms.py | 55 ++ mmocr/models/common/backbones/__init__.py | 3 + mmocr/models/common/backbones/unet.py | 528 ++++++++++++++++++ mmocr/models/kie/__init__.py | 3 + mmocr/models/kie/extractors/__init__.py | 3 + mmocr/models/kie/extractors/sdmgr.py | 87 +++ mmocr/models/kie/heads/__init__.py | 3 + mmocr/models/kie/heads/sdmgr_head.py | 193 +++++++ mmocr/models/kie/losses/__init__.py | 3 + mmocr/models/kie/losses/sdmgr_loss.py | 39 ++ 17 files changed, 1598 insertions(+) create mode 100644 configs/kie/sdmgr/README.md create mode 100644 configs/kie/sdmgr/sdmgr_novisual_60e_wildreceipt.py create mode 100644 configs/kie/sdmgr/sdmgr_unet16_60e_wildreceipt.py create mode 100644 configs/textdet/textsnake/README.md create mode 100644 configs/textdet/textsnake/textsnake_r50_fpn_unet_1200e_ctw1500.py create mode 100644 mmocr/core/evaluation/kie_metric.py create mode 100644 mmocr/datasets/kie_dataset.py create mode 100644 mmocr/datasets/pipelines/kie_transforms.py create mode 100644 mmocr/models/common/backbones/__init__.py create mode 100644 mmocr/models/common/backbones/unet.py create mode 100644 mmocr/models/kie/__init__.py create mode 100644 mmocr/models/kie/extractors/__init__.py create mode 100644 mmocr/models/kie/extractors/sdmgr.py create mode 100644 mmocr/models/kie/heads/__init__.py create mode 100644 mmocr/models/kie/heads/sdmgr_head.py create mode 100644 mmocr/models/kie/losses/__init__.py create mode 100644 mmocr/models/kie/losses/sdmgr_loss.py diff --git a/configs/kie/sdmgr/README.md b/configs/kie/sdmgr/README.md new file mode 100644 index 00000000..d8d7e878 --- /dev/null +++ b/configs/kie/sdmgr/README.md @@ -0,0 +1,25 @@ +# Spatial Dual-Modality Graph Reasoning for Key Information Extraction + +## Introduction + +[ALGORITHM] + +```bibtex +@misc{sun2021spatial, + title={Spatial Dual-Modality Graph Reasoning for Key Information Extraction}, + author={Hongbin Sun and Zhanghui Kuang and Xiaoyu Yue and Chenhao Lin and Wayne Zhang}, + year={2021}, + eprint={2103.14470}, + archivePrefix={arXiv}, + primaryClass={cs.CV} +} +``` + +## Results and models + +### WildReceipt + +| Method | Modality | Macro F1-Score | Download | +| :--------------------------------------------------------------------: | :--------------: | :------------: | :-------------------------------------------------------------------------------------------------------------------------------------: | +| [sdmgr_unet16](/configs/kie/sdmgr/sdmgr_unet16_60e_wildreceipt.py) | Visual + Textual | 0.880 | [model](https://download.openmmlab.com/mmocr/kie/sdmgr/todo.pth) \| [log](https://download.openmmlab.com/mmocr/kie/sdmgr/todo.log.json) | +| [sdmgr_novisual](/configs/kie/sdmgr/sdmgr_novisual_60e_wildreceipt.py) | Textual | 0.871 | [model](https://download.openmmlab.com/mmocr/kie/sdmgr/todo.pth) \| [log](https://download.openmmlab.com/mmocr/kie/sdmgr/todo.log.json) | diff --git a/configs/kie/sdmgr/sdmgr_novisual_60e_wildreceipt.py b/configs/kie/sdmgr/sdmgr_novisual_60e_wildreceipt.py new file mode 100644 index 00000000..00807708 --- /dev/null +++ b/configs/kie/sdmgr/sdmgr_novisual_60e_wildreceipt.py @@ -0,0 +1,99 @@ +dataset_type = 'KIEDataset' +data_root = 'data/wildreceipt' +img_norm_cfg = dict( + mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True) +max_scale, min_scale = 1024, 512 + +train_pipeline = [ + dict(type='LoadImageFromFile'), + dict(type='LoadAnnotations'), + dict(type='Resize', img_scale=(max_scale, min_scale), keep_ratio=True), + dict(type='RandomFlip', flip_ratio=0.), + dict(type='Normalize', **img_norm_cfg), + dict(type='Pad', size_divisor=32), + dict(type='KIEFormatBundle'), + dict( + type='Collect', + keys=['img', 'relations', 'texts', 'gt_bboxes', 'gt_labels']) +] +test_pipeline = [ + dict(type='LoadImageFromFile'), + dict(type='LoadAnnotations'), + dict(type='Resize', img_scale=(max_scale, min_scale), keep_ratio=True), + dict(type='RandomFlip', flip_ratio=0.), + dict(type='Normalize', **img_norm_cfg), + dict(type='Pad', size_divisor=32), + dict(type='KIEFormatBundle'), + dict(type='Collect', keys=['img', 'relations', 'texts', 'gt_bboxes']) +] + +vocab_file = 'dict.txt' +class_file = 'class_list.txt' + +data = dict( + samples_per_gpu=4, + workers_per_gpu=0, + train=dict( + type=dataset_type, + ann_file='train.txt', + pipeline=train_pipeline, + data_root=data_root, + vocab_file=vocab_file, + class_file=class_file), + val=dict( + type=dataset_type, + ann_file='test.txt', + pipeline=test_pipeline, + data_root=data_root, + vocab_file=vocab_file, + class_file=class_file), + test=dict( + type=dataset_type, + ann_file='test.txt', + pipeline=test_pipeline, + data_root=data_root, + vocab_file=vocab_file, + class_file=class_file)) + +evaluation = dict( + interval=1, + metric='macro_f1', + metric_options=dict( + macro_f1=dict( + ignores=[0, 2, 4, 6, 8, 10, 12, 14, 16, 18, 20, 22, 24, 25]))) + +model = dict( + type='SDMGR', + backbone=dict(type='UNet', base_channels=16), + bbox_head=dict( + type='SDMGRHead', visual_dim=16, num_chars=92, num_classes=26), + visual_modality=False, + train_cfg=None, + test_cfg=None) + +optimizer = dict(type='Adam', weight_decay=0.0001) +optimizer_config = dict(grad_clip=None) +lr_config = dict( + policy='step', + warmup='linear', + warmup_iters=1, + warmup_ratio=1, + step=[40, 50]) +total_epochs = 60 + +checkpoint_config = dict(interval=1) +log_config = dict( + interval=50, + hooks=[ + dict(type='TextLoggerHook'), + # dict( + # type='PaviLoggerHook', + # add_last_ckpt=True, + # interval=5, + # init_kwargs=dict(project='kie')), + ]) +dist_params = dict(backend='nccl') +log_level = 'INFO' +load_from = None +resume_from = None +workflow = [('train', 1)] diff --git a/configs/kie/sdmgr/sdmgr_unet16_60e_wildreceipt.py b/configs/kie/sdmgr/sdmgr_unet16_60e_wildreceipt.py new file mode 100644 index 00000000..05beec2a --- /dev/null +++ b/configs/kie/sdmgr/sdmgr_unet16_60e_wildreceipt.py @@ -0,0 +1,99 @@ +dataset_type = 'KIEDataset' +data_root = 'data/wildreceipt' +img_norm_cfg = dict( + mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True) +max_scale, min_scale = 1024, 512 + +train_pipeline = [ + dict(type='LoadImageFromFile'), + dict(type='LoadAnnotations'), + dict(type='Resize', img_scale=(max_scale, min_scale), keep_ratio=True), + dict(type='RandomFlip', flip_ratio=0.), + dict(type='Normalize', **img_norm_cfg), + dict(type='Pad', size_divisor=32), + dict(type='KIEFormatBundle'), + dict( + type='Collect', + keys=['img', 'relations', 'texts', 'gt_bboxes', 'gt_labels']) +] +test_pipeline = [ + dict(type='LoadImageFromFile'), + dict(type='LoadAnnotations'), + dict(type='Resize', img_scale=(max_scale, min_scale), keep_ratio=True), + dict(type='RandomFlip', flip_ratio=0.), + dict(type='Normalize', **img_norm_cfg), + dict(type='Pad', size_divisor=32), + dict(type='KIEFormatBundle'), + dict(type='Collect', keys=['img', 'relations', 'texts', 'gt_bboxes']) +] + +vocab_file = 'dict.txt' +class_file = 'class_list.txt' + +data = dict( + samples_per_gpu=4, + workers_per_gpu=0, + train=dict( + type=dataset_type, + ann_file='train.txt', + pipeline=train_pipeline, + data_root=data_root, + vocab_file=vocab_file, + class_file=class_file), + val=dict( + type=dataset_type, + ann_file='test.txt', + pipeline=test_pipeline, + data_root=data_root, + vocab_file=vocab_file, + class_file=class_file), + test=dict( + type=dataset_type, + ann_file='test.txt', + pipeline=test_pipeline, + data_root=data_root, + vocab_file=vocab_file, + class_file=class_file)) + +evaluation = dict( + interval=1, + metric='macro_f1', + metric_options=dict( + macro_f1=dict( + ignores=[0, 2, 4, 6, 8, 10, 12, 14, 16, 18, 20, 22, 24, 25]))) + +model = dict( + type='SDMGR', + backbone=dict(type='UNet', base_channels=16), + bbox_head=dict( + type='SDMGRHead', visual_dim=16, num_chars=92, num_classes=26), + visual_modality=True, + train_cfg=None, + test_cfg=None) + +optimizer = dict(type='Adam', weight_decay=0.0001) +optimizer_config = dict(grad_clip=None) +lr_config = dict( + policy='step', + warmup='linear', + warmup_iters=1, + warmup_ratio=1, + step=[40, 50]) +total_epochs = 60 + +checkpoint_config = dict(interval=1) +log_config = dict( + interval=50, + hooks=[ + dict(type='TextLoggerHook'), + # dict( + # type='PaviLoggerHook', + # add_last_ckpt=True, + # interval=5, + # init_kwargs=dict(project='kie')), + ]) +dist_params = dict(backend='nccl') +log_level = 'INFO' +load_from = None +resume_from = None +workflow = [('train', 1)] diff --git a/configs/textdet/textsnake/README.md b/configs/textdet/textsnake/README.md new file mode 100644 index 00000000..50812e37 --- /dev/null +++ b/configs/textdet/textsnake/README.md @@ -0,0 +1,23 @@ +# Textsnake + +## Introduction + +[ALGORITHM] + +```bibtex +@article{long2018textsnake, + title={TextSnake: A Flexible Representation for Detecting Text of Arbitrary Shapes}, + author={Long, Shangbang and Ruan, Jiaqiang and Zhang, Wenjie and He, Xin and Wu, Wenhao and Yao, Cong}, + booktitle={ECCV}, + pages={20-36}, + year={2018} +} +``` + +## Results and models + +### CTW1500 + +| Method | Pretrained Model | Training set | Test set | #epochs | Test size | Recall | Precision | Hmean | Download | +| :----------------------------------------------------------------------------: | :--------------: | :-----------: | :----------: | :-----: | :-------: | :----: | :-------: | :---: | :-------------------: | +| [TextSnake](/configs/textdet/textsnake/textsnake_r50_fpn_unet_600e_ctw1500.py) | ImageNet | CTW1500 Train | CTW1500 Test | 1200 | 736 | 0.795 | 0.840 | 0.817 | [model](https://download.openmmlab.com/mmocr/textdet/textsnake/textsnake_r50_fpn_unet_1200e_ctw1500-27f65b64.pth) | [config](https://download.openmmlab.com/mmocr/textdet/textsnake/textsnake_r50_fpn_unet_1200e_ctw1500.py) | diff --git a/configs/textdet/textsnake/textsnake_r50_fpn_unet_1200e_ctw1500.py b/configs/textdet/textsnake/textsnake_r50_fpn_unet_1200e_ctw1500.py new file mode 100644 index 00000000..dba03cd1 --- /dev/null +++ b/configs/textdet/textsnake/textsnake_r50_fpn_unet_1200e_ctw1500.py @@ -0,0 +1,113 @@ +_base_ = [ + '../../_base_/schedules/schedule_1200e.py', + '../../_base_/default_runtime.py' +] +model = dict( + type='TextSnake', + pretrained='torchvision://resnet50', + backbone=dict( + type='ResNet', + depth=50, + num_stages=4, + out_indices=(0, 1, 2, 3), + frozen_stages=-1, + norm_cfg=dict(type='BN', requires_grad=True), + norm_eval=True, + style='caffe'), + neck=dict( + type='FPN_UNET', in_channels=[256, 512, 1024, 2048], out_channels=32), + bbox_head=dict( + type='TextSnakeHead', + in_channels=32, + text_repr_type='poly', + loss=dict(type='TextSnakeLoss')), + train_cfg=None, + test_cfg=None) + +dataset_type = 'IcdarDataset' +data_root = 'data/ctw1500/' +img_norm_cfg = dict( + mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True) + +train_pipeline = [ + dict(type='LoadImageFromFile'), + dict( + type='LoadTextAnnotations', + with_bbox=True, + with_mask=True, + poly2mask=False), + dict(type='ColorJitter', brightness=32.0 / 255, saturation=0.5), + dict(type='Normalize', **img_norm_cfg), + dict( + type='RandomCropPolyInstances', + instance_key='gt_masks', + crop_ratio=0.65, + min_side_ratio=0.3), + dict( + type='RandomRotatePolyInstances', + rotate_ratio=0.5, + max_angle=20, + pad_with_fixed_color=False), + dict( + type='ScaleAspectJitter', + img_scale=[(3000, 736)], # unused + ratio_range=(0.7, 1.3), + aspect_ratio_range=(0.9, 1.1), + multiscale_mode='value', + long_size_bound=800, + short_size_bound=480, + resize_type='long_short_bound', + keep_ratio=False), + dict(type='SquareResizePad', target_size=800, pad_ratio=0.6), + dict(type='RandomFlip', flip_ratio=0.5, direction='horizontal'), + dict(type='TextSnakeTargets'), + dict(type='Pad', size_divisor=32), + dict( + type='CustomFormatBundle', + keys=[ + 'gt_text_mask', 'gt_center_region_mask', 'gt_mask', + 'gt_radius_map', 'gt_sin_map', 'gt_cos_map' + ], + visualize=dict(flag=False, boundary_key='gt_text_mask')), + dict( + type='Collect', + keys=[ + 'img', 'gt_text_mask', 'gt_center_region_mask', 'gt_mask', + 'gt_radius_map', 'gt_sin_map', 'gt_cos_map' + ]) +] +test_pipeline = [ + dict(type='LoadImageFromFile'), + dict( + type='MultiScaleFlipAug', + img_scale=(1333, 736), + flip=False, + transforms=[ + dict(type='Resize', img_scale=(1333, 736), keep_ratio=True), + dict(type='Normalize', **img_norm_cfg), + dict(type='Pad', size_divisor=32), + dict(type='ImageToTensor', keys=['img']), + dict(type='Collect', keys=['img']), + ]) +] + +data = dict( + samples_per_gpu=4, + workers_per_gpu=4, + train=dict( + type=dataset_type, + ann_file=data_root + '/instances_training.json', + img_prefix=data_root + '/imgs', + pipeline=train_pipeline), + val=dict( + type=dataset_type, + ann_file=data_root + '/instances_test.json', + img_prefix=data_root + '/imgs', + pipeline=test_pipeline), + test=dict( + type=dataset_type, + ann_file=data_root + '/instances_test.json', + img_prefix=data_root + '/imgs', + pipeline=test_pipeline)) + +evaluation = dict(interval=10, metric='hmean-iou') diff --git a/mmocr/core/evaluation/kie_metric.py b/mmocr/core/evaluation/kie_metric.py new file mode 100644 index 00000000..00dc2387 --- /dev/null +++ b/mmocr/core/evaluation/kie_metric.py @@ -0,0 +1,27 @@ +import torch + + +def compute_f1_score(preds, gts, ignores=[]): + """Compute the F1-score of prediction. + + Args: + preds (Tensor): The predicted probability NxC map + with N and C being the sample number and class + number respectively. + gts (Tensor): The ground truth vector of size N. + ignores (list): The index set of classes that are ignored when + reporting results. + Note: all samples are participated in computing. + + Returns: + The numpy list of f1-scores of valid classes. + """ + C = preds.size(1) + classes = torch.LongTensor(sorted(set(range(C)) - set(ignores))) + hist = torch.bincount( + gts * C + preds.argmax(1), minlength=C**2).view(C, C).float() + diag = torch.diag(hist) + recalls = diag / hist.sum(1).clamp(min=1) + precisions = diag / hist.sum(0).clamp(min=1) + f1 = 2 * recalls * precisions / (recalls + precisions).clamp(min=1e-8) + return f1[classes].cpu().numpy() diff --git a/mmocr/datasets/kie_dataset.py b/mmocr/datasets/kie_dataset.py new file mode 100644 index 00000000..931ea3d2 --- /dev/null +++ b/mmocr/datasets/kie_dataset.py @@ -0,0 +1,295 @@ +import copy +from os import path as osp + +import mmcv +import numpy as np +import torch +from matplotlib import pyplot as plt +from PIL import Image + +from mmdet.datasets.builder import DATASETS +from mmdet.datasets.custom import CustomDataset +from mmocr.core import compute_f1_score + + +@DATASETS.register_module() +class KIEDataset(CustomDataset): + + def __init__(self, + ann_file, + pipeline=None, + data_root=None, + img_prefix='', + ann_prefix='', + vocab_file=None, + class_file=None, + norm=10., + thresholds=dict(edge=0.5), + directed=False, + **kwargs): + self.ann_prefix = ann_prefix + self.norm = norm + self.thresholds = thresholds + self.directed = directed + + if data_root is not None: + if not osp.isabs(ann_file): + self.ann_file = osp.join(data_root, ann_file) + if not (ann_prefix is None or osp.isabs(ann_prefix)): + self.ann_prefix = osp.join(data_root, ann_prefix) + + self.vocab = dict({'': 0}) + vocab_file = osp.join(data_root, vocab_file) + if osp.exists(vocab_file): + with open(vocab_file, 'r') as fid: + for idx, char in enumerate(fid.readlines(), 1): + self.vocab[char.strip('\n')] = idx + else: + self.construct_dict(self.ann_file) + with open(vocab_file, 'w') as fid: + for key in self.vocab: + if key: + fid.write('{}\n'.format(key)) + + super().__init__( + ann_file, + pipeline, + data_root=data_root, + img_prefix=img_prefix, + **kwargs) + + self.idx_to_cls = dict() + with open(osp.join(data_root, class_file), 'r') as fid: + for line in fid.readlines(): + idx, cls = line.split() + self.idx_to_cls[int(idx)] = cls + + @staticmethod + def _split_edge(line): + text = ','.join(line[8:-1]) + if ';' in text and text.split(';')[0].isdecimal(): + edge, text = text.split(';', 1) + edge = int(edge) + else: + edge = 0 + return edge, text + + def construct_dict(self, ann_file): + img_infos = mmcv.list_from_file(ann_file) + for img_info in img_infos: + _, annname = img_info.split() + if self.ann_prefix: + annname = osp.join(self.ann_prefix, annname) + with open(annname, 'r') as fid: + lines = fid.readlines() + + for line in lines: + line = line.strip().split(',') + _, text = self._split_edge(line) + for c in text: + if c not in self.vocab: + self.vocab[c] = len(self.vocab) + self.vocab = dict( + {k: idx + for idx, k in enumerate(sorted(self.vocab.keys()))}) + + def convert_text(self, text): + return [self.vocab[c] for c in text if c in self.vocab] + + def parse_lines(self, annname): + boxes, edges, texts, chars, labels = [], [], [], [], [] + + if self.ann_prefix: + annname = osp.join(self.ann_prefix, annname) + + with open(annname, 'r') as fid: + for line in fid.readlines(): + line = line.strip().split(',') + boxes.append(list(map(int, line[:8]))) + edge, text = self._split_edge(line) + chars.append(text) + text = self.convert_text(text) + texts.append(text) + edges.append(edge) + labels.append(int(line[-1])) + return dict( + boxes=boxes, edges=edges, texts=texts, chars=chars, labels=labels) + + def format_results(self, results): + boxes = torch.Tensor(results['boxes'])[:, [0, 1, 4, 5]].cuda() + + if 'nodes' in results: + nodes, edges = results['nodes'], results['edges'] + labels = nodes.argmax(-1) + num_nodes = nodes.size(0) + edges = edges[:, -1].view(num_nodes, num_nodes) + else: + labels = torch.Tensor(results['labels']).cuda() + edges = torch.Tensor(results['edges']).cuda() + boxes = torch.cat([boxes, labels[:, None].float()], -1) + + return { + **{ + k: v + for k, v in results.items() if k not in ['boxes', 'edges'] + }, 'boxes': boxes, + 'edges': edges, + 'points': results['boxes'] + } + + def plot(self, results): + img_name = osp.join(self.img_prefix, results['filename']) + img = plt.imread(img_name) + plt.imshow(img) + + boxes, texts = results['points'], results['chars'] + num_nodes = len(boxes) + if 'scores' in results: + scores = results['scores'] + else: + scores = np.ones(num_nodes) + for box, text, score in zip(boxes, texts, scores): + xs, ys = [], [] + for idx in range(0, 10, 2): + xs.append(box[idx % 8]) + ys.append(box[(idx + 1) % 8]) + plt.plot(xs, ys, 'g') + plt.annotate( + '{}: {:.4f}'.format(text, score), (box[0], box[1]), color='g') + + if 'nodes' in results: + nodes = results['nodes'] + inds = nodes.argmax(-1) + else: + nodes = np.ones((num_nodes, 3)) + inds = results['labels'] + for i in range(num_nodes): + plt.annotate( + '{}: {:.4f}'.format( + self.idx_to_cls(inds[i] - 1), nodes[i, inds[i]]), + (boxes[i][6], boxes[i][7]), + color='r' if inds[i] == 1 else 'b') + edges = results['edges'] + if 'nodes' not in results: + edges = (edges[:, None] == edges[None]).float() + for j in range(i + 1, num_nodes): + edge_score = max(edges[i][j], edges[j][i]) + if edge_score > self.thresholds['edge']: + x1 = sum(boxes[i][:3:2]) // 2 + y1 = sum(boxes[i][3:6:2]) // 2 + x2 = sum(boxes[j][:3:2]) // 2 + y2 = sum(boxes[j][3:6:2]) // 2 + plt.plot((x1, x2), (y1, y2), 'r') + plt.annotate( + '{:.4f}'.format(edge_score), + ((x1 + x2) // 2, (y1 + y2) // 2), + color='r') + + def compute_relation(self, boxes): + x1s, y1s = boxes[:, 0:1], boxes[:, 1:2] + x2s, y2s = boxes[:, 4:5], boxes[:, 5:6] + ws, hs = x2s - x1s + 1, np.maximum(y2s - y1s + 1, 1) + dxs = (x1s[:, 0][None] - x1s) / self.norm + dys = (y1s[:, 0][None] - y1s) / self.norm + xhhs, xwhs = hs[:, 0][None] / hs, ws[:, 0][None] / hs + whs = ws / hs + np.zeros_like(xhhs) + relations = np.stack([dxs, dys, whs, xhhs, xwhs], -1) + bboxes = np.concatenate([x1s, y1s, x2s, y2s], -1).astype(np.float32) + return relations, bboxes + + def ann_numpy(self, results): + boxes, texts = results['boxes'], results['texts'] + boxes = np.array(boxes, np.int32) + if boxes[0, 1] > boxes[0, -1]: + boxes = boxes[:, [6, 7, 4, 5, 2, 3, 0, 1]] + relations, bboxes = self.compute_relation(boxes) + + labels = results.get('labels', None) + if labels is not None: + labels = np.array(labels, np.int32) + edges = results.get('edges', None) + if edges is not None: + labels = labels[:, None] + edges = np.array(edges) + edges = (edges[:, None] == edges[None, :]).astype(np.int32) + if self.directed: + edges = (edges & labels == 1).astype(np.int32) + np.fill_diagonal(edges, -1) + labels = np.concatenate([labels, edges], -1) + return dict( + bboxes=bboxes, + relations=relations, + texts=self.pad_text(texts), + labels=labels) + + def image_size(self, filename): + img_path = osp.join(self.img_prefix, filename) + img = Image.open(img_path) + return img.size + + def load_annotations(self, ann_file): + self.anns, data_infos = [], [] + + self.gts = dict() + img_infos = mmcv.list_from_file(ann_file) + for img_info in img_infos: + filename, annname = img_info.split() + results = self.parse_lines(annname) + width, height = self.image_size(filename) + + data_infos.append( + dict(filename=filename, width=width, height=height)) + ann = self.ann_numpy(results) + self.anns.append(ann) + + return data_infos + + def pad_text(self, texts): + max_len = max([len(text) for text in texts]) + padded_texts = -np.ones((len(texts), max_len), np.int32) + for idx, text in enumerate(texts): + padded_texts[idx, :len(text)] = np.array(text) + return padded_texts + + def get_ann_info(self, idx): + return self.anns[idx] + + def prepare_test_img(self, idx): + return self.prepare_train_img(idx) + + def evaluate(self, + results, + metric='macro_f1', + metric_options=dict(macro_f1=dict(ignores=[])), + **kwargs): + # allow some kwargs to pass through + assert set(kwargs).issubset(['logger']) + + # Protect ``metric_options`` since it uses mutable value as default + metric_options = copy.deepcopy(metric_options) + + metrics = metric if isinstance(metric, list) else [metric] + allowed_metrics = ['macro_f1'] + for m in metrics: + if m not in allowed_metrics: + raise KeyError(f'metric {m} is not supported') + + return self.compute_macro_f1(results, **metric_options['macro_f1']) + + def compute_macro_f1(self, results, ignores=[]): + node_preds = [] + for result in results: + node_preds.append(result['nodes']) + node_preds = torch.cat(node_preds) + + node_gts = [ + torch.from_numpy(ann['labels'][:, 0]).to(node_preds.device) + for ann in self.anns + ] + node_gts = torch.cat(node_gts) + + node_f1s = compute_f1_score(node_preds, node_gts, ignores) + + return { + 'macro_f1': node_f1s.mean(), + } diff --git a/mmocr/datasets/pipelines/kie_transforms.py b/mmocr/datasets/pipelines/kie_transforms.py new file mode 100644 index 00000000..4148c938 --- /dev/null +++ b/mmocr/datasets/pipelines/kie_transforms.py @@ -0,0 +1,55 @@ +import numpy as np +from mmcv.parallel import DataContainer as DC + +from mmdet.datasets.builder import PIPELINES +from mmdet.datasets.pipelines.formating import DefaultFormatBundle, to_tensor + + +@PIPELINES.register_module() +class KIEFormatBundle(DefaultFormatBundle): + """Key information extraction formatting bundle. + + Based on the DefaultFormatBundle, itt simplifies the pipeline of formatting + common fields, including "img", "proposals", "gt_bboxes", "gt_labels", + "gt_masks", "gt_semantic_seg", "relations" and "texts". + These fields are formatted as follows. + + - img: (1) transpose, (2) to tensor, (3) to DataContainer (stack=True) + - proposals: (1) to tensor, (2) to DataContainer + - gt_bboxes: (1) to tensor, (2) to DataContainer + - gt_bboxes_ignore: (1) to tensor, (2) to DataContainer + - gt_labels: (1) to tensor, (2) to DataContainer + - gt_masks: (1) to tensor, (2) to DataContainer (cpu_only=True) + - gt_semantic_seg: (1) unsqueeze dim-0 (2) to tensor, \ + (3) to DataContainer (stack=True) + - relations: (1) scale, (2) to tensor, (3) to DataContainer + - texts: (1) to tensor, (2) to DataContainer + """ + + def __call__(self, results): + """Call function to transform and format common fields in results. + + Args: + results (dict): Result dict contains the data to convert. + + Returns: + dict: The result dict contains the data that is formatted with \ + default bundle. + """ + super().__call__(results) + if 'ann_info' in results: + for key in ['relations', 'texts']: + value = results['ann_info'][key] + if key == 'relations' and 'scale_factor' in results: + scale_factor = results['scale_factor'] + if isinstance(scale_factor, float): + sx = sy = scale_factor + else: + sx, sy = results['scale_factor'][:2] + r = sx / sy + value = value * np.array([sx, sy, r, 1, r])[None, None] + results[key] = DC(to_tensor(value)) + return results + + def __repr__(self): + return self.__class__.__name__ diff --git a/mmocr/models/common/backbones/__init__.py b/mmocr/models/common/backbones/__init__.py new file mode 100644 index 00000000..de67ca96 --- /dev/null +++ b/mmocr/models/common/backbones/__init__.py @@ -0,0 +1,3 @@ +from .unet import UNet + +__all__ = ['UNet'] diff --git a/mmocr/models/common/backbones/unet.py b/mmocr/models/common/backbones/unet.py new file mode 100644 index 00000000..39e718d4 --- /dev/null +++ b/mmocr/models/common/backbones/unet.py @@ -0,0 +1,528 @@ +import torch +import torch.nn as nn +import torch.utils.checkpoint as cp +from mmcv.cnn import (UPSAMPLE_LAYERS, ConvModule, build_activation_layer, + build_norm_layer, build_upsample_layer, constant_init, + kaiming_init) +from mmcv.runner import load_checkpoint +from mmcv.utils.parrots_wrapper import _BatchNorm + +from mmdet.models.builder import BACKBONES +from mmdet.utils import get_root_logger + + +class UpConvBlock(nn.Module): + """Upsample convolution block in decoder for UNet. + + This upsample convolution block consists of one upsample module + followed by one convolution block. The upsample module expands the + high-level low-resolution feature map and the convolution block fuses + the upsampled high-level low-resolution feature map and the low-level + high-resolution feature map from encoder. + + Args: + conv_block (nn.Sequential): Sequential of convolutional layers. + in_channels (int): Number of input channels of the high-level + skip_channels (int): Number of input channels of the low-level + high-resolution feature map from encoder. + out_channels (int): Number of output channels. + num_convs (int): Number of convolutional layers in the conv_block. + Default: 2. + stride (int): Stride of convolutional layer in conv_block. Default: 1. + dilation (int): Dilation rate of convolutional layer in conv_block. + Default: 1. + with_cp (bool): Use checkpoint or not. Using checkpoint will save some + memory while slowing down the training speed. Default: False. + conv_cfg (dict | None): Config dict for convolution layer. + Default: None. + norm_cfg (dict | None): Config dict for normalization layer. + Default: dict(type='BN'). + act_cfg (dict | None): Config dict for activation layer in ConvModule. + Default: dict(type='ReLU'). + upsample_cfg (dict): The upsample config of the upsample module in + decoder. Default: dict(type='InterpConv'). If the size of + high-level feature map is the same as that of skip feature map + (low-level feature map from encoder), it does not need upsample the + high-level feature map and the upsample_cfg is None. + dcn (bool): Use deformable convolution in convolutional layer or not. + Default: None. + plugins (dict): plugins for convolutional layers. Default: None. + """ + + def __init__(self, + conv_block, + in_channels, + skip_channels, + out_channels, + num_convs=2, + stride=1, + dilation=1, + with_cp=False, + conv_cfg=None, + norm_cfg=dict(type='BN'), + act_cfg=dict(type='ReLU'), + upsample_cfg=dict(type='InterpConv'), + dcn=None, + plugins=None): + super().__init__() + assert dcn is None, 'Not implemented yet.' + assert plugins is None, 'Not implemented yet.' + + self.conv_block = conv_block( + in_channels=2 * skip_channels, + out_channels=out_channels, + num_convs=num_convs, + stride=stride, + dilation=dilation, + with_cp=with_cp, + conv_cfg=conv_cfg, + norm_cfg=norm_cfg, + act_cfg=act_cfg, + dcn=None, + plugins=None) + if upsample_cfg is not None: + self.upsample = build_upsample_layer( + cfg=upsample_cfg, + in_channels=in_channels, + out_channels=skip_channels, + with_cp=with_cp, + norm_cfg=norm_cfg, + act_cfg=act_cfg) + else: + self.upsample = ConvModule( + in_channels, + skip_channels, + kernel_size=1, + stride=1, + padding=0, + conv_cfg=conv_cfg, + norm_cfg=norm_cfg, + act_cfg=act_cfg) + + def forward(self, skip, x): + """Forward function.""" + + x = self.upsample(x) + out = torch.cat([skip, x], dim=1) + out = self.conv_block(out) + + return out + + +class BasicConvBlock(nn.Module): + """Basic convolutional block for UNet. + + This module consists of several plain convolutional layers. + + Args: + in_channels (int): Number of input channels. + out_channels (int): Number of output channels. + num_convs (int): Number of convolutional layers. Default: 2. + stride (int): Whether use stride convolution to downsample + the input feature map. If stride=2, it only uses stride convolution + in the first convolutional layer to downsample the input feature + map. Options are 1 or 2. Default: 1. + dilation (int): Whether use dilated convolution to expand the + receptive field. Set dilation rate of each convolutional layer and + the dilation rate of the first convolutional layer is always 1. + Default: 1. + with_cp (bool): Use checkpoint or not. Using checkpoint will save some + memory while slowing down the training speed. Default: False. + conv_cfg (dict | None): Config dict for convolution layer. + Default: None. + norm_cfg (dict | None): Config dict for normalization layer. + Default: dict(type='BN'). + act_cfg (dict | None): Config dict for activation layer in ConvModule. + Default: dict(type='ReLU'). + dcn (bool): Use deformable convolution in convolutional layer or not. + Default: None. + plugins (dict): plugins for convolutional layers. Default: None. + """ + + def __init__(self, + in_channels, + out_channels, + num_convs=2, + stride=1, + dilation=1, + with_cp=False, + conv_cfg=None, + norm_cfg=dict(type='BN'), + act_cfg=dict(type='ReLU'), + dcn=None, + plugins=None): + super().__init__() + assert dcn is None, 'Not implemented yet.' + assert plugins is None, 'Not implemented yet.' + + self.with_cp = with_cp + convs = [] + for i in range(num_convs): + convs.append( + ConvModule( + in_channels=in_channels if i == 0 else out_channels, + out_channels=out_channels, + kernel_size=3, + stride=stride if i == 0 else 1, + dilation=1 if i == 0 else dilation, + padding=1 if i == 0 else dilation, + conv_cfg=conv_cfg, + norm_cfg=norm_cfg, + act_cfg=act_cfg)) + + self.convs = nn.Sequential(*convs) + + def forward(self, x): + """Forward function.""" + + if self.with_cp and x.requires_grad: + out = cp.checkpoint(self.convs, x) + else: + out = self.convs(x) + return out + + +@UPSAMPLE_LAYERS.register_module() +class DeconvModule(nn.Module): + """Deconvolution upsample module in decoder for UNet (2X upsample). + + This module uses deconvolution to upsample feature map in the decoder + of UNet. + + Args: + in_channels (int): Number of input channels. + out_channels (int): Number of output channels. + with_cp (bool): Use checkpoint or not. Using checkpoint will save some + memory while slowing down the training speed. Default: False. + norm_cfg (dict | None): Config dict for normalization layer. + Default: dict(type='BN'). + act_cfg (dict | None): Config dict for activation layer in ConvModule. + Default: dict(type='ReLU'). + kernel_size (int): Kernel size of the convolutional layer. Default: 4. + """ + + def __init__(self, + in_channels, + out_channels, + with_cp=False, + norm_cfg=dict(type='BN'), + act_cfg=dict(type='ReLU'), + *, + kernel_size=4, + scale_factor=2): + super().__init__() + + assert (kernel_size - scale_factor >= 0) and\ + (kernel_size - scale_factor) % 2 == 0,\ + f'kernel_size should be greater than or equal to scale_factor '\ + f'and (kernel_size - scale_factor) should be even numbers, '\ + f'while the kernel size is {kernel_size} and scale_factor is '\ + f'{scale_factor}.' + + stride = scale_factor + padding = (kernel_size - scale_factor) // 2 + self.with_cp = with_cp + deconv = nn.ConvTranspose2d( + in_channels, + out_channels, + kernel_size=kernel_size, + stride=stride, + padding=padding) + + _, norm = build_norm_layer(norm_cfg, out_channels) + activate = build_activation_layer(act_cfg) + self.deconv_upsamping = nn.Sequential(deconv, norm, activate) + + def forward(self, x): + """Forward function.""" + + if self.with_cp and x.requires_grad: + out = cp.checkpoint(self.deconv_upsamping, x) + else: + out = self.deconv_upsamping(x) + return out + + +@UPSAMPLE_LAYERS.register_module() +class InterpConv(nn.Module): + """Interpolation upsample module in decoder for UNet. + + This module uses interpolation to upsample feature map in the decoder + of UNet. It consists of one interpolation upsample layer and one + convolutional layer. It can be one interpolation upsample layer followed + by one convolutional layer (conv_first=False) or one convolutional layer + followed by one interpolation upsample layer (conv_first=True). + + Args: + in_channels (int): Number of input channels. + out_channels (int): Number of output channels. + with_cp (bool): Use checkpoint or not. Using checkpoint will save some + memory while slowing down the training speed. Default: False. + norm_cfg (dict | None): Config dict for normalization layer. + Default: dict(type='BN'). + act_cfg (dict | None): Config dict for activation layer in ConvModule. + Default: dict(type='ReLU'). + conv_cfg (dict | None): Config dict for convolution layer. + Default: None. + conv_first (bool): Whether convolutional layer or interpolation + upsample layer first. Default: False. It means interpolation + upsample layer followed by one convolutional layer. + kernel_size (int): Kernel size of the convolutional layer. Default: 1. + stride (int): Stride of the convolutional layer. Default: 1. + padding (int): Padding of the convolutional layer. Default: 1. + upsample_cfg (dict): Interpolation config of the upsample layer. + Default: dict( + scale_factor=2, mode='bilinear', align_corners=False). + """ + + def __init__(self, + in_channels, + out_channels, + with_cp=False, + norm_cfg=dict(type='BN'), + act_cfg=dict(type='ReLU'), + *, + conv_cfg=None, + conv_first=False, + kernel_size=1, + stride=1, + padding=0, + upsample_cfg=dict( + scale_factor=2, mode='bilinear', align_corners=False)): + super().__init__() + + self.with_cp = with_cp + conv = ConvModule( + in_channels, + out_channels, + kernel_size=kernel_size, + stride=stride, + padding=padding, + conv_cfg=conv_cfg, + norm_cfg=norm_cfg, + act_cfg=act_cfg) + upsample = nn.Upsample(**upsample_cfg) + if conv_first: + self.interp_upsample = nn.Sequential(conv, upsample) + else: + self.interp_upsample = nn.Sequential(upsample, conv) + + def forward(self, x): + """Forward function.""" + + if self.with_cp and x.requires_grad: + out = cp.checkpoint(self.interp_upsample, x) + else: + out = self.interp_upsample(x) + return out + + +@BACKBONES.register_module() +class UNet(nn.Module): + """UNet backbone. + U-Net: Convolutional Networks for Biomedical Image Segmentation. + https://arxiv.org/pdf/1505.04597.pdf + + Args: + in_channels (int): Number of input image channels. Default" 3. + base_channels (int): Number of base channels of each stage. + The output channels of the first stage. Default: 64. + num_stages (int): Number of stages in encoder, normally 5. Default: 5. + strides (Sequence[int 1 | 2]): Strides of each stage in encoder. + len(strides) is equal to num_stages. Normally the stride of the + first stage in encoder is 1. If strides[i]=2, it uses stride + convolution to downsample in the correspondence encoder stage. + Default: (1, 1, 1, 1, 1). + enc_num_convs (Sequence[int]): Number of convolutional layers in the + convolution block of the correspondence encoder stage. + Default: (2, 2, 2, 2, 2). + dec_num_convs (Sequence[int]): Number of convolutional layers in the + convolution block of the correspondence decoder stage. + Default: (2, 2, 2, 2). + downsamples (Sequence[int]): Whether use MaxPool to downsample the + feature map after the first stage of encoder + (stages: [1, num_stages)). If the correspondence encoder stage use + stride convolution (strides[i]=2), it will never use MaxPool to + downsample, even downsamples[i-1]=True. + Default: (True, True, True, True). + enc_dilations (Sequence[int]): Dilation rate of each stage in encoder. + Default: (1, 1, 1, 1, 1). + dec_dilations (Sequence[int]): Dilation rate of each stage in decoder. + Default: (1, 1, 1, 1). + with_cp (bool): Use checkpoint or not. Using checkpoint will save some + memory while slowing down the training speed. Default: False. + conv_cfg (dict | None): Config dict for convolution layer. + Default: None. + norm_cfg (dict | None): Config dict for normalization layer. + Default: dict(type='BN'). + act_cfg (dict | None): Config dict for activation layer in ConvModule. + Default: dict(type='ReLU'). + upsample_cfg (dict): The upsample config of the upsample module in + decoder. Default: dict(type='InterpConv'). + norm_eval (bool): Whether to set norm layers to eval mode, namely, + freeze running stats (mean and var). Note: Effect on Batch Norm + and its variants only. Default: False. + dcn (bool): Use deformable convolution in convolutional layer or not. + Default: None. + plugins (dict): plugins for convolutional layers. Default: None. + + Notice: + The input image size should be divisible by the whole downsample rate + of the encoder. More detail of the whole downsample rate can be found + in UNet._check_input_divisible. + + """ + + def __init__(self, + in_channels=3, + base_channels=64, + num_stages=5, + strides=(1, 1, 1, 1, 1), + enc_num_convs=(2, 2, 2, 2, 2), + dec_num_convs=(2, 2, 2, 2), + downsamples=(True, True, True, True), + enc_dilations=(1, 1, 1, 1, 1), + dec_dilations=(1, 1, 1, 1), + with_cp=False, + conv_cfg=None, + norm_cfg=dict(type='BN'), + act_cfg=dict(type='ReLU'), + upsample_cfg=dict(type='InterpConv'), + norm_eval=False, + dcn=None, + plugins=None): + super().__init__() + assert dcn is None, 'Not implemented yet.' + assert plugins is None, 'Not implemented yet.' + assert len(strides) == num_stages, \ + 'The length of strides should be equal to num_stages, '\ + f'while the strides is {strides}, the length of '\ + f'strides is {len(strides)}, and the num_stages is '\ + f'{num_stages}.' + assert len(enc_num_convs) == num_stages, \ + 'The length of enc_num_convs should be equal to num_stages, '\ + f'while the enc_num_convs is {enc_num_convs}, the length of '\ + f'enc_num_convs is {len(enc_num_convs)}, and the num_stages is '\ + f'{num_stages}.' + assert len(dec_num_convs) == (num_stages-1), \ + 'The length of dec_num_convs should be equal to (num_stages-1), '\ + f'while the dec_num_convs is {dec_num_convs}, the length of '\ + f'dec_num_convs is {len(dec_num_convs)}, and the num_stages is '\ + f'{num_stages}.' + assert len(downsamples) == (num_stages-1), \ + 'The length of downsamples should be equal to (num_stages-1), '\ + f'while the downsamples is {downsamples}, the length of '\ + f'downsamples is {len(downsamples)}, and the num_stages is '\ + f'{num_stages}.' + assert len(enc_dilations) == num_stages, \ + 'The length of enc_dilations should be equal to num_stages, '\ + f'while the enc_dilations is {enc_dilations}, the length of '\ + f'enc_dilations is {len(enc_dilations)}, and the num_stages is '\ + f'{num_stages}.' + assert len(dec_dilations) == (num_stages-1), \ + 'The length of dec_dilations should be equal to (num_stages-1), '\ + f'while the dec_dilations is {dec_dilations}, the length of '\ + f'dec_dilations is {len(dec_dilations)}, and the num_stages is '\ + f'{num_stages}.' + self.num_stages = num_stages + self.strides = strides + self.downsamples = downsamples + self.norm_eval = norm_eval + self.base_channels = base_channels + + self.encoder = nn.ModuleList() + self.decoder = nn.ModuleList() + + for i in range(num_stages): + enc_conv_block = [] + if i != 0: + if strides[i] == 1 and downsamples[i - 1]: + enc_conv_block.append(nn.MaxPool2d(kernel_size=2)) + upsample = (strides[i] != 1 or downsamples[i - 1]) + self.decoder.append( + UpConvBlock( + conv_block=BasicConvBlock, + in_channels=base_channels * 2**i, + skip_channels=base_channels * 2**(i - 1), + out_channels=base_channels * 2**(i - 1), + num_convs=dec_num_convs[i - 1], + stride=1, + dilation=dec_dilations[i - 1], + with_cp=with_cp, + conv_cfg=conv_cfg, + norm_cfg=norm_cfg, + act_cfg=act_cfg, + upsample_cfg=upsample_cfg if upsample else None, + dcn=None, + plugins=None)) + + enc_conv_block.append( + BasicConvBlock( + in_channels=in_channels, + out_channels=base_channels * 2**i, + num_convs=enc_num_convs[i], + stride=strides[i], + dilation=enc_dilations[i], + with_cp=with_cp, + conv_cfg=conv_cfg, + norm_cfg=norm_cfg, + act_cfg=act_cfg, + dcn=None, + plugins=None)) + self.encoder.append((nn.Sequential(*enc_conv_block))) + in_channels = base_channels * 2**i + + def forward(self, x): + self._check_input_divisible(x) + enc_outs = [] + for enc in self.encoder: + x = enc(x) + enc_outs.append(x) + dec_outs = [x] + for i in reversed(range(len(self.decoder))): + x = self.decoder[i](enc_outs[i], x) + dec_outs.append(x) + + return dec_outs + + def train(self, mode=True): + """Convert the model into training mode while keep normalization layer + freezed.""" + super().train(mode) + if mode and self.norm_eval: + for m in self.modules(): + # trick: eval have effect on BatchNorm only + if isinstance(m, _BatchNorm): + m.eval() + + def _check_input_divisible(self, x): + h, w = x.shape[-2:] + whole_downsample_rate = 1 + for i in range(1, self.num_stages): + if self.strides[i] == 2 or self.downsamples[i - 1]: + whole_downsample_rate *= 2 + assert (h % whole_downsample_rate == 0) \ + and (w % whole_downsample_rate == 0),\ + f'The input image size {(h, w)} should be divisible by the whole '\ + f'downsample rate {whole_downsample_rate}, when num_stages is '\ + f'{self.num_stages}, strides is {self.strides}, and downsamples '\ + f'is {self.downsamples}.' + + def init_weights(self, pretrained=None): + """Initialize the weights in backbone. + + Args: + pretrained (str, optional): Path to pre-trained weights. + Defaults to None. + """ + if isinstance(pretrained, str): + logger = get_root_logger() + load_checkpoint(self, pretrained, strict=False, logger=logger) + elif pretrained is None: + for m in self.modules(): + if isinstance(m, nn.Conv2d): + kaiming_init(m) + elif isinstance(m, (_BatchNorm, nn.GroupNorm)): + constant_init(m, 1) + else: + raise TypeError('pretrained must be a str or None') diff --git a/mmocr/models/kie/__init__.py b/mmocr/models/kie/__init__.py new file mode 100644 index 00000000..46d98163 --- /dev/null +++ b/mmocr/models/kie/__init__.py @@ -0,0 +1,3 @@ +from .extractors import * # noqa: F401,F403 +from .heads import * # noqa: F401,F403 +from .losses import * # noqa: F401,F403 diff --git a/mmocr/models/kie/extractors/__init__.py b/mmocr/models/kie/extractors/__init__.py new file mode 100644 index 00000000..f58541e6 --- /dev/null +++ b/mmocr/models/kie/extractors/__init__.py @@ -0,0 +1,3 @@ +from .sdmgr import SDMGR + +__all__ = ['SDMGR'] diff --git a/mmocr/models/kie/extractors/sdmgr.py b/mmocr/models/kie/extractors/sdmgr.py new file mode 100644 index 00000000..c11df7ff --- /dev/null +++ b/mmocr/models/kie/extractors/sdmgr.py @@ -0,0 +1,87 @@ +from torch import nn +from torch.nn import functional as F + +from mmdet.core import bbox2roi +from mmdet.models.builder import DETECTORS, build_roi_extractor +from mmdet.models.detectors import SingleStageDetector + + +@DETECTORS.register_module() +class SDMGR(SingleStageDetector): + """The implementation of the paper: Spatial Dual-Modality Graph Reasoning + for Key Information Extraction. https://arxiv.org/abs/2103.14470. + + Args: + visual_modality (bool): Whether use the visual modality. + """ + + def __init__(self, + backbone, + neck=None, + bbox_head=None, + extractor=dict( + type='SingleRoIExtractor', + roi_layer=dict(type='RoIAlign', output_size=7), + featmap_strides=[1]), + visual_modality=False, + train_cfg=None, + test_cfg=None, + pretrained=None): + super().__init__(backbone, neck, bbox_head, train_cfg, test_cfg, + pretrained) + self.visual_modality = visual_modality + if visual_modality: + self.extractor = build_roi_extractor({ + **extractor, 'out_channels': + self.backbone.base_channels + }) + self.maxpool = nn.MaxPool2d(extractor['roi_layer']['output_size']) + else: + self.extractor = None + + def forward_train(self, img, img_metas, relations, texts, gt_bboxes, + gt_labels): + """ + Args: + img (tensor): Input images of shape (N, C, H, W). + Typically these should be mean centered and std scaled. + img_metas (list[dict]): A list of image info dict where each dict + contains: 'img_shape', 'scale_factor', 'flip', and may also + contain 'filename', 'ori_shape', 'pad_shape', and + 'img_norm_cfg'. For details of the values of these keys, + please see :class:`mmdet.datasets.pipelines.Collect`. + relations (list[tensor]): Relations between bboxes. + texts (list[tensor]): Texts in bboxes. + gt_bboxes (list[tensor]): Each item is the truth boxes for each + image in [tl_x, tl_y, br_x, br_y] format. + gt_labels (list[tensor]): Class indices corresponding to each box. + + Returns: + dict[str, tensor]: A dictionary of loss components. + """ + x = self.extract_feat(img, gt_bboxes) + node_preds, edge_preds = self.bbox_head.forward(relations, texts, x) + return self.bbox_head.loss(node_preds, edge_preds, gt_labels) + + def forward_test(self, + img, + img_metas, + relations, + texts, + gt_bboxes, + rescale=False): + x = self.extract_feat(img, gt_bboxes) + node_preds, edge_preds = self.bbox_head.forward(relations, texts, x) + return [ + dict( + img_metas=img_metas, + nodes=F.softmax(node_preds, -1), + edges=F.softmax(edge_preds, -1)) + ] + + def extract_feat(self, img, gt_bboxes): + if self.visual_modality: + x = super().extract_feat(img)[-1] + feats = self.maxpool(self.extractor([x], bbox2roi(gt_bboxes))) + return feats.view(feats.size(0), -1) + return None diff --git a/mmocr/models/kie/heads/__init__.py b/mmocr/models/kie/heads/__init__.py new file mode 100644 index 00000000..00a11469 --- /dev/null +++ b/mmocr/models/kie/heads/__init__.py @@ -0,0 +1,3 @@ +from .sdmgr_head import SDMGRHead + +__all__ = ['SDMGRHead'] diff --git a/mmocr/models/kie/heads/sdmgr_head.py b/mmocr/models/kie/heads/sdmgr_head.py new file mode 100644 index 00000000..92dbd130 --- /dev/null +++ b/mmocr/models/kie/heads/sdmgr_head.py @@ -0,0 +1,193 @@ +import torch +from mmcv.cnn import normal_init +from torch import nn +from torch.nn import functional as F + +from mmdet.models.builder import HEADS, build_loss + + +@HEADS.register_module() +class SDMGRHead(nn.Module): + + def __init__(self, + num_chars=92, + visual_dim=64, + fusion_dim=1024, + node_input=32, + node_embed=256, + edge_input=5, + edge_embed=256, + num_gnn=2, + num_classes=26, + loss=dict(type='SDMGRLoss'), + bidirectional=False, + train_cfg=None, + test_cfg=None): + super().__init__() + + self.fusion = Block([visual_dim, node_embed], node_embed, fusion_dim) + self.node_embed = nn.Embedding(num_chars, node_input, 0) + hidden = node_embed // 2 if bidirectional else node_embed + self.rnn = nn.LSTM( + input_size=node_input, + hidden_size=hidden, + num_layers=1, + batch_first=True, + bidirectional=bidirectional) + self.edge_embed = nn.Linear(edge_input, edge_embed) + self.gnn_layers = nn.ModuleList( + [GNNLayer(node_embed, edge_embed) for _ in range(num_gnn)]) + self.node_cls = nn.Linear(node_embed, num_classes) + self.edge_cls = nn.Linear(edge_embed, 2) + self.loss = build_loss(loss) + + def init_weights(self, pretrained=False): + normal_init(self.edge_embed, mean=0, std=0.01) + + def forward(self, relations, texts, x=None): + node_nums, char_nums = [], [] + for text in texts: + node_nums.append(text.size(0)) + char_nums.append((text > 0).sum(-1)) + + max_num = max([char_num.max() for char_num in char_nums]) + all_nodes = torch.cat([ + torch.cat( + [text, + text.new_zeros(text.size(0), max_num - text.size(1))], -1) + for text in texts + ]) + embed_nodes = self.node_embed(all_nodes.clamp(min=0).long()) + rnn_nodes, _ = self.rnn(embed_nodes) + + nodes = rnn_nodes.new_zeros(*rnn_nodes.shape[::2]) + all_nums = torch.cat(char_nums) + valid = all_nums > 0 + nodes[valid] = rnn_nodes[valid].gather( + 1, (all_nums[valid] - 1).unsqueeze(-1).unsqueeze(-1).expand( + -1, -1, rnn_nodes.size(-1))).squeeze(1) + + if x is not None: + nodes = self.fusion([x, nodes]) + + all_edges = torch.cat( + [rel.view(-1, rel.size(-1)) for rel in relations]) + embed_edges = self.edge_embed(all_edges.float()) + embed_edges = F.normalize(embed_edges) + + for gnn_layer in self.gnn_layers: + nodes, cat_nodes = gnn_layer(nodes, embed_edges, node_nums) + + node_cls, edge_cls = self.node_cls(nodes), self.edge_cls(cat_nodes) + return node_cls, edge_cls + + +class GNNLayer(nn.Module): + + def __init__(self, node_dim=256, edge_dim=256): + super().__init__() + self.in_fc = nn.Linear(node_dim * 2 + edge_dim, node_dim) + self.coef_fc = nn.Linear(node_dim, 1) + self.out_fc = nn.Linear(node_dim, node_dim) + self.relu = nn.ReLU() + + def forward(self, nodes, edges, nums): + start, cat_nodes = 0, [] + for num in nums: + sample_nodes = nodes[start:start + num] + cat_nodes.append( + torch.cat([ + sample_nodes.unsqueeze(1).expand(-1, num, -1), + sample_nodes.unsqueeze(0).expand(num, -1, -1) + ], -1).view(num**2, -1)) + start += num + cat_nodes = torch.cat([torch.cat(cat_nodes), edges], -1) + cat_nodes = self.relu(self.in_fc(cat_nodes)) + coefs = self.coef_fc(cat_nodes) + + start, residuals = 0, [] + for num in nums: + residual = F.softmax( + -torch.eye(num).to(coefs.device).unsqueeze(-1) * 1e9 + + coefs[start:start + num**2].view(num, num, -1), 1) + residuals.append( + (residual * + cat_nodes[start:start + num**2].view(num, num, -1)).sum(1)) + start += num**2 + + nodes += self.relu(self.out_fc(torch.cat(residuals))) + return nodes, cat_nodes + + +class Block(nn.Module): + + def __init__(self, + input_dims, + output_dim, + mm_dim=1600, + chunks=20, + rank=15, + shared=False, + dropout_input=0., + dropout_pre_lin=0., + dropout_output=0., + pos_norm='before_cat'): + super().__init__() + self.rank = rank + self.dropout_input = dropout_input + self.dropout_pre_lin = dropout_pre_lin + self.dropout_output = dropout_output + assert (pos_norm in ['before_cat', 'after_cat']) + self.pos_norm = pos_norm + # Modules + self.linear0 = nn.Linear(input_dims[0], mm_dim) + self.linear1 = self.linear0 if shared \ + else nn.Linear(input_dims[1], mm_dim) + self.merge_linears0, self.merge_linears1 =\ + nn.ModuleList(), nn.ModuleList() + self.chunks = self.chunk_sizes(mm_dim, chunks) + for size in self.chunks: + ml0 = nn.Linear(size, size * rank) + self.merge_linears0.append(ml0) + ml1 = ml0 if shared else nn.Linear(size, size * rank) + self.merge_linears1.append(ml1) + self.linear_out = nn.Linear(mm_dim, output_dim) + + def forward(self, x): + x0 = self.linear0(x[0]) + x1 = self.linear1(x[1]) + bs = x1.size(0) + if self.dropout_input > 0: + x0 = F.dropout(x0, p=self.dropout_input, training=self.training) + x1 = F.dropout(x1, p=self.dropout_input, training=self.training) + x0_chunks = torch.split(x0, self.chunks, -1) + x1_chunks = torch.split(x1, self.chunks, -1) + zs = [] + for x0_c, x1_c, m0, m1 in zip(x0_chunks, x1_chunks, + self.merge_linears0, + self.merge_linears1): + m = m0(x0_c) * m1(x1_c) # bs x split_size*rank + m = m.view(bs, self.rank, -1) + z = torch.sum(m, 1) + if self.pos_norm == 'before_cat': + z = torch.sqrt(F.relu(z)) - torch.sqrt(F.relu(-z)) + z = F.normalize(z) + zs.append(z) + z = torch.cat(zs, 1) + if self.pos_norm == 'after_cat': + z = torch.sqrt(F.relu(z)) - torch.sqrt(F.relu(-z)) + z = F.normalize(z) + + if self.dropout_pre_lin > 0: + z = F.dropout(z, p=self.dropout_pre_lin, training=self.training) + z = self.linear_out(z) + if self.dropout_output > 0: + z = F.dropout(z, p=self.dropout_output, training=self.training) + return z + + @staticmethod + def chunk_sizes(dim, chunks): + split_size = (dim + chunks - 1) // chunks + sizes_list = [split_size] * chunks + sizes_list[-1] = sizes_list[-1] - (sum(sizes_list) - dim) + return sizes_list diff --git a/mmocr/models/kie/losses/__init__.py b/mmocr/models/kie/losses/__init__.py new file mode 100644 index 00000000..96b4afde --- /dev/null +++ b/mmocr/models/kie/losses/__init__.py @@ -0,0 +1,3 @@ +from .sdmgr_loss import SDMGRLoss + +__all__ = ['SDMGRLoss'] diff --git a/mmocr/models/kie/losses/sdmgr_loss.py b/mmocr/models/kie/losses/sdmgr_loss.py new file mode 100644 index 00000000..9e1d2312 --- /dev/null +++ b/mmocr/models/kie/losses/sdmgr_loss.py @@ -0,0 +1,39 @@ +import torch +from torch import nn + +from mmdet.models.builder import LOSSES +from mmdet.models.losses import accuracy + + +@LOSSES.register_module() +class SDMGRLoss(nn.Module): + """The implementation the loss of key information extraction proposed in + the paper: Spatial Dual-Modality Graph Reasoning for Key Information + Extraction. + + https://arxiv.org/abs/2103.14470. + """ + + def __init__(self, node_weight=1.0, edge_weight=1.0, ignore=0): + super().__init__() + self.loss_node = nn.CrossEntropyLoss(ignore_index=ignore) + self.loss_edge = nn.CrossEntropyLoss(ignore_index=-1) + self.node_weight = node_weight + self.edge_weight = edge_weight + self.ignore = ignore + + def forward(self, node_preds, edge_preds, gts): + node_gts, edge_gts = [], [] + for gt in gts: + node_gts.append(gt[:, 0]) + edge_gts.append(gt[:, 1:].contiguous().view(-1)) + node_gts = torch.cat(node_gts).long() + edge_gts = torch.cat(edge_gts).long() + + node_valids = torch.nonzero(node_gts != self.ignore).view(-1) + edge_valids = torch.nonzero(edge_gts != -1).view(-1) + return dict( + loss_node=self.node_weight * self.loss_node(node_preds, node_gts), + loss_edge=self.edge_weight * self.loss_edge(edge_preds, edge_gts), + acc_node=accuracy(node_preds[node_valids], node_gts[node_valids]), + acc_edge=accuracy(edge_preds[edge_valids], edge_gts[edge_valids]))