paddle.shape return int64 tensor (#14318)
parent
83323b55d5
commit
9c01b43301
|
@ -242,7 +242,7 @@ class ParallelSARDecoder(BaseDecoder):
|
|||
# bsz * (seq_len + 1) * h * w * attn_size
|
||||
attn_weight = self.conv1x1_2(attn_weight)
|
||||
# bsz * (seq_len + 1) * h * w * 1
|
||||
bsz, T, h, w, c = paddle.shape(attn_weight)
|
||||
bsz, T, h, w, c = paddle.shape(attn_weight).astype("int32")
|
||||
assert c == 1
|
||||
|
||||
if valid_ratios is not None:
|
||||
|
|
Loading…
Reference in New Issue