New TryExcept decorator (#9154)

* New TryExcept decorator

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
pull/9155/head
Glenn Jocher 2022-08-25 14:34:26 +02:00 committed by GitHub
parent f0e5a608f5
commit d07ddc6996
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 69 additions and 59 deletions

View File

@ -3,6 +3,33 @@
utils/initialization utils/initialization
""" """
import contextlib
import threading
class TryExcept(contextlib.ContextDecorator):
# YOLOv5 TryExcept class. Usage: @TryExcept() decorator or 'with TryExcept():' context manager
def __init__(self, msg='default message here'):
self.msg = msg
def __enter__(self):
pass
def __exit__(self, exc_type, value, traceback):
if value:
print(f'{self.msg}: {value}')
return True
def threaded(func):
# Multi-threads a target function and returns thread. Usage: @threaded decorator
def wrapper(*args, **kwargs):
thread = threading.Thread(target=func, args=args, kwargs=kwargs, daemon=True)
thread.start()
return thread
return wrapper
def notebook_init(verbose=True): def notebook_init(verbose=True):
# Check system software and hardware # Check system software and hardware

View File

@ -15,7 +15,6 @@ import re
import shutil import shutil
import signal import signal
import sys import sys
import threading
import time import time
import urllib import urllib
from datetime import datetime from datetime import datetime
@ -34,6 +33,7 @@ import torch
import torchvision import torchvision
import yaml import yaml
from utils import TryExcept
from utils.downloads import gsutil_getsize from utils.downloads import gsutil_getsize
from utils.metrics import box_iou, fitness from utils.metrics import box_iou, fitness
@ -195,27 +195,6 @@ class WorkingDirectory(contextlib.ContextDecorator):
os.chdir(self.cwd) os.chdir(self.cwd)
def try_except(func):
# try-except function. Usage: @try_except decorator
def handler(*args, **kwargs):
try:
func(*args, **kwargs)
except Exception as e:
print(e)
return handler
def threaded(func):
# Multi-threads a target function and returns thread. Usage: @threaded decorator
def wrapper(*args, **kwargs):
thread = threading.Thread(target=func, args=args, kwargs=kwargs, daemon=True)
thread.start()
return thread
return wrapper
def methods(instance): def methods(instance):
# Get class/instance methods # Get class/instance methods
return [f for f in dir(instance) if callable(getattr(instance, f)) and not f.startswith("__")] return [f for f in dir(instance) if callable(getattr(instance, f)) and not f.startswith("__")]
@ -319,7 +298,7 @@ def git_describe(path=ROOT): # path must be a directory
return '' return ''
@try_except @TryExcept()
@WorkingDirectory(ROOT) @WorkingDirectory(ROOT)
def check_git_status(repo='ultralytics/yolov5'): def check_git_status(repo='ultralytics/yolov5'):
# YOLOv5 status check, recommend 'git pull' if code is out of date # YOLOv5 status check, recommend 'git pull' if code is out of date
@ -364,7 +343,7 @@ def check_version(current='0.0.0', minimum='0.0.0', name='version ', pinned=Fals
return result return result
@try_except @TryExcept()
def check_requirements(requirements=ROOT / 'requirements.txt', exclude=(), install=True, cmds=()): def check_requirements(requirements=ROOT / 'requirements.txt', exclude=(), install=True, cmds=()):
# Check installed dependencies meet YOLOv5 requirements (pass *.txt file or list of packages) # Check installed dependencies meet YOLOv5 requirements (pass *.txt file or list of packages)
prefix = colorstr('red', 'bold', 'requirements:') prefix = colorstr('red', 'bold', 'requirements:')

View File

@ -11,6 +11,8 @@ import matplotlib.pyplot as plt
import numpy as np import numpy as np
import torch import torch
from utils import TryExcept, threaded
def fitness(x): def fitness(x):
# Model fitness as a weighted combination of metrics # Model fitness as a weighted combination of metrics
@ -184,20 +186,21 @@ class ConfusionMatrix:
# fn = self.matrix.sum(0) - tp # false negatives (missed detections) # fn = self.matrix.sum(0) - tp # false negatives (missed detections)
return tp[:-1], fp[:-1] # remove background class return tp[:-1], fp[:-1] # remove background class
@TryExcept('WARNING: ConfusionMatrix plot failure')
def plot(self, normalize=True, save_dir='', names=()): def plot(self, normalize=True, save_dir='', names=()):
try:
import seaborn as sn import seaborn as sn
array = self.matrix / ((self.matrix.sum(0).reshape(1, -1) + 1E-9) if normalize else 1) # normalize columns array = self.matrix / ((self.matrix.sum(0).reshape(1, -1) + 1E-9) if normalize else 1) # normalize columns
array[array < 0.005] = np.nan # don't annotate (would appear as 0.00) array[array < 0.005] = np.nan # don't annotate (would appear as 0.00)
fig = plt.figure(figsize=(12, 9), tight_layout=True) fig, ax = plt.subplots(1, 1, figsize=(12, 9), tight_layout=True)
nc, nn = self.nc, len(names) # number of classes, names nc, nn = self.nc, len(names) # number of classes, names
sn.set(font_scale=1.0 if nc < 50 else 0.8) # for label size sn.set(font_scale=1.0 if nc < 50 else 0.8) # for label size
labels = (0 < nn < 99) and (nn == nc) # apply names to ticklabels labels = (0 < nn < 99) and (nn == nc) # apply names to ticklabels
with warnings.catch_warnings(): with warnings.catch_warnings():
warnings.simplefilter('ignore') # suppress empty matrix RuntimeWarning: All-NaN slice encountered warnings.simplefilter('ignore') # suppress empty matrix RuntimeWarning: All-NaN slice encountered
sn.heatmap(array, sn.heatmap(array,
ax=ax,
annot=nc < 30, annot=nc < 30,
annot_kws={ annot_kws={
"size": 8}, "size": 8},
@ -207,13 +210,11 @@ class ConfusionMatrix:
vmin=0.0, vmin=0.0,
xticklabels=names + ['background FP'] if labels else "auto", xticklabels=names + ['background FP'] if labels else "auto",
yticklabels=names + ['background FN'] if labels else "auto").set_facecolor((1, 1, 1)) yticklabels=names + ['background FN'] if labels else "auto").set_facecolor((1, 1, 1))
fig.axes[0].set_xlabel('True') ax.set_ylabel('True')
fig.axes[0].set_ylabel('Predicted') ax.set_ylabel('Predicted')
plt.title('Confusion Matrix') ax.set_title('Confusion Matrix')
fig.savefig(Path(save_dir) / 'confusion_matrix.png', dpi=250) fig.savefig(Path(save_dir) / 'confusion_matrix.png', dpi=250)
plt.close() plt.close(fig)
except Exception as e:
print(f'WARNING: ConfusionMatrix plot failure: {e}')
def print(self): def print(self):
for i in range(self.nc + 1): for i in range(self.nc + 1):
@ -320,6 +321,7 @@ def wh_iou(wh1, wh2, eps=1e-7):
# Plots ---------------------------------------------------------------------------------------------------------------- # Plots ----------------------------------------------------------------------------------------------------------------
@threaded
def plot_pr_curve(px, py, ap, save_dir=Path('pr_curve.png'), names=()): def plot_pr_curve(px, py, ap, save_dir=Path('pr_curve.png'), names=()):
# Precision-recall curve # Precision-recall curve
fig, ax = plt.subplots(1, 1, figsize=(9, 6), tight_layout=True) fig, ax = plt.subplots(1, 1, figsize=(9, 6), tight_layout=True)
@ -336,12 +338,13 @@ def plot_pr_curve(px, py, ap, save_dir=Path('pr_curve.png'), names=()):
ax.set_ylabel('Precision') ax.set_ylabel('Precision')
ax.set_xlim(0, 1) ax.set_xlim(0, 1)
ax.set_ylim(0, 1) ax.set_ylim(0, 1)
plt.legend(bbox_to_anchor=(1.04, 1), loc="upper left") ax.legend(bbox_to_anchor=(1.04, 1), loc="upper left")
plt.title('Precision-Recall Curve') ax.set_title('Precision-Recall Curve')
fig.savefig(save_dir, dpi=250) fig.savefig(save_dir, dpi=250)
plt.close() plt.close(fig)
@threaded
def plot_mc_curve(px, py, save_dir=Path('mc_curve.png'), names=(), xlabel='Confidence', ylabel='Metric'): def plot_mc_curve(px, py, save_dir=Path('mc_curve.png'), names=(), xlabel='Confidence', ylabel='Metric'):
# Metric-confidence curve # Metric-confidence curve
fig, ax = plt.subplots(1, 1, figsize=(9, 6), tight_layout=True) fig, ax = plt.subplots(1, 1, figsize=(9, 6), tight_layout=True)
@ -358,7 +361,7 @@ def plot_mc_curve(px, py, save_dir=Path('mc_curve.png'), names=(), xlabel='Confi
ax.set_ylabel(ylabel) ax.set_ylabel(ylabel)
ax.set_xlim(0, 1) ax.set_xlim(0, 1)
ax.set_ylim(0, 1) ax.set_ylim(0, 1)
plt.legend(bbox_to_anchor=(1.04, 1), loc="upper left") ax.legend(bbox_to_anchor=(1.04, 1), loc="upper left")
plt.title(f'{ylabel}-Confidence Curve') ax.set_title(f'{ylabel}-Confidence Curve')
fig.savefig(save_dir, dpi=250) fig.savefig(save_dir, dpi=250)
plt.close() plt.close(fig)

View File

@ -19,8 +19,9 @@ import seaborn as sn
import torch import torch
from PIL import Image, ImageDraw, ImageFont from PIL import Image, ImageDraw, ImageFont
from utils import TryExcept, threaded
from utils.general import (CONFIG_DIR, FONT, LOGGER, check_font, check_requirements, clip_coords, increment_path, from utils.general import (CONFIG_DIR, FONT, LOGGER, check_font, check_requirements, clip_coords, increment_path,
is_ascii, threaded, try_except, xywh2xyxy, xyxy2xywh) is_ascii, xywh2xyxy, xyxy2xywh)
from utils.metrics import fitness from utils.metrics import fitness
# Settings # Settings
@ -339,7 +340,7 @@ def plot_val_study(file='', dir='', x=None): # from utils.plots import *; plot_
plt.savefig(f, dpi=300) plt.savefig(f, dpi=300)
@try_except # known issue https://github.com/ultralytics/yolov5/issues/5395 @TryExcept() # known issue https://github.com/ultralytics/yolov5/issues/5395
def plot_labels(labels, names=(), save_dir=Path('')): def plot_labels(labels, names=(), save_dir=Path('')):
# plot dataset labels # plot dataset labels
LOGGER.info(f"Plotting labels to {save_dir / 'labels.jpg'}... ") LOGGER.info(f"Plotting labels to {save_dir / 'labels.jpg'}... ")