mirror of https://github.com/WongKinYiu/yolov7.git
modify TIR channel expansion to be w/o augmentation
parent
abdcce0e70
commit
e7c36bab68
|
@ -30,7 +30,8 @@ copy_paste: 0.0 # image copy paste (probability)
|
|||
paste_in: 0.0 # image copy paste (probability), use 0 for faster training : cutout
|
||||
loss_ota: 0 #1 # use ComputeLossOTA, use 0 for faster training
|
||||
inversion: 0 #opposite temperature
|
||||
tir_channel_expansion: 0.3 #[0, 0.2, 0.5]
|
||||
img_percentile_removal: 0.3
|
||||
beta : 0.3
|
||||
random_perspective : 0
|
||||
gamma : 80 # percent
|
||||
gamma_liklihood: 0.0
|
|
@ -0,0 +1,36 @@
|
|||
lr0: 0.005 #0.001 # initial learning rate (SGD=1E-2, Adam=1E-3)
|
||||
lrf: 0.01 # final OneCycleLR learning rate (lr0 * lrf)
|
||||
momentum: 0.937 # SGD momentum/Adam beta1
|
||||
weight_decay: 0.005 # optimizer weight decay 5e-4 It resolve mAP of overfitting test
|
||||
warmup_epochs: 3.0 # warmup epochs (fractions ok)
|
||||
warmup_momentum: 0.8 # warmup initial momentum
|
||||
warmup_bias_lr: 0.001 #0.001 # warmup initial bias lr
|
||||
box: 0.05 # box loss gain
|
||||
cls: 0.5 # cls loss gain
|
||||
cls_pw: 1.0 # cls BCELoss positive_weight
|
||||
obj: 1.0 # obj loss gain (scale with pixels)
|
||||
obj_pw: 1.0 # obj BCELoss positive_weight
|
||||
iou_t: 0.60 # like the default in the code was 0.2 IoU training threshold
|
||||
anchor_t: 4.0 # anchor-multiple threshold
|
||||
anchors: 2 # anchors per output layer (0 to ignore) @@HK was 3
|
||||
fl_gamma: 1.5 #1.5 # focal loss gamma (efficientDet default gamma=1.5)
|
||||
hsv_h: 0.0 # image HSV-Hue augmentation (fraction)
|
||||
hsv_s: 0.0 # image HSV-Saturation augmentation (fraction)
|
||||
hsv_v: 0.0 # image HSV-Value augmentation (fraction)
|
||||
degrees: 0 # image rotation (+/- deg)
|
||||
translate: 0 # image translation (+/- fraction)
|
||||
scale: 0 # image scale (+/- gain)
|
||||
shear: 0.0 # image shear (+/- deg)
|
||||
perspective: 0.0 # image perspective (+/- fraction), range 0-0.001
|
||||
flipud: 0.3 # image flip up-down (probability)
|
||||
fliplr: 0.5 # image flip left-right (probability)
|
||||
mosaic: 0.5 # image mosaic (probability)
|
||||
mixup: 0.15 # image mixup (probability)
|
||||
copy_paste: 0.0 # image copy paste (probability)
|
||||
paste_in: 0.0 # image copy paste (probability), use 0 for faster training : cutout
|
||||
loss_ota: 0 #1 # use ComputeLossOTA, use 0 for faster training
|
||||
inversion: 0 #opposite temperature
|
||||
img_percentile_removal: 0.3
|
||||
beta : 0.3
|
||||
random_perspective : 0
|
||||
gamma : 80 # percent
|
|
@ -0,0 +1,36 @@
|
|||
lr0: 0.005 #0.001 # initial learning rate (SGD=1E-2, Adam=1E-3)
|
||||
lrf: 0.01 # final OneCycleLR learning rate (lr0 * lrf)
|
||||
momentum: 0.937 # SGD momentum/Adam beta1
|
||||
weight_decay: 0.005 # optimizer weight decay 5e-4 It resolve mAP of overfitting test
|
||||
warmup_epochs: 3.0 # warmup epochs (fractions ok)
|
||||
warmup_momentum: 0.8 # warmup initial momentum
|
||||
warmup_bias_lr: 0.001 #0.001 # warmup initial bias lr
|
||||
box: 0.05 # box loss gain
|
||||
cls: 0.5 # cls loss gain
|
||||
cls_pw: 1.0 # cls BCELoss positive_weight
|
||||
obj: 1.0 # obj loss gain (scale with pixels)
|
||||
obj_pw: 1.0 # obj BCELoss positive_weight
|
||||
iou_t: 0.20 # was 0.2 IoU training threshold
|
||||
anchor_t: 4.0 # anchor-multiple threshold
|
||||
anchors: 2 # anchors per output layer (0 to ignore) @@HK was 3
|
||||
fl_gamma: 1.5 #1.5 # focal loss gamma (efficientDet default gamma=1.5)
|
||||
hsv_h: 0.0 # image HSV-Hue augmentation (fraction)
|
||||
hsv_s: 0.0 # image HSV-Saturation augmentation (fraction)
|
||||
hsv_v: 0.0 # image HSV-Value augmentation (fraction)
|
||||
degrees: 0 # image rotation (+/- deg)
|
||||
translate: 0 # image translation (+/- fraction)
|
||||
scale: 0 # image scale (+/- gain)
|
||||
shear: 0.0 # image shear (+/- deg)
|
||||
perspective: 0.0 # image perspective (+/- fraction), range 0-0.001
|
||||
flipud: 0.3 # image flip up-down (probability)
|
||||
fliplr: 0.5 # image flip left-right (probability)
|
||||
mosaic: 0.5 # image mosaic (probability)
|
||||
mixup: 0.15 # image mixup (probability)
|
||||
copy_paste: 0.0 # image copy paste (probability)
|
||||
paste_in: 0.0 # image copy paste (probability), use 0 for faster training : cutout
|
||||
loss_ota: 0 #1 # use ComputeLossOTA, use 0 for faster training
|
||||
inversion: 0 #opposite temperature
|
||||
img_percentile_removal: 0.3
|
||||
beta : 0.3
|
||||
random_perspective : 0
|
||||
gamma : 80 # percent
|
|
@ -0,0 +1,36 @@
|
|||
lr0: 0.005 #0.001 # initial learning rate (SGD=1E-2, Adam=1E-3)
|
||||
lrf: 0.01 # final OneCycleLR learning rate (lr0 * lrf)
|
||||
momentum: 0.937 # SGD momentum/Adam beta1
|
||||
weight_decay: 0.005 # optimizer weight decay 5e-4 It resolve mAP of overfitting test
|
||||
warmup_epochs: 3.0 # warmup epochs (fractions ok)
|
||||
warmup_momentum: 0.8 # warmup initial momentum
|
||||
warmup_bias_lr: 0.001 #0.001 # warmup initial bias lr
|
||||
box: 0.05 # box loss gain
|
||||
cls: 0.5 # cls loss gain
|
||||
cls_pw: 1.0 # cls BCELoss positive_weight
|
||||
obj: 1.0 # obj loss gain (scale with pixels)
|
||||
obj_pw: 1.0 # obj BCELoss positive_weight
|
||||
iou_t: 0.40 # was 0.2 IoU training threshold
|
||||
anchor_t: 4.0 # anchor-multiple threshold
|
||||
anchors: 2 # anchors per output layer (0 to ignore) @@HK was 3
|
||||
fl_gamma: 1.5 #1.5 # focal loss gamma (efficientDet default gamma=1.5)
|
||||
hsv_h: 0.0 # image HSV-Hue augmentation (fraction)
|
||||
hsv_s: 0.0 # image HSV-Saturation augmentation (fraction)
|
||||
hsv_v: 0.0 # image HSV-Value augmentation (fraction)
|
||||
degrees: 0 # image rotation (+/- deg)
|
||||
translate: 0 # image translation (+/- fraction)
|
||||
scale: 0 # image scale (+/- gain)
|
||||
shear: 0.0 # image shear (+/- deg)
|
||||
perspective: 0.0 # image perspective (+/- fraction), range 0-0.001
|
||||
flipud: 0.3 # image flip up-down (probability)
|
||||
fliplr: 0.5 # image flip left-right (probability)
|
||||
mosaic: 0.5 # image mosaic (probability)
|
||||
mixup: 0.15 # image mixup (probability)
|
||||
copy_paste: 0.0 # image copy paste (probability)
|
||||
paste_in: 0.0 # image copy paste (probability), use 0 for faster training : cutout
|
||||
loss_ota: 0 #1 # use ComputeLossOTA, use 0 for faster training
|
||||
inversion: 0 #opposite temperature
|
||||
img_percentile_removal: 0.3
|
||||
beta : 0.3
|
||||
random_perspective : 0
|
||||
gamma : 80 # percent
|
|
@ -0,0 +1,36 @@
|
|||
lr0: 0.005 #0.001 # initial learning rate (SGD=1E-2, Adam=1E-3)
|
||||
lrf: 0.01 # final OneCycleLR learning rate (lr0 * lrf)
|
||||
momentum: 0.937 # SGD momentum/Adam beta1
|
||||
weight_decay: 0.005 # optimizer weight decay 5e-4 It resolve mAP of overfitting test
|
||||
warmup_epochs: 0.0 # warmup epochs (fractions ok)
|
||||
warmup_momentum: 0.8 # warmup initial momentum
|
||||
warmup_bias_lr: 0.001 #0.001 # warmup initial bias lr
|
||||
box: 0.05 # box loss gain
|
||||
cls: 0.5 # cls loss gain
|
||||
cls_pw: 1.0 # cls BCELoss positive_weight
|
||||
obj: 1.0 # obj loss gain (scale with pixels)
|
||||
obj_pw: 1.0 # obj BCELoss positive_weight
|
||||
iou_t: 0.60 # like the default in the code was 0.2 IoU training threshold
|
||||
anchor_t: 4.0 # anchor-multiple threshold
|
||||
anchors: 2 # anchors per output layer (0 to ignore) @@HK was 3
|
||||
fl_gamma: 1.5 #1.5 # focal loss gamma (efficientDet default gamma=1.5)
|
||||
hsv_h: 0.0 # image HSV-Hue augmentation (fraction)
|
||||
hsv_s: 0.0 # image HSV-Saturation augmentation (fraction)
|
||||
hsv_v: 0.0 # image HSV-Value augmentation (fraction)
|
||||
degrees: 0 # image rotation (+/- deg)
|
||||
translate: 0 # image translation (+/- fraction)
|
||||
scale: 0 # image scale (+/- gain)
|
||||
shear: 0.0 # image shear (+/- deg)
|
||||
perspective: 0.0 # image perspective (+/- fraction), range 0-0.001
|
||||
flipud: 0.3 # image flip up-down (probability)
|
||||
fliplr: 0.5 # image flip left-right (probability)
|
||||
mosaic: 0.5 # image mosaic (probability)
|
||||
mixup: 0.15 # image mixup (probability)
|
||||
copy_paste: 0.0 # image copy paste (probability)
|
||||
paste_in: 0.0 # image copy paste (probability), use 0 for faster training : cutout
|
||||
loss_ota: 0 #1 # use ComputeLossOTA, use 0 for faster training
|
||||
inversion: 0 #opposite temperature
|
||||
img_percentile_removal: 0.3
|
||||
beta : 0.3
|
||||
random_perspective : 0
|
||||
gamma : 80 # percent
|
|
@ -0,0 +1,37 @@
|
|||
lr0: 0.005 #0.001 # initial learning rate (SGD=1E-2, Adam=1E-3)
|
||||
lrf: 0.01 # final OneCycleLR learning rate (lr0 * lrf)
|
||||
momentum: 0.937 # SGD momentum/Adam beta1
|
||||
weight_decay: 0.005 # optimizer weight decay 5e-4 It resolve mAP of overfitting test
|
||||
warmup_epochs: 3.0 # warmup epochs (fractions ok)
|
||||
warmup_momentum: 0.8 # warmup initial momentum
|
||||
warmup_bias_lr: 0.001 #0.001 # warmup initial bias lr
|
||||
box: 0.05 # box loss gain
|
||||
cls: 0.5 # cls loss gain
|
||||
cls_pw: 1.0 # cls BCELoss positive_weight
|
||||
obj: 1.0 # obj loss gain (scale with pixels)
|
||||
obj_pw: 1.0 # obj BCELoss positive_weight
|
||||
iou_t: 0.20 # IoU training threshold
|
||||
anchor_t: 4.0 # anchor-multiple threshold
|
||||
anchors: 2 # anchors per output layer (0 to ignore) @@HK was 3
|
||||
fl_gamma: 1.5 #1.5 # focal loss gamma (efficientDet default gamma=1.5)
|
||||
hsv_h: 0.0 # image HSV-Hue augmentation (fraction)
|
||||
hsv_s: 0.0 # image HSV-Saturation augmentation (fraction)
|
||||
hsv_v: 0.0 # image HSV-Value augmentation (fraction)
|
||||
degrees: 0 # image rotation (+/- deg)
|
||||
translate: 0 # image translation (+/- fraction)
|
||||
scale: 0 # image scale (+/- gain)
|
||||
shear: 0.0 # image shear (+/- deg)
|
||||
perspective: 0.0 # image perspective (+/- fraction), range 0-0.001
|
||||
flipud: 0.3 # image flip up-down (probability)
|
||||
fliplr: 0.5 # image flip left-right (probability)
|
||||
mosaic: 0.5 # image mosaic (probability)
|
||||
mixup: 0.15 # image mixup (probability)
|
||||
copy_paste: 0.0 # image copy paste (probability)
|
||||
paste_in: 0.0 # image copy paste (probability), use 0 for faster training : cutout
|
||||
loss_ota: 0 #1 # use ComputeLossOTA, use 0 for faster training
|
||||
inversion: 0 #opposite temperature
|
||||
img_percentile_removal: 0.3
|
||||
beta : 0.3
|
||||
random_perspective : 0
|
||||
gamma : 80 # percent
|
||||
gamma_liklihood: 0.0
|
|
@ -30,7 +30,8 @@ copy_paste: 0.0 # image copy paste (probability)
|
|||
paste_in: 0.3 # image copy paste (probability), use 0 for faster training : cutout
|
||||
loss_ota: 0 #1 # use ComputeLossOTA, use 0 for faster training
|
||||
inversion: 0 #opposite temperature
|
||||
tir_channel_expansion: 0.3 #[0, 0.2, 0.5]
|
||||
img_percentile_removal: 0.3
|
||||
beta : 0.3
|
||||
random_perspective : 0
|
||||
gamma : 80 # percent
|
||||
gamma_liklihood: 0.25
|
|
@ -2,12 +2,15 @@ import os
|
|||
import pandas as pd
|
||||
from argparse import ArgumentParser
|
||||
import yaml
|
||||
import warnings
|
||||
warnings.warn = lambda *args,**kwargs: None
|
||||
|
||||
def process_class_stats(file_path):
|
||||
columns = ['class_name', 'num_files', 'num_objects', 'precision', 'recall', 'map50', 'map']
|
||||
|
||||
# Read the text file into a pandas DataFrame
|
||||
# df = pd.read_csv(file_path, delim_whitespace=True)
|
||||
warnings.simplefilter# / warnings.catch_warnings
|
||||
df = pd.read_csv(file_path, delim_whitespace=True, names=columns, header=None)
|
||||
|
||||
# Find the index where the last repetition of 'all' starts
|
||||
|
|
|
@ -1,5 +1,6 @@
|
|||
#!/bin/bash
|
||||
source /home/hanoch/.virtualenvs/tir_od/bin/activate
|
||||
cp -rp /home/hanoch/projects/tir_od/runs/train/* /mnt/Data/hanoch/runs/train/
|
||||
if [ -z $1 ] ; then
|
||||
python -u /home/hanoch/projects/tir_od/yolov7/tools/merge_results.py
|
||||
else
|
||||
|
|
1
train.py
1
train.py
|
@ -56,6 +56,7 @@ task = Task.init(
|
|||
task_name="train yolov7 with dummy test"
|
||||
)
|
||||
|
||||
task.set_base_docker(docker_image="nvcr.io/nvidia/pytorch:24.09-py3")
|
||||
gradient_clip_value = 100.0
|
||||
def find_clipped_gradient_within_layer(model, gradient_clip_value):
|
||||
margin_from_sum_abs = 1 / 3
|
||||
|
|
|
@ -7,6 +7,7 @@ import os
|
|||
import random
|
||||
import shutil
|
||||
import time
|
||||
import warnings
|
||||
from itertools import repeat
|
||||
from multiprocessing.pool import ThreadPool
|
||||
from pathlib import Path
|
||||
|
@ -19,6 +20,9 @@ import torch.nn.functional as F
|
|||
from PIL import Image, ExifTags
|
||||
from torch.utils.data import Dataset
|
||||
from tqdm import tqdm
|
||||
# from torchvision.transforms.functional import adjust_gamma
|
||||
from skimage.exposure import adjust_gamma
|
||||
import albumentations as A
|
||||
|
||||
import pickle
|
||||
from copy import deepcopy
|
||||
|
@ -110,6 +114,11 @@ def scaling_image(img, scaling_type, percentile=0.03, beta=0.3):
|
|||
def create_dataloader(path, imgsz, batch_size, stride, opt, hyp=None, augment=False, cache=False, pad=0.0, rect=False,
|
||||
rank=-1, world_size=1, workers=8, image_weights=False, quad=False, prefix='',rel_path_images='', num_cls=-1):
|
||||
# Make sure only the first process in DDP process the dataset first, and the following others can use the cache
|
||||
if opt.gamma_aug_prob > 0:
|
||||
hyp['gamma_liklihood'] = opt.gamma_aug_prob
|
||||
print("", 100 * '==')
|
||||
print('gamma_liklihood was overriden by optional value ', opt.gamma_aug_prob)
|
||||
|
||||
with torch_distributed_zero_first(rank):
|
||||
dataset = LoadImagesAndLabels(path, imgsz, batch_size,
|
||||
augment=augment, # augment images
|
||||
|
@ -459,6 +468,9 @@ class LoadImagesAndLabels(Dataset): # for training/testing
|
|||
self.is_tir_signal = not (no_tir_signal)
|
||||
|
||||
#self.albumentations = Albumentations() if augment else None
|
||||
self.albumentations_gamma_contrast = Albumentations_gamma_contrast(alb_prob=hyp['gamma_liklihood'],
|
||||
gamma_limit=[hyp['gamma'],
|
||||
100 + 100-hyp['gamma']])
|
||||
|
||||
try:
|
||||
f = [] # image files
|
||||
|
@ -474,7 +486,7 @@ class LoadImagesAndLabels(Dataset): # for training/testing
|
|||
if bool(rel_path_images):
|
||||
f += [os.path.join(rel_path_images, x.replace('./', '')).rstrip() if x.startswith('./') else x for x in t] # local to global path
|
||||
else:
|
||||
f += [x.replace('./', parent) if x.startswith('./') else x for x in t] # local to global path
|
||||
f += [x.replace('./', parent).rstrip() if x.startswith('./') else x for x in t] # local to global path
|
||||
|
||||
# f += [p.parent / x.lstrip(os.sep) for x in t] # local to global path (pathlib)
|
||||
else:
|
||||
|
@ -639,7 +651,7 @@ class LoadImagesAndLabels(Dataset): # for training/testing
|
|||
index = self.indices[index] # linear, shuffled, or image_weights
|
||||
|
||||
hyp = self.hyp
|
||||
mosaic = self.mosaic and random.random() < hyp['mosaic']
|
||||
mosaic = self.mosaic and random.random() < hyp['mosaic'] and not(self.tir_channel_expansion)
|
||||
if mosaic:
|
||||
# Load mosaic
|
||||
if random.random() < 0.8:
|
||||
|
@ -671,6 +683,24 @@ class LoadImagesAndLabels(Dataset): # for training/testing
|
|||
if labels.size: # normalized xywh to pixel xyxy format
|
||||
labels[:, 1:] = xywhn2xyxy(labels[:, 1:], ratio[0] * w, ratio[1] * h, padw=pad[0], padh=pad[1])
|
||||
|
||||
if self.tir_channel_expansion: # HK @@ according to the paper this CE is a sort of augmentation hence no need to preliminary augment. One of the channels are inversion hence avoid channel inversion aug
|
||||
img = np.repeat(img[np.newaxis, :, :], 3, axis=0) # convert GL to RGB by replication
|
||||
img_ce = np.zeros_like(img).astype('float64')
|
||||
|
||||
# CH1 hist equalization
|
||||
img_chan = scaling_image(img[0, :, :], scaling_type=self.scaling_type,
|
||||
percentile=0, beta=self.beta)
|
||||
img_ce[0, :, :] = img_chan.astype('float64')
|
||||
|
||||
img_chan = scaling_image(img[1, :, :], scaling_type=self.scaling_type,
|
||||
percentile=self.percentile, beta=self.beta)
|
||||
|
||||
img_ce[1, :, :] = img_chan.astype('float64')
|
||||
|
||||
img_chan = inversion_aug(img_ce[1, :, :]) # invert the DRC one
|
||||
img_ce[2, :, :] = img_chan.astype('float64')
|
||||
img = img_ce
|
||||
|
||||
if self.augment:
|
||||
# Augment imagespace
|
||||
if not mosaic:
|
||||
|
@ -688,6 +718,8 @@ class LoadImagesAndLabels(Dataset): # for training/testing
|
|||
|
||||
|
||||
#img, labels = self.albumentations(img, labels)
|
||||
img = self.albumentations_gamma_contrast(img)
|
||||
|
||||
if hyp['hsv_h'] >0 or hyp['hsv_s'] >0 or hyp['hsv_v'] >0 :
|
||||
# Augment colorspace
|
||||
augment_hsv(img, hgain=hyp['hsv_h'], sgain=hyp['hsv_s'], vgain=hyp['hsv_v'])
|
||||
|
@ -727,29 +759,20 @@ class LoadImagesAndLabels(Dataset): # for training/testing
|
|||
if nL:
|
||||
labels[:, 1] = 1 - labels[:, 1]
|
||||
|
||||
# if random.random() < hyp['gamma_liklihood']:
|
||||
# # if img.dtype == np.uint16 or img.dtype == np.uint8:
|
||||
# # img = img/np.iinfo(img.dtype).max
|
||||
# # if img.max() > 1.0:
|
||||
# # warnings.warn("gamma correction operates over standartized images [0-1]!!!")
|
||||
#
|
||||
# img = adjust_gamma(img, hyp['gamma'], gain=1)
|
||||
|
||||
labels_out = torch.zeros((nL, 6))
|
||||
if nL:
|
||||
labels_out[:, 1:] = torch.from_numpy(labels)
|
||||
|
||||
if self.tir_channel_expansion:
|
||||
img = np.repeat(img[np.newaxis, :, :], 3, axis=0) # convert GL to RGB by replication
|
||||
img_ce = np.zeros_like(img).astype('float64')
|
||||
|
||||
# CH1 hist equalization
|
||||
img_chan = scaling_image(img[0, :, :], scaling_type=self.scaling_type,
|
||||
percentile=0, beta=self.beta)
|
||||
img_ce[0, :, :] = img_chan.astype('float64')
|
||||
|
||||
img_chan = scaling_image(img[1, :, :], scaling_type=self.scaling_type,
|
||||
percentile=self.percentile, beta=self.beta)
|
||||
|
||||
img_ce[1, :, :] = img_chan.astype('float64')
|
||||
|
||||
img_chan = inversion_aug(img_ce[1, :, :]) # invert the DRC one
|
||||
img_ce[2, :, :] = img_chan.astype('float64')
|
||||
img = img_ce
|
||||
# tifffile.imwrite(os.path.join('/home/hanoch/projects/tir_od', 'img_ce.tiff'), 255*img.transpose(1,2,0).astype('uint8'))
|
||||
else:
|
||||
if not self.tir_channel_expansion:
|
||||
if self.is_tir_signal:
|
||||
img = np.repeat(img[np.newaxis, :, :], self.input_channels, axis=0) #convert GL to RGB by replication
|
||||
else:
|
||||
|
@ -1395,6 +1418,32 @@ def pastein(image, labels, sample_labels, sample_images, sample_masks):
|
|||
|
||||
return labels
|
||||
|
||||
|
||||
import albumentations as A
|
||||
|
||||
class Albumentations_gamma_contrast:
|
||||
# YOLOv5 Albumentations class (optional, only used if package is installed)
|
||||
def __init__(self, alb_prob=0.01, gamma_limit=[80, 120]):
|
||||
self.transform = None
|
||||
|
||||
self.transform = A.Compose([
|
||||
# A.CLAHE(p=0.01),
|
||||
A.RandomBrightnessContrast(brightness_limit=0.2, contrast_limit=0.2, p=alb_prob),
|
||||
A.RandomGamma(gamma_limit=gamma_limit, p=alb_prob)])
|
||||
# A.Blur(p=0.01),
|
||||
# A.MedianBlur(p=0.01),
|
||||
# A.ToGray(p=0.01),
|
||||
# A.ImageCompression(quality_lower=75, p=0.01),],
|
||||
# bbox_params=A.BboxParams(format='pascal_voc', label_fields=['class_labels']))
|
||||
|
||||
#logging.info(colorstr('albumentations: ') + ', '.join(f'{x}' for x in self.transform.transforms if x.p))
|
||||
|
||||
def __call__(self, im, p=1.0):
|
||||
if self.transform and random.random() < p:
|
||||
new = self.transform(image=im) # transformed
|
||||
im = new['image']
|
||||
return im
|
||||
|
||||
class Albumentations:
|
||||
# YOLOv5 Albumentations class (optional, only used if package is installed)
|
||||
def __init__(self):
|
||||
|
|
|
@ -18,7 +18,7 @@ def gsutil_getsize(url=''):
|
|||
|
||||
def attempt_download(file, repo='WongKinYiu/yolov7'):
|
||||
# Attempt file download if does not exist
|
||||
file = Path(str(file).strip().replace("'", '').lower())
|
||||
file = Path(str(file).strip().replace("'", ''))
|
||||
|
||||
if not file.exists():
|
||||
try:
|
||||
|
|
Loading…
Reference in New Issue