Add @threaded decorator (#7813)

* Add `@threaded` decorator

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
This commit is contained in:
Glenn Jocher 2022-05-14 16:12:08 +02:00 committed by GitHub
parent 9d8ed37df7
commit 4a295b1a89
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 25 additions and 17 deletions

View File

@ -48,8 +48,8 @@ from utils.dataloaders import create_dataloader
from utils.downloads import attempt_download from utils.downloads import attempt_download
from utils.general import (LOGGER, check_dataset, check_file, check_git_status, check_img_size, check_requirements, from utils.general import (LOGGER, check_dataset, check_file, check_git_status, check_img_size, check_requirements,
check_suffix, check_version, check_yaml, colorstr, get_latest_run, increment_path, check_suffix, check_version, check_yaml, colorstr, get_latest_run, increment_path,
init_seeds, intersect_dicts, is_ascii, labels_to_class_weights, labels_to_image_weights, init_seeds, intersect_dicts, labels_to_class_weights, labels_to_image_weights, methods,
methods, one_cycle, print_args, print_mutation, strip_optimizer) one_cycle, print_args, print_mutation, strip_optimizer)
from utils.loggers import Loggers from utils.loggers import Loggers
from utils.loggers.wandb.wandb_utils import check_wandb_resume from utils.loggers.wandb.wandb_utils import check_wandb_resume
from utils.loss import ComputeLoss from utils.loss import ComputeLoss

View File

@ -14,6 +14,7 @@ import random
import re import re
import shutil import shutil
import signal import signal
import threading
import time import time
import urllib import urllib
from datetime import datetime from datetime import datetime
@ -167,6 +168,16 @@ def try_except(func):
return handler return handler
def threaded(func):
# Multi-threads a target function and returns thread. Usage: @threaded decorator
def wrapper(*args, **kwargs):
thread = threading.Thread(target=func, args=args, kwargs=kwargs, daemon=True)
thread.start()
return thread
return wrapper
def methods(instance): def methods(instance):
# Get class/instance methods # Get class/instance methods
return [f for f in dir(instance) if callable(getattr(instance, f)) and not f.startswith("__")] return [f for f in dir(instance) if callable(getattr(instance, f)) and not f.startswith("__")]

View File

@ -5,7 +5,6 @@ Logging utils
import os import os
import warnings import warnings
from threading import Thread
import pkg_resources as pkg import pkg_resources as pkg
import torch import torch
@ -109,7 +108,7 @@ class Loggers():
self.tb.add_graph(torch.jit.trace(de_parallel(model), imgs[0:1], strict=False), []) self.tb.add_graph(torch.jit.trace(de_parallel(model), imgs[0:1], strict=False), [])
if ni < 3: if ni < 3:
f = self.save_dir / f'train_batch{ni}.jpg' # filename f = self.save_dir / f'train_batch{ni}.jpg' # filename
Thread(target=plot_images, args=(imgs, targets, paths, f), daemon=True).start() plot_images(imgs, targets, paths, f)
if self.wandb and ni == 10: if self.wandb and ni == 10:
files = sorted(self.save_dir.glob('train*.jpg')) files = sorted(self.save_dir.glob('train*.jpg'))
self.wandb.log({'Mosaics': [wandb.Image(str(f), caption=f.name) for f in files if f.exists()]}) self.wandb.log({'Mosaics': [wandb.Image(str(f), caption=f.name) for f in files if f.exists()]})
@ -132,7 +131,7 @@ class Loggers():
def on_fit_epoch_end(self, vals, epoch, best_fitness, fi): def on_fit_epoch_end(self, vals, epoch, best_fitness, fi):
# Callback runs at the end of each fit (train+val) epoch # Callback runs at the end of each fit (train+val) epoch
x = {k: v for k, v in zip(self.keys, vals)} # dict x = dict(zip(self.keys, vals))
if self.csv: if self.csv:
file = self.save_dir / 'results.csv' file = self.save_dir / 'results.csv'
n = len(x) + 1 # number of cols n = len(x) + 1 # number of cols
@ -171,7 +170,7 @@ class Loggers():
self.tb.add_image(f.stem, cv2.imread(str(f))[..., ::-1], epoch, dataformats='HWC') self.tb.add_image(f.stem, cv2.imread(str(f))[..., ::-1], epoch, dataformats='HWC')
if self.wandb: if self.wandb:
self.wandb.log({k: v for k, v in zip(self.keys[3:10], results)}) # log best.pt val results self.wandb.log(dict(zip(self.keys[3:10], results)))
self.wandb.log({"Results": [wandb.Image(str(f), caption=f.name) for f in files]}) self.wandb.log({"Results": [wandb.Image(str(f), caption=f.name) for f in files]})
# Calling wandb.log. TODO: Refactor this into WandbLogger.log_model # Calling wandb.log. TODO: Refactor this into WandbLogger.log_model
if not self.opt.evolve: if not self.opt.evolve:

View File

@ -19,7 +19,7 @@ import torch
from PIL import Image, ImageDraw, ImageFont from PIL import Image, ImageDraw, ImageFont
from utils.general import (CONFIG_DIR, FONT, LOGGER, Timeout, check_font, check_requirements, clip_coords, from utils.general import (CONFIG_DIR, FONT, LOGGER, Timeout, check_font, check_requirements, clip_coords,
increment_path, is_ascii, try_except, xywh2xyxy, xyxy2xywh) increment_path, is_ascii, threaded, try_except, xywh2xyxy, xyxy2xywh)
from utils.metrics import fitness from utils.metrics import fitness
# Settings # Settings
@ -32,9 +32,9 @@ class Colors:
# Ultralytics color palette https://ultralytics.com/ # Ultralytics color palette https://ultralytics.com/
def __init__(self): def __init__(self):
# hex = matplotlib.colors.TABLEAU_COLORS.values() # hex = matplotlib.colors.TABLEAU_COLORS.values()
hex = ('FF3838', 'FF9D97', 'FF701F', 'FFB21D', 'CFD231', '48F90A', '92CC17', '3DDB86', '1A9334', '00D4BB', hexs = ('FF3838', 'FF9D97', 'FF701F', 'FFB21D', 'CFD231', '48F90A', '92CC17', '3DDB86', '1A9334', '00D4BB',
'2C99A8', '00C2FF', '344593', '6473FF', '0018EC', '8438FF', '520085', 'CB38FF', 'FF95C8', 'FF37C7') '2C99A8', '00C2FF', '344593', '6473FF', '0018EC', '8438FF', '520085', 'CB38FF', 'FF95C8', 'FF37C7')
self.palette = [self.hex2rgb('#' + c) for c in hex] self.palette = [self.hex2rgb(f'#{c}') for c in hexs]
self.n = len(self.palette) self.n = len(self.palette)
def __call__(self, i, bgr=False): def __call__(self, i, bgr=False):
@ -100,7 +100,7 @@ class Annotator:
if label: if label:
tf = max(self.lw - 1, 1) # font thickness tf = max(self.lw - 1, 1) # font thickness
w, h = cv2.getTextSize(label, 0, fontScale=self.lw / 3, thickness=tf)[0] # text width, height w, h = cv2.getTextSize(label, 0, fontScale=self.lw / 3, thickness=tf)[0] # text width, height
outside = p1[1] - h - 3 >= 0 # label fits outside box outside = p1[1] - h >= 3
p2 = p1[0] + w, p1[1] - h - 3 if outside else p1[1] + h + 3 p2 = p1[0] + w, p1[1] - h - 3 if outside else p1[1] + h + 3
cv2.rectangle(self.im, p1, p2, color, -1, cv2.LINE_AA) # filled cv2.rectangle(self.im, p1, p2, color, -1, cv2.LINE_AA) # filled
cv2.putText(self.im, cv2.putText(self.im,
@ -184,6 +184,7 @@ def output_to_target(output):
return np.array(targets) return np.array(targets)
@threaded
def plot_images(images, targets, paths=None, fname='images.jpg', names=None, max_size=1920, max_subplots=16): def plot_images(images, targets, paths=None, fname='images.jpg', names=None, max_size=1920, max_subplots=16):
# Plot image grid with labels # Plot image grid with labels
if isinstance(images, torch.Tensor): if isinstance(images, torch.Tensor):
@ -420,7 +421,7 @@ def plot_results(file='path/to/results.csv', dir=''):
ax = ax.ravel() ax = ax.ravel()
files = list(save_dir.glob('results*.csv')) files = list(save_dir.glob('results*.csv'))
assert len(files), f'No results.csv files found in {save_dir.resolve()}, nothing to plot.' assert len(files), f'No results.csv files found in {save_dir.resolve()}, nothing to plot.'
for fi, f in enumerate(files): for f in files:
try: try:
data = pd.read_csv(f) data = pd.read_csv(f)
s = [x.strip() for x in data.columns] s = [x.strip() for x in data.columns]

7
val.py
View File

@ -23,7 +23,6 @@ import json
import os import os
import sys import sys
from pathlib import Path from pathlib import Path
from threading import Thread
import numpy as np import numpy as np
import torch import torch
@ -255,10 +254,8 @@ def run(
# Plot images # Plot images
if plots and batch_i < 3: if plots and batch_i < 3:
f = save_dir / f'val_batch{batch_i}_labels.jpg' # labels plot_images(im, targets, paths, save_dir / f'val_batch{batch_i}_labels.jpg', names) # labels
Thread(target=plot_images, args=(im, targets, paths, f, names), daemon=True).start() plot_images(im, output_to_target(out), paths, save_dir / f'val_batch{batch_i}_pred.jpg', names) # pred
f = save_dir / f'val_batch{batch_i}_pred.jpg' # predictions
Thread(target=plot_images, args=(im, output_to_target(out), paths, f, names), daemon=True).start()
callbacks.run('on_val_batch_end') callbacks.run('on_val_batch_end')