mirror of
https://github.com/open-mmlab/mmocr.git
synced 2025-06-03 21:54:47 +08:00
* [Feature] Add RepeatAugSampler * initial commit * spts inference done * merge repeat_aug (bug in multi-node?) * fix inference * train done * rm readme * Revert "merge repeat_aug (bug in multi-node?)" This reverts commit 393506a97cbe6d75ad1f28611ea10eba6b8fa4b3. * Revert "[Feature] Add RepeatAugSampler" This reverts commit 2089b02b4844157670033766f257b5d1bca452ce. * remove utils * readme & conversion script * update readme * fix * optimize * rename cfg & del compose * fix * fix
131 lines
5.0 KiB
Python
131 lines
5.0 KiB
Python
# Copyright (c) OpenMMLab. All rights reserved.
|
|
|
|
from typing import Dict
|
|
|
|
import torch
|
|
|
|
from mmocr.registry import MODELS
|
|
from mmocr.utils.typing_utils import (ConfigType, InitConfigType,
|
|
OptConfigType, OptRecSampleList,
|
|
RecForwardResults, RecSampleList)
|
|
from .base_text_spotter import BaseTextSpotter
|
|
|
|
|
|
@MODELS.register_module()
|
|
class EncoderDecoderTextSpotter(BaseTextSpotter):
|
|
"""Base class for encode-decode text spotter.
|
|
|
|
Args:
|
|
preprocessor (dict, optional): Config dict for preprocessor. Defaults
|
|
to None.
|
|
backbone (dict, optional): Backbone config. Defaults to None.
|
|
encoder (dict, optional): Encoder config. If None, the output from
|
|
backbone will be directly fed into ``decoder``. Defaults to None.
|
|
decoder (dict, optional): Decoder config. Defaults to None.
|
|
data_preprocessor (dict, optional): Model preprocessing config
|
|
for processing the input image data. Keys allowed are
|
|
``to_rgb``(bool), ``pad_size_divisor``(int), ``pad_value``(int or
|
|
float), ``mean``(int or float) and ``std``(int or float).
|
|
Preprcessing order: 1. to rgb; 2. normalization 3. pad.
|
|
Defaults to None.
|
|
init_cfg (dict or list[dict], optional): Initialization configs.
|
|
Defaults to None.
|
|
"""
|
|
|
|
def __init__(self,
|
|
preprocessor: OptConfigType = None,
|
|
backbone: OptConfigType = None,
|
|
encoder: OptConfigType = None,
|
|
decoder: OptConfigType = None,
|
|
data_preprocessor: ConfigType = None,
|
|
init_cfg: InitConfigType = None) -> None:
|
|
|
|
super().__init__(
|
|
init_cfg=init_cfg, data_preprocessor=data_preprocessor)
|
|
|
|
# Preprocessor module, e.g., TPS
|
|
if preprocessor is not None:
|
|
self.preprocessor = MODELS.build(preprocessor)
|
|
|
|
# Backbone
|
|
if backbone is not None:
|
|
self.backbone = MODELS.build(backbone)
|
|
|
|
# Encoder module
|
|
if encoder is not None:
|
|
self.encoder = MODELS.build(encoder)
|
|
|
|
# Decoder module
|
|
assert decoder is not None
|
|
self.decoder = MODELS.build(decoder)
|
|
|
|
def extract_feat(self, inputs: torch.Tensor) -> torch.Tensor:
|
|
"""Directly extract features from the backbone."""
|
|
if self.with_preprocessor:
|
|
inputs = self.preprocessor(inputs)
|
|
if self.with_backbone:
|
|
inputs = self.backbone(inputs)
|
|
return inputs
|
|
|
|
def loss(self, inputs: torch.Tensor, data_samples: RecSampleList,
|
|
**kwargs) -> Dict:
|
|
"""Calculate losses from a batch of inputs and data samples.
|
|
Args:
|
|
inputs (tensor): Input images of shape (N, C, H, W).
|
|
Typically these should be mean centered and std scaled.
|
|
data_samples (list[TextRecogDataSample]): A list of N
|
|
datasamples, containing meta information and gold
|
|
annotations for each of the images.
|
|
|
|
Returns:
|
|
dict[str, tensor]: A dictionary of loss components.
|
|
"""
|
|
feat = self.extract_feat(inputs)
|
|
out_enc = None
|
|
if self.with_encoder:
|
|
out_enc = self.encoder(feat, data_samples)
|
|
return self.decoder.loss(feat, out_enc, data_samples)
|
|
|
|
def predict(self, inputs: torch.Tensor, data_samples: RecSampleList,
|
|
**kwargs) -> RecSampleList:
|
|
"""Predict results from a batch of inputs and data samples with post-
|
|
processing.
|
|
|
|
Args:
|
|
inputs (torch.Tensor): Image input tensor.
|
|
data_samples (list[TextRecogDataSample]): A list of N datasamples,
|
|
containing meta information and gold annotations for each of
|
|
the images.
|
|
|
|
Returns:
|
|
list[TextRecogDataSample]: A list of N datasamples of prediction
|
|
results. Results are stored in ``pred_text``.
|
|
"""
|
|
feat = self.extract_feat(inputs)
|
|
out_enc = None
|
|
if self.with_encoder:
|
|
out_enc = self.encoder(feat, data_samples)
|
|
return self.decoder.predict(feat, out_enc, data_samples)
|
|
|
|
def _forward(self,
|
|
inputs: torch.Tensor,
|
|
data_samples: OptRecSampleList = None,
|
|
**kwargs) -> RecForwardResults:
|
|
"""Network forward process. Usually includes backbone, encoder and
|
|
decoder forward without any post-processing.
|
|
|
|
Args:
|
|
inputs (Tensor): Inputs with shape (N, C, H, W).
|
|
data_samples (list[TextRecogDataSample]): A list of N
|
|
datasamples, containing meta information and gold
|
|
annotations for each of the images.
|
|
|
|
Returns:
|
|
Tensor: A tuple of features from ``decoder`` forward.
|
|
"""
|
|
feat = self.extract_feat(inputs)
|
|
out_enc = None
|
|
if self.with_encoder:
|
|
out_enc = self.encoder(feat, data_samples)
|
|
return self.decoder(feat, out_enc, data_samples)
|