mirror of
https://github.com/ultralytics/yolov5.git
synced 2025-06-03 14:49:29 +08:00
Improved W&B integration (#2125)
* Init Commit * new wandb integration * Update * Use data_dict in test * Updates * Update: scope of log_img * Update: scope of log_img * Update * Update: Fix logging conditions * Add tqdm bar, support for .txt dataset format * Improve Result table Logger * Init Commit * new wandb integration * Update * Use data_dict in test * Updates * Update: scope of log_img * Update: scope of log_img * Update * Update: Fix logging conditions * Add tqdm bar, support for .txt dataset format * Improve Result table Logger * Add dataset creation in training script * Change scope: self.wandb_run * Add wandb-artifact:// natively you can now use --resume with wandb run links * Add suuport for logging dataset while training * Cleanup * Fix: Merge conflict * Fix: CI tests * Automatically use wandb config * Fix: Resume * Fix: CI * Enhance: Using val_table * More resume enhancement * FIX : CI * Add alias * Get useful opt config data * train.py cleanup * Cleanup train.py * more cleanup * Cleanup| CI fix * Reformat using PEP8 * FIX:CI * rebase * remove uneccesary changes * remove uneccesary changes * remove uneccesary changes * remove unecessary chage from test.py * FIX: resume from local checkpoint * FIX:resume * FIX:resume * Reformat * Performance improvement * Fix local resume * Fix local resume * FIX:CI * Fix: CI * Imporve image logging * (:(:Redo CI tests:):) * Remember epochs when resuming * Remember epochs when resuming * Update DDP location Potential fix for #2405 * PEP8 reformat * 0.25 confidence threshold * reset train.py plots syntax to previous * reset epochs completed syntax to previous * reset space to previous * remove brackets * reset comment to previous * Update: is_coco check, remove unused code * Remove redundant print statement * Remove wandb imports * remove dsviz logger from test.py * Remove redundant change from test.py * remove redundant changes from train.py * reformat and improvements * Fix typo * Add tqdm tqdm progress when scanning files, naming improvements Co-authored-by: Glenn Jocher <glenn.jocher@ultralytics.com>
This commit is contained in:
parent
ed2c74218d
commit
e8fc97aa38
@ -278,7 +278,7 @@ class Detections:
|
||||
def print(self):
|
||||
self.display(pprint=True) # print results
|
||||
print(f'Speed: %.1fms pre-process, %.1fms inference, %.1fms NMS per image at shape {tuple(self.s)}' %
|
||||
tuple(self.t))
|
||||
tuple(self.t))
|
||||
|
||||
def show(self):
|
||||
self.display(show=True) # show results
|
||||
|
49
test.py
49
test.py
@ -35,8 +35,9 @@ def test(data,
|
||||
save_hybrid=False, # for hybrid auto-labelling
|
||||
save_conf=False, # save auto-label confidences
|
||||
plots=True,
|
||||
log_imgs=0, # number of logged images
|
||||
compute_loss=None):
|
||||
wandb_logger=None,
|
||||
compute_loss=None,
|
||||
is_coco=False):
|
||||
# Initialize/load model and set device
|
||||
training = model is not None
|
||||
if training: # called by train.py
|
||||
@ -66,21 +67,19 @@ def test(data,
|
||||
|
||||
# Configure
|
||||
model.eval()
|
||||
is_coco = data.endswith('coco.yaml') # is COCO dataset
|
||||
with open(data) as f:
|
||||
data = yaml.load(f, Loader=yaml.SafeLoader) # model dict
|
||||
if isinstance(data, str):
|
||||
is_coco = data.endswith('coco.yaml')
|
||||
with open(data) as f:
|
||||
data = yaml.load(f, Loader=yaml.SafeLoader)
|
||||
check_dataset(data) # check
|
||||
nc = 1 if single_cls else int(data['nc']) # number of classes
|
||||
iouv = torch.linspace(0.5, 0.95, 10).to(device) # iou vector for mAP@0.5:0.95
|
||||
niou = iouv.numel()
|
||||
|
||||
# Logging
|
||||
log_imgs, wandb = min(log_imgs, 100), None # ceil
|
||||
try:
|
||||
import wandb # Weights & Biases
|
||||
except ImportError:
|
||||
log_imgs = 0
|
||||
|
||||
log_imgs = 0
|
||||
if wandb_logger and wandb_logger.wandb:
|
||||
log_imgs = min(wandb_logger.log_imgs, 100)
|
||||
# Dataloader
|
||||
if not training:
|
||||
if device.type != 'cpu':
|
||||
@ -147,15 +146,17 @@ def test(data,
|
||||
with open(save_dir / 'labels' / (path.stem + '.txt'), 'a') as f:
|
||||
f.write(('%g ' * len(line)).rstrip() % line + '\n')
|
||||
|
||||
# W&B logging
|
||||
if plots and len(wandb_images) < log_imgs:
|
||||
box_data = [{"position": {"minX": xyxy[0], "minY": xyxy[1], "maxX": xyxy[2], "maxY": xyxy[3]},
|
||||
"class_id": int(cls),
|
||||
"box_caption": "%s %.3f" % (names[cls], conf),
|
||||
"scores": {"class_score": conf},
|
||||
"domain": "pixel"} for *xyxy, conf, cls in pred.tolist()]
|
||||
boxes = {"predictions": {"box_data": box_data, "class_labels": names}} # inference-space
|
||||
wandb_images.append(wandb.Image(img[si], boxes=boxes, caption=path.name))
|
||||
# W&B logging - Media Panel Plots
|
||||
if len(wandb_images) < log_imgs and wandb_logger.current_epoch > 0: # Check for test operation
|
||||
if wandb_logger.current_epoch % wandb_logger.bbox_interval == 0:
|
||||
box_data = [{"position": {"minX": xyxy[0], "minY": xyxy[1], "maxX": xyxy[2], "maxY": xyxy[3]},
|
||||
"class_id": int(cls),
|
||||
"box_caption": "%s %.3f" % (names[cls], conf),
|
||||
"scores": {"class_score": conf},
|
||||
"domain": "pixel"} for *xyxy, conf, cls in pred.tolist()]
|
||||
boxes = {"predictions": {"box_data": box_data, "class_labels": names}} # inference-space
|
||||
wandb_images.append(wandb_logger.wandb.Image(img[si], boxes=boxes, caption=path.name))
|
||||
wandb_logger.log_training_progress(predn, path, names) # logs dsviz tables
|
||||
|
||||
# Append to pycocotools JSON dictionary
|
||||
if save_json:
|
||||
@ -239,9 +240,11 @@ def test(data,
|
||||
# Plots
|
||||
if plots:
|
||||
confusion_matrix.plot(save_dir=save_dir, names=list(names.values()))
|
||||
if wandb and wandb.run:
|
||||
val_batches = [wandb.Image(str(f), caption=f.name) for f in sorted(save_dir.glob('test*.jpg'))]
|
||||
wandb.log({"Images": wandb_images, "Validation": val_batches}, commit=False)
|
||||
if wandb_logger and wandb_logger.wandb:
|
||||
val_batches = [wandb_logger.wandb.Image(str(f), caption=f.name) for f in sorted(save_dir.glob('test*.jpg'))]
|
||||
wandb_logger.log({"Validation": val_batches})
|
||||
if wandb_images:
|
||||
wandb_logger.log({"Bounding Box Debugger/Images": wandb_images})
|
||||
|
||||
# Save JSON
|
||||
if save_json and len(jdict):
|
||||
|
116
train.py
116
train.py
@ -1,3 +1,4 @@
|
||||
|
||||
import argparse
|
||||
import logging
|
||||
import math
|
||||
@ -33,11 +34,12 @@ from utils.google_utils import attempt_download
|
||||
from utils.loss import ComputeLoss
|
||||
from utils.plots import plot_images, plot_labels, plot_results, plot_evolution
|
||||
from utils.torch_utils import ModelEMA, select_device, intersect_dicts, torch_distributed_zero_first, is_parallel
|
||||
from utils.wandb_logging.wandb_utils import WandbLogger, resume_and_get_id, check_wandb_config_file
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def train(hyp, opt, device, tb_writer=None, wandb=None):
|
||||
def train(hyp, opt, device, tb_writer=None):
|
||||
logger.info(colorstr('hyperparameters: ') + ', '.join(f'{k}={v}' for k, v in hyp.items()))
|
||||
save_dir, epochs, batch_size, total_batch_size, weights, rank = \
|
||||
Path(opt.save_dir), opt.epochs, opt.batch_size, opt.total_batch_size, opt.weights, opt.global_rank
|
||||
@ -61,10 +63,17 @@ def train(hyp, opt, device, tb_writer=None, wandb=None):
|
||||
init_seeds(2 + rank)
|
||||
with open(opt.data) as f:
|
||||
data_dict = yaml.load(f, Loader=yaml.SafeLoader) # data dict
|
||||
with torch_distributed_zero_first(rank):
|
||||
check_dataset(data_dict) # check
|
||||
train_path = data_dict['train']
|
||||
test_path = data_dict['val']
|
||||
is_coco = opt.data.endswith('coco.yaml')
|
||||
|
||||
# Logging- Doing this before checking the dataset. Might update data_dict
|
||||
if rank in [-1, 0]:
|
||||
opt.hyp = hyp # add hyperparameters
|
||||
run_id = torch.load(weights).get('wandb_id') if weights.endswith('.pt') and os.path.isfile(weights) else None
|
||||
wandb_logger = WandbLogger(opt, Path(opt.save_dir).stem, run_id, data_dict)
|
||||
data_dict = wandb_logger.data_dict
|
||||
if wandb_logger.wandb:
|
||||
weights, epochs, hyp = opt.weights, opt.epochs, opt.hyp # WandbLogger might update weights, epochs if resuming
|
||||
loggers = {'wandb': wandb_logger.wandb} # loggers dict
|
||||
nc = 1 if opt.single_cls else int(data_dict['nc']) # number of classes
|
||||
names = ['item'] if opt.single_cls and len(data_dict['names']) != 1 else data_dict['names'] # class names
|
||||
assert len(names) == nc, '%g names found for nc=%g dataset in %s' % (len(names), nc, opt.data) # check
|
||||
@ -83,6 +92,10 @@ def train(hyp, opt, device, tb_writer=None, wandb=None):
|
||||
logger.info('Transferred %g/%g items from %s' % (len(state_dict), len(model.state_dict()), weights)) # report
|
||||
else:
|
||||
model = Model(opt.cfg, ch=3, nc=nc, anchors=hyp.get('anchors')).to(device) # create
|
||||
with torch_distributed_zero_first(rank):
|
||||
check_dataset(data_dict) # check
|
||||
train_path = data_dict['train']
|
||||
test_path = data_dict['val']
|
||||
|
||||
# Freeze
|
||||
freeze = [] # parameter names to freeze (full or partial)
|
||||
@ -126,16 +139,6 @@ def train(hyp, opt, device, tb_writer=None, wandb=None):
|
||||
scheduler = lr_scheduler.LambdaLR(optimizer, lr_lambda=lf)
|
||||
# plot_lr_scheduler(optimizer, scheduler, epochs)
|
||||
|
||||
# Logging
|
||||
if rank in [-1, 0] and wandb and wandb.run is None:
|
||||
opt.hyp = hyp # add hyperparameters
|
||||
wandb_run = wandb.init(config=opt, resume="allow",
|
||||
project='YOLOv5' if opt.project == 'runs/train' else Path(opt.project).stem,
|
||||
name=save_dir.stem,
|
||||
entity=opt.entity,
|
||||
id=ckpt.get('wandb_id') if 'ckpt' in locals() else None)
|
||||
loggers = {'wandb': wandb} # loggers dict
|
||||
|
||||
# EMA
|
||||
ema = ModelEMA(model) if rank in [-1, 0] else None
|
||||
|
||||
@ -326,9 +329,9 @@ def train(hyp, opt, device, tb_writer=None, wandb=None):
|
||||
# if tb_writer:
|
||||
# tb_writer.add_image(f, result, dataformats='HWC', global_step=epoch)
|
||||
# tb_writer.add_graph(model, imgs) # add model to tensorboard
|
||||
elif plots and ni == 10 and wandb:
|
||||
wandb.log({"Mosaics": [wandb.Image(str(x), caption=x.name) for x in save_dir.glob('train*.jpg')
|
||||
if x.exists()]}, commit=False)
|
||||
elif plots and ni == 10 and wandb_logger.wandb:
|
||||
wandb_logger.log({"Mosaics": [wandb_logger.wandb.Image(str(x), caption=x.name) for x in
|
||||
save_dir.glob('train*.jpg') if x.exists()]})
|
||||
|
||||
# end batch ------------------------------------------------------------------------------------------------
|
||||
# end epoch ----------------------------------------------------------------------------------------------------
|
||||
@ -343,8 +346,9 @@ def train(hyp, opt, device, tb_writer=None, wandb=None):
|
||||
ema.update_attr(model, include=['yaml', 'nc', 'hyp', 'gr', 'names', 'stride', 'class_weights'])
|
||||
final_epoch = epoch + 1 == epochs
|
||||
if not opt.notest or final_epoch: # Calculate mAP
|
||||
results, maps, times = test.test(opt.data,
|
||||
batch_size=batch_size * 2,
|
||||
wandb_logger.current_epoch = epoch + 1
|
||||
results, maps, times = test.test(data_dict,
|
||||
batch_size=total_batch_size,
|
||||
imgsz=imgsz_test,
|
||||
model=ema.ema,
|
||||
single_cls=opt.single_cls,
|
||||
@ -352,8 +356,9 @@ def train(hyp, opt, device, tb_writer=None, wandb=None):
|
||||
save_dir=save_dir,
|
||||
verbose=nc < 50 and final_epoch,
|
||||
plots=plots and final_epoch,
|
||||
log_imgs=opt.log_imgs if wandb else 0,
|
||||
compute_loss=compute_loss)
|
||||
wandb_logger=wandb_logger,
|
||||
compute_loss=compute_loss,
|
||||
is_coco=is_coco)
|
||||
|
||||
# Write
|
||||
with open(results_file, 'a') as f:
|
||||
@ -369,8 +374,8 @@ def train(hyp, opt, device, tb_writer=None, wandb=None):
|
||||
for x, tag in zip(list(mloss[:-1]) + list(results) + lr, tags):
|
||||
if tb_writer:
|
||||
tb_writer.add_scalar(tag, x, epoch) # tensorboard
|
||||
if wandb:
|
||||
wandb.log({tag: x}, step=epoch, commit=tag == tags[-1]) # W&B
|
||||
if wandb_logger.wandb:
|
||||
wandb_logger.log({tag: x}) # W&B
|
||||
|
||||
# Update best mAP
|
||||
fi = fitness(np.array(results).reshape(1, -1)) # weighted combination of [P, R, mAP@.5, mAP@.5-.95]
|
||||
@ -386,36 +391,29 @@ def train(hyp, opt, device, tb_writer=None, wandb=None):
|
||||
'ema': deepcopy(ema.ema).half(),
|
||||
'updates': ema.updates,
|
||||
'optimizer': optimizer.state_dict(),
|
||||
'wandb_id': wandb_run.id if wandb else None}
|
||||
'wandb_id': wandb_logger.wandb_run.id if wandb_logger.wandb else None}
|
||||
|
||||
# Save last, best and delete
|
||||
torch.save(ckpt, last)
|
||||
if best_fitness == fi:
|
||||
torch.save(ckpt, best)
|
||||
if wandb_logger.wandb:
|
||||
if ((epoch + 1) % opt.save_period == 0 and not final_epoch) and opt.save_period != -1:
|
||||
wandb_logger.log_model(
|
||||
last.parent, opt, epoch, fi, best_model=best_fitness == fi)
|
||||
del ckpt
|
||||
|
||||
wandb_logger.end_epoch(best_result=best_fitness == fi)
|
||||
|
||||
# end epoch ----------------------------------------------------------------------------------------------------
|
||||
# end training
|
||||
|
||||
if rank in [-1, 0]:
|
||||
# Strip optimizers
|
||||
final = best if best.exists() else last # final model
|
||||
for f in last, best:
|
||||
if f.exists():
|
||||
strip_optimizer(f)
|
||||
if opt.bucket:
|
||||
os.system(f'gsutil cp {final} gs://{opt.bucket}/weights') # upload
|
||||
|
||||
# Plots
|
||||
if plots:
|
||||
plot_results(save_dir=save_dir) # save as results.png
|
||||
if wandb:
|
||||
if wandb_logger.wandb:
|
||||
files = ['results.png', 'confusion_matrix.png', *[f'{x}_curve.png' for x in ('F1', 'PR', 'P', 'R')]]
|
||||
wandb.log({"Results": [wandb.Image(str(save_dir / f), caption=f) for f in files
|
||||
if (save_dir / f).exists()]})
|
||||
if opt.log_artifacts:
|
||||
wandb.log_artifact(artifact_or_path=str(final), type='model', name=save_dir.stem)
|
||||
|
||||
wandb_logger.log({"Results": [wandb_logger.wandb.Image(str(save_dir / f), caption=f) for f in files
|
||||
if (save_dir / f).exists()]})
|
||||
# Test best.pt
|
||||
logger.info('%g epochs completed in %.3f hours.\n' % (epoch - start_epoch + 1, (time.time() - t0) / 3600))
|
||||
if opt.data.endswith('coco.yaml') and nc == 80: # if COCO
|
||||
@ -430,13 +428,24 @@ def train(hyp, opt, device, tb_writer=None, wandb=None):
|
||||
dataloader=testloader,
|
||||
save_dir=save_dir,
|
||||
save_json=True,
|
||||
plots=False)
|
||||
plots=False,
|
||||
is_coco=is_coco)
|
||||
|
||||
# Strip optimizers
|
||||
final = best if best.exists() else last # final model
|
||||
for f in last, best:
|
||||
if f.exists():
|
||||
strip_optimizer(f) # strip optimizers
|
||||
if opt.bucket:
|
||||
os.system(f'gsutil cp {final} gs://{opt.bucket}/weights') # upload
|
||||
if wandb_logger.wandb: # Log the stripped model
|
||||
wandb_logger.wandb.log_artifact(str(final), type='model',
|
||||
name='run_' + wandb_logger.wandb_run.id + '_model',
|
||||
aliases=['last', 'best', 'stripped'])
|
||||
else:
|
||||
dist.destroy_process_group()
|
||||
|
||||
wandb.run.finish() if wandb and wandb.run else None
|
||||
torch.cuda.empty_cache()
|
||||
wandb_logger.finish_run()
|
||||
return results
|
||||
|
||||
|
||||
@ -464,8 +473,6 @@ if __name__ == '__main__':
|
||||
parser.add_argument('--adam', action='store_true', help='use torch.optim.Adam() optimizer')
|
||||
parser.add_argument('--sync-bn', action='store_true', help='use SyncBatchNorm, only available in DDP mode')
|
||||
parser.add_argument('--local_rank', type=int, default=-1, help='DDP parameter, do not modify')
|
||||
parser.add_argument('--log-imgs', type=int, default=16, help='number of images for W&B logging, max 100')
|
||||
parser.add_argument('--log-artifacts', action='store_true', help='log artifacts, i.e. final trained model')
|
||||
parser.add_argument('--workers', type=int, default=8, help='maximum number of dataloader workers')
|
||||
parser.add_argument('--project', default='runs/train', help='save to project/name')
|
||||
parser.add_argument('--entity', default=None, help='W&B entity')
|
||||
@ -473,6 +480,10 @@ if __name__ == '__main__':
|
||||
parser.add_argument('--exist-ok', action='store_true', help='existing project/name ok, do not increment')
|
||||
parser.add_argument('--quad', action='store_true', help='quad dataloader')
|
||||
parser.add_argument('--linear-lr', action='store_true', help='linear LR')
|
||||
parser.add_argument('--upload_dataset', action='store_true', help='Upload dataset as W&B artifact table')
|
||||
parser.add_argument('--bbox_interval', type=int, default=-1, help='Set bounding-box image logging interval for W&B')
|
||||
parser.add_argument('--save_period', type=int, default=-1, help='Log model after every "save_period" epoch')
|
||||
parser.add_argument('--artifact_alias', type=str, default="latest", help='version of dataset artifact to be used')
|
||||
opt = parser.parse_args()
|
||||
|
||||
# Set DDP variables
|
||||
@ -484,7 +495,8 @@ if __name__ == '__main__':
|
||||
check_requirements()
|
||||
|
||||
# Resume
|
||||
if opt.resume: # resume an interrupted run
|
||||
wandb_run = resume_and_get_id(opt)
|
||||
if opt.resume and not wandb_run: # resume an interrupted run
|
||||
ckpt = opt.resume if isinstance(opt.resume, str) else get_latest_run() # specified or most recent path
|
||||
assert os.path.isfile(ckpt), 'ERROR: --resume checkpoint does not exist'
|
||||
apriori = opt.global_rank, opt.local_rank
|
||||
@ -517,18 +529,12 @@ if __name__ == '__main__':
|
||||
|
||||
# Train
|
||||
logger.info(opt)
|
||||
try:
|
||||
import wandb
|
||||
except ImportError:
|
||||
wandb = None
|
||||
prefix = colorstr('wandb: ')
|
||||
logger.info(f"{prefix}Install Weights & Biases for YOLOv5 logging with 'pip install wandb' (recommended)")
|
||||
if not opt.evolve:
|
||||
tb_writer = None # init loggers
|
||||
if opt.global_rank in [-1, 0]:
|
||||
logger.info(f'Start Tensorboard with "tensorboard --logdir {opt.project}", view at http://localhost:6006/')
|
||||
tb_writer = SummaryWriter(opt.save_dir) # Tensorboard
|
||||
train(hyp, opt, device, tb_writer, wandb)
|
||||
train(hyp, opt, device, tb_writer)
|
||||
|
||||
# Evolve hyperparameters (optional)
|
||||
else:
|
||||
@ -602,7 +608,7 @@ if __name__ == '__main__':
|
||||
hyp[k] = round(hyp[k], 5) # significant digits
|
||||
|
||||
# Train mutation
|
||||
results = train(hyp.copy(), opt, device, wandb=wandb)
|
||||
results = train(hyp.copy(), opt, device)
|
||||
|
||||
# Write mutation results
|
||||
print_mutation(hyp.copy(), results, yaml_file, opt.bucket)
|
||||
|
@ -12,20 +12,7 @@ WANDB_ARTIFACT_PREFIX = 'wandb-artifact://'
|
||||
def create_dataset_artifact(opt):
|
||||
with open(opt.data) as f:
|
||||
data = yaml.load(f, Loader=yaml.SafeLoader) # data dict
|
||||
logger = WandbLogger(opt, '', None, data, job_type='create_dataset')
|
||||
nc, names = (1, ['item']) if opt.single_cls else (int(data['nc']), data['names'])
|
||||
names = {k: v for k, v in enumerate(names)} # to index dictionary
|
||||
logger.log_dataset_artifact(LoadImagesAndLabels(data['train']), names, name='train') # trainset
|
||||
logger.log_dataset_artifact(LoadImagesAndLabels(data['val']), names, name='val') # valset
|
||||
|
||||
# Update data.yaml with artifact links
|
||||
data['train'] = WANDB_ARTIFACT_PREFIX + str(Path(opt.project) / 'train')
|
||||
data['val'] = WANDB_ARTIFACT_PREFIX + str(Path(opt.project) / 'val')
|
||||
path = opt.data if opt.overwrite_config else opt.data.replace('.', '_wandb.') # updated data.yaml path
|
||||
data.pop('download', None) # download via artifact instead of predefined field 'download:'
|
||||
with open(path, 'w') as f:
|
||||
yaml.dump(data, f)
|
||||
print("New Config file => ", path)
|
||||
logger = WandbLogger(opt, '', None, data, job_type='Dataset Creation')
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
@ -33,7 +20,6 @@ if __name__ == '__main__':
|
||||
parser.add_argument('--data', type=str, default='data/coco128.yaml', help='data.yaml path')
|
||||
parser.add_argument('--single-cls', action='store_true', help='train as single-class dataset')
|
||||
parser.add_argument('--project', type=str, default='YOLOv5', help='name of W&B Project')
|
||||
parser.add_argument('--overwrite_config', action='store_true', help='overwrite data.yaml')
|
||||
opt = parser.parse_args()
|
||||
|
||||
create_dataset_artifact(opt)
|
||||
|
@ -1,13 +1,18 @@
|
||||
import argparse
|
||||
import json
|
||||
import os
|
||||
import shutil
|
||||
import sys
|
||||
import torch
|
||||
import yaml
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
|
||||
import torch
|
||||
from tqdm import tqdm
|
||||
|
||||
sys.path.append(str(Path(__file__).parent.parent.parent)) # add utils/ to path
|
||||
from utils.general import colorstr, xywh2xyxy
|
||||
from utils.datasets import LoadImagesAndLabels
|
||||
from utils.datasets import img2label_paths
|
||||
from utils.general import colorstr, xywh2xyxy, check_dataset
|
||||
|
||||
try:
|
||||
import wandb
|
||||
@ -22,87 +27,183 @@ def remove_prefix(from_string, prefix):
|
||||
return from_string[len(prefix):]
|
||||
|
||||
|
||||
def check_wandb_config_file(data_config_file):
|
||||
wandb_config = '_wandb.'.join(data_config_file.rsplit('.', 1)) # updated data.yaml path
|
||||
if Path(wandb_config).is_file():
|
||||
return wandb_config
|
||||
return data_config_file
|
||||
|
||||
|
||||
def resume_and_get_id(opt):
|
||||
# It's more elegant to stick to 1 wandb.init call, but as useful config data is overwritten in the WandbLogger's wandb.init call
|
||||
if isinstance(opt.resume, str):
|
||||
if opt.resume.startswith(WANDB_ARTIFACT_PREFIX):
|
||||
run_path = Path(remove_prefix(opt.resume, WANDB_ARTIFACT_PREFIX))
|
||||
run_id = run_path.stem
|
||||
project = run_path.parent.stem
|
||||
model_artifact_name = WANDB_ARTIFACT_PREFIX + 'run_' + run_id + '_model'
|
||||
assert wandb, 'install wandb to resume wandb runs'
|
||||
# Resume wandb-artifact:// runs here| workaround for not overwriting wandb.config
|
||||
run = wandb.init(id=run_id, project=project, resume='allow')
|
||||
opt.resume = model_artifact_name
|
||||
return run
|
||||
return None
|
||||
|
||||
|
||||
class WandbLogger():
|
||||
def __init__(self, opt, name, run_id, data_dict, job_type='Training'):
|
||||
self.wandb = wandb
|
||||
self.wandb_run = wandb.init(config=opt, resume="allow",
|
||||
project='YOLOv5' if opt.project == 'runs/train' else Path(opt.project).stem,
|
||||
name=name,
|
||||
job_type=job_type,
|
||||
id=run_id) if self.wandb else None
|
||||
# Pre-training routine --
|
||||
self.job_type = job_type
|
||||
self.wandb, self.wandb_run, self.data_dict = wandb, None if not wandb else wandb.run, data_dict
|
||||
if self.wandb:
|
||||
self.wandb_run = wandb.init(config=opt,
|
||||
resume="allow",
|
||||
project='YOLOv5' if opt.project == 'runs/train' else Path(opt.project).stem,
|
||||
name=name,
|
||||
job_type=job_type,
|
||||
id=run_id) if not wandb.run else wandb.run
|
||||
if self.job_type == 'Training':
|
||||
if not opt.resume:
|
||||
wandb_data_dict = self.check_and_upload_dataset(opt) if opt.upload_dataset else data_dict
|
||||
# Info useful for resuming from artifacts
|
||||
self.wandb_run.config.opt = vars(opt)
|
||||
self.wandb_run.config.data_dict = wandb_data_dict
|
||||
self.data_dict = self.setup_training(opt, data_dict)
|
||||
if self.job_type == 'Dataset Creation':
|
||||
self.data_dict = self.check_and_upload_dataset(opt)
|
||||
|
||||
if job_type == 'Training':
|
||||
self.setup_training(opt, data_dict)
|
||||
if opt.bbox_interval == -1:
|
||||
opt.bbox_interval = (opt.epochs // 10) if opt.epochs > 10 else opt.epochs
|
||||
if opt.save_period == -1:
|
||||
opt.save_period = (opt.epochs // 10) if opt.epochs > 10 else opt.epochs
|
||||
def check_and_upload_dataset(self, opt):
|
||||
assert wandb, 'Install wandb to upload dataset'
|
||||
check_dataset(self.data_dict)
|
||||
config_path = self.log_dataset_artifact(opt.data,
|
||||
opt.single_cls,
|
||||
'YOLOv5' if opt.project == 'runs/train' else Path(opt.project).stem)
|
||||
print("Created dataset config file ", config_path)
|
||||
with open(config_path) as f:
|
||||
wandb_data_dict = yaml.load(f, Loader=yaml.SafeLoader)
|
||||
return wandb_data_dict
|
||||
|
||||
def setup_training(self, opt, data_dict):
|
||||
self.log_dict = {}
|
||||
self.train_artifact_path, self.trainset_artifact = \
|
||||
self.download_dataset_artifact(data_dict['train'], opt.artifact_alias)
|
||||
self.test_artifact_path, self.testset_artifact = \
|
||||
self.download_dataset_artifact(data_dict['val'], opt.artifact_alias)
|
||||
self.result_artifact, self.result_table, self.weights = None, None, None
|
||||
if self.train_artifact_path is not None:
|
||||
train_path = Path(self.train_artifact_path) / 'data/images/'
|
||||
data_dict['train'] = str(train_path)
|
||||
if self.test_artifact_path is not None:
|
||||
test_path = Path(self.test_artifact_path) / 'data/images/'
|
||||
data_dict['val'] = str(test_path)
|
||||
self.log_dict, self.current_epoch, self.log_imgs = {}, 0, 16 # Logging Constants
|
||||
self.bbox_interval = opt.bbox_interval
|
||||
if isinstance(opt.resume, str):
|
||||
modeldir, _ = self.download_model_artifact(opt)
|
||||
if modeldir:
|
||||
self.weights = Path(modeldir) / "last.pt"
|
||||
config = self.wandb_run.config
|
||||
opt.weights, opt.save_period, opt.batch_size, opt.bbox_interval, opt.epochs, opt.hyp = str(
|
||||
self.weights), config.save_period, config.total_batch_size, config.bbox_interval, config.epochs, \
|
||||
config.opt['hyp']
|
||||
data_dict = dict(self.wandb_run.config.data_dict) # eliminates the need for config file to resume
|
||||
if 'val_artifact' not in self.__dict__: # If --upload_dataset is set, use the existing artifact, don't download
|
||||
self.train_artifact_path, self.train_artifact = self.download_dataset_artifact(data_dict.get('train'),
|
||||
opt.artifact_alias)
|
||||
self.val_artifact_path, self.val_artifact = self.download_dataset_artifact(data_dict.get('val'),
|
||||
opt.artifact_alias)
|
||||
self.result_artifact, self.result_table, self.val_table, self.weights = None, None, None, None
|
||||
if self.train_artifact_path is not None:
|
||||
train_path = Path(self.train_artifact_path) / 'data/images/'
|
||||
data_dict['train'] = str(train_path)
|
||||
if self.val_artifact_path is not None:
|
||||
val_path = Path(self.val_artifact_path) / 'data/images/'
|
||||
data_dict['val'] = str(val_path)
|
||||
self.val_table = self.val_artifact.get("val")
|
||||
self.map_val_table_path()
|
||||
if self.val_artifact is not None:
|
||||
self.result_artifact = wandb.Artifact("run_" + wandb.run.id + "_progress", "evaluation")
|
||||
self.result_table = wandb.Table(["epoch", "id", "prediction", "avg_confidence"])
|
||||
if opt.resume_from_artifact:
|
||||
modeldir, _ = self.download_model_artifact(opt.resume_from_artifact)
|
||||
if modeldir:
|
||||
self.weights = Path(modeldir) / "best.pt"
|
||||
opt.weights = self.weights
|
||||
if opt.bbox_interval == -1:
|
||||
self.bbox_interval = opt.bbox_interval = (opt.epochs // 10) if opt.epochs > 10 else 1
|
||||
return data_dict
|
||||
|
||||
def download_dataset_artifact(self, path, alias):
|
||||
if path.startswith(WANDB_ARTIFACT_PREFIX):
|
||||
dataset_artifact = wandb.use_artifact(remove_prefix(path, WANDB_ARTIFACT_PREFIX) + ":" + alias)
|
||||
assert dataset_artifact is not None, "'Error: W&B dataset artifact doesn\'t exist'"
|
||||
datadir = dataset_artifact.download()
|
||||
labels_zip = Path(datadir) / "data/labels.zip"
|
||||
shutil.unpack_archive(labels_zip, Path(datadir) / 'data/labels', 'zip')
|
||||
print("Downloaded dataset to : ", datadir)
|
||||
return datadir, dataset_artifact
|
||||
return None, None
|
||||
|
||||
def download_model_artifact(self, name):
|
||||
model_artifact = wandb.use_artifact(name + ":latest")
|
||||
assert model_artifact is not None, 'Error: W&B model artifact doesn\'t exist'
|
||||
modeldir = model_artifact.download()
|
||||
print("Downloaded model to : ", modeldir)
|
||||
return modeldir, model_artifact
|
||||
def download_model_artifact(self, opt):
|
||||
if opt.resume.startswith(WANDB_ARTIFACT_PREFIX):
|
||||
model_artifact = wandb.use_artifact(remove_prefix(opt.resume, WANDB_ARTIFACT_PREFIX) + ":latest")
|
||||
assert model_artifact is not None, 'Error: W&B model artifact doesn\'t exist'
|
||||
modeldir = model_artifact.download()
|
||||
epochs_trained = model_artifact.metadata.get('epochs_trained')
|
||||
total_epochs = model_artifact.metadata.get('total_epochs')
|
||||
assert epochs_trained < total_epochs, 'training to %g epochs is finished, nothing to resume.' % (
|
||||
total_epochs)
|
||||
return modeldir, model_artifact
|
||||
return None, None
|
||||
|
||||
def log_model(self, path, opt, epoch):
|
||||
datetime_suffix = datetime.today().strftime('%Y-%m-%d-%H-%M-%S')
|
||||
def log_model(self, path, opt, epoch, fitness_score, best_model=False):
|
||||
model_artifact = wandb.Artifact('run_' + wandb.run.id + '_model', type='model', metadata={
|
||||
'original_url': str(path),
|
||||
'epoch': epoch + 1,
|
||||
'epochs_trained': epoch + 1,
|
||||
'save period': opt.save_period,
|
||||
'project': opt.project,
|
||||
'datetime': datetime_suffix
|
||||
'total_epochs': opt.epochs,
|
||||
'fitness_score': fitness_score
|
||||
})
|
||||
model_artifact.add_file(str(path / 'last.pt'), name='last.pt')
|
||||
model_artifact.add_file(str(path / 'best.pt'), name='best.pt')
|
||||
wandb.log_artifact(model_artifact)
|
||||
wandb.log_artifact(model_artifact,
|
||||
aliases=['latest', 'epoch ' + str(self.current_epoch), 'best' if best_model else ''])
|
||||
print("Saving model artifact on epoch ", epoch + 1)
|
||||
|
||||
def log_dataset_artifact(self, dataset, class_to_id, name='dataset'):
|
||||
def log_dataset_artifact(self, data_file, single_cls, project, overwrite_config=False):
|
||||
with open(data_file) as f:
|
||||
data = yaml.load(f, Loader=yaml.SafeLoader) # data dict
|
||||
nc, names = (1, ['item']) if single_cls else (int(data['nc']), data['names'])
|
||||
names = {k: v for k, v in enumerate(names)} # to index dictionary
|
||||
self.train_artifact = self.create_dataset_table(LoadImagesAndLabels(
|
||||
data['train']), names, name='train') if data.get('train') else None
|
||||
self.val_artifact = self.create_dataset_table(LoadImagesAndLabels(
|
||||
data['val']), names, name='val') if data.get('val') else None
|
||||
if data.get('train'):
|
||||
data['train'] = WANDB_ARTIFACT_PREFIX + str(Path(project) / 'train')
|
||||
if data.get('val'):
|
||||
data['val'] = WANDB_ARTIFACT_PREFIX + str(Path(project) / 'val')
|
||||
path = data_file if overwrite_config else '_wandb.'.join(data_file.rsplit('.', 1)) # updated data.yaml path
|
||||
data.pop('download', None)
|
||||
with open(path, 'w') as f:
|
||||
yaml.dump(data, f)
|
||||
|
||||
if self.job_type == 'Training': # builds correct artifact pipeline graph
|
||||
self.wandb_run.use_artifact(self.val_artifact)
|
||||
self.wandb_run.use_artifact(self.train_artifact)
|
||||
self.val_artifact.wait()
|
||||
self.val_table = self.val_artifact.get('val')
|
||||
self.map_val_table_path()
|
||||
else:
|
||||
self.wandb_run.log_artifact(self.train_artifact)
|
||||
self.wandb_run.log_artifact(self.val_artifact)
|
||||
return path
|
||||
|
||||
def map_val_table_path(self):
|
||||
self.val_table_map = {}
|
||||
print("Mapping dataset")
|
||||
for i, data in enumerate(tqdm(self.val_table.data)):
|
||||
self.val_table_map[data[3]] = data[0]
|
||||
|
||||
def create_dataset_table(self, dataset, class_to_id, name='dataset'):
|
||||
# TODO: Explore multiprocessing to slpit this loop parallely| This is essential for speeding up the the logging
|
||||
artifact = wandb.Artifact(name=name, type="dataset")
|
||||
image_path = dataset.path
|
||||
artifact.add_dir(image_path, name='data/images')
|
||||
table = wandb.Table(columns=["id", "train_image", "Classes"])
|
||||
for img_file in tqdm([dataset.path]) if Path(dataset.path).is_dir() else tqdm(dataset.img_files):
|
||||
if Path(img_file).is_dir():
|
||||
artifact.add_dir(img_file, name='data/images')
|
||||
labels_path = 'labels'.join(dataset.path.rsplit('images', 1))
|
||||
artifact.add_dir(labels_path, name='data/labels')
|
||||
else:
|
||||
artifact.add_file(img_file, name='data/images/' + Path(img_file).name)
|
||||
label_file = Path(img2label_paths([img_file])[0])
|
||||
artifact.add_file(str(label_file),
|
||||
name='data/labels/' + label_file.name) if label_file.exists() else None
|
||||
table = wandb.Table(columns=["id", "train_image", "Classes", "name"])
|
||||
class_set = wandb.Classes([{'id': id, 'name': name} for id, name in class_to_id.items()])
|
||||
for si, (img, labels, paths, shapes) in enumerate(dataset):
|
||||
for si, (img, labels, paths, shapes) in enumerate(tqdm(dataset)):
|
||||
height, width = shapes[0]
|
||||
labels[:, 2:] = (xywh2xyxy(labels[:, 2:].view(-1, 4)))
|
||||
labels[:, 2:] *= torch.Tensor([width, height, width, height])
|
||||
box_data = []
|
||||
img_classes = {}
|
||||
labels[:, 2:] = (xywh2xyxy(labels[:, 2:].view(-1, 4))) * torch.Tensor([width, height, width, height])
|
||||
box_data, img_classes = [], {}
|
||||
for cls, *xyxy in labels[:, 1:].tolist():
|
||||
cls = int(cls)
|
||||
box_data.append({"position": {"minX": xyxy[0], "minY": xyxy[1], "maxX": xyxy[2], "maxY": xyxy[3]},
|
||||
@ -112,34 +213,52 @@ class WandbLogger():
|
||||
"domain": "pixel"})
|
||||
img_classes[cls] = class_to_id[cls]
|
||||
boxes = {"ground_truth": {"box_data": box_data, "class_labels": class_to_id}} # inference-space
|
||||
table.add_data(si, wandb.Image(paths, classes=class_set, boxes=boxes), json.dumps(img_classes))
|
||||
table.add_data(si, wandb.Image(paths, classes=class_set, boxes=boxes), json.dumps(img_classes),
|
||||
Path(paths).name)
|
||||
artifact.add(table, name)
|
||||
labels_path = 'labels'.join(image_path.rsplit('images', 1))
|
||||
zip_path = Path(labels_path).parent / (name + '_labels.zip')
|
||||
if not zip_path.is_file(): # make_archive won't check if file exists
|
||||
shutil.make_archive(zip_path.with_suffix(''), 'zip', labels_path)
|
||||
artifact.add_file(str(zip_path), name='data/labels.zip')
|
||||
wandb.log_artifact(artifact)
|
||||
print("Saving data to W&B...")
|
||||
return artifact
|
||||
|
||||
def log_training_progress(self, predn, path, names):
|
||||
if self.val_table and self.result_table:
|
||||
class_set = wandb.Classes([{'id': id, 'name': name} for id, name in names.items()])
|
||||
box_data = []
|
||||
total_conf = 0
|
||||
for *xyxy, conf, cls in predn.tolist():
|
||||
if conf >= 0.25:
|
||||
box_data.append(
|
||||
{"position": {"minX": xyxy[0], "minY": xyxy[1], "maxX": xyxy[2], "maxY": xyxy[3]},
|
||||
"class_id": int(cls),
|
||||
"box_caption": "%s %.3f" % (names[cls], conf),
|
||||
"scores": {"class_score": conf},
|
||||
"domain": "pixel"})
|
||||
total_conf = total_conf + conf
|
||||
boxes = {"predictions": {"box_data": box_data, "class_labels": names}} # inference-space
|
||||
id = self.val_table_map[Path(path).name]
|
||||
self.result_table.add_data(self.current_epoch,
|
||||
id,
|
||||
wandb.Image(self.val_table.data[id][1], boxes=boxes, classes=class_set),
|
||||
total_conf / max(1, len(box_data))
|
||||
)
|
||||
|
||||
def log(self, log_dict):
|
||||
if self.wandb_run:
|
||||
for key, value in log_dict.items():
|
||||
self.log_dict[key] = value
|
||||
|
||||
def end_epoch(self):
|
||||
if self.wandb_run and self.log_dict:
|
||||
def end_epoch(self, best_result=False):
|
||||
if self.wandb_run:
|
||||
wandb.log(self.log_dict)
|
||||
self.log_dict = {}
|
||||
self.log_dict = {}
|
||||
if self.result_artifact:
|
||||
train_results = wandb.JoinedTable(self.val_table, self.result_table, "id")
|
||||
self.result_artifact.add(train_results, 'result')
|
||||
wandb.log_artifact(self.result_artifact, aliases=['latest', 'epoch ' + str(self.current_epoch),
|
||||
('best' if best_result else '')])
|
||||
self.result_table = wandb.Table(["epoch", "id", "prediction", "avg_confidence"])
|
||||
self.result_artifact = wandb.Artifact("run_" + wandb.run.id + "_progress", "evaluation")
|
||||
|
||||
def finish_run(self):
|
||||
if self.wandb_run:
|
||||
if self.result_artifact:
|
||||
print("Add Training Progress Artifact")
|
||||
self.result_artifact.add(self.result_table, 'result')
|
||||
train_results = wandb.JoinedTable(self.testset_artifact.get("val"), self.result_table, "id")
|
||||
self.result_artifact.add(train_results, 'joined_result')
|
||||
wandb.log_artifact(self.result_artifact)
|
||||
if self.log_dict:
|
||||
wandb.log(self.log_dict)
|
||||
wandb.run.finish()
|
||||
|
Loading…
x
Reference in New Issue
Block a user