New TryExcept decorator (#9154)

* New TryExcept decorator

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
pull/9155/head
Glenn Jocher 2022-08-25 14:34:26 +02:00 committed by GitHub
parent f0e5a608f5
commit d07ddc6996
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 69 additions and 59 deletions

View File

@ -3,6 +3,33 @@
utils/initialization
"""
import contextlib
import threading
class TryExcept(contextlib.ContextDecorator):
# YOLOv5 TryExcept class. Usage: @TryExcept() decorator or 'with TryExcept():' context manager
def __init__(self, msg='default message here'):
self.msg = msg
def __enter__(self):
pass
def __exit__(self, exc_type, value, traceback):
if value:
print(f'{self.msg}: {value}')
return True
def threaded(func):
# Multi-threads a target function and returns thread. Usage: @threaded decorator
def wrapper(*args, **kwargs):
thread = threading.Thread(target=func, args=args, kwargs=kwargs, daemon=True)
thread.start()
return thread
return wrapper
def notebook_init(verbose=True):
# Check system software and hardware

View File

@ -15,7 +15,6 @@ import re
import shutil
import signal
import sys
import threading
import time
import urllib
from datetime import datetime
@ -34,6 +33,7 @@ import torch
import torchvision
import yaml
from utils import TryExcept
from utils.downloads import gsutil_getsize
from utils.metrics import box_iou, fitness
@ -195,27 +195,6 @@ class WorkingDirectory(contextlib.ContextDecorator):
os.chdir(self.cwd)
def try_except(func):
# try-except function. Usage: @try_except decorator
def handler(*args, **kwargs):
try:
func(*args, **kwargs)
except Exception as e:
print(e)
return handler
def threaded(func):
# Multi-threads a target function and returns thread. Usage: @threaded decorator
def wrapper(*args, **kwargs):
thread = threading.Thread(target=func, args=args, kwargs=kwargs, daemon=True)
thread.start()
return thread
return wrapper
def methods(instance):
# Get class/instance methods
return [f for f in dir(instance) if callable(getattr(instance, f)) and not f.startswith("__")]
@ -319,7 +298,7 @@ def git_describe(path=ROOT): # path must be a directory
return ''
@try_except
@TryExcept()
@WorkingDirectory(ROOT)
def check_git_status(repo='ultralytics/yolov5'):
# YOLOv5 status check, recommend 'git pull' if code is out of date
@ -364,7 +343,7 @@ def check_version(current='0.0.0', minimum='0.0.0', name='version ', pinned=Fals
return result
@try_except
@TryExcept()
def check_requirements(requirements=ROOT / 'requirements.txt', exclude=(), install=True, cmds=()):
# Check installed dependencies meet YOLOv5 requirements (pass *.txt file or list of packages)
prefix = colorstr('red', 'bold', 'requirements:')

View File

@ -11,6 +11,8 @@ import matplotlib.pyplot as plt
import numpy as np
import torch
from utils import TryExcept, threaded
def fitness(x):
# Model fitness as a weighted combination of metrics
@ -184,36 +186,35 @@ class ConfusionMatrix:
# fn = self.matrix.sum(0) - tp # false negatives (missed detections)
return tp[:-1], fp[:-1] # remove background class
@TryExcept('WARNING: ConfusionMatrix plot failure')
def plot(self, normalize=True, save_dir='', names=()):
try:
import seaborn as sn
import seaborn as sn
array = self.matrix / ((self.matrix.sum(0).reshape(1, -1) + 1E-9) if normalize else 1) # normalize columns
array[array < 0.005] = np.nan # don't annotate (would appear as 0.00)
array = self.matrix / ((self.matrix.sum(0).reshape(1, -1) + 1E-9) if normalize else 1) # normalize columns
array[array < 0.005] = np.nan # don't annotate (would appear as 0.00)
fig = plt.figure(figsize=(12, 9), tight_layout=True)
nc, nn = self.nc, len(names) # number of classes, names
sn.set(font_scale=1.0 if nc < 50 else 0.8) # for label size
labels = (0 < nn < 99) and (nn == nc) # apply names to ticklabels
with warnings.catch_warnings():
warnings.simplefilter('ignore') # suppress empty matrix RuntimeWarning: All-NaN slice encountered
sn.heatmap(array,
annot=nc < 30,
annot_kws={
"size": 8},
cmap='Blues',
fmt='.2f',
square=True,
vmin=0.0,
xticklabels=names + ['background FP'] if labels else "auto",
yticklabels=names + ['background FN'] if labels else "auto").set_facecolor((1, 1, 1))
fig.axes[0].set_xlabel('True')
fig.axes[0].set_ylabel('Predicted')
plt.title('Confusion Matrix')
fig.savefig(Path(save_dir) / 'confusion_matrix.png', dpi=250)
plt.close()
except Exception as e:
print(f'WARNING: ConfusionMatrix plot failure: {e}')
fig, ax = plt.subplots(1, 1, figsize=(12, 9), tight_layout=True)
nc, nn = self.nc, len(names) # number of classes, names
sn.set(font_scale=1.0 if nc < 50 else 0.8) # for label size
labels = (0 < nn < 99) and (nn == nc) # apply names to ticklabels
with warnings.catch_warnings():
warnings.simplefilter('ignore') # suppress empty matrix RuntimeWarning: All-NaN slice encountered
sn.heatmap(array,
ax=ax,
annot=nc < 30,
annot_kws={
"size": 8},
cmap='Blues',
fmt='.2f',
square=True,
vmin=0.0,
xticklabels=names + ['background FP'] if labels else "auto",
yticklabels=names + ['background FN'] if labels else "auto").set_facecolor((1, 1, 1))
ax.set_ylabel('True')
ax.set_ylabel('Predicted')
ax.set_title('Confusion Matrix')
fig.savefig(Path(save_dir) / 'confusion_matrix.png', dpi=250)
plt.close(fig)
def print(self):
for i in range(self.nc + 1):
@ -320,6 +321,7 @@ def wh_iou(wh1, wh2, eps=1e-7):
# Plots ----------------------------------------------------------------------------------------------------------------
@threaded
def plot_pr_curve(px, py, ap, save_dir=Path('pr_curve.png'), names=()):
# Precision-recall curve
fig, ax = plt.subplots(1, 1, figsize=(9, 6), tight_layout=True)
@ -336,12 +338,13 @@ def plot_pr_curve(px, py, ap, save_dir=Path('pr_curve.png'), names=()):
ax.set_ylabel('Precision')
ax.set_xlim(0, 1)
ax.set_ylim(0, 1)
plt.legend(bbox_to_anchor=(1.04, 1), loc="upper left")
plt.title('Precision-Recall Curve')
ax.legend(bbox_to_anchor=(1.04, 1), loc="upper left")
ax.set_title('Precision-Recall Curve')
fig.savefig(save_dir, dpi=250)
plt.close()
plt.close(fig)
@threaded
def plot_mc_curve(px, py, save_dir=Path('mc_curve.png'), names=(), xlabel='Confidence', ylabel='Metric'):
# Metric-confidence curve
fig, ax = plt.subplots(1, 1, figsize=(9, 6), tight_layout=True)
@ -358,7 +361,7 @@ def plot_mc_curve(px, py, save_dir=Path('mc_curve.png'), names=(), xlabel='Confi
ax.set_ylabel(ylabel)
ax.set_xlim(0, 1)
ax.set_ylim(0, 1)
plt.legend(bbox_to_anchor=(1.04, 1), loc="upper left")
plt.title(f'{ylabel}-Confidence Curve')
ax.legend(bbox_to_anchor=(1.04, 1), loc="upper left")
ax.set_title(f'{ylabel}-Confidence Curve')
fig.savefig(save_dir, dpi=250)
plt.close()
plt.close(fig)

View File

@ -19,8 +19,9 @@ import seaborn as sn
import torch
from PIL import Image, ImageDraw, ImageFont
from utils import TryExcept, threaded
from utils.general import (CONFIG_DIR, FONT, LOGGER, check_font, check_requirements, clip_coords, increment_path,
is_ascii, threaded, try_except, xywh2xyxy, xyxy2xywh)
is_ascii, xywh2xyxy, xyxy2xywh)
from utils.metrics import fitness
# Settings
@ -339,7 +340,7 @@ def plot_val_study(file='', dir='', x=None): # from utils.plots import *; plot_
plt.savefig(f, dpi=300)
@try_except # known issue https://github.com/ultralytics/yolov5/issues/5395
@TryExcept() # known issue https://github.com/ultralytics/yolov5/issues/5395
def plot_labels(labels, names=(), save_dir=Path('')):
# plot dataset labels
LOGGER.info(f"Plotting labels to {save_dir / 'labels.jpg'}... ")