W&B logging add hyperparameters (#1399)
* W&B logging add hyperparameters * hyp bug fix and image logging updates * if plots and wandb: * cleanup * wandb/ gitignore add * cleanup 2 * cleanup 3 * move wandb import to top of file * wandb evolve * update import * wandb.run.finish() * default anchors: 3pull/1315/head^2
parent
b7007d03b4
commit
9c91aeae10
|
@ -79,9 +79,11 @@ sdist/
|
|||
var/
|
||||
wheels/
|
||||
*.egg-info/
|
||||
wandb/
|
||||
.installed.cfg
|
||||
*.egg
|
||||
|
||||
|
||||
# PyInstaller
|
||||
# Usually these files are written by a python script from a template
|
||||
# before PyInstaller builds the exe, so as to inject date/other infos into it.
|
||||
|
|
|
@ -17,7 +17,7 @@ obj: 1.0 # obj loss gain (scale with pixels)
|
|||
obj_pw: 1.0 # obj BCELoss positive_weight
|
||||
iou_t: 0.20 # IoU training threshold
|
||||
anchor_t: 4.0 # anchor-multiple threshold
|
||||
# anchors: 0 # anchors per output grid (0 to ignore)
|
||||
# anchors: 3 # anchors per output layer (0 to ignore)
|
||||
fl_gamma: 0.0 # focal loss gamma (efficientDet default gamma=1.5)
|
||||
hsv_h: 0.015 # image HSV-Hue augmentation (fraction)
|
||||
hsv_s: 0.7 # image HSV-Saturation augmentation (fraction)
|
||||
|
|
30
test.py
30
test.py
|
@ -75,7 +75,7 @@ def test(data,
|
|||
niou = iouv.numel()
|
||||
|
||||
# Logging
|
||||
log_imgs = min(log_imgs, 100) # ceil
|
||||
log_imgs, wandb = min(log_imgs, 100), None # ceil
|
||||
try:
|
||||
import wandb # Weights & Biases
|
||||
except ImportError:
|
||||
|
@ -132,6 +132,7 @@ def test(data,
|
|||
continue
|
||||
|
||||
# Append to text file
|
||||
path = Path(paths[si])
|
||||
if save_txt:
|
||||
gn = torch.tensor(shapes[si][0])[[1, 0, 1, 0]] # normalization gain whwh
|
||||
x = pred.clone()
|
||||
|
@ -139,18 +140,18 @@ def test(data,
|
|||
for *xyxy, conf, cls in x:
|
||||
xywh = (xyxy2xywh(torch.tensor(xyxy).view(1, 4)) / gn).view(-1).tolist() # normalized xywh
|
||||
line = (cls, *xywh, conf) if save_conf else (cls, *xywh) # label format
|
||||
with open(str(save_dir / 'labels' / Path(paths[si]).stem) + '.txt', 'a') as f:
|
||||
with open(save_dir / 'labels' / (path.stem + '.txt'), 'a') as f:
|
||||
f.write(('%g ' * len(line)).rstrip() % line + '\n')
|
||||
|
||||
# W&B logging
|
||||
if len(wandb_images) < log_imgs:
|
||||
if plots and len(wandb_images) < log_imgs:
|
||||
box_data = [{"position": {"minX": xyxy[0], "minY": xyxy[1], "maxX": xyxy[2], "maxY": xyxy[3]},
|
||||
"class_id": int(cls),
|
||||
"box_caption": "%s %.3f" % (names[cls], conf),
|
||||
"scores": {"class_score": conf},
|
||||
"domain": "pixel"} for *xyxy, conf, cls in pred.clone().tolist()]
|
||||
"domain": "pixel"} for *xyxy, conf, cls in pred.tolist()]
|
||||
boxes = {"predictions": {"box_data": box_data, "class_labels": names}}
|
||||
wandb_images.append(wandb.Image(img[si], boxes=boxes))
|
||||
wandb_images.append(wandb.Image(img[si], boxes=boxes, caption=path.name))
|
||||
|
||||
# Clip boxes to image bounds
|
||||
clip_coords(pred, (height, width))
|
||||
|
@ -158,13 +159,13 @@ def test(data,
|
|||
# Append to pycocotools JSON dictionary
|
||||
if save_json:
|
||||
# [{"image_id": 42, "category_id": 18, "bbox": [258.15, 41.29, 348.26, 243.78], "score": 0.236}, ...
|
||||
image_id = Path(paths[si]).stem
|
||||
image_id = int(path.stem) if path.stem.isnumeric() else path.stem
|
||||
box = pred[:, :4].clone() # xyxy
|
||||
scale_coords(img[si].shape[1:], box, shapes[si][0], shapes[si][1]) # to original shape
|
||||
box = xyxy2xywh(box) # xywh
|
||||
box[:, :2] -= box[:, 2:] / 2 # xy center to top-left corner
|
||||
for p, b in zip(pred.tolist(), box.tolist()):
|
||||
jdict.append({'image_id': int(image_id) if image_id.isnumeric() else image_id,
|
||||
jdict.append({'image_id': image_id,
|
||||
'category_id': coco91class[int(p[5])] if is_coco else int(p[5]),
|
||||
'bbox': [round(x, 3) for x in b],
|
||||
'score': round(p[4], 5)})
|
||||
|
@ -203,15 +204,11 @@ def test(data,
|
|||
stats.append((correct.cpu(), pred[:, 4].cpu(), pred[:, 5].cpu(), tcls))
|
||||
|
||||
# Plot images
|
||||
if plots and batch_i < 1:
|
||||
if plots and batch_i < 3:
|
||||
f = save_dir / f'test_batch{batch_i}_labels.jpg' # filename
|
||||
plot_images(img, targets, paths, str(f), names) # labels
|
||||
plot_images(img, targets, paths, f, names) # labels
|
||||
f = save_dir / f'test_batch{batch_i}_pred.jpg'
|
||||
plot_images(img, output_to_target(output, width, height), paths, str(f), names) # predictions
|
||||
|
||||
# W&B logging
|
||||
if wandb_images:
|
||||
wandb.log({"outputs": wandb_images})
|
||||
plot_images(img, output_to_target(output, width, height), paths, f, names) # predictions
|
||||
|
||||
# Compute statistics
|
||||
stats = [np.concatenate(x, 0) for x in zip(*stats)] # to numpy
|
||||
|
@ -223,6 +220,11 @@ def test(data,
|
|||
else:
|
||||
nt = torch.zeros(1)
|
||||
|
||||
# W&B logging
|
||||
if plots and wandb:
|
||||
wandb.log({"Images": wandb_images})
|
||||
wandb.log({"Validation": [wandb.Image(str(x), caption=x.name) for x in sorted(save_dir.glob('test*.jpg'))]})
|
||||
|
||||
# Print results
|
||||
pf = '%20s' + '%12.3g' * 6 # print format
|
||||
print(pf % ('all', seen, nt.sum(), mp, mr, map50, map))
|
||||
|
|
66
train.py
66
train.py
|
@ -34,6 +34,12 @@ from utils.torch_utils import ModelEMA, select_device, intersect_dicts, torch_di
|
|||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
try:
|
||||
import wandb
|
||||
except ImportError:
|
||||
wandb = None
|
||||
logger.info("Install Weights & Biases for experiment logging via 'pip install wandb' (recommended)")
|
||||
|
||||
|
||||
def train(hyp, opt, device, tb_writer=None, wandb=None):
|
||||
logger.info(f'Hyperparameters {hyp}')
|
||||
|
@ -54,6 +60,7 @@ def train(hyp, opt, device, tb_writer=None, wandb=None):
|
|||
yaml.dump(vars(opt), f, sort_keys=False)
|
||||
|
||||
# Configure
|
||||
plots = not opt.evolve # create plots
|
||||
cuda = device.type != 'cpu'
|
||||
init_seeds(2 + rank)
|
||||
with open(opt.data) as f:
|
||||
|
@ -122,6 +129,7 @@ def train(hyp, opt, device, tb_writer=None, wandb=None):
|
|||
|
||||
# Logging
|
||||
if wandb and wandb.run is None:
|
||||
opt.hyp = hyp # add hyperparameters
|
||||
wandb_run = wandb.init(config=opt, resume="allow",
|
||||
project='YOLOv5' if opt.project == 'runs/train' else Path(opt.project).stem,
|
||||
name=save_dir.stem,
|
||||
|
@ -164,7 +172,7 @@ def train(hyp, opt, device, tb_writer=None, wandb=None):
|
|||
model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model).to(device)
|
||||
logger.info('Using SyncBatchNorm()')
|
||||
|
||||
# Exponential moving average
|
||||
# EMA
|
||||
ema = ModelEMA(model) if rank in [-1, 0] else None
|
||||
|
||||
# DDP mode
|
||||
|
@ -191,10 +199,12 @@ def train(hyp, opt, device, tb_writer=None, wandb=None):
|
|||
c = torch.tensor(labels[:, 0]) # classes
|
||||
# cf = torch.bincount(c.long(), minlength=nc) + 1. # frequency
|
||||
# model._initialize_biases(cf.to(device))
|
||||
plot_labels(labels, save_dir=save_dir)
|
||||
if tb_writer:
|
||||
# tb_writer.add_hparams(hyp, {}) # causes duplicate https://github.com/ultralytics/yolov5/pull/384
|
||||
tb_writer.add_histogram('classes', c, 0)
|
||||
if plots:
|
||||
plot_labels(labels, save_dir=save_dir)
|
||||
if tb_writer:
|
||||
tb_writer.add_histogram('classes', c, 0)
|
||||
if wandb:
|
||||
wandb.log({"Labels": [wandb.Image(str(x), caption=x.name) for x in save_dir.glob('*labels*.png')]})
|
||||
|
||||
# Anchors
|
||||
if not opt.noautoanchor:
|
||||
|
@ -298,14 +308,17 @@ def train(hyp, opt, device, tb_writer=None, wandb=None):
|
|||
pbar.set_description(s)
|
||||
|
||||
# Plot
|
||||
if ni < 3:
|
||||
f = str(save_dir / f'train_batch{ni}.jpg') # filename
|
||||
result = plot_images(images=imgs, targets=targets, paths=paths, fname=f)
|
||||
# if tb_writer and result is not None:
|
||||
# tb_writer.add_image(f, result, dataformats='HWC', global_step=epoch)
|
||||
# tb_writer.add_graph(model, imgs) # add model to tensorboard
|
||||
if plots and ni < 3:
|
||||
f = save_dir / f'train_batch{ni}.jpg' # filename
|
||||
plot_images(images=imgs, targets=targets, paths=paths, fname=f)
|
||||
# if tb_writer:
|
||||
# tb_writer.add_image(f, result, dataformats='HWC', global_step=epoch)
|
||||
# tb_writer.add_graph(model, imgs) # add model to tensorboard
|
||||
elif plots and ni == 3 and wandb:
|
||||
wandb.log({"Mosaics": [wandb.Image(str(x), caption=x.name) for x in save_dir.glob('train*.jpg')]})
|
||||
|
||||
# end batch ------------------------------------------------------------------------------------------------
|
||||
# end epoch ----------------------------------------------------------------------------------------------------
|
||||
|
||||
# Scheduler
|
||||
lr = [x['lr'] for x in optimizer.param_groups] # for tensorboard
|
||||
|
@ -325,7 +338,7 @@ def train(hyp, opt, device, tb_writer=None, wandb=None):
|
|||
single_cls=opt.single_cls,
|
||||
dataloader=testloader,
|
||||
save_dir=save_dir,
|
||||
plots=epoch == 0 or final_epoch, # plot first and last
|
||||
plots=plots and final_epoch,
|
||||
log_imgs=opt.log_imgs if wandb else 0)
|
||||
|
||||
# Write
|
||||
|
@ -380,11 +393,16 @@ def train(hyp, opt, device, tb_writer=None, wandb=None):
|
|||
strip_optimizer(f2) # strip optimizer
|
||||
os.system('gsutil cp %s gs://%s/weights' % (f2, opt.bucket)) if opt.bucket else None # upload
|
||||
# Finish
|
||||
if not opt.evolve:
|
||||
if plots:
|
||||
plot_results(save_dir=save_dir) # save as results.png
|
||||
if wandb:
|
||||
wandb.log({"Results": [wandb.Image(str(save_dir / x), caption=x) for x in
|
||||
['results.png', 'precision-recall_curve.png']]})
|
||||
logger.info('%g epochs completed in %.3f hours.\n' % (epoch - start_epoch + 1, (time.time() - t0) / 3600))
|
||||
else:
|
||||
dist.destroy_process_group()
|
||||
|
||||
dist.destroy_process_group() if rank not in [-1, 0] else None
|
||||
wandb.run.finish() if wandb and wandb.run else None
|
||||
torch.cuda.empty_cache()
|
||||
return results
|
||||
|
||||
|
@ -413,7 +431,7 @@ if __name__ == '__main__':
|
|||
parser.add_argument('--adam', action='store_true', help='use torch.optim.Adam() optimizer')
|
||||
parser.add_argument('--sync-bn', action='store_true', help='use SyncBatchNorm, only available in DDP mode')
|
||||
parser.add_argument('--local_rank', type=int, default=-1, help='DDP parameter, do not modify')
|
||||
parser.add_argument('--log-imgs', type=int, default=10, help='number of images for W&B logging, max 100')
|
||||
parser.add_argument('--log-imgs', type=int, default=16, help='number of images for W&B logging, max 100')
|
||||
parser.add_argument('--workers', type=int, default=8, help='maximum number of dataloader workers')
|
||||
parser.add_argument('--project', default='runs/train', help='save to project/name')
|
||||
parser.add_argument('--name', default='exp', help='save to project/name')
|
||||
|
@ -442,7 +460,7 @@ if __name__ == '__main__':
|
|||
assert len(opt.cfg) or len(opt.weights), 'either --cfg or --weights must be specified'
|
||||
opt.img_size.extend([opt.img_size[-1]] * (2 - len(opt.img_size))) # extend to 2 sizes (train, test)
|
||||
opt.name = 'evolve' if opt.evolve else opt.name
|
||||
opt.save_dir = increment_path(Path(opt.project) / opt.name, exist_ok=opt.exist_ok) # increment run
|
||||
opt.save_dir = increment_path(Path(opt.project) / opt.name, exist_ok=opt.exist_ok | opt.evolve) # increment run
|
||||
|
||||
# DDP mode
|
||||
device = select_device(opt.device, batch_size=opt.batch_size)
|
||||
|
@ -465,20 +483,10 @@ if __name__ == '__main__':
|
|||
# Train
|
||||
logger.info(opt)
|
||||
if not opt.evolve:
|
||||
tb_writer, wandb = None, None # init loggers
|
||||
tb_writer = None # init loggers
|
||||
if opt.global_rank in [-1, 0]:
|
||||
# Tensorboard
|
||||
logger.info(f'Start Tensorboard with "tensorboard --logdir {opt.project}", view at http://localhost:6006/')
|
||||
tb_writer = SummaryWriter(opt.save_dir) # runs/train/exp
|
||||
|
||||
# W&B
|
||||
try:
|
||||
import wandb
|
||||
|
||||
assert os.environ.get('WANDB_DISABLED') != 'true'
|
||||
except (ImportError, AssertionError):
|
||||
logger.info("Install Weights & Biases for experiment logging via 'pip install wandb' (recommended)")
|
||||
|
||||
tb_writer = SummaryWriter(opt.save_dir) # Tensorboard
|
||||
train(hyp, opt, device, tb_writer, wandb)
|
||||
|
||||
# Evolve hyperparameters (optional)
|
||||
|
@ -553,7 +561,7 @@ if __name__ == '__main__':
|
|||
hyp[k] = round(hyp[k], 5) # significant digits
|
||||
|
||||
# Train mutation
|
||||
results = train(hyp.copy(), opt, device)
|
||||
results = train(hyp.copy(), opt, device, wandb=wandb)
|
||||
|
||||
# Write mutation results
|
||||
print_mutation(hyp.copy(), results, yaml_file, opt.bucket)
|
||||
|
|
|
@ -158,13 +158,13 @@ def plot_images(images, targets, paths=None, fname='images.jpg', names=None, max
|
|||
cls = int(classes[j])
|
||||
color = colors[cls % len(colors)]
|
||||
cls = names[cls] if names else cls
|
||||
if labels or conf[j] > 0.3: # 0.3 conf thresh
|
||||
if labels or conf[j] > 0.25: # 0.25 conf thresh
|
||||
label = '%s' % cls if labels else '%s %.1f' % (cls, conf[j])
|
||||
plot_one_box(box, mosaic, label=label, color=color, line_thickness=tl)
|
||||
|
||||
# Draw image filename labels
|
||||
if paths is not None:
|
||||
label = os.path.basename(paths[i])[:40] # trim to 40 char
|
||||
if paths:
|
||||
label = Path(paths[i]).name[:40] # trim to 40 char
|
||||
t_size = cv2.getTextSize(label, 0, fontScale=tl / 3, thickness=tf)[0]
|
||||
cv2.putText(mosaic, label, (block_x + 5, block_y + t_size[1] + 5), 0, tl / 3, [220, 220, 220], thickness=tf,
|
||||
lineType=cv2.LINE_AA)
|
||||
|
@ -172,7 +172,7 @@ def plot_images(images, targets, paths=None, fname='images.jpg', names=None, max
|
|||
# Image border
|
||||
cv2.rectangle(mosaic, (block_x, block_y), (block_x + w, block_y + h), (255, 255, 255), thickness=3)
|
||||
|
||||
if fname is not None:
|
||||
if fname:
|
||||
r = min(1280. / max(h, w) / ns, 1.0) # ratio to limit image size
|
||||
mosaic = cv2.resize(mosaic, (int(ns * w * r), int(ns * h * r)), interpolation=cv2.INTER_AREA)
|
||||
# cv2.imwrite(fname, cv2.cvtColor(mosaic, cv2.COLOR_BGR2RGB)) # cv2 save
|
||||
|
|
Loading…
Reference in New Issue