fix a bug for sar decoder when bi-rnn is used (#690)

* fix bug for sar when bi-rnn is used

* Update sar_decoder.py

fix lint check
pull/692/head
Minghui Liao 2021-12-26 21:43:06 -05:00 committed by GitHub
parent fb1892a1ae
commit f66336761c
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 2 additions and 1 deletions

View File

@ -98,7 +98,8 @@ class ParallelSARDecoder(BaseDecoder):
self.pred_dropout = nn.Dropout(pred_dropout)
pred_num_classes = num_classes - 1 # ignore padding_idx in prediction
if pred_concat:
fc_in_channel = decoder_rnn_out_size + d_model + d_enc
fc_in_channel = decoder_rnn_out_size + d_model + \
encoder_rnn_out_size
else:
fc_in_channel = d_model
self.prediction = nn.Linear(fc_in_channel, pred_num_classes)