mirror of https://github.com/open-mmlab/mmocr.git
147 lines
6.1 KiB
Python
147 lines
6.1 KiB
Python
# Copyright (c) OpenMMLab. All rights reserved.
|
|
import warnings
|
|
from typing import Dict, List, Optional, Sequence, Union
|
|
|
|
import torch
|
|
import torch.nn as nn
|
|
|
|
from mmocr.models.common.dictionary import Dictionary
|
|
from mmocr.registry import MODELS
|
|
from mmocr.structures import TextRecogDataSample
|
|
from .base import BaseDecoder
|
|
|
|
|
|
@MODELS.register_module()
|
|
class RobustScannerFuser(BaseDecoder):
|
|
"""Decoder for RobustScanner.
|
|
|
|
Args:
|
|
dictionary (dict or :obj:`Dictionary`): The config for `Dictionary` or
|
|
the instance of `Dictionary`.
|
|
module_loss (dict, optional): Config to build module_loss. Defaults
|
|
to None.
|
|
postprocessor (dict, optional): Config to build postprocessor.
|
|
Defaults to None.
|
|
hybrid_decoder (dict): Config to build hybrid_decoder. Defaults to
|
|
dict(type='SequenceAttentionDecoder').
|
|
position_decoder (dict): Config to build position_decoder. Defaults to
|
|
dict(type='PositionAttentionDecoder').
|
|
fuser (dict): Config to build fuser. Defaults to
|
|
dict(type='RobustScannerFuser').
|
|
max_seq_len (int): Maximum sequence length. The
|
|
sequence is usually generated from decoder. Defaults to 30.
|
|
in_channels (list[int]): List of input channels.
|
|
Defaults to [512, 512].
|
|
dim (int): The dimension on which to split the input. Defaults to -1.
|
|
init_cfg (dict or list[dict], optional): Initialization configs.
|
|
Defaults to None.
|
|
"""
|
|
|
|
def __init__(self,
|
|
dictionary: Union[Dict, Dictionary],
|
|
module_loss: Optional[Dict] = None,
|
|
postprocessor: Optional[Dict] = None,
|
|
hybrid_decoder: Dict = dict(type='SequenceAttentionDecoder'),
|
|
position_decoder: Dict = dict(
|
|
type='PositionAttentionDecoder'),
|
|
max_seq_len: int = 30,
|
|
in_channels: List[int] = [512, 512],
|
|
dim: int = -1,
|
|
init_cfg: Optional[Union[Dict, List[Dict]]] = None) -> None:
|
|
super().__init__(
|
|
dictionary=dictionary,
|
|
module_loss=module_loss,
|
|
postprocessor=postprocessor,
|
|
max_seq_len=max_seq_len,
|
|
init_cfg=init_cfg)
|
|
|
|
for cfg_name in ['hybrid_decoder', 'position_decoder']:
|
|
cfg = eval(cfg_name)
|
|
if cfg is not None:
|
|
if cfg.get('dictionary', None) is None:
|
|
cfg.update(dictionary=self.dictionary)
|
|
else:
|
|
warnings.warn(f"Using dictionary {cfg['dictionary']} "
|
|
"in decoder's config.")
|
|
if cfg.get('max_seq_len', None) is None:
|
|
cfg.update(max_seq_len=max_seq_len)
|
|
else:
|
|
warnings.warn(f"Using max_seq_len {cfg['max_seq_len']} "
|
|
"in decoder's config.")
|
|
setattr(self, cfg_name, MODELS.build(cfg))
|
|
|
|
in_channels = sum(in_channels)
|
|
self.dim = dim
|
|
|
|
self.linear_layer = nn.Linear(in_channels, in_channels)
|
|
self.glu_layer = nn.GLU(dim=dim)
|
|
self.prediction = nn.Linear(
|
|
int(in_channels / 2), self.dictionary.num_classes)
|
|
self.softmax = nn.Softmax(dim=-1)
|
|
|
|
def forward_train(
|
|
self,
|
|
feat: Optional[torch.Tensor] = None,
|
|
out_enc: Optional[torch.Tensor] = None,
|
|
data_samples: Optional[Sequence[TextRecogDataSample]] = None
|
|
) -> torch.Tensor:
|
|
"""Forward for training.
|
|
|
|
Args:
|
|
feat (torch.Tensor, optional): The feature map from backbone of
|
|
shape :math:`(N, E, H, W)`. Defaults to None.
|
|
out_enc (torch.Tensor, optional): Encoder output. Defaults to None.
|
|
data_samples (Sequence[TextRecogDataSample]): Batch of
|
|
TextRecogDataSample, containing gt_text information. Defaults
|
|
to None.
|
|
"""
|
|
hybrid_glimpse = self.hybrid_decoder(feat, out_enc, data_samples)
|
|
position_glimpse = self.position_decoder(feat, out_enc, data_samples)
|
|
fusion_input = torch.cat([hybrid_glimpse, position_glimpse], self.dim)
|
|
outputs = self.linear_layer(fusion_input)
|
|
outputs = self.glu_layer(outputs)
|
|
return self.prediction(outputs)
|
|
|
|
def forward_test(
|
|
self,
|
|
feat: Optional[torch.Tensor] = None,
|
|
out_enc: Optional[torch.Tensor] = None,
|
|
data_samples: Optional[Sequence[TextRecogDataSample]] = None
|
|
) -> torch.Tensor:
|
|
"""Forward for testing.
|
|
|
|
Args:
|
|
feat (torch.Tensor, optional): The feature map from backbone of
|
|
shape :math:`(N, E, H, W)`. Defaults to None.
|
|
out_enc (torch.Tensor, optional): Encoder output. Defaults to None.
|
|
data_samples (Sequence[TextRecogDataSample]): Batch of
|
|
TextRecogDataSample, containing vaild_ratio information.
|
|
Defaults to None.
|
|
|
|
Returns:
|
|
Tensor: Character probabilities. of shape
|
|
:math:`(N, self.max_seq_len, C)` where :math:`C` is
|
|
``num_classes``.
|
|
"""
|
|
position_glimpse = self.position_decoder(feat, out_enc, data_samples)
|
|
|
|
batch_size = feat.size(0)
|
|
decode_sequence = (feat.new_ones((batch_size, self.max_seq_len)) *
|
|
self.dictionary.start_idx).long()
|
|
outputs = []
|
|
for step in range(self.max_seq_len):
|
|
hybrid_glimpse_step = self.hybrid_decoder.forward_test_step(
|
|
feat, out_enc, decode_sequence, step, data_samples)
|
|
|
|
fusion_input = torch.cat(
|
|
[hybrid_glimpse_step, position_glimpse[:, step, :]], self.dim)
|
|
output = self.linear_layer(fusion_input)
|
|
output = self.glu_layer(output)
|
|
output = self.prediction(output)
|
|
_, max_idx = torch.max(output, dim=1, keepdim=False)
|
|
if step < self.max_seq_len - 1:
|
|
decode_sequence[:, step + 1] = max_idx
|
|
outputs.append(output)
|
|
outputs = torch.stack(outputs, 1)
|
|
return self.softmax(outputs)
|