Refactor train.py and val.py `loggers` (#4137)
* Update loggers * Config * Update val.py * cleanup * fix1 * fix2 * fix3 and reformat * format sweep.py * Logger() class * cleanup * cleanup2 * wandb package import fix * wandb package import fix2 * txt fix * fix4 * fix5 * fix6 * drop wandb into utils/loggers * fix 7 * rename loggers/wandb_logging to loggers/wandb * Update message * Update message * Update message * cleanup * Fix x axis bug * fix rank 0 issue * cleanuppull/4143/head
parent
63dd65e7ed
commit
efe60b5681
87
train.py
87
train.py
|
@ -10,7 +10,6 @@ import os
|
|||
import random
|
||||
import sys
|
||||
import time
|
||||
import warnings
|
||||
from copy import deepcopy
|
||||
from pathlib import Path
|
||||
from threading import Thread
|
||||
|
@ -24,7 +23,6 @@ import yaml
|
|||
from torch.cuda import amp
|
||||
from torch.nn.parallel import DistributedDataParallel as DDP
|
||||
from torch.optim import Adam, SGD, lr_scheduler
|
||||
from torch.utils.tensorboard import SummaryWriter
|
||||
from tqdm import tqdm
|
||||
|
||||
FILE = Path(__file__).absolute()
|
||||
|
@ -42,8 +40,9 @@ from utils.google_utils import attempt_download
|
|||
from utils.loss import ComputeLoss
|
||||
from utils.plots import plot_images, plot_labels, plot_results, plot_evolution
|
||||
from utils.torch_utils import ModelEMA, select_device, intersect_dicts, torch_distributed_zero_first, de_parallel
|
||||
from utils.wandb_logging.wandb_utils import WandbLogger, check_wandb_resume
|
||||
from utils.loggers.wandb.wandb_utils import check_wandb_resume
|
||||
from utils.metrics import fitness
|
||||
from utils.loggers import Loggers
|
||||
|
||||
LOGGER = logging.getLogger(__name__)
|
||||
LOCAL_RANK = int(os.getenv('LOCAL_RANK', -1)) # https://pytorch.org/docs/stable/elastic/run.html
|
||||
|
@ -76,37 +75,23 @@ def train(hyp, # path/to/hyp.yaml or hyp dictionary
|
|||
with open(save_dir / 'opt.yaml', 'w') as f:
|
||||
yaml.safe_dump(vars(opt), f, sort_keys=False)
|
||||
|
||||
# Configure
|
||||
# Config
|
||||
plots = not evolve # create plots
|
||||
cuda = device.type != 'cpu'
|
||||
init_seeds(1 + RANK)
|
||||
with open(data) as f:
|
||||
data_dict = yaml.safe_load(f) # data dict
|
||||
|
||||
# Loggers
|
||||
loggers = {'wandb': None, 'tb': None} # loggers dict
|
||||
if RANK in [-1, 0]:
|
||||
# TensorBoard
|
||||
if plots:
|
||||
prefix = colorstr('tensorboard: ')
|
||||
LOGGER.info(f"{prefix}Start with 'tensorboard --logdir {opt.project}', view at http://localhost:6006/")
|
||||
loggers['tb'] = SummaryWriter(str(save_dir))
|
||||
|
||||
# W&B
|
||||
opt.hyp = hyp # add hyperparameters
|
||||
run_id = torch.load(weights).get('wandb_id') if weights.endswith('.pt') and os.path.isfile(weights) else None
|
||||
run_id = run_id if opt.resume else None # start fresh run if transfer learning
|
||||
wandb_logger = WandbLogger(opt, save_dir.stem, run_id, data_dict)
|
||||
loggers['wandb'] = wandb_logger.wandb
|
||||
if loggers['wandb']:
|
||||
data_dict = wandb_logger.data_dict
|
||||
weights, epochs, hyp = opt.weights, opt.epochs, opt.hyp # may update values if resuming
|
||||
|
||||
nc = 1 if single_cls else int(data_dict['nc']) # number of classes
|
||||
names = ['item'] if single_cls and len(data_dict['names']) != 1 else data_dict['names'] # class names
|
||||
assert len(names) == nc, f'{len(names)} names found for nc={nc} dataset in {data}' # check
|
||||
is_coco = data.endswith('coco.yaml') and nc == 80 # COCO dataset
|
||||
|
||||
# Loggers
|
||||
if RANK in [-1, 0]:
|
||||
loggers = Loggers(save_dir, results_file, weights, opt, hyp, data_dict, LOGGER).start() # loggers dict
|
||||
if loggers.wandb and resume:
|
||||
weights, epochs, hyp, data_dict = opt.weights, opt.epochs, opt.hyp, loggers.wandb.data_dict
|
||||
|
||||
# Model
|
||||
pretrained = weights.endswith('.pt')
|
||||
if pretrained:
|
||||
|
@ -351,16 +336,11 @@ def train(hyp, # path/to/hyp.yaml or hyp dictionary
|
|||
pbar.set_description(s)
|
||||
|
||||
# Plot
|
||||
if plots and ni < 3:
|
||||
f = save_dir / f'train_batch{ni}.jpg' # filename
|
||||
Thread(target=plot_images, args=(imgs, targets, paths, f), daemon=True).start()
|
||||
if loggers['tb'] and ni == 0: # TensorBoard
|
||||
with warnings.catch_warnings():
|
||||
warnings.simplefilter('ignore') # suppress jit trace warning
|
||||
loggers['tb'].add_graph(torch.jit.trace(de_parallel(model), imgs[0:1], strict=False), [])
|
||||
elif plots and ni == 10 and loggers['wandb']:
|
||||
wandb_logger.log({'Mosaics': [loggers['wandb'].Image(str(x), caption=x.name) for x in
|
||||
save_dir.glob('train*.jpg') if x.exists()]})
|
||||
if plots:
|
||||
if ni < 3:
|
||||
f = save_dir / f'train_batch{ni}.jpg' # filename
|
||||
Thread(target=plot_images, args=(imgs, targets, paths, f), daemon=True).start()
|
||||
loggers.on_train_batch_end(ni, model, imgs)
|
||||
|
||||
# end batch ------------------------------------------------------------------------------------------------
|
||||
|
||||
|
@ -368,13 +348,12 @@ def train(hyp, # path/to/hyp.yaml or hyp dictionary
|
|||
lr = [x['lr'] for x in optimizer.param_groups] # for loggers
|
||||
scheduler.step()
|
||||
|
||||
# DDP process 0 or single-GPU
|
||||
if RANK in [-1, 0]:
|
||||
# mAP
|
||||
loggers.on_train_epoch_end(epoch)
|
||||
ema.update_attr(model, include=['yaml', 'nc', 'hyp', 'names', 'stride', 'class_weights'])
|
||||
final_epoch = epoch + 1 == epochs
|
||||
if not noval or final_epoch: # Calculate mAP
|
||||
wandb_logger.current_epoch = epoch + 1
|
||||
results, maps, _ = val.run(data_dict,
|
||||
batch_size=batch_size // WORLD_SIZE * 2,
|
||||
imgsz=imgsz,
|
||||
|
@ -385,29 +364,14 @@ def train(hyp, # path/to/hyp.yaml or hyp dictionary
|
|||
save_json=is_coco and final_epoch,
|
||||
verbose=nc < 50 and final_epoch,
|
||||
plots=plots and final_epoch,
|
||||
wandb_logger=wandb_logger,
|
||||
loggers=loggers,
|
||||
compute_loss=compute_loss)
|
||||
|
||||
# Write
|
||||
with open(results_file, 'a') as f:
|
||||
f.write(s + '%10.4g' * 7 % results + '\n') # append metrics, val_loss
|
||||
|
||||
# Log
|
||||
tags = ['train/box_loss', 'train/obj_loss', 'train/cls_loss', # train loss
|
||||
'metrics/precision', 'metrics/recall', 'metrics/mAP_0.5', 'metrics/mAP_0.5:0.95',
|
||||
'val/box_loss', 'val/obj_loss', 'val/cls_loss', # val loss
|
||||
'x/lr0', 'x/lr1', 'x/lr2'] # params
|
||||
for x, tag in zip(list(mloss[:-1]) + list(results) + lr, tags):
|
||||
if loggers['tb']:
|
||||
loggers['tb'].add_scalar(tag, x, epoch) # TensorBoard
|
||||
if loggers['wandb']:
|
||||
wandb_logger.log({tag: x}) # W&B
|
||||
|
||||
# Update best mAP
|
||||
fi = fitness(np.array(results).reshape(1, -1)) # weighted combination of [P, R, mAP@.5, mAP@.5-.95]
|
||||
if fi > best_fitness:
|
||||
best_fitness = fi
|
||||
wandb_logger.end_epoch(best_result=best_fitness == fi)
|
||||
loggers.on_train_val_end(mloss, results, lr, epoch, s, best_fitness, fi)
|
||||
|
||||
# Save model
|
||||
if (not nosave) or (final_epoch and not evolve): # if save
|
||||
|
@ -418,16 +382,14 @@ def train(hyp, # path/to/hyp.yaml or hyp dictionary
|
|||
'ema': deepcopy(ema.ema).half(),
|
||||
'updates': ema.updates,
|
||||
'optimizer': optimizer.state_dict(),
|
||||
'wandb_id': wandb_logger.wandb_run.id if loggers['wandb'] else None}
|
||||
'wandb_id': loggers.wandb.wandb_run.id if loggers.wandb else None}
|
||||
|
||||
# Save last, best and delete
|
||||
torch.save(ckpt, last)
|
||||
if best_fitness == fi:
|
||||
torch.save(ckpt, best)
|
||||
if loggers['wandb']:
|
||||
if ((epoch + 1) % opt.save_period == 0 and not final_epoch) and opt.save_period != -1:
|
||||
wandb_logger.log_model(last.parent, opt, epoch, fi, best_model=best_fitness == fi)
|
||||
del ckpt
|
||||
loggers.on_model_save(last, epoch, final_epoch, best_fitness, fi)
|
||||
|
||||
# end epoch ----------------------------------------------------------------------------------------------------
|
||||
# end training -----------------------------------------------------------------------------------------------------
|
||||
|
@ -435,10 +397,6 @@ def train(hyp, # path/to/hyp.yaml or hyp dictionary
|
|||
LOGGER.info(f'{epoch - start_epoch + 1} epochs completed in {(time.time() - t0) / 3600:.3f} hours.\n')
|
||||
if plots:
|
||||
plot_results(save_dir=save_dir) # save as results.png
|
||||
if loggers['wandb']:
|
||||
files = ['results.png', 'confusion_matrix.png', *[f'{x}_curve.png' for x in ('F1', 'PR', 'P', 'R')]]
|
||||
wandb_logger.log({"Results": [loggers['wandb'].Image(str(save_dir / f), caption=f) for f in files
|
||||
if (save_dir / f).exists()]})
|
||||
|
||||
if not evolve:
|
||||
if is_coco: # COCO dataset
|
||||
|
@ -458,11 +416,8 @@ def train(hyp, # path/to/hyp.yaml or hyp dictionary
|
|||
for f in last, best:
|
||||
if f.exists():
|
||||
strip_optimizer(f) # strip optimizers
|
||||
if loggers['wandb']: # Log the stripped model
|
||||
loggers['wandb'].log_artifact(str(best if best.exists() else last), type='model',
|
||||
name='run_' + wandb_logger.wandb_run.id + '_model',
|
||||
aliases=['latest', 'best', 'stripped'])
|
||||
wandb_logger.finish_run()
|
||||
|
||||
loggers.on_train_end(last, best)
|
||||
|
||||
torch.cuda.empty_cache()
|
||||
return results
|
||||
|
|
|
@ -0,0 +1,129 @@
|
|||
# YOLOv5 experiment logging utils
|
||||
|
||||
import warnings
|
||||
|
||||
import torch
|
||||
from torch.utils.tensorboard import SummaryWriter
|
||||
|
||||
from utils.general import colorstr, emojis
|
||||
from utils.loggers.wandb.wandb_utils import WandbLogger
|
||||
from utils.torch_utils import de_parallel
|
||||
|
||||
LOGGERS = ('txt', 'tb', 'wandb') # text-file, TensorBoard, Weights & Biases
|
||||
|
||||
try:
|
||||
import wandb
|
||||
|
||||
assert hasattr(wandb, '__version__') # verify package import not local dir
|
||||
except (ImportError, AssertionError):
|
||||
wandb = None
|
||||
|
||||
|
||||
class Loggers():
|
||||
# YOLOv5 Loggers class
|
||||
def __init__(self, save_dir=None, results_file=None, weights=None, opt=None, hyp=None,
|
||||
data_dict=None, logger=None, include=LOGGERS):
|
||||
self.save_dir = save_dir
|
||||
self.results_file = results_file
|
||||
self.weights = weights
|
||||
self.opt = opt
|
||||
self.hyp = hyp
|
||||
self.data_dict = data_dict
|
||||
self.logger = logger # for printing results to console
|
||||
self.include = include
|
||||
for k in LOGGERS:
|
||||
setattr(self, k, None) # init empty logger dictionary
|
||||
|
||||
def start(self):
|
||||
self.txt = True # always log to txt
|
||||
|
||||
# Message
|
||||
try:
|
||||
import wandb
|
||||
except ImportError:
|
||||
prefix = colorstr('Weights & Biases: ')
|
||||
s = f"{prefix}run 'pip install wandb' to automatically track and visualize YOLOv5 🚀 runs (RECOMMENDED)"
|
||||
print(emojis(s))
|
||||
|
||||
# TensorBoard
|
||||
s = self.save_dir
|
||||
if 'tb' in self.include and not self.opt.evolve:
|
||||
prefix = colorstr('TensorBoard: ')
|
||||
self.logger.info(f"{prefix}Start with 'tensorboard --logdir {s.parent}', view at http://localhost:6006/")
|
||||
self.tb = SummaryWriter(str(s))
|
||||
|
||||
# W&B
|
||||
try:
|
||||
assert 'wandb' in self.include and wandb
|
||||
run_id = torch.load(self.weights).get('wandb_id') if self.opt.resume else None
|
||||
self.opt.hyp = self.hyp # add hyperparameters
|
||||
self.wandb = WandbLogger(self.opt, s.stem, run_id, self.data_dict)
|
||||
except:
|
||||
self.wandb = None
|
||||
|
||||
return self
|
||||
|
||||
def on_train_batch_end(self, ni, model, imgs):
|
||||
# Callback runs on train batch end
|
||||
if ni == 0:
|
||||
with warnings.catch_warnings():
|
||||
warnings.simplefilter('ignore') # suppress jit trace warning
|
||||
self.tb.add_graph(torch.jit.trace(de_parallel(model), imgs[0:1], strict=False), [])
|
||||
if self.wandb and ni == 10:
|
||||
files = sorted(self.save_dir.glob('train*.jpg'))
|
||||
self.wandb.log({'Mosaics': [wandb.Image(str(f), caption=f.name) for f in files if f.exists()]})
|
||||
|
||||
def on_train_epoch_end(self, epoch):
|
||||
# Callback runs on train epoch end
|
||||
if self.wandb:
|
||||
self.wandb.current_epoch = epoch + 1
|
||||
|
||||
def on_val_batch_end(self, pred, predn, path, names, im):
|
||||
# Callback runs on train batch end
|
||||
if self.wandb:
|
||||
self.wandb.val_one_image(pred, predn, path, names, im)
|
||||
|
||||
def on_val_end(self):
|
||||
# Callback runs on val end
|
||||
if self.wandb:
|
||||
files = sorted(self.save_dir.glob('val*.jpg'))
|
||||
self.wandb.log({"Validation": [wandb.Image(str(f), caption=f.name) for f in files]})
|
||||
|
||||
def on_train_val_end(self, mloss, results, lr, epoch, s, best_fitness, fi):
|
||||
# Callback runs on validation end during training
|
||||
vals = list(mloss[:-1]) + list(results) + lr
|
||||
tags = ['train/box_loss', 'train/obj_loss', 'train/cls_loss', # train loss
|
||||
'metrics/precision', 'metrics/recall', 'metrics/mAP_0.5', 'metrics/mAP_0.5:0.95',
|
||||
'val/box_loss', 'val/obj_loss', 'val/cls_loss', # val loss
|
||||
'x/lr0', 'x/lr1', 'x/lr2'] # params
|
||||
if self.txt:
|
||||
with open(self.results_file, 'a') as f:
|
||||
f.write(s + '%10.4g' * 7 % results + '\n') # append metrics, val_loss
|
||||
if self.tb:
|
||||
for x, tag in zip(vals, tags):
|
||||
self.tb.add_scalar(tag, x, epoch) # TensorBoard
|
||||
if self.wandb:
|
||||
self.wandb.log({k: v for k, v in zip(tags, vals)})
|
||||
self.wandb.end_epoch(best_result=best_fitness == fi)
|
||||
|
||||
def on_model_save(self, last, epoch, final_epoch, best_fitness, fi):
|
||||
# Callback runs on model save event
|
||||
if self.wandb:
|
||||
if ((epoch + 1) % self.opt.save_period == 0 and not final_epoch) and self.opt.save_period != -1:
|
||||
self.wandb.log_model(last.parent, self.opt, epoch, fi, best_model=best_fitness == fi)
|
||||
|
||||
def on_train_end(self, last, best):
|
||||
# Callback runs on training end
|
||||
files = ['results.png', 'confusion_matrix.png', *[f'{x}_curve.png' for x in ('F1', 'PR', 'P', 'R')]]
|
||||
files = [(self.save_dir / f) for f in files if (self.save_dir / f).exists()] # filter
|
||||
if self.wandb:
|
||||
wandb.log({"Results": [wandb.Image(str(f), caption=f.name) for f in files]})
|
||||
wandb.log_artifact(str(best if best.exists() else last), type='model',
|
||||
name='run_' + self.wandb.wandb_run.id + '_model',
|
||||
aliases=['latest', 'best', 'stripped'])
|
||||
self.wandb.finish_run()
|
||||
|
||||
def log_images(self, paths):
|
||||
# Log images
|
||||
if self.wandb:
|
||||
self.wandb.log({"Labels": [wandb.Image(str(x), caption=x.name) for x in paths]})
|
|
@ -1,12 +1,12 @@
|
|||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
import wandb
|
||||
|
||||
FILE = Path(__file__).absolute()
|
||||
sys.path.append(FILE.parents[2].as_posix()) # add utils/ to path
|
||||
|
||||
from train import train, parse_opt
|
||||
import test
|
||||
from utils.general import increment_path
|
||||
from utils.torch_utils import select_device
|
||||
|
|
@ -14,7 +14,7 @@
|
|||
# You can use grid, bayesian and hyperopt search strategy
|
||||
# For more info on configuring sweeps visit - https://docs.wandb.ai/guides/sweeps/configuration
|
||||
|
||||
program: utils/wandb_logging/sweep.py
|
||||
program: utils/loggers/wandb/sweep.py
|
||||
method: random
|
||||
metric:
|
||||
name: metrics/mAP_0.5
|
|
@ -1,4 +1,5 @@
|
|||
"""Utilities and tools for tracking runs with Weights & Biases."""
|
||||
|
||||
import logging
|
||||
import os
|
||||
import sys
|
||||
|
@ -8,15 +9,18 @@ from pathlib import Path
|
|||
import yaml
|
||||
from tqdm import tqdm
|
||||
|
||||
sys.path.append(str(Path(__file__).parent.parent.parent)) # add utils/ to path
|
||||
FILE = Path(__file__).absolute()
|
||||
sys.path.append(FILE.parents[3].as_posix()) # add yolov5/ to path
|
||||
|
||||
from utils.datasets import LoadImagesAndLabels
|
||||
from utils.datasets import img2label_paths
|
||||
from utils.general import colorstr, check_dataset, check_file
|
||||
from utils.general import check_dataset, check_file
|
||||
|
||||
try:
|
||||
import wandb
|
||||
from wandb import init, finish
|
||||
except ImportError:
|
||||
|
||||
assert hasattr(wandb, '__version__') # verify package import not local dir
|
||||
except (ImportError, AssertionError):
|
||||
wandb = None
|
||||
|
||||
RANK = int(os.getenv('RANK', -1))
|
||||
|
@ -106,7 +110,7 @@ class WandbLogger():
|
|||
self.data_dict = data_dict
|
||||
self.bbox_media_panel_images = []
|
||||
self.val_table_path_map = None
|
||||
self.max_imgs_to_log = 16
|
||||
self.max_imgs_to_log = 16
|
||||
# It's more elegant to stick to 1 wandb.init call, but useful config data is overwritten in the WandbLogger's wandb.init call
|
||||
if isinstance(opt.resume, str): # checks resume from artifact
|
||||
if opt.resume.startswith(WANDB_ARTIFACT_PREFIX):
|
||||
|
@ -134,13 +138,11 @@ class WandbLogger():
|
|||
if not opt.resume:
|
||||
wandb_data_dict = self.check_and_upload_dataset(opt) if opt.upload_dataset else data_dict
|
||||
# Info useful for resuming from artifacts
|
||||
self.wandb_run.config.update({'opt': vars(opt), 'data_dict': wandb_data_dict}, allow_val_change=True)
|
||||
self.wandb_run.config.update({'opt': vars(opt), 'data_dict': wandb_data_dict},
|
||||
allow_val_change=True)
|
||||
self.data_dict = self.setup_training(opt, data_dict)
|
||||
if self.job_type == 'Dataset Creation':
|
||||
self.data_dict = self.check_and_upload_dataset(opt)
|
||||
else:
|
||||
prefix = colorstr('wandb: ')
|
||||
print(f"{prefix}Install Weights & Biases for YOLOv5 logging with 'pip install wandb' (recommended)")
|
||||
|
||||
def check_and_upload_dataset(self, opt):
|
||||
assert wandb, 'Install wandb to upload dataset'
|
||||
|
@ -169,7 +171,7 @@ class WandbLogger():
|
|||
opt.artifact_alias)
|
||||
self.val_artifact_path, self.val_artifact = self.download_dataset_artifact(data_dict.get('val'),
|
||||
opt.artifact_alias)
|
||||
|
||||
|
||||
if self.train_artifact_path is not None:
|
||||
train_path = Path(self.train_artifact_path) / 'data/images/'
|
||||
data_dict['train'] = str(train_path)
|
||||
|
@ -177,7 +179,6 @@ class WandbLogger():
|
|||
val_path = Path(self.val_artifact_path) / 'data/images/'
|
||||
data_dict['val'] = str(val_path)
|
||||
|
||||
|
||||
if self.val_artifact is not None:
|
||||
self.result_artifact = wandb.Artifact("run_" + wandb.run.id + "_progress", "evaluation")
|
||||
self.result_table = wandb.Table(["epoch", "id", "ground truth", "prediction", "avg_confidence"])
|
||||
|
@ -315,9 +316,9 @@ class WandbLogger():
|
|||
)
|
||||
|
||||
def val_one_image(self, pred, predn, path, names, im):
|
||||
if self.val_table and self.result_table: # Log Table if Val dataset is uploaded as artifact
|
||||
if self.val_table and self.result_table: # Log Table if Val dataset is uploaded as artifact
|
||||
self.log_training_progress(predn, path, names)
|
||||
else: # Default to bbox media panelif Val artifact not found
|
||||
else: # Default to bbox media panelif Val artifact not found
|
||||
if len(self.bbox_media_panel_images) < self.max_imgs_to_log and self.current_epoch > 0:
|
||||
if self.current_epoch % self.bbox_interval == 0:
|
||||
box_data = [{"position": {"minX": xyxy[0], "minY": xyxy[1], "maxX": xyxy[2], "maxY": xyxy[3]},
|
||||
|
@ -328,7 +329,6 @@ class WandbLogger():
|
|||
boxes = {"predictions": {"box_data": box_data, "class_labels": names}} # inference-space
|
||||
self.bbox_media_panel_images.append(wandb.Image(im, boxes=boxes, caption=path.name))
|
||||
|
||||
|
||||
def log(self, log_dict):
|
||||
if self.wandb_run:
|
||||
for key, value in log_dict.items():
|
|
@ -327,9 +327,8 @@ def plot_labels(labels, names=(), save_dir=Path(''), loggers=None):
|
|||
plt.close()
|
||||
|
||||
# loggers
|
||||
for k, v in loggers.items() or {}:
|
||||
if k == 'wandb' and v:
|
||||
v.log({"Labels": [v.Image(str(x), caption=x.name) for x in save_dir.glob('*labels*.jpg')]}, commit=False)
|
||||
if loggers:
|
||||
loggers.log_images(save_dir.glob('*labels*.jpg'))
|
||||
|
||||
|
||||
def plot_evolution(yaml_file='data/hyp.finetune.yaml'): # from utils.plots import *; plot_evolution()
|
||||
|
|
10
val.py
10
val.py
|
@ -26,6 +26,7 @@ from utils.general import coco80_to_coco91_class, check_dataset, check_file, che
|
|||
from utils.metrics import ap_per_class, ConfusionMatrix
|
||||
from utils.plots import plot_images, output_to_target, plot_study_txt
|
||||
from utils.torch_utils import select_device, time_sync
|
||||
from utils.loggers import Loggers
|
||||
|
||||
|
||||
def save_one_txt(predn, save_conf, shape, file):
|
||||
|
@ -97,7 +98,7 @@ def run(data,
|
|||
dataloader=None,
|
||||
save_dir=Path(''),
|
||||
plots=True,
|
||||
wandb_logger=None,
|
||||
loggers=Loggers(),
|
||||
compute_loss=None,
|
||||
):
|
||||
# Initialize/load model and set device
|
||||
|
@ -215,8 +216,7 @@ def run(data,
|
|||
save_one_txt(predn, save_conf, shape, file=save_dir / 'labels' / (path.stem + '.txt'))
|
||||
if save_json:
|
||||
save_one_json(predn, jdict, path, class_map) # append to COCO-JSON dictionary
|
||||
if wandb_logger and wandb_logger.wandb_run:
|
||||
wandb_logger.val_one_image(pred, predn, path, names, img[si])
|
||||
loggers.on_val_batch_end(pred, predn, path, names, img[si])
|
||||
|
||||
# Plot images
|
||||
if plots and batch_i < 3:
|
||||
|
@ -253,9 +253,7 @@ def run(data,
|
|||
# Plots
|
||||
if plots:
|
||||
confusion_matrix.plot(save_dir=save_dir, names=list(names.values()))
|
||||
if wandb_logger and wandb_logger.wandb:
|
||||
val_batches = [wandb_logger.wandb.Image(str(f), caption=f.name) for f in sorted(save_dir.glob('val*.jpg'))]
|
||||
wandb_logger.log({"Validation": val_batches})
|
||||
loggers.on_val_end()
|
||||
|
||||
# Save JSON
|
||||
if save_json and len(jdict):
|
||||
|
|
Loading…
Reference in New Issue