fix aster loss for axis (#8674)

pull/8689/head
xiaoting 2022-12-20 13:22:03 +08:00 committed by GitHub
parent b28af5d865
commit 23e034c40e
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 1 additions and 1 deletions

View File

@ -28,7 +28,7 @@ class CosineEmbeddingLoss(nn.Layer):
def forward(self, x1, x2, target):
similarity = paddle.sum(
x1 * x2, dim=-1) / (paddle.norm(
x1 * x2, axis=-1) / (paddle.norm(
x1, axis=-1) * paddle.norm(
x2, axis=-1) + self.epsilon)
one_list = paddle.full_like(target, fill_value=1)