diff --git a/train.py b/train.py index e18864fef..ed1fa914d 100644 --- a/train.py +++ b/train.py @@ -48,8 +48,8 @@ from utils.dataloaders import create_dataloader from utils.downloads import attempt_download from utils.general import (LOGGER, check_dataset, check_file, check_git_status, check_img_size, check_requirements, check_suffix, check_version, check_yaml, colorstr, get_latest_run, increment_path, - init_seeds, intersect_dicts, is_ascii, labels_to_class_weights, labels_to_image_weights, - methods, one_cycle, print_args, print_mutation, strip_optimizer) + init_seeds, intersect_dicts, labels_to_class_weights, labels_to_image_weights, methods, + one_cycle, print_args, print_mutation, strip_optimizer) from utils.loggers import Loggers from utils.loggers.wandb.wandb_utils import check_wandb_resume from utils.loss import ComputeLoss diff --git a/utils/general.py b/utils/general.py index a8bc0cc17..2dc81ea1c 100755 --- a/utils/general.py +++ b/utils/general.py @@ -14,6 +14,7 @@ import random import re import shutil import signal +import threading import time import urllib from datetime import datetime @@ -167,6 +168,16 @@ def try_except(func): return handler +def threaded(func): + # Multi-threads a target function and returns thread. Usage: @threaded decorator + def wrapper(*args, **kwargs): + thread = threading.Thread(target=func, args=args, kwargs=kwargs, daemon=True) + thread.start() + return thread + + return wrapper + + def methods(instance): # Get class/instance methods return [f for f in dir(instance) if callable(getattr(instance, f)) and not f.startswith("__")] diff --git a/utils/loggers/__init__.py b/utils/loggers/__init__.py index ef94615e9..7eda62570 100644 --- a/utils/loggers/__init__.py +++ b/utils/loggers/__init__.py @@ -5,7 +5,6 @@ Logging utils import os import warnings -from threading import Thread import pkg_resources as pkg import torch @@ -109,7 +108,7 @@ class Loggers(): self.tb.add_graph(torch.jit.trace(de_parallel(model), imgs[0:1], strict=False), []) if ni < 3: f = self.save_dir / f'train_batch{ni}.jpg' # filename - Thread(target=plot_images, args=(imgs, targets, paths, f), daemon=True).start() + plot_images(imgs, targets, paths, f) if self.wandb and ni == 10: files = sorted(self.save_dir.glob('train*.jpg')) self.wandb.log({'Mosaics': [wandb.Image(str(f), caption=f.name) for f in files if f.exists()]}) @@ -132,7 +131,7 @@ class Loggers(): def on_fit_epoch_end(self, vals, epoch, best_fitness, fi): # Callback runs at the end of each fit (train+val) epoch - x = {k: v for k, v in zip(self.keys, vals)} # dict + x = dict(zip(self.keys, vals)) if self.csv: file = self.save_dir / 'results.csv' n = len(x) + 1 # number of cols @@ -171,7 +170,7 @@ class Loggers(): self.tb.add_image(f.stem, cv2.imread(str(f))[..., ::-1], epoch, dataformats='HWC') if self.wandb: - self.wandb.log({k: v for k, v in zip(self.keys[3:10], results)}) # log best.pt val results + self.wandb.log(dict(zip(self.keys[3:10], results))) self.wandb.log({"Results": [wandb.Image(str(f), caption=f.name) for f in files]}) # Calling wandb.log. TODO: Refactor this into WandbLogger.log_model if not self.opt.evolve: diff --git a/utils/plots.py b/utils/plots.py index 3ec70b62b..1bbb9c09c 100644 --- a/utils/plots.py +++ b/utils/plots.py @@ -19,7 +19,7 @@ import torch from PIL import Image, ImageDraw, ImageFont from utils.general import (CONFIG_DIR, FONT, LOGGER, Timeout, check_font, check_requirements, clip_coords, - increment_path, is_ascii, try_except, xywh2xyxy, xyxy2xywh) + increment_path, is_ascii, threaded, try_except, xywh2xyxy, xyxy2xywh) from utils.metrics import fitness # Settings @@ -32,9 +32,9 @@ class Colors: # Ultralytics color palette https://ultralytics.com/ def __init__(self): # hex = matplotlib.colors.TABLEAU_COLORS.values() - hex = ('FF3838', 'FF9D97', 'FF701F', 'FFB21D', 'CFD231', '48F90A', '92CC17', '3DDB86', '1A9334', '00D4BB', - '2C99A8', '00C2FF', '344593', '6473FF', '0018EC', '8438FF', '520085', 'CB38FF', 'FF95C8', 'FF37C7') - self.palette = [self.hex2rgb('#' + c) for c in hex] + hexs = ('FF3838', 'FF9D97', 'FF701F', 'FFB21D', 'CFD231', '48F90A', '92CC17', '3DDB86', '1A9334', '00D4BB', + '2C99A8', '00C2FF', '344593', '6473FF', '0018EC', '8438FF', '520085', 'CB38FF', 'FF95C8', 'FF37C7') + self.palette = [self.hex2rgb(f'#{c}') for c in hexs] self.n = len(self.palette) def __call__(self, i, bgr=False): @@ -100,7 +100,7 @@ class Annotator: if label: tf = max(self.lw - 1, 1) # font thickness w, h = cv2.getTextSize(label, 0, fontScale=self.lw / 3, thickness=tf)[0] # text width, height - outside = p1[1] - h - 3 >= 0 # label fits outside box + outside = p1[1] - h >= 3 p2 = p1[0] + w, p1[1] - h - 3 if outside else p1[1] + h + 3 cv2.rectangle(self.im, p1, p2, color, -1, cv2.LINE_AA) # filled cv2.putText(self.im, @@ -184,6 +184,7 @@ def output_to_target(output): return np.array(targets) +@threaded def plot_images(images, targets, paths=None, fname='images.jpg', names=None, max_size=1920, max_subplots=16): # Plot image grid with labels if isinstance(images, torch.Tensor): @@ -420,7 +421,7 @@ def plot_results(file='path/to/results.csv', dir=''): ax = ax.ravel() files = list(save_dir.glob('results*.csv')) assert len(files), f'No results.csv files found in {save_dir.resolve()}, nothing to plot.' - for fi, f in enumerate(files): + for f in files: try: data = pd.read_csv(f) s = [x.strip() for x in data.columns] diff --git a/val.py b/val.py index 0473f1d75..d886cf302 100644 --- a/val.py +++ b/val.py @@ -23,7 +23,6 @@ import json import os import sys from pathlib import Path -from threading import Thread import numpy as np import torch @@ -255,10 +254,8 @@ def run( # Plot images if plots and batch_i < 3: - f = save_dir / f'val_batch{batch_i}_labels.jpg' # labels - Thread(target=plot_images, args=(im, targets, paths, f, names), daemon=True).start() - f = save_dir / f'val_batch{batch_i}_pred.jpg' # predictions - Thread(target=plot_images, args=(im, output_to_target(out), paths, f, names), daemon=True).start() + plot_images(im, targets, paths, save_dir / f'val_batch{batch_i}_labels.jpg', names) # labels + plot_images(im, output_to_target(out), paths, save_dir / f'val_batch{batch_i}_pred.jpg', names) # pred callbacks.run('on_val_batch_end')