mirror of
https://github.com/PaddlePaddle/PaddleOCR.git
synced 2025-06-03 21:53:39 +08:00
Update rec_nrtr_optim_head.py
This commit is contained in:
parent
c635925895
commit
c8094e6575
@ -216,7 +216,7 @@ class TransformerOptim(nn.Layer):
|
|||||||
new_shape = (n_curr_active_inst * n_bm, *d_hs)
|
new_shape = (n_curr_active_inst * n_bm, *d_hs)
|
||||||
|
|
||||||
beamed_tensor = beamed_tensor.reshape(
|
beamed_tensor = beamed_tensor.reshape(
|
||||||
[n_prev_active_inst, -1]) #contiguous()
|
[n_prev_active_inst, -1])
|
||||||
beamed_tensor = beamed_tensor.index_select(
|
beamed_tensor = beamed_tensor.index_select(
|
||||||
paddle.to_tensor(curr_active_inst_idx), axis=0)
|
paddle.to_tensor(curr_active_inst_idx), axis=0)
|
||||||
beamed_tensor = beamed_tensor.reshape([*new_shape])
|
beamed_tensor = beamed_tensor.reshape([*new_shape])
|
||||||
@ -337,7 +337,7 @@ class TransformerOptim(nn.Layer):
|
|||||||
n_inst, len_s, d_h = src_enc.shape
|
n_inst, len_s, d_h = src_enc.shape
|
||||||
src_enc = paddle.concat([src_enc for i in range(n_bm)], axis=1)
|
src_enc = paddle.concat([src_enc for i in range(n_bm)], axis=1)
|
||||||
src_enc = src_enc.reshape([n_inst * n_bm, len_s, d_h]).transpose(
|
src_enc = src_enc.reshape([n_inst * n_bm, len_s, d_h]).transpose(
|
||||||
[1, 0, 2]) #repeat(1, n_bm, 1)
|
[1, 0, 2])
|
||||||
#-- Prepare beams
|
#-- Prepare beams
|
||||||
inst_dec_beams = [Beam(n_bm) for _ in range(n_inst)]
|
inst_dec_beams = [Beam(n_bm) for _ in range(n_inst)]
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user