mirror of
https://github.com/PaddlePaddle/PaddleClas.git
synced 2025-06-03 21:55:06 +08:00
Update arcmargin.py
This commit is contained in:
parent
77a2c4571c
commit
f61a64aab6
@ -36,7 +36,7 @@ class ArcMargin(nn.Layer):
|
|||||||
input_norm = paddle.sqrt(paddle.sum(paddle.square(input), axis=1, keepdim=True))
|
input_norm = paddle.sqrt(paddle.sum(paddle.square(input), axis=1, keepdim=True))
|
||||||
input = paddle.divide(input, input_norm)
|
input = paddle.divide(input, input_norm)
|
||||||
|
|
||||||
weight = self.fc0.weight
|
weight = self.fc.weight
|
||||||
weight_norm = paddle.sqrt(paddle.sum(paddle.square(weight), axis=0, keepdim=True))
|
weight_norm = paddle.sqrt(paddle.sum(paddle.square(weight), axis=0, keepdim=True))
|
||||||
weight = paddle.divide(weight, weight_norm)
|
weight = paddle.divide(weight, weight_norm)
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user