incoroperating CrossEntropy() loss with designated FocalLoss. w/ weighted loss

Adding Asymentrical loss for Multi-label hasn;t tested yet!
pull/2071/head
hanoch3 2025-03-30 11:14:54 +03:00
parent 0a12c34633
commit 565c17e6a4
5 changed files with 133 additions and 23 deletions

View File

@ -5,7 +5,7 @@ weight_decay: 0.005 # optimizer weight decay 5e-4 It resolve mAP of overfittin
warmup_epochs: 3.0 # warmup epochs (fractions ok)
warmup_momentum: 0.8 # warmup initial momentum
warmup_bias_lr: 0.001 #0.001 # warmup initial bias lr
loss_ota: 1 #1 # use ComputeLossOTA, use 0 for faster training
loss_ota: 0 #1 # use ComputeLossOTA, use 0 for faster training
box: 0.05 # box loss gain
cls: 0.5 # cls loss gain
cls_pw: 1.0 # cls BCELoss positive_weight

View File

@ -441,7 +441,7 @@ def test(data,
# Plot images aa = np.repeat(img[0,:,:,:].cpu().permute(1,2,0).numpy(), 3, axis=2).astype('float32') cv2.imwrite('test/exp40/test_batch88_labels__.jpg', aa*255)
if (plots and batch_i > 10) or 1:
if (plots and batch_i > 10):
# conf_thresh_plot = 0.1 # the plot threshold the connfidence
f = save_dir / f'test_batch{batch_i}_labels.jpg' # labels
Thread(target=plot_images, args=(img, targets, paths, f, names), daemon=True).start()

View File

@ -458,7 +458,10 @@ def train(hyp, opt, device, tb_writer=None):
else:
scaler = torch.amp.GradScaler("cuda", enabled=opt.amp) if is_torch_240 else torch.cuda.amp.GradScaler(enabled=opt.amp)
loss_weight = torch.tensor([])
loss_weight = torch.tensor([]) # for BCE
if opt.multi_class_no_multi_label:
loss_weight = torch.ones(1)
if opt.loss_weight:
loss_weight = class_inverse_freq
if 0:
@ -466,8 +469,12 @@ def train(hyp, opt, device, tb_writer=None):
# Replaced YOLO classification loss with Focal Loss using per-class α values. Kept Objectness Loss and BBox Loss unchanged.
if 'loss_ota' not in hyp or hyp['loss_ota'] == 1:
compute_loss_ota = ComputeLossOTA(model, loss_weight=loss_weight) # init loss class
if opt.multi_class_no_multi_label:
raise ValueError('Not imp yet!')
compute_loss = ComputeLoss(model, loss_weight=loss_weight) # init loss class it is required for the test set as well hance mandatory
compute_loss = ComputeLoss(device, model, loss_weight=loss_weight,
multi_class_no_multi_label=opt.multi_class_no_multi_label,
multi_label_asymetric_focal_loss=opt.multi_label_asymetric_focal_loss) # init loss class it is required for the test set as well hance mandatory
logger.info(f'Image sizes {imgsz} train, {imgsz_test} test\n'
f'Using {dataloader.num_workers} dataloader workers\n'
@ -621,7 +628,7 @@ def train(hyp, opt, device, tb_writer=None):
# end batch ------------------------------------------------------------------------------------------------
# end epoch ----------------------------------------------------------------------------------------------------
if epoch >= opt.ohem_start_ep and opt.ohem_start_ep >0:
# dataloader_orig = copy.deepcopy(dataloader)
dataloader_orig = copy.deepcopy(dataloader)
print('OHEM')
@ -837,9 +844,16 @@ if __name__ == '__main__':
parser.add_argument('--cosine-anneal', action='store_true', help='')
parser.add_argument('--multi-class-no-multi-label', action='store_true', help='disbale multi-label')
parser.add_argument('--multi-label-asymetric-focal-loss', action='store_true', help='disbale multi-label')
opt = parser.parse_args()
# Only for clearML env
if opt.multi_class_no_multi_label and opt.multi_label_asymetric_focal_loss:
raise ValueError('ASL is for multi label rather than multi class')
if opt.tir_channel_expansion: # operates over 3 channels
opt.input_channels = 3

View File

@ -688,7 +688,8 @@ class LoadImagesAndLabels(Dataset): # for training/testing
x[:, 0] = 0
if nm > 0:
print('Remove missing annotations file avoiding unlabeled images that would considered as BG')
print(100*'/*/')
print('Remove missing annotations file avoiding unlabeled images that would considered as BG. Before', len(self.labels))
for ix in range(len(self.labels) - 1, -1, -1): # safe remove by reverrse iteration #enumerate(self.labels):
if (self.labels[ix][:, 1:] > 1).any() or self.labels[ix].size < 5:
del self.labels[ix]
@ -696,6 +697,8 @@ class LoadImagesAndLabels(Dataset): # for training/testing
del self.label_files[ix]
del shapes[ix]
print('after', len(self.labels))
self.shapes = np.array(shapes, dtype=np.float64)
n = len(shapes) # number of images
bi = np.floor(np.arange(n) / batch_size).astype(int) # batch index
@ -770,6 +773,24 @@ class LoadImagesAndLabels(Dataset): # for training/testing
print(f'{fname} fname WARNING: Ignoring corrupted image and/or label {file_name}: {e}')
def resample_ohem(self):
for ix in range(len(self.labels) - 1, -1, -1): # safe remove by reverrse iteration #enumerate(self.labels):
if (self.labels[ix][:, 1:] > 1).any() or self.labels[ix].size < 5:
del self.labels[ix]
del self.img_files[ix]
del self.label_files[ix]
# del shapes[ix]
del self.imgs[ix]
self.n = self.n - 1
# self.shapes = np.array(shapes, dtype=np.float64)
# n = len(shapes) # number of images
bi = np.floor(np.arange(self.n) / self.batch).astype(int) # batch index
nb = bi[-1] + 1 # number of batches
# self.batch = bi # batch index of image
self.indices = range(self.n)
self.mosiac_no = 0

View File

@ -120,13 +120,14 @@ class SigmoidBin(nn.Module):
class FocalLoss(nn.Module):
# Wraps focal loss around existing loss_fcn(), i.e. criteria = FocalLoss(nn.BCEWithLogitsLoss(), gamma=1.5)
def __init__(self, loss_fcn, gamma=1.5, alpha=0.25):
def __init__(self, loss_fcn, gamma=1.5, alpha=0.25, multi_class_no_multi_label=False):
super(FocalLoss, self).__init__()
self.loss_fcn = loss_fcn # must be nn.BCEWithLogitsLoss()
self.gamma = gamma
self.alpha = alpha
self.reduction = loss_fcn.reduction
self.loss_fcn.reduction = 'none' # required to apply FL to each element
self.multi_class_no_multi_label = multi_class_no_multi_label
def forward(self, pred, true):
loss = self.loss_fcn(pred, true)
@ -134,17 +135,36 @@ class FocalLoss(nn.Module):
# loss *= self.alpha * (1.000001 - p_t) ** self.gamma # non-zero power for gradient stability
# TF implementation https://github.com/tensorflow/addons/blob/v0.7.1/tensorflow_addons/losses/focal_loss.py
pred_prob = torch.sigmoid(pred) # prob from logits
p_t = true * pred_prob + (1 - true) * (1 - pred_prob)
if not self.multi_class_no_multi_label:
pred_prob = torch.sigmoid(pred) # prob from logits
# For BCE the positive and negative class has different treatment
p_t = true * pred_prob + (1 - true) * (1 - pred_prob)
if isinstance(self.alpha, torch.Tensor): # weights between classes or labels rather than between each label and BG
if self.alpha.numel() > 0:
alpha_factor = true * self.alpha
else:# old scalar case of class vs. all but not supporting multi-class weight
alpha_factor = true * self.alpha + (1 - true) * (1 - self.alpha)
modulating_factor = (1.0 - p_t) ** self.gamma
loss *= alpha_factor * modulating_factor
if isinstance(self.alpha, torch.Tensor): # weights between classes or labels rather than between each label and BG
if self.alpha.numel() > 0:
alpha_factor = true * self.alpha
else:# old scalar case of class vs. all but not supporting multi-class weight
alpha_factor = true * self.alpha + (1 - true) * (1 - self.alpha)
else:
# Get softmax probabilities
probs = F.softmax(pred, dim=-1) # (batch_size, num_classes)
pt = probs.gather(dim=-1, index=true.unsqueeze(-1)).squeeze(-1) # True class probabilities
# Compute focal loss factor
focal_weight = (1 - pt) ** self.gamma # (batch_size,)
# Apply alpha weighting
if isinstance(self.alpha, torch.Tensor):
if self.alpha.numel()>1:# Per-class weighting
alpha_t = self.alpha[true]
else:
alpha_t = self.alpha
# Compute focal loss
loss = alpha_t * focal_weight * loss
modulating_factor = (1.0 - p_t) ** self.gamma
loss *= alpha_factor * modulating_factor
if self.reduction == 'mean':
return loss.mean()
@ -164,6 +184,39 @@ alpha_factor = torch.tensor([alpha_list[i] for i in labels]).to(device)
alpha_factor = alpha_factor * labels + (1 - alpha_factor) * (1 - labels)
loss *= alpha_factor
"""
# https://openaccess.thecvf.com/content/ICCV2021/papers/Ridnik_Asymmetric_Loss_for_Multi-Label_Classification_ICCV_2021_paper.pdf
class AsymmetricLoss(nn.Module):
def __init__(self, gamma_neg=4, gamma_pos=1, clip=0.05):
super(AsymmetricLoss, self).__init__()
self.gamma_neg = gamma_neg
self.gamma_pos = gamma_pos
self.clip = clip
def forward(self, x, y):
"""
Asymmetric loss for handling label imbalance
Args:
x (torch.Tensor): Model predictions
y (torch.Tensor): Binary label matrix
Returns:
torch.Tensor: Asymmetric loss
"""
# Positive and negative losses
xs_pos = torch.sigmoid(x)
xs_neg = 1 - xs_pos
# Asymmetric clipping
y = torch.clamp(y, min=self.clip, max=1 - self.clip)
# Focal-like modulation
los_pos = y * torch.log(xs_pos) * (1 - xs_pos) ** self.gamma_pos
los_neg = (1 - y) * torch.log(xs_neg) * xs_neg ** self.gamma_neg
return -torch.mean(los_pos + los_neg)
class QFocalLoss(nn.Module):
# Wraps Quality focal loss around existing loss_fcn(), i.e. criteria = FocalLoss(nn.BCEWithLogitsLoss(), gamma=1.5)
def __init__(self, loss_fcn, gamma=1.5, alpha=0.25):
@ -437,13 +490,19 @@ class APLoss(torch.autograd.Function):
# Dual obj and cls losses and outputs inherited from Joseph Redmon's original YOLOv3
class ComputeLoss:
# Compute losses
def __init__(self, model, autobalance=False, loss_weight=torch.tensor([])):
def __init__(self, device, model, autobalance=False, loss_weight=torch.tensor([]),
multi_class_no_multi_label=False, multi_label_asymetric_focal_loss=False):
super(ComputeLoss, self).__init__()
device = next(model.parameters()).device # get model device
# device = next(model.parameters()).device # get model device
h = model.hyp # hyperparameters
self.multi_class_no_multi_label = multi_class_no_multi_label
self.multi_label_asymetric_focal_loss = multi_label_asymetric_focal_loss
# Define criteria
BCEcls = nn.BCEWithLogitsLoss(pos_weight=torch.tensor([h['cls_pw']], device=device))
if self.multi_class_no_multi_label:
xCEcls = nn.CrossEntropyLoss() #weight=loss_weight
else:
xCEcls = nn.BCEWithLogitsLoss(pos_weight=torch.tensor([h['cls_pw']], device=device))
BCEobj = nn.BCEWithLogitsLoss(pos_weight=torch.tensor([h['obj_pw']], device=device))
self.loss_weight = loss_weight
# Class label smoothing https://arxiv.org/pdf/1902.04103.pdf eqn 3
@ -455,14 +514,25 @@ class ComputeLoss:
alpha = 0.25 # default by base code
if loss_weight.numel()>0:
alpha = loss_weight # Overide the default from the paper
BCEcls, BCEobj = FocalLoss(BCEcls, g, alpha=alpha), FocalLoss(BCEobj, g)
if multi_label_asymetric_focal_loss:
xCEcls, BCEobj = (AsymmetricLoss(xCEcls, g,
alpha=alpha.to(device),
multi_class_no_multi_label=self.multi_class_no_multi_label,
multi_label_asymetric_focal_loss=self.multi_label_asymetric_focal_loss),
FocalLoss(BCEobj, g))
else:
xCEcls, BCEobj = (FocalLoss(xCEcls, g,
alpha=alpha,
multi_class_no_multi_label=self.multi_class_no_multi_label),
FocalLoss(BCEobj, g))
det = model.module.model[-1] if is_parallel(model) else model.model[-1] # Detect() module
self.balance = {3: [4.0, 1.0, 0.4]}.get(det.nl, [4.0, 1.0, 0.25, 0.06, .02]) # P3-P7
#self.balance = {3: [4.0, 1.0, 0.4]}.get(det.nl, [4.0, 1.0, 0.25, 0.1, .05]) # P3-P7
#self.balance = {3: [4.0, 1.0, 0.4]}.get(det.nl, [4.0, 1.0, 0.5, 0.4, .1]) # P3-P7
self.ssi = list(det.stride).index(16) if autobalance else 0 # stride 16 index
self.BCEcls, self.BCEobj, self.gr, self.hyp, self.autobalance = BCEcls, BCEobj, model.gr, h, autobalance
self.xCEcls, self.BCEobj, self.gr, self.hyp, self.autobalance = xCEcls, BCEobj, model.gr, h, autobalance
for k in 'na', 'nc', 'nl', 'anchors':
setattr(self, k, getattr(det, k))
@ -495,7 +565,10 @@ class ComputeLoss:
t = torch.full_like(ps[:, 5:], self.cn, device=device) # targets
t[range(n), tcls[i]] = self.cp
#t[t==self.cp] = iou.detach().clamp(0).type(t.dtype)
lcls += self.BCEcls(ps[:, 5:], t) # BCE
if self.multi_class_no_multi_label:
lcls += self.xCEcls(ps[:, 5:], tcls[i].long().to(device)) # CE
else:
lcls += self.xCEcls(ps[:, 5:], t) # BCE
# Append targets to text file
# with open('targets.txt', 'a') as file:
@ -1715,8 +1788,10 @@ class ComputeLossAuxOTA:
anch.append(anchors[a]) # anchors
return indices, anch
"""
import torch
import torch.nn.functional as F