main code

fix nan of aux training https://github.com/WongKinYiu/yolov7/issues/250#issue-1312356380 @hudingding
pull/90/merge
Kin-Yiu, Wong 2022-07-21 12:09:25 +08:00 committed by GitHub
parent de6a5e733d
commit 4f6e390c99
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 25 additions and 16 deletions

View File

@ -1218,47 +1218,56 @@ class ComputeLossAuxOTA:
tobj_aux = torch.zeros_like(pi_aux[..., 0], device=device) # target obj
n = b.shape[0] # number of targets
n_aux = b_aux.shape[0] # number of targets
if n:
ps = pi[b, a, gj, gi] # prediction subset corresponding to targets
ps_aux = pi_aux[b_aux, a_aux, gj_aux, gi_aux] # prediction subset corresponding to targets
# Regression
grid = torch.stack([gi, gj], dim=1)
grid_aux = torch.stack([gi_aux, gj_aux], dim=1)
pxy = ps[:, :2].sigmoid() * 2. - 0.5
pxy_aux = ps_aux[:, :2].sigmoid() * 2. - 0.5
#pxy = ps[:, :2].sigmoid() * 3. - 1.
pwh = (ps[:, 2:4].sigmoid() * 2) ** 2 * anchors[i]
pwh_aux = (ps_aux[:, 2:4].sigmoid() * 2) ** 2 * anchors_aux[i]
pbox = torch.cat((pxy, pwh), 1) # predicted box
pbox_aux = torch.cat((pxy_aux, pwh_aux), 1) # predicted box
selected_tbox = targets[i][:, 2:6] * pre_gen_gains[i]
selected_tbox_aux = targets_aux[i][:, 2:6] * pre_gen_gains_aux[i]
selected_tbox[:, :2] -= grid
selected_tbox_aux[:, :2] -= grid_aux
iou = bbox_iou(pbox.T, selected_tbox, x1y1x2y2=False, CIoU=True) # iou(prediction, target)
iou_aux = bbox_iou(pbox_aux.T, selected_tbox_aux, x1y1x2y2=False, CIoU=True) # iou(prediction, target)
lbox += (1.0 - iou).mean() + 0.25 * (1.0 - iou_aux).mean() # iou loss
lbox += (1.0 - iou).mean() # iou loss
# Objectness
tobj[b, a, gj, gi] = (1.0 - self.gr) + self.gr * iou.detach().clamp(0).type(tobj.dtype) # iou ratio
tobj_aux[b_aux, a_aux, gj_aux, gi_aux] = (1.0 - self.gr) + self.gr * iou_aux.detach().clamp(0).type(tobj_aux.dtype) # iou ratio
# Classification
selected_tcls = targets[i][:, 1].long()
selected_tcls_aux = targets_aux[i][:, 1].long()
if self.nc > 1: # cls loss (only if multiple classes)
t = torch.full_like(ps[:, 5:], self.cn, device=device) # targets
t_aux = torch.full_like(ps_aux[:, 5:], self.cn, device=device) # targets
t[range(n), selected_tcls] = self.cp
t_aux[range(n_aux), selected_tcls_aux] = self.cp
lcls += self.BCEcls(ps[:, 5:], t) + 0.25 * self.BCEcls(ps_aux[:, 5:], t_aux) # BCE
lcls += self.BCEcls(ps[:, 5:], t) # BCE
# Append targets to text file
# with open('targets.txt', 'a') as file:
# [file.write('%11.5g ' * 4 % tuple(x) + '\n') for x in torch.cat((txy[i], twh[i]), 1)]
n_aux = b_aux.shape[0] # number of targets
if n_aux:
ps_aux = pi_aux[b_aux, a_aux, gj_aux, gi_aux] # prediction subset corresponding to targets
grid_aux = torch.stack([gi_aux, gj_aux], dim=1)
pxy_aux = ps_aux[:, :2].sigmoid() * 2. - 0.5
#pxy_aux = ps_aux[:, :2].sigmoid() * 3. - 1.
pwh_aux = (ps_aux[:, 2:4].sigmoid() * 2) ** 2 * anchors_aux[i]
pbox_aux = torch.cat((pxy_aux, pwh_aux), 1) # predicted box
selected_tbox_aux = targets_aux[i][:, 2:6] * pre_gen_gains_aux[i]
selected_tbox_aux[:, :2] -= grid_aux
iou_aux = bbox_iou(pbox_aux.T, selected_tbox_aux, x1y1x2y2=False, CIoU=True) # iou(prediction, target)
lbox += 0.25 * (1.0 - iou_aux).mean() # iou loss
# Objectness
tobj_aux[b_aux, a_aux, gj_aux, gi_aux] = (1.0 - self.gr) + self.gr * iou_aux.detach().clamp(0).type(tobj_aux.dtype) # iou ratio
# Classification
selected_tcls_aux = targets_aux[i][:, 1].long()
if self.nc > 1: # cls loss (only if multiple classes)
t_aux = torch.full_like(ps_aux[:, 5:], self.cn, device=device) # targets
t_aux[range(n_aux), selected_tcls_aux] = self.cp
lcls += 0.25 * self.BCEcls(ps_aux[:, 5:], t_aux) # BCE
obji = self.BCEobj(pi[..., 4], tobj)
obji_aux = self.BCEobj(pi_aux[..., 4], tobj_aux)
lobj += obji * self.balance[i] + 0.25 * obji_aux * self.balance[i] # obj loss