mirror of https://github.com/WongKinYiu/yolov7.git
main code
fix nan of aux training https://github.com/WongKinYiu/yolov7/issues/250#issue-1312356380 @hudingdingpull/90/merge
parent
de6a5e733d
commit
4f6e390c99
|
@ -1218,47 +1218,56 @@ class ComputeLossAuxOTA:
|
|||
tobj_aux = torch.zeros_like(pi_aux[..., 0], device=device) # target obj
|
||||
|
||||
n = b.shape[0] # number of targets
|
||||
n_aux = b_aux.shape[0] # number of targets
|
||||
if n:
|
||||
ps = pi[b, a, gj, gi] # prediction subset corresponding to targets
|
||||
ps_aux = pi_aux[b_aux, a_aux, gj_aux, gi_aux] # prediction subset corresponding to targets
|
||||
|
||||
# Regression
|
||||
grid = torch.stack([gi, gj], dim=1)
|
||||
grid_aux = torch.stack([gi_aux, gj_aux], dim=1)
|
||||
pxy = ps[:, :2].sigmoid() * 2. - 0.5
|
||||
pxy_aux = ps_aux[:, :2].sigmoid() * 2. - 0.5
|
||||
#pxy = ps[:, :2].sigmoid() * 3. - 1.
|
||||
pwh = (ps[:, 2:4].sigmoid() * 2) ** 2 * anchors[i]
|
||||
pwh_aux = (ps_aux[:, 2:4].sigmoid() * 2) ** 2 * anchors_aux[i]
|
||||
pbox = torch.cat((pxy, pwh), 1) # predicted box
|
||||
pbox_aux = torch.cat((pxy_aux, pwh_aux), 1) # predicted box
|
||||
selected_tbox = targets[i][:, 2:6] * pre_gen_gains[i]
|
||||
selected_tbox_aux = targets_aux[i][:, 2:6] * pre_gen_gains_aux[i]
|
||||
selected_tbox[:, :2] -= grid
|
||||
selected_tbox_aux[:, :2] -= grid_aux
|
||||
iou = bbox_iou(pbox.T, selected_tbox, x1y1x2y2=False, CIoU=True) # iou(prediction, target)
|
||||
iou_aux = bbox_iou(pbox_aux.T, selected_tbox_aux, x1y1x2y2=False, CIoU=True) # iou(prediction, target)
|
||||
lbox += (1.0 - iou).mean() + 0.25 * (1.0 - iou_aux).mean() # iou loss
|
||||
lbox += (1.0 - iou).mean() # iou loss
|
||||
|
||||
# Objectness
|
||||
tobj[b, a, gj, gi] = (1.0 - self.gr) + self.gr * iou.detach().clamp(0).type(tobj.dtype) # iou ratio
|
||||
tobj_aux[b_aux, a_aux, gj_aux, gi_aux] = (1.0 - self.gr) + self.gr * iou_aux.detach().clamp(0).type(tobj_aux.dtype) # iou ratio
|
||||
|
||||
# Classification
|
||||
selected_tcls = targets[i][:, 1].long()
|
||||
selected_tcls_aux = targets_aux[i][:, 1].long()
|
||||
if self.nc > 1: # cls loss (only if multiple classes)
|
||||
t = torch.full_like(ps[:, 5:], self.cn, device=device) # targets
|
||||
t_aux = torch.full_like(ps_aux[:, 5:], self.cn, device=device) # targets
|
||||
t[range(n), selected_tcls] = self.cp
|
||||
t_aux[range(n_aux), selected_tcls_aux] = self.cp
|
||||
lcls += self.BCEcls(ps[:, 5:], t) + 0.25 * self.BCEcls(ps_aux[:, 5:], t_aux) # BCE
|
||||
lcls += self.BCEcls(ps[:, 5:], t) # BCE
|
||||
|
||||
# Append targets to text file
|
||||
# with open('targets.txt', 'a') as file:
|
||||
# [file.write('%11.5g ' * 4 % tuple(x) + '\n') for x in torch.cat((txy[i], twh[i]), 1)]
|
||||
|
||||
n_aux = b_aux.shape[0] # number of targets
|
||||
if n_aux:
|
||||
ps_aux = pi_aux[b_aux, a_aux, gj_aux, gi_aux] # prediction subset corresponding to targets
|
||||
grid_aux = torch.stack([gi_aux, gj_aux], dim=1)
|
||||
pxy_aux = ps_aux[:, :2].sigmoid() * 2. - 0.5
|
||||
#pxy_aux = ps_aux[:, :2].sigmoid() * 3. - 1.
|
||||
pwh_aux = (ps_aux[:, 2:4].sigmoid() * 2) ** 2 * anchors_aux[i]
|
||||
pbox_aux = torch.cat((pxy_aux, pwh_aux), 1) # predicted box
|
||||
selected_tbox_aux = targets_aux[i][:, 2:6] * pre_gen_gains_aux[i]
|
||||
selected_tbox_aux[:, :2] -= grid_aux
|
||||
iou_aux = bbox_iou(pbox_aux.T, selected_tbox_aux, x1y1x2y2=False, CIoU=True) # iou(prediction, target)
|
||||
lbox += 0.25 * (1.0 - iou_aux).mean() # iou loss
|
||||
|
||||
# Objectness
|
||||
tobj_aux[b_aux, a_aux, gj_aux, gi_aux] = (1.0 - self.gr) + self.gr * iou_aux.detach().clamp(0).type(tobj_aux.dtype) # iou ratio
|
||||
|
||||
# Classification
|
||||
selected_tcls_aux = targets_aux[i][:, 1].long()
|
||||
if self.nc > 1: # cls loss (only if multiple classes)
|
||||
t_aux = torch.full_like(ps_aux[:, 5:], self.cn, device=device) # targets
|
||||
t_aux[range(n_aux), selected_tcls_aux] = self.cp
|
||||
lcls += 0.25 * self.BCEcls(ps_aux[:, 5:], t_aux) # BCE
|
||||
|
||||
obji = self.BCEobj(pi[..., 4], tobj)
|
||||
obji_aux = self.BCEobj(pi_aux[..., 4], tobj_aux)
|
||||
lobj += obji * self.balance[i] + 0.25 * obji_aux * self.balance[i] # obj loss
|
||||
|
|
Loading…
Reference in New Issue