diff --git a/configs/_base_/schedules/schedule_adadelta_18e.py b/configs/_base_/schedules/schedule_adadelta_18e.py new file mode 100644 index 00000000..396e807d --- /dev/null +++ b/configs/_base_/schedules/schedule_adadelta_18e.py @@ -0,0 +1,6 @@ +# optimizer +optimizer = dict(type='Adadelta', lr=0.5) +optimizer_config = dict(grad_clip=dict(max_norm=0.5)) +# learning policy +lr_config = dict(policy='step', step=[8, 14, 16]) +total_epochs = 18 diff --git a/configs/ner/bert_softmax/README.md b/configs/ner/bert_softmax/README.md new file mode 100644 index 00000000..fbe9378c --- /dev/null +++ b/configs/ner/bert_softmax/README.md @@ -0,0 +1,34 @@ +# Chinese Named Entity Recognition using BERT + Softmax + +## Introduction + +[ALGORITHM] +```bibtex +@article{devlin2018bert, + title={Bert: Pre-training of deep bidirectional transformers for language understanding}, + author={Devlin, Jacob and Chang, Ming-Wei and Lee, Kenton and Toutanova, Kristina}, + journal={arXiv preprint arXiv:1810.04805}, + year={2018} +} +``` + +## Dataset + +### Train Dataset + +| trainset | text_num | entity_num | +| :--------: | :----------: | :--------: | +| CLUENER2020 | 10748 | 23338 | + +### Test Dataset + +| testset | text_num | entity_num | +| :--------: | :----------: | :--------: | +| CLUENER2020 | 1343 | 2982 | + + +## Results and models + +| Method |Pretrain| Precision | Recall | F1-Score | Download | +| :--------------------------------------------------------------------: |:-----------:|:-----------:| :--------:| :-------: | :-------------------------------------: | +| [bert_softmax](/configs/ner/bert_softmax/bert_softmax_cluener_18e.py)| [pretrain](https://download.openmmlab.com/mmocr/ner/bert_softmax/bert_pretrain.pth) |0.7885 | 0.7998 | 0.7941 | [model](https://download.openmmlab.com/mmocr/ner/bert_softmax/bert_softmax_cluener-eea70ea2.pth) \| [log](https://download.openmmlab.com/mmocr/ner/bert_softmax/20210514_172645.log.json) | diff --git a/configs/ner/bert_softmax/bert_softmax_cluener_18e.py b/configs/ner/bert_softmax/bert_softmax_cluener_18e.py new file mode 100755 index 00000000..eb9a8997 --- /dev/null +++ b/configs/ner/bert_softmax/bert_softmax_cluener_18e.py @@ -0,0 +1,66 @@ +_base_ = [ + '../../_base_/schedules/schedule_adadelta_18e.py', + '../../_base_/default_runtime.py' +] + +categories = [ + 'address', 'book', 'company', 'game', 'government', 'movie', 'name', + 'organization', 'position', 'scene' +] + +test_ann_file = 'data/cluener2020/dev.json' +train_ann_file = 'data/cluener2020/train.json' +vocab_file = 'data/cluener2020/vocab.txt' + +max_len = 128 +loader = dict( + type='HardDiskLoader', + repeat=1, + parser=dict(type='LineJsonParser', keys=['text', 'label'])) + +ner_convertor = dict( + type='NerConvertor', + annotation_type='bio', + vocab_file=vocab_file, + categories=categories, + max_len=max_len) + +test_pipeline = [ + dict(type='NerTransform', label_convertor=ner_convertor, max_len=max_len), + dict(type='ToTensorNER') +] + +train_pipeline = [ + dict(type='NerTransform', label_convertor=ner_convertor, max_len=max_len), + dict(type='ToTensorNER') +] +dataset_type = 'NerDataset' + +train = dict( + type=dataset_type, + ann_file=train_ann_file, + loader=loader, + pipeline=train_pipeline, + test_mode=False) + +test = dict( + type=dataset_type, + ann_file=test_ann_file, + loader=loader, + pipeline=test_pipeline, + test_mode=True) +data = dict( + samples_per_gpu=8, workers_per_gpu=2, train=train, val=test, test=test) + +evaluation = dict(interval=1, metric='f1-score') + +model = dict( + type='NerClassifier', + pretrained='https://download.openmmlab.com/mmocr/ner/' + 'bert_softmax/bert_pretrain.pth', + encoder=dict(type='BertEncoder', max_position_embeddings=512), + decoder=dict(type='FCDecoder'), + loss=dict(type='MaskedCrossEntropyLoss'), + label_convertor=ner_convertor) + +test_cfg = None diff --git a/demo/ner_demo.py b/demo/ner_demo.py new file mode 100755 index 00000000..a7204935 --- /dev/null +++ b/demo/ner_demo.py @@ -0,0 +1,35 @@ +from argparse import ArgumentParser + +from mmdet.apis import init_detector +from mmocr.apis.inference import text_model_inference + +from mmocr.datasets import build_dataset # NOQA +from mmocr.models import build_detector # NOQA + + +def main(): + parser = ArgumentParser() + parser.add_argument('config', help='Config file.') + parser.add_argument('checkpoint', help='Checkpoint file.') + parser.add_argument( + '--device', default='cuda:0', help='Device used for inference.') + args = parser.parse_args() + + # build the model from a config file and a checkpoint file + model = init_detector(args.config, args.checkpoint, device=args.device) + if model.cfg.data.test['type'] == 'ConcatDataset': + model.cfg.data.test.pipeline = model.cfg.data.test['datasets'][ + 0].pipeline + + # test a single text + input_sentence = input('Please enter a sentence you want to test: ') + result = text_model_inference(model, input_sentence) + + # show the results + for pred_entities in result: + for entity in pred_entities: + print(f'{entity[0]}: {input_sentence[entity[1]:entity[2] + 1]}') + + +if __name__ == '__main__': + main() diff --git a/docs/datasets.md b/docs/datasets.md index 8a795d6e..80d94a8a 100644 --- a/docs/datasets.md +++ b/docs/datasets.md @@ -234,3 +234,24 @@ The structure of the key information extraction dataset directory is organized a ``` - Download [wildreceipt.tar](https://download.openmmlab.com/mmocr/data/wildreceipt.tar) + + +## Named Entity Recognition + +### CLUENER2020 + +The structure of the named entity recognition dataset directory is organized as follows. + +```text +└── cluener2020 + ├── cluener_predict.json + ├── dev.json + ├── README.md + ├── test.json + ├── train.json + └── vocab.txt + +``` +- Download [cluener_public.zip](https://storage.googleapis.com/cluebenchmark/tasks/cluener_public.zip) + +- Download [vocab.txt](https://download.openmmlab.com/mmocr/data/cluener2020/vocab.txt) and move `vocab.txt` to `cluener2020` diff --git a/docs/merge_docs.sh b/docs/merge_docs.sh index 23f113d6..700eb6a9 100755 --- a/docs/merge_docs.sh +++ b/docs/merge_docs.sh @@ -3,9 +3,11 @@ sed -i '$a\\n' ../configs/kie/*/*.md sed -i '$a\\n' ../configs/textdet/*/*.md sed -i '$a\\n' ../configs/textrecog/*/*.md +sed -i '$a\\n' ../configs/ner/*/*.md # gather models cat ../configs/kie/*/*.md | sed "s/md###t/html#t/g" | sed "s/#/#&/" | sed '1i\# Key Information Extraction Models' | sed 's/](\/docs\//](/g' | sed 's=](/=](https://github.com/open-mmlab/mmocr/tree/master/=g' >kie_models.md cat ../configs/textdet/*/*.md | sed "s/md###t/html#t/g" | sed "s/#/#&/" | sed '1i\# Text Detection Models' | sed 's/](\/docs\//](/g' | sed 's=](/=](https://github.com/open-mmlab/mmocr/tree/master/=g' >textdet_models.md cat ../configs/textrecog/*/*.md | sed "s/md###t/html#t/g" | sed "s/#/#&/" | sed '1i\# Text Recognition Models' | sed 's/](\/docs\//](/g' | sed 's=](/=](https://github.com/open-mmlab/mmocr/tree/master/=g' >textrecog_models.md +cat ../configs/ner/*/*.md | sed "s/md###t/html#t/g" | sed "s/#/#&/" | sed '1i\# Named Entity Recognition Models' | sed 's/](\/docs\//](/g' | sed 's=](/=](https://github.com/open-mmlab/mmocr/tree/master/=g' >ner_models.md cat ../demo/docs/*_demo.md | sed "s/#/#&/" | sed "s/md###t/html#t/g" | sed '1i\# Demo' | sed 's/](\/docs\//](/g' | sed 's=](/=](https://github.com/open-mmlab/mmocr/tree/master/=g' >demo.md diff --git a/mmocr/apis/inference.py b/mmocr/apis/inference.py index 670ab1cf..3816ea72 100644 --- a/mmocr/apis/inference.py +++ b/mmocr/apis/inference.py @@ -128,3 +128,40 @@ def model_inference(model, imgs, batch_mode=False): return results[0] else: return results + + +def text_model_inference(model, input_sentence): + """Inference text(s) with the entity recognizer. + + Args: + model (nn.Module): The loaded recognizer. + input_sentence (str): A text entered by the user. + + Returns: + result (dict): Predicted results. + """ + + assert isinstance(input_sentence, str) + + cfg = model.cfg + test_pipeline = Compose(cfg.data.test.pipeline) + data = {'text': input_sentence, 'label': {}} + + # build the data pipeline + data = test_pipeline(data) + if isinstance(data['img_metas'], dict): + img_metas = data['img_metas'] + else: + img_metas = data['img_metas'].data + + assert isinstance(img_metas, dict) + img_metas = { + 'input_ids': img_metas['input_ids'].unsqueeze(0), + 'attention_masks': img_metas['attention_masks'].unsqueeze(0), + 'token_type_ids': img_metas['token_type_ids'].unsqueeze(0), + 'labels': img_metas['labels'].unsqueeze(0) + } + # forward the model + with torch.no_grad(): + result = model(None, img_metas, return_loss=False) + return result diff --git a/mmocr/core/evaluation/__init__.py b/mmocr/core/evaluation/__init__.py index 493e894e..f171ef18 100644 --- a/mmocr/core/evaluation/__init__.py +++ b/mmocr/core/evaluation/__init__.py @@ -2,9 +2,10 @@ from .hmean import eval_hmean from .hmean_ic13 import eval_hmean_ic13 from .hmean_iou import eval_hmean_iou from .kie_metric import compute_f1_score +from .ner_metric import eval_ner_f1 from .ocr_metric import eval_ocr_metric __all__ = [ 'eval_hmean_ic13', 'eval_hmean_iou', 'eval_ocr_metric', 'eval_hmean', - 'compute_f1_score' + 'compute_f1_score', 'eval_ner_f1' ] diff --git a/mmocr/core/evaluation/ner_metric.py b/mmocr/core/evaluation/ner_metric.py new file mode 100644 index 00000000..4cd842c6 --- /dev/null +++ b/mmocr/core/evaluation/ner_metric.py @@ -0,0 +1,113 @@ +from collections import Counter + + +def gt_label2entity(gt_infos): + """Get all entities from ground truth infos. + Args: + gt_infos (list[dict]): Groudtruth infomation contains text and label. + Returns: + gt_entities (list[list]): Original labeled entities in groundtruth. + [[category,start_position,end_position]] + """ + gt_entities = [] + for gt_info in gt_infos: + line_entities = [] + label = gt_info['label'] + for key, value in label.items(): + for _, places in value.items(): + for place in places: + line_entities.append([key, place[0], place[1]]) + gt_entities.append(line_entities) + return gt_entities + + +def _compute_f1(origin, found, right): + """Calculate recall, precision, f1-score. + + Args: + origin (int): Original entities in groundtruth. + found (int): Predicted entities from model. + right (int): Predicted entities that + can match to the original annotation. + Returns: + recall (float): Metric of recall. + precision (float): Metric of precision. + f1 (float): Metric of f1-score. + """ + recall = 0 if origin == 0 else (right / origin) + precision = 0 if found == 0 else (right / found) + f1 = 0. if recall + precision == 0 else (2 * precision * recall) / ( + precision + recall) + return recall, precision, f1 + + +def compute_f1_all(pred_entities, gt_entities): + """Calculate precision, recall and F1-score for all categories. + + Args: + pred_entities: The predicted entities from model. + gt_entities: The entities of ground truth file. + Returns: + class_info (dict): precision,recall, f1-score in total + and each catogories. + """ + origins = [] + founds = [] + rights = [] + for i, _ in enumerate(pred_entities): + origins.extend(gt_entities[i]) + founds.extend(pred_entities[i]) + rights.extend([ + pre_entity for pre_entity in pred_entities[i] + if pre_entity in gt_entities[i] + ]) + + class_info = {} + origin_counter = Counter([x[0] for x in origins]) + found_counter = Counter([x[0] for x in founds]) + right_counter = Counter([x[0] for x in rights]) + for type_, count in origin_counter.items(): + origin = count + found = found_counter.get(type_, 0) + right = right_counter.get(type_, 0) + recall, precision, f1 = _compute_f1(origin, found, right) + class_info[type_] = { + 'precision': precision, + 'recall': recall, + 'f1-score': f1 + } + origin = len(origins) + found = len(founds) + right = len(rights) + recall, precision, f1 = _compute_f1(origin, found, right) + class_info['all'] = { + 'precision': precision, + 'recall': recall, + 'f1-score': f1 + } + return class_info + + +def eval_ner_f1(results, gt_infos): + """Evaluate for ner task. + + Args: + results (list): Predict results of entities. + gt_infos (list[dict]): Groudtruth infomation which contains + text and label. + Returns: + class_info (dict): precision,recall, f1-score of total + and each catogory. + """ + assert len(results) == len(gt_infos) + gt_entities = gt_label2entity(gt_infos) + pred_entities = [] + for i, gt_info in enumerate(gt_infos): + line_entities = [] + for result in results[i]: + line_entities.append(result) + pred_entities.append(line_entities) + assert len(pred_entities) == len(gt_entities) + class_info = compute_f1_all(pred_entities, gt_entities) + + return class_info diff --git a/mmocr/datasets/__init__.py b/mmocr/datasets/__init__.py index 87baba26..881bccb2 100644 --- a/mmocr/datasets/__init__.py +++ b/mmocr/datasets/__init__.py @@ -3,6 +3,7 @@ from . import utils from .base_dataset import BaseDataset from .icdar_dataset import IcdarDataset from .kie_dataset import KIEDataset +from .ner_dataset import NerDataset from .ocr_dataset import OCRDataset from .ocr_seg_dataset import OCRSegDataset from .pipelines import CustomFormatBundle, DBNetTargets, FCENetTargets @@ -13,7 +14,8 @@ from .utils import * # NOQA __all__ = [ 'DATASETS', 'IcdarDataset', 'build_dataloader', 'build_dataset', 'BaseDataset', 'OCRDataset', 'TextDetDataset', 'CustomFormatBundle', - 'DBNetTargets', 'OCRSegDataset', 'KIEDataset', 'FCENetTargets' + 'DBNetTargets', 'OCRSegDataset', 'KIEDataset', 'FCENetTargets', + 'NerDataset' ] __all__ += utils.__all__ diff --git a/mmocr/datasets/ner_dataset.py b/mmocr/datasets/ner_dataset.py new file mode 100644 index 00000000..fd2fa146 --- /dev/null +++ b/mmocr/datasets/ner_dataset.py @@ -0,0 +1,47 @@ +from mmdet.datasets.builder import DATASETS +from mmocr.core.evaluation.ner_metric import eval_ner_f1 +from mmocr.datasets.base_dataset import BaseDataset + + +@DATASETS.register_module() +class NerDataset(BaseDataset): + """Custom dataset for named entity recognition tasks. + + Args: + ann_file (txt): Annotation file path. + loader (dict): Dictionary to construct loader + to load annotation infos. + pipeline (list[dict]): Processing pipeline. + test_mode (bool, optional): If True, try...except will + be turned off in __getitem__. + """ + + def prepare_train_img(self, index): + """Get training data and annotations after pipeline. + + Args: + index (int): Index of data. + + Returns: + dict: Training data and annotation after pipeline with new keys \ + introduced by pipeline. + """ + ann_info = self.data_infos[index] + + return self.pipeline(ann_info) + + def evaluate(self, results, metric=None, logger=None, **kwargs): + """Evaluate the dataset. + + Args: + results (list): Testing results of the dataset. + metric (str | list[str]): Metrics to be evaluated. + logger (logging.Logger | str | None): Logger used for printing + related information during evaluation. Default: None. + Returns: + info (dict): A dict containing the following keys: + 'acc', 'recall', 'f1-score'. + """ + gt_infos = list(self.data_infos) + eval_results = eval_ner_f1(results, gt_infos) + return eval_results diff --git a/mmocr/datasets/pipelines/__init__.py b/mmocr/datasets/pipelines/__init__.py index 7fa4c924..3d0e60f8 100644 --- a/mmocr/datasets/pipelines/__init__.py +++ b/mmocr/datasets/pipelines/__init__.py @@ -3,6 +3,7 @@ from .custom_format_bundle import CustomFormatBundle from .dbnet_transforms import EastRandomCrop, ImgAug from .kie_transforms import KIEFormatBundle from .loading import LoadImageFromNdarray, LoadTextAnnotations +from .ner_transforms import NerTransform, ToTensorNER from .ocr_seg_targets import OCRSegTargets from .ocr_transforms import (FancyPCA, NormalizeOCR, OnlineCropOCR, OpencvToPil, PilToOpencv, RandomPaddingOCR, @@ -24,5 +25,5 @@ __all__ = [ 'ImgAug', 'EastRandomCrop', 'RandomRotateImageBox', 'OpencvToPil', 'PilToOpencv', 'KIEFormatBundle', 'SquareResizePad', 'TextSnakeTargets', 'sort_vertex', 'LoadImageFromNdarray', 'sort_vertex8', 'FCENetTargets', - 'RandomScaling', 'RandomCropFlip' + 'RandomScaling', 'RandomCropFlip', 'NerTransform', 'ToTensorNER' ] diff --git a/mmocr/datasets/pipelines/ner_transforms.py b/mmocr/datasets/pipelines/ner_transforms.py new file mode 100644 index 00000000..afcaf53a --- /dev/null +++ b/mmocr/datasets/pipelines/ner_transforms.py @@ -0,0 +1,62 @@ +import torch + +from mmdet.datasets.builder import PIPELINES +from mmocr.models.builder import build_convertor + + +@PIPELINES.register_module() +class NerTransform: + """Convert text to ID and entity in ground truth to label ID. The masks and + tokens are generated at the same time. The four parameters will be used as + input to the model. + + Args: + label_convertor: Convert text to ID and entity + in ground truth to label ID. + max_len (int): Limited maximum input length. + """ + + def __init__(self, label_convertor, max_len): + self.label_convertor = build_convertor(label_convertor) + self.max_len = max_len + + def __call__(self, results): + texts = results['text'] + input_ids = self.label_convertor.convert_text2id(texts) + labels = self.label_convertor.convert_entity2label( + results['label'], len(texts)) + + attention_mask = [0] * self.max_len + token_type_ids = [0] * self.max_len + # The beginning and end IDs are added to the ID, + # so the mask length is increased by 2 + for i in range(len(texts) + 2): + attention_mask[i] = 1 + results = dict( + labels=labels, + texts=texts, + input_ids=input_ids, + attention_mask=attention_mask, + token_type_ids=token_type_ids) + return results + + +@PIPELINES.register_module() +class ToTensorNER: + """Convert data with ``list`` type to tensor.""" + + def __call__(self, results): + + input_ids = torch.tensor(results['input_ids']) + labels = torch.tensor(results['labels']) + attention_masks = torch.tensor(results['attention_mask']) + token_type_ids = torch.tensor(results['token_type_ids']) + + results = dict( + img=[], + img_metas=dict( + input_ids=input_ids, + attention_masks=attention_masks, + labels=labels, + token_type_ids=token_type_ids)) + return results diff --git a/mmocr/models/__init__.py b/mmocr/models/__init__.py index b1aafca5..4c3586a6 100644 --- a/mmocr/models/__init__.py +++ b/mmocr/models/__init__.py @@ -7,6 +7,7 @@ from .builder import (CONVERTORS, DECODERS, ENCODERS, PREPROCESSOR, from .common import * # NOQA from .kie import * # NOQA +from .ner import * # NOQA from .textdet import * # NOQA from .textrecog import * # NOQA diff --git a/mmocr/models/common/losses/__init__.py b/mmocr/models/common/losses/__init__.py index 7daa7345..cbba1432 100644 --- a/mmocr/models/common/losses/__init__.py +++ b/mmocr/models/common/losses/__init__.py @@ -1,3 +1,4 @@ from .dice_loss import DiceLoss +from .focal_loss import FocalLoss -__all__ = ['DiceLoss'] +__all__ = ['DiceLoss', 'FocalLoss'] diff --git a/mmocr/models/common/losses/focal_loss.py b/mmocr/models/common/losses/focal_loss.py new file mode 100644 index 00000000..b67a1c46 --- /dev/null +++ b/mmocr/models/common/losses/focal_loss.py @@ -0,0 +1,30 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F + + +class FocalLoss(nn.Module): + """Multi-class Focal loss implementation. + + Args: + gamma (float): The larger the gamma, the smaller + the loss weight of easier samples. + weight (float): A manual rescaling weight given to each + class. + ignore_index (int): Specifies a target value that is ignored + and does not contribute to the input gradient. + """ + + def __init__(self, gamma=2, weight=None, ignore_index=-100): + super().__init__() + self.gamma = gamma + self.weight = weight + self.ignore_index = ignore_index + + def forward(self, input, target): + logit = F.log_softmax(input, dim=1) + pt = torch.exp(logit) + logit = (1 - pt)**self.gamma * logit + loss = F.nll_loss( + logit, target, self.weight, ignore_index=self.ignore_index) + return loss diff --git a/mmocr/models/ner/__init__.py b/mmocr/models/ner/__init__.py new file mode 100644 index 00000000..0006a5c9 --- /dev/null +++ b/mmocr/models/ner/__init__.py @@ -0,0 +1,5 @@ +from .classifer import * # noqa: F401,F403 +from .convertors import * # noqa: F401,F403 +from .decoder import * # noqa: F401,F403 +from .encoder import * # noqa: F401,F403 +from .loss import * # noqa: F401,F403 diff --git a/mmocr/models/ner/classifer/__init__.py b/mmocr/models/ner/classifer/__init__.py new file mode 100644 index 00000000..13189569 --- /dev/null +++ b/mmocr/models/ner/classifer/__init__.py @@ -0,0 +1,3 @@ +from .ner_classifier import NerClassifier + +__all__ = ['NerClassifier'] diff --git a/mmocr/models/ner/classifer/ner_classifier.py b/mmocr/models/ner/classifer/ner_classifier.py new file mode 100644 index 00000000..d2de90f7 --- /dev/null +++ b/mmocr/models/ner/classifer/ner_classifier.py @@ -0,0 +1,52 @@ +from mmdet.models.builder import DETECTORS, build_loss +from mmocr.models.builder import build_convertor, build_decoder, build_encoder +from mmocr.models.textrecog.recognizer.base import BaseRecognizer + + +@DETECTORS.register_module() +class NerClassifier(BaseRecognizer): + """Base class for NER classifier.""" + + def __init__(self, + encoder, + decoder, + loss, + label_convertor, + train_cfg=None, + test_cfg=None, + pretrained=None): + super().__init__() + self.label_convertor = build_convertor(label_convertor) + + encoder.update(pretrained=pretrained) + self.encoder = build_encoder(encoder) + + decoder.update(num_labels=self.label_convertor.num_labels) + self.decoder = build_decoder(decoder) + + loss.update(num_labels=self.label_convertor.num_labels) + self.loss = build_loss(loss) + + def extract_feat(self, imgs): + """Extract features from images.""" + raise NotImplementedError( + 'Extract feature module is not implemented yet.') + + def forward_train(self, imgs, img_metas, **kwargs): + encode_out = self.encoder(img_metas) + logits, _ = self.decoder(encode_out) + loss = self.loss(logits, img_metas) + return loss + + def forward_test(self, imgs, img_metas, **kwargs): + encode_out = self.encoder(img_metas) + _, preds = self.decoder(encode_out) + pred_entities = self.label_convertor.convert_pred2entities( + preds, img_metas['attention_masks']) + return pred_entities + + def aug_test(self, imgs, img_metas, **kwargs): + raise NotImplementedError('Augmentation test is not implemented yet.') + + def simple_test(self, img, img_metas, **kwargs): + raise NotImplementedError('Simple test is not implemented yet.') diff --git a/mmocr/models/ner/convertors/__init__.py b/mmocr/models/ner/convertors/__init__.py new file mode 100644 index 00000000..9186e46d --- /dev/null +++ b/mmocr/models/ner/convertors/__init__.py @@ -0,0 +1,3 @@ +from .ner_convertor import NerConvertor + +__all__ = ['NerConvertor'] diff --git a/mmocr/models/ner/convertors/ner_convertor.py b/mmocr/models/ner/convertors/ner_convertor.py new file mode 100644 index 00000000..324ebce0 --- /dev/null +++ b/mmocr/models/ner/convertors/ner_convertor.py @@ -0,0 +1,171 @@ +import numpy as np + +from mmocr.models.builder import CONVERTORS + + +@CONVERTORS.register_module() +class NerConvertor: + """Convert between text, index and tensor for NER pipeline. + + Args: + annotation_type (str): BIO((B-begin, I-inside, O-outside)), + BIOES(B-begin, I-inside, O-outside, E-end, S-single) + vocab_file (str): File to convert words to ids. + categories (list[str]): All entity categories supported by the model. + max_len (int): The maximum length of the input text. + unknown_id (int): For words that do not appear in vocab.txt. + start_id (int): Each input is prefixed with an input ID. + end_id (int): Each output is prefixed with an output ID. + """ + + def __init__(self, + annotation_type='bio', + vocab_file=None, + categories=None, + max_len=None, + unknown_id=100, + start_id=101, + end_id=102): + self.annotation_type = annotation_type + self.categories = categories + self.word2ids = {} + self.max_len = max_len + self.unknown_id = unknown_id + self.start_id = start_id + self.end_id = end_id + assert self.max_len > 2 + assert self.annotation_type in ['bio', 'bioes'] + + lines = open(vocab_file, encoding='utf-8').readlines() + self.vocab_size = len(lines) + for i in range(len(lines)): + self.word2ids.update({lines[i].rstrip(): i}) + + if self.annotation_type == 'bio': + self.label2id_dict, self.id2label, self.ignore_id = \ + self._generate_labelid_dict() + elif self.annotation_type == 'bioes': + raise NotImplementedError('Bioes format is not surpported yet!') + + assert self.ignore_id is not None + assert self.id2label is not None + self.num_labels = len(self.id2label) + + def _generate_labelid_dict(self): + """Generate a dictionary that maps input to ID and ID to output.""" + num_classes = len(self.categories) + label2id_dict = {} + ignore_id = 2 * num_classes + 1 + id2label_dict = { + 0: 'X', + ignore_id: 'O', + 2 * num_classes + 2: '[START]', + 2 * num_classes + 3: '[END]' + } + + for index, category in enumerate(self.categories): + start_label = index + 1 + end_label = index + 1 + num_classes + label2id_dict.update({category: [start_label, end_label]}) + id2label_dict.update({start_label: 'B-' + category}) + id2label_dict.update({end_label: 'I-' + category}) + + return label2id_dict, id2label_dict, ignore_id + + def convert_text2id(self, text): + """Convert characters to ids. + + If the input is uppercase, + convert to lowercase first. + Args: + text (list[char]): Annotations of one paragraph. + Returns: + input_ids (list): Corresponding IDs after conversion. + """ + ids = [] + for word in text.lower(): + if word in self.word2ids: + ids.append(self.word2ids[word]) + else: + ids.append(self.unknown_id) + # Text that exceeds the maximum length is truncated. + valid_len = min(len(text), self.max_len) + input_ids = [0] * self.max_len + input_ids[0] = self.start_id + for i in range(1, valid_len + 1): + input_ids[i] = ids[i - 1] + input_ids[i + 1] = self.end_id + + return input_ids + + def convert_entity2label(self, label, text_len): + """Convert labeled entities to ids. + + Args: + label (dict): Labels of entities. + text_len (int): The length of input text. + Returns: + labels (list): Label ids of an input text. + """ + labels = [0] * self.max_len + for j in range(min(text_len + 2, self.max_len)): + labels[j] = self.ignore_id + categorys = label + for key in categorys: + for text in categorys[key]: + for place in categorys[key][text]: + # Remove the label position beyond the maximum length. + if place[0] + 1 < len(labels): + labels[place[0] + 1] = self.label2id_dict[key][0] + for i in range(place[0] + 1, place[1] + 1): + if i + 1 < len(labels): + labels[i + 1] = self.label2id_dict[key][1] + return labels + + def convert_pred2entities(self, preds, masks): + """Gets entities from preds. + + Args: + preds (list): Sequence of preds. + masks (tensor): The valid part is 1 and the invalid part is 0. + Returns: + pred_entities (list): List of [[[entity_type, + entity_start, entity_end]]]. + """ + + masks = masks.detach().cpu().numpy() + pred_entities = [] + assert isinstance(preds, list) + for index, pred in enumerate(preds): + entities = [] + entity = [-1, -1, -1] + results = (masks[index][1:] * np.array(pred[1:])).tolist() + for index, tag in enumerate(results): + if not isinstance(tag, str): + tag = self.id2label[tag] + if self.annotation_type == 'bio': + if tag.startswith('B-'): + if entity[2] != -1 and entity[1] < entity[2]: + entities.append(entity) + entity = [-1, -1, -1] + entity[1] = index + entity[0] = tag.split('-')[1] + entity[2] = index + if index == len(results) - 1 and entity[1] < entity[2]: + entities.append(entity) + elif tag.startswith('I-') and entity[1] != -1: + _type = tag.split('-')[1] + if _type == entity[0]: + entity[2] = index + + if index == len(results) - 1 and entity[1] < entity[2]: + entities.append(entity) + else: + if entity[2] != -1 and entity[1] < entity[2]: + entities.append(entity) + entity = [-1, -1, -1] + else: + raise NotImplementedError( + 'The data format is not surpported yet!') + pred_entities.append(entities) + return pred_entities diff --git a/mmocr/models/ner/decoder/__init__.py b/mmocr/models/ner/decoder/__init__.py new file mode 100644 index 00000000..c3fd96b7 --- /dev/null +++ b/mmocr/models/ner/decoder/__init__.py @@ -0,0 +1,3 @@ +from .fc_decoder import FCDecoder + +__all__ = ['FCDecoder'] diff --git a/mmocr/models/ner/decoder/fc_decoder.py b/mmocr/models/ner/decoder/fc_decoder.py new file mode 100644 index 00000000..8fdf00f4 --- /dev/null +++ b/mmocr/models/ner/decoder/fc_decoder.py @@ -0,0 +1,45 @@ +import numpy as np +import torch.nn as nn +import torch.nn.functional as F +from mmcv.cnn import uniform_init, xavier_init + +from mmocr.models.builder import DECODERS + + +@DECODERS.register_module() +class FCDecoder(nn.Module): + """FC Decoder class for Ner. + + Args: + num_labels (int): Number of categories mapped by entity label. + hidden_dropout_prob (float): The dropout probability of hidden layer. + hidden_size (int): Hidden layer output layer channels. + """ + + def __init__(self, + num_labels=None, + hidden_dropout_prob=0.1, + hidden_size=768): + super().__init__() + self.num_labels = num_labels + + self.dropout = nn.Dropout(hidden_dropout_prob) + self.classifier = nn.Linear(hidden_size, self.num_labels) + self.init_weights() + + def forward(self, outputs): + sequence_output = outputs[0] + sequence_output = self.dropout(sequence_output) + logits = self.classifier(sequence_output) + softmax = F.softmax(logits, dim=2) + preds = softmax.detach().cpu().numpy() + preds = np.argmax(preds, axis=2).tolist() + return logits, preds + + def init_weights(self): + + for m in self.modules(): + if isinstance(m, nn.Conv2d): + xavier_init(m) + elif isinstance(m, nn.BatchNorm2d): + uniform_init(m) diff --git a/mmocr/models/ner/encoder/__init__.py b/mmocr/models/ner/encoder/__init__.py new file mode 100755 index 00000000..0110ca0c --- /dev/null +++ b/mmocr/models/ner/encoder/__init__.py @@ -0,0 +1,3 @@ +from .bert_encoder import BertEncoder + +__all__ = ['BertEncoder'] diff --git a/mmocr/models/ner/encoder/bert_encoder.py b/mmocr/models/ner/encoder/bert_encoder.py new file mode 100644 index 00000000..dc392bf0 --- /dev/null +++ b/mmocr/models/ner/encoder/bert_encoder.py @@ -0,0 +1,87 @@ +import torch.nn as nn +from mmcv.cnn import uniform_init, xavier_init +from mmcv.runner import load_checkpoint + +from mmdet.utils import get_root_logger +from mmocr.models.builder import ENCODERS +from mmocr.models.ner.utils.bert import BertModel + + +@ENCODERS.register_module() +class BertEncoder(nn.Module): + """Bert encoder + Args: + num_hidden_layers (int): The number of hidden layers. + initializer_range (float): + vocab_size (int): Number of words supported. + hidden_size (int): Hidden size. + max_position_embeddings (int): Max positions embedding size. + type_vocab_size (int): The size of type_vocab. + layer_norm_eps (float): Epsilon of layer norm. + hidden_dropout_prob (float): The dropout probability of hidden layer. + output_attentions (bool): Whether use the attentions in output. + output_hidden_states (bool): Whether use the hidden_states in output. + num_attention_heads (int): The number of attention heads. + attention_probs_dropout_prob (float): The dropout probability + of attention. + intermediate_size (int): The size of intermediate layer. + hidden_act (str): Hidden layer activation. + """ + + def __init__(self, + num_hidden_layers=12, + initializer_range=0.02, + vocab_size=21128, + hidden_size=768, + max_position_embeddings=128, + type_vocab_size=2, + layer_norm_eps=1e-12, + hidden_dropout_prob=0.1, + output_attentions=False, + output_hidden_states=False, + num_attention_heads=12, + attention_probs_dropout_prob=0.1, + intermediate_size=3072, + hidden_act='gelu_new', + pretrained=None): + super().__init__() + self.bert = BertModel( + num_hidden_layers=num_hidden_layers, + initializer_range=initializer_range, + vocab_size=vocab_size, + hidden_size=hidden_size, + max_position_embeddings=max_position_embeddings, + type_vocab_size=type_vocab_size, + layer_norm_eps=layer_norm_eps, + hidden_dropout_prob=hidden_dropout_prob, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + num_attention_heads=num_attention_heads, + attention_probs_dropout_prob=attention_probs_dropout_prob, + intermediate_size=intermediate_size, + hidden_act=hidden_act) + self.init_weights(pretrained=pretrained) + + def forward(self, results): + + device = next(self.bert.parameters()).device + input_ids = results['input_ids'].to(device) + attention_masks = results['attention_masks'].to(device) + token_type_ids = results['token_type_ids'].to(device) + + outputs = self.bert( + input_ids=input_ids, + attention_masks=attention_masks, + token_type_ids=token_type_ids) + return outputs + + def init_weights(self, pretrained=None): + if pretrained is not None: + logger = get_root_logger() + load_checkpoint(self, pretrained, strict=False, logger=logger) + else: + for m in self.modules(): + if isinstance(m, nn.Conv2d): + xavier_init(m) + elif isinstance(m, nn.BatchNorm2d): + uniform_init(m) diff --git a/mmocr/models/ner/loss/__init__.py b/mmocr/models/ner/loss/__init__.py new file mode 100644 index 00000000..e35c095e --- /dev/null +++ b/mmocr/models/ner/loss/__init__.py @@ -0,0 +1,4 @@ +from .masked_cross_entropy_loss import MaskedCrossEntropyLoss +from .masked_focal_loss import MaskedFocalLoss + +__all__ = ['MaskedCrossEntropyLoss', 'MaskedFocalLoss'] diff --git a/mmocr/models/ner/loss/masked_cross_entropy_loss.py b/mmocr/models/ner/loss/masked_cross_entropy_loss.py new file mode 100644 index 00000000..b925d049 --- /dev/null +++ b/mmocr/models/ner/loss/masked_cross_entropy_loss.py @@ -0,0 +1,55 @@ +from torch import nn +from torch.nn import CrossEntropyLoss + +from mmdet.models.builder import LOSSES + + +@LOSSES.register_module() +class MaskedCrossEntropyLoss(nn.Module): + """The implementation of masked cross entropy loss. + + The mask has 1 for real tokens and 0 for padding tokens, + which only keep active parts of the cross entropy loss. + Args: + num_labels (int): Number of classes in labels. + ignore_index (int): Specifies a target value that is ignored + and does not contribute to the input gradient. + """ + + def __init__(self, num_labels=None, ignore_index=0): + super().__init__() + self.num_labels = num_labels + self.criterion = CrossEntropyLoss(ignore_index=ignore_index) + + def forward(self, logits, img_metas): + '''Loss forword. + Args: + logits: Model output with shape [N, C]. + img_metas (dict): A dict containing the following keys: + - img (list]): This parameter is reserved. + - labels (list[int]): The labels for each word + of the sequence. + - texts (list): The words of the sequence. + - input_ids (list): The ids for each word of + the sequence. + - attention_mask (list): The mask for each word + of the sequence. The mask has 1 for real tokens + and 0 for padding tokens. Only real tokens are + attended to. + - token_type_ids (list): The tokens for each word + of the sequence. + ''' + + labels = img_metas['labels'] + attention_masks = img_metas['attention_masks'] + + # Only keep active parts of the loss + if attention_masks is not None: + active_loss = attention_masks.view(-1) == 1 + active_logits = logits.view(-1, self.num_labels)[active_loss] + active_labels = labels.view(-1)[active_loss] + loss = self.criterion(active_logits, active_labels) + else: + loss = self.criterion( + logits.view(-1, self.num_labels), labels.view(-1)) + return {'loss_cls': loss} diff --git a/mmocr/models/ner/loss/masked_focal_loss.py b/mmocr/models/ner/loss/masked_focal_loss.py new file mode 100644 index 00000000..c5147778 --- /dev/null +++ b/mmocr/models/ner/loss/masked_focal_loss.py @@ -0,0 +1,55 @@ +from torch import nn + +from mmdet.models.builder import LOSSES +from mmocr.models.common.losses.focal_loss import FocalLoss + + +@LOSSES.register_module() +class MaskedFocalLoss(nn.Module): + """The implementation of masked focal loss. + + The mask has 1 for real tokens and 0 for padding tokens, + which only keep active parts of the focal loss + Args: + num_labels (int): Number of classes in labels. + ignore_index (int): Specifies a target value that is ignored + and does not contribute to the input gradient. + """ + + def __init__(self, num_labels=None, ignore_index=0): + super().__init__() + self.num_labels = num_labels + self.criterion = FocalLoss(ignore_index=ignore_index) + + def forward(self, logits, img_metas): + '''Loss forword. + Args: + logits: Model output with shape [N, C]. + img_metas (dict): A dict containing the following keys: + - img (list]): This parameter is reserved. + - labels (list[int]): The labels for each word + of the sequence. + - texts (list): The words of the sequence. + - input_ids (list): The ids for each word of + the sequence. + - attention_mask (list): The mask for each word + of the sequence. The mask has 1 for real tokens + and 0 for padding tokens. Only real tokens are + attended to. + - token_type_ids (list): The tokens for each word + of the sequence. + ''' + + labels = img_metas['labels'] + attention_masks = img_metas['attention_masks'] + + # Only keep active parts of the loss + if attention_masks is not None: + active_loss = attention_masks.view(-1) == 1 + active_logits = logits.view(-1, self.num_labels)[active_loss] + active_labels = labels.view(-1)[active_loss] + loss = self.criterion(active_logits, active_labels) + else: + loss = self.criterion( + logits.view(-1, self.num_labels), labels.view(-1)) + return {'loss_cls': loss} diff --git a/mmocr/models/ner/utils/__init__.py b/mmocr/models/ner/utils/__init__.py new file mode 100644 index 00000000..2fb123d0 --- /dev/null +++ b/mmocr/models/ner/utils/__init__.py @@ -0,0 +1,4 @@ +from .activations import ACT2FN +from .bert import BertModel + +__all__ = ['BertModel', 'ACT2FN'] diff --git a/mmocr/models/ner/utils/activations.py b/mmocr/models/ner/utils/activations.py new file mode 100644 index 00000000..8a7abbe8 --- /dev/null +++ b/mmocr/models/ner/utils/activations.py @@ -0,0 +1,39 @@ +# ------------------------------------------------------------------------------ +# Adapted from https://github.com/lonePatient/BERT-NER-Pytorch +# Original licence: Copyright (c) 2020 Weitang Liu, under the MIT License. +# ------------------------------------------------------------------------------ + +import math + +import torch +from mmcv.cnn import Swish + + +def gelu(x): + """Original Implementation of the gelu activation function in Google Bert + repo when initially created. For information: OpenAI GPT's gelu is slightly + different (and gives slightly different results): + + 0.5 * x * (1 + torch.tanh(math.sqrt(2 / math.pi) * + (x + 0.044715 * torch.pow(x, 3)))) + Also see https://arxiv.org/abs/1606.08415 + """ + return x * 0.5 * (1.0 + torch.erf(x / math.sqrt(2.0))) + + +def gelu_new(x): + """Implementation of the gelu activation function currently in Google Bert + repo (identical to OpenAI GPT). + + Also see https://arxiv.org/abs/1606.08415 + """ + return 0.5 * x * (1 + torch.tanh( + math.sqrt(2 / math.pi) * (x + 0.044715 * torch.pow(x, 3)))) + + +ACT2FN = { + 'gelu': gelu, + 'relu': torch.nn.functional.relu, + 'swish': Swish, + 'gelu_new': gelu_new +} diff --git a/mmocr/models/ner/utils/bert.py b/mmocr/models/ner/utils/bert.py new file mode 100644 index 00000000..94264504 --- /dev/null +++ b/mmocr/models/ner/utils/bert.py @@ -0,0 +1,486 @@ +# ------------------------------------------------------------------------------ +# Adapted from https://github.com/lonePatient/BERT-NER-Pytorch +# Original licence: Copyright (c) 2020 Weitang Liu, under the MIT License. +# ------------------------------------------------------------------------------ + +import math + +import torch +import torch.nn as nn + +from mmocr.models.ner.utils.activations import ACT2FN + + +class BertModel(nn.Module): + """Implement Bert model for named entity recognition task. + + The code is adapted from https://github.com/lonePatient/BERT-NER-Pytorch + Args: + num_hidden_layers (int): The number of hidden layers. + initializer_range (float): + vocab_size (int): Number of words supported. + hidden_size (int): Hidden size. + max_position_embeddings (int): Max positionsembedding size. + type_vocab_size (int): The size of type_vocab. + layer_norm_eps (float): eps. + hidden_dropout_prob (float): The dropout probability of hidden layer. + output_attentions (bool): Whether use the attentions in output + output_hidden_states (bool): Whether use the hidden_states in output. + num_attention_heads (int): The number of attention heads. + attention_probs_dropout_prob (float): The dropout probability + for the attention probabilities normalized from + the attention scores. + intermediate_size (int): The size of intermediate layer. + hidden_act (str): hidden layer activation + """ + + def __init__(self, + num_hidden_layers=12, + initializer_range=0.02, + vocab_size=21128, + hidden_size=768, + max_position_embeddings=128, + type_vocab_size=2, + layer_norm_eps=1e-12, + hidden_dropout_prob=0.1, + output_attentions=False, + output_hidden_states=False, + num_attention_heads=12, + attention_probs_dropout_prob=0.1, + intermediate_size=3072, + hidden_act='gelu_new'): + super().__init__() + self.embeddings = BertEmbeddings( + vocab_size=vocab_size, + hidden_size=hidden_size, + max_position_embeddings=max_position_embeddings, + type_vocab_size=type_vocab_size, + layer_norm_eps=layer_norm_eps, + hidden_dropout_prob=hidden_dropout_prob) + self.encoder = BertEncoder( + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + num_hidden_layers=num_hidden_layers, + hidden_size=hidden_size, + num_attention_heads=num_attention_heads, + attention_probs_dropout_prob=attention_probs_dropout_prob, + layer_norm_eps=layer_norm_eps, + hidden_dropout_prob=hidden_dropout_prob, + intermediate_size=intermediate_size, + hidden_act=hidden_act) + self.pooler = BertPooler(hidden_size=hidden_size) + self.num_hidden_layers = num_hidden_layers + self.initializer_range = initializer_range + self.init_weights() + + def _resize_token_embeddings(self, new_num_tokens): + old_embeddings = self.embeddings.word_embeddings + new_embeddings = self._get_resized_embeddings(old_embeddings, + new_num_tokens) + self.embeddings.word_embeddings = new_embeddings + return self.embeddings.word_embeddings + + def forward(self, + input_ids, + attention_masks=None, + token_type_ids=None, + position_ids=None, + head_mask=None): + if attention_masks is None: + attention_masks = torch.ones_like(input_ids) + if token_type_ids is None: + token_type_ids = torch.zeros_like(input_ids) + attention_masks = attention_masks[:, None, None] + attention_masks = attention_masks.to( + dtype=next(self.parameters()).dtype) + attention_masks = (1.0 - attention_masks) * -10000.0 + if head_mask is not None: + if head_mask.dim() == 1: + head_mask = head_mask[None, None, :, None, None] + elif head_mask.dim() == 2: + head_mask = head_mask[None, :, None, None] + head_mask = head_mask.to(dtype=next(self.parameters()).dtype) + else: + head_mask = [None] * self.num_hidden_layers + + embedding_output = self.embeddings( + input_ids, + position_ids=position_ids, + token_type_ids=token_type_ids) + sequence_output, *encoder_outputs = self.encoder( + embedding_output, attention_masks, head_mask=head_mask) + # sequence_output = encoder_outputs[0] + pooled_output = self.pooler(sequence_output) + + # add hidden_states and attentions if they are here + # sequence_output, pooled_output, (hidden_states), (attentions) + outputs = ( + sequence_output, + pooled_output, + ) + tuple(encoder_outputs) + return outputs + + def _init_weights(self, module): + """Initialize the weights.""" + if isinstance(module, (nn.Linear, nn.Embedding)): + # Slightly different from the TF version which + # uses truncated_normal for initialization + # cf https://github.com/pytorch/pytorch/pull/5617 + module.weight.data.normal_(mean=0.0, std=self.initializer_range) + elif isinstance(module, torch.nn.LayerNorm): + module.bias.data.zero_() + module.weight.data.fill_(1.0) + if isinstance(module, nn.Linear) and module.bias is not None: + module.bias.data.zero_() + + def init_weights(self): + """Initialize and prunes weights if needed.""" + # Initialize weights + self.apply(self._init_weights) + + +class BertEmbeddings(nn.Module): + """Construct the embeddings from word, position and token_type embeddings. + + The code is adapted from https://github.com/lonePatient/BERT-NER-Pytorch. + Args: + vocab_size (int): Number of words supported. + hidden_size (int): Hidden size. + max_position_embeddings (int): Max positions embedding size. + type_vocab_size (int): The size of type_vocab. + layer_norm_eps (float): eps. + hidden_dropout_prob (float): The dropout probability of hidden layer. + """ + + def __init__(self, + vocab_size=21128, + hidden_size=768, + max_position_embeddings=128, + type_vocab_size=2, + layer_norm_eps=1e-12, + hidden_dropout_prob=0.1): + super().__init__() + + self.word_embeddings = nn.Embedding( + vocab_size, hidden_size, padding_idx=0) + self.position_embeddings = nn.Embedding(max_position_embeddings, + hidden_size) + self.token_type_embeddings = nn.Embedding(type_vocab_size, hidden_size) + + # self.LayerNorm is not snake-cased to stick with + # TensorFlow model variable name and be able to load + # any TensorFlow checkpoint file + self.LayerNorm = torch.nn.LayerNorm(hidden_size, eps=layer_norm_eps) + self.dropout = nn.Dropout(hidden_dropout_prob) + + def forward(self, input_ids, token_type_ids=None, position_ids=None): + seq_length = input_ids.size(1) + if position_ids is None: + position_ids = torch.arange( + seq_length, dtype=torch.long, device=input_ids.device) + position_ids = position_ids.unsqueeze(0).expand_as(input_ids) + if token_type_ids is None: + token_type_ids = torch.zeros_like(input_ids) + + words_emb = self.word_embeddings(input_ids) + position_emb = self.position_embeddings(position_ids) + token_type_emb = self.token_type_embeddings(token_type_ids) + embeddings = words_emb + position_emb + token_type_emb + embeddings = self.LayerNorm(embeddings) + embeddings = self.dropout(embeddings) + return embeddings + + +class BertEncoder(nn.Module): + """The code is adapted from https://github.com/lonePatient/BERT-NER- + Pytorch.""" + + def __init__(self, + output_attentions=False, + output_hidden_states=False, + num_hidden_layers=12, + hidden_size=768, + num_attention_heads=12, + attention_probs_dropout_prob=0.1, + layer_norm_eps=1e-12, + hidden_dropout_prob=0.1, + intermediate_size=3072, + hidden_act='gelu_new'): + super().__init__() + self.output_attentions = output_attentions + self.output_hidden_states = output_hidden_states + self.layer = nn.ModuleList([ + BertLayer( + hidden_size=hidden_size, + num_attention_heads=num_attention_heads, + output_attentions=output_attentions, + attention_probs_dropout_prob=attention_probs_dropout_prob, + layer_norm_eps=layer_norm_eps, + hidden_dropout_prob=hidden_dropout_prob, + intermediate_size=intermediate_size, + hidden_act=hidden_act) for _ in range(num_hidden_layers) + ]) + + def forward(self, hidden_states, attention_mask=None, head_mask=None): + all_hidden_states = () + all_attentions = () + for i, layer_module in enumerate(self.layer): + if self.output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states, ) + + layer_outputs = layer_module(hidden_states, attention_mask, + head_mask[i]) + hidden_states = layer_outputs[0] + + if self.output_attentions: + all_attentions = all_attentions + (layer_outputs[1], ) + + # Add last layer + if self.output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states, ) + + outputs = (hidden_states, ) + if self.output_hidden_states: + outputs = outputs + (all_hidden_states, ) + if self.output_attentions: + outputs = outputs + (all_attentions, ) + # last-layer hidden state, (all hidden states), (all attentions) + return outputs + + +class BertPooler(nn.Module): + + def __init__(self, hidden_size=768): + super().__init__() + self.dense = nn.Linear(hidden_size, hidden_size) + self.activation = nn.Tanh() + + def forward(self, hidden_states): + # We "pool" the model by simply taking the hidden state corresponding + # to the first token. + first_token_tensor = hidden_states[:, 0] + pooled_output = self.dense(first_token_tensor) + pooled_output = self.activation(pooled_output) + return pooled_output + + +class BertLayer(nn.Module): + """Bert layer. + + The code is adapted from https://github.com/lonePatient/BERT-NER-Pytorch. + """ + + def __init__(self, + hidden_size=768, + num_attention_heads=12, + output_attentions=False, + attention_probs_dropout_prob=0.1, + layer_norm_eps=1e-12, + hidden_dropout_prob=0.1, + intermediate_size=3072, + hidden_act='gelu_new'): + super().__init__() + self.attention = BertAttention( + hidden_size=hidden_size, + num_attention_heads=num_attention_heads, + output_attentions=output_attentions, + attention_probs_dropout_prob=attention_probs_dropout_prob, + layer_norm_eps=layer_norm_eps, + hidden_dropout_prob=hidden_dropout_prob) + self.intermediate = BertIntermediate( + hidden_size=hidden_size, + intermediate_size=intermediate_size, + hidden_act=hidden_act) + self.output = BertOutput( + intermediate_size=intermediate_size, + hidden_size=hidden_size, + layer_norm_eps=layer_norm_eps, + hidden_dropout_prob=hidden_dropout_prob) + + def forward(self, hidden_states, attention_mask=None, head_mask=None): + attention_outputs = self.attention(hidden_states, attention_mask, + head_mask) + attention_output = attention_outputs[0] + intermediate_output = self.intermediate(attention_output) + layer_output = self.output(intermediate_output, attention_output) + outputs = (layer_output, ) + attention_outputs[ + 1:] # add attentions if we output them + return outputs + + +class BertSelfAttention(nn.Module): + """Bert self attention module. + + The code is adapted from https://github.com/lonePatient/BERT-NER-Pytorch. + """ + + def __init__(self, + hidden_size=768, + num_attention_heads=12, + output_attentions=False, + attention_probs_dropout_prob=0.1): + super().__init__() + if hidden_size % num_attention_heads != 0: + raise ValueError('The hidden size (%d) is not a multiple of' + 'the number of attention heads (%d)' % + (hidden_size, num_attention_heads)) + self.output_attentions = output_attentions + + self.num_attention_heads = num_attention_heads + self.att_head_size = int(hidden_size / num_attention_heads) + self.all_head_size = self.num_attention_heads * self.att_head_size + + self.query = nn.Linear(hidden_size, self.all_head_size) + self.key = nn.Linear(hidden_size, self.all_head_size) + self.value = nn.Linear(hidden_size, self.all_head_size) + + self.dropout = nn.Dropout(attention_probs_dropout_prob) + + def transpose_for_scores(self, x): + new_x_shape = x.size()[:-1] + (self.num_attention_heads, + self.att_head_size) + x = x.view(*new_x_shape) + return x.permute(0, 2, 1, 3) + + def forward(self, hidden_states, attention_mask=None, head_mask=None): + mixed_query_layer = self.query(hidden_states) + mixed_key_layer = self.key(hidden_states) + mixed_value_layer = self.value(hidden_states) + + query_layer = self.transpose_for_scores(mixed_query_layer) + key_layer = self.transpose_for_scores(mixed_key_layer) + value_layer = self.transpose_for_scores(mixed_value_layer) + + # Take the dot product between "query" and + # "key" to get the raw attention scores. + attention_scores = torch.matmul(query_layer, + key_layer.transpose(-1, -2)) + attention_scores = attention_scores / math.sqrt(self.att_head_size) + if attention_mask is not None: + # Apply the attention mask is precomputed for + # all layers in BertModel forward() function. + attention_scores = attention_scores + attention_mask + + # Normalize the attention scores to probabilities. + attention_probs = nn.Softmax(dim=-1)(attention_scores) + + # This is actually dropping out entire tokens to attend to, which might + # seem a bit unusual, but is taken from the original Transformer paper. + attention_probs = self.dropout(attention_probs) + + # Mask heads if we want to. + if head_mask is not None: + attention_probs = attention_probs * head_mask + + context_layer = torch.matmul(attention_probs, value_layer) + + context_layer = context_layer.permute(0, 2, 1, 3).contiguous() + new_context_layer_shape = context_layer.size()[:-2] + ( + self.all_head_size, ) + context_layer = context_layer.view(*new_context_layer_shape) + + outputs = (context_layer, + attention_probs) if self.output_attentions else ( + context_layer, ) + return outputs + + +class BertSelfOutput(nn.Module): + """Bert self output. + + The code is adapted from https://github.com/lonePatient/BERT-NER-Pytorch. + """ + + def __init__(self, + hidden_size=768, + layer_norm_eps=1e-12, + hidden_dropout_prob=0.1): + super().__init__() + self.dense = nn.Linear(hidden_size, hidden_size) + self.LayerNorm = torch.nn.LayerNorm(hidden_size, eps=layer_norm_eps) + self.dropout = nn.Dropout(hidden_dropout_prob) + + def forward(self, hidden_states, input_tensor): + hidden_states = self.dense(hidden_states) + hidden_states = self.dropout(hidden_states) + hidden_states = self.LayerNorm(hidden_states + input_tensor) + return hidden_states + + +class BertAttention(nn.Module): + """Bert Attention module implementation. + + The code is adapted from https://github.com/lonePatient/BERT-NER-Pytorch. + """ + + def __init__(self, + hidden_size=768, + num_attention_heads=12, + output_attentions=False, + attention_probs_dropout_prob=0.1, + layer_norm_eps=1e-12, + hidden_dropout_prob=0.1): + super().__init__() + self.self = BertSelfAttention( + hidden_size=hidden_size, + num_attention_heads=num_attention_heads, + output_attentions=output_attentions, + attention_probs_dropout_prob=attention_probs_dropout_prob) + self.output = BertSelfOutput( + hidden_size=hidden_size, + layer_norm_eps=layer_norm_eps, + hidden_dropout_prob=hidden_dropout_prob) + + def forward(self, input_tensor, attention_mask=None, head_mask=None): + self_outputs = self.self(input_tensor, attention_mask, head_mask) + attention_output = self.output(self_outputs[0], input_tensor) + outputs = (attention_output, + ) + self_outputs[1:] # add attentions if we output them + return outputs + + +class BertIntermediate(nn.Module): + """Bert BertIntermediate module implementation. + + The code is adapted from https://github.com/lonePatient/BERT-NER-Pytorch. + """ + + def __init__(self, + hidden_size=768, + intermediate_size=3072, + hidden_act='gelu_new'): + super().__init__() + self.dense = nn.Linear(hidden_size, intermediate_size) + if isinstance(hidden_act, str): + self.intermediate_act_fn = ACT2FN[hidden_act] + else: + self.intermediate_act_fn = hidden_act + + def forward(self, hidden_states): + hidden_states = self.dense(hidden_states) + hidden_states = self.intermediate_act_fn(hidden_states) + return hidden_states + + +class BertOutput(nn.Module): + """Bert output module. + + The code is adapted from https://github.com/lonePatient/BERT-NER-Pytorch. + """ + + def __init__(self, + intermediate_size=3072, + hidden_size=768, + layer_norm_eps=1e-12, + hidden_dropout_prob=0.1): + + super().__init__() + self.dense = nn.Linear(intermediate_size, hidden_size) + self.LayerNorm = torch.nn.LayerNorm(hidden_size, eps=layer_norm_eps) + self.dropout = nn.Dropout(hidden_dropout_prob) + + def forward(self, hidden_states, input_tensor): + hidden_states = self.dense(hidden_states) + hidden_states = self.dropout(hidden_states) + hidden_states = self.LayerNorm(hidden_states + input_tensor) + return hidden_states diff --git a/tests/test_dataset/test_ner_dataset.py b/tests/test_dataset/test_ner_dataset.py new file mode 100644 index 00000000..f79ba7b9 --- /dev/null +++ b/tests/test_dataset/test_ner_dataset.py @@ -0,0 +1,114 @@ +import json +import os.path as osp +import tempfile + +import torch + +from mmocr.datasets.ner_dataset import NerDataset +from mmocr.models.ner.convertors.ner_convertor import NerConvertor + + +def _create_dummy_ann_file(ann_file): + data = { + 'text': '彭小军认为,国内银行现在走的是台湾的发卡模式', + 'label': { + 'address': { + '台湾': [[15, 16]] + }, + 'name': { + '彭小军': [[0, 2]] + } + } + } + + with open(ann_file, 'w') as fw: + fw.write(json.dumps(data, ensure_ascii=False) + '\n') + + +def _create_dummy_vocab_file(vocab_file): + with open(vocab_file, 'w') as fw: + for char in list(map(chr, range(ord('a'), ord('z') + 1))): + fw.write(char + '\n') + + +def _create_dummy_loader(): + loader = dict( + type='HardDiskLoader', + repeat=1, + parser=dict(type='LineJsonParser', keys=['text', 'label'])) + return loader + + +def test_ner_dataset(): + # test initialization + loader = _create_dummy_loader() + categories = [ + 'address', 'book', 'company', 'game', 'government', 'movie', 'name', + 'organization', 'position', 'scene' + ] + + # create dummy data + tmp_dir = tempfile.TemporaryDirectory() + ann_file = osp.join(tmp_dir.name, 'fake_data.txt') + vocab_file = osp.join(tmp_dir.name, 'fake_vocab.txt') + _create_dummy_ann_file(ann_file) + _create_dummy_vocab_file(vocab_file) + + max_len = 128 + ner_convertor = dict( + type='NerConvertor', + annotation_type='bio', + vocab_file=vocab_file, + categories=categories, + max_len=max_len) + + test_pipeline = [ + dict( + type='NerTransform', + label_convertor=ner_convertor, + max_len=max_len), + dict(type='ToTensorNER') + ] + dataset = NerDataset(ann_file, loader, pipeline=test_pipeline) + + # test pre_pipeline + img_info = dataset.data_infos[0] + results = dict(img_info=img_info) + dataset.pre_pipeline(results) + + # test prepare_train_img + dataset.prepare_train_img(0) + + # test evaluation + result = [[['address', 15, 16], ['name', 0, 2]]] + + dataset.evaluate(result) + + # test pred convert2entity function + pred = [ + 21, 7, 17, 17, 21, 21, 21, 21, 21, 21, 13, 21, 21, 21, 21, 21, 1, 11, + 21, 21, 7, 17, 17, 21, 21, 21, 21, 21, 21, 13, 21, 21, 21, 21, 21, 1, + 11, 21, 21, 21, 21, 21, 21, 21, 21, 21, 21, 21, 21, 21, 21, 21, 21, 21, + 21, 21, 21, 21, 21, 21, 21, 21, 21, 21, 21, 21, 21, 21, 21, 21, 21, 21, + 21, 21, 21, 21, 21, 21, 21, 21, 21, 21, 21, 21, 21, 21, 21, 21, 21, 21, + 21, 21, 21, 21, 21, 21, 21, 21, 21, 21, 21, 21, 21, 21, 21, 21, 21, 21, + 21, 21, 21, 21, 21, 21, 21, 21, 21, 21, 21, 21, 21, 21, 21, 21, 21, 21, + 21, 21, 21, 21, 21, 21, 21, 21, 21, 21, 21, 21, 1, 21, 21, 21, 21, 21, + 21, 21, 21, 21, 21, 21, 21, 21, 21, 21, 21, 21, 21, 1, 21, 21, 21, 21, + 21, 21 + ] + preds = [pred[:128]] + mask = [0] * 128 + for i in range(10): + mask[i] = 1 + assert len(preds[0]) == len(mask) + masks = torch.tensor([mask]) + convertor = NerConvertor( + annotation_type='bio', + vocab_file=vocab_file, + categories=categories, + max_len=128) + all_entities = convertor.convert_pred2entities(preds=preds, masks=masks) + assert len(all_entities[0][0]) == 3 + + tmp_dir.cleanup() diff --git a/tests/test_models/test_ner_model.py b/tests/test_models/test_ner_model.py new file mode 100644 index 00000000..c5dc4764 --- /dev/null +++ b/tests/test_models/test_ner_model.py @@ -0,0 +1,83 @@ +import copy +import os.path as osp +import tempfile + +import pytest +import torch + +from mmocr.models import build_detector +from mmocr.models.ner.utils.activations import gelu, gelu_new + + +def _create_dummy_vocab_file(vocab_file): + with open(vocab_file, 'w') as fw: + for char in list(map(chr, range(ord('a'), ord('z') + 1))): + fw.write(char + '\n') + + +def _get_config_module(fname): + """Load a configuration as a python module.""" + from mmcv import Config + config_mod = Config.fromfile(fname) + return config_mod + + +def _get_detector_cfg(fname): + """Grab configs necessary to create a detector. + + These are deep copied to allow for safe modification of parameters without + influencing other tests. + """ + config = _get_config_module(fname) + model = copy.deepcopy(config.model) + return model + + +@pytest.mark.parametrize( + 'cfg_file', ['configs/ner/bert_softmax/bert_softmax_cluener_18e.py']) +def test_encoder_decoder_pipeline(cfg_file): + # prepare data + texts = ['中'] * 47 + img = [31] * 47 + labels = [31] * 128 + input_ids = [0] * 128 + attention_mask = [0] * 128 + token_type_ids = [0] * 128 + img_metas = { + 'texts': texts, + 'labels': torch.tensor(labels).unsqueeze(0), + 'img': img, + 'input_ids': torch.tensor(input_ids).unsqueeze(0), + 'attention_masks': torch.tensor(attention_mask).unsqueeze(0), + 'token_type_ids': torch.tensor(token_type_ids).unsqueeze(0) + } + + # create dummy data + tmp_dir = tempfile.TemporaryDirectory() + vocab_file = osp.join(tmp_dir.name, 'fake_vocab.txt') + _create_dummy_vocab_file(vocab_file) + + model = _get_detector_cfg(cfg_file) + model['label_convertor']['vocab_file'] = vocab_file + model['pretrained'] = None + + detector = build_detector(model) + losses = detector.forward(img, img_metas) + assert isinstance(losses, dict) + + model['loss']['type'] = 'MaskedFocalLoss' + detector = build_detector(model) + losses = detector.forward(img, img_metas) + assert isinstance(losses, dict) + + tmp_dir.cleanup() + + # Test forward test + with torch.no_grad(): + batch_results = [] + result = detector.forward(None, img_metas, return_loss=False) + batch_results.append(result) + + # Test activations + gelu(torch.tensor(0.5)) + gelu_new(torch.tensor(0.5)) diff --git a/tools/train.py b/tools/train.py index 92fe3066..117072ad 100644 --- a/tools/train.py +++ b/tools/train.py @@ -22,6 +22,8 @@ def parse_args(): parser = argparse.ArgumentParser(description='Train a detector.') parser.add_argument('config', help='Train config file path.') parser.add_argument('--work-dir', help='The dir to save logs and models.') + parser.add_argument( + '--load-from', help='The checkpoint file to load from.') parser.add_argument( '--resume-from', help='The checkpoint file to resume from.') parser.add_argument( @@ -123,6 +125,8 @@ def main(): # use config filename as default work_dir if cfg.work_dir is None cfg.work_dir = osp.join('./work_dirs', osp.splitext(osp.basename(args.config))[0]) + if args.load_from is not None: + cfg.load_from = args.load_from if args.resume_from is not None: cfg.resume_from = args.resume_from if args.gpu_ids is not None: