mirror of
https://github.com/open-mmlab/mmocr.git
synced 2025-06-03 21:54:47 +08:00
parent
961c450df7
commit
0e70f99f4d
@ -1,6 +1,7 @@
|
|||||||
import copy
|
import copy
|
||||||
from os import path as osp
|
from os import path as osp
|
||||||
|
|
||||||
|
import mmcv
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
@ -47,14 +48,13 @@ class KIEDataset(BaseDataset):
|
|||||||
|
|
||||||
self.norm = norm
|
self.norm = norm
|
||||||
self.directed = directed
|
self.directed = directed
|
||||||
|
self.dict = {
|
||||||
self.dict = dict({'': 0})
|
'': 0,
|
||||||
with open(dict_file, 'r') as fr:
|
**{
|
||||||
idx = 1
|
line.rstrip('\r\n'): ind
|
||||||
for line in fr:
|
for ind, line in enumerate(mmcv.list_from_file(dict_file), 1)
|
||||||
char = line.strip()
|
}
|
||||||
self.dict[char] = idx
|
}
|
||||||
idx += 1
|
|
||||||
|
|
||||||
def pre_pipeline(self, results):
|
def pre_pipeline(self, results):
|
||||||
results['img_prefix'] = self.img_prefix
|
results['img_prefix'] = self.img_prefix
|
||||||
@ -70,7 +70,7 @@ class KIEDataset(BaseDataset):
|
|||||||
dict: A dict containing the following keys:
|
dict: A dict containing the following keys:
|
||||||
|
|
||||||
- bboxes (np.ndarray): Bbox in one image with shape:
|
- bboxes (np.ndarray): Bbox in one image with shape:
|
||||||
box_num * 4.
|
box_num * 4. They are sorted clockwise when loading.
|
||||||
- relations (np.ndarray): Relations between bbox with shape:
|
- relations (np.ndarray): Relations between bbox with shape:
|
||||||
box_num * box_num * D.
|
box_num * box_num * D.
|
||||||
- texts (np.ndarray): Text index with shape:
|
- texts (np.ndarray): Text index with shape:
|
||||||
@ -93,7 +93,7 @@ class KIEDataset(BaseDataset):
|
|||||||
texts.append(ann['text'])
|
texts.append(ann['text'])
|
||||||
text_ind = [self.dict[c] for c in text if c in self.dict]
|
text_ind = [self.dict[c] for c in text if c in self.dict]
|
||||||
text_inds.append(text_ind)
|
text_inds.append(text_ind)
|
||||||
labels.append(ann['label'])
|
labels.append(ann.get('label', 0))
|
||||||
edges.append(ann.get('edge', 0))
|
edges.append(ann.get('edge', 0))
|
||||||
|
|
||||||
ann_infos = dict(
|
ann_infos = dict(
|
||||||
@ -201,13 +201,13 @@ class KIEDataset(BaseDataset):
|
|||||||
|
|
||||||
def compute_relation(self, boxes):
|
def compute_relation(self, boxes):
|
||||||
"""Compute relation between every two boxes."""
|
"""Compute relation between every two boxes."""
|
||||||
x1s, y1s = boxes[:, 0:1], boxes[:, 1:2]
|
x1, y1 = boxes[:, 0:1], boxes[:, 1:2]
|
||||||
x2s, y2s = boxes[:, 4:5], boxes[:, 5:6]
|
x2, y2 = boxes[:, 4:5], boxes[:, 5:6]
|
||||||
ws, hs = x2s - x1s + 1, np.maximum(y2s - y1s + 1, 1)
|
w, h = np.maximum(x2 - x1 + 1, 1), np.maximum(y2 - y1 + 1, 1)
|
||||||
dxs = (x1s[:, 0][None] - x1s) / self.norm
|
dx = (x1.T - x1) / self.norm
|
||||||
dys = (y1s[:, 0][None] - y1s) / self.norm
|
dy = (y1.T - y1) / self.norm
|
||||||
xhhs, xwhs = hs[:, 0][None] / hs, ws[:, 0][None] / hs
|
xhh, xwh = h.T / h, w.T / h
|
||||||
whs = ws / hs + np.zeros_like(xhhs)
|
whs = w / h + np.zeros_like(xhh)
|
||||||
relations = np.stack([dxs, dys, whs, xhhs, xwhs], -1)
|
relation = np.stack([dx, dy, whs, xhh, xwh], -1)
|
||||||
bboxes = np.concatenate([x1s, y1s, x2s, y2s], -1).astype(np.float32)
|
bboxes = np.concatenate([x1, y1, x2, y2], -1).astype(np.float32)
|
||||||
return relations, bboxes
|
return relation, bboxes
|
||||||
|
@ -1,5 +1,7 @@
|
|||||||
import os.path as osp
|
import os.path as osp
|
||||||
|
|
||||||
|
import mmcv
|
||||||
|
|
||||||
from mmocr.datasets.builder import LOADERS, build_parser
|
from mmocr.datasets.builder import LOADERS, build_parser
|
||||||
|
|
||||||
|
|
||||||
@ -47,13 +49,7 @@ class HardDiskLoader(Loader):
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
def _load(self, ann_file):
|
def _load(self, ann_file):
|
||||||
data_ret = []
|
return mmcv.list_from_file(ann_file)
|
||||||
with open(ann_file, 'r', encoding='utf-8') as f:
|
|
||||||
for line in f:
|
|
||||||
line = line.strip()
|
|
||||||
data_ret.append(line)
|
|
||||||
|
|
||||||
return data_ret
|
|
||||||
|
|
||||||
|
|
||||||
@LOADERS.register_module()
|
@LOADERS.register_module()
|
||||||
|
Loading…
x
Reference in New Issue
Block a user