fix face 2d keypoints devices bug

pull/200/head
shouzhou.bx 2022-09-20 16:10:23 +08:00
parent a73667708c
commit f9566c7027
1 changed files with 1 additions and 1 deletions

View File

@ -54,7 +54,7 @@ class WingLossWithPose(nn.Module):
self.part_weight = None
if part_weight is not None:
self.part_weight = torch.from_numpy(part_weight).cuda()
self.part_weight = torch.from_numpy(part_weight)
def forward(self, pred, target, pose):
weight = 5.0 * (1.0 - torch.cos(pose * np.pi / 180.0)) + 1.0