mirror of https://github.com/open-mmlab/mmocr.git
[Refactor] encoder_decoder_recognizer
parent
58c59e80dd
commit
206c4ccc65
|
@ -4,6 +4,7 @@
|
|||
# mmocr/models/textdet/postprocess/utils.py
|
||||
# .*/utils.py
|
||||
mmocr/models/textrecog/recognizers/base.py
|
||||
mmocr/models/textrecog/recognizers/encode_decode_recognizer.py
|
||||
.*/__init__.py
|
||||
# It will be removed after all transforms have been refactored into processing.py
|
||||
mmocr/datasets/pipelines/transforms.py
|
||||
|
|
|
@ -56,6 +56,31 @@ class BaseRecognizer(BaseModule, metaclass=ABCMeta):
|
|||
def device(self) -> torch.device:
|
||||
return self.pixel_mean.device
|
||||
|
||||
@property
|
||||
def with_backbone(self):
|
||||
"""bool: whether the recognizer has a backbone"""
|
||||
return getattr(self, 'backbone', None) is not None
|
||||
|
||||
@property
|
||||
def with_encoder(self):
|
||||
"""bool: whether the recognizer has an encoder"""
|
||||
return getattr(self, 'encoder', None) is not None
|
||||
|
||||
@property
|
||||
def with_preprocessor(self):
|
||||
"""bool: whether the recognizer has a preprocessor"""
|
||||
return getattr(self, 'preprocessor', None) is not None
|
||||
|
||||
@property
|
||||
def with_dictionary(self):
|
||||
"""bool: whether the recognizer has a dictionary"""
|
||||
return getattr(self, 'dictionary', None) is not None
|
||||
|
||||
@property
|
||||
def with_decoder(self):
|
||||
"""bool: whether the recognizer has a decoder"""
|
||||
return getattr(self, 'decoder', None) is not None
|
||||
|
||||
@abstractmethod
|
||||
def extract_feat(self, inputs: torch.Tensor) -> torch.Tensor:
|
||||
"""Extract features from images."""
|
||||
|
@ -77,8 +102,8 @@ class BaseRecognizer(BaseModule, metaclass=ABCMeta):
|
|||
# NOTE the batched image size information may be useful for
|
||||
# calculating valid ratio.
|
||||
batch_input_shape = tuple(inputs[0].size()[-2:])
|
||||
for data_samples in data_samples:
|
||||
data_samples.set_metainfo({'batch_input_shape': batch_input_shape})
|
||||
for data_sample in data_samples:
|
||||
data_sample.set_metainfo({'batch_input_shape': batch_input_shape})
|
||||
|
||||
@abstractmethod
|
||||
def simple_test(self, inputs: torch.Tensor,
|
||||
|
@ -86,7 +111,6 @@ class BaseRecognizer(BaseModule, metaclass=ABCMeta):
|
|||
**kwargs) -> Sequence[TextRecogDataSample]:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def aug_test(self, imgs: torch.Tensor,
|
||||
data_samples: Sequence[Sequence[TextRecogDataSample]],
|
||||
**kwargs):
|
||||
|
@ -230,16 +254,8 @@ class BaseRecognizer(BaseModule, metaclass=ABCMeta):
|
|||
``pred_text``.
|
||||
"""
|
||||
# TODO: Consider merging with forward_train logic
|
||||
batch_size = len(data_samples)
|
||||
batch_img_metas = []
|
||||
for batch_index in range(batch_size):
|
||||
metainfo = data_samples[batch_index].metainfo
|
||||
batch_input_shape = tuple(inputs[0].size()[-2:])
|
||||
for data_sample in data_samples:
|
||||
data_sample.set_metainfo({'batch_input_shape': batch_input_shape})
|
||||
|
||||
# TODO: maybe remove to stack_batch
|
||||
# NOTE the batched image size information may be useful for
|
||||
# calculating valid ratio.
|
||||
metainfo['batch_input_shape'] = \
|
||||
tuple(inputs.size()[-2:])
|
||||
batch_img_metas.append(metainfo)
|
||||
|
||||
return self.simple_test(inputs, batch_img_metas, **kwargs)
|
||||
return self.simple_test(inputs, data_samples, **kwargs)
|
||||
|
|
|
@ -1,181 +1,144 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
|
||||
import warnings
|
||||
from typing import Dict, Optional, Sequence
|
||||
|
||||
import torch
|
||||
|
||||
from mmocr.registry import MODELS
|
||||
from mmocr.core.data_structures import TextRecogDataSample
|
||||
from mmocr.registry import MODELS, TASK_UTILS
|
||||
from .base import BaseRecognizer
|
||||
|
||||
|
||||
@MODELS.register_module()
|
||||
class EncodeDecodeRecognizer(BaseRecognizer):
|
||||
"""Base class for encode-decode recognizer."""
|
||||
"""Base class for encode-decode recognizer.
|
||||
|
||||
Args:
|
||||
backbone (dict, optional): Backbone config. Defaults to None.
|
||||
encoder (dict, optional): Encoder config. If None, the output from
|
||||
backbone will be directly fed into ``decoder``. Defaults to None.
|
||||
decoder (dict, optional): Decoder config. Defaults to None.
|
||||
dictionary (dict, optional): Dictionary config. Defaults to None.
|
||||
max_seq_len (int): Maximum sequence length. Defaults to 40.
|
||||
preprocess_cfg (dict, optional): Model preprocessing config
|
||||
for processing the input image data. Keys allowed are
|
||||
``to_rgb``(bool), ``pad_size_divisor``(int), ``pad_value``(int or
|
||||
float), ``mean``(int or float) and ``std``(int or float).
|
||||
Preprcessing order: 1. to rgb; 2. normalization 3. pad.
|
||||
Defaults to None.
|
||||
init_cfg (dict or list[dict], optional): Initialization configs.
|
||||
Defaults to None.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
preprocessor=None,
|
||||
backbone=None,
|
||||
encoder=None,
|
||||
decoder=None,
|
||||
loss=None,
|
||||
label_convertor=None,
|
||||
train_cfg=None,
|
||||
test_cfg=None,
|
||||
max_seq_len=40,
|
||||
pretrained=None,
|
||||
init_cfg=None):
|
||||
preprocessor: Optional[Dict] = None,
|
||||
backbone: Optional[Dict] = None,
|
||||
encoder: Optional[Dict] = None,
|
||||
decoder: Optional[Dict] = None,
|
||||
dictionary: Optional[Dict] = None,
|
||||
max_seq_len: int = 40,
|
||||
preprocess_cfg: Dict = None,
|
||||
init_cfg: Optional[Dict] = None) -> None:
|
||||
|
||||
super().__init__(init_cfg=init_cfg)
|
||||
|
||||
# Label convertor (str2tensor, tensor2str)
|
||||
assert label_convertor is not None
|
||||
label_convertor.update(max_seq_len=max_seq_len)
|
||||
self.label_convertor = MODELS.build(label_convertor)
|
||||
super().__init__(init_cfg=init_cfg, preprocess_cfg=preprocess_cfg)
|
||||
|
||||
# Preprocessor module, e.g., TPS
|
||||
self.preprocessor = None
|
||||
if preprocessor is not None:
|
||||
self.preprocessor = MODELS.build(preprocessor)
|
||||
|
||||
# Backbone
|
||||
assert backbone is not None
|
||||
self.backbone = MODELS.build(backbone)
|
||||
if backbone is not None:
|
||||
self.backbone = MODELS.build(backbone)
|
||||
|
||||
# Encoder module
|
||||
self.encoder = None
|
||||
if encoder is not None:
|
||||
self.encoder = MODELS.build(encoder)
|
||||
|
||||
# Dictionary
|
||||
if dictionary is not None:
|
||||
self.dictionary = TASK_UTILS.build(dictionary)
|
||||
# Decoder module
|
||||
assert decoder is not None
|
||||
decoder.update(num_classes=self.label_convertor.num_classes())
|
||||
decoder.update(start_idx=self.label_convertor.start_idx)
|
||||
decoder.update(padding_idx=self.label_convertor.padding_idx)
|
||||
decoder.update(max_seq_len=max_seq_len)
|
||||
|
||||
if self.with_dictionary:
|
||||
if decoder.get('dictionary', None) is None:
|
||||
decoder.update(dictionary=self.dictionary)
|
||||
else:
|
||||
warnings.warn(f"Using dictionary {decoder['dictionary']} "
|
||||
"in decoder's config.")
|
||||
if decoder.get('max_seq_len', None) is None:
|
||||
decoder.update(max_seq_len=max_seq_len)
|
||||
else:
|
||||
warnings.warn(f"Using max_seq_len {decoder['max_seq_len']} "
|
||||
"in decoder's config.")
|
||||
self.decoder = MODELS.build(decoder)
|
||||
|
||||
# Loss
|
||||
assert loss is not None
|
||||
loss.update(ignore_index=self.label_convertor.padding_idx)
|
||||
self.loss = MODELS.build(loss)
|
||||
|
||||
self.train_cfg = train_cfg
|
||||
self.test_cfg = test_cfg
|
||||
self.max_seq_len = max_seq_len
|
||||
|
||||
if pretrained is not None:
|
||||
warnings.warn('DeprecationWarning: pretrained is a deprecated \
|
||||
key, please consider using init_cfg')
|
||||
self.init_cfg = dict(type='Pretrained', checkpoint=pretrained)
|
||||
|
||||
def extract_feat(self, img):
|
||||
def extract_feat(self, inputs: torch.Tensor) -> torch.Tensor:
|
||||
"""Directly extract features from the backbone."""
|
||||
if self.preprocessor is not None:
|
||||
img = self.preprocessor(img)
|
||||
if self.with_preprocessor:
|
||||
inputs = self.preprocessor(inputs)
|
||||
if self.with_backbone:
|
||||
inputs = self.backbone(inputs)
|
||||
return inputs
|
||||
|
||||
x = self.backbone(img)
|
||||
|
||||
return x
|
||||
|
||||
def forward_train(self, img, img_metas):
|
||||
def forward_train(self, inputs: torch.Tensor,
|
||||
data_samples: Sequence[TextRecogDataSample],
|
||||
**kwargs) -> Dict:
|
||||
"""
|
||||
Args:
|
||||
img (tensor): Input images of shape (N, C, H, W).
|
||||
Typically these should be mean centered and std scaled.
|
||||
img_metas (list[dict]): A list of image info dict where each dict
|
||||
contains: 'img_shape', 'filename', and may also contain
|
||||
'ori_shape', and 'img_norm_cfg'.
|
||||
For details on the values of these keys see
|
||||
:class:`mmdet.datasets.pipelines.Collect`.
|
||||
Args:
|
||||
inputs (tensor): Input images of shape (N, C, H, W).
|
||||
Typically these should be mean centered and std scaled.
|
||||
data_samples (list[TextRecogDataSample]): A list of N
|
||||
datasamples, containing meta information and gold
|
||||
annotations for each of the images.
|
||||
|
||||
Returns:
|
||||
dict[str, tensor]: A dictionary of loss components.
|
||||
Returns:
|
||||
dict[str, tensor]: A dictionary of loss components.
|
||||
"""
|
||||
for img_meta in img_metas:
|
||||
valid_ratio = 1.0 * img_meta['resize_shape'][1] / img.size(-1)
|
||||
img_meta['valid_ratio'] = valid_ratio
|
||||
|
||||
feat = self.extract_feat(img)
|
||||
|
||||
gt_labels = [img_meta['text'] for img_meta in img_metas]
|
||||
|
||||
targets_dict = self.label_convertor.str2tensor(gt_labels)
|
||||
# TODO move to preprocess to update valid ratio
|
||||
super().forward_train(inputs, data_samples, **kwargs)
|
||||
for data_sample in data_samples:
|
||||
valid_ratio = data_sample.valid_ratio * data_sample.img_shape[
|
||||
1] / data_sample.batch_input_shape[1]
|
||||
data_sample.set_metainfo(dict(valid_ratio=valid_ratio))
|
||||
|
||||
feat = self.extract_feat(inputs)
|
||||
out_enc = None
|
||||
if self.encoder is not None:
|
||||
out_enc = self.encoder(feat, img_metas)
|
||||
if self.with_encoder:
|
||||
out_enc = self.encoder(feat, data_samples)
|
||||
data_samples = self.decoder.loss.get_target(data_samples)
|
||||
out_dec = self.decoder(feat, out_enc, data_samples, train_mode=True)
|
||||
|
||||
out_dec = self.decoder(
|
||||
feat, out_enc, targets_dict, img_metas, train_mode=True)
|
||||
|
||||
loss_inputs = (
|
||||
out_dec,
|
||||
targets_dict,
|
||||
img_metas,
|
||||
)
|
||||
losses = self.loss(*loss_inputs)
|
||||
losses = self.decoder.loss(out_dec, data_samples)
|
||||
|
||||
return losses
|
||||
|
||||
def simple_test(self, img, img_metas, **kwargs):
|
||||
"""Test function with test time augmentation.
|
||||
def simple_test(self, inputs: torch.Tensor,
|
||||
data_samples: Sequence[TextRecogDataSample],
|
||||
**kwargs) -> Sequence[TextRecogDataSample]:
|
||||
"""Test function without test-time augmentation.
|
||||
|
||||
Args:
|
||||
imgs (torch.Tensor): Image input tensor.
|
||||
img_metas (list[dict]): List of image information.
|
||||
inputs (torch.Tensor): Image input tensor.
|
||||
data_samples (list[TextRecogDataSample]): A list of N datasamples,
|
||||
containing meta information and gold annotations for each of
|
||||
the images.
|
||||
|
||||
Returns:
|
||||
list[str]: Text label result of each image.
|
||||
list[TextRecogDataSample]: A list of N datasamples of prediction
|
||||
results. Results are stored in ``pred_text``.
|
||||
"""
|
||||
for img_meta in img_metas:
|
||||
valid_ratio = 1.0 * img_meta['resize_shape'][1] / img.size(-1)
|
||||
img_meta['valid_ratio'] = valid_ratio
|
||||
|
||||
feat = self.extract_feat(img)
|
||||
# TODO move to preprocess to update valid ratio
|
||||
for data_sample in data_samples:
|
||||
valid_ratio = data_sample.valid_ratio * data_sample.img_shape[
|
||||
1] / data_sample.batch_input_shape[1]
|
||||
data_sample.set_metainfo(dict(valid_ratio=valid_ratio))
|
||||
feat = self.extract_feat(inputs)
|
||||
|
||||
out_enc = None
|
||||
if self.encoder is not None:
|
||||
out_enc = self.encoder(feat, img_metas)
|
||||
|
||||
out_dec = self.decoder(
|
||||
feat, out_enc, None, img_metas, train_mode=False)
|
||||
|
||||
# early return to avoid post processing
|
||||
if torch.onnx.is_in_onnx_export():
|
||||
return out_dec
|
||||
|
||||
label_indexes, label_scores = self.label_convertor.tensor2idx(
|
||||
out_dec, img_metas)
|
||||
label_strings = self.label_convertor.idx2str(label_indexes)
|
||||
|
||||
# flatten batch results
|
||||
results = []
|
||||
for string, score in zip(label_strings, label_scores):
|
||||
results.append(dict(text=string, score=score))
|
||||
|
||||
return results
|
||||
|
||||
def merge_aug_results(self, aug_results):
|
||||
out_text, out_score = '', -1
|
||||
for result in aug_results:
|
||||
text = result[0]['text']
|
||||
score = sum(result[0]['score']) / max(1, len(text))
|
||||
if score > out_score:
|
||||
out_text = text
|
||||
out_score = score
|
||||
out_results = [dict(text=out_text, score=out_score)]
|
||||
return out_results
|
||||
|
||||
def aug_test(self, imgs, img_metas, **kwargs):
|
||||
"""Test function as well as time augmentation.
|
||||
|
||||
Args:
|
||||
imgs (list[tensor]): Tensor should have shape NxCxHxW,
|
||||
which contains all images in the batch.
|
||||
img_metas (list[list[dict]]): The metadata of images.
|
||||
"""
|
||||
aug_results = []
|
||||
for img, img_meta in zip(imgs, img_metas):
|
||||
result = self.simple_test(img, img_meta, **kwargs)
|
||||
aug_results.append(result)
|
||||
|
||||
return self.merge_aug_results(aug_results)
|
||||
if self.with_encoder:
|
||||
out_enc = self.encoder(feat, data_samples)
|
||||
out_dec = self.decoder(feat, out_enc, data_samples, train_mode=False)
|
||||
data_samples = self.decoder.postprocessor(out_dec, data_samples)
|
||||
return data_samples
|
||||
|
|
Loading…
Reference in New Issue