From 221429a81b980495014f3add47c4f01e5c0236d6 Mon Sep 17 00:00:00 2001 From: hanoch Date: Mon, 6 Jan 2025 17:26:49 +0200 Subject: [PATCH] datasets.py: 768x1024, random padding, roi based scaling . still open issue the prediction over the random padding area should be omitted from calculation --- YOLOv7onnx.py | 24 ++-- ...before_mosaic_rnd_scaling_e6_full_res.yaml | 37 ++++++ ...c_rnd_scaling_e6_full_res_OVERFITTING.yaml | 37 ++++++ ...tir_od.tiny_aug_gamma_rnd_perspective.yaml | 4 +- ...mma_scaling_before_mosaic_rnd_scaling.yaml | 2 +- data/tir_od_center_roi_aug_list.yaml | 2 +- detect.py | 9 +- test.py | 2 +- train.py | 23 +++- utils/datasets.py | 120 +++++++++++++++--- utils/metrics.py | 8 +- utils/plots.py | 5 +- 12 files changed, 227 insertions(+), 46 deletions(-) create mode 100644 data/hyp.tir_od.aug_gamma_scaling_before_mosaic_rnd_scaling_e6_full_res.yaml create mode 100644 data/hyp.tir_od.aug_gamma_scaling_before_mosaic_rnd_scaling_e6_full_res_OVERFITTING.yaml diff --git a/YOLOv7onnx.py b/YOLOv7onnx.py index 07e20ef..515a06b 100644 --- a/YOLOv7onnx.py +++ b/YOLOv7onnx.py @@ -5,6 +5,7 @@ from tqdm import tqdm from torchvision.ops import batched_nms import matplotlib.pyplot as plt import cv2 +import pandas as pd #%% import random import numpy as np @@ -12,12 +13,13 @@ import onnxruntime as ort from PIL import Image import argparse -from utils.datasets import create_dataloader +from utils.datasets import create_dataloader, create_folder from utils.general import labels_to_class_weights, increment_path, labels_to_image_weights, init_seeds, \ fitness, strip_optimizer, get_latest_run, check_dataset, check_file, check_git_status, check_img_size, \ check_requirements, print_mutation, set_logging, one_cycle, colorstr from utils.metrics import ap_per_class from utils.general import box_iou +import os from utils.general import xywh2xyxy from collections import defaultdict def compute_iou(box1, box2): @@ -166,6 +168,10 @@ def letterbox(im, new_shape=(640, 640), color=(114, 114, 114), auto=True, scaleu return im, r, (dw, dh) def main(opt): + + create_folder(opt.save_path) + + pred_tgt_acm = list() p_r_iou = 0.5 niou = 1 # Model @@ -318,10 +324,10 @@ def main(opt): # Like test.py labels = torch.tensor([[y.item()] + x for x, y in zip(img_gt_boxes_xyxy, img_gt_lbls)]) nl = len(labels) - if nl == ml_class_id.shape[0]: - print('prob all TP') - print('path', paths[seen-1]) - predn[pi, :4] + # if nl == ml_class_id.shape[0]: + # print('prob all TP') + # print('path', paths[seen-1]) + # predn[pi, :4] tcls = labels[:, 0].tolist() if nl else [] # target class pred = torch.tensor([np.append(np.append(bboxes_, scores_above_th_.item()), ml_class_id_.item()) for bboxes_, ml_class_id_, scores_above_th_ in zip(bboxes, ml_class_id, scores_above_th_val)]) @@ -354,7 +360,7 @@ def main(opt): break stats.append((correct.cpu(), pred[:, 4].cpu(), pred[:, 5].cpu(), tcls)) # correct @ IOU=0.5 of pred box with target - + pred_tgt_acm.append({'correct': correct.cpu().numpy(), 'conf': pred[:, 4].cpu().numpy(), 'pred_cls': pred[:, 5].cpu().numpy(), 'tcls': tcls} ) else: @@ -437,9 +443,12 @@ def main(opt): # print('Speed: %.1f/%.1f/%.1f ms inference/NMS/total per %gx%g image at batch-size %g' % t) # Predicted: [(bbox, class_id, confidence)] if is_new_model: + df = pd.DataFrame(pred_tgt_acm) + df.to_csv(os.path.join(opt.save_path, 'onnx_model_pred_tgt_acm_conf_th_' + str(det_threshold.__format__('.3f')) + '.csv'), index=False) + stats = [np.concatenate(x, 0) for x in zip(*stats)] # to numpy if len(stats) and stats[0].any(): # P, R @ # max F1 index - p, r, ap, f1, ap_class = ap_per_class(*stats, plot=True, v5_metric=False, save_dir=save_path, + p, r, ap, f1, ap_class = ap_per_class(*stats, plot=True, v5_metric=False, save_dir=opt.save_path, names=names) # based on correct @ IOU=0.5 of pred box with target for i, c in enumerate(ap_class): print(pf % (names[c], seen, nt[c], p[i], r[i], ap50[i], ap[i])) @@ -474,7 +483,6 @@ if __name__ == '__main__': opt = parser.parse_args() - main(opt=opt) """ diff --git a/data/hyp.tir_od.aug_gamma_scaling_before_mosaic_rnd_scaling_e6_full_res.yaml b/data/hyp.tir_od.aug_gamma_scaling_before_mosaic_rnd_scaling_e6_full_res.yaml new file mode 100644 index 0000000..f10aa13 --- /dev/null +++ b/data/hyp.tir_od.aug_gamma_scaling_before_mosaic_rnd_scaling_e6_full_res.yaml @@ -0,0 +1,37 @@ +lr0: 0.001 #0.001 # initial learning rate (SGD=1E-2, Adam=1E-3) +lrf: 0.01 # final OneCycleLR learning rate (lr0 * lrf) +momentum: 0.937 # SGD momentum/Adam beta1 +weight_decay: 0.0005 # optimizer weight decay 5e-4 It resolve mAP of overfitting test +warmup_epochs: 3.0 # warmup epochs (fractions ok) +warmup_momentum: 0.8 # warmup initial momentum +warmup_bias_lr: 0.001 #0.001 # warmup initial bias lr +loss_ota: 1 #1 # use ComputeLossOTA, use 0 for faster training +box: 0.05 # box loss gain +cls: 0.5 # cls loss gain +cls_pw: 1.0 # cls BCELoss positive_weight +obj: 1.0 # obj loss gain (scale with pixels) +obj_pw: 1.0 # obj BCELoss positive_weight +iou_t: 0.60 # like the default in the code was 0.2 IoU training threshold +anchor_t: 4.0 # anchor-multiple threshold +anchors: 2 # anchors per output layer (0 to ignore) @@HK was 3 +fl_gamma: 1.5 #1.5 # focal loss gamma (efficientDet default gamma=1.5) +hsv_h: 0.0 # image HSV-Hue augmentation (fraction) +hsv_s: 0.0 # image HSV-Saturation augmentation (fraction) +hsv_v: 0.0 # image HSV-Value augmentation (fraction) +degrees: 0 # image rotation (+/- deg) +translate: 0.2 #0.2 # image translation (+/- fraction) +scale: 0.5 # image scale (+/- gain) +shear: 0.0 # image shear (+/- deg) +perspective: 0.0 # image perspective (+/- fraction), range 0-0.001 +flipud: 0.3 # image flip up-down (probability) +fliplr: 0.5 # image flip left-right (probability) +mosaic: 0.5 # image mosaic (probability) +mixup: 0.15 # image mixup (probability) +copy_paste: 0.0 # image copy paste (probability) +paste_in: 0.1 # 0.1 # image copy paste (probability), use 0 for faster training : cutout +inversion: 0.5 #opposite temperature +img_percentile_removal: 0.3 +beta : 0.3 +random_perspective : 1 +scaling_before_mosaic : 1 +gamma : 80 # percent 90 percente more stability to gamma diff --git a/data/hyp.tir_od.aug_gamma_scaling_before_mosaic_rnd_scaling_e6_full_res_OVERFITTING.yaml b/data/hyp.tir_od.aug_gamma_scaling_before_mosaic_rnd_scaling_e6_full_res_OVERFITTING.yaml new file mode 100644 index 0000000..40ed673 --- /dev/null +++ b/data/hyp.tir_od.aug_gamma_scaling_before_mosaic_rnd_scaling_e6_full_res_OVERFITTING.yaml @@ -0,0 +1,37 @@ +lr0: 0.001 #0.001 # initial learning rate (SGD=1E-2, Adam=1E-3) +lrf: 0.01 # final OneCycleLR learning rate (lr0 * lrf) +momentum: 0.937 # SGD momentum/Adam beta1 +weight_decay: 0.005 # optimizer weight decay 5e-4 It resolve mAP of overfitting test +warmup_epochs: 0.0 # warmup epochs (fractions ok) +warmup_momentum: 0.8 # warmup initial momentum +warmup_bias_lr: 0.001 #0.001 # warmup initial bias lr +loss_ota: 0 #1 # use ComputeLossOTA, use 0 for faster training +box: 0.05 # box loss gain +cls: 0.5 # cls loss gain +cls_pw: 1.0 # cls BCELoss positive_weight +obj: 1.0 # obj loss gain (scale with pixels) +obj_pw: 1.0 # obj BCELoss positive_weight +iou_t: 0.60 # like the default in the code was 0.2 IoU training threshold +anchor_t: 4.0 # anchor-multiple threshold +anchors: 2 # anchors per output layer (0 to ignore) @@HK was 3 +fl_gamma: 1.5 #1.5 # focal loss gamma (efficientDet default gamma=1.5) +hsv_h: 0.0 # image HSV-Hue augmentation (fraction) +hsv_s: 0.0 # image HSV-Saturation augmentation (fraction) +hsv_v: 0.0 # image HSV-Value augmentation (fraction) +degrees: 0 # image rotation (+/- deg) +translate: 0.2 #0.2 # image translation (+/- fraction) +scale: 0.5 # image scale (+/- gain) +shear: 0.0 # image shear (+/- deg) +perspective: 0.0 # image perspective (+/- fraction), range 0-0.001 +flipud: 0.3 # image flip up-down (probability) +fliplr: 0.5 # image flip left-right (probability) +mosaic: 0.5 # image mosaic (probability) +mixup: 0.15 # image mixup (probability) +copy_paste: 0.0 # image copy paste (probability) +paste_in: 0.1 # 0.1 # image copy paste (probability), use 0 for faster training : cutout +inversion: 0.5 #opposite temperature +img_percentile_removal: 0.3 +beta : 0.3 +random_perspective : 1 +scaling_before_mosaic : 1 +gamma : 80 # percent 90 percente more stability to gamma diff --git a/data/hyp.tir_od.tiny_aug_gamma_rnd_perspective.yaml b/data/hyp.tir_od.tiny_aug_gamma_rnd_perspective.yaml index ab42798..b2a824f 100644 --- a/data/hyp.tir_od.tiny_aug_gamma_rnd_perspective.yaml +++ b/data/hyp.tir_od.tiny_aug_gamma_rnd_perspective.yaml @@ -28,10 +28,10 @@ mosaic: 0.5 # image mosaic (probability) mixup: 0.15 # image mixup (probability) copy_paste: 0.0 # image copy paste (probability) paste_in: 0.1 # image copy paste (probability), use 0 for faster training : cutout -loss_ota: 0 #1 # use ComputeLossOTA, use 0 for faster training +loss_ota: 1 #1 # use ComputeLossOTA, use 0 for faster training inversion: 0.5 #opposite temperature img_percentile_removal: 0.3 beta : 0.3 random_perspective : 1 -scaling_before_mosaic : 0 +scaling_before_mosaic : 1 gamma : 80 # percent diff --git a/data/hyp.tir_od.tiny_aug_gamma_scaling_before_mosaic_rnd_scaling.yaml b/data/hyp.tir_od.tiny_aug_gamma_scaling_before_mosaic_rnd_scaling.yaml index 3e6cc26..f89792c 100644 --- a/data/hyp.tir_od.tiny_aug_gamma_scaling_before_mosaic_rnd_scaling.yaml +++ b/data/hyp.tir_od.tiny_aug_gamma_scaling_before_mosaic_rnd_scaling.yaml @@ -5,6 +5,7 @@ weight_decay: 0.005 # optimizer weight decay 5e-4 It resolve mAP of overfittin warmup_epochs: 3.0 # warmup epochs (fractions ok) warmup_momentum: 0.8 # warmup initial momentum warmup_bias_lr: 0.001 #0.001 # warmup initial bias lr +loss_ota: 1 #1 # use ComputeLossOTA, use 0 for faster training box: 0.05 # box loss gain cls: 0.5 # cls loss gain cls_pw: 1.0 # cls BCELoss positive_weight @@ -28,7 +29,6 @@ mosaic: 0.5 # image mosaic (probability) mixup: 0.15 # image mixup (probability) copy_paste: 0.0 # image copy paste (probability) paste_in: 0.1 # 0.1 # image copy paste (probability), use 0 for faster training : cutout -loss_ota: 0 #1 # use ComputeLossOTA, use 0 for faster training inversion: 0.5 #opposite temperature img_percentile_removal: 0.3 beta : 0.3 diff --git a/data/tir_od_center_roi_aug_list.yaml b/data/tir_od_center_roi_aug_list.yaml index 9453753..dbaa814 100644 --- a/data/tir_od_center_roi_aug_list.yaml +++ b/data/tir_od_center_roi_aug_list.yaml @@ -5,7 +5,7 @@ # sudo ln -s ~hanoch/projects/tir_frames_rois /mnt/Data/hanoch/tir_frames_rois path: /mnt/Data/hanoch/tir_frames_rois/yolo7_tir_data_all #/home/hanoch/projects/tir_frames_rois/tir_car_44person_31 #/home/hanochk/tir_frames_rois/yolo7_tir_data # train and val data as 1) directory: path/images/, 2) file: path/images.txt, or 3) list: [path1/images/, path2/images/] -train: ./yolov7/tir_od/training_set_yolov7941_w_center_roi_no_swiss_no_38B_20b_22c_n_png.txt #./yolov7/tir_od/training_set_yolov7941_w_center_roi.txt +train: ./yolov7/tir_od/training_set_noTest19G_swiss_w_mp4.txt #./yolov7/tir_od/training_set_yolov7941_w_center_roi.txt val: ./yolov7/tir_od/tir_tiff_w_center_roi_validation_set.txt #./yolov7/tir_od/tir_tiff_car_person_min_size_44_31_validation_set.txt #./yolov7/tir_od/validation_set.txt #./yolov7/tir_od/val_tir_od.txt #./yolov7/tir_od/validation_set.txt # 5000 images #test: ./tir_od/test-dev2017.txt # 20288 of 40670 images, submit to https://competitions.codalab.org/competitions/20794 diff --git a/detect.py b/detect.py index 4b1e6e2..e341211 100644 --- a/detect.py +++ b/detect.py @@ -76,7 +76,7 @@ def detect(save_img=False): t0 = time.time() for path, img, im0s, vid_cap in dataset: - if os.path.basename(path).split('.')[1] == 'tiff': + if os.path.basename(path).split('.')[1] == 'tiff': # im0s only for plotting version im0s = np.repeat(im0s[ :, :, np.newaxis], 3, axis=2) # convert GL to RGB by replication im0s = scaling_image(im0s, scaling_type=opt.norm_type) if im0s.max()<=1: @@ -214,7 +214,7 @@ if __name__ == '__main__': parser.add_argument('--tir-channel-expansion', action='store_true', help='drc_per_ch_percentile') - parser.add_argument('--input-channels', type=int, default=3, help='') + parser.add_argument('--input-channels', type=int, default=1, help='') parser.add_argument('--save-path', default='', help='save to project/name') @@ -246,6 +246,11 @@ python -u ./yolov7/detect.py --weights ./yolov7/yolov7.pt --conf 0.25 --img-size --weights ./yolov7/yolov7.pt --conf 0.25 --img-size 640 --device 0 --save-txt --norm-type single_image_percentile_0_1 --source /home/hanoch/projects/tir_frames_rois/yolo7_tir_data_all/TIR10_V50_OCT21_Test46A_ML_RD_IL_2021_08_05_14_48_05_FS_210_XGA_630_922_DENIS_right_roi_210_881.tiff --weights ./yolov7/yolov7.pt --conf 0.25 --img-size 640 --device 0 --save-txt --norm-type single_image_percentile_0_1 --source /home/hanoch/projects/tir_frames_rois/yolo7_tir_data_all/TIR135_V80_JUL23_Test55A_SY_RD_US_2023_01_18_07_29_38_FS_50_XGA_0001_3562_Shahar_left_roi_50_1348.tiff +--weights /mnt/Data/hanoch/runs/train/yolov7999/weights/best.pt --conf 0.25 --img-size 640 --device 0 --save-txt --norm-type single_image_percentile_0_1 --source /home/hanoch/projects/tir_frames_rois/fog/28_02_2019_16_05_01[1]_04783.tiff + + + + YOLO model --weights ./yolov7/yolov7.pt --conf 0.25 --img-size 640 --device 0 --save-txt --norm-type single_image_percentile_0_1 --source /home/hanoch/projects/tir_od/Snipaste_2024-09-15_09-00-58_tir_135_TIR135_V80_JUL23_Test55A_SY_RD_US_2023_01_18_07_29_38_FS_50_XGA_0001_3562_Shahar_left_roi_50_1348.png diff --git a/test.py b/test.py index b3af6da..b90a4dc 100644 --- a/test.py +++ b/test.py @@ -321,7 +321,7 @@ def test(data, stats_person_medium = [np.concatenate(x, 0) for x in zip(*stats_person_medium)] # to numpy stats_all_large = [np.concatenate(x, 0) for x in zip(*stats_all_large)] # to numpy - if len(stats) and stats[0].any(): # P, R @ # max F1 index + if len(stats) and stats[0].any(): # P, R @ # max F1 index if any correct prediction p, r, ap, f1, ap_class = ap_per_class(*stats, plot=plots, v5_metric=v5_metric, save_dir=save_dir, names=names) #based on correct @ IOU=0.5 of pred box with target if not training or 1: if bool(stats_person_medium): diff --git a/train.py b/train.py index 8195ffb..6e7ced1 100644 --- a/train.py +++ b/train.py @@ -57,10 +57,11 @@ if clear_ml: # clearml support task = Task.init( project_name="TIR_OD", - task_name="train yolov7 with dummy test" + task_name="train yolov7 with dummy test" # output_uri = True model torch.save will uploaded to file server or =/mnt/myfolder or AWS or Azure ) # Task.execute_remotely() will invoke the job immidiately over the remote and not DeV task.set_base_docker(docker_image="nvcr.io/nvidia/pytorch:24.09-py3", docker_arguments="--shm-size 8G") +# clear_ml can capture graph like tensorboard gradient_clip_value = 100.0 opt_gradient_clipping = True @@ -417,7 +418,7 @@ def train(hyp, opt, device, tb_writer=None): # Start training t0 = time.time() if hyp['warmup_epochs'] !=0: # otherwise it is forced to 1000 iterations - nw = max(round(hyp['warmup_epochs'] * nb), 1000) # number of warmup iterations, max(3 epochs, 1k iterations) + nw = max(round(hyp['warmup_epochs'] * nb), 1000) # number of warmup iterations, max(3 epochs, 1k iterations) # HK@@ bad for overfitting test where few examples i.e itoo few iterations else: nw = 0 # nw = min(nw, (epochs - start_epoch) / 2 * nb) # limit warmup to < 1/2 of training @@ -513,7 +514,6 @@ def train(hyp, opt, device, tb_writer=None): if opt.quad: loss *= 4. - # HK TODO : https://discuss.pytorch.org/t/switching-between-mixed-precision-training-and-full-precision-training-after-training-is-started/132366/4 remove scaler backwards # Backward scaler.scale(loss).backward() @@ -554,14 +554,14 @@ def train(hyp, opt, device, tb_writer=None): # import tifffile # for ix, img in enumerate(imgs): # print(ix, torch.std(img), torch.quantile(img, 0.5)) - # tifffile.imwrite(os.path.join('/home/hanoch/projects/tir_od', 'img_scl_bef_mosaic' + str(ix)+'.tiff'), + # tifffile.imwrite(os.path.join('/home/hanoch/projects/tir_od/outputs', 'img_scl_bef_mosaic' + str(ix)+'.tiff'), # img.cpu().numpy().transpose(1, 2, 0)) # # Plot if plots and ni < 100: f = save_dir / f'train_batch{ni}.jpg' # filename - Thread(target=plot_images, args=(imgs, targets, paths, f, opt.input_channels), daemon=True).start() + Thread(target=plot_images, args=(imgs, targets, paths, f), daemon=True).start() # if tb_writer: # tb_writer.add_image(f, result, dataformats='HWC', global_step=epoch) # tb_writer.add_graph(torch.jit.trace(model, imgs, strict=False), []) # add model graph @@ -811,6 +811,8 @@ if __name__ == '__main__': assert opt.batch_size % opt.world_size == 0, '--batch-size must be multiple of CUDA device count' opt.batch_size = opt.total_batch_size // opt.world_size + defualt_random_pad = True # lazy hyp def + # clearml support if clear_ml: #clearml support config_file = task.connect_configuration(opt.hyp, name='hyperparameters_cfg') @@ -825,7 +827,7 @@ if __name__ == '__main__': #defaults for backward compatible hyp files whree not set hyp['person_size_small_medium_th'] = hyp.get('person_size_small_medium_th', 32 * 32) hyp['car_size_small_medium_th'] = hyp.get('car_size_small_medium_th', 44 * 44) - hyp['random_pad'] = True # lazy hyp def + hyp['random_pad'] = hyp.get('random_pad', defualt_random_pad) # Train @@ -957,7 +959,14 @@ FT : you need the --cfg of arch yaml because nc-classes are changing --workers 8 --device 0 --batch-size 16 --data data/tir_od.yaml --img 640 640 --weights ./yolov7/yolov7-tiny.pt --cfg cfg/training/yolov7-tiny.yaml --name yolov7 --hyp hyp.tir_od.tiny_aug.yaml --adam --norm-type single_image_mean_std --input-channels 3 --linear-lr --epochs 2 ---workers 8 --device 0 --batch-size 32 --data data/tir_od.yaml --img 640 640 --weights /mnt/Data/hanoch/tir_frames_rois/yolov7.pt --cfg cfg/training/yolov7.yaml --name yolov7 --hyp hyp.tir_od.tiny_aug_gamma_scaling_before_mosaic.yaml --adam --norm-type single_image_percentile_0_1 --input-channels 1 --linear-lr --epochs 100 --nosave --gamma-aug-prob 0.2 --cache-images +--workers 8 --device 0 --batch-size 32 --data data/tir_od_center_roi_aug_list.yaml --img-size 640 --weights /mnt/Data/hanoch/tir_frames_rois/yolov7.pt --cfg cfg/training/yolov7.yaml --name yolov7 --hyp hyp.tir_od.tiny_aug_gamma_scaling_before_mosaic.yaml --adam --norm-type single_image_percentile_0_1 --input-channels 1 --linear-lr --epochs 100 --nosave --gamma-aug-prob 0.2 --cache-images + +Extended model for higher resolution +# --workers 8 --device 0 --batch-size 8 --data data/tir_od_center_roi_aug_list_full_res.yaml --weights /mnt/Data/hanoch/tir_frames_rois/yolov7-e6.pt --img-size [768, 1024] --cfg cfg/deploy/yolov7-e6.yaml --name yolov7e --hyp hyp.tir_od.aug_gamma_scaling_before_mosaic_rnd_scaling_e6_full_res.yaml --adam --norm-type single_image_percentile_0_1 --input-channels 1 --linear-lr --epochs 2 --gamma-aug-prob 0.3 --cache-images --rect +# --workers 8 --device 0 --batch-size 8 --data data/tir_od_center_roi_aug_list_full_res.yaml --weights /mnt/Data/hanoch/tir_frames_rois/yolov7-e6.pt --img-size 1024 --cfg cfg/deploy/yolov7-e6.yaml --name yolov7e --hyp hyp.tir_od.aug_gamma_scaling_before_mosaic_rnd_scaling_e6_full_res.yaml --adam --norm-type single_image_percentile_0_1 --input-channels 1 --linear-lr --epochs 10 --gamma-aug-prob 0.3 --cache-images +# --workers 1 --device 0 --batch-size 8 --data data/tir_od_center_roi_aug_list_full_res.yaml --weights /mnt/Data/hanoch/tir_frames_rois/yolov7-e6.pt --img-size 1024 --cfg cfg/deploy/yolov7-e6.yaml --name yolov7e --hyp hyp.tir_od.aug_gamma_scaling_before_mosaic_rnd_scaling_e6_full_res.yaml --adam --norm-type single_image_percentile_0_1 --input-channels 1 --linear-lr --epochs 10 --gamma-aug-prob 0.3 --cache-images + +--workers 8 --device 0 --batch-size 8 --data data/tir_od_center_roi_aug_list_full_res.yaml --weights /mnt/Data/hanoch/tir_frames_rois/yolov7-e6.pt --img-size 1024 --cfg cfg/deploy/yolov7-e6.yaml --name yolov7e --hyp hyp.tir_od.aug_gamma_scaling_before_mosaic_rnd_scaling_e6_full_res.yaml --adam --norm-type single_image_percentile_0_1 --input-channels 1 --linear-lr --epochs 150 --gamma-aug-prob 0.3 --cache-images --project runs/train_7e class EMA_Clip(EMA): #Exponential moving average diff --git a/utils/datasets.py b/utils/datasets.py index 5914c79..4a3b886 100644 --- a/utils/datasets.py +++ b/utils/datasets.py @@ -71,25 +71,49 @@ def exif_size(img): # import warnings # warnings.filterwarnings('error', category=RuntimeWarning) -def scaling_image(img, scaling_type, percentile=0.03, beta=0.3): +def scaling_image(img, scaling_type, percentile:float =0.03, + beta:float =0.3, roi :tuple=(), img_size: int=640): if scaling_type == 'no_norm': + if bool(roi): + raise img = img elif scaling_type == 'standardization': # default by repo + if bool(roi): + raise img = img/ 255.0 elif scaling_type =="single_image_0_to_1": + if bool(roi): + raise max_val = np.max(img.ravel()) min_val = np.min(img.ravel()) img = np.double(img - min_val) / (np.double(max_val - min_val) + eps) img = np.minimum(np.maximum(img, 0), 1) elif scaling_type == 'single_image_mean_std': + if bool(roi): + raise img = (img - img.ravel().mean()) / img.ravel().std() elif scaling_type == 'single_image_percentile_0_1': - min_val = np.percentile(img.ravel(), percentile) - max_val = np.percentile(img.ravel(), 100-percentile) + if bool(roi): + dw, dh = img_size[1] - roi[1], img_size[0] - roi[0] # wh padding + dw /= 2 # divide padding into 2 sides + dh /= 2 + top, bottom = int(round(dh - 0.1)), int(round(dh + 0.1)) + left, right = int(round(dw - 0.1)), int(round(dw + 0.1)) + + if len(img.shape) == 2: + img_crop = img[bottom:-top, :] + else: + img_crop = img[:, bottom:-top, :] + + min_val = np.percentile(img_crop.ravel(), percentile) + max_val = np.percentile(img_crop.ravel(), 100-percentile) + else: + min_val = np.percentile(img.ravel(), percentile) + max_val = np.percentile(img.ravel(), 100-percentile) img = np.double(img - min_val) / (np.double(max_val - min_val) + eps) img = np.minimum(np.maximum(img, 0), 1) @@ -98,6 +122,8 @@ def scaling_image(img, scaling_type, percentile=0.03, beta=0.3): # max_val = np.percentile(img.ravel(), 100 - percentile) # img = np.double(img - min_val) / np.double(max_val - min_val) # img = np.uint8(np.minimum(np.maximum(img, 0), 1)*255) + if bool(roi): + raise ImgMin = np.percentile(img, percentile) ImgMax = np.percentile(img, 100-percentile) ImgDRC = (np.double(img - ImgMin) / (np.double(ImgMax - ImgMin)) * 255 + eps) @@ -107,6 +133,8 @@ def scaling_image(img, scaling_type, percentile=0.03, beta=0.3): elif scaling_type == 'remove+global_outlier_0_1': + if bool(roi): + raise img = np.double(img - img.min()*(beta))/np.double(img.max()*(1-beta) - img.min()*(beta)) # beta in [percentile] img = np.double(np.minimum(np.maximum(img, 0), 1)) elif scaling_type == 'normalization_uint16': @@ -310,7 +338,7 @@ class LoadImages: # for inference plt.hist(img.ravel(), bins=128) plt.savefig(os.path.join('/home/hanoch/projects/tir_od/outputs', os.path.basename(path).split('.')[0]+ 'pre')) - file_type = os.path.basename(self.img_files[index]).split('.')[-1].lower() + file_type = os.path.basename(path).split('.')[-1].lower() if (file_type !='tiff' and file_type != 'png'): print('!!!!!!!!!!!!!!!! index : {} {} unrecognized '.format(index, self.img_files[index])) @@ -499,7 +527,7 @@ class LoadImagesAndLabels(Dataset): # for training/testing self.hyp = hyp self.image_weights = image_weights self.rect = False if image_weights else rect - self.mosaic = self.augment and not self.rect # load 4 images at a time into a mosaic (only during training) @@ HK TODO: disable mosaic implicitly by prob mosaic =0 + self.mosaic = self.augment and not self.rect # load 4 images at a time into a mosaic (only during training) self.mosaic_border = [-img_size // 2, -img_size // 2] self.stride = stride self.path = path @@ -734,6 +762,7 @@ class LoadImagesAndLabels(Dataset): # for training/testing img = (img * r + img2 * (1 - r)).astype(img.dtype)#.astype(np.uint8) labels = np.concatenate((labels, labels2), 0) + else: # Load image img, (h0, w0), (h, w) = load_image(self, index) @@ -741,13 +770,15 @@ class LoadImagesAndLabels(Dataset): # for training/testing # Letterbox shape = self.batch_shapes[self.batch[index]] if self.rect else self.img_size # final letterboxed shape # img, ratio, pad = letterbox(img, shape, color=(img.mean(), img.mean(), img.mean()), auto=False, scaleup=self.augment) - img, ratio, pad = letterbox(img, shape, auto=False, scaleup=self.augment) + img, ratio, pad = letterbox(img, shape, auto=False, scaleup=self.augment, random_pad=self.random_pad) + shapes = (h0, w0), ((h / h0, w / w0), pad) # for COCO mAP rescaling labels = self.labels[index].copy() if labels.size: # normalized xywh to pixel xyxy format labels[:, 1:] = xywhn2xyxy(labels[:, 1:], ratio[0] * w, ratio[1] * h, padw=pad[0], padh=pad[1]) + if self.tir_channel_expansion: # HK @@ according to the paper this CE is a sort of augmentation hence no need to preliminary augment. One of the channels are inversion hence avoid channel inversion aug img = np.repeat(img[np.newaxis, :, :], 3, axis=0) # convert GL to RGB by replication img_ce = np.zeros_like(img).astype('float64') @@ -791,7 +822,7 @@ class LoadImagesAndLabels(Dataset): # for training/testing # GL gain/attenuation # Squeeze pdf (x-mu)*scl+mu #img, labels = self.albumentations(img, labels) - img = self.albumentations_gamma_contrast(img) + img = self.albumentations_gamma_contrast(img) # apply RandomBrightnessContrast only since it has buggy response if random.random() < hyp['gamma_liklihood']: if img.dtype == np.uint16 or img.dtype == np.uint8: @@ -865,7 +896,11 @@ class LoadImagesAndLabels(Dataset): # for training/testing # tifffile.imwrite(os.path.join('/home/hanoch/projects/tir_od', 'img_ce.tiff'), 255*img.transpose(1,2,0).astype('uint8')) if not self.tir_channel_expansion: if self.is_tir_signal: - img = np.repeat(img[np.newaxis, :, :], self.input_channels, axis=0) #convert GL to RGB by replication + if len(img.shape) == 2: + img = np.repeat(img[np.newaxis, :, :], self.input_channels, axis=0) #convert GL to 3-ch if any RGB by replication + print('Warning , TIR image should be 3dim by now (w,h,1)', 100*'*') + else: + img = np.repeat(img.transpose(2, 0, 1), self.input_channels, axis=0) else: # Convert img = img[:, :, ::-1].transpose(2, 0, 1) # BGR to RGB, to 3x416x416 @@ -874,23 +909,42 @@ class LoadImagesAndLabels(Dataset): # for training/testing import matplotlib.pyplot as plt plt.figure() plt.hist(img.ravel(), bins=128) - plt.savefig(os.path.join('/home/hanoch/projects/tir_od/outputs', os.path.basename(self.img_files[index]).split('.')[0]+ 'pre_' +str(self.scaling_type))) + plt.savefig(os.path.join('/home/hanoch/projects/tir_od/output', os.path.basename(self.img_files[index]).split('.')[0]+ 'pre_' +str(self.scaling_type))) # import tifffile - # tifffile.imwrite(os.path.join('/home/hanoch/projects/tir_od', 'img_before_last_scaling.tiff'), img.transpose(1,2,0)) + # tifffile.imwrite(os.path.join('/home/hanoch/projects/tir_od/output', 'img_loaded_before_scaling_' + '_' +str(str(img.max())) + '_' +str(self.img_files[index].split('/')[-1].split('.tiff')[0]) + '.tiff'), + # (img.transpose(1, 2, 0))) + + # In case moasaic of mixed PNG and TIFF the TIFF is pre scaled while the PNG shouldn;t if file_type != 'png': + img_orig, _, _ = load_image(self, index) + loaded_img_shape = img_orig.shape[:2] + new_shape = self.img_size + if isinstance(self.img_size, int): # if list then the 2d dim is embedded + new_shape = (new_shape, new_shape) + if new_shape != loaded_img_shape: + roi = loaded_img_shape + img_size = new_shape + else: # don't do nothing normaliza the entire image + roi = () + img_size = loaded_img_shape + + if self.rect: + raise ValueError('not supported') + img = scaling_image(img, scaling_type=self.scaling_type, - percentile=self.percentile, beta=self.beta) + percentile=self.percentile, beta=self.beta, + roi=roi, img_size=img_size) else: img = scaling_image(img, scaling_type='single_image_0_to_1') # safer in case double standartiozation one before mosaic and her the last one since mosaic is random based occurance # print('ka') if 0: import matplotlib.pyplot as plt - plt.figure() + # plt.figure() plt.hist(img.ravel(), bins=128) - plt.savefig(os.path.join('/home/hanoch/projects/tir_od/output', os.path.basename(self.img_files[index]).split('.')[0] + 'post_'+ str(self.scaling_type))) + plt.savefig(os.path.join('/home/hanoch/projects/tir_od/output', os.path.basename(self.img_files[index]).split('.')[0] + '_hist_post_scaling_'+ str(self.scaling_type))) # aa1 = np.repeat(img[1,:,:,:].cpu().permute(1,2,0).numpy(), 3, axis=2).astype('float32') # cv2.imwrite('test/exp40/test_batch88_labels__1.jpg', aa1*255) # aa1 = np.repeat(img.transpose(1,2,0), 3, axis=2).astype('float32') @@ -898,10 +952,14 @@ class LoadImagesAndLabels(Dataset): # for training/testing if np.isnan(img).any(): print('img {} index : {} is nan fin'.format(self.img_files[index], index)) # raise - # tag='png' - # import tifffile - # tifffile.imwrite(os.path.join('/home/hanoch/projects/tir_od/output', 'img_loaded__' + tag +'__' +str(self.img_files[index].split('/')[-1].split('.tiff')[0]) + '.tiff'), - # img.transpose(1, 2, 0)) + # try: + # tag='full_rect' + # import tifffile + # tifffile.imwrite(os.path.join('/home/hanoch/projects/tir_od/output', 'img_loaded__' + tag +'__' +str(self.img_files[index].split('/')[-1].split('.tiff')[0]) + '.tiff'), + # (img.transpose(1, 2, 0)*2**16).astype('uint16')) + # except Exception as e: + # print(f'\nfailed reading: due to {str(e)}') + # # img = np.ascontiguousarray(img) # print('\n 2nd', img.shape) @@ -1288,6 +1346,7 @@ def init_image_plane(self, img, s, n_div=2): else: img4 = np.full((s * n_div, s * n_div, img.shape[2]), img.mean(), dtype=img.dtype) # base image with 4 tiles fill with 0.5 in [0 1] equals to 114 in [0 255] + img4 = img4[:s*n_div, :s*n_div] # in case rectangle shape, AR>1, than crop the padding plane according to the right final shape return img4 @@ -1382,8 +1441,9 @@ def replicate(img, labels): return img, labels -def letterbox(img, new_shape=(640, 640), color=(114, 114, 114), auto=True, scaleFill=False, scaleup=True, stride=32): - # Resize and pad image while meeting stride-multiple constraints +def letterbox(img, new_shape=(640, 640), color=(114, 114, 114), + auto=True, scaleFill=False, scaleup=True, stride=32, random_pad=False): + # Resize and pad image while meeting stride-multiple constraints i.e. 32 shape = img.shape[:2] # current shape [height, width] if isinstance(new_shape, int): new_shape = (new_shape, new_shape) @@ -1409,9 +1469,21 @@ def letterbox(img, new_shape=(640, 640), color=(114, 114, 114), auto=True, scale if shape[::-1] != new_unpad: # resize img = cv2.resize(img, new_unpad, interpolation=cv2.INTER_LINEAR) + top, bottom = int(round(dh - 0.1)), int(round(dh + 0.1)) left, right = int(round(dw - 0.1)), int(round(dw + 0.1)) - img = cv2.copyMakeBorder(img, top, bottom, left, right, cv2.BORDER_CONSTANT, value=color) # add border + if random_pad and dh>0: # recatangle image with padding is expected + img_plane = init_random_image_plane(img, s=max(img.shape), n_div=1) + # img_plane = img_plane[:img.shape[0], :img.shape[1]] # in case rectangle shape, AR>1, than crop the padding plane according to the right final shape + img_plane[bottom:-top, :] = img + img = img_plane + # img[:bottom, :] = img_plane[:bottom, :] + # img[-top:, :] = img_plane[-top:, :] + else: + n_ch = img.shape[-1] + img = cv2.copyMakeBorder(img, top, bottom, left, right, cv2.BORDER_CONSTANT, value=color) # add border + if n_ch == 1 and len(img.shape) == 2: # fixing bug in cv2 where n_ch==1 no explicit consideration + img = img[..., None] return img, ratio, (dw, dh) @@ -1457,11 +1529,16 @@ def random_perspective(img, targets=(), segments=(), degrees=10, translate=.1, s if (border[0] != 0) or (border[1] != 0) or (M != np.eye(3)).any(): # image changed if is_fill_by_mean_img: filling_value = int(img.mean()+1) # filling value can be only an integer hance when scaling before mosaic signal is [0,1] then in the random perspective the posibilities for filling values are 0 or 1 + n_ch = img.shape[-1] + if perspective: img = cv2.warpPerspective(img, M, dsize=(width, height), borderValue=(filling_value, filling_value, filling_value)) else: # affine img = cv2.warpAffine(img, M[:2], dsize=(width, height), borderValue=(filling_value, filling_value, filling_value)) + if n_ch == 1 and len(img.shape) == 2: # fixing bug in cv2 where n_ch==1 no explicit consideration + img = img[..., None] + # import tifffile # unique_run_name = str(int(time.time_ns())) # @@ -1472,7 +1549,8 @@ def random_perspective(img, targets=(), segments=(), degrees=10, translate=.1, s pad_w = int((width - np.round(width * s)) // 2) pad_h = int((height - np.round(height * s)) // 2) - img_plane = init_random_image_plane(img, s=img.shape[0], n_div=1) + img_plane = init_random_image_plane(img, s=max(img.shape), n_div=1) + img_plane = img_plane[:img.shape[0], :img.shape[1]] # in case rectangle shape, AR>1, than crop the padding plane according to the right final shape if pad_w + int(T[0, 2] - width/2) >0: # Left padding diff --git a/utils/metrics.py b/utils/metrics.py index ff49637..97aaf64 100644 --- a/utils/metrics.py +++ b/utils/metrics.py @@ -210,8 +210,12 @@ def plot_pr_curve(px, py, ap, save_dir='pr_curve.png', names=(), precisions_of_i ax.plot(recall_of_interest_per_class, precisions_of_interest, '*', color='green') for k in range(len(precisions_of_interest)): ax.plot(recall_of_interest_per_class[k], precisions_of_interest[k], '*', color='green') - ax.text(x=recall_of_interest_per_class[k], y=precisions_of_interest[k], fontsize=12, - s=f"th={conf_at_precision_of_iterest[k]:.2f}") + try: + ax.text(x=recall_of_interest_per_class[k], y=precisions_of_interest[k], fontsize=12, + s=f"th={conf_at_precision_of_iterest[k]:.2f}") + except Exception as e: + print(f'WARNING: cant plot recall of interest too few data or so: {e}') + # ax.text(x=0.6, y=precisions_of_interest[i], fontsize=12, s=f" R/P {names[i]}[ {recall_of_interest_per_class[i]:.3f} {precisions_of_interest[i]:.3f}]") # ax.text(x=0.6, y=max(0.9-0.2*i, 0), fontsize=12, s=f" R/P {names[i]}[ {recall_of_interest_per_class[i]:.3f} {precisions_of_interest[i]:.3f}]") if k == 0: diff --git a/utils/plots.py b/utils/plots.py index 5938842..15cc4db 100644 --- a/utils/plots.py +++ b/utils/plots.py @@ -116,7 +116,7 @@ def output_to_target(output): return np.array(targets) -def plot_images(images, targets, paths=None, fname='images.jpg', input_channels=3, names=None, max_size=640, max_subplots=16): +def plot_images(images, targets, paths=None, fname='images.jpg', names=None, max_size=640, max_subplots=16): # Plot image grid with labels if isinstance(images, torch.Tensor): @@ -150,8 +150,11 @@ def plot_images(images, targets, paths=None, fname='images.jpg', input_channels= block_y = int(h * (i % ns)) img = img.transpose(1, 2, 0) + n_ch = img.shape[-1] if scale_factor < 1: img = cv2.resize(img, (w, h)) + if n_ch == 1: + img = img[..., None] if img.shape[2] > 1: # GL no permute # Convert