2022-07-06 23:23:27 +08:00
import argparse
import logging
import math
import os
import random
import time
from copy import deepcopy
from pathlib import Path
from threading import Thread
import numpy as np
import torch . distributed as dist
import torch . nn as nn
import torch . nn . functional as F
import torch . optim as optim
import torch . optim . lr_scheduler as lr_scheduler
import torch . utils . data
import yaml
from torch . cuda import amp
from torch . nn . parallel import DistributedDataParallel as DDP
from torch . utils . tensorboard import SummaryWriter
from tqdm import tqdm
2024-10-14 15:32:25 +08:00
import re
2022-07-06 23:23:27 +08:00
import test # import test.py to get mAP after each epoch
2024-10-14 15:32:25 +08:00
try :
# from yolov7_main.models.common import Conv, DWConv
# from yolov7_main.utils.google_utils import attempt_download
from models . experimental import attempt_load
except :
print ( " " , 100 * ' == ' )
print ( os . getcwd ( ) )
import sys
sys . path . append ( ' /home/hanoch/projects/tir_od ' )
from tir_od . yolov7 . models . experimental import attempt_load
#
2022-07-06 23:23:27 +08:00
from models . yolo import Model
from utils . autoanchor import check_anchors
from utils . datasets import create_dataloader
from utils . general import labels_to_class_weights , increment_path , labels_to_image_weights , init_seeds , \
fitness , strip_optimizer , get_latest_run , check_dataset , check_file , check_git_status , check_img_size , \
check_requirements , print_mutation , set_logging , one_cycle , colorstr
from utils . google_utils import attempt_download
from utils . loss import ComputeLoss , ComputeLossOTA
from utils . plots import plot_images , plot_labels , plot_results , plot_evolution
from utils . torch_utils import ModelEMA , select_device , intersect_dicts , torch_distributed_zero_first , is_parallel
from utils . wandb_logging . wandb_utils import WandbLogger , check_wandb_resume
logger = logging . getLogger ( __name__ )
2024-09-23 16:23:13 +08:00
from clearml import Task , Logger
task = Task . init (
project_name = " TIR_OD " ,
2024-10-14 15:32:25 +08:00
task_name = " train yolov7 with dummy test "
2024-09-23 16:23:13 +08:00
)
2024-10-14 15:32:25 +08:00
gradient_clip_value = 100.0
def find_clipped_gradient_within_layer ( model , gradient_clip_value ) :
margin_from_sum_abs = 1 / 3
# find if excess gradient value w/o clipping using the clipping API with clip=INF=100 :just check total norm with dummy high clip val
total_grad_norm = torch . nn . utils . clip_grad_norm_ ( model . parameters ( ) , 100 )
if total_grad_norm > gradient_clip_value :
max_grad_temp = - 100.0
name_grad_temp = ' None '
for name , param in model . named_parameters ( ) :
# not_none_grad = [p is not None for p in param.grad]
if param . grad is not None :
# print(param.grad)
norm_layer = torch . unsqueeze ( torch . norm ( param . grad . detach ( ) , float ( 2 ) ) , 0 )
not_none_grad = [ i for i in norm_layer if i is not None ]
for u in not_none_grad :
if ( u > gradient_clip_value * margin_from_sum_abs ) . any ( ) :
# print(name, u[u > gradient_clip_value/2])
if ( u [ u > gradient_clip_value * margin_from_sum_abs ] > max_grad_temp ) :
max_grad_temp = u [ u > gradient_clip_value * margin_from_sum_abs ]
name_grad_temp = name
print ( " layer {} with max gradient {} " . format ( name_grad_temp , max_grad_temp ) )
2024-09-15 14:20:39 +08:00
def compare_models_basic ( model1 , model2 ) :
for ix , ( p1 , p2 ) in enumerate ( zip ( model1 . parameters ( ) , model2 . parameters ( ) ) ) :
if p1 . data . ne ( p2 . data ) . sum ( ) > 0 :
print ( ' Models are different ' , ix , p1 . data . ne ( p2 . data ) . sum ( ) )
return False
return True
def compare_models ( model1 , model2 ) :
# Iterate through named layers and parameters of both models
for ( name1 , param1 ) , ( name2 , param2 ) in zip ( model1 . named_parameters ( ) , model2 . named_parameters ( ) ) :
if name1 != name2 :
print ( f " Layer names differ: { name1 } vs { name2 } " )
# Compare the parameters
if not torch . equal ( param1 , param2 ) :
print ( ' Difference found in layer {} {} ' . format ( name1 , param1 . data . ne ( param2 . data ) . sum ( ) ) )
return
# print("No differences found in any layer.")
2022-07-06 23:23:27 +08:00
def train ( hyp , opt , device , tb_writer = None ) :
logger . info ( colorstr ( ' hyperparameters: ' ) + ' , ' . join ( f ' { k } = { v } ' for k , v in hyp . items ( ) ) )
2022-08-02 22:55:28 +08:00
save_dir , epochs , batch_size , total_batch_size , weights , rank , freeze = \
Path ( opt . save_dir ) , opt . epochs , opt . batch_size , opt . total_batch_size , opt . weights , opt . global_rank , opt . freeze
2022-07-06 23:23:27 +08:00
# Directories
wdir = save_dir / ' weights '
wdir . mkdir ( parents = True , exist_ok = True ) # make dir
last = wdir / ' last.pt '
best = wdir / ' best.pt '
results_file = save_dir / ' results.txt '
# Save run settings
with open ( save_dir / ' hyp.yaml ' , ' w ' ) as f :
yaml . dump ( hyp , f , sort_keys = False )
with open ( save_dir / ' opt.yaml ' , ' w ' ) as f :
yaml . dump ( vars ( opt ) , f , sort_keys = False )
2024-10-14 15:32:25 +08:00
is_torch_240 = int ( re . search ( r ' ([ \ d.]+) ' , torch . __version__ ) . group ( 1 ) . replace ( ' . ' , ' ' ) ) > = 240
2024-09-15 14:20:39 +08:00
2022-07-06 23:23:27 +08:00
# Configure
plots = not opt . evolve # create plots
cuda = device . type != ' cpu '
init_seeds ( 2 + rank )
with open ( opt . data ) as f :
data_dict = yaml . load ( f , Loader = yaml . SafeLoader ) # data dict
is_coco = opt . data . endswith ( ' coco.yaml ' )
2024-09-15 14:20:39 +08:00
with open ( save_dir / ' data.yaml ' , ' w ' ) as f :
yaml . dump ( data_dict , f , sort_keys = False )
2022-07-06 23:23:27 +08:00
# Logging- Doing this before checking the dataset. Might update data_dict
loggers = { ' wandb ' : None } # loggers dict
if rank in [ - 1 , 0 ] :
opt . hyp = hyp # add hyperparameters
2022-07-30 23:43:18 +08:00
run_id = torch . load ( weights , map_location = device ) . get ( ' wandb_id ' ) if weights . endswith ( ' .pt ' ) and os . path . isfile ( weights ) else None
2022-07-06 23:23:27 +08:00
wandb_logger = WandbLogger ( opt , Path ( opt . save_dir ) . stem , run_id , data_dict )
loggers [ ' wandb ' ] = wandb_logger . wandb
data_dict = wandb_logger . data_dict
if wandb_logger . wandb :
weights , epochs , hyp = opt . weights , opt . epochs , opt . hyp # WandbLogger might update weights, epochs if resuming
nc = 1 if opt . single_cls else int ( data_dict [ ' nc ' ] ) # number of classes
names = [ ' item ' ] if opt . single_cls and len ( data_dict [ ' names ' ] ) != 1 else data_dict [ ' names ' ] # class names
assert len ( names ) == nc , ' %g names found for nc= %g dataset in %s ' % ( len ( names ) , nc , opt . data ) # check
# Model
pretrained = weights . endswith ( ' .pt ' )
if pretrained :
with torch_distributed_zero_first ( rank ) :
attempt_download ( weights ) # download if not found locally
ckpt = torch . load ( weights , map_location = device ) # load checkpoint
2024-09-15 14:20:39 +08:00
model = Model ( opt . cfg or ckpt [ ' model ' ] . yaml , ch = opt . input_channels , nc = nc , anchors = hyp . get ( ' anchors ' ) ) . to ( device ) # create model structure according to yaml and not the checkpoint
2022-07-06 23:23:27 +08:00
exclude = [ ' anchor ' ] if ( opt . cfg or hyp . get ( ' anchors ' ) ) and not opt . resume else [ ] # exclude keys
state_dict = ckpt [ ' model ' ] . float ( ) . state_dict ( ) # to FP32
state_dict = intersect_dicts ( state_dict , model . state_dict ( ) , exclude = exclude ) # intersect
model . load_state_dict ( state_dict , strict = False ) # load
logger . info ( ' Transferred %g / %g items from %s ' % ( len ( state_dict ) , len ( model . state_dict ( ) ) , weights ) ) # report
else :
2024-08-15 20:40:01 +08:00
model = Model ( opt . cfg , ch = opt . input_channels , nc = nc , anchors = hyp . get ( ' anchors ' ) ) . to ( device ) # create
2022-07-06 23:23:27 +08:00
with torch_distributed_zero_first ( rank ) :
check_dataset ( data_dict ) # check
train_path = data_dict [ ' train ' ]
test_path = data_dict [ ' val ' ]
2024-08-08 19:44:50 +08:00
images_parent_folder = data_dict [ ' path ' ]
2022-07-06 23:23:27 +08:00
# Freeze
2022-08-02 22:55:28 +08:00
freeze = [ f ' model. { x } . ' for x in ( freeze if len ( freeze ) > 1 else range ( freeze [ 0 ] ) ) ] # parameter names to freeze (full or partial)
2022-07-06 23:23:27 +08:00
for k , v in model . named_parameters ( ) :
v . requires_grad = True # train all layers
if any ( x in k for x in freeze ) :
print ( ' freezing %s ' % k )
v . requires_grad = False
# Optimizer
nbs = 64 # nominal batch size
accumulate = max ( round ( nbs / total_batch_size ) , 1 ) # accumulate loss before optimizing
hyp [ ' weight_decay ' ] * = total_batch_size * accumulate / nbs # scale weight_decay
logger . info ( f " Scaled weight_decay = { hyp [ ' weight_decay ' ] } " )
pg0 , pg1 , pg2 = [ ] , [ ] , [ ] # optimizer parameter groups
for k , v in model . named_modules ( ) :
if hasattr ( v , ' bias ' ) and isinstance ( v . bias , nn . Parameter ) :
2024-09-15 14:20:39 +08:00
pg2 . append ( v . bias ) # biases # also need to be set to zero
2022-07-06 23:23:27 +08:00
if isinstance ( v , nn . BatchNorm2d ) :
pg0 . append ( v . weight ) # no decay
elif hasattr ( v , ' weight ' ) and isinstance ( v . weight , nn . Parameter ) :
pg1 . append ( v . weight ) # apply decay
if hasattr ( v , ' im ' ) :
if hasattr ( v . im , ' implicit ' ) :
pg0 . append ( v . im . implicit )
else :
for iv in v . im :
pg0 . append ( iv . implicit )
if hasattr ( v , ' imc ' ) :
if hasattr ( v . imc , ' implicit ' ) :
pg0 . append ( v . imc . implicit )
else :
for iv in v . imc :
pg0 . append ( iv . implicit )
if hasattr ( v , ' imb ' ) :
if hasattr ( v . imb , ' implicit ' ) :
pg0 . append ( v . imb . implicit )
else :
for iv in v . imb :
pg0 . append ( iv . implicit )
if hasattr ( v , ' imo ' ) :
if hasattr ( v . imo , ' implicit ' ) :
pg0 . append ( v . imo . implicit )
else :
for iv in v . imo :
pg0 . append ( iv . implicit )
if hasattr ( v , ' ia ' ) :
if hasattr ( v . ia , ' implicit ' ) :
pg0 . append ( v . ia . implicit )
else :
for iv in v . ia :
pg0 . append ( iv . implicit )
if hasattr ( v , ' attn ' ) :
if hasattr ( v . attn , ' logit_scale ' ) :
pg0 . append ( v . attn . logit_scale )
if hasattr ( v . attn , ' q_bias ' ) :
pg0 . append ( v . attn . q_bias )
if hasattr ( v . attn , ' v_bias ' ) :
pg0 . append ( v . attn . v_bias )
if hasattr ( v . attn , ' relative_position_bias_table ' ) :
pg0 . append ( v . attn . relative_position_bias_table )
if hasattr ( v , ' rbr_dense ' ) :
if hasattr ( v . rbr_dense , ' weight_rbr_origin ' ) :
pg0 . append ( v . rbr_dense . weight_rbr_origin )
if hasattr ( v . rbr_dense , ' weight_rbr_avg_conv ' ) :
pg0 . append ( v . rbr_dense . weight_rbr_avg_conv )
if hasattr ( v . rbr_dense , ' weight_rbr_pfir_conv ' ) :
pg0 . append ( v . rbr_dense . weight_rbr_pfir_conv )
if hasattr ( v . rbr_dense , ' weight_rbr_1x1_kxk_idconv1 ' ) :
pg0 . append ( v . rbr_dense . weight_rbr_1x1_kxk_idconv1 )
if hasattr ( v . rbr_dense , ' weight_rbr_1x1_kxk_conv2 ' ) :
pg0 . append ( v . rbr_dense . weight_rbr_1x1_kxk_conv2 )
if hasattr ( v . rbr_dense , ' weight_rbr_gconv_dw ' ) :
pg0 . append ( v . rbr_dense . weight_rbr_gconv_dw )
if hasattr ( v . rbr_dense , ' weight_rbr_gconv_pw ' ) :
pg0 . append ( v . rbr_dense . weight_rbr_gconv_pw )
if hasattr ( v . rbr_dense , ' vector ' ) :
pg0 . append ( v . rbr_dense . vector )
2024-08-27 18:58:16 +08:00
if opt . adam : # @@ HK AdamW() is a fix for Adam due to Wdecay/L2 loss bug
2024-09-15 14:20:39 +08:00
optimizer = optim . AdamW ( pg0 , lr = hyp [ ' lr0 ' ] , weight_decay = 0 , betas = ( hyp [ ' momentum ' ] , 0.999 ) ) # adjust beta1 to momentum
2022-07-06 23:23:27 +08:00
else :
2024-09-15 14:20:39 +08:00
optimizer = optim . SGD ( pg0 , lr = hyp [ ' lr0 ' ] , weight_decay = 0 , momentum = hyp [ ' momentum ' ] , nesterov = True )
2022-07-06 23:23:27 +08:00
optimizer . add_param_group ( { ' params ' : pg1 , ' weight_decay ' : hyp [ ' weight_decay ' ] } ) # add pg1 with weight_decay
2024-09-15 14:20:39 +08:00
optimizer . add_param_group ( { ' params ' : pg2 , ' weight_decay ' : 0 } ) # add pg2 (biases)
2022-07-06 23:23:27 +08:00
logger . info ( ' Optimizer groups: %g .bias, %g conv.weight, %g other ' % ( len ( pg2 ) , len ( pg1 ) , len ( pg0 ) ) )
2024-09-15 14:20:39 +08:00
# validate that we considered every parameter
# param_dict = {pn: p for pn, p in model.named_parameters()}
# inter_params = set(pg1) & set(pg0) & set(pg1)
# union_params = set(pg1) | set(pg0) | set(pg1)
# assert len(inter_params) == 0, "parameters %s made it into both decay/no_decay sets!" % (str(inter_params),)
# assert len(
# param_dict.keys() - union_params) == 0, "parameters %s were not separated into either decay/no_decay set!" \
# % (str(param_dict.keys() - union_params),)
2022-07-06 23:23:27 +08:00
del pg0 , pg1 , pg2
# Scheduler https://arxiv.org/pdf/1812.01187.pdf
# https://pytorch.org/docs/stable/_modules/torch/optim/lr_scheduler.html#OneCycleLR
if opt . linear_lr :
lf = lambda x : ( 1 - x / ( epochs - 1 ) ) * ( 1.0 - hyp [ ' lrf ' ] ) + hyp [ ' lrf ' ] # linear
else :
lf = one_cycle ( 1 , hyp [ ' lrf ' ] , epochs ) # cosine 1->hyp['lrf']
scheduler = lr_scheduler . LambdaLR ( optimizer , lr_lambda = lf )
# plot_lr_scheduler(optimizer, scheduler, epochs)
2024-09-15 14:20:39 +08:00
# from utils.plots import plot_lr_scheduler
# plot_lr_scheduler(optimizer, scheduler, epochs, save_dir='/home/hanoch/projects/tir_od')
2024-08-27 18:58:16 +08:00
2022-07-06 23:23:27 +08:00
# EMA
ema = ModelEMA ( model ) if rank in [ - 1 , 0 ] else None
# Resume
start_epoch , best_fitness = 0 , 0.0
if pretrained :
# Optimizer
if ckpt [ ' optimizer ' ] is not None :
optimizer . load_state_dict ( ckpt [ ' optimizer ' ] )
best_fitness = ckpt [ ' best_fitness ' ]
# EMA
if ema and ckpt . get ( ' ema ' ) :
ema . ema . load_state_dict ( ckpt [ ' ema ' ] . float ( ) . state_dict ( ) )
ema . updates = ckpt [ ' updates ' ]
# Results
if ckpt . get ( ' training_results ' ) is not None :
results_file . write_text ( ckpt [ ' training_results ' ] ) # write results.txt
# Epochs
start_epoch = ckpt [ ' epoch ' ] + 1
if opt . resume :
assert start_epoch > 0 , ' %s training to %g epochs is finished, nothing to resume. ' % ( weights , epochs )
if epochs < start_epoch :
logger . info ( ' %s has been trained for %g epochs. Fine-tuning for %g additional epochs. ' %
( weights , ckpt [ ' epoch ' ] , epochs ) )
epochs + = ckpt [ ' epoch ' ] # finetune additional epochs
del ckpt , state_dict
# Image sizes
gs = max ( int ( model . stride . max ( ) ) , 32 ) # grid size (max stride)
nl = model . model [ - 1 ] . nl # number of detection layers (used for scaling hyp['obj'])
imgsz , imgsz_test = [ check_img_size ( x , gs ) for x in opt . img_size ] # verify imgsz are gs-multiples
# DP mode
if cuda and rank == - 1 and torch . cuda . device_count ( ) > 1 :
model = torch . nn . DataParallel ( model )
# SyncBatchNorm
if opt . sync_bn and cuda and rank != - 1 :
model = torch . nn . SyncBatchNorm . convert_sync_batchnorm ( model ) . to ( device )
logger . info ( ' Using SyncBatchNorm() ' )
# Trainloader
dataloader , dataset = create_dataloader ( train_path , imgsz , batch_size , gs , opt ,
hyp = hyp , augment = True , cache = opt . cache_images , rect = opt . rect , rank = rank ,
world_size = opt . world_size , workers = opt . workers ,
2024-08-08 19:44:50 +08:00
image_weights = opt . image_weights , quad = opt . quad , prefix = colorstr ( ' train: ' ) ,
2024-08-27 18:58:16 +08:00
rel_path_images = images_parent_folder , num_cls = data_dict [ ' nc ' ] )
2022-07-06 23:23:27 +08:00
mlc = np . concatenate ( dataset . labels , 0 ) [ : , 0 ] . max ( ) # max label class
nb = len ( dataloader ) # number of batches
assert mlc < nc , ' Label class %g exceeds nc= %g in %s . Possible class labels are 0- %g ' % ( mlc , nc , opt . data , nc - 1 )
# Process 0
if rank in [ - 1 , 0 ] :
2024-08-27 18:58:16 +08:00
testloader , test_dataset = create_dataloader ( test_path , imgsz_test , batch_size * 2 , gs , opt , # testloader
hyp = hyp , cache = opt . cache_images and not opt . notest , rect = False , rank = - 1 , # @@@ rect was True why?
2022-07-06 23:23:27 +08:00
world_size = opt . world_size , workers = opt . workers ,
2024-08-08 19:44:50 +08:00
pad = 0.5 , prefix = colorstr ( ' val: ' ) ,
2024-08-27 18:58:16 +08:00
rel_path_images = images_parent_folder , num_cls = data_dict [ ' nc ' ] )
mlc = np . concatenate ( test_dataset . labels , 0 ) [ : , 0 ] . max ( ) # max label class
assert mlc < nc , ' Label class %g exceeds nc= %g in %s . Possible class labels are 0- %g ' % ( mlc , nc , opt . data , nc - 1 )
2022-07-06 23:23:27 +08:00
if not opt . resume :
labels = np . concatenate ( dataset . labels , 0 )
c = torch . tensor ( labels [ : , 0 ] ) # classes
# cf = torch.bincount(c.long(), minlength=nc) + 1. # frequency
# model._initialize_biases(cf.to(device))
if plots :
#plot_labels(labels, names, save_dir, loggers)
if tb_writer :
tb_writer . add_histogram ( ' classes ' , c , 0 )
# Anchors
if not opt . noautoanchor :
check_anchors ( dataset , model = model , thr = hyp [ ' anchor_t ' ] , imgsz = imgsz )
2024-10-14 15:32:25 +08:00
if opt . amp or 1 :
2024-09-15 14:20:39 +08:00
model . half ( ) . float ( ) # pre-reduce anchor precision TODO HK Why ? >???!!!!
2022-07-06 23:23:27 +08:00
# DDP mode
if cuda and rank != - 1 :
model = DDP ( model , device_ids = [ opt . local_rank ] , output_device = opt . local_rank ,
# nn.MultiheadAttention incompatibility with DDP https://github.com/pytorch/pytorch/issues/26698
find_unused_parameters = any ( isinstance ( layer , nn . MultiheadAttention ) for layer in model . modules ( ) ) )
# Model parameters
hyp [ ' box ' ] * = 3. / nl # scale to layers
hyp [ ' cls ' ] * = nc / 80. * 3. / nl # scale to classes and layers
hyp [ ' obj ' ] * = ( imgsz / 640 ) * * 2 * 3. / nl # scale to image size and layers
hyp [ ' label_smoothing ' ] = opt . label_smoothing
model . nc = nc # attach number of classes to model
model . hyp = hyp # attach hyperparameters to model
model . gr = 1.0 # iou loss ratio (obj_loss = 1.0 or iou)
model . class_weights = labels_to_class_weights ( dataset . labels , nc ) . to ( device ) * nc # attach class weights
model . names = names
# Start training
t0 = time . time ( )
2024-10-14 15:32:25 +08:00
if hyp [ ' warmup_epochs ' ] != 0 : # otherwise it is forced to 1000 iterations
nw = max ( round ( hyp [ ' warmup_epochs ' ] * nb ) , 1000 ) # number of warmup iterations, max(3 epochs, 1k iterations)
else :
nw = 0
2022-07-06 23:23:27 +08:00
# nw = min(nw, (epochs - start_epoch) / 2 * nb) # limit warmup to < 1/2 of training
maps = np . zeros ( nc ) # mAP per class
results = ( 0 , 0 , 0 , 0 , 0 , 0 , 0 ) # P, R, mAP@.5, mAP@.5-.95, val_loss(box, obj, cls)
scheduler . last_epoch = start_epoch - 1 # do not move
2024-10-14 15:32:25 +08:00
if 1 :
scaler = amp . GradScaler ( enabled = cuda )
else :
scaler = torch . amp . GradScaler ( " cuda " , enabled = opt . amp ) if is_torch_240 else torch . cuda . amp . GradScaler ( enabled = opt . amp )
2022-07-06 23:23:27 +08:00
compute_loss_ota = ComputeLossOTA ( model ) # init loss class
compute_loss = ComputeLoss ( model ) # init loss class
logger . info ( f ' Image sizes { imgsz } train, { imgsz_test } test \n '
f ' Using { dataloader . num_workers } dataloader workers \n '
f ' Logging results to { save_dir } \n '
f ' Starting training for { epochs } epochs... ' )
2024-09-23 16:23:13 +08:00
if ( not opt . nosave ) :
torch . save ( model , wdir / ' init.pt ' )
2024-10-14 15:32:25 +08:00
# from pympler import tracker
# the_tracker = tracker.SummaryTracker()
# the_tracker.print_diff()
# OP
# the_tracker.print_diff()
if 1 : # HK TODO remove later
torch . autograd . set_detect_anomaly ( True )
2024-09-23 16:23:13 +08:00
2022-07-06 23:23:27 +08:00
for epoch in range ( start_epoch , epochs ) : # epoch ------------------------------------------------------------------
model . train ( )
# Update image weights (optional)
if opt . image_weights :
# Generate indices
if rank in [ - 1 , 0 ] :
cw = model . class_weights . cpu ( ) . numpy ( ) * ( 1 - maps ) * * 2 / nc # class weights
iw = labels_to_image_weights ( dataset . labels , nc = nc , class_weights = cw ) # image weights
dataset . indices = random . choices ( range ( dataset . n ) , weights = iw , k = dataset . n ) # rand weighted idx
# Broadcast if DDP
if rank != - 1 :
indices = ( torch . tensor ( dataset . indices ) if rank == 0 else torch . zeros ( dataset . n ) ) . int ( )
dist . broadcast ( indices , 0 )
if rank != 0 :
dataset . indices = indices . cpu ( ) . numpy ( )
# Update mosaic border
# b = int(random.uniform(0.25 * imgsz, 0.75 * imgsz + gs) // gs * gs)
# dataset.mosaic_border = [b - imgsz, -b] # height, width borders
mloss = torch . zeros ( 4 , device = device ) # mean losses
if rank != - 1 :
dataloader . sampler . set_epoch ( epoch )
pbar = enumerate ( dataloader )
logger . info ( ( ' \n ' + ' %10s ' * 8 ) % ( ' Epoch ' , ' gpu_mem ' , ' box ' , ' obj ' , ' cls ' , ' total ' , ' labels ' , ' img_size ' ) )
if rank in [ - 1 , 0 ] :
pbar = tqdm ( pbar , total = nb ) # progress bar
optimizer . zero_grad ( )
2024-10-14 15:32:25 +08:00
2022-07-06 23:23:27 +08:00
for i , ( imgs , targets , paths , _ ) in pbar : # batch -------------------------------------------------------------
ni = i + nb * epoch # number integrated batches (since train start)
2024-08-15 20:40:01 +08:00
# imgs = imgs.to(device, non_blocking=True).float() / 255.0 # uint8 to float32, 0-255 to 0.0-1.0 @@HK TODO is that standartization ?
imgs = imgs . to ( device , non_blocking = True ) . float ( )
2022-07-06 23:23:27 +08:00
# Warmup
if ni < = nw :
xi = [ 0 , nw ] # x interp
# model.gr = np.interp(ni, xi, [0.0, 1.0]) # iou loss ratio (obj_loss = 1.0 or iou)
accumulate = max ( 1 , np . interp ( ni , xi , [ 1 , nbs / total_batch_size ] ) . round ( ) )
for j , x in enumerate ( optimizer . param_groups ) :
# bias lr falls from 0.1 to lr0, all other lrs rise from 0.0 to lr0
x [ ' lr ' ] = np . interp ( ni , xi , [ hyp [ ' warmup_bias_lr ' ] if j == 2 else 0.0 , x [ ' initial_lr ' ] * lf ( epoch ) ] )
if ' momentum ' in x :
x [ ' momentum ' ] = np . interp ( ni , xi , [ hyp [ ' warmup_momentum ' ] , hyp [ ' momentum ' ] ] )
# Multi-scale
if opt . multi_scale :
2024-08-15 20:46:47 +08:00
sz = random . randrange ( int ( imgsz * 0.5 ) , int ( imgsz * 1.5 + gs ) ) / / gs * gs # size
2022-07-06 23:23:27 +08:00
sf = sz / max ( imgs . shape [ 2 : ] ) # scale factor
if sf != 1 :
ns = [ math . ceil ( x * sf / gs ) * gs for x in imgs . shape [ 2 : ] ] # new shape (stretched to gs-multiple)
imgs = F . interpolate ( imgs , size = ns , mode = ' bilinear ' , align_corners = False )
with amp . autocast ( enabled = cuda ) :
2024-10-14 15:32:25 +08:00
# with amp.autocast(enabled=(cuda and opt.amp)):
2022-07-06 23:23:27 +08:00
pred = model ( imgs ) # forward
2022-08-16 07:10:07 +08:00
if ' loss_ota ' not in hyp or hyp [ ' loss_ota ' ] == 1 :
2022-08-09 12:48:28 +08:00
loss , loss_items = compute_loss_ota ( pred , targets . to ( device ) , imgs ) # loss scaled by batch_size
else :
loss , loss_items = compute_loss ( pred , targets . to ( device ) ) # loss scaled by batch_size
2022-07-06 23:23:27 +08:00
if rank != - 1 :
loss * = opt . world_size # gradient averaged between devices in DDP mode
if opt . quad :
loss * = 4.
2024-10-14 15:32:25 +08:00
# HK TODO : https://discuss.pytorch.org/t/switching-between-mixed-precision-training-and-full-precision-training-after-training-is-started/132366/4 remove scaler backwards
2022-07-06 23:23:27 +08:00
# Backward
scaler . scale ( loss ) . backward ( )
2024-10-14 15:32:25 +08:00
# gradient clipping find and clip
# find_clipped_gradient_within_layer(model, gradient_clip_value)
if ni > nw and rank in [ - 1 , 0 ] :
total_grad_norm = torch . nn . utils . clip_grad_norm_ ( model . parameters ( ) ,
gradient_clip_value ) # dont worry the clipping occurs if |sum(grad)|^2>1000 => no clipping just monitoring
tb_writer . add_scalar ( ' Grad norm ' , total_grad_norm , ni )
# if total_grad_norm > gradient_clip_value:
# print("Gradeint {} was clipped to {}".format(total_grad_norm, gradient_clip_value))
2022-07-06 23:23:27 +08:00
# Optimize
if ni % accumulate == 0 :
scaler . step ( optimizer ) # optimizer.step
scaler . update ( )
optimizer . zero_grad ( )
if ema :
ema . update ( model )
# Print
if rank in [ - 1 , 0 ] :
mloss = ( mloss * i + loss_items ) / ( i + 1 ) # update mean losses
mem = ' %.3g G ' % ( torch . cuda . memory_reserved ( ) / 1E9 if torch . cuda . is_available ( ) else 0 ) # (GB)
s = ( ' %10s ' * 2 + ' %10.4g ' * 6 ) % (
' %g / %g ' % ( epoch , epochs - 1 ) , mem , * mloss , targets . shape [ 0 ] , imgs . shape [ - 1 ] )
pbar . set_description ( s )
# Plot
if plots and ni < 10 :
f = save_dir / f ' train_batch { ni } .jpg ' # filename
2024-08-19 20:56:20 +08:00
Thread ( target = plot_images , args = ( imgs , targets , paths , f , opt . input_channels ) , daemon = True ) . start ( )
2022-07-06 23:23:27 +08:00
# if tb_writer:
# tb_writer.add_image(f, result, dataformats='HWC', global_step=epoch)
# tb_writer.add_graph(torch.jit.trace(model, imgs, strict=False), []) # add model graph
elif plots and ni == 10 and wandb_logger . wandb :
wandb_logger . log ( { " Mosaics " : [ wandb_logger . wandb . Image ( str ( x ) , caption = x . name ) for x in
save_dir . glob ( ' train*.jpg ' ) if x . exists ( ) ] } )
# end batch ------------------------------------------------------------------------------------------------
# end epoch ----------------------------------------------------------------------------------------------------
# Scheduler
lr = [ x [ ' lr ' ] for x in optimizer . param_groups ] # for tensorboard
2024-09-15 14:20:39 +08:00
# print("Lr : ", 10*'+',lr)
2022-07-06 23:23:27 +08:00
scheduler . step ( )
2024-08-27 18:58:16 +08:00
if 1 : #@@ HK
plots = True
2022-07-06 23:23:27 +08:00
# DDP process 0 or single-GPU
if rank in [ - 1 , 0 ] :
# mAP
ema . update_attr ( model , include = [ ' yaml ' , ' nc ' , ' hyp ' , ' gr ' , ' names ' , ' stride ' , ' class_weights ' ] )
final_epoch = epoch + 1 == epochs
if not opt . notest or final_epoch : # Calculate mAP
wandb_logger . current_epoch = epoch + 1
results , maps , times = test . test ( data_dict ,
batch_size = batch_size * 2 ,
imgsz = imgsz_test ,
2024-10-14 15:32:25 +08:00
save_json = opt . save_json ,
2022-07-06 23:23:27 +08:00
model = ema . ema ,
2024-10-14 15:32:25 +08:00
iou_thres = hyp [ ' iou_t ' ] ,
2022-07-06 23:23:27 +08:00
single_cls = opt . single_cls ,
dataloader = testloader ,
save_dir = save_dir ,
verbose = nc < 50 and final_epoch ,
plots = plots and final_epoch ,
wandb_logger = wandb_logger ,
compute_loss = compute_loss ,
2022-09-17 02:14:01 +08:00
is_coco = is_coco ,
v5_metric = opt . v5_metric )
2022-07-06 23:23:27 +08:00
# Write
with open ( results_file , ' a ' ) as f :
f . write ( s + ' %10.4g ' * 7 % results + ' \n ' ) # append metrics, val_loss
if len ( opt . name ) and opt . bucket :
os . system ( ' gsutil cp %s gs:// %s /results/results %s .txt ' % ( results_file , opt . bucket , opt . name ) )
# Log
tags = [ ' train/box_loss ' , ' train/obj_loss ' , ' train/cls_loss ' , # train loss
' metrics/precision ' , ' metrics/recall ' , ' metrics/mAP_0.5 ' , ' metrics/mAP_0.5:0.95 ' ,
' val/box_loss ' , ' val/obj_loss ' , ' val/cls_loss ' , # val loss
' x/lr0 ' , ' x/lr1 ' , ' x/lr2 ' ] # params
for x , tag in zip ( list ( mloss [ : - 1 ] ) + list ( results ) + lr , tags ) :
if tb_writer :
tb_writer . add_scalar ( tag , x , epoch ) # tensorboard
if wandb_logger . wandb :
wandb_logger . log ( { tag : x } ) # W&B
# Update best mAP
fi = fitness ( np . array ( results ) . reshape ( 1 , - 1 ) ) # weighted combination of [P, R, mAP@.5, mAP@.5-.95]
if fi > best_fitness :
best_fitness = fi
wandb_logger . end_epoch ( best_result = best_fitness == fi )
# Save model
if ( not opt . nosave ) or ( final_epoch and not opt . evolve ) : # if save
ckpt = { ' epoch ' : epoch ,
' best_fitness ' : best_fitness ,
' training_results ' : results_file . read_text ( ) ,
2024-10-14 15:32:25 +08:00
' model ' : deepcopy ( model . module if is_parallel ( model ) else model ) . half ( ) , # HK TODO hlaf() is only if AMP is True
2022-07-06 23:23:27 +08:00
' ema ' : deepcopy ( ema . ema ) . half ( ) ,
' updates ' : ema . updates ,
' optimizer ' : optimizer . state_dict ( ) ,
' wandb_id ' : wandb_logger . wandb_run . id if wandb_logger . wandb else None }
# Save last, best and delete
torch . save ( ckpt , last )
if best_fitness == fi :
torch . save ( ckpt , best )
if ( best_fitness == fi ) and ( epoch > = 200 ) :
torch . save ( ckpt , wdir / ' best_ {:03d} .pt ' . format ( epoch ) )
if epoch == 0 :
torch . save ( ckpt , wdir / ' epoch_ {:03d} .pt ' . format ( epoch ) )
elif ( ( epoch + 1 ) % 25 ) == 0 :
torch . save ( ckpt , wdir / ' epoch_ {:03d} .pt ' . format ( epoch ) )
elif epoch > = ( epochs - 5 ) :
torch . save ( ckpt , wdir / ' epoch_ {:03d} .pt ' . format ( epoch ) )
if wandb_logger . wandb :
if ( ( epoch + 1 ) % opt . save_period == 0 and not final_epoch ) and opt . save_period != - 1 :
wandb_logger . log_model (
last . parent , opt , epoch , fi , best_model = best_fitness == fi )
del ckpt
# end epoch ----------------------------------------------------------------------------------------------------
# end training
if rank in [ - 1 , 0 ] :
# Plots
if plots :
plot_results ( save_dir = save_dir ) # save as results.png
if wandb_logger . wandb :
files = [ ' results.png ' , ' confusion_matrix.png ' , * [ f ' { x } _curve.png ' for x in ( ' F1 ' , ' PR ' , ' P ' , ' R ' ) ] ]
wandb_logger . log ( { " Results " : [ wandb_logger . wandb . Image ( str ( save_dir / f ) , caption = f ) for f in files
if ( save_dir / f ) . exists ( ) ] } )
# Test best.pt
logger . info ( ' %g epochs completed in %.3f hours. \n ' % ( epoch - start_epoch + 1 , ( time . time ( ) - t0 ) / 3600 ) )
if opt . data . endswith ( ' coco.yaml ' ) and nc == 80 : # if COCO
for m in ( last , best ) if best . exists ( ) else ( last ) : # speed, mAP tests
results , _ , _ = test . test ( opt . data ,
batch_size = batch_size * 2 ,
imgsz = imgsz_test ,
conf_thres = 0.001 ,
iou_thres = 0.7 ,
model = attempt_load ( m , device ) . half ( ) ,
single_cls = opt . single_cls ,
dataloader = testloader ,
save_dir = save_dir ,
save_json = True ,
plots = False ,
2022-09-17 02:14:01 +08:00
is_coco = is_coco ,
v5_metric = opt . v5_metric )
2022-07-06 23:23:27 +08:00
2024-10-14 15:32:25 +08:00
# Strip optimerizs
2022-07-06 23:23:27 +08:00
final = best if best . exists ( ) else last # final model
for f in last , best :
if f . exists ( ) :
strip_optimizer ( f ) # strip optimizers
if opt . bucket :
os . system ( f ' gsutil cp { final } gs:// { opt . bucket } /weights ' ) # upload
if wandb_logger . wandb and not opt . evolve : # Log the stripped model
wandb_logger . wandb . log_artifact ( str ( final ) , type = ' model ' ,
name = ' run_ ' + wandb_logger . wandb_run . id + ' _model ' ,
aliases = [ ' last ' , ' best ' , ' stripped ' ] )
wandb_logger . finish_run ( )
else :
dist . destroy_process_group ( )
torch . cuda . empty_cache ( )
return results
if __name__ == ' __main__ ' :
parser = argparse . ArgumentParser ( )
parser . add_argument ( ' --weights ' , type = str , default = ' yolo7.pt ' , help = ' initial weights path ' )
parser . add_argument ( ' --cfg ' , type = str , default = ' ' , help = ' model.yaml path ' )
parser . add_argument ( ' --data ' , type = str , default = ' data/coco.yaml ' , help = ' data.yaml path ' )
parser . add_argument ( ' --hyp ' , type = str , default = ' data/hyp.scratch.p5.yaml ' , help = ' hyperparameters path ' )
parser . add_argument ( ' --epochs ' , type = int , default = 300 )
parser . add_argument ( ' --batch-size ' , type = int , default = 16 , help = ' total batch size for all GPUs ' )
parser . add_argument ( ' --img-size ' , nargs = ' + ' , type = int , default = [ 640 , 640 ] , help = ' [train, test] image sizes ' )
parser . add_argument ( ' --rect ' , action = ' store_true ' , help = ' rectangular training ' )
parser . add_argument ( ' --resume ' , nargs = ' ? ' , const = True , default = False , help = ' resume most recent training ' )
parser . add_argument ( ' --nosave ' , action = ' store_true ' , help = ' only save final checkpoint ' )
2024-10-14 15:32:25 +08:00
parser . add_argument ( ' --save-json ' , action = ' store_true ' , help = ' save save-json ' )
2022-07-06 23:23:27 +08:00
parser . add_argument ( ' --notest ' , action = ' store_true ' , help = ' only test final epoch ' )
parser . add_argument ( ' --noautoanchor ' , action = ' store_true ' , help = ' disable autoanchor check ' )
parser . add_argument ( ' --evolve ' , action = ' store_true ' , help = ' evolve hyperparameters ' )
parser . add_argument ( ' --bucket ' , type = str , default = ' ' , help = ' gsutil bucket ' )
parser . add_argument ( ' --cache-images ' , action = ' store_true ' , help = ' cache images for faster training ' )
parser . add_argument ( ' --image-weights ' , action = ' store_true ' , help = ' use weighted image selection for training ' )
parser . add_argument ( ' --device ' , default = ' ' , help = ' cuda device, i.e. 0 or 0,1,2,3 or cpu ' )
parser . add_argument ( ' --multi-scale ' , action = ' store_true ' , help = ' vary img-size +/- 50 %% ' )
parser . add_argument ( ' --single-cls ' , action = ' store_true ' , help = ' train multi-class data as single-class ' )
parser . add_argument ( ' --adam ' , action = ' store_true ' , help = ' use torch.optim.Adam() optimizer ' )
parser . add_argument ( ' --sync-bn ' , action = ' store_true ' , help = ' use SyncBatchNorm, only available in DDP mode ' )
parser . add_argument ( ' --local_rank ' , type = int , default = - 1 , help = ' DDP parameter, do not modify ' )
parser . add_argument ( ' --workers ' , type = int , default = 8 , help = ' maximum number of dataloader workers ' )
parser . add_argument ( ' --project ' , default = ' runs/train ' , help = ' save to project/name ' )
parser . add_argument ( ' --entity ' , default = None , help = ' W&B entity ' )
parser . add_argument ( ' --name ' , default = ' exp ' , help = ' save to project/name ' )
parser . add_argument ( ' --exist-ok ' , action = ' store_true ' , help = ' existing project/name ok, do not increment ' )
parser . add_argument ( ' --quad ' , action = ' store_true ' , help = ' quad dataloader ' )
parser . add_argument ( ' --linear-lr ' , action = ' store_true ' , help = ' linear LR ' )
parser . add_argument ( ' --label-smoothing ' , type = float , default = 0.0 , help = ' Label smoothing epsilon ' )
parser . add_argument ( ' --upload_dataset ' , action = ' store_true ' , help = ' Upload dataset as W&B artifact table ' )
parser . add_argument ( ' --bbox_interval ' , type = int , default = - 1 , help = ' Set bounding-box image logging interval for W&B ' )
parser . add_argument ( ' --save_period ' , type = int , default = - 1 , help = ' Log model after every " save_period " epoch ' )
parser . add_argument ( ' --artifact_alias ' , type = str , default = " latest " , help = ' version of dataset artifact to be used ' )
2022-08-02 22:55:28 +08:00
parser . add_argument ( ' --freeze ' , nargs = ' + ' , type = int , default = [ 0 ] , help = ' Freeze layers: backbone of yolov7=50, first3=0 1 2 ' )
2022-09-17 02:14:01 +08:00
parser . add_argument ( ' --v5-metric ' , action = ' store_true ' , help = ' assume maximum recall as 1.0 in AP calculation ' )
2024-08-15 20:40:01 +08:00
parser . add_argument ( ' --norm-type ' , type = str , default = ' standardization ' ,
2024-10-14 15:32:25 +08:00
choices = [ ' standardization ' , ' single_image_0_to_1 ' , ' single_image_mean_std ' , ' single_image_percentile_0_255 ' ,
' single_image_percentile_0_1 ' , ' remove+global_outlier_0_1 ' ] ,
help = ' Normalization approach ' )
2024-09-23 16:23:13 +08:00
parser . add_argument ( ' --no-tir-signal ' , action = ' store_true ' , help = ' ' )
2024-09-15 14:20:39 +08:00
2024-09-23 16:23:13 +08:00
parser . add_argument ( ' --tir-channel-expansion ' , action = ' store_true ' , help = ' drc_per_ch_percentile ' )
2024-09-15 14:20:39 +08:00
2024-08-15 20:40:01 +08:00
parser . add_argument ( ' --input-channels ' , type = int , default = 3 , help = ' ' )
2024-10-14 15:32:25 +08:00
parser . add_argument ( ' --save-path ' , default = ' /mnt/Data/hanoch ' , help = ' save to project/name ' )
parser . add_argument ( ' --gamma-aug-prob ' , type = float , default = 0.1 , help = ' ' )
parser . add_argument ( ' --amp ' , action = ' store_true ' , help = ' Remove torch AMP ' )
2022-07-06 23:23:27 +08:00
opt = parser . parse_args ( )
2024-09-23 16:23:13 +08:00
if opt . tir_channel_expansion : # operates over 3 channels
opt . input_channels = 3
if opt . tir_channel_expansion and opt . norm_type != ' single_image_percentile_0_1 ' : # operates over 3 channels
print ( ' Not a good combination ' )
2022-07-06 23:23:27 +08:00
# Set DDP variables
opt . world_size = int ( os . environ [ ' WORLD_SIZE ' ] ) if ' WORLD_SIZE ' in os . environ else 1
opt . global_rank = int ( os . environ [ ' RANK ' ] ) if ' RANK ' in os . environ else - 1
set_logging ( opt . global_rank )
#if opt.global_rank in [-1, 0]:
# check_git_status()
# check_requirements()
# Resume
wandb_run = check_wandb_resume ( opt )
if opt . resume and not wandb_run : # resume an interrupted run
ckpt = opt . resume if isinstance ( opt . resume , str ) else get_latest_run ( ) # specified or most recent path
assert os . path . isfile ( ckpt ) , ' ERROR: --resume checkpoint does not exist '
apriori = opt . global_rank , opt . local_rank
with open ( Path ( ckpt ) . parent . parent / ' opt.yaml ' ) as f :
opt = argparse . Namespace ( * * yaml . load ( f , Loader = yaml . SafeLoader ) ) # replace
opt . cfg , opt . weights , opt . resume , opt . batch_size , opt . global_rank , opt . local_rank = ' ' , ckpt , True , opt . total_batch_size , * apriori # reinstate
logger . info ( ' Resuming training from %s ' % ckpt )
else :
# opt.hyp = opt.hyp or ('hyp.finetune.yaml' if opt.weights else 'hyp.scratch.yaml')
opt . data , opt . cfg , opt . hyp = check_file ( opt . data ) , check_file ( opt . cfg ) , check_file ( opt . hyp ) # check files
assert len ( opt . cfg ) or len ( opt . weights ) , ' either --cfg or --weights must be specified '
opt . img_size . extend ( [ opt . img_size [ - 1 ] ] * ( 2 - len ( opt . img_size ) ) ) # extend to 2 sizes (train, test)
opt . name = ' evolve ' if opt . evolve else opt . name
2024-10-14 15:32:25 +08:00
# if opt.save_path == '':
2022-07-06 23:23:27 +08:00
opt . save_dir = increment_path ( Path ( opt . project ) / opt . name , exist_ok = opt . exist_ok | opt . evolve ) # increment run
2024-10-14 15:32:25 +08:00
# else:
# opt.save_dir = os.path.join(opt.save_path,
# increment_path(Path(opt.project) / opt.name, exist_ok=opt.exist_ok | opt.evolve))
2022-07-06 23:23:27 +08:00
# DDP mode
opt . total_batch_size = opt . batch_size
device = select_device ( opt . device , batch_size = opt . batch_size )
if opt . local_rank != - 1 :
assert torch . cuda . device_count ( ) > opt . local_rank
torch . cuda . set_device ( opt . local_rank )
device = torch . device ( ' cuda ' , opt . local_rank )
dist . init_process_group ( backend = ' nccl ' , init_method = ' env:// ' ) # distributed backend
assert opt . batch_size % opt . world_size == 0 , ' --batch-size must be multiple of CUDA device count '
opt . batch_size = opt . total_batch_size / / opt . world_size
# Hyperparameters
with open ( opt . hyp ) as f :
hyp = yaml . load ( f , Loader = yaml . SafeLoader ) # load hyps
# Train
logger . info ( opt )
if not opt . evolve :
tb_writer = None # init loggers
if opt . global_rank in [ - 1 , 0 ] :
prefix = colorstr ( ' tensorboard: ' )
logger . info ( f " { prefix } Start with ' tensorboard --logdir { opt . project } ' , view at http://localhost:6006/ " )
tb_writer = SummaryWriter ( opt . save_dir ) # Tensorboard
train ( hyp , opt , device , tb_writer )
# Evolve hyperparameters (optional)
else :
# Hyperparameter evolution metadata (mutation scale 0-1, lower_limit, upper_limit)
meta = { ' lr0 ' : ( 1 , 1e-5 , 1e-1 ) , # initial learning rate (SGD=1E-2, Adam=1E-3)
' lrf ' : ( 1 , 0.01 , 1.0 ) , # final OneCycleLR learning rate (lr0 * lrf)
' momentum ' : ( 0.3 , 0.6 , 0.98 ) , # SGD momentum/Adam beta1
' weight_decay ' : ( 1 , 0.0 , 0.001 ) , # optimizer weight decay
' warmup_epochs ' : ( 1 , 0.0 , 5.0 ) , # warmup epochs (fractions ok)
' warmup_momentum ' : ( 1 , 0.0 , 0.95 ) , # warmup initial momentum
' warmup_bias_lr ' : ( 1 , 0.0 , 0.2 ) , # warmup initial bias lr
' box ' : ( 1 , 0.02 , 0.2 ) , # box loss gain
' cls ' : ( 1 , 0.2 , 4.0 ) , # cls loss gain
' cls_pw ' : ( 1 , 0.5 , 2.0 ) , # cls BCELoss positive_weight
' obj ' : ( 1 , 0.2 , 4.0 ) , # obj loss gain (scale with pixels)
' obj_pw ' : ( 1 , 0.5 , 2.0 ) , # obj BCELoss positive_weight
' iou_t ' : ( 0 , 0.1 , 0.7 ) , # IoU training threshold
' anchor_t ' : ( 1 , 2.0 , 8.0 ) , # anchor-multiple threshold
' anchors ' : ( 2 , 2.0 , 10.0 ) , # anchors per output grid (0 to ignore)
' fl_gamma ' : ( 0 , 0.0 , 2.0 ) , # focal loss gamma (efficientDet default gamma=1.5)
' hsv_h ' : ( 1 , 0.0 , 0.1 ) , # image HSV-Hue augmentation (fraction)
' hsv_s ' : ( 1 , 0.0 , 0.9 ) , # image HSV-Saturation augmentation (fraction)
' hsv_v ' : ( 1 , 0.0 , 0.9 ) , # image HSV-Value augmentation (fraction)
' degrees ' : ( 1 , 0.0 , 45.0 ) , # image rotation (+/- deg)
' translate ' : ( 1 , 0.0 , 0.9 ) , # image translation (+/- fraction)
' scale ' : ( 1 , 0.0 , 0.9 ) , # image scale (+/- gain)
' shear ' : ( 1 , 0.0 , 10.0 ) , # image shear (+/- deg)
' perspective ' : ( 0 , 0.0 , 0.001 ) , # image perspective (+/- fraction), range 0-0.001
' flipud ' : ( 1 , 0.0 , 1.0 ) , # image flip up-down (probability)
' fliplr ' : ( 0 , 0.0 , 1.0 ) , # image flip left-right (probability)
' mosaic ' : ( 1 , 0.0 , 1.0 ) , # image mixup (probability)
2022-07-29 03:09:52 +08:00
' mixup ' : ( 1 , 0.0 , 1.0 ) , # image mixup (probability)
' copy_paste ' : ( 1 , 0.0 , 1.0 ) , # segment copy-paste (probability)
' paste_in ' : ( 1 , 0.0 , 1.0 ) } # segment copy-paste (probability)
with open ( opt . hyp , errors = ' ignore ' ) as f :
hyp = yaml . safe_load ( f ) # load hyps dict
if ' anchors ' not in hyp : # anchors commented in hyp.yaml
hyp [ ' anchors ' ] = 3
2022-07-06 23:23:27 +08:00
assert opt . local_rank == - 1 , ' DDP mode not implemented for --evolve '
opt . notest , opt . nosave = True , True # only test/save final epoch
# ei = [isinstance(x, (int, float)) for x in hyp.values()] # evolvable indices
yaml_file = Path ( opt . save_dir ) / ' hyp_evolved.yaml ' # save best result here
if opt . bucket :
os . system ( ' gsutil cp gs:// %s /evolve.txt . ' % opt . bucket ) # download evolve.txt if exists
for _ in range ( 300 ) : # generations to evolve
if Path ( ' evolve.txt ' ) . exists ( ) : # if evolve.txt exists: select best hyps and mutate
# Select parent(s)
parent = ' single ' # parent selection method: 'single' or 'weighted'
x = np . loadtxt ( ' evolve.txt ' , ndmin = 2 )
n = min ( 5 , len ( x ) ) # number of previous results to consider
x = x [ np . argsort ( - fitness ( x ) ) ] [ : n ] # top n mutations
w = fitness ( x ) - fitness ( x ) . min ( ) # weights
if parent == ' single ' or len ( x ) == 1 :
# x = x[random.randint(0, n - 1)] # random selection
x = x [ random . choices ( range ( n ) , weights = w ) [ 0 ] ] # weighted selection
elif parent == ' weighted ' :
x = ( x * w . reshape ( n , 1 ) ) . sum ( 0 ) / w . sum ( ) # weighted combination
# Mutate
mp , s = 0.8 , 0.2 # mutation probability, sigma
npr = np . random
npr . seed ( int ( time . time ( ) ) )
g = np . array ( [ x [ 0 ] for x in meta . values ( ) ] ) # gains 0-1
ng = len ( meta )
v = np . ones ( ng )
while all ( v == 1 ) : # mutate until a change occurs (prevent duplicates)
v = ( g * ( npr . random ( ng ) < mp ) * npr . randn ( ng ) * npr . random ( ) * s + 1 ) . clip ( 0.3 , 3.0 )
for i , k in enumerate ( hyp . keys ( ) ) : # plt.hist(v.ravel(), 300)
hyp [ k ] = float ( x [ i + 7 ] * v [ i ] ) # mutate
# Constrain to limits
for k , v in meta . items ( ) :
hyp [ k ] = max ( hyp [ k ] , v [ 1 ] ) # lower limit
hyp [ k ] = min ( hyp [ k ] , v [ 2 ] ) # upper limit
hyp [ k ] = round ( hyp [ k ] , 5 ) # significant digits
# Train mutation
results = train ( hyp . copy ( ) , opt , device )
# Write mutation results
print_mutation ( hyp . copy ( ) , results , yaml_file , opt . bucket )
# Plot results
plot_evolution ( yaml_file )
print ( f ' Hyperparameter evolution complete. Best results saved as: { yaml_file } \n '
f ' Command to train a new model with these hyperparameters: $ python train.py --hyp { yaml_file } ' )
2024-08-07 23:03:48 +08:00
"""
2024-08-19 20:56:20 +08:00
TODO
Anchors ,
hyp [ ' anchor_t ' ] = 4 let the AR < = 4 = > TODO check if valid
Ive reduced anchors to 2 per anchors : 2
Sampler : torch_weighted : WeightedRandomSampler
2024-08-27 18:58:16 +08:00
PP - YOLO bumps the batch size up from 64 to 192. Of course , this is hard to implement if you have GPU memory constraints .
2024-08-19 20:56:20 +08:00
2024-09-15 14:20:39 +08:00
* * * * * * DONT FORGET to delete cache files upon changing data * * * * * * * * * * * *
2024-08-07 23:03:48 +08:00
python train . py - - workers 8 - - device ' cpu ' - - batch - size 32 - - data data / coco . yaml - - img 640 640 - - cfg cfg / training / yolov7 . yaml - - weights ' v7 ' - - name yolov7 - - hyp data / hyp . scratch . p5 . yaml
2024-08-15 20:40:01 +08:00
- - workers 8 - - device cpu - - batch - size 32 - - data data / tir_od . yaml - - img 640 640 - - cfg cfg / training / yolov7 . yaml - - weights ' v7 ' - - name yolov7 - - cache - images - - hyp data / hyp . tir_od . tiny . yaml - - adam - - norm - type single_image_percentile_0_1
- - workers 8 - - device cpu - - batch - size 32 - - data data / tir_od . yaml - - img 640 640 - - cfg cfg / training / yolov7 - tiny . yaml - - weights ' v7 ' - - name yolov7 - - cache - images - - hyp data / hyp . tir_od . tiny . yaml - - adam - - norm - type single_image_percentile_0_1 - - input - channels 1 - - multi - scale
2024-08-19 20:56:20 +08:00
- - multi - scale training with resized image resolution not good for TIR
2024-09-15 14:20:39 +08:00
TRaining based on given model w / o prototype yaml by the - - cfg
- - workers 8 - - device 0 - - batch - size 16 - - data data / coco_2_tir . yaml - - img 640 640 - - weights . / yolov7 / yolov7 . pt - - name yolov7 - - hyp data / hyp . tir_od . tiny . yaml - - adam - - norm - type single_image_percentile_0_1 - - input - channels 3 - - linear - lr - - noautoanchor
- - workers 8 - - device 0 - - batch - size 16 - - data data / tir_od . yaml - - img 640 640 - - weights . / yolov7 / yolov7 - tiny . pt - - name yolov7 - - hyp data / hyp . tir_od . tiny . yaml - - adam - - norm - type single_image_percentile_0_1 - - input - channels 3 - - linear - lr - - noautoanchor
== == == == == == == == == == == == == == == == == == == == == == == == == == == == == == == == == == == == == =
FT : you need the - - cfg of arch yaml because nc - classes are changing
- - workers 8 - - device 0 - - batch - size 16 - - data data / tir_od . yaml - - img 640 640 - - weights . / yolov7 / yolov7 - tiny . pt - - cfg cfg / training / yolov7 - tiny . yaml - - name yolov7 - - hyp data / hyp . tir_od . tiny . yaml - - adam - - norm - type single_image_percentile_0_1 - - input - channels 3 - - linear - lr
2024-09-23 16:23:13 +08:00
- - workers 8 - - device 0 - - batch - size 16 - - data data / tir_od . yaml - - img 640 640 - - weights . / yolov7 / yolov7 - tiny . pt - - cfg cfg / training / yolov7 - tiny . yaml - - name yolov7 - - hyp hyp . tir_od . tiny_aug . yaml - - adam - - norm - type single_image_mean_std - - input - channels 3 - - linear - lr - - epochs 2
2024-08-07 23:03:48 +08:00
"""