mirror of https://github.com/WongKinYiu/yolov7.git
Add option to use YOLOv5 AP metric (#775)
* Add YOLOv5 metric option * Inform if using v5 metricpull/464/head^2
parent
b1850c7dca
commit
55b90e1119
14
test.py
14
test.py
|
@ -39,7 +39,8 @@ def test(data,
|
|||
compute_loss=None,
|
||||
half_precision=True,
|
||||
trace=False,
|
||||
is_coco=False):
|
||||
is_coco=False,
|
||||
v5_metric=False):
|
||||
# Initialize/load model and set device
|
||||
training = model is not None
|
||||
if training: # called by train.py
|
||||
|
@ -89,6 +90,9 @@ def test(data,
|
|||
dataloader = create_dataloader(data[task], imgsz, batch_size, gs, opt, pad=0.5, rect=True,
|
||||
prefix=colorstr(f'{task}: '))[0]
|
||||
|
||||
if v5_metric:
|
||||
print("Testing with YOLOv5 AP metric...")
|
||||
|
||||
seen = 0
|
||||
confusion_matrix = ConfusionMatrix(nc=nc)
|
||||
names = {k: v for k, v in enumerate(model.names if hasattr(model, 'names') else model.module.names)}
|
||||
|
@ -217,7 +221,7 @@ def test(data,
|
|||
# Compute statistics
|
||||
stats = [np.concatenate(x, 0) for x in zip(*stats)] # to numpy
|
||||
if len(stats) and stats[0].any():
|
||||
p, r, ap, f1, ap_class = ap_per_class(*stats, plot=plots, save_dir=save_dir, names=names)
|
||||
p, r, ap, f1, ap_class = ap_per_class(*stats, plot=plots, v5_metric=v5_metric, save_dir=save_dir, names=names)
|
||||
ap50, ap = ap[:, 0], ap.mean(1) # AP@0.5, AP@0.5:0.95
|
||||
mp, mr, map50, map = p.mean(), r.mean(), ap50.mean(), ap.mean()
|
||||
nt = np.bincount(stats[3].astype(np.int64), minlength=nc) # number of targets per class
|
||||
|
@ -304,6 +308,7 @@ if __name__ == '__main__':
|
|||
parser.add_argument('--name', default='exp', help='save to project/name')
|
||||
parser.add_argument('--exist-ok', action='store_true', help='existing project/name ok, do not increment')
|
||||
parser.add_argument('--no-trace', action='store_true', help='don`t trace model')
|
||||
parser.add_argument('--v5-metric', action='store_true', help='assume maximum recall as 1.0 in AP calculation')
|
||||
opt = parser.parse_args()
|
||||
opt.save_json |= opt.data.endswith('coco.yaml')
|
||||
opt.data = check_file(opt.data) # check file
|
||||
|
@ -325,11 +330,12 @@ if __name__ == '__main__':
|
|||
save_hybrid=opt.save_hybrid,
|
||||
save_conf=opt.save_conf,
|
||||
trace=not opt.no_trace,
|
||||
v5_metric=opt.v5_metric
|
||||
)
|
||||
|
||||
elif opt.task == 'speed': # speed benchmarks
|
||||
for w in opt.weights:
|
||||
test(opt.data, w, opt.batch_size, opt.img_size, 0.25, 0.45, save_json=False, plots=False)
|
||||
test(opt.data, w, opt.batch_size, opt.img_size, 0.25, 0.45, save_json=False, plots=False, v5_metric=opt.v5_metric)
|
||||
|
||||
elif opt.task == 'study': # run over a range of settings and save/plot
|
||||
# python test.py --task study --data coco.yaml --iou 0.65 --weights yolov7.pt
|
||||
|
@ -340,7 +346,7 @@ if __name__ == '__main__':
|
|||
for i in x: # img-size
|
||||
print(f'\nRunning {f} point {i}...')
|
||||
r, _, t = test(opt.data, w, opt.batch_size, i, opt.conf_thres, opt.iou_thres, opt.save_json,
|
||||
plots=False)
|
||||
plots=False, v5_metric=opt.v5_metric)
|
||||
y.append(r + t) # results and times
|
||||
np.savetxt(f, y, fmt='%10.4g') # save
|
||||
os.system('zip -r study.zip study_*.txt')
|
||||
|
|
7
train.py
7
train.py
|
@ -423,7 +423,8 @@ def train(hyp, opt, device, tb_writer=None):
|
|||
plots=plots and final_epoch,
|
||||
wandb_logger=wandb_logger,
|
||||
compute_loss=compute_loss,
|
||||
is_coco=is_coco)
|
||||
is_coco=is_coco,
|
||||
v5_metric=opt.v5_metric)
|
||||
|
||||
# Write
|
||||
with open(results_file, 'a') as f:
|
||||
|
@ -502,7 +503,8 @@ def train(hyp, opt, device, tb_writer=None):
|
|||
save_dir=save_dir,
|
||||
save_json=True,
|
||||
plots=False,
|
||||
is_coco=is_coco)
|
||||
is_coco=is_coco,
|
||||
v5_metric=opt.v5_metric)
|
||||
|
||||
# Strip optimizers
|
||||
final = best if best.exists() else last # final model
|
||||
|
@ -559,6 +561,7 @@ if __name__ == '__main__':
|
|||
parser.add_argument('--save_period', type=int, default=-1, help='Log model after every "save_period" epoch')
|
||||
parser.add_argument('--artifact_alias', type=str, default="latest", help='version of dataset artifact to be used')
|
||||
parser.add_argument('--freeze', nargs='+', type=int, default=[0], help='Freeze layers: backbone of yolov7=50, first3=0 1 2')
|
||||
parser.add_argument('--v5-metric', action='store_true', help='assume maximum recall as 1.0 in AP calculation')
|
||||
opt = parser.parse_args()
|
||||
|
||||
# Set DDP variables
|
||||
|
|
|
@ -420,7 +420,8 @@ def train(hyp, opt, device, tb_writer=None):
|
|||
plots=plots and final_epoch,
|
||||
wandb_logger=wandb_logger,
|
||||
compute_loss=compute_loss,
|
||||
is_coco=is_coco)
|
||||
is_coco=is_coco,
|
||||
v5_metric=opt.v5_metric)
|
||||
|
||||
# Write
|
||||
with open(results_file, 'a') as f:
|
||||
|
@ -499,7 +500,8 @@ def train(hyp, opt, device, tb_writer=None):
|
|||
save_dir=save_dir,
|
||||
save_json=True,
|
||||
plots=False,
|
||||
is_coco=is_coco)
|
||||
is_coco=is_coco,
|
||||
v5_metric=opt.v5_metric)
|
||||
|
||||
# Strip optimizers
|
||||
final = best if best.exists() else last # final model
|
||||
|
@ -555,6 +557,7 @@ if __name__ == '__main__':
|
|||
parser.add_argument('--bbox_interval', type=int, default=-1, help='Set bounding-box image logging interval for W&B')
|
||||
parser.add_argument('--save_period', type=int, default=-1, help='Log model after every "save_period" epoch')
|
||||
parser.add_argument('--artifact_alias', type=str, default="latest", help='version of dataset artifact to be used')
|
||||
parser.add_argument('--v5-metric', action='store_true', help='assume maximum recall as 1.0 in AP calculation')
|
||||
opt = parser.parse_args()
|
||||
|
||||
# Set DDP variables
|
||||
|
|
|
@ -15,7 +15,7 @@ def fitness(x):
|
|||
return (x[:, :4] * w).sum(1)
|
||||
|
||||
|
||||
def ap_per_class(tp, conf, pred_cls, target_cls, plot=False, save_dir='.', names=()):
|
||||
def ap_per_class(tp, conf, pred_cls, target_cls, v5_metric=False, plot=False, save_dir='.', names=()):
|
||||
""" Compute the average precision, given the recall and precision curves.
|
||||
Source: https://github.com/rafaelpadilla/Object-Detection-Metrics.
|
||||
# Arguments
|
||||
|
@ -62,7 +62,7 @@ def ap_per_class(tp, conf, pred_cls, target_cls, plot=False, save_dir='.', names
|
|||
|
||||
# AP from recall-precision curve
|
||||
for j in range(tp.shape[1]):
|
||||
ap[ci, j], mpre, mrec = compute_ap(recall[:, j], precision[:, j])
|
||||
ap[ci, j], mpre, mrec = compute_ap(recall[:, j], precision[:, j], v5_metric=v5_metric)
|
||||
if plot and j == 0:
|
||||
py.append(np.interp(px, mrec, mpre)) # precision at mAP@0.5
|
||||
|
||||
|
@ -78,17 +78,21 @@ def ap_per_class(tp, conf, pred_cls, target_cls, plot=False, save_dir='.', names
|
|||
return p[:, i], r[:, i], ap, f1[:, i], unique_classes.astype('int32')
|
||||
|
||||
|
||||
def compute_ap(recall, precision):
|
||||
def compute_ap(recall, precision, v5_metric=False):
|
||||
""" Compute the average precision, given the recall and precision curves
|
||||
# Arguments
|
||||
recall: The recall curve (list)
|
||||
precision: The precision curve (list)
|
||||
v5_metric: Assume maximum recall to be 1.0, as in YOLOv5, MMDetetion etc.
|
||||
# Returns
|
||||
Average precision, precision curve, recall curve
|
||||
"""
|
||||
|
||||
# Append sentinel values to beginning and end
|
||||
mrec = np.concatenate(([0.], recall, [recall[-1] + 0.01]))
|
||||
if v5_metric: # New YOLOv5 metric, same as MMDetection and Detectron2 repositories
|
||||
mrec = np.concatenate(([0.], recall, [1.0]))
|
||||
else: # Old YOLOv5 metric, i.e. default YOLOv7 metric
|
||||
mrec = np.concatenate(([0.], recall, [recall[-1] + 0.01]))
|
||||
mpre = np.concatenate(([1.], precision, [0.]))
|
||||
|
||||
# Compute the precision envelope
|
||||
|
|
Loading…
Reference in New Issue