mirror of
https://github.com/open-mmlab/mmocr.git
synced 2025-06-03 21:54:47 +08:00
131 lines
5.0 KiB
Python
131 lines
5.0 KiB
Python
|
# Copyright (c) OpenMMLab. All rights reserved.
|
||
|
|
||
|
from typing import Dict
|
||
|
|
||
|
import torch
|
||
|
|
||
|
from mmocr.registry import MODELS
|
||
|
from mmocr.utils.typing_utils import (ConfigType, InitConfigType,
|
||
|
OptConfigType, OptRecSampleList,
|
||
|
RecForwardResults, RecSampleList)
|
||
|
from .base_text_spotter import BaseTextSpotter
|
||
|
|
||
|
|
||
|
@MODELS.register_module()
|
||
|
class EncoderDecoderTextSpotter(BaseTextSpotter):
|
||
|
"""Base class for encode-decode text spotter.
|
||
|
|
||
|
Args:
|
||
|
preprocessor (dict, optional): Config dict for preprocessor. Defaults
|
||
|
to None.
|
||
|
backbone (dict, optional): Backbone config. Defaults to None.
|
||
|
encoder (dict, optional): Encoder config. If None, the output from
|
||
|
backbone will be directly fed into ``decoder``. Defaults to None.
|
||
|
decoder (dict, optional): Decoder config. Defaults to None.
|
||
|
data_preprocessor (dict, optional): Model preprocessing config
|
||
|
for processing the input image data. Keys allowed are
|
||
|
``to_rgb``(bool), ``pad_size_divisor``(int), ``pad_value``(int or
|
||
|
float), ``mean``(int or float) and ``std``(int or float).
|
||
|
Preprcessing order: 1. to rgb; 2. normalization 3. pad.
|
||
|
Defaults to None.
|
||
|
init_cfg (dict or list[dict], optional): Initialization configs.
|
||
|
Defaults to None.
|
||
|
"""
|
||
|
|
||
|
def __init__(self,
|
||
|
preprocessor: OptConfigType = None,
|
||
|
backbone: OptConfigType = None,
|
||
|
encoder: OptConfigType = None,
|
||
|
decoder: OptConfigType = None,
|
||
|
data_preprocessor: ConfigType = None,
|
||
|
init_cfg: InitConfigType = None) -> None:
|
||
|
|
||
|
super().__init__(
|
||
|
init_cfg=init_cfg, data_preprocessor=data_preprocessor)
|
||
|
|
||
|
# Preprocessor module, e.g., TPS
|
||
|
if preprocessor is not None:
|
||
|
self.preprocessor = MODELS.build(preprocessor)
|
||
|
|
||
|
# Backbone
|
||
|
if backbone is not None:
|
||
|
self.backbone = MODELS.build(backbone)
|
||
|
|
||
|
# Encoder module
|
||
|
if encoder is not None:
|
||
|
self.encoder = MODELS.build(encoder)
|
||
|
|
||
|
# Decoder module
|
||
|
assert decoder is not None
|
||
|
self.decoder = MODELS.build(decoder)
|
||
|
|
||
|
def extract_feat(self, inputs: torch.Tensor) -> torch.Tensor:
|
||
|
"""Directly extract features from the backbone."""
|
||
|
if self.with_preprocessor:
|
||
|
inputs = self.preprocessor(inputs)
|
||
|
if self.with_backbone:
|
||
|
inputs = self.backbone(inputs)
|
||
|
return inputs
|
||
|
|
||
|
def loss(self, inputs: torch.Tensor, data_samples: RecSampleList,
|
||
|
**kwargs) -> Dict:
|
||
|
"""Calculate losses from a batch of inputs and data samples.
|
||
|
Args:
|
||
|
inputs (tensor): Input images of shape (N, C, H, W).
|
||
|
Typically these should be mean centered and std scaled.
|
||
|
data_samples (list[TextRecogDataSample]): A list of N
|
||
|
datasamples, containing meta information and gold
|
||
|
annotations for each of the images.
|
||
|
|
||
|
Returns:
|
||
|
dict[str, tensor]: A dictionary of loss components.
|
||
|
"""
|
||
|
feat = self.extract_feat(inputs)
|
||
|
out_enc = None
|
||
|
if self.with_encoder:
|
||
|
out_enc = self.encoder(feat, data_samples)
|
||
|
return self.decoder.loss(feat, out_enc, data_samples)
|
||
|
|
||
|
def predict(self, inputs: torch.Tensor, data_samples: RecSampleList,
|
||
|
**kwargs) -> RecSampleList:
|
||
|
"""Predict results from a batch of inputs and data samples with post-
|
||
|
processing.
|
||
|
|
||
|
Args:
|
||
|
inputs (torch.Tensor): Image input tensor.
|
||
|
data_samples (list[TextRecogDataSample]): A list of N datasamples,
|
||
|
containing meta information and gold annotations for each of
|
||
|
the images.
|
||
|
|
||
|
Returns:
|
||
|
list[TextRecogDataSample]: A list of N datasamples of prediction
|
||
|
results. Results are stored in ``pred_text``.
|
||
|
"""
|
||
|
feat = self.extract_feat(inputs)
|
||
|
out_enc = None
|
||
|
if self.with_encoder:
|
||
|
out_enc = self.encoder(feat, data_samples)
|
||
|
return self.decoder.predict(feat, out_enc, data_samples)
|
||
|
|
||
|
def _forward(self,
|
||
|
inputs: torch.Tensor,
|
||
|
data_samples: OptRecSampleList = None,
|
||
|
**kwargs) -> RecForwardResults:
|
||
|
"""Network forward process. Usually includes backbone, encoder and
|
||
|
decoder forward without any post-processing.
|
||
|
|
||
|
Args:
|
||
|
inputs (Tensor): Inputs with shape (N, C, H, W).
|
||
|
data_samples (list[TextRecogDataSample]): A list of N
|
||
|
datasamples, containing meta information and gold
|
||
|
annotations for each of the images.
|
||
|
|
||
|
Returns:
|
||
|
Tensor: A tuple of features from ``decoder`` forward.
|
||
|
"""
|
||
|
feat = self.extract_feat(inputs)
|
||
|
out_enc = None
|
||
|
if self.with_encoder:
|
||
|
out_enc = self.encoder(feat, data_samples)
|
||
|
return self.decoder(feat, out_enc, data_samples)
|