mirror of
https://github.com/open-mmlab/mmocr.git
synced 2025-06-03 21:54:47 +08:00
parent
961c450df7
commit
0e70f99f4d
@ -1,6 +1,7 @@
|
||||
import copy
|
||||
from os import path as osp
|
||||
|
||||
import mmcv
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
@ -47,14 +48,13 @@ class KIEDataset(BaseDataset):
|
||||
|
||||
self.norm = norm
|
||||
self.directed = directed
|
||||
|
||||
self.dict = dict({'': 0})
|
||||
with open(dict_file, 'r') as fr:
|
||||
idx = 1
|
||||
for line in fr:
|
||||
char = line.strip()
|
||||
self.dict[char] = idx
|
||||
idx += 1
|
||||
self.dict = {
|
||||
'': 0,
|
||||
**{
|
||||
line.rstrip('\r\n'): ind
|
||||
for ind, line in enumerate(mmcv.list_from_file(dict_file), 1)
|
||||
}
|
||||
}
|
||||
|
||||
def pre_pipeline(self, results):
|
||||
results['img_prefix'] = self.img_prefix
|
||||
@ -70,7 +70,7 @@ class KIEDataset(BaseDataset):
|
||||
dict: A dict containing the following keys:
|
||||
|
||||
- bboxes (np.ndarray): Bbox in one image with shape:
|
||||
box_num * 4.
|
||||
box_num * 4. They are sorted clockwise when loading.
|
||||
- relations (np.ndarray): Relations between bbox with shape:
|
||||
box_num * box_num * D.
|
||||
- texts (np.ndarray): Text index with shape:
|
||||
@ -93,7 +93,7 @@ class KIEDataset(BaseDataset):
|
||||
texts.append(ann['text'])
|
||||
text_ind = [self.dict[c] for c in text if c in self.dict]
|
||||
text_inds.append(text_ind)
|
||||
labels.append(ann['label'])
|
||||
labels.append(ann.get('label', 0))
|
||||
edges.append(ann.get('edge', 0))
|
||||
|
||||
ann_infos = dict(
|
||||
@ -201,13 +201,13 @@ class KIEDataset(BaseDataset):
|
||||
|
||||
def compute_relation(self, boxes):
|
||||
"""Compute relation between every two boxes."""
|
||||
x1s, y1s = boxes[:, 0:1], boxes[:, 1:2]
|
||||
x2s, y2s = boxes[:, 4:5], boxes[:, 5:6]
|
||||
ws, hs = x2s - x1s + 1, np.maximum(y2s - y1s + 1, 1)
|
||||
dxs = (x1s[:, 0][None] - x1s) / self.norm
|
||||
dys = (y1s[:, 0][None] - y1s) / self.norm
|
||||
xhhs, xwhs = hs[:, 0][None] / hs, ws[:, 0][None] / hs
|
||||
whs = ws / hs + np.zeros_like(xhhs)
|
||||
relations = np.stack([dxs, dys, whs, xhhs, xwhs], -1)
|
||||
bboxes = np.concatenate([x1s, y1s, x2s, y2s], -1).astype(np.float32)
|
||||
return relations, bboxes
|
||||
x1, y1 = boxes[:, 0:1], boxes[:, 1:2]
|
||||
x2, y2 = boxes[:, 4:5], boxes[:, 5:6]
|
||||
w, h = np.maximum(x2 - x1 + 1, 1), np.maximum(y2 - y1 + 1, 1)
|
||||
dx = (x1.T - x1) / self.norm
|
||||
dy = (y1.T - y1) / self.norm
|
||||
xhh, xwh = h.T / h, w.T / h
|
||||
whs = w / h + np.zeros_like(xhh)
|
||||
relation = np.stack([dx, dy, whs, xhh, xwh], -1)
|
||||
bboxes = np.concatenate([x1, y1, x2, y2], -1).astype(np.float32)
|
||||
return relation, bboxes
|
||||
|
@ -1,5 +1,7 @@
|
||||
import os.path as osp
|
||||
|
||||
import mmcv
|
||||
|
||||
from mmocr.datasets.builder import LOADERS, build_parser
|
||||
|
||||
|
||||
@ -47,13 +49,7 @@ class HardDiskLoader(Loader):
|
||||
"""
|
||||
|
||||
def _load(self, ann_file):
|
||||
data_ret = []
|
||||
with open(ann_file, 'r', encoding='utf-8') as f:
|
||||
for line in f:
|
||||
line = line.strip()
|
||||
data_ret.append(line)
|
||||
|
||||
return data_ret
|
||||
return mmcv.list_from_file(ann_file)
|
||||
|
||||
|
||||
@LOADERS.register_module()
|
||||
|
Loading…
x
Reference in New Issue
Block a user