Update arcmargin.py

This commit is contained in:
Bin Lu 2021-05-31 20:46:45 +08:00 committed by GitHub
parent 77a2c4571c
commit f61a64aab6
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -36,7 +36,7 @@ class ArcMargin(nn.Layer):
input_norm = paddle.sqrt(paddle.sum(paddle.square(input), axis=1, keepdim=True)) input_norm = paddle.sqrt(paddle.sum(paddle.square(input), axis=1, keepdim=True))
input = paddle.divide(input, input_norm) input = paddle.divide(input, input_norm)
weight = self.fc0.weight weight = self.fc.weight
weight_norm = paddle.sqrt(paddle.sum(paddle.square(weight), axis=0, keepdim=True)) weight_norm = paddle.sqrt(paddle.sum(paddle.square(weight), axis=0, keepdim=True))
weight = paddle.divide(weight, weight_norm) weight = paddle.divide(weight, weight_norm)