mirror of https://github.com/WongKinYiu/yolov7.git
incoroperating CrossEntropy() loss with designated FocalLoss. w/ weighted loss
Adding Asymentrical loss for Multi-label hasn;t tested yet!pull/2071/head
parent
0a12c34633
commit
565c17e6a4
|
@ -5,7 +5,7 @@ weight_decay: 0.005 # optimizer weight decay 5e-4 It resolve mAP of overfittin
|
|||
warmup_epochs: 3.0 # warmup epochs (fractions ok)
|
||||
warmup_momentum: 0.8 # warmup initial momentum
|
||||
warmup_bias_lr: 0.001 #0.001 # warmup initial bias lr
|
||||
loss_ota: 1 #1 # use ComputeLossOTA, use 0 for faster training
|
||||
loss_ota: 0 #1 # use ComputeLossOTA, use 0 for faster training
|
||||
box: 0.05 # box loss gain
|
||||
cls: 0.5 # cls loss gain
|
||||
cls_pw: 1.0 # cls BCELoss positive_weight
|
||||
|
|
2
test.py
2
test.py
|
@ -441,7 +441,7 @@ def test(data,
|
|||
|
||||
|
||||
# Plot images aa = np.repeat(img[0,:,:,:].cpu().permute(1,2,0).numpy(), 3, axis=2).astype('float32') cv2.imwrite('test/exp40/test_batch88_labels__.jpg', aa*255)
|
||||
if (plots and batch_i > 10) or 1:
|
||||
if (plots and batch_i > 10):
|
||||
# conf_thresh_plot = 0.1 # the plot threshold the connfidence
|
||||
f = save_dir / f'test_batch{batch_i}_labels.jpg' # labels
|
||||
Thread(target=plot_images, args=(img, targets, paths, f, names), daemon=True).start()
|
||||
|
|
20
train.py
20
train.py
|
@ -458,7 +458,10 @@ def train(hyp, opt, device, tb_writer=None):
|
|||
else:
|
||||
scaler = torch.amp.GradScaler("cuda", enabled=opt.amp) if is_torch_240 else torch.cuda.amp.GradScaler(enabled=opt.amp)
|
||||
|
||||
loss_weight = torch.tensor([])
|
||||
loss_weight = torch.tensor([]) # for BCE
|
||||
if opt.multi_class_no_multi_label:
|
||||
loss_weight = torch.ones(1)
|
||||
|
||||
if opt.loss_weight:
|
||||
loss_weight = class_inverse_freq
|
||||
if 0:
|
||||
|
@ -466,8 +469,12 @@ def train(hyp, opt, device, tb_writer=None):
|
|||
# Replaced YOLO classification loss with Focal Loss using per-class α values. Kept Objectness Loss and BBox Loss unchanged.
|
||||
if 'loss_ota' not in hyp or hyp['loss_ota'] == 1:
|
||||
compute_loss_ota = ComputeLossOTA(model, loss_weight=loss_weight) # init loss class
|
||||
if opt.multi_class_no_multi_label:
|
||||
raise ValueError('Not imp yet!')
|
||||
|
||||
compute_loss = ComputeLoss(model, loss_weight=loss_weight) # init loss class it is required for the test set as well hance mandatory
|
||||
compute_loss = ComputeLoss(device, model, loss_weight=loss_weight,
|
||||
multi_class_no_multi_label=opt.multi_class_no_multi_label,
|
||||
multi_label_asymetric_focal_loss=opt.multi_label_asymetric_focal_loss) # init loss class it is required for the test set as well hance mandatory
|
||||
|
||||
logger.info(f'Image sizes {imgsz} train, {imgsz_test} test\n'
|
||||
f'Using {dataloader.num_workers} dataloader workers\n'
|
||||
|
@ -621,7 +628,7 @@ def train(hyp, opt, device, tb_writer=None):
|
|||
# end batch ------------------------------------------------------------------------------------------------
|
||||
# end epoch ----------------------------------------------------------------------------------------------------
|
||||
if epoch >= opt.ohem_start_ep and opt.ohem_start_ep >0:
|
||||
# dataloader_orig = copy.deepcopy(dataloader)
|
||||
dataloader_orig = copy.deepcopy(dataloader)
|
||||
|
||||
print('OHEM')
|
||||
|
||||
|
@ -837,9 +844,16 @@ if __name__ == '__main__':
|
|||
|
||||
parser.add_argument('--cosine-anneal', action='store_true', help='')
|
||||
|
||||
parser.add_argument('--multi-class-no-multi-label', action='store_true', help='disbale multi-label')
|
||||
|
||||
parser.add_argument('--multi-label-asymetric-focal-loss', action='store_true', help='disbale multi-label')
|
||||
|
||||
opt = parser.parse_args()
|
||||
# Only for clearML env
|
||||
|
||||
if opt.multi_class_no_multi_label and opt.multi_label_asymetric_focal_loss:
|
||||
raise ValueError('ASL is for multi label rather than multi class')
|
||||
|
||||
if opt.tir_channel_expansion: # operates over 3 channels
|
||||
opt.input_channels = 3
|
||||
|
||||
|
|
|
@ -688,7 +688,8 @@ class LoadImagesAndLabels(Dataset): # for training/testing
|
|||
x[:, 0] = 0
|
||||
|
||||
if nm > 0:
|
||||
print('Remove missing annotations file avoiding unlabeled images that would considered as BG')
|
||||
print(100*'/*/')
|
||||
print('Remove missing annotations file avoiding unlabeled images that would considered as BG. Before', len(self.labels))
|
||||
for ix in range(len(self.labels) - 1, -1, -1): # safe remove by reverrse iteration #enumerate(self.labels):
|
||||
if (self.labels[ix][:, 1:] > 1).any() or self.labels[ix].size < 5:
|
||||
del self.labels[ix]
|
||||
|
@ -696,6 +697,8 @@ class LoadImagesAndLabels(Dataset): # for training/testing
|
|||
del self.label_files[ix]
|
||||
del shapes[ix]
|
||||
|
||||
print('after', len(self.labels))
|
||||
|
||||
self.shapes = np.array(shapes, dtype=np.float64)
|
||||
n = len(shapes) # number of images
|
||||
bi = np.floor(np.arange(n) / batch_size).astype(int) # batch index
|
||||
|
@ -770,6 +773,24 @@ class LoadImagesAndLabels(Dataset): # for training/testing
|
|||
print(f'{fname} fname WARNING: Ignoring corrupted image and/or label {file_name}: {e}')
|
||||
|
||||
|
||||
def resample_ohem(self):
|
||||
|
||||
for ix in range(len(self.labels) - 1, -1, -1): # safe remove by reverrse iteration #enumerate(self.labels):
|
||||
if (self.labels[ix][:, 1:] > 1).any() or self.labels[ix].size < 5:
|
||||
del self.labels[ix]
|
||||
del self.img_files[ix]
|
||||
del self.label_files[ix]
|
||||
# del shapes[ix]
|
||||
del self.imgs[ix]
|
||||
self.n = self.n - 1
|
||||
# self.shapes = np.array(shapes, dtype=np.float64)
|
||||
# n = len(shapes) # number of images
|
||||
bi = np.floor(np.arange(self.n) / self.batch).astype(int) # batch index
|
||||
nb = bi[-1] + 1 # number of batches
|
||||
# self.batch = bi # batch index of image
|
||||
|
||||
self.indices = range(self.n)
|
||||
self.mosiac_no = 0
|
||||
|
||||
|
||||
|
||||
|
|
109
utils/loss.py
109
utils/loss.py
|
@ -120,13 +120,14 @@ class SigmoidBin(nn.Module):
|
|||
|
||||
class FocalLoss(nn.Module):
|
||||
# Wraps focal loss around existing loss_fcn(), i.e. criteria = FocalLoss(nn.BCEWithLogitsLoss(), gamma=1.5)
|
||||
def __init__(self, loss_fcn, gamma=1.5, alpha=0.25):
|
||||
def __init__(self, loss_fcn, gamma=1.5, alpha=0.25, multi_class_no_multi_label=False):
|
||||
super(FocalLoss, self).__init__()
|
||||
self.loss_fcn = loss_fcn # must be nn.BCEWithLogitsLoss()
|
||||
self.gamma = gamma
|
||||
self.alpha = alpha
|
||||
self.reduction = loss_fcn.reduction
|
||||
self.loss_fcn.reduction = 'none' # required to apply FL to each element
|
||||
self.multi_class_no_multi_label = multi_class_no_multi_label
|
||||
|
||||
def forward(self, pred, true):
|
||||
loss = self.loss_fcn(pred, true)
|
||||
|
@ -134,17 +135,36 @@ class FocalLoss(nn.Module):
|
|||
# loss *= self.alpha * (1.000001 - p_t) ** self.gamma # non-zero power for gradient stability
|
||||
|
||||
# TF implementation https://github.com/tensorflow/addons/blob/v0.7.1/tensorflow_addons/losses/focal_loss.py
|
||||
pred_prob = torch.sigmoid(pred) # prob from logits
|
||||
p_t = true * pred_prob + (1 - true) * (1 - pred_prob)
|
||||
if not self.multi_class_no_multi_label:
|
||||
pred_prob = torch.sigmoid(pred) # prob from logits
|
||||
# For BCE the positive and negative class has different treatment
|
||||
p_t = true * pred_prob + (1 - true) * (1 - pred_prob)
|
||||
if isinstance(self.alpha, torch.Tensor): # weights between classes or labels rather than between each label and BG
|
||||
if self.alpha.numel() > 0:
|
||||
alpha_factor = true * self.alpha
|
||||
else:# old scalar case of class vs. all but not supporting multi-class weight
|
||||
alpha_factor = true * self.alpha + (1 - true) * (1 - self.alpha)
|
||||
modulating_factor = (1.0 - p_t) ** self.gamma
|
||||
loss *= alpha_factor * modulating_factor
|
||||
|
||||
if isinstance(self.alpha, torch.Tensor): # weights between classes or labels rather than between each label and BG
|
||||
if self.alpha.numel() > 0:
|
||||
alpha_factor = true * self.alpha
|
||||
else:# old scalar case of class vs. all but not supporting multi-class weight
|
||||
alpha_factor = true * self.alpha + (1 - true) * (1 - self.alpha)
|
||||
else:
|
||||
# Get softmax probabilities
|
||||
probs = F.softmax(pred, dim=-1) # (batch_size, num_classes)
|
||||
pt = probs.gather(dim=-1, index=true.unsqueeze(-1)).squeeze(-1) # True class probabilities
|
||||
|
||||
# Compute focal loss factor
|
||||
focal_weight = (1 - pt) ** self.gamma # (batch_size,)
|
||||
|
||||
# Apply alpha weighting
|
||||
if isinstance(self.alpha, torch.Tensor):
|
||||
if self.alpha.numel()>1:# Per-class weighting
|
||||
alpha_t = self.alpha[true]
|
||||
else:
|
||||
alpha_t = self.alpha
|
||||
|
||||
# Compute focal loss
|
||||
loss = alpha_t * focal_weight * loss
|
||||
|
||||
modulating_factor = (1.0 - p_t) ** self.gamma
|
||||
loss *= alpha_factor * modulating_factor
|
||||
|
||||
if self.reduction == 'mean':
|
||||
return loss.mean()
|
||||
|
@ -164,6 +184,39 @@ alpha_factor = torch.tensor([alpha_list[i] for i in labels]).to(device)
|
|||
alpha_factor = alpha_factor * labels + (1 - alpha_factor) * (1 - labels)
|
||||
loss *= alpha_factor
|
||||
"""
|
||||
|
||||
# https://openaccess.thecvf.com/content/ICCV2021/papers/Ridnik_Asymmetric_Loss_for_Multi-Label_Classification_ICCV_2021_paper.pdf
|
||||
class AsymmetricLoss(nn.Module):
|
||||
def __init__(self, gamma_neg=4, gamma_pos=1, clip=0.05):
|
||||
super(AsymmetricLoss, self).__init__()
|
||||
self.gamma_neg = gamma_neg
|
||||
self.gamma_pos = gamma_pos
|
||||
self.clip = clip
|
||||
|
||||
def forward(self, x, y):
|
||||
"""
|
||||
Asymmetric loss for handling label imbalance
|
||||
|
||||
Args:
|
||||
x (torch.Tensor): Model predictions
|
||||
y (torch.Tensor): Binary label matrix
|
||||
|
||||
Returns:
|
||||
torch.Tensor: Asymmetric loss
|
||||
"""
|
||||
# Positive and negative losses
|
||||
xs_pos = torch.sigmoid(x)
|
||||
xs_neg = 1 - xs_pos
|
||||
|
||||
# Asymmetric clipping
|
||||
y = torch.clamp(y, min=self.clip, max=1 - self.clip)
|
||||
|
||||
# Focal-like modulation
|
||||
los_pos = y * torch.log(xs_pos) * (1 - xs_pos) ** self.gamma_pos
|
||||
los_neg = (1 - y) * torch.log(xs_neg) * xs_neg ** self.gamma_neg
|
||||
|
||||
return -torch.mean(los_pos + los_neg)
|
||||
|
||||
class QFocalLoss(nn.Module):
|
||||
# Wraps Quality focal loss around existing loss_fcn(), i.e. criteria = FocalLoss(nn.BCEWithLogitsLoss(), gamma=1.5)
|
||||
def __init__(self, loss_fcn, gamma=1.5, alpha=0.25):
|
||||
|
@ -437,13 +490,19 @@ class APLoss(torch.autograd.Function):
|
|||
# Dual obj and cls losses and outputs inherited from Joseph Redmon's original YOLOv3
|
||||
class ComputeLoss:
|
||||
# Compute losses
|
||||
def __init__(self, model, autobalance=False, loss_weight=torch.tensor([])):
|
||||
def __init__(self, device, model, autobalance=False, loss_weight=torch.tensor([]),
|
||||
multi_class_no_multi_label=False, multi_label_asymetric_focal_loss=False):
|
||||
super(ComputeLoss, self).__init__()
|
||||
device = next(model.parameters()).device # get model device
|
||||
# device = next(model.parameters()).device # get model device
|
||||
h = model.hyp # hyperparameters
|
||||
|
||||
self.multi_class_no_multi_label = multi_class_no_multi_label
|
||||
self.multi_label_asymetric_focal_loss = multi_label_asymetric_focal_loss
|
||||
# Define criteria
|
||||
BCEcls = nn.BCEWithLogitsLoss(pos_weight=torch.tensor([h['cls_pw']], device=device))
|
||||
if self.multi_class_no_multi_label:
|
||||
xCEcls = nn.CrossEntropyLoss() #weight=loss_weight
|
||||
else:
|
||||
xCEcls = nn.BCEWithLogitsLoss(pos_weight=torch.tensor([h['cls_pw']], device=device))
|
||||
|
||||
BCEobj = nn.BCEWithLogitsLoss(pos_weight=torch.tensor([h['obj_pw']], device=device))
|
||||
self.loss_weight = loss_weight
|
||||
# Class label smoothing https://arxiv.org/pdf/1902.04103.pdf eqn 3
|
||||
|
@ -455,14 +514,25 @@ class ComputeLoss:
|
|||
alpha = 0.25 # default by base code
|
||||
if loss_weight.numel()>0:
|
||||
alpha = loss_weight # Overide the default from the paper
|
||||
BCEcls, BCEobj = FocalLoss(BCEcls, g, alpha=alpha), FocalLoss(BCEobj, g)
|
||||
|
||||
if multi_label_asymetric_focal_loss:
|
||||
xCEcls, BCEobj = (AsymmetricLoss(xCEcls, g,
|
||||
alpha=alpha.to(device),
|
||||
multi_class_no_multi_label=self.multi_class_no_multi_label,
|
||||
multi_label_asymetric_focal_loss=self.multi_label_asymetric_focal_loss),
|
||||
FocalLoss(BCEobj, g))
|
||||
else:
|
||||
xCEcls, BCEobj = (FocalLoss(xCEcls, g,
|
||||
alpha=alpha,
|
||||
multi_class_no_multi_label=self.multi_class_no_multi_label),
|
||||
FocalLoss(BCEobj, g))
|
||||
|
||||
det = model.module.model[-1] if is_parallel(model) else model.model[-1] # Detect() module
|
||||
self.balance = {3: [4.0, 1.0, 0.4]}.get(det.nl, [4.0, 1.0, 0.25, 0.06, .02]) # P3-P7
|
||||
#self.balance = {3: [4.0, 1.0, 0.4]}.get(det.nl, [4.0, 1.0, 0.25, 0.1, .05]) # P3-P7
|
||||
#self.balance = {3: [4.0, 1.0, 0.4]}.get(det.nl, [4.0, 1.0, 0.5, 0.4, .1]) # P3-P7
|
||||
self.ssi = list(det.stride).index(16) if autobalance else 0 # stride 16 index
|
||||
self.BCEcls, self.BCEobj, self.gr, self.hyp, self.autobalance = BCEcls, BCEobj, model.gr, h, autobalance
|
||||
self.xCEcls, self.BCEobj, self.gr, self.hyp, self.autobalance = xCEcls, BCEobj, model.gr, h, autobalance
|
||||
for k in 'na', 'nc', 'nl', 'anchors':
|
||||
setattr(self, k, getattr(det, k))
|
||||
|
||||
|
@ -495,7 +565,10 @@ class ComputeLoss:
|
|||
t = torch.full_like(ps[:, 5:], self.cn, device=device) # targets
|
||||
t[range(n), tcls[i]] = self.cp
|
||||
#t[t==self.cp] = iou.detach().clamp(0).type(t.dtype)
|
||||
lcls += self.BCEcls(ps[:, 5:], t) # BCE
|
||||
if self.multi_class_no_multi_label:
|
||||
lcls += self.xCEcls(ps[:, 5:], tcls[i].long().to(device)) # CE
|
||||
else:
|
||||
lcls += self.xCEcls(ps[:, 5:], t) # BCE
|
||||
|
||||
# Append targets to text file
|
||||
# with open('targets.txt', 'a') as file:
|
||||
|
@ -1715,8 +1788,10 @@ class ComputeLossAuxOTA:
|
|||
anch.append(anchors[a]) # anchors
|
||||
|
||||
return indices, anch
|
||||
|
||||
"""
|
||||
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
|
||||
|
|
Loading…
Reference in New Issue