From 206c4ccc65f5e3620e21c493ecd88ab2829746c7 Mon Sep 17 00:00:00 2001 From: liukuikun <641417025@qq.com> Date: Thu, 26 May 2022 14:32:06 +0000 Subject: [PATCH] [Refactor] encoder_decoder_recognizer --- .dev_scripts/covignore.cfg | 1 + mmocr/models/textrecog/recognizers/base.py | 46 ++-- .../recognizers/encode_decode_recognizer.py | 233 ++++++++---------- 3 files changed, 130 insertions(+), 150 deletions(-) diff --git a/.dev_scripts/covignore.cfg b/.dev_scripts/covignore.cfg index bafcba45..82368a9a 100644 --- a/.dev_scripts/covignore.cfg +++ b/.dev_scripts/covignore.cfg @@ -4,6 +4,7 @@ # mmocr/models/textdet/postprocess/utils.py # .*/utils.py mmocr/models/textrecog/recognizers/base.py +mmocr/models/textrecog/recognizers/encode_decode_recognizer.py .*/__init__.py # It will be removed after all transforms have been refactored into processing.py mmocr/datasets/pipelines/transforms.py diff --git a/mmocr/models/textrecog/recognizers/base.py b/mmocr/models/textrecog/recognizers/base.py index 1e1f24b0..4bfbe3e8 100644 --- a/mmocr/models/textrecog/recognizers/base.py +++ b/mmocr/models/textrecog/recognizers/base.py @@ -56,6 +56,31 @@ class BaseRecognizer(BaseModule, metaclass=ABCMeta): def device(self) -> torch.device: return self.pixel_mean.device + @property + def with_backbone(self): + """bool: whether the recognizer has a backbone""" + return getattr(self, 'backbone', None) is not None + + @property + def with_encoder(self): + """bool: whether the recognizer has an encoder""" + return getattr(self, 'encoder', None) is not None + + @property + def with_preprocessor(self): + """bool: whether the recognizer has a preprocessor""" + return getattr(self, 'preprocessor', None) is not None + + @property + def with_dictionary(self): + """bool: whether the recognizer has a dictionary""" + return getattr(self, 'dictionary', None) is not None + + @property + def with_decoder(self): + """bool: whether the recognizer has a decoder""" + return getattr(self, 'decoder', None) is not None + @abstractmethod def extract_feat(self, inputs: torch.Tensor) -> torch.Tensor: """Extract features from images.""" @@ -77,8 +102,8 @@ class BaseRecognizer(BaseModule, metaclass=ABCMeta): # NOTE the batched image size information may be useful for # calculating valid ratio. batch_input_shape = tuple(inputs[0].size()[-2:]) - for data_samples in data_samples: - data_samples.set_metainfo({'batch_input_shape': batch_input_shape}) + for data_sample in data_samples: + data_sample.set_metainfo({'batch_input_shape': batch_input_shape}) @abstractmethod def simple_test(self, inputs: torch.Tensor, @@ -86,7 +111,6 @@ class BaseRecognizer(BaseModule, metaclass=ABCMeta): **kwargs) -> Sequence[TextRecogDataSample]: pass - @abstractmethod def aug_test(self, imgs: torch.Tensor, data_samples: Sequence[Sequence[TextRecogDataSample]], **kwargs): @@ -230,16 +254,8 @@ class BaseRecognizer(BaseModule, metaclass=ABCMeta): ``pred_text``. """ # TODO: Consider merging with forward_train logic - batch_size = len(data_samples) - batch_img_metas = [] - for batch_index in range(batch_size): - metainfo = data_samples[batch_index].metainfo + batch_input_shape = tuple(inputs[0].size()[-2:]) + for data_sample in data_samples: + data_sample.set_metainfo({'batch_input_shape': batch_input_shape}) - # TODO: maybe remove to stack_batch - # NOTE the batched image size information may be useful for - # calculating valid ratio. - metainfo['batch_input_shape'] = \ - tuple(inputs.size()[-2:]) - batch_img_metas.append(metainfo) - - return self.simple_test(inputs, batch_img_metas, **kwargs) + return self.simple_test(inputs, data_samples, **kwargs) diff --git a/mmocr/models/textrecog/recognizers/encode_decode_recognizer.py b/mmocr/models/textrecog/recognizers/encode_decode_recognizer.py index 52614444..be20b858 100644 --- a/mmocr/models/textrecog/recognizers/encode_decode_recognizer.py +++ b/mmocr/models/textrecog/recognizers/encode_decode_recognizer.py @@ -1,181 +1,144 @@ # Copyright (c) OpenMMLab. All rights reserved. + import warnings +from typing import Dict, Optional, Sequence import torch -from mmocr.registry import MODELS +from mmocr.core.data_structures import TextRecogDataSample +from mmocr.registry import MODELS, TASK_UTILS from .base import BaseRecognizer @MODELS.register_module() class EncodeDecodeRecognizer(BaseRecognizer): - """Base class for encode-decode recognizer.""" + """Base class for encode-decode recognizer. + + Args: + backbone (dict, optional): Backbone config. Defaults to None. + encoder (dict, optional): Encoder config. If None, the output from + backbone will be directly fed into ``decoder``. Defaults to None. + decoder (dict, optional): Decoder config. Defaults to None. + dictionary (dict, optional): Dictionary config. Defaults to None. + max_seq_len (int): Maximum sequence length. Defaults to 40. + preprocess_cfg (dict, optional): Model preprocessing config + for processing the input image data. Keys allowed are + ``to_rgb``(bool), ``pad_size_divisor``(int), ``pad_value``(int or + float), ``mean``(int or float) and ``std``(int or float). + Preprcessing order: 1. to rgb; 2. normalization 3. pad. + Defaults to None. + init_cfg (dict or list[dict], optional): Initialization configs. + Defaults to None. + """ def __init__(self, - preprocessor=None, - backbone=None, - encoder=None, - decoder=None, - loss=None, - label_convertor=None, - train_cfg=None, - test_cfg=None, - max_seq_len=40, - pretrained=None, - init_cfg=None): + preprocessor: Optional[Dict] = None, + backbone: Optional[Dict] = None, + encoder: Optional[Dict] = None, + decoder: Optional[Dict] = None, + dictionary: Optional[Dict] = None, + max_seq_len: int = 40, + preprocess_cfg: Dict = None, + init_cfg: Optional[Dict] = None) -> None: - super().__init__(init_cfg=init_cfg) - - # Label convertor (str2tensor, tensor2str) - assert label_convertor is not None - label_convertor.update(max_seq_len=max_seq_len) - self.label_convertor = MODELS.build(label_convertor) + super().__init__(init_cfg=init_cfg, preprocess_cfg=preprocess_cfg) # Preprocessor module, e.g., TPS - self.preprocessor = None if preprocessor is not None: self.preprocessor = MODELS.build(preprocessor) # Backbone - assert backbone is not None - self.backbone = MODELS.build(backbone) + if backbone is not None: + self.backbone = MODELS.build(backbone) # Encoder module - self.encoder = None if encoder is not None: self.encoder = MODELS.build(encoder) + # Dictionary + if dictionary is not None: + self.dictionary = TASK_UTILS.build(dictionary) # Decoder module assert decoder is not None - decoder.update(num_classes=self.label_convertor.num_classes()) - decoder.update(start_idx=self.label_convertor.start_idx) - decoder.update(padding_idx=self.label_convertor.padding_idx) - decoder.update(max_seq_len=max_seq_len) + + if self.with_dictionary: + if decoder.get('dictionary', None) is None: + decoder.update(dictionary=self.dictionary) + else: + warnings.warn(f"Using dictionary {decoder['dictionary']} " + "in decoder's config.") + if decoder.get('max_seq_len', None) is None: + decoder.update(max_seq_len=max_seq_len) + else: + warnings.warn(f"Using max_seq_len {decoder['max_seq_len']} " + "in decoder's config.") self.decoder = MODELS.build(decoder) - # Loss - assert loss is not None - loss.update(ignore_index=self.label_convertor.padding_idx) - self.loss = MODELS.build(loss) - - self.train_cfg = train_cfg - self.test_cfg = test_cfg - self.max_seq_len = max_seq_len - - if pretrained is not None: - warnings.warn('DeprecationWarning: pretrained is a deprecated \ - key, please consider using init_cfg') - self.init_cfg = dict(type='Pretrained', checkpoint=pretrained) - - def extract_feat(self, img): + def extract_feat(self, inputs: torch.Tensor) -> torch.Tensor: """Directly extract features from the backbone.""" - if self.preprocessor is not None: - img = self.preprocessor(img) + if self.with_preprocessor: + inputs = self.preprocessor(inputs) + if self.with_backbone: + inputs = self.backbone(inputs) + return inputs - x = self.backbone(img) - - return x - - def forward_train(self, img, img_metas): + def forward_train(self, inputs: torch.Tensor, + data_samples: Sequence[TextRecogDataSample], + **kwargs) -> Dict: """ - Args: - img (tensor): Input images of shape (N, C, H, W). - Typically these should be mean centered and std scaled. - img_metas (list[dict]): A list of image info dict where each dict - contains: 'img_shape', 'filename', and may also contain - 'ori_shape', and 'img_norm_cfg'. - For details on the values of these keys see - :class:`mmdet.datasets.pipelines.Collect`. + Args: + inputs (tensor): Input images of shape (N, C, H, W). + Typically these should be mean centered and std scaled. + data_samples (list[TextRecogDataSample]): A list of N + datasamples, containing meta information and gold + annotations for each of the images. - Returns: - dict[str, tensor]: A dictionary of loss components. + Returns: + dict[str, tensor]: A dictionary of loss components. """ - for img_meta in img_metas: - valid_ratio = 1.0 * img_meta['resize_shape'][1] / img.size(-1) - img_meta['valid_ratio'] = valid_ratio - - feat = self.extract_feat(img) - - gt_labels = [img_meta['text'] for img_meta in img_metas] - - targets_dict = self.label_convertor.str2tensor(gt_labels) + # TODO move to preprocess to update valid ratio + super().forward_train(inputs, data_samples, **kwargs) + for data_sample in data_samples: + valid_ratio = data_sample.valid_ratio * data_sample.img_shape[ + 1] / data_sample.batch_input_shape[1] + data_sample.set_metainfo(dict(valid_ratio=valid_ratio)) + feat = self.extract_feat(inputs) out_enc = None - if self.encoder is not None: - out_enc = self.encoder(feat, img_metas) + if self.with_encoder: + out_enc = self.encoder(feat, data_samples) + data_samples = self.decoder.loss.get_target(data_samples) + out_dec = self.decoder(feat, out_enc, data_samples, train_mode=True) - out_dec = self.decoder( - feat, out_enc, targets_dict, img_metas, train_mode=True) - - loss_inputs = ( - out_dec, - targets_dict, - img_metas, - ) - losses = self.loss(*loss_inputs) + losses = self.decoder.loss(out_dec, data_samples) return losses - def simple_test(self, img, img_metas, **kwargs): - """Test function with test time augmentation. + def simple_test(self, inputs: torch.Tensor, + data_samples: Sequence[TextRecogDataSample], + **kwargs) -> Sequence[TextRecogDataSample]: + """Test function without test-time augmentation. Args: - imgs (torch.Tensor): Image input tensor. - img_metas (list[dict]): List of image information. + inputs (torch.Tensor): Image input tensor. + data_samples (list[TextRecogDataSample]): A list of N datasamples, + containing meta information and gold annotations for each of + the images. Returns: - list[str]: Text label result of each image. + list[TextRecogDataSample]: A list of N datasamples of prediction + results. Results are stored in ``pred_text``. """ - for img_meta in img_metas: - valid_ratio = 1.0 * img_meta['resize_shape'][1] / img.size(-1) - img_meta['valid_ratio'] = valid_ratio - - feat = self.extract_feat(img) + # TODO move to preprocess to update valid ratio + for data_sample in data_samples: + valid_ratio = data_sample.valid_ratio * data_sample.img_shape[ + 1] / data_sample.batch_input_shape[1] + data_sample.set_metainfo(dict(valid_ratio=valid_ratio)) + feat = self.extract_feat(inputs) out_enc = None - if self.encoder is not None: - out_enc = self.encoder(feat, img_metas) - - out_dec = self.decoder( - feat, out_enc, None, img_metas, train_mode=False) - - # early return to avoid post processing - if torch.onnx.is_in_onnx_export(): - return out_dec - - label_indexes, label_scores = self.label_convertor.tensor2idx( - out_dec, img_metas) - label_strings = self.label_convertor.idx2str(label_indexes) - - # flatten batch results - results = [] - for string, score in zip(label_strings, label_scores): - results.append(dict(text=string, score=score)) - - return results - - def merge_aug_results(self, aug_results): - out_text, out_score = '', -1 - for result in aug_results: - text = result[0]['text'] - score = sum(result[0]['score']) / max(1, len(text)) - if score > out_score: - out_text = text - out_score = score - out_results = [dict(text=out_text, score=out_score)] - return out_results - - def aug_test(self, imgs, img_metas, **kwargs): - """Test function as well as time augmentation. - - Args: - imgs (list[tensor]): Tensor should have shape NxCxHxW, - which contains all images in the batch. - img_metas (list[list[dict]]): The metadata of images. - """ - aug_results = [] - for img, img_meta in zip(imgs, img_metas): - result = self.simple_test(img, img_meta, **kwargs) - aug_results.append(result) - - return self.merge_aug_results(aug_results) + if self.with_encoder: + out_enc = self.encoder(feat, data_samples) + out_dec = self.decoder(feat, out_enc, data_samples, train_mode=False) + data_samples = self.decoder.postprocessor(out_dec, data_samples) + return data_samples