mmocr/projects/SPTS/spts/model/encoder_decoder_text_spotter.py
Tong Gao 2d743cfa19
[Model] SPTS (#1696)
* [Feature] Add RepeatAugSampler

* initial commit

* spts inference done

* merge repeat_aug (bug in multi-node?)

* fix inference

* train done

* rm readme

* Revert "merge repeat_aug (bug in multi-node?)"

This reverts commit 393506a97cbe6d75ad1f28611ea10eba6b8fa4b3.

* Revert "[Feature] Add RepeatAugSampler"

This reverts commit 2089b02b4844157670033766f257b5d1bca452ce.

* remove utils

* readme & conversion script

* update readme

* fix

* optimize

* rename cfg & del compose

* fix

* fix
2023-02-01 11:58:03 +08:00

131 lines
5.0 KiB
Python

# Copyright (c) OpenMMLab. All rights reserved.
from typing import Dict
import torch
from mmocr.registry import MODELS
from mmocr.utils.typing_utils import (ConfigType, InitConfigType,
OptConfigType, OptRecSampleList,
RecForwardResults, RecSampleList)
from .base_text_spotter import BaseTextSpotter
@MODELS.register_module()
class EncoderDecoderTextSpotter(BaseTextSpotter):
"""Base class for encode-decode text spotter.
Args:
preprocessor (dict, optional): Config dict for preprocessor. Defaults
to None.
backbone (dict, optional): Backbone config. Defaults to None.
encoder (dict, optional): Encoder config. If None, the output from
backbone will be directly fed into ``decoder``. Defaults to None.
decoder (dict, optional): Decoder config. Defaults to None.
data_preprocessor (dict, optional): Model preprocessing config
for processing the input image data. Keys allowed are
``to_rgb``(bool), ``pad_size_divisor``(int), ``pad_value``(int or
float), ``mean``(int or float) and ``std``(int or float).
Preprcessing order: 1. to rgb; 2. normalization 3. pad.
Defaults to None.
init_cfg (dict or list[dict], optional): Initialization configs.
Defaults to None.
"""
def __init__(self,
preprocessor: OptConfigType = None,
backbone: OptConfigType = None,
encoder: OptConfigType = None,
decoder: OptConfigType = None,
data_preprocessor: ConfigType = None,
init_cfg: InitConfigType = None) -> None:
super().__init__(
init_cfg=init_cfg, data_preprocessor=data_preprocessor)
# Preprocessor module, e.g., TPS
if preprocessor is not None:
self.preprocessor = MODELS.build(preprocessor)
# Backbone
if backbone is not None:
self.backbone = MODELS.build(backbone)
# Encoder module
if encoder is not None:
self.encoder = MODELS.build(encoder)
# Decoder module
assert decoder is not None
self.decoder = MODELS.build(decoder)
def extract_feat(self, inputs: torch.Tensor) -> torch.Tensor:
"""Directly extract features from the backbone."""
if self.with_preprocessor:
inputs = self.preprocessor(inputs)
if self.with_backbone:
inputs = self.backbone(inputs)
return inputs
def loss(self, inputs: torch.Tensor, data_samples: RecSampleList,
**kwargs) -> Dict:
"""Calculate losses from a batch of inputs and data samples.
Args:
inputs (tensor): Input images of shape (N, C, H, W).
Typically these should be mean centered and std scaled.
data_samples (list[TextRecogDataSample]): A list of N
datasamples, containing meta information and gold
annotations for each of the images.
Returns:
dict[str, tensor]: A dictionary of loss components.
"""
feat = self.extract_feat(inputs)
out_enc = None
if self.with_encoder:
out_enc = self.encoder(feat, data_samples)
return self.decoder.loss(feat, out_enc, data_samples)
def predict(self, inputs: torch.Tensor, data_samples: RecSampleList,
**kwargs) -> RecSampleList:
"""Predict results from a batch of inputs and data samples with post-
processing.
Args:
inputs (torch.Tensor): Image input tensor.
data_samples (list[TextRecogDataSample]): A list of N datasamples,
containing meta information and gold annotations for each of
the images.
Returns:
list[TextRecogDataSample]: A list of N datasamples of prediction
results. Results are stored in ``pred_text``.
"""
feat = self.extract_feat(inputs)
out_enc = None
if self.with_encoder:
out_enc = self.encoder(feat, data_samples)
return self.decoder.predict(feat, out_enc, data_samples)
def _forward(self,
inputs: torch.Tensor,
data_samples: OptRecSampleList = None,
**kwargs) -> RecForwardResults:
"""Network forward process. Usually includes backbone, encoder and
decoder forward without any post-processing.
Args:
inputs (Tensor): Inputs with shape (N, C, H, W).
data_samples (list[TextRecogDataSample]): A list of N
datasamples, containing meta information and gold
annotations for each of the images.
Returns:
Tensor: A tuple of features from ``decoder`` forward.
"""
feat = self.extract_feat(inputs)
out_enc = None
if self.with_encoder:
out_enc = self.encoder(feat, data_samples)
return self.decoder(feat, out_enc, data_samples)