From d4dbad56ee523c452b710e5a5ef2d7bc0a8d9903 Mon Sep 17 00:00:00 2001 From: gaotongxiao Date: Fri, 8 Jul 2022 14:43:40 +0800 Subject: [PATCH] update abi framework --- mmocr/models/textrecog/recognizers/abinet.py | 178 ------------------- 1 file changed, 178 deletions(-) diff --git a/mmocr/models/textrecog/recognizers/abinet.py b/mmocr/models/textrecog/recognizers/abinet.py index 40084f7f..52919674 100644 --- a/mmocr/models/textrecog/recognizers/abinet.py +++ b/mmocr/models/textrecog/recognizers/abinet.py @@ -1,8 +1,4 @@ # Copyright (c) OpenMMLab. All rights reserved. -import warnings - -import torch - from mmocr.registry import MODELS from .encode_decode_recognizer import EncodeDecodeRecognizer @@ -14,177 +10,3 @@ class ABINet(EncodeDecodeRecognizer): `_ """ - - def __init__(self, - preprocessor=None, - backbone=None, - encoder=None, - decoder=None, - iter_size=1, - fuser=None, - loss=None, - label_convertor=None, - train_cfg=None, - test_cfg=None, - max_seq_len=40, - pretrained=None, - init_cfg=None): - super(EncodeDecodeRecognizer, self).__init__(init_cfg=init_cfg) - - # Label convertor (str2tensor, tensor2str) - assert label_convertor is not None - label_convertor.update(max_seq_len=max_seq_len) - self.label_convertor = MODELS.build(label_convertor) - - # Preprocessor module, e.g., TPS - self.preprocessor = None - if preprocessor is not None: - self.preprocessor = MODELS.build(preprocessor) - - # Backbone - assert backbone is not None - self.backbone = MODELS.build(backbone) - - # Encoder module - self.encoder = None - if encoder is not None: - self.encoder = MODELS.build(encoder) - - # Decoder module - self.decoder = None - if decoder is not None: - decoder.update(num_classes=self.label_convertor.num_classes()) - decoder.update(start_idx=self.label_convertor.start_idx) - decoder.update(padding_idx=self.label_convertor.padding_idx) - decoder.update(max_seq_len=max_seq_len) - self.decoder = MODELS.build(decoder) - - # Loss - assert loss is not None - self.loss = MODELS.build(loss) - - self.train_cfg = train_cfg - self.test_cfg = test_cfg - self.max_seq_len = max_seq_len - - if pretrained is not None: - warnings.warn('DeprecationWarning: pretrained is a deprecated \ - key, please consider using init_cfg') - self.init_cfg = dict(type='Pretrained', checkpoint=pretrained) - - self.iter_size = iter_size - - self.fuser = None - if fuser is not None: - self.fuser = MODELS.build(fuser) - - def forward_train(self, img, img_metas): - """ - Args: - img (tensor): Input images of shape (N, C, H, W). - Typically these should be mean centered and std scaled. - img_metas (list[dict]): A list of image info dict where each dict - contains: 'img_shape', 'filename', and may also contain - 'ori_shape', and 'img_norm_cfg'. - For details on the values of these keys see - :class:`mmdet.datasets.pipelines.Collect`. - - Returns: - dict[str, tensor]: A dictionary of loss components. - """ - for img_meta in img_metas: - valid_ratio = 1.0 * img_meta['resize_shape'][1] / img.size(-1) - img_meta['valid_ratio'] = valid_ratio - - feat = self.extract_feat(img) - - gt_labels = [img_meta['text'] for img_meta in img_metas] - - targets_dict = self.label_convertor.str2tensor(gt_labels) - - text_logits = None - out_enc = None - if self.encoder is not None: - out_enc = self.encoder(feat) - text_logits = out_enc['logits'] - - out_decs = [] - out_fusers = [] - for _ in range(self.iter_size): - if self.decoder is not None: - out_dec = self.decoder( - feat, - text_logits, - targets_dict, - img_metas, - train_mode=True) - out_decs.append(out_dec) - - if self.fuser is not None: - out_fuser = self.fuser(out_enc['feature'], out_dec['feature']) - text_logits = out_fuser['logits'] - out_fusers.append(out_fuser) - - outputs = dict( - out_enc=out_enc, out_decs=out_decs, out_fusers=out_fusers) - - losses = self.loss(outputs, targets_dict, img_metas) - - return losses - - def simple_test(self, img, img_metas, **kwargs): - """Test function with test time augmentation. - - Args: - imgs (torch.Tensor): Image input tensor. - img_metas (list[dict]): List of image information. - - Returns: - list[str]: Text label result of each image. - """ - for img_meta in img_metas: - valid_ratio = 1.0 * img_meta['resize_shape'][1] / img.size(-1) - img_meta['valid_ratio'] = valid_ratio - - feat = self.extract_feat(img) - - text_logits = None - out_enc = None - if self.encoder is not None: - out_enc = self.encoder(feat) - text_logits = out_enc['logits'] - - out_decs = [] - out_fusers = [] - for _ in range(self.iter_size): - if self.decoder is not None: - out_dec = self.decoder( - feat, text_logits, img_metas=img_metas, train_mode=False) - out_decs.append(out_dec) - - if self.fuser is not None: - out_fuser = self.fuser(out_enc['feature'], out_dec['feature']) - text_logits = out_fuser['logits'] - out_fusers.append(out_fuser) - - if len(out_fusers) > 0: - ret = out_fusers[-1] - elif len(out_decs) > 0: - ret = out_decs[-1] - else: - ret = out_enc - - # early return to avoid post processing - if torch.onnx.is_in_onnx_export(): - return ret['logits'] - - label_indexes, label_scores = self.label_convertor.tensor2idx( - ret['logits'], img_metas) - label_strings = self.label_convertor.idx2str(label_indexes) - - # flatten batch results - results = [] - for string, score in zip(label_strings, label_scores): - results.append(dict(text=string, score=score)) - - return results