mirror of https://github.com/open-mmlab/mmocr.git
update abi framework
parent
a26114d9c7
commit
d4dbad56ee
|
@ -1,8 +1,4 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import warnings
|
||||
|
||||
import torch
|
||||
|
||||
from mmocr.registry import MODELS
|
||||
from .encode_decode_recognizer import EncodeDecodeRecognizer
|
||||
|
||||
|
@ -14,177 +10,3 @@ class ABINet(EncodeDecodeRecognizer):
|
|||
|
||||
<https://arxiv.org/pdf/2103.06495.pdf>`_
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
preprocessor=None,
|
||||
backbone=None,
|
||||
encoder=None,
|
||||
decoder=None,
|
||||
iter_size=1,
|
||||
fuser=None,
|
||||
loss=None,
|
||||
label_convertor=None,
|
||||
train_cfg=None,
|
||||
test_cfg=None,
|
||||
max_seq_len=40,
|
||||
pretrained=None,
|
||||
init_cfg=None):
|
||||
super(EncodeDecodeRecognizer, self).__init__(init_cfg=init_cfg)
|
||||
|
||||
# Label convertor (str2tensor, tensor2str)
|
||||
assert label_convertor is not None
|
||||
label_convertor.update(max_seq_len=max_seq_len)
|
||||
self.label_convertor = MODELS.build(label_convertor)
|
||||
|
||||
# Preprocessor module, e.g., TPS
|
||||
self.preprocessor = None
|
||||
if preprocessor is not None:
|
||||
self.preprocessor = MODELS.build(preprocessor)
|
||||
|
||||
# Backbone
|
||||
assert backbone is not None
|
||||
self.backbone = MODELS.build(backbone)
|
||||
|
||||
# Encoder module
|
||||
self.encoder = None
|
||||
if encoder is not None:
|
||||
self.encoder = MODELS.build(encoder)
|
||||
|
||||
# Decoder module
|
||||
self.decoder = None
|
||||
if decoder is not None:
|
||||
decoder.update(num_classes=self.label_convertor.num_classes())
|
||||
decoder.update(start_idx=self.label_convertor.start_idx)
|
||||
decoder.update(padding_idx=self.label_convertor.padding_idx)
|
||||
decoder.update(max_seq_len=max_seq_len)
|
||||
self.decoder = MODELS.build(decoder)
|
||||
|
||||
# Loss
|
||||
assert loss is not None
|
||||
self.loss = MODELS.build(loss)
|
||||
|
||||
self.train_cfg = train_cfg
|
||||
self.test_cfg = test_cfg
|
||||
self.max_seq_len = max_seq_len
|
||||
|
||||
if pretrained is not None:
|
||||
warnings.warn('DeprecationWarning: pretrained is a deprecated \
|
||||
key, please consider using init_cfg')
|
||||
self.init_cfg = dict(type='Pretrained', checkpoint=pretrained)
|
||||
|
||||
self.iter_size = iter_size
|
||||
|
||||
self.fuser = None
|
||||
if fuser is not None:
|
||||
self.fuser = MODELS.build(fuser)
|
||||
|
||||
def forward_train(self, img, img_metas):
|
||||
"""
|
||||
Args:
|
||||
img (tensor): Input images of shape (N, C, H, W).
|
||||
Typically these should be mean centered and std scaled.
|
||||
img_metas (list[dict]): A list of image info dict where each dict
|
||||
contains: 'img_shape', 'filename', and may also contain
|
||||
'ori_shape', and 'img_norm_cfg'.
|
||||
For details on the values of these keys see
|
||||
:class:`mmdet.datasets.pipelines.Collect`.
|
||||
|
||||
Returns:
|
||||
dict[str, tensor]: A dictionary of loss components.
|
||||
"""
|
||||
for img_meta in img_metas:
|
||||
valid_ratio = 1.0 * img_meta['resize_shape'][1] / img.size(-1)
|
||||
img_meta['valid_ratio'] = valid_ratio
|
||||
|
||||
feat = self.extract_feat(img)
|
||||
|
||||
gt_labels = [img_meta['text'] for img_meta in img_metas]
|
||||
|
||||
targets_dict = self.label_convertor.str2tensor(gt_labels)
|
||||
|
||||
text_logits = None
|
||||
out_enc = None
|
||||
if self.encoder is not None:
|
||||
out_enc = self.encoder(feat)
|
||||
text_logits = out_enc['logits']
|
||||
|
||||
out_decs = []
|
||||
out_fusers = []
|
||||
for _ in range(self.iter_size):
|
||||
if self.decoder is not None:
|
||||
out_dec = self.decoder(
|
||||
feat,
|
||||
text_logits,
|
||||
targets_dict,
|
||||
img_metas,
|
||||
train_mode=True)
|
||||
out_decs.append(out_dec)
|
||||
|
||||
if self.fuser is not None:
|
||||
out_fuser = self.fuser(out_enc['feature'], out_dec['feature'])
|
||||
text_logits = out_fuser['logits']
|
||||
out_fusers.append(out_fuser)
|
||||
|
||||
outputs = dict(
|
||||
out_enc=out_enc, out_decs=out_decs, out_fusers=out_fusers)
|
||||
|
||||
losses = self.loss(outputs, targets_dict, img_metas)
|
||||
|
||||
return losses
|
||||
|
||||
def simple_test(self, img, img_metas, **kwargs):
|
||||
"""Test function with test time augmentation.
|
||||
|
||||
Args:
|
||||
imgs (torch.Tensor): Image input tensor.
|
||||
img_metas (list[dict]): List of image information.
|
||||
|
||||
Returns:
|
||||
list[str]: Text label result of each image.
|
||||
"""
|
||||
for img_meta in img_metas:
|
||||
valid_ratio = 1.0 * img_meta['resize_shape'][1] / img.size(-1)
|
||||
img_meta['valid_ratio'] = valid_ratio
|
||||
|
||||
feat = self.extract_feat(img)
|
||||
|
||||
text_logits = None
|
||||
out_enc = None
|
||||
if self.encoder is not None:
|
||||
out_enc = self.encoder(feat)
|
||||
text_logits = out_enc['logits']
|
||||
|
||||
out_decs = []
|
||||
out_fusers = []
|
||||
for _ in range(self.iter_size):
|
||||
if self.decoder is not None:
|
||||
out_dec = self.decoder(
|
||||
feat, text_logits, img_metas=img_metas, train_mode=False)
|
||||
out_decs.append(out_dec)
|
||||
|
||||
if self.fuser is not None:
|
||||
out_fuser = self.fuser(out_enc['feature'], out_dec['feature'])
|
||||
text_logits = out_fuser['logits']
|
||||
out_fusers.append(out_fuser)
|
||||
|
||||
if len(out_fusers) > 0:
|
||||
ret = out_fusers[-1]
|
||||
elif len(out_decs) > 0:
|
||||
ret = out_decs[-1]
|
||||
else:
|
||||
ret = out_enc
|
||||
|
||||
# early return to avoid post processing
|
||||
if torch.onnx.is_in_onnx_export():
|
||||
return ret['logits']
|
||||
|
||||
label_indexes, label_scores = self.label_convertor.tensor2idx(
|
||||
ret['logits'], img_metas)
|
||||
label_strings = self.label_convertor.idx2str(label_indexes)
|
||||
|
||||
# flatten batch results
|
||||
results = []
|
||||
for string, score in zip(label_strings, label_scores):
|
||||
results.append(dict(text=string, score=score))
|
||||
|
||||
return results
|
||||
|
|
Loading…
Reference in New Issue