mirror of
https://github.com/open-mmlab/mmocr.git
synced 2025-06-03 21:54:47 +08:00
* bezier align * Update projects/ABCNet/README.md * Update projects/ABCNet/README.md * update * updata home readme Co-authored-by: Tong Gao <gaotongxiao@gmail.com>
162 lines
6.4 KiB
Python
162 lines
6.4 KiB
Python
# Copyright (c) OpenMMLab. All rights reserved.
|
|
import random
|
|
from typing import Dict, Optional, Sequence, Union
|
|
|
|
import torch
|
|
import torch.nn as nn
|
|
from torch.nn import functional as F
|
|
|
|
from mmocr.models.common.dictionary import Dictionary
|
|
from mmocr.models.textrecog.decoders.base import BaseDecoder
|
|
from mmocr.registry import MODELS
|
|
from mmocr.structures import TextRecogDataSample
|
|
|
|
|
|
@MODELS.register_module()
|
|
class ABCNetRecDecoder(BaseDecoder):
|
|
"""Decoder for ABCNet.
|
|
|
|
Args:
|
|
in_channels (int): Number of input channels.
|
|
dropout_prob (float): Probability of dropout. Default to 0.5.
|
|
teach_prob (float): Probability of teacher forcing. Defaults to 0.5.
|
|
dictionary (dict or :obj:`Dictionary`): The config for `Dictionary` or
|
|
the instance of `Dictionary`.
|
|
module_loss (dict, optional): Config to build module_loss. Defaults
|
|
to None.
|
|
postprocessor (dict, optional): Config to build postprocessor.
|
|
Defaults to None.
|
|
max_seq_len (int, optional): Max sequence length. Defaults to 30.
|
|
init_cfg (dict or list[dict], optional): Initialization configs.
|
|
Defaults to None.
|
|
"""
|
|
|
|
def __init__(self,
|
|
in_channels: int = 256,
|
|
dropout_prob: float = 0.5,
|
|
teach_prob: float = 0.5,
|
|
dictionary: Union[Dictionary, Dict] = None,
|
|
module_loss: Dict = None,
|
|
postprocessor: Dict = None,
|
|
max_seq_len: int = 30,
|
|
init_cfg=dict(type='Xavier', layer='Conv2d'),
|
|
**kwargs):
|
|
super().__init__(
|
|
init_cfg=init_cfg,
|
|
dictionary=dictionary,
|
|
module_loss=module_loss,
|
|
postprocessor=postprocessor,
|
|
max_seq_len=max_seq_len)
|
|
self.in_channels = in_channels
|
|
self.teach_prob = teach_prob
|
|
self.embedding = nn.Embedding(self.dictionary.num_classes, in_channels)
|
|
self.attn_combine = nn.Linear(in_channels * 2, in_channels)
|
|
self.dropout = nn.Dropout(dropout_prob)
|
|
self.gru = nn.GRU(in_channels, in_channels)
|
|
self.out = nn.Linear(in_channels, self.dictionary.num_classes)
|
|
self.vat = nn.Linear(in_channels, 1)
|
|
self.softmax = nn.Softmax(dim=-1)
|
|
|
|
def forward_train(
|
|
self,
|
|
feat: torch.Tensor,
|
|
out_enc: Optional[torch.Tensor] = None,
|
|
data_samples: Optional[Sequence[TextRecogDataSample]] = None
|
|
) -> torch.Tensor:
|
|
"""
|
|
Args:
|
|
feat (Tensor): A Tensor of shape :math:`(N, C, 1, W)`.
|
|
out_enc (torch.Tensor, optional): Encoder output. Defaults to None.
|
|
data_samples (list[TextRecogDataSample], optional): Batch of
|
|
TextRecogDataSample, containing gt_text information. Defaults
|
|
to None.
|
|
|
|
Returns:
|
|
Tensor: The raw logit tensor. Shape :math:`(N, W, C)` where
|
|
:math:`C` is ``num_classes``.
|
|
"""
|
|
bs = out_enc.size()[1]
|
|
trg_seq = []
|
|
for target in data_samples:
|
|
trg_seq.append(target.gt_text.padded_indexes.to(feat.device))
|
|
decoder_input = torch.zeros(bs).long().to(out_enc.device)
|
|
trg_seq = torch.stack(trg_seq, dim=0)
|
|
decoder_hidden = torch.zeros(1, bs,
|
|
self.in_channels).to(out_enc.device)
|
|
decoder_outputs = []
|
|
for index in range(trg_seq.shape[1]):
|
|
# decoder_output (nbatch, ncls)
|
|
decoder_output, decoder_hidden = self._attention(
|
|
decoder_input, decoder_hidden, out_enc)
|
|
teach_forcing = True if random.random(
|
|
) > self.teach_prob else False
|
|
if teach_forcing:
|
|
decoder_input = trg_seq[:, index] # Teacher forcing
|
|
else:
|
|
_, topi = decoder_output.data.topk(1)
|
|
decoder_input = topi.squeeze()
|
|
decoder_outputs.append(decoder_output)
|
|
|
|
return torch.stack(decoder_outputs, dim=1)
|
|
|
|
def forward_test(
|
|
self,
|
|
feat: Optional[torch.Tensor] = None,
|
|
out_enc: Optional[torch.Tensor] = None,
|
|
data_samples: Optional[Sequence[TextRecogDataSample]] = None
|
|
) -> torch.Tensor:
|
|
"""
|
|
Args:
|
|
feat (Tensor): A Tensor of shape :math:`(N, C, 1, W)`.
|
|
out_enc (torch.Tensor, optional): Encoder output. Defaults to None.
|
|
data_samples (list[TextRecogDataSample]): Batch of
|
|
TextRecogDataSample, containing ``gt_text`` information.
|
|
Defaults to None.
|
|
|
|
Returns:
|
|
Tensor: Character probabilities. of shape
|
|
:math:`(N, self.max_seq_len, C)` where :math:`C` is
|
|
``num_classes``.
|
|
"""
|
|
bs = out_enc.size()[1]
|
|
outputs = []
|
|
decoder_input = torch.zeros(bs).long().to(out_enc.device)
|
|
decoder_hidden = torch.zeros(1, bs,
|
|
self.in_channels).to(out_enc.device)
|
|
for _ in range(self.max_seq_len):
|
|
# decoder_output (nbatch, ncls)
|
|
decoder_output, decoder_hidden = self._attention(
|
|
decoder_input, decoder_hidden, out_enc)
|
|
_, topi = decoder_output.data.topk(1)
|
|
decoder_input = topi.squeeze()
|
|
outputs.append(decoder_output)
|
|
outputs = torch.stack(outputs, dim=1)
|
|
return self.softmax(outputs)
|
|
|
|
def _attention(self, input, hidden, encoder_outputs):
|
|
embedded = self.embedding(input)
|
|
embedded = self.dropout(embedded)
|
|
|
|
# test
|
|
batch_size = encoder_outputs.shape[1]
|
|
|
|
alpha = hidden + encoder_outputs
|
|
alpha = alpha.view(-1, alpha.shape[-1]) # (T * n, hidden_size)
|
|
attn_weights = self.vat(torch.tanh(alpha)) # (T * n, 1)
|
|
attn_weights = attn_weights.view(-1, 1, batch_size).permute(
|
|
(2, 1, 0)) # (T, 1, n) -> (n, 1, T)
|
|
attn_weights = F.softmax(attn_weights, dim=2)
|
|
|
|
attn_applied = torch.matmul(attn_weights,
|
|
encoder_outputs.permute((1, 0, 2)))
|
|
|
|
if embedded.dim() == 1:
|
|
embedded = embedded.unsqueeze(0)
|
|
output = torch.cat((embedded, attn_applied.squeeze(1)), 1)
|
|
output = self.attn_combine(output).unsqueeze(0) # (1, n, hidden_size)
|
|
|
|
output = F.relu(output)
|
|
output, hidden = self.gru(output, hidden) # (1, n, hidden_size)
|
|
output = self.out(output[0])
|
|
return output, hidden
|