mirror of https://github.com/open-mmlab/mmocr.git
Ner task (#148)
* update ner standard code format * add pytest * fix pre-commit * Annotate the dataset section * fix pre-commit for dataset * rm big files and add comments in dataset * rename configs for ner task * minor changes if metric * Note modification * fix pre-commit * detail modification * rm transform * rm magic number * fix warnings in pylint * fix pre-commit * correct help info * rename model files * rename err fixed * 428_tag * Adjust to more general pipline * update unit test rate * update * Unit test coverage over 90% and add Readme * modify details * fix precommit * update * fix pre-commit * update * update * update * update result * update readme * update baseline config * update config and small minor changes * minor changes in readme and etc. * back to original * update toy config * upload model and log * fix pytest * Modify the notes. * fix readme * Delete Chinese punctuation * add demo and fix some logic and naming problems * add To_tensor transformer for ner and load pretrained model in config * delete extra lines * split ner loss to MaskedCrossEntropyLoss and MaskedFocalLoss * update config * fix err * updata * modify noqa * update new model report * fix err in ner demo * Update ner_dataset.py * Update test_ner_dataset.py * Update ner_dataset.py * Update ner_transforms.py * rm toy config and data * add comment * add empty * fix conflict * fix precommit * fix pytest * fix pytest err * Update ner_dataset.py * change dataset name to cluener2020 * move the postprocess in metric to convertor * rm __init__ etc. * precommit * add discription in loss * add auto download * add http * update * remove some 'issert' * replace unsqueeze * update config * update doc and bert.py * update * update demo code Co-authored-by: weihuaqiang <weihuaqiang@sensetime.com> Co-authored-by: Hongbin Sun <hongbin306@gmail.com>pull/204/head
parent
2414c65577
commit
24c590bb04
|
@ -0,0 +1,6 @@
|
|||
# optimizer
|
||||
optimizer = dict(type='Adadelta', lr=0.5)
|
||||
optimizer_config = dict(grad_clip=dict(max_norm=0.5))
|
||||
# learning policy
|
||||
lr_config = dict(policy='step', step=[8, 14, 16])
|
||||
total_epochs = 18
|
|
@ -0,0 +1,34 @@
|
|||
# Chinese Named Entity Recognition using BERT + Softmax
|
||||
|
||||
## Introduction
|
||||
|
||||
[ALGORITHM]
|
||||
```bibtex
|
||||
@article{devlin2018bert,
|
||||
title={Bert: Pre-training of deep bidirectional transformers for language understanding},
|
||||
author={Devlin, Jacob and Chang, Ming-Wei and Lee, Kenton and Toutanova, Kristina},
|
||||
journal={arXiv preprint arXiv:1810.04805},
|
||||
year={2018}
|
||||
}
|
||||
```
|
||||
|
||||
## Dataset
|
||||
|
||||
### Train Dataset
|
||||
|
||||
| trainset | text_num | entity_num |
|
||||
| :--------: | :----------: | :--------: |
|
||||
| CLUENER2020 | 10748 | 23338 |
|
||||
|
||||
### Test Dataset
|
||||
|
||||
| testset | text_num | entity_num |
|
||||
| :--------: | :----------: | :--------: |
|
||||
| CLUENER2020 | 1343 | 2982 |
|
||||
|
||||
|
||||
## Results and models
|
||||
|
||||
| Method |Pretrain| Precision | Recall | F1-Score | Download |
|
||||
| :--------------------------------------------------------------------: |:-----------:|:-----------:| :--------:| :-------: | :-------------------------------------: |
|
||||
| [bert_softmax](/configs/ner/bert_softmax/bert_softmax_cluener_18e.py)| [pretrain](https://download.openmmlab.com/mmocr/ner/bert_softmax/bert_pretrain.pth) |0.7885 | 0.7998 | 0.7941 | [model](https://download.openmmlab.com/mmocr/ner/bert_softmax/bert_softmax_cluener-eea70ea2.pth) \| [log](https://download.openmmlab.com/mmocr/ner/bert_softmax/20210514_172645.log.json) |
|
|
@ -0,0 +1,66 @@
|
|||
_base_ = [
|
||||
'../../_base_/schedules/schedule_adadelta_18e.py',
|
||||
'../../_base_/default_runtime.py'
|
||||
]
|
||||
|
||||
categories = [
|
||||
'address', 'book', 'company', 'game', 'government', 'movie', 'name',
|
||||
'organization', 'position', 'scene'
|
||||
]
|
||||
|
||||
test_ann_file = 'data/cluener2020/dev.json'
|
||||
train_ann_file = 'data/cluener2020/train.json'
|
||||
vocab_file = 'data/cluener2020/vocab.txt'
|
||||
|
||||
max_len = 128
|
||||
loader = dict(
|
||||
type='HardDiskLoader',
|
||||
repeat=1,
|
||||
parser=dict(type='LineJsonParser', keys=['text', 'label']))
|
||||
|
||||
ner_convertor = dict(
|
||||
type='NerConvertor',
|
||||
annotation_type='bio',
|
||||
vocab_file=vocab_file,
|
||||
categories=categories,
|
||||
max_len=max_len)
|
||||
|
||||
test_pipeline = [
|
||||
dict(type='NerTransform', label_convertor=ner_convertor, max_len=max_len),
|
||||
dict(type='ToTensorNER')
|
||||
]
|
||||
|
||||
train_pipeline = [
|
||||
dict(type='NerTransform', label_convertor=ner_convertor, max_len=max_len),
|
||||
dict(type='ToTensorNER')
|
||||
]
|
||||
dataset_type = 'NerDataset'
|
||||
|
||||
train = dict(
|
||||
type=dataset_type,
|
||||
ann_file=train_ann_file,
|
||||
loader=loader,
|
||||
pipeline=train_pipeline,
|
||||
test_mode=False)
|
||||
|
||||
test = dict(
|
||||
type=dataset_type,
|
||||
ann_file=test_ann_file,
|
||||
loader=loader,
|
||||
pipeline=test_pipeline,
|
||||
test_mode=True)
|
||||
data = dict(
|
||||
samples_per_gpu=8, workers_per_gpu=2, train=train, val=test, test=test)
|
||||
|
||||
evaluation = dict(interval=1, metric='f1-score')
|
||||
|
||||
model = dict(
|
||||
type='NerClassifier',
|
||||
pretrained='https://download.openmmlab.com/mmocr/ner/'
|
||||
'bert_softmax/bert_pretrain.pth',
|
||||
encoder=dict(type='BertEncoder', max_position_embeddings=512),
|
||||
decoder=dict(type='FCDecoder'),
|
||||
loss=dict(type='MaskedCrossEntropyLoss'),
|
||||
label_convertor=ner_convertor)
|
||||
|
||||
test_cfg = None
|
|
@ -0,0 +1,35 @@
|
|||
from argparse import ArgumentParser
|
||||
|
||||
from mmdet.apis import init_detector
|
||||
from mmocr.apis.inference import text_model_inference
|
||||
|
||||
from mmocr.datasets import build_dataset # NOQA
|
||||
from mmocr.models import build_detector # NOQA
|
||||
|
||||
|
||||
def main():
|
||||
parser = ArgumentParser()
|
||||
parser.add_argument('config', help='Config file.')
|
||||
parser.add_argument('checkpoint', help='Checkpoint file.')
|
||||
parser.add_argument(
|
||||
'--device', default='cuda:0', help='Device used for inference.')
|
||||
args = parser.parse_args()
|
||||
|
||||
# build the model from a config file and a checkpoint file
|
||||
model = init_detector(args.config, args.checkpoint, device=args.device)
|
||||
if model.cfg.data.test['type'] == 'ConcatDataset':
|
||||
model.cfg.data.test.pipeline = model.cfg.data.test['datasets'][
|
||||
0].pipeline
|
||||
|
||||
# test a single text
|
||||
input_sentence = input('Please enter a sentence you want to test: ')
|
||||
result = text_model_inference(model, input_sentence)
|
||||
|
||||
# show the results
|
||||
for pred_entities in result:
|
||||
for entity in pred_entities:
|
||||
print(f'{entity[0]}: {input_sentence[entity[1]:entity[2] + 1]}')
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
|
@ -234,3 +234,24 @@ The structure of the key information extraction dataset directory is organized a
|
|||
```
|
||||
|
||||
- Download [wildreceipt.tar](https://download.openmmlab.com/mmocr/data/wildreceipt.tar)
|
||||
|
||||
|
||||
## Named Entity Recognition
|
||||
|
||||
### CLUENER2020
|
||||
|
||||
The structure of the named entity recognition dataset directory is organized as follows.
|
||||
|
||||
```text
|
||||
└── cluener2020
|
||||
├── cluener_predict.json
|
||||
├── dev.json
|
||||
├── README.md
|
||||
├── test.json
|
||||
├── train.json
|
||||
└── vocab.txt
|
||||
|
||||
```
|
||||
- Download [cluener_public.zip](https://storage.googleapis.com/cluebenchmark/tasks/cluener_public.zip)
|
||||
|
||||
- Download [vocab.txt](https://download.openmmlab.com/mmocr/data/cluener2020/vocab.txt) and move `vocab.txt` to `cluener2020`
|
||||
|
|
|
@ -3,9 +3,11 @@
|
|||
sed -i '$a\\n' ../configs/kie/*/*.md
|
||||
sed -i '$a\\n' ../configs/textdet/*/*.md
|
||||
sed -i '$a\\n' ../configs/textrecog/*/*.md
|
||||
sed -i '$a\\n' ../configs/ner/*/*.md
|
||||
|
||||
# gather models
|
||||
cat ../configs/kie/*/*.md | sed "s/md###t/html#t/g" | sed "s/#/#&/" | sed '1i\# Key Information Extraction Models' | sed 's/](\/docs\//](/g' | sed 's=](/=](https://github.com/open-mmlab/mmocr/tree/master/=g' >kie_models.md
|
||||
cat ../configs/textdet/*/*.md | sed "s/md###t/html#t/g" | sed "s/#/#&/" | sed '1i\# Text Detection Models' | sed 's/](\/docs\//](/g' | sed 's=](/=](https://github.com/open-mmlab/mmocr/tree/master/=g' >textdet_models.md
|
||||
cat ../configs/textrecog/*/*.md | sed "s/md###t/html#t/g" | sed "s/#/#&/" | sed '1i\# Text Recognition Models' | sed 's/](\/docs\//](/g' | sed 's=](/=](https://github.com/open-mmlab/mmocr/tree/master/=g' >textrecog_models.md
|
||||
cat ../configs/ner/*/*.md | sed "s/md###t/html#t/g" | sed "s/#/#&/" | sed '1i\# Named Entity Recognition Models' | sed 's/](\/docs\//](/g' | sed 's=](/=](https://github.com/open-mmlab/mmocr/tree/master/=g' >ner_models.md
|
||||
cat ../demo/docs/*_demo.md | sed "s/#/#&/" | sed "s/md###t/html#t/g" | sed '1i\# Demo' | sed 's/](\/docs\//](/g' | sed 's=](/=](https://github.com/open-mmlab/mmocr/tree/master/=g' >demo.md
|
||||
|
|
|
@ -128,3 +128,40 @@ def model_inference(model, imgs, batch_mode=False):
|
|||
return results[0]
|
||||
else:
|
||||
return results
|
||||
|
||||
|
||||
def text_model_inference(model, input_sentence):
|
||||
"""Inference text(s) with the entity recognizer.
|
||||
|
||||
Args:
|
||||
model (nn.Module): The loaded recognizer.
|
||||
input_sentence (str): A text entered by the user.
|
||||
|
||||
Returns:
|
||||
result (dict): Predicted results.
|
||||
"""
|
||||
|
||||
assert isinstance(input_sentence, str)
|
||||
|
||||
cfg = model.cfg
|
||||
test_pipeline = Compose(cfg.data.test.pipeline)
|
||||
data = {'text': input_sentence, 'label': {}}
|
||||
|
||||
# build the data pipeline
|
||||
data = test_pipeline(data)
|
||||
if isinstance(data['img_metas'], dict):
|
||||
img_metas = data['img_metas']
|
||||
else:
|
||||
img_metas = data['img_metas'].data
|
||||
|
||||
assert isinstance(img_metas, dict)
|
||||
img_metas = {
|
||||
'input_ids': img_metas['input_ids'].unsqueeze(0),
|
||||
'attention_masks': img_metas['attention_masks'].unsqueeze(0),
|
||||
'token_type_ids': img_metas['token_type_ids'].unsqueeze(0),
|
||||
'labels': img_metas['labels'].unsqueeze(0)
|
||||
}
|
||||
# forward the model
|
||||
with torch.no_grad():
|
||||
result = model(None, img_metas, return_loss=False)
|
||||
return result
|
||||
|
|
|
@ -2,9 +2,10 @@ from .hmean import eval_hmean
|
|||
from .hmean_ic13 import eval_hmean_ic13
|
||||
from .hmean_iou import eval_hmean_iou
|
||||
from .kie_metric import compute_f1_score
|
||||
from .ner_metric import eval_ner_f1
|
||||
from .ocr_metric import eval_ocr_metric
|
||||
|
||||
__all__ = [
|
||||
'eval_hmean_ic13', 'eval_hmean_iou', 'eval_ocr_metric', 'eval_hmean',
|
||||
'compute_f1_score'
|
||||
'compute_f1_score', 'eval_ner_f1'
|
||||
]
|
||||
|
|
|
@ -0,0 +1,113 @@
|
|||
from collections import Counter
|
||||
|
||||
|
||||
def gt_label2entity(gt_infos):
|
||||
"""Get all entities from ground truth infos.
|
||||
Args:
|
||||
gt_infos (list[dict]): Groudtruth infomation contains text and label.
|
||||
Returns:
|
||||
gt_entities (list[list]): Original labeled entities in groundtruth.
|
||||
[[category,start_position,end_position]]
|
||||
"""
|
||||
gt_entities = []
|
||||
for gt_info in gt_infos:
|
||||
line_entities = []
|
||||
label = gt_info['label']
|
||||
for key, value in label.items():
|
||||
for _, places in value.items():
|
||||
for place in places:
|
||||
line_entities.append([key, place[0], place[1]])
|
||||
gt_entities.append(line_entities)
|
||||
return gt_entities
|
||||
|
||||
|
||||
def _compute_f1(origin, found, right):
|
||||
"""Calculate recall, precision, f1-score.
|
||||
|
||||
Args:
|
||||
origin (int): Original entities in groundtruth.
|
||||
found (int): Predicted entities from model.
|
||||
right (int): Predicted entities that
|
||||
can match to the original annotation.
|
||||
Returns:
|
||||
recall (float): Metric of recall.
|
||||
precision (float): Metric of precision.
|
||||
f1 (float): Metric of f1-score.
|
||||
"""
|
||||
recall = 0 if origin == 0 else (right / origin)
|
||||
precision = 0 if found == 0 else (right / found)
|
||||
f1 = 0. if recall + precision == 0 else (2 * precision * recall) / (
|
||||
precision + recall)
|
||||
return recall, precision, f1
|
||||
|
||||
|
||||
def compute_f1_all(pred_entities, gt_entities):
|
||||
"""Calculate precision, recall and F1-score for all categories.
|
||||
|
||||
Args:
|
||||
pred_entities: The predicted entities from model.
|
||||
gt_entities: The entities of ground truth file.
|
||||
Returns:
|
||||
class_info (dict): precision,recall, f1-score in total
|
||||
and each catogories.
|
||||
"""
|
||||
origins = []
|
||||
founds = []
|
||||
rights = []
|
||||
for i, _ in enumerate(pred_entities):
|
||||
origins.extend(gt_entities[i])
|
||||
founds.extend(pred_entities[i])
|
||||
rights.extend([
|
||||
pre_entity for pre_entity in pred_entities[i]
|
||||
if pre_entity in gt_entities[i]
|
||||
])
|
||||
|
||||
class_info = {}
|
||||
origin_counter = Counter([x[0] for x in origins])
|
||||
found_counter = Counter([x[0] for x in founds])
|
||||
right_counter = Counter([x[0] for x in rights])
|
||||
for type_, count in origin_counter.items():
|
||||
origin = count
|
||||
found = found_counter.get(type_, 0)
|
||||
right = right_counter.get(type_, 0)
|
||||
recall, precision, f1 = _compute_f1(origin, found, right)
|
||||
class_info[type_] = {
|
||||
'precision': precision,
|
||||
'recall': recall,
|
||||
'f1-score': f1
|
||||
}
|
||||
origin = len(origins)
|
||||
found = len(founds)
|
||||
right = len(rights)
|
||||
recall, precision, f1 = _compute_f1(origin, found, right)
|
||||
class_info['all'] = {
|
||||
'precision': precision,
|
||||
'recall': recall,
|
||||
'f1-score': f1
|
||||
}
|
||||
return class_info
|
||||
|
||||
|
||||
def eval_ner_f1(results, gt_infos):
|
||||
"""Evaluate for ner task.
|
||||
|
||||
Args:
|
||||
results (list): Predict results of entities.
|
||||
gt_infos (list[dict]): Groudtruth infomation which contains
|
||||
text and label.
|
||||
Returns:
|
||||
class_info (dict): precision,recall, f1-score of total
|
||||
and each catogory.
|
||||
"""
|
||||
assert len(results) == len(gt_infos)
|
||||
gt_entities = gt_label2entity(gt_infos)
|
||||
pred_entities = []
|
||||
for i, gt_info in enumerate(gt_infos):
|
||||
line_entities = []
|
||||
for result in results[i]:
|
||||
line_entities.append(result)
|
||||
pred_entities.append(line_entities)
|
||||
assert len(pred_entities) == len(gt_entities)
|
||||
class_info = compute_f1_all(pred_entities, gt_entities)
|
||||
|
||||
return class_info
|
|
@ -3,6 +3,7 @@ from . import utils
|
|||
from .base_dataset import BaseDataset
|
||||
from .icdar_dataset import IcdarDataset
|
||||
from .kie_dataset import KIEDataset
|
||||
from .ner_dataset import NerDataset
|
||||
from .ocr_dataset import OCRDataset
|
||||
from .ocr_seg_dataset import OCRSegDataset
|
||||
from .pipelines import CustomFormatBundle, DBNetTargets, FCENetTargets
|
||||
|
@ -13,7 +14,8 @@ from .utils import * # NOQA
|
|||
__all__ = [
|
||||
'DATASETS', 'IcdarDataset', 'build_dataloader', 'build_dataset',
|
||||
'BaseDataset', 'OCRDataset', 'TextDetDataset', 'CustomFormatBundle',
|
||||
'DBNetTargets', 'OCRSegDataset', 'KIEDataset', 'FCENetTargets'
|
||||
'DBNetTargets', 'OCRSegDataset', 'KIEDataset', 'FCENetTargets',
|
||||
'NerDataset'
|
||||
]
|
||||
|
||||
__all__ += utils.__all__
|
||||
|
|
|
@ -0,0 +1,47 @@
|
|||
from mmdet.datasets.builder import DATASETS
|
||||
from mmocr.core.evaluation.ner_metric import eval_ner_f1
|
||||
from mmocr.datasets.base_dataset import BaseDataset
|
||||
|
||||
|
||||
@DATASETS.register_module()
|
||||
class NerDataset(BaseDataset):
|
||||
"""Custom dataset for named entity recognition tasks.
|
||||
|
||||
Args:
|
||||
ann_file (txt): Annotation file path.
|
||||
loader (dict): Dictionary to construct loader
|
||||
to load annotation infos.
|
||||
pipeline (list[dict]): Processing pipeline.
|
||||
test_mode (bool, optional): If True, try...except will
|
||||
be turned off in __getitem__.
|
||||
"""
|
||||
|
||||
def prepare_train_img(self, index):
|
||||
"""Get training data and annotations after pipeline.
|
||||
|
||||
Args:
|
||||
index (int): Index of data.
|
||||
|
||||
Returns:
|
||||
dict: Training data and annotation after pipeline with new keys \
|
||||
introduced by pipeline.
|
||||
"""
|
||||
ann_info = self.data_infos[index]
|
||||
|
||||
return self.pipeline(ann_info)
|
||||
|
||||
def evaluate(self, results, metric=None, logger=None, **kwargs):
|
||||
"""Evaluate the dataset.
|
||||
|
||||
Args:
|
||||
results (list): Testing results of the dataset.
|
||||
metric (str | list[str]): Metrics to be evaluated.
|
||||
logger (logging.Logger | str | None): Logger used for printing
|
||||
related information during evaluation. Default: None.
|
||||
Returns:
|
||||
info (dict): A dict containing the following keys:
|
||||
'acc', 'recall', 'f1-score'.
|
||||
"""
|
||||
gt_infos = list(self.data_infos)
|
||||
eval_results = eval_ner_f1(results, gt_infos)
|
||||
return eval_results
|
|
@ -3,6 +3,7 @@ from .custom_format_bundle import CustomFormatBundle
|
|||
from .dbnet_transforms import EastRandomCrop, ImgAug
|
||||
from .kie_transforms import KIEFormatBundle
|
||||
from .loading import LoadImageFromNdarray, LoadTextAnnotations
|
||||
from .ner_transforms import NerTransform, ToTensorNER
|
||||
from .ocr_seg_targets import OCRSegTargets
|
||||
from .ocr_transforms import (FancyPCA, NormalizeOCR, OnlineCropOCR,
|
||||
OpencvToPil, PilToOpencv, RandomPaddingOCR,
|
||||
|
@ -24,5 +25,5 @@ __all__ = [
|
|||
'ImgAug', 'EastRandomCrop', 'RandomRotateImageBox', 'OpencvToPil',
|
||||
'PilToOpencv', 'KIEFormatBundle', 'SquareResizePad', 'TextSnakeTargets',
|
||||
'sort_vertex', 'LoadImageFromNdarray', 'sort_vertex8', 'FCENetTargets',
|
||||
'RandomScaling', 'RandomCropFlip'
|
||||
'RandomScaling', 'RandomCropFlip', 'NerTransform', 'ToTensorNER'
|
||||
]
|
||||
|
|
|
@ -0,0 +1,62 @@
|
|||
import torch
|
||||
|
||||
from mmdet.datasets.builder import PIPELINES
|
||||
from mmocr.models.builder import build_convertor
|
||||
|
||||
|
||||
@PIPELINES.register_module()
|
||||
class NerTransform:
|
||||
"""Convert text to ID and entity in ground truth to label ID. The masks and
|
||||
tokens are generated at the same time. The four parameters will be used as
|
||||
input to the model.
|
||||
|
||||
Args:
|
||||
label_convertor: Convert text to ID and entity
|
||||
in ground truth to label ID.
|
||||
max_len (int): Limited maximum input length.
|
||||
"""
|
||||
|
||||
def __init__(self, label_convertor, max_len):
|
||||
self.label_convertor = build_convertor(label_convertor)
|
||||
self.max_len = max_len
|
||||
|
||||
def __call__(self, results):
|
||||
texts = results['text']
|
||||
input_ids = self.label_convertor.convert_text2id(texts)
|
||||
labels = self.label_convertor.convert_entity2label(
|
||||
results['label'], len(texts))
|
||||
|
||||
attention_mask = [0] * self.max_len
|
||||
token_type_ids = [0] * self.max_len
|
||||
# The beginning and end IDs are added to the ID,
|
||||
# so the mask length is increased by 2
|
||||
for i in range(len(texts) + 2):
|
||||
attention_mask[i] = 1
|
||||
results = dict(
|
||||
labels=labels,
|
||||
texts=texts,
|
||||
input_ids=input_ids,
|
||||
attention_mask=attention_mask,
|
||||
token_type_ids=token_type_ids)
|
||||
return results
|
||||
|
||||
|
||||
@PIPELINES.register_module()
|
||||
class ToTensorNER:
|
||||
"""Convert data with ``list`` type to tensor."""
|
||||
|
||||
def __call__(self, results):
|
||||
|
||||
input_ids = torch.tensor(results['input_ids'])
|
||||
labels = torch.tensor(results['labels'])
|
||||
attention_masks = torch.tensor(results['attention_mask'])
|
||||
token_type_ids = torch.tensor(results['token_type_ids'])
|
||||
|
||||
results = dict(
|
||||
img=[],
|
||||
img_metas=dict(
|
||||
input_ids=input_ids,
|
||||
attention_masks=attention_masks,
|
||||
labels=labels,
|
||||
token_type_ids=token_type_ids))
|
||||
return results
|
|
@ -7,6 +7,7 @@ from .builder import (CONVERTORS, DECODERS, ENCODERS, PREPROCESSOR,
|
|||
|
||||
from .common import * # NOQA
|
||||
from .kie import * # NOQA
|
||||
from .ner import * # NOQA
|
||||
from .textdet import * # NOQA
|
||||
from .textrecog import * # NOQA
|
||||
|
||||
|
|
|
@ -1,3 +1,4 @@
|
|||
from .dice_loss import DiceLoss
|
||||
from .focal_loss import FocalLoss
|
||||
|
||||
__all__ = ['DiceLoss']
|
||||
__all__ = ['DiceLoss', 'FocalLoss']
|
||||
|
|
|
@ -0,0 +1,30 @@
|
|||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
|
||||
class FocalLoss(nn.Module):
|
||||
"""Multi-class Focal loss implementation.
|
||||
|
||||
Args:
|
||||
gamma (float): The larger the gamma, the smaller
|
||||
the loss weight of easier samples.
|
||||
weight (float): A manual rescaling weight given to each
|
||||
class.
|
||||
ignore_index (int): Specifies a target value that is ignored
|
||||
and does not contribute to the input gradient.
|
||||
"""
|
||||
|
||||
def __init__(self, gamma=2, weight=None, ignore_index=-100):
|
||||
super().__init__()
|
||||
self.gamma = gamma
|
||||
self.weight = weight
|
||||
self.ignore_index = ignore_index
|
||||
|
||||
def forward(self, input, target):
|
||||
logit = F.log_softmax(input, dim=1)
|
||||
pt = torch.exp(logit)
|
||||
logit = (1 - pt)**self.gamma * logit
|
||||
loss = F.nll_loss(
|
||||
logit, target, self.weight, ignore_index=self.ignore_index)
|
||||
return loss
|
|
@ -0,0 +1,5 @@
|
|||
from .classifer import * # noqa: F401,F403
|
||||
from .convertors import * # noqa: F401,F403
|
||||
from .decoder import * # noqa: F401,F403
|
||||
from .encoder import * # noqa: F401,F403
|
||||
from .loss import * # noqa: F401,F403
|
|
@ -0,0 +1,3 @@
|
|||
from .ner_classifier import NerClassifier
|
||||
|
||||
__all__ = ['NerClassifier']
|
|
@ -0,0 +1,52 @@
|
|||
from mmdet.models.builder import DETECTORS, build_loss
|
||||
from mmocr.models.builder import build_convertor, build_decoder, build_encoder
|
||||
from mmocr.models.textrecog.recognizer.base import BaseRecognizer
|
||||
|
||||
|
||||
@DETECTORS.register_module()
|
||||
class NerClassifier(BaseRecognizer):
|
||||
"""Base class for NER classifier."""
|
||||
|
||||
def __init__(self,
|
||||
encoder,
|
||||
decoder,
|
||||
loss,
|
||||
label_convertor,
|
||||
train_cfg=None,
|
||||
test_cfg=None,
|
||||
pretrained=None):
|
||||
super().__init__()
|
||||
self.label_convertor = build_convertor(label_convertor)
|
||||
|
||||
encoder.update(pretrained=pretrained)
|
||||
self.encoder = build_encoder(encoder)
|
||||
|
||||
decoder.update(num_labels=self.label_convertor.num_labels)
|
||||
self.decoder = build_decoder(decoder)
|
||||
|
||||
loss.update(num_labels=self.label_convertor.num_labels)
|
||||
self.loss = build_loss(loss)
|
||||
|
||||
def extract_feat(self, imgs):
|
||||
"""Extract features from images."""
|
||||
raise NotImplementedError(
|
||||
'Extract feature module is not implemented yet.')
|
||||
|
||||
def forward_train(self, imgs, img_metas, **kwargs):
|
||||
encode_out = self.encoder(img_metas)
|
||||
logits, _ = self.decoder(encode_out)
|
||||
loss = self.loss(logits, img_metas)
|
||||
return loss
|
||||
|
||||
def forward_test(self, imgs, img_metas, **kwargs):
|
||||
encode_out = self.encoder(img_metas)
|
||||
_, preds = self.decoder(encode_out)
|
||||
pred_entities = self.label_convertor.convert_pred2entities(
|
||||
preds, img_metas['attention_masks'])
|
||||
return pred_entities
|
||||
|
||||
def aug_test(self, imgs, img_metas, **kwargs):
|
||||
raise NotImplementedError('Augmentation test is not implemented yet.')
|
||||
|
||||
def simple_test(self, img, img_metas, **kwargs):
|
||||
raise NotImplementedError('Simple test is not implemented yet.')
|
|
@ -0,0 +1,3 @@
|
|||
from .ner_convertor import NerConvertor
|
||||
|
||||
__all__ = ['NerConvertor']
|
|
@ -0,0 +1,171 @@
|
|||
import numpy as np
|
||||
|
||||
from mmocr.models.builder import CONVERTORS
|
||||
|
||||
|
||||
@CONVERTORS.register_module()
|
||||
class NerConvertor:
|
||||
"""Convert between text, index and tensor for NER pipeline.
|
||||
|
||||
Args:
|
||||
annotation_type (str): BIO((B-begin, I-inside, O-outside)),
|
||||
BIOES(B-begin, I-inside, O-outside, E-end, S-single)
|
||||
vocab_file (str): File to convert words to ids.
|
||||
categories (list[str]): All entity categories supported by the model.
|
||||
max_len (int): The maximum length of the input text.
|
||||
unknown_id (int): For words that do not appear in vocab.txt.
|
||||
start_id (int): Each input is prefixed with an input ID.
|
||||
end_id (int): Each output is prefixed with an output ID.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
annotation_type='bio',
|
||||
vocab_file=None,
|
||||
categories=None,
|
||||
max_len=None,
|
||||
unknown_id=100,
|
||||
start_id=101,
|
||||
end_id=102):
|
||||
self.annotation_type = annotation_type
|
||||
self.categories = categories
|
||||
self.word2ids = {}
|
||||
self.max_len = max_len
|
||||
self.unknown_id = unknown_id
|
||||
self.start_id = start_id
|
||||
self.end_id = end_id
|
||||
assert self.max_len > 2
|
||||
assert self.annotation_type in ['bio', 'bioes']
|
||||
|
||||
lines = open(vocab_file, encoding='utf-8').readlines()
|
||||
self.vocab_size = len(lines)
|
||||
for i in range(len(lines)):
|
||||
self.word2ids.update({lines[i].rstrip(): i})
|
||||
|
||||
if self.annotation_type == 'bio':
|
||||
self.label2id_dict, self.id2label, self.ignore_id = \
|
||||
self._generate_labelid_dict()
|
||||
elif self.annotation_type == 'bioes':
|
||||
raise NotImplementedError('Bioes format is not surpported yet!')
|
||||
|
||||
assert self.ignore_id is not None
|
||||
assert self.id2label is not None
|
||||
self.num_labels = len(self.id2label)
|
||||
|
||||
def _generate_labelid_dict(self):
|
||||
"""Generate a dictionary that maps input to ID and ID to output."""
|
||||
num_classes = len(self.categories)
|
||||
label2id_dict = {}
|
||||
ignore_id = 2 * num_classes + 1
|
||||
id2label_dict = {
|
||||
0: 'X',
|
||||
ignore_id: 'O',
|
||||
2 * num_classes + 2: '[START]',
|
||||
2 * num_classes + 3: '[END]'
|
||||
}
|
||||
|
||||
for index, category in enumerate(self.categories):
|
||||
start_label = index + 1
|
||||
end_label = index + 1 + num_classes
|
||||
label2id_dict.update({category: [start_label, end_label]})
|
||||
id2label_dict.update({start_label: 'B-' + category})
|
||||
id2label_dict.update({end_label: 'I-' + category})
|
||||
|
||||
return label2id_dict, id2label_dict, ignore_id
|
||||
|
||||
def convert_text2id(self, text):
|
||||
"""Convert characters to ids.
|
||||
|
||||
If the input is uppercase,
|
||||
convert to lowercase first.
|
||||
Args:
|
||||
text (list[char]): Annotations of one paragraph.
|
||||
Returns:
|
||||
input_ids (list): Corresponding IDs after conversion.
|
||||
"""
|
||||
ids = []
|
||||
for word in text.lower():
|
||||
if word in self.word2ids:
|
||||
ids.append(self.word2ids[word])
|
||||
else:
|
||||
ids.append(self.unknown_id)
|
||||
# Text that exceeds the maximum length is truncated.
|
||||
valid_len = min(len(text), self.max_len)
|
||||
input_ids = [0] * self.max_len
|
||||
input_ids[0] = self.start_id
|
||||
for i in range(1, valid_len + 1):
|
||||
input_ids[i] = ids[i - 1]
|
||||
input_ids[i + 1] = self.end_id
|
||||
|
||||
return input_ids
|
||||
|
||||
def convert_entity2label(self, label, text_len):
|
||||
"""Convert labeled entities to ids.
|
||||
|
||||
Args:
|
||||
label (dict): Labels of entities.
|
||||
text_len (int): The length of input text.
|
||||
Returns:
|
||||
labels (list): Label ids of an input text.
|
||||
"""
|
||||
labels = [0] * self.max_len
|
||||
for j in range(min(text_len + 2, self.max_len)):
|
||||
labels[j] = self.ignore_id
|
||||
categorys = label
|
||||
for key in categorys:
|
||||
for text in categorys[key]:
|
||||
for place in categorys[key][text]:
|
||||
# Remove the label position beyond the maximum length.
|
||||
if place[0] + 1 < len(labels):
|
||||
labels[place[0] + 1] = self.label2id_dict[key][0]
|
||||
for i in range(place[0] + 1, place[1] + 1):
|
||||
if i + 1 < len(labels):
|
||||
labels[i + 1] = self.label2id_dict[key][1]
|
||||
return labels
|
||||
|
||||
def convert_pred2entities(self, preds, masks):
|
||||
"""Gets entities from preds.
|
||||
|
||||
Args:
|
||||
preds (list): Sequence of preds.
|
||||
masks (tensor): The valid part is 1 and the invalid part is 0.
|
||||
Returns:
|
||||
pred_entities (list): List of [[[entity_type,
|
||||
entity_start, entity_end]]].
|
||||
"""
|
||||
|
||||
masks = masks.detach().cpu().numpy()
|
||||
pred_entities = []
|
||||
assert isinstance(preds, list)
|
||||
for index, pred in enumerate(preds):
|
||||
entities = []
|
||||
entity = [-1, -1, -1]
|
||||
results = (masks[index][1:] * np.array(pred[1:])).tolist()
|
||||
for index, tag in enumerate(results):
|
||||
if not isinstance(tag, str):
|
||||
tag = self.id2label[tag]
|
||||
if self.annotation_type == 'bio':
|
||||
if tag.startswith('B-'):
|
||||
if entity[2] != -1 and entity[1] < entity[2]:
|
||||
entities.append(entity)
|
||||
entity = [-1, -1, -1]
|
||||
entity[1] = index
|
||||
entity[0] = tag.split('-')[1]
|
||||
entity[2] = index
|
||||
if index == len(results) - 1 and entity[1] < entity[2]:
|
||||
entities.append(entity)
|
||||
elif tag.startswith('I-') and entity[1] != -1:
|
||||
_type = tag.split('-')[1]
|
||||
if _type == entity[0]:
|
||||
entity[2] = index
|
||||
|
||||
if index == len(results) - 1 and entity[1] < entity[2]:
|
||||
entities.append(entity)
|
||||
else:
|
||||
if entity[2] != -1 and entity[1] < entity[2]:
|
||||
entities.append(entity)
|
||||
entity = [-1, -1, -1]
|
||||
else:
|
||||
raise NotImplementedError(
|
||||
'The data format is not surpported yet!')
|
||||
pred_entities.append(entities)
|
||||
return pred_entities
|
|
@ -0,0 +1,3 @@
|
|||
from .fc_decoder import FCDecoder
|
||||
|
||||
__all__ = ['FCDecoder']
|
|
@ -0,0 +1,45 @@
|
|||
import numpy as np
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from mmcv.cnn import uniform_init, xavier_init
|
||||
|
||||
from mmocr.models.builder import DECODERS
|
||||
|
||||
|
||||
@DECODERS.register_module()
|
||||
class FCDecoder(nn.Module):
|
||||
"""FC Decoder class for Ner.
|
||||
|
||||
Args:
|
||||
num_labels (int): Number of categories mapped by entity label.
|
||||
hidden_dropout_prob (float): The dropout probability of hidden layer.
|
||||
hidden_size (int): Hidden layer output layer channels.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
num_labels=None,
|
||||
hidden_dropout_prob=0.1,
|
||||
hidden_size=768):
|
||||
super().__init__()
|
||||
self.num_labels = num_labels
|
||||
|
||||
self.dropout = nn.Dropout(hidden_dropout_prob)
|
||||
self.classifier = nn.Linear(hidden_size, self.num_labels)
|
||||
self.init_weights()
|
||||
|
||||
def forward(self, outputs):
|
||||
sequence_output = outputs[0]
|
||||
sequence_output = self.dropout(sequence_output)
|
||||
logits = self.classifier(sequence_output)
|
||||
softmax = F.softmax(logits, dim=2)
|
||||
preds = softmax.detach().cpu().numpy()
|
||||
preds = np.argmax(preds, axis=2).tolist()
|
||||
return logits, preds
|
||||
|
||||
def init_weights(self):
|
||||
|
||||
for m in self.modules():
|
||||
if isinstance(m, nn.Conv2d):
|
||||
xavier_init(m)
|
||||
elif isinstance(m, nn.BatchNorm2d):
|
||||
uniform_init(m)
|
|
@ -0,0 +1,3 @@
|
|||
from .bert_encoder import BertEncoder
|
||||
|
||||
__all__ = ['BertEncoder']
|
|
@ -0,0 +1,87 @@
|
|||
import torch.nn as nn
|
||||
from mmcv.cnn import uniform_init, xavier_init
|
||||
from mmcv.runner import load_checkpoint
|
||||
|
||||
from mmdet.utils import get_root_logger
|
||||
from mmocr.models.builder import ENCODERS
|
||||
from mmocr.models.ner.utils.bert import BertModel
|
||||
|
||||
|
||||
@ENCODERS.register_module()
|
||||
class BertEncoder(nn.Module):
|
||||
"""Bert encoder
|
||||
Args:
|
||||
num_hidden_layers (int): The number of hidden layers.
|
||||
initializer_range (float):
|
||||
vocab_size (int): Number of words supported.
|
||||
hidden_size (int): Hidden size.
|
||||
max_position_embeddings (int): Max positions embedding size.
|
||||
type_vocab_size (int): The size of type_vocab.
|
||||
layer_norm_eps (float): Epsilon of layer norm.
|
||||
hidden_dropout_prob (float): The dropout probability of hidden layer.
|
||||
output_attentions (bool): Whether use the attentions in output.
|
||||
output_hidden_states (bool): Whether use the hidden_states in output.
|
||||
num_attention_heads (int): The number of attention heads.
|
||||
attention_probs_dropout_prob (float): The dropout probability
|
||||
of attention.
|
||||
intermediate_size (int): The size of intermediate layer.
|
||||
hidden_act (str): Hidden layer activation.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
num_hidden_layers=12,
|
||||
initializer_range=0.02,
|
||||
vocab_size=21128,
|
||||
hidden_size=768,
|
||||
max_position_embeddings=128,
|
||||
type_vocab_size=2,
|
||||
layer_norm_eps=1e-12,
|
||||
hidden_dropout_prob=0.1,
|
||||
output_attentions=False,
|
||||
output_hidden_states=False,
|
||||
num_attention_heads=12,
|
||||
attention_probs_dropout_prob=0.1,
|
||||
intermediate_size=3072,
|
||||
hidden_act='gelu_new',
|
||||
pretrained=None):
|
||||
super().__init__()
|
||||
self.bert = BertModel(
|
||||
num_hidden_layers=num_hidden_layers,
|
||||
initializer_range=initializer_range,
|
||||
vocab_size=vocab_size,
|
||||
hidden_size=hidden_size,
|
||||
max_position_embeddings=max_position_embeddings,
|
||||
type_vocab_size=type_vocab_size,
|
||||
layer_norm_eps=layer_norm_eps,
|
||||
hidden_dropout_prob=hidden_dropout_prob,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
num_attention_heads=num_attention_heads,
|
||||
attention_probs_dropout_prob=attention_probs_dropout_prob,
|
||||
intermediate_size=intermediate_size,
|
||||
hidden_act=hidden_act)
|
||||
self.init_weights(pretrained=pretrained)
|
||||
|
||||
def forward(self, results):
|
||||
|
||||
device = next(self.bert.parameters()).device
|
||||
input_ids = results['input_ids'].to(device)
|
||||
attention_masks = results['attention_masks'].to(device)
|
||||
token_type_ids = results['token_type_ids'].to(device)
|
||||
|
||||
outputs = self.bert(
|
||||
input_ids=input_ids,
|
||||
attention_masks=attention_masks,
|
||||
token_type_ids=token_type_ids)
|
||||
return outputs
|
||||
|
||||
def init_weights(self, pretrained=None):
|
||||
if pretrained is not None:
|
||||
logger = get_root_logger()
|
||||
load_checkpoint(self, pretrained, strict=False, logger=logger)
|
||||
else:
|
||||
for m in self.modules():
|
||||
if isinstance(m, nn.Conv2d):
|
||||
xavier_init(m)
|
||||
elif isinstance(m, nn.BatchNorm2d):
|
||||
uniform_init(m)
|
|
@ -0,0 +1,4 @@
|
|||
from .masked_cross_entropy_loss import MaskedCrossEntropyLoss
|
||||
from .masked_focal_loss import MaskedFocalLoss
|
||||
|
||||
__all__ = ['MaskedCrossEntropyLoss', 'MaskedFocalLoss']
|
|
@ -0,0 +1,55 @@
|
|||
from torch import nn
|
||||
from torch.nn import CrossEntropyLoss
|
||||
|
||||
from mmdet.models.builder import LOSSES
|
||||
|
||||
|
||||
@LOSSES.register_module()
|
||||
class MaskedCrossEntropyLoss(nn.Module):
|
||||
"""The implementation of masked cross entropy loss.
|
||||
|
||||
The mask has 1 for real tokens and 0 for padding tokens,
|
||||
which only keep active parts of the cross entropy loss.
|
||||
Args:
|
||||
num_labels (int): Number of classes in labels.
|
||||
ignore_index (int): Specifies a target value that is ignored
|
||||
and does not contribute to the input gradient.
|
||||
"""
|
||||
|
||||
def __init__(self, num_labels=None, ignore_index=0):
|
||||
super().__init__()
|
||||
self.num_labels = num_labels
|
||||
self.criterion = CrossEntropyLoss(ignore_index=ignore_index)
|
||||
|
||||
def forward(self, logits, img_metas):
|
||||
'''Loss forword.
|
||||
Args:
|
||||
logits: Model output with shape [N, C].
|
||||
img_metas (dict): A dict containing the following keys:
|
||||
- img (list]): This parameter is reserved.
|
||||
- labels (list[int]): The labels for each word
|
||||
of the sequence.
|
||||
- texts (list): The words of the sequence.
|
||||
- input_ids (list): The ids for each word of
|
||||
the sequence.
|
||||
- attention_mask (list): The mask for each word
|
||||
of the sequence. The mask has 1 for real tokens
|
||||
and 0 for padding tokens. Only real tokens are
|
||||
attended to.
|
||||
- token_type_ids (list): The tokens for each word
|
||||
of the sequence.
|
||||
'''
|
||||
|
||||
labels = img_metas['labels']
|
||||
attention_masks = img_metas['attention_masks']
|
||||
|
||||
# Only keep active parts of the loss
|
||||
if attention_masks is not None:
|
||||
active_loss = attention_masks.view(-1) == 1
|
||||
active_logits = logits.view(-1, self.num_labels)[active_loss]
|
||||
active_labels = labels.view(-1)[active_loss]
|
||||
loss = self.criterion(active_logits, active_labels)
|
||||
else:
|
||||
loss = self.criterion(
|
||||
logits.view(-1, self.num_labels), labels.view(-1))
|
||||
return {'loss_cls': loss}
|
|
@ -0,0 +1,55 @@
|
|||
from torch import nn
|
||||
|
||||
from mmdet.models.builder import LOSSES
|
||||
from mmocr.models.common.losses.focal_loss import FocalLoss
|
||||
|
||||
|
||||
@LOSSES.register_module()
|
||||
class MaskedFocalLoss(nn.Module):
|
||||
"""The implementation of masked focal loss.
|
||||
|
||||
The mask has 1 for real tokens and 0 for padding tokens,
|
||||
which only keep active parts of the focal loss
|
||||
Args:
|
||||
num_labels (int): Number of classes in labels.
|
||||
ignore_index (int): Specifies a target value that is ignored
|
||||
and does not contribute to the input gradient.
|
||||
"""
|
||||
|
||||
def __init__(self, num_labels=None, ignore_index=0):
|
||||
super().__init__()
|
||||
self.num_labels = num_labels
|
||||
self.criterion = FocalLoss(ignore_index=ignore_index)
|
||||
|
||||
def forward(self, logits, img_metas):
|
||||
'''Loss forword.
|
||||
Args:
|
||||
logits: Model output with shape [N, C].
|
||||
img_metas (dict): A dict containing the following keys:
|
||||
- img (list]): This parameter is reserved.
|
||||
- labels (list[int]): The labels for each word
|
||||
of the sequence.
|
||||
- texts (list): The words of the sequence.
|
||||
- input_ids (list): The ids for each word of
|
||||
the sequence.
|
||||
- attention_mask (list): The mask for each word
|
||||
of the sequence. The mask has 1 for real tokens
|
||||
and 0 for padding tokens. Only real tokens are
|
||||
attended to.
|
||||
- token_type_ids (list): The tokens for each word
|
||||
of the sequence.
|
||||
'''
|
||||
|
||||
labels = img_metas['labels']
|
||||
attention_masks = img_metas['attention_masks']
|
||||
|
||||
# Only keep active parts of the loss
|
||||
if attention_masks is not None:
|
||||
active_loss = attention_masks.view(-1) == 1
|
||||
active_logits = logits.view(-1, self.num_labels)[active_loss]
|
||||
active_labels = labels.view(-1)[active_loss]
|
||||
loss = self.criterion(active_logits, active_labels)
|
||||
else:
|
||||
loss = self.criterion(
|
||||
logits.view(-1, self.num_labels), labels.view(-1))
|
||||
return {'loss_cls': loss}
|
|
@ -0,0 +1,4 @@
|
|||
from .activations import ACT2FN
|
||||
from .bert import BertModel
|
||||
|
||||
__all__ = ['BertModel', 'ACT2FN']
|
|
@ -0,0 +1,39 @@
|
|||
# ------------------------------------------------------------------------------
|
||||
# Adapted from https://github.com/lonePatient/BERT-NER-Pytorch
|
||||
# Original licence: Copyright (c) 2020 Weitang Liu, under the MIT License.
|
||||
# ------------------------------------------------------------------------------
|
||||
|
||||
import math
|
||||
|
||||
import torch
|
||||
from mmcv.cnn import Swish
|
||||
|
||||
|
||||
def gelu(x):
|
||||
"""Original Implementation of the gelu activation function in Google Bert
|
||||
repo when initially created. For information: OpenAI GPT's gelu is slightly
|
||||
different (and gives slightly different results):
|
||||
|
||||
0.5 * x * (1 + torch.tanh(math.sqrt(2 / math.pi) *
|
||||
(x + 0.044715 * torch.pow(x, 3))))
|
||||
Also see https://arxiv.org/abs/1606.08415
|
||||
"""
|
||||
return x * 0.5 * (1.0 + torch.erf(x / math.sqrt(2.0)))
|
||||
|
||||
|
||||
def gelu_new(x):
|
||||
"""Implementation of the gelu activation function currently in Google Bert
|
||||
repo (identical to OpenAI GPT).
|
||||
|
||||
Also see https://arxiv.org/abs/1606.08415
|
||||
"""
|
||||
return 0.5 * x * (1 + torch.tanh(
|
||||
math.sqrt(2 / math.pi) * (x + 0.044715 * torch.pow(x, 3))))
|
||||
|
||||
|
||||
ACT2FN = {
|
||||
'gelu': gelu,
|
||||
'relu': torch.nn.functional.relu,
|
||||
'swish': Swish,
|
||||
'gelu_new': gelu_new
|
||||
}
|
|
@ -0,0 +1,486 @@
|
|||
# ------------------------------------------------------------------------------
|
||||
# Adapted from https://github.com/lonePatient/BERT-NER-Pytorch
|
||||
# Original licence: Copyright (c) 2020 Weitang Liu, under the MIT License.
|
||||
# ------------------------------------------------------------------------------
|
||||
|
||||
import math
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from mmocr.models.ner.utils.activations import ACT2FN
|
||||
|
||||
|
||||
class BertModel(nn.Module):
|
||||
"""Implement Bert model for named entity recognition task.
|
||||
|
||||
The code is adapted from https://github.com/lonePatient/BERT-NER-Pytorch
|
||||
Args:
|
||||
num_hidden_layers (int): The number of hidden layers.
|
||||
initializer_range (float):
|
||||
vocab_size (int): Number of words supported.
|
||||
hidden_size (int): Hidden size.
|
||||
max_position_embeddings (int): Max positionsembedding size.
|
||||
type_vocab_size (int): The size of type_vocab.
|
||||
layer_norm_eps (float): eps.
|
||||
hidden_dropout_prob (float): The dropout probability of hidden layer.
|
||||
output_attentions (bool): Whether use the attentions in output
|
||||
output_hidden_states (bool): Whether use the hidden_states in output.
|
||||
num_attention_heads (int): The number of attention heads.
|
||||
attention_probs_dropout_prob (float): The dropout probability
|
||||
for the attention probabilities normalized from
|
||||
the attention scores.
|
||||
intermediate_size (int): The size of intermediate layer.
|
||||
hidden_act (str): hidden layer activation
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
num_hidden_layers=12,
|
||||
initializer_range=0.02,
|
||||
vocab_size=21128,
|
||||
hidden_size=768,
|
||||
max_position_embeddings=128,
|
||||
type_vocab_size=2,
|
||||
layer_norm_eps=1e-12,
|
||||
hidden_dropout_prob=0.1,
|
||||
output_attentions=False,
|
||||
output_hidden_states=False,
|
||||
num_attention_heads=12,
|
||||
attention_probs_dropout_prob=0.1,
|
||||
intermediate_size=3072,
|
||||
hidden_act='gelu_new'):
|
||||
super().__init__()
|
||||
self.embeddings = BertEmbeddings(
|
||||
vocab_size=vocab_size,
|
||||
hidden_size=hidden_size,
|
||||
max_position_embeddings=max_position_embeddings,
|
||||
type_vocab_size=type_vocab_size,
|
||||
layer_norm_eps=layer_norm_eps,
|
||||
hidden_dropout_prob=hidden_dropout_prob)
|
||||
self.encoder = BertEncoder(
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
num_hidden_layers=num_hidden_layers,
|
||||
hidden_size=hidden_size,
|
||||
num_attention_heads=num_attention_heads,
|
||||
attention_probs_dropout_prob=attention_probs_dropout_prob,
|
||||
layer_norm_eps=layer_norm_eps,
|
||||
hidden_dropout_prob=hidden_dropout_prob,
|
||||
intermediate_size=intermediate_size,
|
||||
hidden_act=hidden_act)
|
||||
self.pooler = BertPooler(hidden_size=hidden_size)
|
||||
self.num_hidden_layers = num_hidden_layers
|
||||
self.initializer_range = initializer_range
|
||||
self.init_weights()
|
||||
|
||||
def _resize_token_embeddings(self, new_num_tokens):
|
||||
old_embeddings = self.embeddings.word_embeddings
|
||||
new_embeddings = self._get_resized_embeddings(old_embeddings,
|
||||
new_num_tokens)
|
||||
self.embeddings.word_embeddings = new_embeddings
|
||||
return self.embeddings.word_embeddings
|
||||
|
||||
def forward(self,
|
||||
input_ids,
|
||||
attention_masks=None,
|
||||
token_type_ids=None,
|
||||
position_ids=None,
|
||||
head_mask=None):
|
||||
if attention_masks is None:
|
||||
attention_masks = torch.ones_like(input_ids)
|
||||
if token_type_ids is None:
|
||||
token_type_ids = torch.zeros_like(input_ids)
|
||||
attention_masks = attention_masks[:, None, None]
|
||||
attention_masks = attention_masks.to(
|
||||
dtype=next(self.parameters()).dtype)
|
||||
attention_masks = (1.0 - attention_masks) * -10000.0
|
||||
if head_mask is not None:
|
||||
if head_mask.dim() == 1:
|
||||
head_mask = head_mask[None, None, :, None, None]
|
||||
elif head_mask.dim() == 2:
|
||||
head_mask = head_mask[None, :, None, None]
|
||||
head_mask = head_mask.to(dtype=next(self.parameters()).dtype)
|
||||
else:
|
||||
head_mask = [None] * self.num_hidden_layers
|
||||
|
||||
embedding_output = self.embeddings(
|
||||
input_ids,
|
||||
position_ids=position_ids,
|
||||
token_type_ids=token_type_ids)
|
||||
sequence_output, *encoder_outputs = self.encoder(
|
||||
embedding_output, attention_masks, head_mask=head_mask)
|
||||
# sequence_output = encoder_outputs[0]
|
||||
pooled_output = self.pooler(sequence_output)
|
||||
|
||||
# add hidden_states and attentions if they are here
|
||||
# sequence_output, pooled_output, (hidden_states), (attentions)
|
||||
outputs = (
|
||||
sequence_output,
|
||||
pooled_output,
|
||||
) + tuple(encoder_outputs)
|
||||
return outputs
|
||||
|
||||
def _init_weights(self, module):
|
||||
"""Initialize the weights."""
|
||||
if isinstance(module, (nn.Linear, nn.Embedding)):
|
||||
# Slightly different from the TF version which
|
||||
# uses truncated_normal for initialization
|
||||
# cf https://github.com/pytorch/pytorch/pull/5617
|
||||
module.weight.data.normal_(mean=0.0, std=self.initializer_range)
|
||||
elif isinstance(module, torch.nn.LayerNorm):
|
||||
module.bias.data.zero_()
|
||||
module.weight.data.fill_(1.0)
|
||||
if isinstance(module, nn.Linear) and module.bias is not None:
|
||||
module.bias.data.zero_()
|
||||
|
||||
def init_weights(self):
|
||||
"""Initialize and prunes weights if needed."""
|
||||
# Initialize weights
|
||||
self.apply(self._init_weights)
|
||||
|
||||
|
||||
class BertEmbeddings(nn.Module):
|
||||
"""Construct the embeddings from word, position and token_type embeddings.
|
||||
|
||||
The code is adapted from https://github.com/lonePatient/BERT-NER-Pytorch.
|
||||
Args:
|
||||
vocab_size (int): Number of words supported.
|
||||
hidden_size (int): Hidden size.
|
||||
max_position_embeddings (int): Max positions embedding size.
|
||||
type_vocab_size (int): The size of type_vocab.
|
||||
layer_norm_eps (float): eps.
|
||||
hidden_dropout_prob (float): The dropout probability of hidden layer.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
vocab_size=21128,
|
||||
hidden_size=768,
|
||||
max_position_embeddings=128,
|
||||
type_vocab_size=2,
|
||||
layer_norm_eps=1e-12,
|
||||
hidden_dropout_prob=0.1):
|
||||
super().__init__()
|
||||
|
||||
self.word_embeddings = nn.Embedding(
|
||||
vocab_size, hidden_size, padding_idx=0)
|
||||
self.position_embeddings = nn.Embedding(max_position_embeddings,
|
||||
hidden_size)
|
||||
self.token_type_embeddings = nn.Embedding(type_vocab_size, hidden_size)
|
||||
|
||||
# self.LayerNorm is not snake-cased to stick with
|
||||
# TensorFlow model variable name and be able to load
|
||||
# any TensorFlow checkpoint file
|
||||
self.LayerNorm = torch.nn.LayerNorm(hidden_size, eps=layer_norm_eps)
|
||||
self.dropout = nn.Dropout(hidden_dropout_prob)
|
||||
|
||||
def forward(self, input_ids, token_type_ids=None, position_ids=None):
|
||||
seq_length = input_ids.size(1)
|
||||
if position_ids is None:
|
||||
position_ids = torch.arange(
|
||||
seq_length, dtype=torch.long, device=input_ids.device)
|
||||
position_ids = position_ids.unsqueeze(0).expand_as(input_ids)
|
||||
if token_type_ids is None:
|
||||
token_type_ids = torch.zeros_like(input_ids)
|
||||
|
||||
words_emb = self.word_embeddings(input_ids)
|
||||
position_emb = self.position_embeddings(position_ids)
|
||||
token_type_emb = self.token_type_embeddings(token_type_ids)
|
||||
embeddings = words_emb + position_emb + token_type_emb
|
||||
embeddings = self.LayerNorm(embeddings)
|
||||
embeddings = self.dropout(embeddings)
|
||||
return embeddings
|
||||
|
||||
|
||||
class BertEncoder(nn.Module):
|
||||
"""The code is adapted from https://github.com/lonePatient/BERT-NER-
|
||||
Pytorch."""
|
||||
|
||||
def __init__(self,
|
||||
output_attentions=False,
|
||||
output_hidden_states=False,
|
||||
num_hidden_layers=12,
|
||||
hidden_size=768,
|
||||
num_attention_heads=12,
|
||||
attention_probs_dropout_prob=0.1,
|
||||
layer_norm_eps=1e-12,
|
||||
hidden_dropout_prob=0.1,
|
||||
intermediate_size=3072,
|
||||
hidden_act='gelu_new'):
|
||||
super().__init__()
|
||||
self.output_attentions = output_attentions
|
||||
self.output_hidden_states = output_hidden_states
|
||||
self.layer = nn.ModuleList([
|
||||
BertLayer(
|
||||
hidden_size=hidden_size,
|
||||
num_attention_heads=num_attention_heads,
|
||||
output_attentions=output_attentions,
|
||||
attention_probs_dropout_prob=attention_probs_dropout_prob,
|
||||
layer_norm_eps=layer_norm_eps,
|
||||
hidden_dropout_prob=hidden_dropout_prob,
|
||||
intermediate_size=intermediate_size,
|
||||
hidden_act=hidden_act) for _ in range(num_hidden_layers)
|
||||
])
|
||||
|
||||
def forward(self, hidden_states, attention_mask=None, head_mask=None):
|
||||
all_hidden_states = ()
|
||||
all_attentions = ()
|
||||
for i, layer_module in enumerate(self.layer):
|
||||
if self.output_hidden_states:
|
||||
all_hidden_states = all_hidden_states + (hidden_states, )
|
||||
|
||||
layer_outputs = layer_module(hidden_states, attention_mask,
|
||||
head_mask[i])
|
||||
hidden_states = layer_outputs[0]
|
||||
|
||||
if self.output_attentions:
|
||||
all_attentions = all_attentions + (layer_outputs[1], )
|
||||
|
||||
# Add last layer
|
||||
if self.output_hidden_states:
|
||||
all_hidden_states = all_hidden_states + (hidden_states, )
|
||||
|
||||
outputs = (hidden_states, )
|
||||
if self.output_hidden_states:
|
||||
outputs = outputs + (all_hidden_states, )
|
||||
if self.output_attentions:
|
||||
outputs = outputs + (all_attentions, )
|
||||
# last-layer hidden state, (all hidden states), (all attentions)
|
||||
return outputs
|
||||
|
||||
|
||||
class BertPooler(nn.Module):
|
||||
|
||||
def __init__(self, hidden_size=768):
|
||||
super().__init__()
|
||||
self.dense = nn.Linear(hidden_size, hidden_size)
|
||||
self.activation = nn.Tanh()
|
||||
|
||||
def forward(self, hidden_states):
|
||||
# We "pool" the model by simply taking the hidden state corresponding
|
||||
# to the first token.
|
||||
first_token_tensor = hidden_states[:, 0]
|
||||
pooled_output = self.dense(first_token_tensor)
|
||||
pooled_output = self.activation(pooled_output)
|
||||
return pooled_output
|
||||
|
||||
|
||||
class BertLayer(nn.Module):
|
||||
"""Bert layer.
|
||||
|
||||
The code is adapted from https://github.com/lonePatient/BERT-NER-Pytorch.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
hidden_size=768,
|
||||
num_attention_heads=12,
|
||||
output_attentions=False,
|
||||
attention_probs_dropout_prob=0.1,
|
||||
layer_norm_eps=1e-12,
|
||||
hidden_dropout_prob=0.1,
|
||||
intermediate_size=3072,
|
||||
hidden_act='gelu_new'):
|
||||
super().__init__()
|
||||
self.attention = BertAttention(
|
||||
hidden_size=hidden_size,
|
||||
num_attention_heads=num_attention_heads,
|
||||
output_attentions=output_attentions,
|
||||
attention_probs_dropout_prob=attention_probs_dropout_prob,
|
||||
layer_norm_eps=layer_norm_eps,
|
||||
hidden_dropout_prob=hidden_dropout_prob)
|
||||
self.intermediate = BertIntermediate(
|
||||
hidden_size=hidden_size,
|
||||
intermediate_size=intermediate_size,
|
||||
hidden_act=hidden_act)
|
||||
self.output = BertOutput(
|
||||
intermediate_size=intermediate_size,
|
||||
hidden_size=hidden_size,
|
||||
layer_norm_eps=layer_norm_eps,
|
||||
hidden_dropout_prob=hidden_dropout_prob)
|
||||
|
||||
def forward(self, hidden_states, attention_mask=None, head_mask=None):
|
||||
attention_outputs = self.attention(hidden_states, attention_mask,
|
||||
head_mask)
|
||||
attention_output = attention_outputs[0]
|
||||
intermediate_output = self.intermediate(attention_output)
|
||||
layer_output = self.output(intermediate_output, attention_output)
|
||||
outputs = (layer_output, ) + attention_outputs[
|
||||
1:] # add attentions if we output them
|
||||
return outputs
|
||||
|
||||
|
||||
class BertSelfAttention(nn.Module):
|
||||
"""Bert self attention module.
|
||||
|
||||
The code is adapted from https://github.com/lonePatient/BERT-NER-Pytorch.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
hidden_size=768,
|
||||
num_attention_heads=12,
|
||||
output_attentions=False,
|
||||
attention_probs_dropout_prob=0.1):
|
||||
super().__init__()
|
||||
if hidden_size % num_attention_heads != 0:
|
||||
raise ValueError('The hidden size (%d) is not a multiple of'
|
||||
'the number of attention heads (%d)' %
|
||||
(hidden_size, num_attention_heads))
|
||||
self.output_attentions = output_attentions
|
||||
|
||||
self.num_attention_heads = num_attention_heads
|
||||
self.att_head_size = int(hidden_size / num_attention_heads)
|
||||
self.all_head_size = self.num_attention_heads * self.att_head_size
|
||||
|
||||
self.query = nn.Linear(hidden_size, self.all_head_size)
|
||||
self.key = nn.Linear(hidden_size, self.all_head_size)
|
||||
self.value = nn.Linear(hidden_size, self.all_head_size)
|
||||
|
||||
self.dropout = nn.Dropout(attention_probs_dropout_prob)
|
||||
|
||||
def transpose_for_scores(self, x):
|
||||
new_x_shape = x.size()[:-1] + (self.num_attention_heads,
|
||||
self.att_head_size)
|
||||
x = x.view(*new_x_shape)
|
||||
return x.permute(0, 2, 1, 3)
|
||||
|
||||
def forward(self, hidden_states, attention_mask=None, head_mask=None):
|
||||
mixed_query_layer = self.query(hidden_states)
|
||||
mixed_key_layer = self.key(hidden_states)
|
||||
mixed_value_layer = self.value(hidden_states)
|
||||
|
||||
query_layer = self.transpose_for_scores(mixed_query_layer)
|
||||
key_layer = self.transpose_for_scores(mixed_key_layer)
|
||||
value_layer = self.transpose_for_scores(mixed_value_layer)
|
||||
|
||||
# Take the dot product between "query" and
|
||||
# "key" to get the raw attention scores.
|
||||
attention_scores = torch.matmul(query_layer,
|
||||
key_layer.transpose(-1, -2))
|
||||
attention_scores = attention_scores / math.sqrt(self.att_head_size)
|
||||
if attention_mask is not None:
|
||||
# Apply the attention mask is precomputed for
|
||||
# all layers in BertModel forward() function.
|
||||
attention_scores = attention_scores + attention_mask
|
||||
|
||||
# Normalize the attention scores to probabilities.
|
||||
attention_probs = nn.Softmax(dim=-1)(attention_scores)
|
||||
|
||||
# This is actually dropping out entire tokens to attend to, which might
|
||||
# seem a bit unusual, but is taken from the original Transformer paper.
|
||||
attention_probs = self.dropout(attention_probs)
|
||||
|
||||
# Mask heads if we want to.
|
||||
if head_mask is not None:
|
||||
attention_probs = attention_probs * head_mask
|
||||
|
||||
context_layer = torch.matmul(attention_probs, value_layer)
|
||||
|
||||
context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
|
||||
new_context_layer_shape = context_layer.size()[:-2] + (
|
||||
self.all_head_size, )
|
||||
context_layer = context_layer.view(*new_context_layer_shape)
|
||||
|
||||
outputs = (context_layer,
|
||||
attention_probs) if self.output_attentions else (
|
||||
context_layer, )
|
||||
return outputs
|
||||
|
||||
|
||||
class BertSelfOutput(nn.Module):
|
||||
"""Bert self output.
|
||||
|
||||
The code is adapted from https://github.com/lonePatient/BERT-NER-Pytorch.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
hidden_size=768,
|
||||
layer_norm_eps=1e-12,
|
||||
hidden_dropout_prob=0.1):
|
||||
super().__init__()
|
||||
self.dense = nn.Linear(hidden_size, hidden_size)
|
||||
self.LayerNorm = torch.nn.LayerNorm(hidden_size, eps=layer_norm_eps)
|
||||
self.dropout = nn.Dropout(hidden_dropout_prob)
|
||||
|
||||
def forward(self, hidden_states, input_tensor):
|
||||
hidden_states = self.dense(hidden_states)
|
||||
hidden_states = self.dropout(hidden_states)
|
||||
hidden_states = self.LayerNorm(hidden_states + input_tensor)
|
||||
return hidden_states
|
||||
|
||||
|
||||
class BertAttention(nn.Module):
|
||||
"""Bert Attention module implementation.
|
||||
|
||||
The code is adapted from https://github.com/lonePatient/BERT-NER-Pytorch.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
hidden_size=768,
|
||||
num_attention_heads=12,
|
||||
output_attentions=False,
|
||||
attention_probs_dropout_prob=0.1,
|
||||
layer_norm_eps=1e-12,
|
||||
hidden_dropout_prob=0.1):
|
||||
super().__init__()
|
||||
self.self = BertSelfAttention(
|
||||
hidden_size=hidden_size,
|
||||
num_attention_heads=num_attention_heads,
|
||||
output_attentions=output_attentions,
|
||||
attention_probs_dropout_prob=attention_probs_dropout_prob)
|
||||
self.output = BertSelfOutput(
|
||||
hidden_size=hidden_size,
|
||||
layer_norm_eps=layer_norm_eps,
|
||||
hidden_dropout_prob=hidden_dropout_prob)
|
||||
|
||||
def forward(self, input_tensor, attention_mask=None, head_mask=None):
|
||||
self_outputs = self.self(input_tensor, attention_mask, head_mask)
|
||||
attention_output = self.output(self_outputs[0], input_tensor)
|
||||
outputs = (attention_output,
|
||||
) + self_outputs[1:] # add attentions if we output them
|
||||
return outputs
|
||||
|
||||
|
||||
class BertIntermediate(nn.Module):
|
||||
"""Bert BertIntermediate module implementation.
|
||||
|
||||
The code is adapted from https://github.com/lonePatient/BERT-NER-Pytorch.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
hidden_size=768,
|
||||
intermediate_size=3072,
|
||||
hidden_act='gelu_new'):
|
||||
super().__init__()
|
||||
self.dense = nn.Linear(hidden_size, intermediate_size)
|
||||
if isinstance(hidden_act, str):
|
||||
self.intermediate_act_fn = ACT2FN[hidden_act]
|
||||
else:
|
||||
self.intermediate_act_fn = hidden_act
|
||||
|
||||
def forward(self, hidden_states):
|
||||
hidden_states = self.dense(hidden_states)
|
||||
hidden_states = self.intermediate_act_fn(hidden_states)
|
||||
return hidden_states
|
||||
|
||||
|
||||
class BertOutput(nn.Module):
|
||||
"""Bert output module.
|
||||
|
||||
The code is adapted from https://github.com/lonePatient/BERT-NER-Pytorch.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
intermediate_size=3072,
|
||||
hidden_size=768,
|
||||
layer_norm_eps=1e-12,
|
||||
hidden_dropout_prob=0.1):
|
||||
|
||||
super().__init__()
|
||||
self.dense = nn.Linear(intermediate_size, hidden_size)
|
||||
self.LayerNorm = torch.nn.LayerNorm(hidden_size, eps=layer_norm_eps)
|
||||
self.dropout = nn.Dropout(hidden_dropout_prob)
|
||||
|
||||
def forward(self, hidden_states, input_tensor):
|
||||
hidden_states = self.dense(hidden_states)
|
||||
hidden_states = self.dropout(hidden_states)
|
||||
hidden_states = self.LayerNorm(hidden_states + input_tensor)
|
||||
return hidden_states
|
|
@ -0,0 +1,114 @@
|
|||
import json
|
||||
import os.path as osp
|
||||
import tempfile
|
||||
|
||||
import torch
|
||||
|
||||
from mmocr.datasets.ner_dataset import NerDataset
|
||||
from mmocr.models.ner.convertors.ner_convertor import NerConvertor
|
||||
|
||||
|
||||
def _create_dummy_ann_file(ann_file):
|
||||
data = {
|
||||
'text': '彭小军认为,国内银行现在走的是台湾的发卡模式',
|
||||
'label': {
|
||||
'address': {
|
||||
'台湾': [[15, 16]]
|
||||
},
|
||||
'name': {
|
||||
'彭小军': [[0, 2]]
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
with open(ann_file, 'w') as fw:
|
||||
fw.write(json.dumps(data, ensure_ascii=False) + '\n')
|
||||
|
||||
|
||||
def _create_dummy_vocab_file(vocab_file):
|
||||
with open(vocab_file, 'w') as fw:
|
||||
for char in list(map(chr, range(ord('a'), ord('z') + 1))):
|
||||
fw.write(char + '\n')
|
||||
|
||||
|
||||
def _create_dummy_loader():
|
||||
loader = dict(
|
||||
type='HardDiskLoader',
|
||||
repeat=1,
|
||||
parser=dict(type='LineJsonParser', keys=['text', 'label']))
|
||||
return loader
|
||||
|
||||
|
||||
def test_ner_dataset():
|
||||
# test initialization
|
||||
loader = _create_dummy_loader()
|
||||
categories = [
|
||||
'address', 'book', 'company', 'game', 'government', 'movie', 'name',
|
||||
'organization', 'position', 'scene'
|
||||
]
|
||||
|
||||
# create dummy data
|
||||
tmp_dir = tempfile.TemporaryDirectory()
|
||||
ann_file = osp.join(tmp_dir.name, 'fake_data.txt')
|
||||
vocab_file = osp.join(tmp_dir.name, 'fake_vocab.txt')
|
||||
_create_dummy_ann_file(ann_file)
|
||||
_create_dummy_vocab_file(vocab_file)
|
||||
|
||||
max_len = 128
|
||||
ner_convertor = dict(
|
||||
type='NerConvertor',
|
||||
annotation_type='bio',
|
||||
vocab_file=vocab_file,
|
||||
categories=categories,
|
||||
max_len=max_len)
|
||||
|
||||
test_pipeline = [
|
||||
dict(
|
||||
type='NerTransform',
|
||||
label_convertor=ner_convertor,
|
||||
max_len=max_len),
|
||||
dict(type='ToTensorNER')
|
||||
]
|
||||
dataset = NerDataset(ann_file, loader, pipeline=test_pipeline)
|
||||
|
||||
# test pre_pipeline
|
||||
img_info = dataset.data_infos[0]
|
||||
results = dict(img_info=img_info)
|
||||
dataset.pre_pipeline(results)
|
||||
|
||||
# test prepare_train_img
|
||||
dataset.prepare_train_img(0)
|
||||
|
||||
# test evaluation
|
||||
result = [[['address', 15, 16], ['name', 0, 2]]]
|
||||
|
||||
dataset.evaluate(result)
|
||||
|
||||
# test pred convert2entity function
|
||||
pred = [
|
||||
21, 7, 17, 17, 21, 21, 21, 21, 21, 21, 13, 21, 21, 21, 21, 21, 1, 11,
|
||||
21, 21, 7, 17, 17, 21, 21, 21, 21, 21, 21, 13, 21, 21, 21, 21, 21, 1,
|
||||
11, 21, 21, 21, 21, 21, 21, 21, 21, 21, 21, 21, 21, 21, 21, 21, 21, 21,
|
||||
21, 21, 21, 21, 21, 21, 21, 21, 21, 21, 21, 21, 21, 21, 21, 21, 21, 21,
|
||||
21, 21, 21, 21, 21, 21, 21, 21, 21, 21, 21, 21, 21, 21, 21, 21, 21, 21,
|
||||
21, 21, 21, 21, 21, 21, 21, 21, 21, 21, 21, 21, 21, 21, 21, 21, 21, 21,
|
||||
21, 21, 21, 21, 21, 21, 21, 21, 21, 21, 21, 21, 21, 21, 21, 21, 21, 21,
|
||||
21, 21, 21, 21, 21, 21, 21, 21, 21, 21, 21, 21, 1, 21, 21, 21, 21, 21,
|
||||
21, 21, 21, 21, 21, 21, 21, 21, 21, 21, 21, 21, 21, 1, 21, 21, 21, 21,
|
||||
21, 21
|
||||
]
|
||||
preds = [pred[:128]]
|
||||
mask = [0] * 128
|
||||
for i in range(10):
|
||||
mask[i] = 1
|
||||
assert len(preds[0]) == len(mask)
|
||||
masks = torch.tensor([mask])
|
||||
convertor = NerConvertor(
|
||||
annotation_type='bio',
|
||||
vocab_file=vocab_file,
|
||||
categories=categories,
|
||||
max_len=128)
|
||||
all_entities = convertor.convert_pred2entities(preds=preds, masks=masks)
|
||||
assert len(all_entities[0][0]) == 3
|
||||
|
||||
tmp_dir.cleanup()
|
|
@ -0,0 +1,83 @@
|
|||
import copy
|
||||
import os.path as osp
|
||||
import tempfile
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from mmocr.models import build_detector
|
||||
from mmocr.models.ner.utils.activations import gelu, gelu_new
|
||||
|
||||
|
||||
def _create_dummy_vocab_file(vocab_file):
|
||||
with open(vocab_file, 'w') as fw:
|
||||
for char in list(map(chr, range(ord('a'), ord('z') + 1))):
|
||||
fw.write(char + '\n')
|
||||
|
||||
|
||||
def _get_config_module(fname):
|
||||
"""Load a configuration as a python module."""
|
||||
from mmcv import Config
|
||||
config_mod = Config.fromfile(fname)
|
||||
return config_mod
|
||||
|
||||
|
||||
def _get_detector_cfg(fname):
|
||||
"""Grab configs necessary to create a detector.
|
||||
|
||||
These are deep copied to allow for safe modification of parameters without
|
||||
influencing other tests.
|
||||
"""
|
||||
config = _get_config_module(fname)
|
||||
model = copy.deepcopy(config.model)
|
||||
return model
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
'cfg_file', ['configs/ner/bert_softmax/bert_softmax_cluener_18e.py'])
|
||||
def test_encoder_decoder_pipeline(cfg_file):
|
||||
# prepare data
|
||||
texts = ['中'] * 47
|
||||
img = [31] * 47
|
||||
labels = [31] * 128
|
||||
input_ids = [0] * 128
|
||||
attention_mask = [0] * 128
|
||||
token_type_ids = [0] * 128
|
||||
img_metas = {
|
||||
'texts': texts,
|
||||
'labels': torch.tensor(labels).unsqueeze(0),
|
||||
'img': img,
|
||||
'input_ids': torch.tensor(input_ids).unsqueeze(0),
|
||||
'attention_masks': torch.tensor(attention_mask).unsqueeze(0),
|
||||
'token_type_ids': torch.tensor(token_type_ids).unsqueeze(0)
|
||||
}
|
||||
|
||||
# create dummy data
|
||||
tmp_dir = tempfile.TemporaryDirectory()
|
||||
vocab_file = osp.join(tmp_dir.name, 'fake_vocab.txt')
|
||||
_create_dummy_vocab_file(vocab_file)
|
||||
|
||||
model = _get_detector_cfg(cfg_file)
|
||||
model['label_convertor']['vocab_file'] = vocab_file
|
||||
model['pretrained'] = None
|
||||
|
||||
detector = build_detector(model)
|
||||
losses = detector.forward(img, img_metas)
|
||||
assert isinstance(losses, dict)
|
||||
|
||||
model['loss']['type'] = 'MaskedFocalLoss'
|
||||
detector = build_detector(model)
|
||||
losses = detector.forward(img, img_metas)
|
||||
assert isinstance(losses, dict)
|
||||
|
||||
tmp_dir.cleanup()
|
||||
|
||||
# Test forward test
|
||||
with torch.no_grad():
|
||||
batch_results = []
|
||||
result = detector.forward(None, img_metas, return_loss=False)
|
||||
batch_results.append(result)
|
||||
|
||||
# Test activations
|
||||
gelu(torch.tensor(0.5))
|
||||
gelu_new(torch.tensor(0.5))
|
|
@ -22,6 +22,8 @@ def parse_args():
|
|||
parser = argparse.ArgumentParser(description='Train a detector.')
|
||||
parser.add_argument('config', help='Train config file path.')
|
||||
parser.add_argument('--work-dir', help='The dir to save logs and models.')
|
||||
parser.add_argument(
|
||||
'--load-from', help='The checkpoint file to load from.')
|
||||
parser.add_argument(
|
||||
'--resume-from', help='The checkpoint file to resume from.')
|
||||
parser.add_argument(
|
||||
|
@ -123,6 +125,8 @@ def main():
|
|||
# use config filename as default work_dir if cfg.work_dir is None
|
||||
cfg.work_dir = osp.join('./work_dirs',
|
||||
osp.splitext(osp.basename(args.config))[0])
|
||||
if args.load_from is not None:
|
||||
cfg.load_from = args.load_from
|
||||
if args.resume_from is not None:
|
||||
cfg.resume_from = args.resume_from
|
||||
if args.gpu_ids is not None:
|
||||
|
|
Loading…
Reference in New Issue