mirror of https://github.com/WongKinYiu/yolov7.git
incoroperating CrossEntropy() loss with designated FocalLoss. w/ weighted loss
Adding Asymentrical loss for Multi-label hasn;t tested yet!pull/2071/head
parent
0a12c34633
commit
565c17e6a4
|
@ -5,7 +5,7 @@ weight_decay: 0.005 # optimizer weight decay 5e-4 It resolve mAP of overfittin
|
||||||
warmup_epochs: 3.0 # warmup epochs (fractions ok)
|
warmup_epochs: 3.0 # warmup epochs (fractions ok)
|
||||||
warmup_momentum: 0.8 # warmup initial momentum
|
warmup_momentum: 0.8 # warmup initial momentum
|
||||||
warmup_bias_lr: 0.001 #0.001 # warmup initial bias lr
|
warmup_bias_lr: 0.001 #0.001 # warmup initial bias lr
|
||||||
loss_ota: 1 #1 # use ComputeLossOTA, use 0 for faster training
|
loss_ota: 0 #1 # use ComputeLossOTA, use 0 for faster training
|
||||||
box: 0.05 # box loss gain
|
box: 0.05 # box loss gain
|
||||||
cls: 0.5 # cls loss gain
|
cls: 0.5 # cls loss gain
|
||||||
cls_pw: 1.0 # cls BCELoss positive_weight
|
cls_pw: 1.0 # cls BCELoss positive_weight
|
||||||
|
|
2
test.py
2
test.py
|
@ -441,7 +441,7 @@ def test(data,
|
||||||
|
|
||||||
|
|
||||||
# Plot images aa = np.repeat(img[0,:,:,:].cpu().permute(1,2,0).numpy(), 3, axis=2).astype('float32') cv2.imwrite('test/exp40/test_batch88_labels__.jpg', aa*255)
|
# Plot images aa = np.repeat(img[0,:,:,:].cpu().permute(1,2,0).numpy(), 3, axis=2).astype('float32') cv2.imwrite('test/exp40/test_batch88_labels__.jpg', aa*255)
|
||||||
if (plots and batch_i > 10) or 1:
|
if (plots and batch_i > 10):
|
||||||
# conf_thresh_plot = 0.1 # the plot threshold the connfidence
|
# conf_thresh_plot = 0.1 # the plot threshold the connfidence
|
||||||
f = save_dir / f'test_batch{batch_i}_labels.jpg' # labels
|
f = save_dir / f'test_batch{batch_i}_labels.jpg' # labels
|
||||||
Thread(target=plot_images, args=(img, targets, paths, f, names), daemon=True).start()
|
Thread(target=plot_images, args=(img, targets, paths, f, names), daemon=True).start()
|
||||||
|
|
20
train.py
20
train.py
|
@ -458,7 +458,10 @@ def train(hyp, opt, device, tb_writer=None):
|
||||||
else:
|
else:
|
||||||
scaler = torch.amp.GradScaler("cuda", enabled=opt.amp) if is_torch_240 else torch.cuda.amp.GradScaler(enabled=opt.amp)
|
scaler = torch.amp.GradScaler("cuda", enabled=opt.amp) if is_torch_240 else torch.cuda.amp.GradScaler(enabled=opt.amp)
|
||||||
|
|
||||||
loss_weight = torch.tensor([])
|
loss_weight = torch.tensor([]) # for BCE
|
||||||
|
if opt.multi_class_no_multi_label:
|
||||||
|
loss_weight = torch.ones(1)
|
||||||
|
|
||||||
if opt.loss_weight:
|
if opt.loss_weight:
|
||||||
loss_weight = class_inverse_freq
|
loss_weight = class_inverse_freq
|
||||||
if 0:
|
if 0:
|
||||||
|
@ -466,8 +469,12 @@ def train(hyp, opt, device, tb_writer=None):
|
||||||
# Replaced YOLO classification loss with Focal Loss using per-class α values. Kept Objectness Loss and BBox Loss unchanged.
|
# Replaced YOLO classification loss with Focal Loss using per-class α values. Kept Objectness Loss and BBox Loss unchanged.
|
||||||
if 'loss_ota' not in hyp or hyp['loss_ota'] == 1:
|
if 'loss_ota' not in hyp or hyp['loss_ota'] == 1:
|
||||||
compute_loss_ota = ComputeLossOTA(model, loss_weight=loss_weight) # init loss class
|
compute_loss_ota = ComputeLossOTA(model, loss_weight=loss_weight) # init loss class
|
||||||
|
if opt.multi_class_no_multi_label:
|
||||||
|
raise ValueError('Not imp yet!')
|
||||||
|
|
||||||
compute_loss = ComputeLoss(model, loss_weight=loss_weight) # init loss class it is required for the test set as well hance mandatory
|
compute_loss = ComputeLoss(device, model, loss_weight=loss_weight,
|
||||||
|
multi_class_no_multi_label=opt.multi_class_no_multi_label,
|
||||||
|
multi_label_asymetric_focal_loss=opt.multi_label_asymetric_focal_loss) # init loss class it is required for the test set as well hance mandatory
|
||||||
|
|
||||||
logger.info(f'Image sizes {imgsz} train, {imgsz_test} test\n'
|
logger.info(f'Image sizes {imgsz} train, {imgsz_test} test\n'
|
||||||
f'Using {dataloader.num_workers} dataloader workers\n'
|
f'Using {dataloader.num_workers} dataloader workers\n'
|
||||||
|
@ -621,7 +628,7 @@ def train(hyp, opt, device, tb_writer=None):
|
||||||
# end batch ------------------------------------------------------------------------------------------------
|
# end batch ------------------------------------------------------------------------------------------------
|
||||||
# end epoch ----------------------------------------------------------------------------------------------------
|
# end epoch ----------------------------------------------------------------------------------------------------
|
||||||
if epoch >= opt.ohem_start_ep and opt.ohem_start_ep >0:
|
if epoch >= opt.ohem_start_ep and opt.ohem_start_ep >0:
|
||||||
# dataloader_orig = copy.deepcopy(dataloader)
|
dataloader_orig = copy.deepcopy(dataloader)
|
||||||
|
|
||||||
print('OHEM')
|
print('OHEM')
|
||||||
|
|
||||||
|
@ -837,9 +844,16 @@ if __name__ == '__main__':
|
||||||
|
|
||||||
parser.add_argument('--cosine-anneal', action='store_true', help='')
|
parser.add_argument('--cosine-anneal', action='store_true', help='')
|
||||||
|
|
||||||
|
parser.add_argument('--multi-class-no-multi-label', action='store_true', help='disbale multi-label')
|
||||||
|
|
||||||
|
parser.add_argument('--multi-label-asymetric-focal-loss', action='store_true', help='disbale multi-label')
|
||||||
|
|
||||||
opt = parser.parse_args()
|
opt = parser.parse_args()
|
||||||
# Only for clearML env
|
# Only for clearML env
|
||||||
|
|
||||||
|
if opt.multi_class_no_multi_label and opt.multi_label_asymetric_focal_loss:
|
||||||
|
raise ValueError('ASL is for multi label rather than multi class')
|
||||||
|
|
||||||
if opt.tir_channel_expansion: # operates over 3 channels
|
if opt.tir_channel_expansion: # operates over 3 channels
|
||||||
opt.input_channels = 3
|
opt.input_channels = 3
|
||||||
|
|
||||||
|
|
|
@ -688,7 +688,8 @@ class LoadImagesAndLabels(Dataset): # for training/testing
|
||||||
x[:, 0] = 0
|
x[:, 0] = 0
|
||||||
|
|
||||||
if nm > 0:
|
if nm > 0:
|
||||||
print('Remove missing annotations file avoiding unlabeled images that would considered as BG')
|
print(100*'/*/')
|
||||||
|
print('Remove missing annotations file avoiding unlabeled images that would considered as BG. Before', len(self.labels))
|
||||||
for ix in range(len(self.labels) - 1, -1, -1): # safe remove by reverrse iteration #enumerate(self.labels):
|
for ix in range(len(self.labels) - 1, -1, -1): # safe remove by reverrse iteration #enumerate(self.labels):
|
||||||
if (self.labels[ix][:, 1:] > 1).any() or self.labels[ix].size < 5:
|
if (self.labels[ix][:, 1:] > 1).any() or self.labels[ix].size < 5:
|
||||||
del self.labels[ix]
|
del self.labels[ix]
|
||||||
|
@ -696,6 +697,8 @@ class LoadImagesAndLabels(Dataset): # for training/testing
|
||||||
del self.label_files[ix]
|
del self.label_files[ix]
|
||||||
del shapes[ix]
|
del shapes[ix]
|
||||||
|
|
||||||
|
print('after', len(self.labels))
|
||||||
|
|
||||||
self.shapes = np.array(shapes, dtype=np.float64)
|
self.shapes = np.array(shapes, dtype=np.float64)
|
||||||
n = len(shapes) # number of images
|
n = len(shapes) # number of images
|
||||||
bi = np.floor(np.arange(n) / batch_size).astype(int) # batch index
|
bi = np.floor(np.arange(n) / batch_size).astype(int) # batch index
|
||||||
|
@ -770,6 +773,24 @@ class LoadImagesAndLabels(Dataset): # for training/testing
|
||||||
print(f'{fname} fname WARNING: Ignoring corrupted image and/or label {file_name}: {e}')
|
print(f'{fname} fname WARNING: Ignoring corrupted image and/or label {file_name}: {e}')
|
||||||
|
|
||||||
|
|
||||||
|
def resample_ohem(self):
|
||||||
|
|
||||||
|
for ix in range(len(self.labels) - 1, -1, -1): # safe remove by reverrse iteration #enumerate(self.labels):
|
||||||
|
if (self.labels[ix][:, 1:] > 1).any() or self.labels[ix].size < 5:
|
||||||
|
del self.labels[ix]
|
||||||
|
del self.img_files[ix]
|
||||||
|
del self.label_files[ix]
|
||||||
|
# del shapes[ix]
|
||||||
|
del self.imgs[ix]
|
||||||
|
self.n = self.n - 1
|
||||||
|
# self.shapes = np.array(shapes, dtype=np.float64)
|
||||||
|
# n = len(shapes) # number of images
|
||||||
|
bi = np.floor(np.arange(self.n) / self.batch).astype(int) # batch index
|
||||||
|
nb = bi[-1] + 1 # number of batches
|
||||||
|
# self.batch = bi # batch index of image
|
||||||
|
|
||||||
|
self.indices = range(self.n)
|
||||||
|
self.mosiac_no = 0
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
109
utils/loss.py
109
utils/loss.py
|
@ -120,13 +120,14 @@ class SigmoidBin(nn.Module):
|
||||||
|
|
||||||
class FocalLoss(nn.Module):
|
class FocalLoss(nn.Module):
|
||||||
# Wraps focal loss around existing loss_fcn(), i.e. criteria = FocalLoss(nn.BCEWithLogitsLoss(), gamma=1.5)
|
# Wraps focal loss around existing loss_fcn(), i.e. criteria = FocalLoss(nn.BCEWithLogitsLoss(), gamma=1.5)
|
||||||
def __init__(self, loss_fcn, gamma=1.5, alpha=0.25):
|
def __init__(self, loss_fcn, gamma=1.5, alpha=0.25, multi_class_no_multi_label=False):
|
||||||
super(FocalLoss, self).__init__()
|
super(FocalLoss, self).__init__()
|
||||||
self.loss_fcn = loss_fcn # must be nn.BCEWithLogitsLoss()
|
self.loss_fcn = loss_fcn # must be nn.BCEWithLogitsLoss()
|
||||||
self.gamma = gamma
|
self.gamma = gamma
|
||||||
self.alpha = alpha
|
self.alpha = alpha
|
||||||
self.reduction = loss_fcn.reduction
|
self.reduction = loss_fcn.reduction
|
||||||
self.loss_fcn.reduction = 'none' # required to apply FL to each element
|
self.loss_fcn.reduction = 'none' # required to apply FL to each element
|
||||||
|
self.multi_class_no_multi_label = multi_class_no_multi_label
|
||||||
|
|
||||||
def forward(self, pred, true):
|
def forward(self, pred, true):
|
||||||
loss = self.loss_fcn(pred, true)
|
loss = self.loss_fcn(pred, true)
|
||||||
|
@ -134,17 +135,36 @@ class FocalLoss(nn.Module):
|
||||||
# loss *= self.alpha * (1.000001 - p_t) ** self.gamma # non-zero power for gradient stability
|
# loss *= self.alpha * (1.000001 - p_t) ** self.gamma # non-zero power for gradient stability
|
||||||
|
|
||||||
# TF implementation https://github.com/tensorflow/addons/blob/v0.7.1/tensorflow_addons/losses/focal_loss.py
|
# TF implementation https://github.com/tensorflow/addons/blob/v0.7.1/tensorflow_addons/losses/focal_loss.py
|
||||||
pred_prob = torch.sigmoid(pred) # prob from logits
|
if not self.multi_class_no_multi_label:
|
||||||
p_t = true * pred_prob + (1 - true) * (1 - pred_prob)
|
pred_prob = torch.sigmoid(pred) # prob from logits
|
||||||
|
# For BCE the positive and negative class has different treatment
|
||||||
|
p_t = true * pred_prob + (1 - true) * (1 - pred_prob)
|
||||||
|
if isinstance(self.alpha, torch.Tensor): # weights between classes or labels rather than between each label and BG
|
||||||
|
if self.alpha.numel() > 0:
|
||||||
|
alpha_factor = true * self.alpha
|
||||||
|
else:# old scalar case of class vs. all but not supporting multi-class weight
|
||||||
|
alpha_factor = true * self.alpha + (1 - true) * (1 - self.alpha)
|
||||||
|
modulating_factor = (1.0 - p_t) ** self.gamma
|
||||||
|
loss *= alpha_factor * modulating_factor
|
||||||
|
|
||||||
if isinstance(self.alpha, torch.Tensor): # weights between classes or labels rather than between each label and BG
|
else:
|
||||||
if self.alpha.numel() > 0:
|
# Get softmax probabilities
|
||||||
alpha_factor = true * self.alpha
|
probs = F.softmax(pred, dim=-1) # (batch_size, num_classes)
|
||||||
else:# old scalar case of class vs. all but not supporting multi-class weight
|
pt = probs.gather(dim=-1, index=true.unsqueeze(-1)).squeeze(-1) # True class probabilities
|
||||||
alpha_factor = true * self.alpha + (1 - true) * (1 - self.alpha)
|
|
||||||
|
# Compute focal loss factor
|
||||||
|
focal_weight = (1 - pt) ** self.gamma # (batch_size,)
|
||||||
|
|
||||||
|
# Apply alpha weighting
|
||||||
|
if isinstance(self.alpha, torch.Tensor):
|
||||||
|
if self.alpha.numel()>1:# Per-class weighting
|
||||||
|
alpha_t = self.alpha[true]
|
||||||
|
else:
|
||||||
|
alpha_t = self.alpha
|
||||||
|
|
||||||
|
# Compute focal loss
|
||||||
|
loss = alpha_t * focal_weight * loss
|
||||||
|
|
||||||
modulating_factor = (1.0 - p_t) ** self.gamma
|
|
||||||
loss *= alpha_factor * modulating_factor
|
|
||||||
|
|
||||||
if self.reduction == 'mean':
|
if self.reduction == 'mean':
|
||||||
return loss.mean()
|
return loss.mean()
|
||||||
|
@ -164,6 +184,39 @@ alpha_factor = torch.tensor([alpha_list[i] for i in labels]).to(device)
|
||||||
alpha_factor = alpha_factor * labels + (1 - alpha_factor) * (1 - labels)
|
alpha_factor = alpha_factor * labels + (1 - alpha_factor) * (1 - labels)
|
||||||
loss *= alpha_factor
|
loss *= alpha_factor
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
# https://openaccess.thecvf.com/content/ICCV2021/papers/Ridnik_Asymmetric_Loss_for_Multi-Label_Classification_ICCV_2021_paper.pdf
|
||||||
|
class AsymmetricLoss(nn.Module):
|
||||||
|
def __init__(self, gamma_neg=4, gamma_pos=1, clip=0.05):
|
||||||
|
super(AsymmetricLoss, self).__init__()
|
||||||
|
self.gamma_neg = gamma_neg
|
||||||
|
self.gamma_pos = gamma_pos
|
||||||
|
self.clip = clip
|
||||||
|
|
||||||
|
def forward(self, x, y):
|
||||||
|
"""
|
||||||
|
Asymmetric loss for handling label imbalance
|
||||||
|
|
||||||
|
Args:
|
||||||
|
x (torch.Tensor): Model predictions
|
||||||
|
y (torch.Tensor): Binary label matrix
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
torch.Tensor: Asymmetric loss
|
||||||
|
"""
|
||||||
|
# Positive and negative losses
|
||||||
|
xs_pos = torch.sigmoid(x)
|
||||||
|
xs_neg = 1 - xs_pos
|
||||||
|
|
||||||
|
# Asymmetric clipping
|
||||||
|
y = torch.clamp(y, min=self.clip, max=1 - self.clip)
|
||||||
|
|
||||||
|
# Focal-like modulation
|
||||||
|
los_pos = y * torch.log(xs_pos) * (1 - xs_pos) ** self.gamma_pos
|
||||||
|
los_neg = (1 - y) * torch.log(xs_neg) * xs_neg ** self.gamma_neg
|
||||||
|
|
||||||
|
return -torch.mean(los_pos + los_neg)
|
||||||
|
|
||||||
class QFocalLoss(nn.Module):
|
class QFocalLoss(nn.Module):
|
||||||
# Wraps Quality focal loss around existing loss_fcn(), i.e. criteria = FocalLoss(nn.BCEWithLogitsLoss(), gamma=1.5)
|
# Wraps Quality focal loss around existing loss_fcn(), i.e. criteria = FocalLoss(nn.BCEWithLogitsLoss(), gamma=1.5)
|
||||||
def __init__(self, loss_fcn, gamma=1.5, alpha=0.25):
|
def __init__(self, loss_fcn, gamma=1.5, alpha=0.25):
|
||||||
|
@ -437,13 +490,19 @@ class APLoss(torch.autograd.Function):
|
||||||
# Dual obj and cls losses and outputs inherited from Joseph Redmon's original YOLOv3
|
# Dual obj and cls losses and outputs inherited from Joseph Redmon's original YOLOv3
|
||||||
class ComputeLoss:
|
class ComputeLoss:
|
||||||
# Compute losses
|
# Compute losses
|
||||||
def __init__(self, model, autobalance=False, loss_weight=torch.tensor([])):
|
def __init__(self, device, model, autobalance=False, loss_weight=torch.tensor([]),
|
||||||
|
multi_class_no_multi_label=False, multi_label_asymetric_focal_loss=False):
|
||||||
super(ComputeLoss, self).__init__()
|
super(ComputeLoss, self).__init__()
|
||||||
device = next(model.parameters()).device # get model device
|
# device = next(model.parameters()).device # get model device
|
||||||
h = model.hyp # hyperparameters
|
h = model.hyp # hyperparameters
|
||||||
|
self.multi_class_no_multi_label = multi_class_no_multi_label
|
||||||
|
self.multi_label_asymetric_focal_loss = multi_label_asymetric_focal_loss
|
||||||
# Define criteria
|
# Define criteria
|
||||||
BCEcls = nn.BCEWithLogitsLoss(pos_weight=torch.tensor([h['cls_pw']], device=device))
|
if self.multi_class_no_multi_label:
|
||||||
|
xCEcls = nn.CrossEntropyLoss() #weight=loss_weight
|
||||||
|
else:
|
||||||
|
xCEcls = nn.BCEWithLogitsLoss(pos_weight=torch.tensor([h['cls_pw']], device=device))
|
||||||
|
|
||||||
BCEobj = nn.BCEWithLogitsLoss(pos_weight=torch.tensor([h['obj_pw']], device=device))
|
BCEobj = nn.BCEWithLogitsLoss(pos_weight=torch.tensor([h['obj_pw']], device=device))
|
||||||
self.loss_weight = loss_weight
|
self.loss_weight = loss_weight
|
||||||
# Class label smoothing https://arxiv.org/pdf/1902.04103.pdf eqn 3
|
# Class label smoothing https://arxiv.org/pdf/1902.04103.pdf eqn 3
|
||||||
|
@ -455,14 +514,25 @@ class ComputeLoss:
|
||||||
alpha = 0.25 # default by base code
|
alpha = 0.25 # default by base code
|
||||||
if loss_weight.numel()>0:
|
if loss_weight.numel()>0:
|
||||||
alpha = loss_weight # Overide the default from the paper
|
alpha = loss_weight # Overide the default from the paper
|
||||||
BCEcls, BCEobj = FocalLoss(BCEcls, g, alpha=alpha), FocalLoss(BCEobj, g)
|
|
||||||
|
if multi_label_asymetric_focal_loss:
|
||||||
|
xCEcls, BCEobj = (AsymmetricLoss(xCEcls, g,
|
||||||
|
alpha=alpha.to(device),
|
||||||
|
multi_class_no_multi_label=self.multi_class_no_multi_label,
|
||||||
|
multi_label_asymetric_focal_loss=self.multi_label_asymetric_focal_loss),
|
||||||
|
FocalLoss(BCEobj, g))
|
||||||
|
else:
|
||||||
|
xCEcls, BCEobj = (FocalLoss(xCEcls, g,
|
||||||
|
alpha=alpha,
|
||||||
|
multi_class_no_multi_label=self.multi_class_no_multi_label),
|
||||||
|
FocalLoss(BCEobj, g))
|
||||||
|
|
||||||
det = model.module.model[-1] if is_parallel(model) else model.model[-1] # Detect() module
|
det = model.module.model[-1] if is_parallel(model) else model.model[-1] # Detect() module
|
||||||
self.balance = {3: [4.0, 1.0, 0.4]}.get(det.nl, [4.0, 1.0, 0.25, 0.06, .02]) # P3-P7
|
self.balance = {3: [4.0, 1.0, 0.4]}.get(det.nl, [4.0, 1.0, 0.25, 0.06, .02]) # P3-P7
|
||||||
#self.balance = {3: [4.0, 1.0, 0.4]}.get(det.nl, [4.0, 1.0, 0.25, 0.1, .05]) # P3-P7
|
#self.balance = {3: [4.0, 1.0, 0.4]}.get(det.nl, [4.0, 1.0, 0.25, 0.1, .05]) # P3-P7
|
||||||
#self.balance = {3: [4.0, 1.0, 0.4]}.get(det.nl, [4.0, 1.0, 0.5, 0.4, .1]) # P3-P7
|
#self.balance = {3: [4.0, 1.0, 0.4]}.get(det.nl, [4.0, 1.0, 0.5, 0.4, .1]) # P3-P7
|
||||||
self.ssi = list(det.stride).index(16) if autobalance else 0 # stride 16 index
|
self.ssi = list(det.stride).index(16) if autobalance else 0 # stride 16 index
|
||||||
self.BCEcls, self.BCEobj, self.gr, self.hyp, self.autobalance = BCEcls, BCEobj, model.gr, h, autobalance
|
self.xCEcls, self.BCEobj, self.gr, self.hyp, self.autobalance = xCEcls, BCEobj, model.gr, h, autobalance
|
||||||
for k in 'na', 'nc', 'nl', 'anchors':
|
for k in 'na', 'nc', 'nl', 'anchors':
|
||||||
setattr(self, k, getattr(det, k))
|
setattr(self, k, getattr(det, k))
|
||||||
|
|
||||||
|
@ -495,7 +565,10 @@ class ComputeLoss:
|
||||||
t = torch.full_like(ps[:, 5:], self.cn, device=device) # targets
|
t = torch.full_like(ps[:, 5:], self.cn, device=device) # targets
|
||||||
t[range(n), tcls[i]] = self.cp
|
t[range(n), tcls[i]] = self.cp
|
||||||
#t[t==self.cp] = iou.detach().clamp(0).type(t.dtype)
|
#t[t==self.cp] = iou.detach().clamp(0).type(t.dtype)
|
||||||
lcls += self.BCEcls(ps[:, 5:], t) # BCE
|
if self.multi_class_no_multi_label:
|
||||||
|
lcls += self.xCEcls(ps[:, 5:], tcls[i].long().to(device)) # CE
|
||||||
|
else:
|
||||||
|
lcls += self.xCEcls(ps[:, 5:], t) # BCE
|
||||||
|
|
||||||
# Append targets to text file
|
# Append targets to text file
|
||||||
# with open('targets.txt', 'a') as file:
|
# with open('targets.txt', 'a') as file:
|
||||||
|
@ -1715,8 +1788,10 @@ class ComputeLossAuxOTA:
|
||||||
anch.append(anchors[a]) # anchors
|
anch.append(anchors[a]) # anchors
|
||||||
|
|
||||||
return indices, anch
|
return indices, anch
|
||||||
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue