mirror of https://github.com/YifanXu74/MQ-Det.git
116 lines
5.7 KiB
Python
116 lines
5.7 KiB
Python
from copy import deepcopy
|
|
import numpy as np
|
|
import torch
|
|
from torch import nn
|
|
|
|
|
|
class RNNEnoder(nn.Module):
|
|
def __init__(self, cfg):
|
|
super(RNNEnoder, self).__init__()
|
|
self.cfg = cfg
|
|
|
|
self.rnn_type = cfg.MODEL.LANGUAGE_BACKBONE.RNN_TYPE
|
|
self.variable_length = cfg.MODEL.LANGUAGE_BACKBONE.VARIABLE_LENGTH
|
|
self.word_embedding_size = cfg.MODEL.LANGUAGE_BACKBONE.WORD_EMBEDDING_SIZE
|
|
self.word_vec_size = cfg.MODEL.LANGUAGE_BACKBONE.WORD_VEC_SIZE
|
|
self.hidden_size = cfg.MODEL.LANGUAGE_BACKBONE.HIDDEN_SIZE
|
|
self.bidirectional = cfg.MODEL.LANGUAGE_BACKBONE.BIDIRECTIONAL
|
|
self.input_dropout_p = cfg.MODEL.LANGUAGE_BACKBONE.INPUT_DROPOUT_P
|
|
self.dropout_p = cfg.MODEL.LANGUAGE_BACKBONE.DROPOUT_P
|
|
self.n_layers = cfg.MODEL.LANGUAGE_BACKBONE.N_LAYERS
|
|
self.corpus_path = cfg.MODEL.LANGUAGE_BACKBONE.CORPUS_PATH
|
|
self.vocab_size = cfg.MODEL.LANGUAGE_BACKBONE.VOCAB_SIZE
|
|
|
|
# language encoder
|
|
self.embedding = nn.Embedding(self.vocab_size, self.word_embedding_size)
|
|
self.input_dropout = nn.Dropout(self.input_dropout_p)
|
|
self.mlp = nn.Sequential(nn.Linear(self.word_embedding_size, self.word_vec_size), nn.ReLU())
|
|
self.rnn = getattr(nn, self.rnn_type.upper())(self.word_vec_size,
|
|
self.hidden_size,
|
|
self.n_layers,
|
|
batch_first=True,
|
|
bidirectional=self.bidirectional,
|
|
dropout=self.dropout_p)
|
|
self.num_dirs = 2 if self.bidirectional else 1
|
|
|
|
def forward(self, input, mask=None):
|
|
word_id = input
|
|
max_len = (word_id != 0).sum(1).max().item()
|
|
word_id = word_id[:, :max_len] # mask zero
|
|
# embedding
|
|
output, hidden, embedded, final_output = self.RNNEncode(word_id)
|
|
return {
|
|
'hidden': hidden,
|
|
'output': output,
|
|
'embedded': embedded,
|
|
'final_output': final_output,
|
|
}
|
|
|
|
def encode(self, input_labels):
|
|
"""
|
|
Inputs:
|
|
- input_labels: Variable long (batch, seq_len)
|
|
Outputs:
|
|
- output : Variable float (batch, max_len, hidden_size * num_dirs)
|
|
- hidden : Variable float (batch, num_layers * num_dirs * hidden_size)
|
|
- embedded: Variable float (batch, max_len, word_vec_size)
|
|
"""
|
|
device = input_labels.device
|
|
if self.variable_length:
|
|
input_lengths_list, sorted_lengths_list, sort_idxs, recover_idxs = self.sort_inputs(input_labels)
|
|
input_labels = input_labels[sort_idxs]
|
|
|
|
embedded = self.embedding(input_labels) # (n, seq_len, word_embedding_size)
|
|
embedded = self.input_dropout(embedded) # (n, seq_len, word_embedding_size)
|
|
embedded = self.mlp(embedded) # (n, seq_len, word_vec_size)
|
|
|
|
if self.variable_length:
|
|
if self.variable_length:
|
|
embedded = nn.utils.rnn.pack_padded_sequence(embedded, \
|
|
sorted_lengths_list, \
|
|
batch_first=True)
|
|
# forward rnn
|
|
self.rnn.flatten_parameters()
|
|
output, hidden = self.rnn(embedded)
|
|
|
|
# recover
|
|
if self.variable_length:
|
|
# recover embedded
|
|
embedded, _ = nn.utils.rnn.pad_packed_sequence(embedded,
|
|
batch_first=True) # (batch, max_len, word_vec_size)
|
|
embedded = embedded[recover_idxs]
|
|
|
|
# recover output
|
|
output, _ = nn.utils.rnn.pad_packed_sequence(output,
|
|
batch_first=True) # (batch, max_len, hidden_size * num_dir)
|
|
output = output[recover_idxs]
|
|
|
|
# recover hidden
|
|
if self.rnn_type == 'lstm':
|
|
hidden = hidden[0] # hidden state
|
|
hidden = hidden[:, recover_idxs, :] # (num_layers * num_dirs, batch, hidden_size)
|
|
hidden = hidden.transpose(0, 1).contiguous() # (batch, num_layers * num_dirs, hidden_size)
|
|
hidden = hidden.view(hidden.size(0), -1) # (batch, num_layers * num_dirs * hidden_size)
|
|
|
|
# final output
|
|
finnal_output = []
|
|
for ii in range(output.shape[0]):
|
|
finnal_output.append(output[ii, int(input_lengths_list[ii] - 1), :])
|
|
finnal_output = torch.stack(finnal_output, dim=0) # (batch, number_dirs * hidden_size)
|
|
|
|
return output, hidden, embedded, finnal_output
|
|
|
|
def sort_inputs(self, input_labels): # sort input labels by descending
|
|
device = input_labels.device
|
|
input_lengths = (input_labels != 0).sum(1)
|
|
input_lengths_list = input_lengths.data.cpu().numpy().tolist()
|
|
sorted_input_lengths_list = np.sort(input_lengths_list)[::-1].tolist() # list of sorted input_lengths
|
|
sort_idxs = np.argsort(input_lengths_list)[::-1].tolist()
|
|
s2r = {s: r for r, s in enumerate(sort_idxs)}
|
|
recover_idxs = [s2r[s] for s in range(len(input_lengths_list))]
|
|
assert max(input_lengths_list) == input_labels.size(1)
|
|
# move to long tensor
|
|
sort_idxs = input_labels.data.new(sort_idxs).long().to(device) # Variable long
|
|
recover_idxs = input_labels.data.new(recover_idxs).long().to(device) # Variable long
|
|
return input_lengths_list, sorted_input_lengths_list, sort_idxs, recover_idxs
|