mirror of https://github.com/WongKinYiu/yolov7.git
True prediction equal the same range cell as GT. Fix bug where no preds but GT exist. whether statistics, ONNX save pre-processed file,
parent
4c0190e554
commit
d00653795f
|
@ -21,7 +21,7 @@ from utils.general import labels_to_class_weights, increment_path, labels_to_ima
|
|||
from utils.metrics import ap_per_class
|
||||
from utils.general import box_iou
|
||||
from utils.plots import plot_one_box
|
||||
|
||||
import pickle
|
||||
import os
|
||||
from utils.general import xywh2xyxy
|
||||
from collections import defaultdict
|
||||
|
@ -591,7 +591,9 @@ def main(opt):
|
|||
tta_res = dict()
|
||||
tta_res['predictions'] = predictions
|
||||
tta_res['ground_truths'] = ground_truths
|
||||
tta_res['iou_threshold'] = iou_threshold
|
||||
tta_res['iou_threshold'] = opt.iou_thres
|
||||
tta_res['conf_thres'] = opt.conf_thres
|
||||
np.savetxt(os.path.join(opt.save_path,"image_preprocessed.csv"), img[0,:,:,0], delimiter=",")
|
||||
with open(os.path.join(opt.save_path, 'metadata_for_pre_re_detection_threshold_' + str(det_threshold) + '.pkl'), 'wb') as f:
|
||||
pickle.dump(tta_res, f)
|
||||
|
||||
|
@ -692,8 +694,12 @@ if __name__ == '__main__':
|
|||
Plotting P/R curve over detections th=0.05
|
||||
--cache-images --device 0 --weights /mnt/Data/hanoch/tir_old_tf/tir_od_1.5.onnx --img-size 512 --conf-thres 0.05 --iou-thres 0.5 --norm-type no_norm --save-path /mnt/Data/hanoch/runs/tir_old_1.5 --images-parent-folder /home/hanoch/projects/tir_frames_rois/marmon_noisy_sy --detection-no-gt
|
||||
|
||||
DEtections only New model Yolov7999
|
||||
Detection
|
||||
--cache-images --device 0 --weights /mnt/Data/hanoch/runs/train/yolov7999/weights/best.onnx --img-size 640 --conf-thres 0.66 --iou-thres 0.6 --norm-type single_image_percentile_0_1 --images-parent-folder /home/hanoch/projects/tir_frames_rois/marmon_noisy_sy --save-path /mnt/Data/hanoch/runs/yolov7999_onnx_run --detection-no-gt --adding-ext-noise
|
||||
|
||||
|
||||
DEtections only New model Yolov7999 with noise addition
|
||||
--cache-images --device 0 --weights /mnt/Data/hanoch/runs/train/yolov7999/weights/best.onnx --img-size 640 --conf-thres 0.48 --iou-thres 0.6 --norm-type single_image_percentile_0_1 --images-parent-folder /mnt/Data/hanoch/tir_frames_rois/onnx_bm --save-path /mnt/Data/hanoch/runs/yolov7999_onnx_run --detection-no-gt
|
||||
P/R curve
|
||||
--cache-images --device 0 --weights /mnt/Data/hanoch/runs/train/yolov7999/weights/best.onnx --img-size 640 --conf-thres 0.01 --iou-thres 0.6 --norm-type single_image_percentile_0_1 --test-files-path /home/hanoch/projects/tir_od/yolov7/tir_od/test_set/Test51a_Test40A_test_set.txt --save-path /mnt/Data/hanoch/runs/yolov7999_onnx_run/P_R_curve_test_set --adding-ext-noise
|
||||
"""
|
71
test.py
71
test.py
|
@ -347,21 +347,47 @@ def test(data,
|
|||
|
||||
# sensor type
|
||||
if dataloader.dataset.use_csv_meta_data_file:
|
||||
weather_condition = (dataloader.dataset.df_metadata[dataloader.dataset.df_metadata['tir_frame_image_file_name'] == str(path).split('/')[-1]]['weather_condition'].item())
|
||||
if isinstance(weather_condition, str):
|
||||
weather_condition = weather_condition.lower()
|
||||
exec([x for x in weather_condition_vars if str(weather_condition) in x][0] + '.append((correct.cpu(), pred[:, 4].cpu(), pred[:, 5].cpu(), tcls))')
|
||||
try:
|
||||
|
||||
weather_condition = (dataloader.dataset.df_metadata[dataloader.dataset.df_metadata['tir_frame_image_file_name'] == str(path).split('/')[-1]]['weather_condition'].item())
|
||||
if isinstance(weather_condition, str):
|
||||
weather_condition = weather_condition.lower()
|
||||
exec([x for x in weather_condition_vars if str(weather_condition) in x][0] + '.append((correct.cpu(), pred[:, 4].cpu(), pred[:, 5].cpu(), tcls))')
|
||||
except Exception as e:
|
||||
print(f'{weather_condition} fname WARNING: Ignoring corrupted image and/or label {weather_condition}: {e}')
|
||||
|
||||
time_in_day = dataloader.dataset.df_metadata[dataloader.dataset.df_metadata['tir_frame_image_file_name'] == str(path).split('/')[-1]]['part_in_day'].item().lower()
|
||||
# eval([x for x in time_vars if str(time_in_day) in x][0]).append((correct.cpu(), pred[:, 4].cpu(), pred[:, 5].cpu(), tcls))
|
||||
exec([x for x in time_vars if str(time_in_day) in x][0] + '.append((correct.cpu(), pred[:, 4].cpu(), pred[:, 5].cpu(), tcls))')
|
||||
|
||||
sensor_type = dataloader.dataset.df_metadata[dataloader.dataset.df_metadata['tir_frame_image_file_name'] == str(path).split('/')[-1]]['sensor_type'].item()
|
||||
obj_range_m = torch.tensor(
|
||||
[(object_size_to_range(obj_height_pixels=h.cpu(), focal=sensor_type, class_id=class_id.cpu().numpy().item()))
|
||||
for class_id, (x, y, w, h) in zip(pred[:, 5], xyxy2xywh(pred[:, :4]))])
|
||||
# obj_range_m = torch.tensor(
|
||||
# [(object_size_to_range(obj_height_pixels=h.cpu(), focal=sensor_type, class_id=class_id.cpu().numpy().item()))
|
||||
# for class_id, (x, y, w, h) in zip(pred[:, 5], xyxy2xywh(pred[:, :4]))])
|
||||
gt_range = [(object_size_to_range(obj_height_pixels=h, focal=sensor_type, class_id=class_id.numpy().item())) for
|
||||
class_id, (x, y, w, h) in zip(labels[:, 0].cpu(), labels[:, 1:5].cpu())]
|
||||
|
||||
# coupling the range cell between any overlapped IOU >TH between pred bbox and GT bbox
|
||||
obj_range_m = list()
|
||||
i = 0
|
||||
for class_id_pred, (x1_p, y1_p, x2_p, y2_p) in zip(pred[:, 5].cpu(), pred[:, :4].cpu()):
|
||||
|
||||
range_candidate = torch.tensor(
|
||||
object_size_to_range(obj_height_pixels=xyxy2xywh(pred[:, :4])[0][-1].cpu(),
|
||||
focal=sensor_type,
|
||||
class_id=class_id_pred.cpu().numpy().item()))
|
||||
# Find any IOU overlapped between GT and prediction
|
||||
for class_id_gt, (x1_gt, y1_gt, x2_gt, y2_gt), (xc,yc,w,h) in zip(labels[:, 0].cpu(),
|
||||
xywh2xyxy(labels[:, 1:5].cpu()), labels[:, 1:5].cpu()):
|
||||
ious = box_iou(torch.tensor((x1_gt, y1_gt, x2_gt, y2_gt)).unsqueeze(axis=0), torch.tensor((x1_p, y1_p, x2_p, y2_p)).unsqueeze(axis=0))
|
||||
i += 1
|
||||
if ious > iouv.cpu()[0]: #
|
||||
range_candidate = torch.tensor(
|
||||
object_size_to_range(obj_height_pixels=h.cpu(),
|
||||
focal=sensor_type,
|
||||
class_id=class_id_pred.cpu().numpy().item()))
|
||||
break # the aligned GT/Pred was found no need to iterate more, this is the atmost candidate
|
||||
obj_range_m.append(range_candidate)
|
||||
# else: #ranges = func(sqrt(height*width))
|
||||
# obj_range_m = torch.tensor([(object_size_to_range(obj_height_pixels=(np.sqrt(h.cpu()*w.cpu())), focal=sensor_type)) for ix, (x, y, w, h) in enumerate(xyxy2xywh(pred[:, :4]))])
|
||||
# gt_range = [(object_size_to_range(obj_height_pixels=(np.sqrt(h*w)), focal=sensor_type)) for ix, (x, y, w, h) in enumerate(labels[:,1:5].cpu())]
|
||||
|
@ -490,6 +516,8 @@ def test(data,
|
|||
nt_stat_list_per_range = np.array([0, 0])
|
||||
r_stat_list_per_range = np.array([0, 0])
|
||||
p_stat_list_per_range = np.array([0, 0])
|
||||
map50_per_range = np.array(0)
|
||||
|
||||
# ind = np.array([])
|
||||
# exec('ind = np.where(ranges == rng_100)[0]')
|
||||
ind = eval('np.where(ranges == rng_100)[0]')
|
||||
|
@ -515,14 +543,17 @@ def test(data,
|
|||
ap50_per_range, ap_per_range = ap_stat_list_per_range[:, 0], ap_stat_list_per_range.mean(1) # AP@0.5, AP@0.5:0.95
|
||||
mp_per_range, mr_per_range, map50_per_range, map_per_range = p_stat_list_per_range.mean(), r_stat_list_per_range.mean(), ap50_per_range.mean(), ap_per_range.mean()
|
||||
else:# no prediction at this range
|
||||
if not bool(gt_per_range_bins[sensor_focal][rng_100]):# there are GT but no pred
|
||||
nt_stat_list_per_range = np.array([0,0])
|
||||
r_stat_list_per_range = np.array([0,0])
|
||||
p_stat_list_per_range = np.array([0,0])
|
||||
fn = len(gt_per_range_bins[sensor_focal][rng_100])
|
||||
recall = 0 # no TP 0/TP+FN
|
||||
precision = 0
|
||||
map50_per_range = np.array(0)
|
||||
r_stat_list_per_range = np.array([0, 0])
|
||||
p_stat_list_per_range = np.array([0, 0])
|
||||
fn = len(gt_per_range_bins[sensor_focal][rng_100])
|
||||
recall = 0 # no TP 0/TP+FN
|
||||
precision = 0
|
||||
map50_per_range = np.array(0)
|
||||
if not bool(gt_per_range_bins[sensor_focal][rng_100]):# there are no GT no pred
|
||||
nt_stat_list_per_range = np.array([0,0]) # actual GT
|
||||
else:
|
||||
nt_stat_list_per_range = np.array(gt_per_range_bins[sensor_focal][rng_100]).sum()
|
||||
# there are GT but no pred
|
||||
# print(map50_per_range)
|
||||
|
||||
range_bins_map[sensor_focal][rng_100] = map50_per_range.item()
|
||||
|
@ -827,8 +858,18 @@ tir_tiff_w_center_roi_validation_set_train_cls_usa.txt
|
|||
|
||||
# per day/nigh SY/ML
|
||||
--weights /mnt/Data/hanoch/runs/train/yolov7999/weights/best.pt --device 0 --batch-size 16 --data data/tir_od_test_set.yaml --img-size 640 --conf 0.001 --verbose --norm-type single_image_percentile_0_1 --input-channels 1 --project test --task test --iou-thres 0.6 --csv-metadata-path tir_od/tir_center_merged_seq_tiff_last_original_png.xlsx
|
||||
P/R
|
||||
--weights /mnt/Data/hanoch/runs/train/yolov7999/weights/best.pt --device 0 --batch-size 16 --data data/tir_od_test_set.yaml --img-size 640 --verbose --norm-type single_image_percentile_0_1 --input-channels 1 --project test --task test --iou-thres 0.6 --csv-metadata-path tir_od/tir_center_merged_seq_tiff_last_original_png.xlsx --conf 0.65
|
||||
mAP:
|
||||
--weights /mnt/Data/hanoch/runs/train/yolov7999/weights/best.pt --device 0 --batch-size 16 --data data/tir_od_test_set.yaml --img-size 640 --verbose --norm-type single_image_percentile_0_1 --input-channels 1 --project test --task test --iou-thres 0.6 --csv-metadata-path tir_od/tir_center_merged_seq_tiff_last_original_png.xlsx --conf 0.01
|
||||
|
||||
|
||||
Fixed wether csv P/R
|
||||
--weights /mnt/Data/hanoch/runs/train/yolov7999/weights/best.pt --device 0 --batch-size 16 --data data/tir_od_test_set.yaml --img-size 640 --verbose --norm-type single_image_percentile_0_1 --input-channels 1 --project test --task test --csv-metadata-path tir_od/tir_tiff_seq_png_3_class_fixed_whether.xlsx --iou-thres 0.6 --conf 0.65
|
||||
|
||||
|
||||
FOG
|
||||
--weights /mnt/Data/hanoch/runs/train/yolov7999/weights/best.pt --device 0 --batch-size 16 --data data/tir_od_fog_set.yaml --img-size 640 --verbose --norm-type single_image_percentile_0_1 --input-channels 1 --project test --task test --csv-metadata-path tir_od/tir_tiff_seq_png_3_class_fixed_whether.xlsx --conf 0.01 --iou-thres 0.6
|
||||
------- Error analysis ------------
|
||||
1st run with conf_th=0.0001 then observe the desired threshold, re-run with the desired threshold abd observe images with bboxes given the deired threshold
|
||||
"""
|
|
@ -677,6 +677,7 @@ class LoadImagesAndLabels(Dataset): # for training/testing
|
|||
if self.use_csv_meta_data_file:
|
||||
df = load_csv_xls_2_df(self.csv_meta_data_file)
|
||||
self.df_metadata = pd.DataFrame(columns=['sensor_type', 'part_in_day', 'weather_condition', 'country', 'train_state', 'tir_frame_image_file_name'])
|
||||
# TODO :HK @@ itereate tqdm(zip(self.img_files, self.label_files) and upon --force-csv-list remove missing entries in the csv from train/test lists!!!
|
||||
for ix, fname in enumerate(self.img_files):
|
||||
file_name = fname.split('/')[-1]
|
||||
if not (df['tir_frame_image_file_name'] == file_name).any():
|
||||
|
@ -720,13 +721,18 @@ class LoadImagesAndLabels(Dataset): # for training/testing
|
|||
segments = [np.array(x[1:], dtype=np.float32).reshape(-1, 2) for x in l] # (cls, xy1...)
|
||||
l = np.concatenate((classes.reshape(-1, 1), segments2boxes(segments)), 1) # (cls, xywh)
|
||||
l = np.array(l, dtype=np.float32)
|
||||
# if (l[:, 0].max() >= num_cls):
|
||||
# print('ka', i, l, lb_file, im_file)
|
||||
l = np.array([lbl for lbl in l if lbl[0] < num_cls]) # take only labels index upto num of classes and omit others
|
||||
|
||||
if len(l):
|
||||
assert l.shape[1] == 5, 'labels require 5 columns each' #@@HK TODO adding truncation ratio increase here : assert l.shape[1] == 6,
|
||||
assert (l >= 0).all(), 'negative labels'
|
||||
assert (l[:, 1:] <= 1).all(), 'non-normalized or out of bounds coordinate labels'
|
||||
assert np.unique(l, axis=0).shape[0] == l.shape[0], 'duplicate labels'
|
||||
l = np.array([lbl for lbl in l if lbl[0] < num_cls]) # take only labels index upto num of classes and omit others
|
||||
# assert if (l[:, 0].max() < num_cls), 'class label out of range -- invalid' # max label can't be greater than num of labels
|
||||
assert (l[:, 0].max() < num_cls), 'class label out of range -- invalid' # max label can't be greater than num of labels
|
||||
# print(l[:, 0])
|
||||
|
||||
|
||||
else:
|
||||
ne += 1 # label empty
|
||||
|
|
|
@ -200,8 +200,8 @@ def range_p_r_bar_plot(n_bins:int, range_bins_precision_all_classes:dict, range_
|
|||
x = 100 * np.arange(n_bins) + 100
|
||||
for (k_p, v_p), (k_r, v_r) in zip(range_bins_precision_all_classes.items(), range_bins_recall_all_classes.items()):
|
||||
plt.figure()
|
||||
bar1 = plt.bar(x , [x[class_graph] for x in v_p[:n_bins]], bar_width, color='b', label='p')
|
||||
bar2 = plt.bar(x+10 , [x[class_graph] for x in v_r[:n_bins]], bar_width, color='g', label='r')
|
||||
bar1 = plt.bar(x-bar_width , [x[class_graph] for x in v_p[:n_bins]], bar_width, color='b', label='p')
|
||||
bar2 = plt.bar(x-bar_width//2 , [x[class_graph] for x in v_r[:n_bins]], bar_width, color='g', label='r')
|
||||
|
||||
plt.xticks(x, (x/100).astype('int'))
|
||||
|
||||
|
@ -210,13 +210,13 @@ def range_p_r_bar_plot(n_bins:int, range_bins_precision_all_classes:dict, range_
|
|||
# height = rect.get_height()
|
||||
# plt.text(rect.get_x() + rect.get_width() / 2.0, height, 'p/r', ha='center', va='bottom')
|
||||
|
||||
# plt.legend()
|
||||
plt.legend()
|
||||
# plt.tight_layout()
|
||||
# plt.ylim([0.0, 1.05])
|
||||
plt.ylabel('P/R')
|
||||
plt.xlabel('Range[x100m]')
|
||||
plt.grid()
|
||||
plt.title('Sensor {}mm Precision/Recall vs. range class: {} conf: {}'.format(k_p, str(names[class_graph]), conf))
|
||||
plt.title('Sensor {}mm Precision/Recall vs. range. Class: {} conf: {}'.format(k_p, str(names[class_graph]), conf))
|
||||
# plt.legend(bbox_to_anchor=(1.04, 1), loc="upper left")
|
||||
plt.savefig(os.path.join(save_dir, 'p_r_distribution_distance_sensor_' + str(k_p) + '_class_' + str(names[class_graph]) +'.png'), dpi=250)
|
||||
plt.clf()
|
||||
|
|
Loading…
Reference in New Issue