mirror of https://github.com/open-mmlab/mmocr.git
115 lines
3.4 KiB
Python
115 lines
3.4 KiB
Python
# Copyright (c) OpenMMLab. All rights reserved.
|
|
import json
|
|
import os.path as osp
|
|
import tempfile
|
|
|
|
import torch
|
|
|
|
from mmocr.datasets.ner_dataset import NerDataset
|
|
from mmocr.models.ner.convertors.ner_convertor import NerConvertor
|
|
from mmocr.utils import list_to_file
|
|
|
|
|
|
def _create_dummy_ann_file(ann_file):
|
|
data = {
|
|
'text': '彭小军认为,国内银行现在走的是台湾的发卡模式',
|
|
'label': {
|
|
'address': {
|
|
'台湾': [[15, 16]]
|
|
},
|
|
'name': {
|
|
'彭小军': [[0, 2]]
|
|
}
|
|
}
|
|
}
|
|
|
|
list_to_file(ann_file, [json.dumps(data, ensure_ascii=False)])
|
|
|
|
|
|
def _create_dummy_vocab_file(vocab_file):
|
|
for char in list(map(chr, range(ord('a'), ord('z') + 1))):
|
|
list_to_file(vocab_file, [json.dumps(char + '\n', ensure_ascii=False)])
|
|
|
|
|
|
def _create_dummy_loader():
|
|
loader = dict(
|
|
type='HardDiskLoader',
|
|
repeat=1,
|
|
parser=dict(type='LineJsonParser', keys=['text', 'label']))
|
|
return loader
|
|
|
|
|
|
def test_ner_dataset():
|
|
# test initialization
|
|
loader = _create_dummy_loader()
|
|
categories = [
|
|
'address', 'book', 'company', 'game', 'government', 'movie', 'name',
|
|
'organization', 'position', 'scene'
|
|
]
|
|
|
|
# create dummy data
|
|
tmp_dir = tempfile.TemporaryDirectory()
|
|
ann_file = osp.join(tmp_dir.name, 'fake_data.txt')
|
|
vocab_file = osp.join(tmp_dir.name, 'fake_vocab.txt')
|
|
_create_dummy_ann_file(ann_file)
|
|
_create_dummy_vocab_file(vocab_file)
|
|
|
|
max_len = 128
|
|
ner_convertor = dict(
|
|
type='NerConvertor',
|
|
annotation_type='bio',
|
|
vocab_file=vocab_file,
|
|
categories=categories,
|
|
max_len=max_len)
|
|
|
|
test_pipeline = [
|
|
dict(
|
|
type='NerTransform',
|
|
label_convertor=ner_convertor,
|
|
max_len=max_len),
|
|
dict(type='ToTensorNER')
|
|
]
|
|
dataset = NerDataset(ann_file, loader, pipeline=test_pipeline)
|
|
|
|
# test pre_pipeline
|
|
img_info = dataset.data_infos[0]
|
|
results = dict(img_info=img_info)
|
|
dataset.pre_pipeline(results)
|
|
|
|
# test prepare_train_img
|
|
dataset.prepare_train_img(0)
|
|
|
|
# test evaluation
|
|
result = [[['address', 15, 16], ['name', 0, 2]]]
|
|
|
|
dataset.evaluate(result)
|
|
|
|
# test pred convert2entity function
|
|
pred = [
|
|
21, 7, 17, 17, 21, 21, 21, 21, 21, 21, 13, 21, 21, 21, 21, 21, 1, 11,
|
|
21, 21, 7, 17, 17, 21, 21, 21, 21, 21, 21, 13, 21, 21, 21, 21, 21, 1,
|
|
11, 21, 21, 21, 21, 21, 21, 21, 21, 21, 21, 21, 21, 21, 21, 21, 21, 21,
|
|
21, 21, 21, 21, 21, 21, 21, 21, 21, 21, 21, 21, 21, 21, 21, 21, 21, 21,
|
|
21, 21, 21, 21, 21, 21, 21, 21, 21, 21, 21, 21, 21, 21, 21, 21, 21, 21,
|
|
21, 21, 21, 21, 21, 21, 21, 21, 21, 21, 21, 21, 21, 21, 21, 21, 21, 21,
|
|
21, 21, 21, 21, 21, 21, 21, 21, 21, 21, 21, 21, 21, 21, 21, 21, 21, 21,
|
|
21, 21, 21, 21, 21, 21, 21, 21, 21, 21, 21, 21, 1, 21, 21, 21, 21, 21,
|
|
21, 21, 21, 21, 21, 21, 21, 21, 21, 21, 21, 21, 21, 1, 21, 21, 21, 21,
|
|
21, 21
|
|
]
|
|
preds = [pred[:128]]
|
|
mask = [0] * 128
|
|
for i in range(10):
|
|
mask[i] = 1
|
|
assert len(preds[0]) == len(mask)
|
|
masks = torch.tensor([mask])
|
|
convertor = NerConvertor(
|
|
annotation_type='bio',
|
|
vocab_file=vocab_file,
|
|
categories=categories,
|
|
max_len=128)
|
|
all_entities = convertor.convert_pred2entities(preds=preds, masks=masks)
|
|
assert len(all_entities[0][0]) == 3
|
|
|
|
tmp_dir.cleanup()
|