mirror of https://github.com/alibaba/EasyCV.git
fix face 2d keypoints devices bug
parent
a73667708c
commit
f9566c7027
|
@ -54,7 +54,7 @@ class WingLossWithPose(nn.Module):
|
|||
|
||||
self.part_weight = None
|
||||
if part_weight is not None:
|
||||
self.part_weight = torch.from_numpy(part_weight).cuda()
|
||||
self.part_weight = torch.from_numpy(part_weight)
|
||||
|
||||
def forward(self, pred, target, pose):
|
||||
weight = 5.0 * (1.0 - torch.cos(pose * np.pi / 180.0)) + 1.0
|
||||
|
|
Loading…
Reference in New Issue