add dtype param for arange API. (#10075)

This commit is contained in:
zxcd 2023-06-02 14:21:17 +08:00 committed by GitHub
parent a46a061082
commit 46a6950e7a
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -46,7 +46,7 @@ def positionalencoding2d(d_model, height, width):
# Each dimension use half of d_model
d_model = int(d_model / 2)
div_term = paddle.exp(
paddle.arange(0., d_model, 2) * -(math.log(10000.0) / d_model))
paddle.arange(0., d_model, 2, dtype='int64') * -(math.log(10000.0) / d_model))
pos_w = paddle.arange(0., width, dtype='float32').unsqueeze(1)
pos_h = paddle.arange(0., height, dtype='float32').unsqueeze(1)