mirror of
https://github.com/ultralytics/yolov5.git
synced 2025-06-03 14:49:29 +08:00
Update train.py (#2290)
* Update train.py * Update train.py * Update train.py * Update train.py * Create train.py
This commit is contained in:
parent
0070995bd5
commit
ca5b10b759
35
train.py
35
train.py
@ -146,8 +146,7 @@ def train(hyp, opt, device, tb_writer=None, wandb=None):
|
||||
|
||||
# Results
|
||||
if ckpt.get('training_results') is not None:
|
||||
with open(results_file, 'w') as file:
|
||||
file.write(ckpt['training_results']) # write results.txt
|
||||
results_file.write_text(ckpt['training_results']) # write results.txt
|
||||
|
||||
# Epochs
|
||||
start_epoch = ckpt['epoch'] + 1
|
||||
@ -354,7 +353,7 @@ def train(hyp, opt, device, tb_writer=None, wandb=None):
|
||||
|
||||
# Write
|
||||
with open(results_file, 'a') as f:
|
||||
f.write(s + '%10.4g' * 7 % results + '\n') # P, R, mAP@.5, mAP@.5-.95, val_loss(box, obj, cls)
|
||||
f.write(s + '%10.4g' * 7 % results + '\n') # append metrics, val_loss
|
||||
if len(opt.name) and opt.bucket:
|
||||
os.system('gsutil cp %s gs://%s/results/results%s.txt' % (results_file, opt.bucket, opt.name))
|
||||
|
||||
@ -375,15 +374,13 @@ def train(hyp, opt, device, tb_writer=None, wandb=None):
|
||||
best_fitness = fi
|
||||
|
||||
# Save model
|
||||
save = (not opt.nosave) or (final_epoch and not opt.evolve)
|
||||
if save:
|
||||
with open(results_file, 'r') as f: # create checkpoint
|
||||
ckpt = {'epoch': epoch,
|
||||
'best_fitness': best_fitness,
|
||||
'training_results': f.read(),
|
||||
'model': ema.ema,
|
||||
'optimizer': None if final_epoch else optimizer.state_dict(),
|
||||
'wandb_id': wandb_run.id if wandb else None}
|
||||
if (not opt.nosave) or (final_epoch and not opt.evolve): # if save
|
||||
ckpt = {'epoch': epoch,
|
||||
'best_fitness': best_fitness,
|
||||
'training_results': results_file.read_text(),
|
||||
'model': ema.ema,
|
||||
'optimizer': None if final_epoch else optimizer.state_dict(),
|
||||
'wandb_id': wandb_run.id if wandb else None}
|
||||
|
||||
# Save last, best and delete
|
||||
torch.save(ckpt, last)
|
||||
@ -396,9 +393,9 @@ def train(hyp, opt, device, tb_writer=None, wandb=None):
|
||||
if rank in [-1, 0]:
|
||||
# Strip optimizers
|
||||
final = best if best.exists() else last # final model
|
||||
for f in [last, best]:
|
||||
for f in last, best:
|
||||
if f.exists():
|
||||
strip_optimizer(f) # strip optimizers
|
||||
strip_optimizer(f)
|
||||
if opt.bucket:
|
||||
os.system(f'gsutil cp {final} gs://{opt.bucket}/weights') # upload
|
||||
|
||||
@ -415,17 +412,17 @@ def train(hyp, opt, device, tb_writer=None, wandb=None):
|
||||
# Test best.pt
|
||||
logger.info('%g epochs completed in %.3f hours.\n' % (epoch - start_epoch + 1, (time.time() - t0) / 3600))
|
||||
if opt.data.endswith('coco.yaml') and nc == 80: # if COCO
|
||||
for conf, iou, save_json in ([0.25, 0.45, False], [0.001, 0.65, True]): # speed, mAP tests
|
||||
for m in (last, best) if best.exists() else (last): # speed, mAP tests
|
||||
results, _, _ = test.test(opt.data,
|
||||
batch_size=batch_size * 2,
|
||||
imgsz=imgsz_test,
|
||||
conf_thres=conf,
|
||||
iou_thres=iou,
|
||||
model=attempt_load(final, device).half(),
|
||||
conf_thres=0.001,
|
||||
iou_thres=0.7,
|
||||
model=attempt_load(m, device).half(),
|
||||
single_cls=opt.single_cls,
|
||||
dataloader=testloader,
|
||||
save_dir=save_dir,
|
||||
save_json=save_json,
|
||||
save_json=True,
|
||||
plots=False)
|
||||
|
||||
else:
|
||||
|
Loading…
x
Reference in New Issue
Block a user