mirror of
https://github.com/ultralytics/yolov5.git
synced 2025-06-03 14:49:29 +08:00
Add @threaded
decorator (#7813)
* Add `@threaded` decorator * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
This commit is contained in:
parent
9d8ed37df7
commit
4a295b1a89
4
train.py
4
train.py
@ -48,8 +48,8 @@ from utils.dataloaders import create_dataloader
|
|||||||
from utils.downloads import attempt_download
|
from utils.downloads import attempt_download
|
||||||
from utils.general import (LOGGER, check_dataset, check_file, check_git_status, check_img_size, check_requirements,
|
from utils.general import (LOGGER, check_dataset, check_file, check_git_status, check_img_size, check_requirements,
|
||||||
check_suffix, check_version, check_yaml, colorstr, get_latest_run, increment_path,
|
check_suffix, check_version, check_yaml, colorstr, get_latest_run, increment_path,
|
||||||
init_seeds, intersect_dicts, is_ascii, labels_to_class_weights, labels_to_image_weights,
|
init_seeds, intersect_dicts, labels_to_class_weights, labels_to_image_weights, methods,
|
||||||
methods, one_cycle, print_args, print_mutation, strip_optimizer)
|
one_cycle, print_args, print_mutation, strip_optimizer)
|
||||||
from utils.loggers import Loggers
|
from utils.loggers import Loggers
|
||||||
from utils.loggers.wandb.wandb_utils import check_wandb_resume
|
from utils.loggers.wandb.wandb_utils import check_wandb_resume
|
||||||
from utils.loss import ComputeLoss
|
from utils.loss import ComputeLoss
|
||||||
|
@ -14,6 +14,7 @@ import random
|
|||||||
import re
|
import re
|
||||||
import shutil
|
import shutil
|
||||||
import signal
|
import signal
|
||||||
|
import threading
|
||||||
import time
|
import time
|
||||||
import urllib
|
import urllib
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
@ -167,6 +168,16 @@ def try_except(func):
|
|||||||
return handler
|
return handler
|
||||||
|
|
||||||
|
|
||||||
|
def threaded(func):
|
||||||
|
# Multi-threads a target function and returns thread. Usage: @threaded decorator
|
||||||
|
def wrapper(*args, **kwargs):
|
||||||
|
thread = threading.Thread(target=func, args=args, kwargs=kwargs, daemon=True)
|
||||||
|
thread.start()
|
||||||
|
return thread
|
||||||
|
|
||||||
|
return wrapper
|
||||||
|
|
||||||
|
|
||||||
def methods(instance):
|
def methods(instance):
|
||||||
# Get class/instance methods
|
# Get class/instance methods
|
||||||
return [f for f in dir(instance) if callable(getattr(instance, f)) and not f.startswith("__")]
|
return [f for f in dir(instance) if callable(getattr(instance, f)) and not f.startswith("__")]
|
||||||
|
@ -5,7 +5,6 @@ Logging utils
|
|||||||
|
|
||||||
import os
|
import os
|
||||||
import warnings
|
import warnings
|
||||||
from threading import Thread
|
|
||||||
|
|
||||||
import pkg_resources as pkg
|
import pkg_resources as pkg
|
||||||
import torch
|
import torch
|
||||||
@ -109,7 +108,7 @@ class Loggers():
|
|||||||
self.tb.add_graph(torch.jit.trace(de_parallel(model), imgs[0:1], strict=False), [])
|
self.tb.add_graph(torch.jit.trace(de_parallel(model), imgs[0:1], strict=False), [])
|
||||||
if ni < 3:
|
if ni < 3:
|
||||||
f = self.save_dir / f'train_batch{ni}.jpg' # filename
|
f = self.save_dir / f'train_batch{ni}.jpg' # filename
|
||||||
Thread(target=plot_images, args=(imgs, targets, paths, f), daemon=True).start()
|
plot_images(imgs, targets, paths, f)
|
||||||
if self.wandb and ni == 10:
|
if self.wandb and ni == 10:
|
||||||
files = sorted(self.save_dir.glob('train*.jpg'))
|
files = sorted(self.save_dir.glob('train*.jpg'))
|
||||||
self.wandb.log({'Mosaics': [wandb.Image(str(f), caption=f.name) for f in files if f.exists()]})
|
self.wandb.log({'Mosaics': [wandb.Image(str(f), caption=f.name) for f in files if f.exists()]})
|
||||||
@ -132,7 +131,7 @@ class Loggers():
|
|||||||
|
|
||||||
def on_fit_epoch_end(self, vals, epoch, best_fitness, fi):
|
def on_fit_epoch_end(self, vals, epoch, best_fitness, fi):
|
||||||
# Callback runs at the end of each fit (train+val) epoch
|
# Callback runs at the end of each fit (train+val) epoch
|
||||||
x = {k: v for k, v in zip(self.keys, vals)} # dict
|
x = dict(zip(self.keys, vals))
|
||||||
if self.csv:
|
if self.csv:
|
||||||
file = self.save_dir / 'results.csv'
|
file = self.save_dir / 'results.csv'
|
||||||
n = len(x) + 1 # number of cols
|
n = len(x) + 1 # number of cols
|
||||||
@ -171,7 +170,7 @@ class Loggers():
|
|||||||
self.tb.add_image(f.stem, cv2.imread(str(f))[..., ::-1], epoch, dataformats='HWC')
|
self.tb.add_image(f.stem, cv2.imread(str(f))[..., ::-1], epoch, dataformats='HWC')
|
||||||
|
|
||||||
if self.wandb:
|
if self.wandb:
|
||||||
self.wandb.log({k: v for k, v in zip(self.keys[3:10], results)}) # log best.pt val results
|
self.wandb.log(dict(zip(self.keys[3:10], results)))
|
||||||
self.wandb.log({"Results": [wandb.Image(str(f), caption=f.name) for f in files]})
|
self.wandb.log({"Results": [wandb.Image(str(f), caption=f.name) for f in files]})
|
||||||
# Calling wandb.log. TODO: Refactor this into WandbLogger.log_model
|
# Calling wandb.log. TODO: Refactor this into WandbLogger.log_model
|
||||||
if not self.opt.evolve:
|
if not self.opt.evolve:
|
||||||
|
@ -19,7 +19,7 @@ import torch
|
|||||||
from PIL import Image, ImageDraw, ImageFont
|
from PIL import Image, ImageDraw, ImageFont
|
||||||
|
|
||||||
from utils.general import (CONFIG_DIR, FONT, LOGGER, Timeout, check_font, check_requirements, clip_coords,
|
from utils.general import (CONFIG_DIR, FONT, LOGGER, Timeout, check_font, check_requirements, clip_coords,
|
||||||
increment_path, is_ascii, try_except, xywh2xyxy, xyxy2xywh)
|
increment_path, is_ascii, threaded, try_except, xywh2xyxy, xyxy2xywh)
|
||||||
from utils.metrics import fitness
|
from utils.metrics import fitness
|
||||||
|
|
||||||
# Settings
|
# Settings
|
||||||
@ -32,9 +32,9 @@ class Colors:
|
|||||||
# Ultralytics color palette https://ultralytics.com/
|
# Ultralytics color palette https://ultralytics.com/
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
# hex = matplotlib.colors.TABLEAU_COLORS.values()
|
# hex = matplotlib.colors.TABLEAU_COLORS.values()
|
||||||
hex = ('FF3838', 'FF9D97', 'FF701F', 'FFB21D', 'CFD231', '48F90A', '92CC17', '3DDB86', '1A9334', '00D4BB',
|
hexs = ('FF3838', 'FF9D97', 'FF701F', 'FFB21D', 'CFD231', '48F90A', '92CC17', '3DDB86', '1A9334', '00D4BB',
|
||||||
'2C99A8', '00C2FF', '344593', '6473FF', '0018EC', '8438FF', '520085', 'CB38FF', 'FF95C8', 'FF37C7')
|
'2C99A8', '00C2FF', '344593', '6473FF', '0018EC', '8438FF', '520085', 'CB38FF', 'FF95C8', 'FF37C7')
|
||||||
self.palette = [self.hex2rgb('#' + c) for c in hex]
|
self.palette = [self.hex2rgb(f'#{c}') for c in hexs]
|
||||||
self.n = len(self.palette)
|
self.n = len(self.palette)
|
||||||
|
|
||||||
def __call__(self, i, bgr=False):
|
def __call__(self, i, bgr=False):
|
||||||
@ -100,7 +100,7 @@ class Annotator:
|
|||||||
if label:
|
if label:
|
||||||
tf = max(self.lw - 1, 1) # font thickness
|
tf = max(self.lw - 1, 1) # font thickness
|
||||||
w, h = cv2.getTextSize(label, 0, fontScale=self.lw / 3, thickness=tf)[0] # text width, height
|
w, h = cv2.getTextSize(label, 0, fontScale=self.lw / 3, thickness=tf)[0] # text width, height
|
||||||
outside = p1[1] - h - 3 >= 0 # label fits outside box
|
outside = p1[1] - h >= 3
|
||||||
p2 = p1[0] + w, p1[1] - h - 3 if outside else p1[1] + h + 3
|
p2 = p1[0] + w, p1[1] - h - 3 if outside else p1[1] + h + 3
|
||||||
cv2.rectangle(self.im, p1, p2, color, -1, cv2.LINE_AA) # filled
|
cv2.rectangle(self.im, p1, p2, color, -1, cv2.LINE_AA) # filled
|
||||||
cv2.putText(self.im,
|
cv2.putText(self.im,
|
||||||
@ -184,6 +184,7 @@ def output_to_target(output):
|
|||||||
return np.array(targets)
|
return np.array(targets)
|
||||||
|
|
||||||
|
|
||||||
|
@threaded
|
||||||
def plot_images(images, targets, paths=None, fname='images.jpg', names=None, max_size=1920, max_subplots=16):
|
def plot_images(images, targets, paths=None, fname='images.jpg', names=None, max_size=1920, max_subplots=16):
|
||||||
# Plot image grid with labels
|
# Plot image grid with labels
|
||||||
if isinstance(images, torch.Tensor):
|
if isinstance(images, torch.Tensor):
|
||||||
@ -420,7 +421,7 @@ def plot_results(file='path/to/results.csv', dir=''):
|
|||||||
ax = ax.ravel()
|
ax = ax.ravel()
|
||||||
files = list(save_dir.glob('results*.csv'))
|
files = list(save_dir.glob('results*.csv'))
|
||||||
assert len(files), f'No results.csv files found in {save_dir.resolve()}, nothing to plot.'
|
assert len(files), f'No results.csv files found in {save_dir.resolve()}, nothing to plot.'
|
||||||
for fi, f in enumerate(files):
|
for f in files:
|
||||||
try:
|
try:
|
||||||
data = pd.read_csv(f)
|
data = pd.read_csv(f)
|
||||||
s = [x.strip() for x in data.columns]
|
s = [x.strip() for x in data.columns]
|
||||||
|
7
val.py
7
val.py
@ -23,7 +23,6 @@ import json
|
|||||||
import os
|
import os
|
||||||
import sys
|
import sys
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from threading import Thread
|
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
@ -255,10 +254,8 @@ def run(
|
|||||||
|
|
||||||
# Plot images
|
# Plot images
|
||||||
if plots and batch_i < 3:
|
if plots and batch_i < 3:
|
||||||
f = save_dir / f'val_batch{batch_i}_labels.jpg' # labels
|
plot_images(im, targets, paths, save_dir / f'val_batch{batch_i}_labels.jpg', names) # labels
|
||||||
Thread(target=plot_images, args=(im, targets, paths, f, names), daemon=True).start()
|
plot_images(im, output_to_target(out), paths, save_dir / f'val_batch{batch_i}_pred.jpg', names) # pred
|
||||||
f = save_dir / f'val_batch{batch_i}_pred.jpg' # predictions
|
|
||||||
Thread(target=plot_images, args=(im, output_to_target(out), paths, f, names), daemon=True).start()
|
|
||||||
|
|
||||||
callbacks.run('on_val_batch_end')
|
callbacks.run('on_val_batch_end')
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user