mirror of https://github.com/alibaba/EasyCV.git
483 lines
16 KiB
Python
483 lines
16 KiB
Python
# Modified from https://github.com/PaddlePaddle/PaddleOCR/tree/release/2.6/ppocr/modeling/heads
|
|
import math
|
|
|
|
import torch
|
|
import torch.nn as nn
|
|
import torch.nn.functional as F
|
|
|
|
from easycv.models.builder import HEADS
|
|
from ..necks.squence_encoder import Im2Seq, SequenceEncoder
|
|
|
|
|
|
class SAREncoder(nn.Module):
|
|
"""
|
|
Args:
|
|
enc_bi_rnn (bool): If True, use bidirectional RNN in encoder.
|
|
enc_drop_rnn (float): Dropout probability of RNN layer in encoder.
|
|
enc_gru (bool): If True, use GRU, else LSTM in encoder.
|
|
d_model (int): Dim of channels from backbone.
|
|
d_enc (int): Dim of encoder RNN layer.
|
|
mask (bool): If True, mask padding in RNN sequence.
|
|
"""
|
|
|
|
def __init__(self,
|
|
enc_bi_rnn=False,
|
|
enc_drop_rnn=0.1,
|
|
enc_gru=False,
|
|
d_model=512,
|
|
d_enc=512,
|
|
mask=True,
|
|
**kwargs):
|
|
super().__init__()
|
|
assert isinstance(enc_bi_rnn, bool)
|
|
assert isinstance(enc_drop_rnn, (int, float))
|
|
assert 0 <= enc_drop_rnn < 1.0
|
|
assert isinstance(enc_gru, bool)
|
|
assert isinstance(d_model, int)
|
|
assert isinstance(d_enc, int)
|
|
assert isinstance(mask, bool)
|
|
|
|
self.enc_bi_rnn = enc_bi_rnn
|
|
self.enc_drop_rnn = enc_drop_rnn
|
|
self.mask = mask
|
|
|
|
# LSTM Encoder
|
|
kwargs = dict(
|
|
input_size=d_model,
|
|
hidden_size=d_enc,
|
|
num_layers=2,
|
|
batch_first=True,
|
|
dropout=enc_drop_rnn,
|
|
bidirectional=enc_bi_rnn)
|
|
|
|
if enc_gru:
|
|
self.rnn_encoder = nn.GRU(**kwargs)
|
|
else:
|
|
self.rnn_encoder = nn.LSTM(**kwargs)
|
|
|
|
# global feature transformation
|
|
encoder_rnn_out_size = d_enc * (int(enc_bi_rnn) + 1)
|
|
self.linear = nn.Linear(encoder_rnn_out_size, encoder_rnn_out_size)
|
|
|
|
def forward(self, feat, valid_ratios=None):
|
|
|
|
h_feat = feat.shape[2] # bsz c h w
|
|
feat_v = F.max_pool2d(
|
|
feat, kernel_size=(h_feat, 1), stride=1, padding=0)
|
|
feat_v = feat_v.squeeze(2) # bsz * C * W
|
|
feat_v = feat_v.permute(0, 2, 1).contiguous() # bsz * W * C
|
|
holistic_feat = self.rnn_encoder(feat_v)[0] # bsz * T * C
|
|
|
|
if valid_ratios is not None:
|
|
valid_hf = []
|
|
T = holistic_feat.size(1)
|
|
for i, valid_ratio in enumerate(valid_ratios):
|
|
valid_step = min(T, math.ceil(T * valid_ratio)) - 1
|
|
# for i in range(valid_ratios.size(0)):
|
|
# valid_step = torch.min(T, torch.ceil(T * valid_ratios[i])) - 1
|
|
valid_hf.append(holistic_feat[i, valid_step, :])
|
|
valid_hf = torch.stack(valid_hf, dim=0)
|
|
else:
|
|
valid_hf = holistic_feat[:, -1, :] # bsz * C
|
|
holistic_feat = self.linear(valid_hf) # bsz * C
|
|
|
|
return holistic_feat
|
|
|
|
|
|
class BaseDecoder(nn.Module):
|
|
|
|
def __init__(self, **kwargs):
|
|
super().__init__()
|
|
|
|
def forward_train(self, feat, out_enc, targets, valid_ratios):
|
|
raise NotImplementedError
|
|
|
|
def forward_test(self, feat, out_enc, valid_ratios):
|
|
raise NotImplementedError
|
|
|
|
def forward(self,
|
|
feat,
|
|
out_enc,
|
|
label=None,
|
|
valid_ratios=None,
|
|
train_mode=True):
|
|
self.train_mode = train_mode
|
|
|
|
if train_mode:
|
|
return self.forward_train(feat, out_enc, label, valid_ratios)
|
|
return self.forward_test(feat, out_enc, valid_ratios)
|
|
|
|
|
|
class ParallelSARDecoder(BaseDecoder):
|
|
"""
|
|
Args:
|
|
out_channels (int): Output class number.
|
|
enc_bi_rnn (bool): If True, use bidirectional RNN in encoder.
|
|
dec_bi_rnn (bool): If True, use bidirectional RNN in decoder.
|
|
dec_drop_rnn (float): Dropout of RNN layer in decoder.
|
|
dec_gru (bool): If True, use GRU, else LSTM in decoder.
|
|
d_model (int): Dim of channels from backbone.
|
|
d_enc (int): Dim of encoder RNN layer.
|
|
d_k (int): Dim of channels of attention module.
|
|
pred_dropout (float): Dropout probability of prediction layer.
|
|
max_seq_len (int): Maximum sequence length for decoding.
|
|
mask (bool): If True, mask padding in feature map.
|
|
start_idx (int): Index of start token.
|
|
padding_idx (int): Index of padding token.
|
|
pred_concat (bool): If True, concat glimpse feature from
|
|
attention with holistic feature and hidden state.
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
out_channels, # 90 + unknown + start + padding
|
|
enc_bi_rnn=False,
|
|
dec_bi_rnn=False,
|
|
dec_drop_rnn=0.0,
|
|
dec_gru=False,
|
|
d_model=512,
|
|
d_enc=512,
|
|
d_k=64,
|
|
pred_dropout=0.1,
|
|
max_text_length=30,
|
|
mask=True,
|
|
pred_concat=True,
|
|
**kwargs):
|
|
super().__init__()
|
|
|
|
self.num_classes = out_channels
|
|
self.enc_bi_rnn = enc_bi_rnn
|
|
self.d_k = d_k
|
|
self.start_idx = out_channels - 2
|
|
self.padding_idx = out_channels - 1
|
|
self.max_seq_len = max_text_length
|
|
self.mask = mask
|
|
self.pred_concat = pred_concat
|
|
|
|
encoder_rnn_out_size = d_enc * (int(enc_bi_rnn) + 1)
|
|
decoder_rnn_out_size = encoder_rnn_out_size * (int(dec_bi_rnn) + 1)
|
|
|
|
# 2D attention layer
|
|
self.conv1x1_1 = nn.Linear(decoder_rnn_out_size, d_k)
|
|
self.conv3x3_1 = nn.Conv2d(
|
|
d_model, d_k, kernel_size=3, stride=1, padding=1)
|
|
self.conv1x1_2 = nn.Linear(d_k, 1)
|
|
|
|
# Decoder RNN layer
|
|
|
|
kwargs = dict(
|
|
input_size=encoder_rnn_out_size,
|
|
hidden_size=encoder_rnn_out_size,
|
|
num_layers=2,
|
|
batch_first=True,
|
|
dropout=dec_drop_rnn,
|
|
bidirectional=dec_bi_rnn)
|
|
if dec_gru:
|
|
self.rnn_decoder = nn.GRU(**kwargs)
|
|
else:
|
|
self.rnn_decoder = nn.LSTM(**kwargs)
|
|
|
|
# Decoder input embedding
|
|
self.embedding = nn.Embedding(
|
|
self.num_classes,
|
|
encoder_rnn_out_size,
|
|
padding_idx=self.padding_idx)
|
|
|
|
# Prediction layer
|
|
self.pred_dropout = nn.Dropout(pred_dropout)
|
|
pred_num_classes = self.num_classes - 1
|
|
if pred_concat:
|
|
fc_in_channel = decoder_rnn_out_size + d_model + encoder_rnn_out_size
|
|
else:
|
|
fc_in_channel = d_model
|
|
self.prediction = nn.Linear(fc_in_channel, pred_num_classes)
|
|
|
|
def _2d_attention(self,
|
|
decoder_input,
|
|
feat,
|
|
holistic_feat,
|
|
valid_ratios=None):
|
|
|
|
y = self.rnn_decoder(decoder_input)[0]
|
|
# y: bsz * (seq_len + 1) * hidden_size
|
|
|
|
attn_query = self.conv1x1_1(y) # bsz * (seq_len + 1) * attn_size
|
|
bsz, seq_len, attn_size = attn_query.shape
|
|
attn_query = attn_query.view(bsz, seq_len, attn_size, 1, 1)
|
|
# (bsz, seq_len + 1, attn_size, 1, 1)
|
|
|
|
attn_key = self.conv3x3_1(feat)
|
|
# bsz * attn_size * h * w
|
|
attn_key = attn_key.unsqueeze(1)
|
|
# bsz * 1 * attn_size * h * w
|
|
|
|
attn_weight = torch.tanh(torch.add(attn_key, attn_query, alpha=1))
|
|
|
|
# bsz * (seq_len + 1) * attn_size * h * w
|
|
attn_weight = attn_weight.permute(0, 1, 3, 4, 2).contiguous()
|
|
# bsz * (seq_len + 1) * h * w * attn_size
|
|
attn_weight = self.conv1x1_2(attn_weight)
|
|
# bsz * (seq_len + 1) * h * w * 1
|
|
bsz, T, h, w, c = attn_weight.size()
|
|
assert c == 1
|
|
|
|
if valid_ratios is not None:
|
|
# cal mask of attention weight
|
|
attn_mask = torch.zeros_like(attn_weight)
|
|
for i, valid_ratio in enumerate(valid_ratios):
|
|
valid_width = min(w, math.ceil(w * valid_ratio))
|
|
attn_mask[i, :, :, valid_width:, :] = 1
|
|
attn_weight = attn_weight.masked_fill(attn_mask.bool(),
|
|
float('-inf'))
|
|
# if valid_ratios is not None:
|
|
# # cal mask of attention weight
|
|
# for i in range(valid_ratios.size(0)):
|
|
# valid_width = torch.min(w, torch.ceil(w * valid_ratios[i]))
|
|
# # valid_width = paddle.minimum(
|
|
# # w, paddle.ceil(valid_ratios[i] * w).astype("int32"))
|
|
# if valid_width < w:
|
|
# attn_weight[i, :, :, valid_width:, :] = float('-inf')
|
|
|
|
attn_weight = attn_weight.view(bsz, T, -1)
|
|
attn_weight = F.softmax(attn_weight, dim=-1)
|
|
|
|
attn_weight = attn_weight.view(bsz, T, h, w,
|
|
c).permute(0, 1, 4, 2, 3).contiguous()
|
|
# attn_weight: bsz * T * c * h * w
|
|
# feat: bsz * c * h * w
|
|
attn_feat = torch.sum(
|
|
torch.mul(feat.unsqueeze(1), attn_weight), (3, 4), keepdim=False)
|
|
# bsz * (seq_len + 1) * C
|
|
|
|
# Linear transformation
|
|
if self.pred_concat:
|
|
hf_c = holistic_feat.size(-1)
|
|
holistic_feat = holistic_feat.expand(bsz, seq_len, hf_c)
|
|
y = self.prediction(torch.cat((y, attn_feat, holistic_feat), 2))
|
|
else:
|
|
y = self.prediction(attn_feat)
|
|
# bsz * (seq_len + 1) * num_classes
|
|
if self.train_mode:
|
|
y = self.pred_dropout(y)
|
|
|
|
return y
|
|
|
|
def forward_train(self, feat, out_enc, label, valid_ratios=None):
|
|
|
|
lab_embedding = self.embedding(label)
|
|
# bsz * seq_len * emb_dim
|
|
out_enc = out_enc.unsqueeze(1)
|
|
# bsz * 1 * emb_dim
|
|
in_dec = torch.cat((out_enc, lab_embedding), dim=1)
|
|
# bsz * (seq_len + 1) * C
|
|
out_dec = self._2d_attention(
|
|
in_dec, feat, out_enc, valid_ratios=valid_ratios)
|
|
|
|
return out_dec[:, 1:, :] # bsz * seq_len * num_classes
|
|
|
|
def forward_test(self, feat, out_enc, valid_ratios=None):
|
|
|
|
seq_len = self.max_seq_len
|
|
bsz = feat.shape[0]
|
|
start_token = torch.full((bsz, ),
|
|
self.start_idx,
|
|
device=feat.device,
|
|
dtype=torch.long)
|
|
# bsz
|
|
start_token = self.embedding(start_token)
|
|
# bsz * emb_dim
|
|
emb_dim = start_token.shape[1]
|
|
start_token = start_token.unsqueeze(1)
|
|
start_token = start_token.unsqueeze(1).expand(-1, seq_len, -1)
|
|
# bsz * seq_len * emb_dim
|
|
out_enc = out_enc.unsqueeze(1)
|
|
# bsz * 1 * emb_dim
|
|
decoder_input = torch.cat((out_enc, start_token), dim=1)
|
|
# bsz * (seq_len + 1) * emb_dim
|
|
|
|
outputs = []
|
|
for i in range(1, seq_len + 1):
|
|
decoder_output = self._2d_attention(
|
|
decoder_input, feat, out_enc, valid_ratios=valid_ratios)
|
|
char_output = decoder_output[:, i, :] # bsz * num_classes
|
|
char_output = F.softmax(char_output, -1)
|
|
outputs.append(char_output)
|
|
_, max_idx = torch.max(char_output, dim=1, keepdim=False)
|
|
char_embedding = self.embedding(max_idx) # bsz * emb_dim
|
|
if i < seq_len:
|
|
decoder_input[:, i + 1, :] = char_embedding
|
|
|
|
outputs = torch.stack(outputs, 1) # bsz * seq_len * num_classes
|
|
|
|
return outputs
|
|
|
|
|
|
@HEADS.register_module()
|
|
class SARHead(nn.Module):
|
|
|
|
def __init__(self,
|
|
in_channels,
|
|
out_channels,
|
|
enc_dim=512,
|
|
max_text_length=30,
|
|
enc_bi_rnn=False,
|
|
enc_drop_rnn=0.1,
|
|
enc_gru=False,
|
|
dec_bi_rnn=False,
|
|
dec_drop_rnn=0.0,
|
|
dec_gru=False,
|
|
d_k=512,
|
|
pred_dropout=0.1,
|
|
pred_concat=True,
|
|
**kwargs):
|
|
super(SARHead, self).__init__()
|
|
|
|
# encoder module
|
|
self.encoder = SAREncoder(
|
|
enc_bi_rnn=enc_bi_rnn,
|
|
enc_drop_rnn=enc_drop_rnn,
|
|
enc_gru=enc_gru,
|
|
d_model=in_channels,
|
|
d_enc=enc_dim)
|
|
|
|
# decoder module
|
|
self.decoder = ParallelSARDecoder(
|
|
out_channels=out_channels,
|
|
enc_bi_rnn=enc_bi_rnn,
|
|
dec_bi_rnn=dec_bi_rnn,
|
|
dec_drop_rnn=dec_drop_rnn,
|
|
dec_gru=dec_gru,
|
|
d_model=in_channels,
|
|
d_enc=enc_dim,
|
|
d_k=d_k,
|
|
pred_dropout=pred_dropout,
|
|
max_text_length=max_text_length,
|
|
pred_concat=pred_concat)
|
|
|
|
def forward(self, feat, label, valid_ratios=None):
|
|
'''
|
|
img_metas: [label, valid_ratio]
|
|
'''
|
|
holistic_feat = self.encoder(feat, valid_ratios) # bsz c
|
|
|
|
if self.training:
|
|
final_out = self.decoder(
|
|
feat, holistic_feat, label, valid_ratios=valid_ratios)
|
|
else:
|
|
final_out = self.decoder(
|
|
feat,
|
|
holistic_feat,
|
|
label=None,
|
|
valid_ratios=valid_ratios,
|
|
train_mode=False)
|
|
|
|
return final_out
|
|
|
|
|
|
@HEADS.register_module()
|
|
class CTCHead(nn.Module):
|
|
|
|
def __init__(self,
|
|
in_channels,
|
|
out_channels=6625,
|
|
fc_decay=0.0004,
|
|
mid_channels=None,
|
|
return_feats=False,
|
|
**kwargs):
|
|
super(CTCHead, self).__init__()
|
|
if mid_channels is None:
|
|
self.fc = nn.Linear(
|
|
in_channels,
|
|
out_channels,
|
|
bias=True,
|
|
)
|
|
else:
|
|
self.fc1 = nn.Linear(
|
|
in_channels,
|
|
mid_channels,
|
|
bias=True,
|
|
)
|
|
self.fc2 = nn.Linear(
|
|
mid_channels,
|
|
out_channels,
|
|
bias=True,
|
|
)
|
|
|
|
self.out_channels = out_channels
|
|
self.mid_channels = mid_channels
|
|
self.return_feats = return_feats
|
|
|
|
def forward(self, x, labels=None):
|
|
if self.mid_channels is None:
|
|
predicts = self.fc(x)
|
|
else:
|
|
x = self.fc1(x)
|
|
predicts = self.fc2(x)
|
|
|
|
if self.return_feats:
|
|
result = (x, predicts)
|
|
else:
|
|
result = predicts
|
|
|
|
if not self.training:
|
|
predicts = F.softmax(predicts, dim=2)
|
|
result = predicts
|
|
|
|
return result
|
|
|
|
|
|
@HEADS.register_module()
|
|
class MultiHead(nn.Module):
|
|
|
|
def __init__(self, in_channels, out_channels_list, **kwargs):
|
|
super().__init__()
|
|
self.head_list = kwargs.pop('head_list')
|
|
head_name = [head.type for head in self.head_list]
|
|
self.gtc_head = 'sar' if 'SARHead' in head_name else 'ctc'
|
|
# assert len(self.head_list) >= 2
|
|
for idx, head_name in enumerate(self.head_list):
|
|
name = head_name.type
|
|
if name == 'SARHead':
|
|
# sar head
|
|
sar_args = self.head_list[idx]
|
|
self.sar_head = eval(name)(
|
|
in_channels=in_channels,
|
|
out_channels=out_channels_list['SARLabelDecode'],
|
|
**sar_args)
|
|
elif name == 'CTCHead':
|
|
# ctc neck
|
|
self.encoder_reshape = Im2Seq(in_channels)
|
|
neck_args = self.head_list[idx].Neck
|
|
# encoder_type = neck_args.pop('type')
|
|
encoder_type = neck_args.get('type')
|
|
self.encoder = encoder_type
|
|
self.ctc_encoder = SequenceEncoder(
|
|
in_channels=in_channels,
|
|
encoder_type=encoder_type,
|
|
**neck_args)
|
|
# ctc head
|
|
head_args = self.head_list[idx].Head
|
|
self.ctc_head = eval(name)(
|
|
in_channels=self.ctc_encoder.out_channels,
|
|
out_channels=out_channels_list['CTCLabelDecode'],
|
|
**head_args)
|
|
else:
|
|
raise NotImplementedError(
|
|
'{} is not supported in MultiHead yet'.format(name))
|
|
|
|
def forward(self, x, label=None, valid_ratios=None):
|
|
ctc_encoder = self.ctc_encoder(x)
|
|
ctc_out = self.ctc_head(ctc_encoder)
|
|
head_out = dict()
|
|
head_out['ctc'] = ctc_out
|
|
head_out['ctc_neck'] = ctc_encoder
|
|
# eval mode
|
|
if not self.training:
|
|
return ctc_out
|
|
if self.gtc_head == 'sar':
|
|
sar_out = self.sar_head(x, label, valid_ratios)
|
|
head_out['sar'] = sar_out
|
|
return head_out
|
|
else:
|
|
return head_out
|