mirror of
https://github.com/WongKinYiu/yolov7.git
synced 2025-06-03 21:54:57 +08:00
main code
fix nan of aux training https://github.com/WongKinYiu/yolov7/issues/250#issue-1312356380 @hudingding
This commit is contained in:
parent
de6a5e733d
commit
4f6e390c99
@ -1218,47 +1218,56 @@ class ComputeLossAuxOTA:
|
|||||||
tobj_aux = torch.zeros_like(pi_aux[..., 0], device=device) # target obj
|
tobj_aux = torch.zeros_like(pi_aux[..., 0], device=device) # target obj
|
||||||
|
|
||||||
n = b.shape[0] # number of targets
|
n = b.shape[0] # number of targets
|
||||||
n_aux = b_aux.shape[0] # number of targets
|
|
||||||
if n:
|
if n:
|
||||||
ps = pi[b, a, gj, gi] # prediction subset corresponding to targets
|
ps = pi[b, a, gj, gi] # prediction subset corresponding to targets
|
||||||
ps_aux = pi_aux[b_aux, a_aux, gj_aux, gi_aux] # prediction subset corresponding to targets
|
|
||||||
|
|
||||||
# Regression
|
# Regression
|
||||||
grid = torch.stack([gi, gj], dim=1)
|
grid = torch.stack([gi, gj], dim=1)
|
||||||
grid_aux = torch.stack([gi_aux, gj_aux], dim=1)
|
|
||||||
pxy = ps[:, :2].sigmoid() * 2. - 0.5
|
pxy = ps[:, :2].sigmoid() * 2. - 0.5
|
||||||
pxy_aux = ps_aux[:, :2].sigmoid() * 2. - 0.5
|
|
||||||
#pxy = ps[:, :2].sigmoid() * 3. - 1.
|
|
||||||
pwh = (ps[:, 2:4].sigmoid() * 2) ** 2 * anchors[i]
|
pwh = (ps[:, 2:4].sigmoid() * 2) ** 2 * anchors[i]
|
||||||
pwh_aux = (ps_aux[:, 2:4].sigmoid() * 2) ** 2 * anchors_aux[i]
|
|
||||||
pbox = torch.cat((pxy, pwh), 1) # predicted box
|
pbox = torch.cat((pxy, pwh), 1) # predicted box
|
||||||
pbox_aux = torch.cat((pxy_aux, pwh_aux), 1) # predicted box
|
|
||||||
selected_tbox = targets[i][:, 2:6] * pre_gen_gains[i]
|
selected_tbox = targets[i][:, 2:6] * pre_gen_gains[i]
|
||||||
selected_tbox_aux = targets_aux[i][:, 2:6] * pre_gen_gains_aux[i]
|
|
||||||
selected_tbox[:, :2] -= grid
|
selected_tbox[:, :2] -= grid
|
||||||
selected_tbox_aux[:, :2] -= grid_aux
|
|
||||||
iou = bbox_iou(pbox.T, selected_tbox, x1y1x2y2=False, CIoU=True) # iou(prediction, target)
|
iou = bbox_iou(pbox.T, selected_tbox, x1y1x2y2=False, CIoU=True) # iou(prediction, target)
|
||||||
iou_aux = bbox_iou(pbox_aux.T, selected_tbox_aux, x1y1x2y2=False, CIoU=True) # iou(prediction, target)
|
lbox += (1.0 - iou).mean() # iou loss
|
||||||
lbox += (1.0 - iou).mean() + 0.25 * (1.0 - iou_aux).mean() # iou loss
|
|
||||||
|
|
||||||
# Objectness
|
# Objectness
|
||||||
tobj[b, a, gj, gi] = (1.0 - self.gr) + self.gr * iou.detach().clamp(0).type(tobj.dtype) # iou ratio
|
tobj[b, a, gj, gi] = (1.0 - self.gr) + self.gr * iou.detach().clamp(0).type(tobj.dtype) # iou ratio
|
||||||
tobj_aux[b_aux, a_aux, gj_aux, gi_aux] = (1.0 - self.gr) + self.gr * iou_aux.detach().clamp(0).type(tobj_aux.dtype) # iou ratio
|
|
||||||
|
|
||||||
# Classification
|
# Classification
|
||||||
selected_tcls = targets[i][:, 1].long()
|
selected_tcls = targets[i][:, 1].long()
|
||||||
selected_tcls_aux = targets_aux[i][:, 1].long()
|
|
||||||
if self.nc > 1: # cls loss (only if multiple classes)
|
if self.nc > 1: # cls loss (only if multiple classes)
|
||||||
t = torch.full_like(ps[:, 5:], self.cn, device=device) # targets
|
t = torch.full_like(ps[:, 5:], self.cn, device=device) # targets
|
||||||
t_aux = torch.full_like(ps_aux[:, 5:], self.cn, device=device) # targets
|
|
||||||
t[range(n), selected_tcls] = self.cp
|
t[range(n), selected_tcls] = self.cp
|
||||||
t_aux[range(n_aux), selected_tcls_aux] = self.cp
|
lcls += self.BCEcls(ps[:, 5:], t) # BCE
|
||||||
lcls += self.BCEcls(ps[:, 5:], t) + 0.25 * self.BCEcls(ps_aux[:, 5:], t_aux) # BCE
|
|
||||||
|
|
||||||
# Append targets to text file
|
# Append targets to text file
|
||||||
# with open('targets.txt', 'a') as file:
|
# with open('targets.txt', 'a') as file:
|
||||||
# [file.write('%11.5g ' * 4 % tuple(x) + '\n') for x in torch.cat((txy[i], twh[i]), 1)]
|
# [file.write('%11.5g ' * 4 % tuple(x) + '\n') for x in torch.cat((txy[i], twh[i]), 1)]
|
||||||
|
|
||||||
|
n_aux = b_aux.shape[0] # number of targets
|
||||||
|
if n_aux:
|
||||||
|
ps_aux = pi_aux[b_aux, a_aux, gj_aux, gi_aux] # prediction subset corresponding to targets
|
||||||
|
grid_aux = torch.stack([gi_aux, gj_aux], dim=1)
|
||||||
|
pxy_aux = ps_aux[:, :2].sigmoid() * 2. - 0.5
|
||||||
|
#pxy_aux = ps_aux[:, :2].sigmoid() * 3. - 1.
|
||||||
|
pwh_aux = (ps_aux[:, 2:4].sigmoid() * 2) ** 2 * anchors_aux[i]
|
||||||
|
pbox_aux = torch.cat((pxy_aux, pwh_aux), 1) # predicted box
|
||||||
|
selected_tbox_aux = targets_aux[i][:, 2:6] * pre_gen_gains_aux[i]
|
||||||
|
selected_tbox_aux[:, :2] -= grid_aux
|
||||||
|
iou_aux = bbox_iou(pbox_aux.T, selected_tbox_aux, x1y1x2y2=False, CIoU=True) # iou(prediction, target)
|
||||||
|
lbox += 0.25 * (1.0 - iou_aux).mean() # iou loss
|
||||||
|
|
||||||
|
# Objectness
|
||||||
|
tobj_aux[b_aux, a_aux, gj_aux, gi_aux] = (1.0 - self.gr) + self.gr * iou_aux.detach().clamp(0).type(tobj_aux.dtype) # iou ratio
|
||||||
|
|
||||||
|
# Classification
|
||||||
|
selected_tcls_aux = targets_aux[i][:, 1].long()
|
||||||
|
if self.nc > 1: # cls loss (only if multiple classes)
|
||||||
|
t_aux = torch.full_like(ps_aux[:, 5:], self.cn, device=device) # targets
|
||||||
|
t_aux[range(n_aux), selected_tcls_aux] = self.cp
|
||||||
|
lcls += 0.25 * self.BCEcls(ps_aux[:, 5:], t_aux) # BCE
|
||||||
|
|
||||||
obji = self.BCEobj(pi[..., 4], tobj)
|
obji = self.BCEobj(pi[..., 4], tobj)
|
||||||
obji_aux = self.BCEobj(pi_aux[..., 4], tobj_aux)
|
obji_aux = self.BCEobj(pi_aux[..., 4], tobj_aux)
|
||||||
lobj += obji * self.balance[i] + 0.25 * obji_aux * self.balance[i] # obj loss
|
lobj += obji * self.balance[i] + 0.25 * obji_aux * self.balance[i] # obj loss
|
||||||
|
Loading…
x
Reference in New Issue
Block a user