Update loss for FP16 tobj (#7088)

This commit is contained in:
Glenn Jocher 2022-03-21 19:18:34 +01:00 committed by GitHub
parent 6f128031d0
commit a2d617ece9
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -125,7 +125,7 @@ class ComputeLoss:
# Losses
for i, pi in enumerate(p): # layer index, layer predictions
b, a, gj, gi = indices[i] # image, anchor, gridy, gridx
tobj = torch.zeros(pi.shape[:4], device=self.device) # target obj
tobj = torch.zeros(pi.shape[:4], dtype=pi.dtype, device=self.device) # target obj
n = b.shape[0] # number of targets
if n: