mirror of https://github.com/WongKinYiu/yolov7.git
datasets.py: 768x1024, random padding, roi based scaling . still open issue the prediction over the random padding area should be omitted from calculation
parent
76e6315e99
commit
221429a81b
|
@ -5,6 +5,7 @@ from tqdm import tqdm
|
|||
from torchvision.ops import batched_nms
|
||||
import matplotlib.pyplot as plt
|
||||
import cv2
|
||||
import pandas as pd
|
||||
#%%
|
||||
import random
|
||||
import numpy as np
|
||||
|
@ -12,12 +13,13 @@ import onnxruntime as ort
|
|||
from PIL import Image
|
||||
import argparse
|
||||
|
||||
from utils.datasets import create_dataloader
|
||||
from utils.datasets import create_dataloader, create_folder
|
||||
from utils.general import labels_to_class_weights, increment_path, labels_to_image_weights, init_seeds, \
|
||||
fitness, strip_optimizer, get_latest_run, check_dataset, check_file, check_git_status, check_img_size, \
|
||||
check_requirements, print_mutation, set_logging, one_cycle, colorstr
|
||||
from utils.metrics import ap_per_class
|
||||
from utils.general import box_iou
|
||||
import os
|
||||
from utils.general import xywh2xyxy
|
||||
from collections import defaultdict
|
||||
def compute_iou(box1, box2):
|
||||
|
@ -166,6 +168,10 @@ def letterbox(im, new_shape=(640, 640), color=(114, 114, 114), auto=True, scaleu
|
|||
return im, r, (dw, dh)
|
||||
|
||||
def main(opt):
|
||||
|
||||
create_folder(opt.save_path)
|
||||
|
||||
pred_tgt_acm = list()
|
||||
p_r_iou = 0.5
|
||||
niou = 1
|
||||
# Model
|
||||
|
@ -318,10 +324,10 @@ def main(opt):
|
|||
# Like test.py
|
||||
labels = torch.tensor([[y.item()] + x for x, y in zip(img_gt_boxes_xyxy, img_gt_lbls)])
|
||||
nl = len(labels)
|
||||
if nl == ml_class_id.shape[0]:
|
||||
print('prob all TP')
|
||||
print('path', paths[seen-1])
|
||||
predn[pi, :4]
|
||||
# if nl == ml_class_id.shape[0]:
|
||||
# print('prob all TP')
|
||||
# print('path', paths[seen-1])
|
||||
# predn[pi, :4]
|
||||
tcls = labels[:, 0].tolist() if nl else [] # target class
|
||||
pred = torch.tensor([np.append(np.append(bboxes_, scores_above_th_.item()), ml_class_id_.item()) for bboxes_, ml_class_id_, scores_above_th_ in
|
||||
zip(bboxes, ml_class_id, scores_above_th_val)])
|
||||
|
@ -354,7 +360,7 @@ def main(opt):
|
|||
break
|
||||
stats.append((correct.cpu(), pred[:, 4].cpu(), pred[:, 5].cpu(),
|
||||
tcls)) # correct @ IOU=0.5 of pred box with target
|
||||
|
||||
pred_tgt_acm.append({'correct': correct.cpu().numpy(), 'conf': pred[:, 4].cpu().numpy(), 'pred_cls': pred[:, 5].cpu().numpy(), 'tcls': tcls} )
|
||||
|
||||
|
||||
else:
|
||||
|
@ -437,9 +443,12 @@ def main(opt):
|
|||
# print('Speed: %.1f/%.1f/%.1f ms inference/NMS/total per %gx%g image at batch-size %g' % t)
|
||||
# Predicted: [(bbox, class_id, confidence)]
|
||||
if is_new_model:
|
||||
df = pd.DataFrame(pred_tgt_acm)
|
||||
df.to_csv(os.path.join(opt.save_path, 'onnx_model_pred_tgt_acm_conf_th_' + str(det_threshold.__format__('.3f')) + '.csv'), index=False)
|
||||
|
||||
stats = [np.concatenate(x, 0) for x in zip(*stats)] # to numpy
|
||||
if len(stats) and stats[0].any(): # P, R @ # max F1 index
|
||||
p, r, ap, f1, ap_class = ap_per_class(*stats, plot=True, v5_metric=False, save_dir=save_path,
|
||||
p, r, ap, f1, ap_class = ap_per_class(*stats, plot=True, v5_metric=False, save_dir=opt.save_path,
|
||||
names=names) # based on correct @ IOU=0.5 of pred box with target
|
||||
for i, c in enumerate(ap_class):
|
||||
print(pf % (names[c], seen, nt[c], p[i], r[i], ap50[i], ap[i]))
|
||||
|
@ -474,7 +483,6 @@ if __name__ == '__main__':
|
|||
|
||||
opt = parser.parse_args()
|
||||
|
||||
|
||||
main(opt=opt)
|
||||
|
||||
"""
|
||||
|
|
|
@ -0,0 +1,37 @@
|
|||
lr0: 0.001 #0.001 # initial learning rate (SGD=1E-2, Adam=1E-3)
|
||||
lrf: 0.01 # final OneCycleLR learning rate (lr0 * lrf)
|
||||
momentum: 0.937 # SGD momentum/Adam beta1
|
||||
weight_decay: 0.0005 # optimizer weight decay 5e-4 It resolve mAP of overfitting test
|
||||
warmup_epochs: 3.0 # warmup epochs (fractions ok)
|
||||
warmup_momentum: 0.8 # warmup initial momentum
|
||||
warmup_bias_lr: 0.001 #0.001 # warmup initial bias lr
|
||||
loss_ota: 1 #1 # use ComputeLossOTA, use 0 for faster training
|
||||
box: 0.05 # box loss gain
|
||||
cls: 0.5 # cls loss gain
|
||||
cls_pw: 1.0 # cls BCELoss positive_weight
|
||||
obj: 1.0 # obj loss gain (scale with pixels)
|
||||
obj_pw: 1.0 # obj BCELoss positive_weight
|
||||
iou_t: 0.60 # like the default in the code was 0.2 IoU training threshold
|
||||
anchor_t: 4.0 # anchor-multiple threshold
|
||||
anchors: 2 # anchors per output layer (0 to ignore) @@HK was 3
|
||||
fl_gamma: 1.5 #1.5 # focal loss gamma (efficientDet default gamma=1.5)
|
||||
hsv_h: 0.0 # image HSV-Hue augmentation (fraction)
|
||||
hsv_s: 0.0 # image HSV-Saturation augmentation (fraction)
|
||||
hsv_v: 0.0 # image HSV-Value augmentation (fraction)
|
||||
degrees: 0 # image rotation (+/- deg)
|
||||
translate: 0.2 #0.2 # image translation (+/- fraction)
|
||||
scale: 0.5 # image scale (+/- gain)
|
||||
shear: 0.0 # image shear (+/- deg)
|
||||
perspective: 0.0 # image perspective (+/- fraction), range 0-0.001
|
||||
flipud: 0.3 # image flip up-down (probability)
|
||||
fliplr: 0.5 # image flip left-right (probability)
|
||||
mosaic: 0.5 # image mosaic (probability)
|
||||
mixup: 0.15 # image mixup (probability)
|
||||
copy_paste: 0.0 # image copy paste (probability)
|
||||
paste_in: 0.1 # 0.1 # image copy paste (probability), use 0 for faster training : cutout
|
||||
inversion: 0.5 #opposite temperature
|
||||
img_percentile_removal: 0.3
|
||||
beta : 0.3
|
||||
random_perspective : 1
|
||||
scaling_before_mosaic : 1
|
||||
gamma : 80 # percent 90 percente more stability to gamma
|
|
@ -0,0 +1,37 @@
|
|||
lr0: 0.001 #0.001 # initial learning rate (SGD=1E-2, Adam=1E-3)
|
||||
lrf: 0.01 # final OneCycleLR learning rate (lr0 * lrf)
|
||||
momentum: 0.937 # SGD momentum/Adam beta1
|
||||
weight_decay: 0.005 # optimizer weight decay 5e-4 It resolve mAP of overfitting test
|
||||
warmup_epochs: 0.0 # warmup epochs (fractions ok)
|
||||
warmup_momentum: 0.8 # warmup initial momentum
|
||||
warmup_bias_lr: 0.001 #0.001 # warmup initial bias lr
|
||||
loss_ota: 0 #1 # use ComputeLossOTA, use 0 for faster training
|
||||
box: 0.05 # box loss gain
|
||||
cls: 0.5 # cls loss gain
|
||||
cls_pw: 1.0 # cls BCELoss positive_weight
|
||||
obj: 1.0 # obj loss gain (scale with pixels)
|
||||
obj_pw: 1.0 # obj BCELoss positive_weight
|
||||
iou_t: 0.60 # like the default in the code was 0.2 IoU training threshold
|
||||
anchor_t: 4.0 # anchor-multiple threshold
|
||||
anchors: 2 # anchors per output layer (0 to ignore) @@HK was 3
|
||||
fl_gamma: 1.5 #1.5 # focal loss gamma (efficientDet default gamma=1.5)
|
||||
hsv_h: 0.0 # image HSV-Hue augmentation (fraction)
|
||||
hsv_s: 0.0 # image HSV-Saturation augmentation (fraction)
|
||||
hsv_v: 0.0 # image HSV-Value augmentation (fraction)
|
||||
degrees: 0 # image rotation (+/- deg)
|
||||
translate: 0.2 #0.2 # image translation (+/- fraction)
|
||||
scale: 0.5 # image scale (+/- gain)
|
||||
shear: 0.0 # image shear (+/- deg)
|
||||
perspective: 0.0 # image perspective (+/- fraction), range 0-0.001
|
||||
flipud: 0.3 # image flip up-down (probability)
|
||||
fliplr: 0.5 # image flip left-right (probability)
|
||||
mosaic: 0.5 # image mosaic (probability)
|
||||
mixup: 0.15 # image mixup (probability)
|
||||
copy_paste: 0.0 # image copy paste (probability)
|
||||
paste_in: 0.1 # 0.1 # image copy paste (probability), use 0 for faster training : cutout
|
||||
inversion: 0.5 #opposite temperature
|
||||
img_percentile_removal: 0.3
|
||||
beta : 0.3
|
||||
random_perspective : 1
|
||||
scaling_before_mosaic : 1
|
||||
gamma : 80 # percent 90 percente more stability to gamma
|
|
@ -28,10 +28,10 @@ mosaic: 0.5 # image mosaic (probability)
|
|||
mixup: 0.15 # image mixup (probability)
|
||||
copy_paste: 0.0 # image copy paste (probability)
|
||||
paste_in: 0.1 # image copy paste (probability), use 0 for faster training : cutout
|
||||
loss_ota: 0 #1 # use ComputeLossOTA, use 0 for faster training
|
||||
loss_ota: 1 #1 # use ComputeLossOTA, use 0 for faster training
|
||||
inversion: 0.5 #opposite temperature
|
||||
img_percentile_removal: 0.3
|
||||
beta : 0.3
|
||||
random_perspective : 1
|
||||
scaling_before_mosaic : 0
|
||||
scaling_before_mosaic : 1
|
||||
gamma : 80 # percent
|
||||
|
|
|
@ -5,6 +5,7 @@ weight_decay: 0.005 # optimizer weight decay 5e-4 It resolve mAP of overfittin
|
|||
warmup_epochs: 3.0 # warmup epochs (fractions ok)
|
||||
warmup_momentum: 0.8 # warmup initial momentum
|
||||
warmup_bias_lr: 0.001 #0.001 # warmup initial bias lr
|
||||
loss_ota: 1 #1 # use ComputeLossOTA, use 0 for faster training
|
||||
box: 0.05 # box loss gain
|
||||
cls: 0.5 # cls loss gain
|
||||
cls_pw: 1.0 # cls BCELoss positive_weight
|
||||
|
@ -28,7 +29,6 @@ mosaic: 0.5 # image mosaic (probability)
|
|||
mixup: 0.15 # image mixup (probability)
|
||||
copy_paste: 0.0 # image copy paste (probability)
|
||||
paste_in: 0.1 # 0.1 # image copy paste (probability), use 0 for faster training : cutout
|
||||
loss_ota: 0 #1 # use ComputeLossOTA, use 0 for faster training
|
||||
inversion: 0.5 #opposite temperature
|
||||
img_percentile_removal: 0.3
|
||||
beta : 0.3
|
||||
|
|
|
@ -5,7 +5,7 @@
|
|||
# sudo ln -s ~hanoch/projects/tir_frames_rois /mnt/Data/hanoch/tir_frames_rois
|
||||
path: /mnt/Data/hanoch/tir_frames_rois/yolo7_tir_data_all #/home/hanoch/projects/tir_frames_rois/tir_car_44person_31 #/home/hanochk/tir_frames_rois/yolo7_tir_data
|
||||
# train and val data as 1) directory: path/images/, 2) file: path/images.txt, or 3) list: [path1/images/, path2/images/]
|
||||
train: ./yolov7/tir_od/training_set_yolov7941_w_center_roi_no_swiss_no_38B_20b_22c_n_png.txt #./yolov7/tir_od/training_set_yolov7941_w_center_roi.txt
|
||||
train: ./yolov7/tir_od/training_set_noTest19G_swiss_w_mp4.txt #./yolov7/tir_od/training_set_yolov7941_w_center_roi.txt
|
||||
val: ./yolov7/tir_od/tir_tiff_w_center_roi_validation_set.txt #./yolov7/tir_od/tir_tiff_car_person_min_size_44_31_validation_set.txt #./yolov7/tir_od/validation_set.txt #./yolov7/tir_od/val_tir_od.txt #./yolov7/tir_od/validation_set.txt # 5000 images
|
||||
#test: ./tir_od/test-dev2017.txt # 20288 of 40670 images, submit to https://competitions.codalab.org/competitions/20794
|
||||
|
||||
|
|
|
@ -76,7 +76,7 @@ def detect(save_img=False):
|
|||
t0 = time.time()
|
||||
for path, img, im0s, vid_cap in dataset:
|
||||
|
||||
if os.path.basename(path).split('.')[1] == 'tiff':
|
||||
if os.path.basename(path).split('.')[1] == 'tiff': # im0s only for plotting version
|
||||
im0s = np.repeat(im0s[ :, :, np.newaxis], 3, axis=2) # convert GL to RGB by replication
|
||||
im0s = scaling_image(im0s, scaling_type=opt.norm_type)
|
||||
if im0s.max()<=1:
|
||||
|
@ -214,7 +214,7 @@ if __name__ == '__main__':
|
|||
|
||||
parser.add_argument('--tir-channel-expansion', action='store_true', help='drc_per_ch_percentile')
|
||||
|
||||
parser.add_argument('--input-channels', type=int, default=3, help='')
|
||||
parser.add_argument('--input-channels', type=int, default=1, help='')
|
||||
|
||||
parser.add_argument('--save-path', default='', help='save to project/name')
|
||||
|
||||
|
@ -246,6 +246,11 @@ python -u ./yolov7/detect.py --weights ./yolov7/yolov7.pt --conf 0.25 --img-size
|
|||
--weights ./yolov7/yolov7.pt --conf 0.25 --img-size 640 --device 0 --save-txt --norm-type single_image_percentile_0_1 --source /home/hanoch/projects/tir_frames_rois/yolo7_tir_data_all/TIR10_V50_OCT21_Test46A_ML_RD_IL_2021_08_05_14_48_05_FS_210_XGA_630_922_DENIS_right_roi_210_881.tiff
|
||||
--weights ./yolov7/yolov7.pt --conf 0.25 --img-size 640 --device 0 --save-txt --norm-type single_image_percentile_0_1 --source /home/hanoch/projects/tir_frames_rois/yolo7_tir_data_all/TIR135_V80_JUL23_Test55A_SY_RD_US_2023_01_18_07_29_38_FS_50_XGA_0001_3562_Shahar_left_roi_50_1348.tiff
|
||||
|
||||
--weights /mnt/Data/hanoch/runs/train/yolov7999/weights/best.pt --conf 0.25 --img-size 640 --device 0 --save-txt --norm-type single_image_percentile_0_1 --source /home/hanoch/projects/tir_frames_rois/fog/28_02_2019_16_05_01[1]_04783.tiff
|
||||
|
||||
|
||||
|
||||
|
||||
YOLO model
|
||||
--weights ./yolov7/yolov7.pt --conf 0.25 --img-size 640 --device 0 --save-txt --norm-type single_image_percentile_0_1 --source /home/hanoch/projects/tir_od/Snipaste_2024-09-15_09-00-58_tir_135_TIR135_V80_JUL23_Test55A_SY_RD_US_2023_01_18_07_29_38_FS_50_XGA_0001_3562_Shahar_left_roi_50_1348.png
|
||||
|
||||
|
|
2
test.py
2
test.py
|
@ -321,7 +321,7 @@ def test(data,
|
|||
stats_person_medium = [np.concatenate(x, 0) for x in zip(*stats_person_medium)] # to numpy
|
||||
stats_all_large = [np.concatenate(x, 0) for x in zip(*stats_all_large)] # to numpy
|
||||
|
||||
if len(stats) and stats[0].any(): # P, R @ # max F1 index
|
||||
if len(stats) and stats[0].any(): # P, R @ # max F1 index if any correct prediction
|
||||
p, r, ap, f1, ap_class = ap_per_class(*stats, plot=plots, v5_metric=v5_metric, save_dir=save_dir, names=names) #based on correct @ IOU=0.5 of pred box with target
|
||||
if not training or 1:
|
||||
if bool(stats_person_medium):
|
||||
|
|
23
train.py
23
train.py
|
@ -57,10 +57,11 @@ if clear_ml: # clearml support
|
|||
|
||||
task = Task.init(
|
||||
project_name="TIR_OD",
|
||||
task_name="train yolov7 with dummy test"
|
||||
task_name="train yolov7 with dummy test" # output_uri = True model torch.save will uploaded to file server or =/mnt/myfolder or AWS or Azure
|
||||
)
|
||||
# Task.execute_remotely() will invoke the job immidiately over the remote and not DeV
|
||||
task.set_base_docker(docker_image="nvcr.io/nvidia/pytorch:24.09-py3", docker_arguments="--shm-size 8G")
|
||||
# clear_ml can capture graph like tensorboard
|
||||
|
||||
gradient_clip_value = 100.0
|
||||
opt_gradient_clipping = True
|
||||
|
@ -417,7 +418,7 @@ def train(hyp, opt, device, tb_writer=None):
|
|||
# Start training
|
||||
t0 = time.time()
|
||||
if hyp['warmup_epochs'] !=0: # otherwise it is forced to 1000 iterations
|
||||
nw = max(round(hyp['warmup_epochs'] * nb), 1000) # number of warmup iterations, max(3 epochs, 1k iterations)
|
||||
nw = max(round(hyp['warmup_epochs'] * nb), 1000) # number of warmup iterations, max(3 epochs, 1k iterations) # HK@@ bad for overfitting test where few examples i.e itoo few iterations
|
||||
else:
|
||||
nw = 0
|
||||
# nw = min(nw, (epochs - start_epoch) / 2 * nb) # limit warmup to < 1/2 of training
|
||||
|
@ -513,7 +514,6 @@ def train(hyp, opt, device, tb_writer=None):
|
|||
if opt.quad:
|
||||
loss *= 4.
|
||||
|
||||
|
||||
# HK TODO : https://discuss.pytorch.org/t/switching-between-mixed-precision-training-and-full-precision-training-after-training-is-started/132366/4 remove scaler backwards
|
||||
# Backward
|
||||
scaler.scale(loss).backward()
|
||||
|
@ -554,14 +554,14 @@ def train(hyp, opt, device, tb_writer=None):
|
|||
# import tifffile
|
||||
# for ix, img in enumerate(imgs):
|
||||
# print(ix, torch.std(img), torch.quantile(img, 0.5))
|
||||
# tifffile.imwrite(os.path.join('/home/hanoch/projects/tir_od', 'img_scl_bef_mosaic' + str(ix)+'.tiff'),
|
||||
# tifffile.imwrite(os.path.join('/home/hanoch/projects/tir_od/outputs', 'img_scl_bef_mosaic' + str(ix)+'.tiff'),
|
||||
# img.cpu().numpy().transpose(1, 2, 0))
|
||||
#
|
||||
|
||||
# Plot
|
||||
if plots and ni < 100:
|
||||
f = save_dir / f'train_batch{ni}.jpg' # filename
|
||||
Thread(target=plot_images, args=(imgs, targets, paths, f, opt.input_channels), daemon=True).start()
|
||||
Thread(target=plot_images, args=(imgs, targets, paths, f), daemon=True).start()
|
||||
# if tb_writer:
|
||||
# tb_writer.add_image(f, result, dataformats='HWC', global_step=epoch)
|
||||
# tb_writer.add_graph(torch.jit.trace(model, imgs, strict=False), []) # add model graph
|
||||
|
@ -811,6 +811,8 @@ if __name__ == '__main__':
|
|||
assert opt.batch_size % opt.world_size == 0, '--batch-size must be multiple of CUDA device count'
|
||||
opt.batch_size = opt.total_batch_size // opt.world_size
|
||||
|
||||
defualt_random_pad = True # lazy hyp def
|
||||
|
||||
# clearml support
|
||||
if clear_ml: #clearml support
|
||||
config_file = task.connect_configuration(opt.hyp, name='hyperparameters_cfg')
|
||||
|
@ -825,7 +827,7 @@ if __name__ == '__main__':
|
|||
#defaults for backward compatible hyp files whree not set
|
||||
hyp['person_size_small_medium_th'] = hyp.get('person_size_small_medium_th', 32 * 32)
|
||||
hyp['car_size_small_medium_th'] = hyp.get('car_size_small_medium_th', 44 * 44)
|
||||
hyp['random_pad'] = True # lazy hyp def
|
||||
hyp['random_pad'] = hyp.get('random_pad', defualt_random_pad)
|
||||
|
||||
|
||||
# Train
|
||||
|
@ -957,7 +959,14 @@ FT : you need the --cfg of arch yaml because nc-classes are changing
|
|||
--workers 8 --device 0 --batch-size 16 --data data/tir_od.yaml --img 640 640 --weights ./yolov7/yolov7-tiny.pt --cfg cfg/training/yolov7-tiny.yaml --name yolov7 --hyp hyp.tir_od.tiny_aug.yaml --adam --norm-type single_image_mean_std --input-channels 3 --linear-lr --epochs 2
|
||||
|
||||
|
||||
--workers 8 --device 0 --batch-size 32 --data data/tir_od.yaml --img 640 640 --weights /mnt/Data/hanoch/tir_frames_rois/yolov7.pt --cfg cfg/training/yolov7.yaml --name yolov7 --hyp hyp.tir_od.tiny_aug_gamma_scaling_before_mosaic.yaml --adam --norm-type single_image_percentile_0_1 --input-channels 1 --linear-lr --epochs 100 --nosave --gamma-aug-prob 0.2 --cache-images
|
||||
--workers 8 --device 0 --batch-size 32 --data data/tir_od_center_roi_aug_list.yaml --img-size 640 --weights /mnt/Data/hanoch/tir_frames_rois/yolov7.pt --cfg cfg/training/yolov7.yaml --name yolov7 --hyp hyp.tir_od.tiny_aug_gamma_scaling_before_mosaic.yaml --adam --norm-type single_image_percentile_0_1 --input-channels 1 --linear-lr --epochs 100 --nosave --gamma-aug-prob 0.2 --cache-images
|
||||
|
||||
Extended model for higher resolution
|
||||
# --workers 8 --device 0 --batch-size 8 --data data/tir_od_center_roi_aug_list_full_res.yaml --weights /mnt/Data/hanoch/tir_frames_rois/yolov7-e6.pt --img-size [768, 1024] --cfg cfg/deploy/yolov7-e6.yaml --name yolov7e --hyp hyp.tir_od.aug_gamma_scaling_before_mosaic_rnd_scaling_e6_full_res.yaml --adam --norm-type single_image_percentile_0_1 --input-channels 1 --linear-lr --epochs 2 --gamma-aug-prob 0.3 --cache-images --rect
|
||||
# --workers 8 --device 0 --batch-size 8 --data data/tir_od_center_roi_aug_list_full_res.yaml --weights /mnt/Data/hanoch/tir_frames_rois/yolov7-e6.pt --img-size 1024 --cfg cfg/deploy/yolov7-e6.yaml --name yolov7e --hyp hyp.tir_od.aug_gamma_scaling_before_mosaic_rnd_scaling_e6_full_res.yaml --adam --norm-type single_image_percentile_0_1 --input-channels 1 --linear-lr --epochs 10 --gamma-aug-prob 0.3 --cache-images
|
||||
# --workers 1 --device 0 --batch-size 8 --data data/tir_od_center_roi_aug_list_full_res.yaml --weights /mnt/Data/hanoch/tir_frames_rois/yolov7-e6.pt --img-size 1024 --cfg cfg/deploy/yolov7-e6.yaml --name yolov7e --hyp hyp.tir_od.aug_gamma_scaling_before_mosaic_rnd_scaling_e6_full_res.yaml --adam --norm-type single_image_percentile_0_1 --input-channels 1 --linear-lr --epochs 10 --gamma-aug-prob 0.3 --cache-images
|
||||
|
||||
--workers 8 --device 0 --batch-size 8 --data data/tir_od_center_roi_aug_list_full_res.yaml --weights /mnt/Data/hanoch/tir_frames_rois/yolov7-e6.pt --img-size 1024 --cfg cfg/deploy/yolov7-e6.yaml --name yolov7e --hyp hyp.tir_od.aug_gamma_scaling_before_mosaic_rnd_scaling_e6_full_res.yaml --adam --norm-type single_image_percentile_0_1 --input-channels 1 --linear-lr --epochs 150 --gamma-aug-prob 0.3 --cache-images --project runs/train_7e
|
||||
|
||||
class EMA_Clip(EMA):
|
||||
#Exponential moving average
|
||||
|
|
|
@ -71,25 +71,49 @@ def exif_size(img):
|
|||
|
||||
# import warnings
|
||||
# warnings.filterwarnings('error', category=RuntimeWarning)
|
||||
def scaling_image(img, scaling_type, percentile=0.03, beta=0.3):
|
||||
def scaling_image(img, scaling_type, percentile:float =0.03,
|
||||
beta:float =0.3, roi :tuple=(), img_size: int=640):
|
||||
if scaling_type == 'no_norm':
|
||||
if bool(roi):
|
||||
raise
|
||||
img = img
|
||||
|
||||
elif scaling_type == 'standardization': # default by repo
|
||||
if bool(roi):
|
||||
raise
|
||||
img = img/ 255.0
|
||||
|
||||
elif scaling_type =="single_image_0_to_1":
|
||||
if bool(roi):
|
||||
raise
|
||||
max_val = np.max(img.ravel())
|
||||
min_val = np.min(img.ravel())
|
||||
img = np.double(img - min_val) / (np.double(max_val - min_val) + eps)
|
||||
img = np.minimum(np.maximum(img, 0), 1)
|
||||
|
||||
elif scaling_type == 'single_image_mean_std':
|
||||
if bool(roi):
|
||||
raise
|
||||
img = (img - img.ravel().mean()) / img.ravel().std()
|
||||
|
||||
elif scaling_type == 'single_image_percentile_0_1':
|
||||
min_val = np.percentile(img.ravel(), percentile)
|
||||
max_val = np.percentile(img.ravel(), 100-percentile)
|
||||
if bool(roi):
|
||||
dw, dh = img_size[1] - roi[1], img_size[0] - roi[0] # wh padding
|
||||
dw /= 2 # divide padding into 2 sides
|
||||
dh /= 2
|
||||
top, bottom = int(round(dh - 0.1)), int(round(dh + 0.1))
|
||||
left, right = int(round(dw - 0.1)), int(round(dw + 0.1))
|
||||
|
||||
if len(img.shape) == 2:
|
||||
img_crop = img[bottom:-top, :]
|
||||
else:
|
||||
img_crop = img[:, bottom:-top, :]
|
||||
|
||||
min_val = np.percentile(img_crop.ravel(), percentile)
|
||||
max_val = np.percentile(img_crop.ravel(), 100-percentile)
|
||||
else:
|
||||
min_val = np.percentile(img.ravel(), percentile)
|
||||
max_val = np.percentile(img.ravel(), 100-percentile)
|
||||
img = np.double(img - min_val) / (np.double(max_val - min_val) + eps)
|
||||
img = np.minimum(np.maximum(img, 0), 1)
|
||||
|
||||
|
@ -98,6 +122,8 @@ def scaling_image(img, scaling_type, percentile=0.03, beta=0.3):
|
|||
# max_val = np.percentile(img.ravel(), 100 - percentile)
|
||||
# img = np.double(img - min_val) / np.double(max_val - min_val)
|
||||
# img = np.uint8(np.minimum(np.maximum(img, 0), 1)*255)
|
||||
if bool(roi):
|
||||
raise
|
||||
ImgMin = np.percentile(img, percentile)
|
||||
ImgMax = np.percentile(img, 100-percentile)
|
||||
ImgDRC = (np.double(img - ImgMin) / (np.double(ImgMax - ImgMin)) * 255 + eps)
|
||||
|
@ -107,6 +133,8 @@ def scaling_image(img, scaling_type, percentile=0.03, beta=0.3):
|
|||
|
||||
|
||||
elif scaling_type == 'remove+global_outlier_0_1':
|
||||
if bool(roi):
|
||||
raise
|
||||
img = np.double(img - img.min()*(beta))/np.double(img.max()*(1-beta) - img.min()*(beta)) # beta in [percentile]
|
||||
img = np.double(np.minimum(np.maximum(img, 0), 1))
|
||||
elif scaling_type == 'normalization_uint16':
|
||||
|
@ -310,7 +338,7 @@ class LoadImages: # for inference
|
|||
plt.hist(img.ravel(), bins=128)
|
||||
plt.savefig(os.path.join('/home/hanoch/projects/tir_od/outputs', os.path.basename(path).split('.')[0]+ 'pre'))
|
||||
|
||||
file_type = os.path.basename(self.img_files[index]).split('.')[-1].lower()
|
||||
file_type = os.path.basename(path).split('.')[-1].lower()
|
||||
|
||||
if (file_type !='tiff' and file_type != 'png'):
|
||||
print('!!!!!!!!!!!!!!!! index : {} {} unrecognized '.format(index, self.img_files[index]))
|
||||
|
@ -499,7 +527,7 @@ class LoadImagesAndLabels(Dataset): # for training/testing
|
|||
self.hyp = hyp
|
||||
self.image_weights = image_weights
|
||||
self.rect = False if image_weights else rect
|
||||
self.mosaic = self.augment and not self.rect # load 4 images at a time into a mosaic (only during training) @@ HK TODO: disable mosaic implicitly by prob mosaic =0
|
||||
self.mosaic = self.augment and not self.rect # load 4 images at a time into a mosaic (only during training)
|
||||
self.mosaic_border = [-img_size // 2, -img_size // 2]
|
||||
self.stride = stride
|
||||
self.path = path
|
||||
|
@ -734,6 +762,7 @@ class LoadImagesAndLabels(Dataset): # for training/testing
|
|||
img = (img * r + img2 * (1 - r)).astype(img.dtype)#.astype(np.uint8)
|
||||
labels = np.concatenate((labels, labels2), 0)
|
||||
|
||||
|
||||
else:
|
||||
# Load image
|
||||
img, (h0, w0), (h, w) = load_image(self, index)
|
||||
|
@ -741,13 +770,15 @@ class LoadImagesAndLabels(Dataset): # for training/testing
|
|||
# Letterbox
|
||||
shape = self.batch_shapes[self.batch[index]] if self.rect else self.img_size # final letterboxed shape
|
||||
# img, ratio, pad = letterbox(img, shape, color=(img.mean(), img.mean(), img.mean()), auto=False, scaleup=self.augment)
|
||||
img, ratio, pad = letterbox(img, shape, auto=False, scaleup=self.augment)
|
||||
img, ratio, pad = letterbox(img, shape, auto=False, scaleup=self.augment, random_pad=self.random_pad)
|
||||
|
||||
shapes = (h0, w0), ((h / h0, w / w0), pad) # for COCO mAP rescaling
|
||||
|
||||
labels = self.labels[index].copy()
|
||||
if labels.size: # normalized xywh to pixel xyxy format
|
||||
labels[:, 1:] = xywhn2xyxy(labels[:, 1:], ratio[0] * w, ratio[1] * h, padw=pad[0], padh=pad[1])
|
||||
|
||||
|
||||
if self.tir_channel_expansion: # HK @@ according to the paper this CE is a sort of augmentation hence no need to preliminary augment. One of the channels are inversion hence avoid channel inversion aug
|
||||
img = np.repeat(img[np.newaxis, :, :], 3, axis=0) # convert GL to RGB by replication
|
||||
img_ce = np.zeros_like(img).astype('float64')
|
||||
|
@ -791,7 +822,7 @@ class LoadImagesAndLabels(Dataset): # for training/testing
|
|||
# GL gain/attenuation
|
||||
# Squeeze pdf (x-mu)*scl+mu
|
||||
#img, labels = self.albumentations(img, labels)
|
||||
img = self.albumentations_gamma_contrast(img)
|
||||
img = self.albumentations_gamma_contrast(img) # apply RandomBrightnessContrast only since it has buggy response
|
||||
|
||||
if random.random() < hyp['gamma_liklihood']:
|
||||
if img.dtype == np.uint16 or img.dtype == np.uint8:
|
||||
|
@ -865,7 +896,11 @@ class LoadImagesAndLabels(Dataset): # for training/testing
|
|||
# tifffile.imwrite(os.path.join('/home/hanoch/projects/tir_od', 'img_ce.tiff'), 255*img.transpose(1,2,0).astype('uint8'))
|
||||
if not self.tir_channel_expansion:
|
||||
if self.is_tir_signal:
|
||||
img = np.repeat(img[np.newaxis, :, :], self.input_channels, axis=0) #convert GL to RGB by replication
|
||||
if len(img.shape) == 2:
|
||||
img = np.repeat(img[np.newaxis, :, :], self.input_channels, axis=0) #convert GL to 3-ch if any RGB by replication
|
||||
print('Warning , TIR image should be 3dim by now (w,h,1)', 100*'*')
|
||||
else:
|
||||
img = np.repeat(img.transpose(2, 0, 1), self.input_channels, axis=0)
|
||||
else:
|
||||
# Convert
|
||||
img = img[:, :, ::-1].transpose(2, 0, 1) # BGR to RGB, to 3x416x416
|
||||
|
@ -874,23 +909,42 @@ class LoadImagesAndLabels(Dataset): # for training/testing
|
|||
import matplotlib.pyplot as plt
|
||||
plt.figure()
|
||||
plt.hist(img.ravel(), bins=128)
|
||||
plt.savefig(os.path.join('/home/hanoch/projects/tir_od/outputs', os.path.basename(self.img_files[index]).split('.')[0]+ 'pre_' +str(self.scaling_type)))
|
||||
plt.savefig(os.path.join('/home/hanoch/projects/tir_od/output', os.path.basename(self.img_files[index]).split('.')[0]+ 'pre_' +str(self.scaling_type)))
|
||||
|
||||
# import tifffile
|
||||
# tifffile.imwrite(os.path.join('/home/hanoch/projects/tir_od', 'img_before_last_scaling.tiff'), img.transpose(1,2,0))
|
||||
# tifffile.imwrite(os.path.join('/home/hanoch/projects/tir_od/output', 'img_loaded_before_scaling_' + '_' +str(str(img.max())) + '_' +str(self.img_files[index].split('/')[-1].split('.tiff')[0]) + '.tiff'),
|
||||
# (img.transpose(1, 2, 0)))
|
||||
|
||||
|
||||
# In case moasaic of mixed PNG and TIFF the TIFF is pre scaled while the PNG shouldn;t
|
||||
if file_type != 'png':
|
||||
img_orig, _, _ = load_image(self, index)
|
||||
loaded_img_shape = img_orig.shape[:2]
|
||||
new_shape = self.img_size
|
||||
if isinstance(self.img_size, int): # if list then the 2d dim is embedded
|
||||
new_shape = (new_shape, new_shape)
|
||||
if new_shape != loaded_img_shape:
|
||||
roi = loaded_img_shape
|
||||
img_size = new_shape
|
||||
else: # don't do nothing normaliza the entire image
|
||||
roi = ()
|
||||
img_size = loaded_img_shape
|
||||
|
||||
if self.rect:
|
||||
raise ValueError('not supported')
|
||||
|
||||
img = scaling_image(img, scaling_type=self.scaling_type,
|
||||
percentile=self.percentile, beta=self.beta)
|
||||
percentile=self.percentile, beta=self.beta,
|
||||
roi=roi, img_size=img_size)
|
||||
else:
|
||||
img = scaling_image(img, scaling_type='single_image_0_to_1') # safer in case double standartiozation one before mosaic and her the last one since mosaic is random based occurance
|
||||
|
||||
# print('ka')
|
||||
if 0:
|
||||
import matplotlib.pyplot as plt
|
||||
plt.figure()
|
||||
# plt.figure()
|
||||
plt.hist(img.ravel(), bins=128)
|
||||
plt.savefig(os.path.join('/home/hanoch/projects/tir_od/output', os.path.basename(self.img_files[index]).split('.')[0] + 'post_'+ str(self.scaling_type)))
|
||||
plt.savefig(os.path.join('/home/hanoch/projects/tir_od/output', os.path.basename(self.img_files[index]).split('.')[0] + '_hist_post_scaling_'+ str(self.scaling_type)))
|
||||
# aa1 = np.repeat(img[1,:,:,:].cpu().permute(1,2,0).numpy(), 3, axis=2).astype('float32')
|
||||
# cv2.imwrite('test/exp40/test_batch88_labels__1.jpg', aa1*255)
|
||||
# aa1 = np.repeat(img.transpose(1,2,0), 3, axis=2).astype('float32')
|
||||
|
@ -898,10 +952,14 @@ class LoadImagesAndLabels(Dataset): # for training/testing
|
|||
if np.isnan(img).any():
|
||||
print('img {} index : {} is nan fin'.format(self.img_files[index], index))
|
||||
# raise
|
||||
# tag='png'
|
||||
# import tifffile
|
||||
# tifffile.imwrite(os.path.join('/home/hanoch/projects/tir_od/output', 'img_loaded__' + tag +'__' +str(self.img_files[index].split('/')[-1].split('.tiff')[0]) + '.tiff'),
|
||||
# img.transpose(1, 2, 0))
|
||||
# try:
|
||||
# tag='full_rect'
|
||||
# import tifffile
|
||||
# tifffile.imwrite(os.path.join('/home/hanoch/projects/tir_od/output', 'img_loaded__' + tag +'__' +str(self.img_files[index].split('/')[-1].split('.tiff')[0]) + '.tiff'),
|
||||
# (img.transpose(1, 2, 0)*2**16).astype('uint16'))
|
||||
# except Exception as e:
|
||||
# print(f'\nfailed reading: due to {str(e)}')
|
||||
|
||||
# #
|
||||
img = np.ascontiguousarray(img)
|
||||
# print('\n 2nd', img.shape)
|
||||
|
@ -1288,6 +1346,7 @@ def init_image_plane(self, img, s, n_div=2):
|
|||
else:
|
||||
img4 = np.full((s * n_div, s * n_div, img.shape[2]), img.mean(),
|
||||
dtype=img.dtype) # base image with 4 tiles fill with 0.5 in [0 1] equals to 114 in [0 255]
|
||||
img4 = img4[:s*n_div, :s*n_div] # in case rectangle shape, AR>1, than crop the padding plane according to the right final shape
|
||||
return img4
|
||||
|
||||
|
||||
|
@ -1382,8 +1441,9 @@ def replicate(img, labels):
|
|||
return img, labels
|
||||
|
||||
|
||||
def letterbox(img, new_shape=(640, 640), color=(114, 114, 114), auto=True, scaleFill=False, scaleup=True, stride=32):
|
||||
# Resize and pad image while meeting stride-multiple constraints
|
||||
def letterbox(img, new_shape=(640, 640), color=(114, 114, 114),
|
||||
auto=True, scaleFill=False, scaleup=True, stride=32, random_pad=False):
|
||||
# Resize and pad image while meeting stride-multiple constraints i.e. 32
|
||||
shape = img.shape[:2] # current shape [height, width]
|
||||
if isinstance(new_shape, int):
|
||||
new_shape = (new_shape, new_shape)
|
||||
|
@ -1409,9 +1469,21 @@ def letterbox(img, new_shape=(640, 640), color=(114, 114, 114), auto=True, scale
|
|||
|
||||
if shape[::-1] != new_unpad: # resize
|
||||
img = cv2.resize(img, new_unpad, interpolation=cv2.INTER_LINEAR)
|
||||
|
||||
top, bottom = int(round(dh - 0.1)), int(round(dh + 0.1))
|
||||
left, right = int(round(dw - 0.1)), int(round(dw + 0.1))
|
||||
img = cv2.copyMakeBorder(img, top, bottom, left, right, cv2.BORDER_CONSTANT, value=color) # add border
|
||||
if random_pad and dh>0: # recatangle image with padding is expected
|
||||
img_plane = init_random_image_plane(img, s=max(img.shape), n_div=1)
|
||||
# img_plane = img_plane[:img.shape[0], :img.shape[1]] # in case rectangle shape, AR>1, than crop the padding plane according to the right final shape
|
||||
img_plane[bottom:-top, :] = img
|
||||
img = img_plane
|
||||
# img[:bottom, :] = img_plane[:bottom, :]
|
||||
# img[-top:, :] = img_plane[-top:, :]
|
||||
else:
|
||||
n_ch = img.shape[-1]
|
||||
img = cv2.copyMakeBorder(img, top, bottom, left, right, cv2.BORDER_CONSTANT, value=color) # add border
|
||||
if n_ch == 1 and len(img.shape) == 2: # fixing bug in cv2 where n_ch==1 no explicit consideration
|
||||
img = img[..., None]
|
||||
return img, ratio, (dw, dh)
|
||||
|
||||
|
||||
|
@ -1457,11 +1529,16 @@ def random_perspective(img, targets=(), segments=(), degrees=10, translate=.1, s
|
|||
if (border[0] != 0) or (border[1] != 0) or (M != np.eye(3)).any(): # image changed
|
||||
if is_fill_by_mean_img:
|
||||
filling_value = int(img.mean()+1) # filling value can be only an integer hance when scaling before mosaic signal is [0,1] then in the random perspective the posibilities for filling values are 0 or 1
|
||||
n_ch = img.shape[-1]
|
||||
|
||||
if perspective:
|
||||
img = cv2.warpPerspective(img, M, dsize=(width, height), borderValue=(filling_value, filling_value, filling_value))
|
||||
else: # affine
|
||||
img = cv2.warpAffine(img, M[:2], dsize=(width, height), borderValue=(filling_value, filling_value, filling_value))
|
||||
|
||||
if n_ch == 1 and len(img.shape) == 2: # fixing bug in cv2 where n_ch==1 no explicit consideration
|
||||
img = img[..., None]
|
||||
|
||||
# import tifffile
|
||||
# unique_run_name = str(int(time.time_ns()))
|
||||
#
|
||||
|
@ -1472,7 +1549,8 @@ def random_perspective(img, targets=(), segments=(), degrees=10, translate=.1, s
|
|||
|
||||
pad_w = int((width - np.round(width * s)) // 2)
|
||||
pad_h = int((height - np.round(height * s)) // 2)
|
||||
img_plane = init_random_image_plane(img, s=img.shape[0], n_div=1)
|
||||
img_plane = init_random_image_plane(img, s=max(img.shape), n_div=1)
|
||||
img_plane = img_plane[:img.shape[0], :img.shape[1]] # in case rectangle shape, AR>1, than crop the padding plane according to the right final shape
|
||||
|
||||
if pad_w + int(T[0, 2] - width/2) >0:
|
||||
# Left padding
|
||||
|
|
|
@ -210,8 +210,12 @@ def plot_pr_curve(px, py, ap, save_dir='pr_curve.png', names=(), precisions_of_i
|
|||
ax.plot(recall_of_interest_per_class, precisions_of_interest, '*', color='green')
|
||||
for k in range(len(precisions_of_interest)):
|
||||
ax.plot(recall_of_interest_per_class[k], precisions_of_interest[k], '*', color='green')
|
||||
ax.text(x=recall_of_interest_per_class[k], y=precisions_of_interest[k], fontsize=12,
|
||||
s=f"th={conf_at_precision_of_iterest[k]:.2f}")
|
||||
try:
|
||||
ax.text(x=recall_of_interest_per_class[k], y=precisions_of_interest[k], fontsize=12,
|
||||
s=f"th={conf_at_precision_of_iterest[k]:.2f}")
|
||||
except Exception as e:
|
||||
print(f'WARNING: cant plot recall of interest too few data or so: {e}')
|
||||
|
||||
# ax.text(x=0.6, y=precisions_of_interest[i], fontsize=12, s=f" R/P {names[i]}[ {recall_of_interest_per_class[i]:.3f} {precisions_of_interest[i]:.3f}]")
|
||||
# ax.text(x=0.6, y=max(0.9-0.2*i, 0), fontsize=12, s=f" R/P {names[i]}[ {recall_of_interest_per_class[i]:.3f} {precisions_of_interest[i]:.3f}]")
|
||||
if k == 0:
|
||||
|
|
|
@ -116,7 +116,7 @@ def output_to_target(output):
|
|||
return np.array(targets)
|
||||
|
||||
|
||||
def plot_images(images, targets, paths=None, fname='images.jpg', input_channels=3, names=None, max_size=640, max_subplots=16):
|
||||
def plot_images(images, targets, paths=None, fname='images.jpg', names=None, max_size=640, max_subplots=16):
|
||||
# Plot image grid with labels
|
||||
|
||||
if isinstance(images, torch.Tensor):
|
||||
|
@ -150,8 +150,11 @@ def plot_images(images, targets, paths=None, fname='images.jpg', input_channels=
|
|||
block_y = int(h * (i % ns))
|
||||
|
||||
img = img.transpose(1, 2, 0)
|
||||
n_ch = img.shape[-1]
|
||||
if scale_factor < 1:
|
||||
img = cv2.resize(img, (w, h))
|
||||
if n_ch == 1:
|
||||
img = img[..., None]
|
||||
|
||||
if img.shape[2] > 1: # GL no permute
|
||||
# Convert
|
||||
|
|
Loading…
Reference in New Issue