New TryExcept decorator (#9154)
* New TryExcept decorator * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>pull/9155/head
parent
f0e5a608f5
commit
d07ddc6996
|
@ -3,6 +3,33 @@
|
|||
utils/initialization
|
||||
"""
|
||||
|
||||
import contextlib
|
||||
import threading
|
||||
|
||||
|
||||
class TryExcept(contextlib.ContextDecorator):
|
||||
# YOLOv5 TryExcept class. Usage: @TryExcept() decorator or 'with TryExcept():' context manager
|
||||
def __init__(self, msg='default message here'):
|
||||
self.msg = msg
|
||||
|
||||
def __enter__(self):
|
||||
pass
|
||||
|
||||
def __exit__(self, exc_type, value, traceback):
|
||||
if value:
|
||||
print(f'{self.msg}: {value}')
|
||||
return True
|
||||
|
||||
|
||||
def threaded(func):
|
||||
# Multi-threads a target function and returns thread. Usage: @threaded decorator
|
||||
def wrapper(*args, **kwargs):
|
||||
thread = threading.Thread(target=func, args=args, kwargs=kwargs, daemon=True)
|
||||
thread.start()
|
||||
return thread
|
||||
|
||||
return wrapper
|
||||
|
||||
|
||||
def notebook_init(verbose=True):
|
||||
# Check system software and hardware
|
||||
|
|
|
@ -15,7 +15,6 @@ import re
|
|||
import shutil
|
||||
import signal
|
||||
import sys
|
||||
import threading
|
||||
import time
|
||||
import urllib
|
||||
from datetime import datetime
|
||||
|
@ -34,6 +33,7 @@ import torch
|
|||
import torchvision
|
||||
import yaml
|
||||
|
||||
from utils import TryExcept
|
||||
from utils.downloads import gsutil_getsize
|
||||
from utils.metrics import box_iou, fitness
|
||||
|
||||
|
@ -195,27 +195,6 @@ class WorkingDirectory(contextlib.ContextDecorator):
|
|||
os.chdir(self.cwd)
|
||||
|
||||
|
||||
def try_except(func):
|
||||
# try-except function. Usage: @try_except decorator
|
||||
def handler(*args, **kwargs):
|
||||
try:
|
||||
func(*args, **kwargs)
|
||||
except Exception as e:
|
||||
print(e)
|
||||
|
||||
return handler
|
||||
|
||||
|
||||
def threaded(func):
|
||||
# Multi-threads a target function and returns thread. Usage: @threaded decorator
|
||||
def wrapper(*args, **kwargs):
|
||||
thread = threading.Thread(target=func, args=args, kwargs=kwargs, daemon=True)
|
||||
thread.start()
|
||||
return thread
|
||||
|
||||
return wrapper
|
||||
|
||||
|
||||
def methods(instance):
|
||||
# Get class/instance methods
|
||||
return [f for f in dir(instance) if callable(getattr(instance, f)) and not f.startswith("__")]
|
||||
|
@ -319,7 +298,7 @@ def git_describe(path=ROOT): # path must be a directory
|
|||
return ''
|
||||
|
||||
|
||||
@try_except
|
||||
@TryExcept()
|
||||
@WorkingDirectory(ROOT)
|
||||
def check_git_status(repo='ultralytics/yolov5'):
|
||||
# YOLOv5 status check, recommend 'git pull' if code is out of date
|
||||
|
@ -364,7 +343,7 @@ def check_version(current='0.0.0', minimum='0.0.0', name='version ', pinned=Fals
|
|||
return result
|
||||
|
||||
|
||||
@try_except
|
||||
@TryExcept()
|
||||
def check_requirements(requirements=ROOT / 'requirements.txt', exclude=(), install=True, cmds=()):
|
||||
# Check installed dependencies meet YOLOv5 requirements (pass *.txt file or list of packages)
|
||||
prefix = colorstr('red', 'bold', 'requirements:')
|
||||
|
|
|
@ -11,6 +11,8 @@ import matplotlib.pyplot as plt
|
|||
import numpy as np
|
||||
import torch
|
||||
|
||||
from utils import TryExcept, threaded
|
||||
|
||||
|
||||
def fitness(x):
|
||||
# Model fitness as a weighted combination of metrics
|
||||
|
@ -184,36 +186,35 @@ class ConfusionMatrix:
|
|||
# fn = self.matrix.sum(0) - tp # false negatives (missed detections)
|
||||
return tp[:-1], fp[:-1] # remove background class
|
||||
|
||||
@TryExcept('WARNING: ConfusionMatrix plot failure')
|
||||
def plot(self, normalize=True, save_dir='', names=()):
|
||||
try:
|
||||
import seaborn as sn
|
||||
import seaborn as sn
|
||||
|
||||
array = self.matrix / ((self.matrix.sum(0).reshape(1, -1) + 1E-9) if normalize else 1) # normalize columns
|
||||
array[array < 0.005] = np.nan # don't annotate (would appear as 0.00)
|
||||
array = self.matrix / ((self.matrix.sum(0).reshape(1, -1) + 1E-9) if normalize else 1) # normalize columns
|
||||
array[array < 0.005] = np.nan # don't annotate (would appear as 0.00)
|
||||
|
||||
fig = plt.figure(figsize=(12, 9), tight_layout=True)
|
||||
nc, nn = self.nc, len(names) # number of classes, names
|
||||
sn.set(font_scale=1.0 if nc < 50 else 0.8) # for label size
|
||||
labels = (0 < nn < 99) and (nn == nc) # apply names to ticklabels
|
||||
with warnings.catch_warnings():
|
||||
warnings.simplefilter('ignore') # suppress empty matrix RuntimeWarning: All-NaN slice encountered
|
||||
sn.heatmap(array,
|
||||
annot=nc < 30,
|
||||
annot_kws={
|
||||
"size": 8},
|
||||
cmap='Blues',
|
||||
fmt='.2f',
|
||||
square=True,
|
||||
vmin=0.0,
|
||||
xticklabels=names + ['background FP'] if labels else "auto",
|
||||
yticklabels=names + ['background FN'] if labels else "auto").set_facecolor((1, 1, 1))
|
||||
fig.axes[0].set_xlabel('True')
|
||||
fig.axes[0].set_ylabel('Predicted')
|
||||
plt.title('Confusion Matrix')
|
||||
fig.savefig(Path(save_dir) / 'confusion_matrix.png', dpi=250)
|
||||
plt.close()
|
||||
except Exception as e:
|
||||
print(f'WARNING: ConfusionMatrix plot failure: {e}')
|
||||
fig, ax = plt.subplots(1, 1, figsize=(12, 9), tight_layout=True)
|
||||
nc, nn = self.nc, len(names) # number of classes, names
|
||||
sn.set(font_scale=1.0 if nc < 50 else 0.8) # for label size
|
||||
labels = (0 < nn < 99) and (nn == nc) # apply names to ticklabels
|
||||
with warnings.catch_warnings():
|
||||
warnings.simplefilter('ignore') # suppress empty matrix RuntimeWarning: All-NaN slice encountered
|
||||
sn.heatmap(array,
|
||||
ax=ax,
|
||||
annot=nc < 30,
|
||||
annot_kws={
|
||||
"size": 8},
|
||||
cmap='Blues',
|
||||
fmt='.2f',
|
||||
square=True,
|
||||
vmin=0.0,
|
||||
xticklabels=names + ['background FP'] if labels else "auto",
|
||||
yticklabels=names + ['background FN'] if labels else "auto").set_facecolor((1, 1, 1))
|
||||
ax.set_ylabel('True')
|
||||
ax.set_ylabel('Predicted')
|
||||
ax.set_title('Confusion Matrix')
|
||||
fig.savefig(Path(save_dir) / 'confusion_matrix.png', dpi=250)
|
||||
plt.close(fig)
|
||||
|
||||
def print(self):
|
||||
for i in range(self.nc + 1):
|
||||
|
@ -320,6 +321,7 @@ def wh_iou(wh1, wh2, eps=1e-7):
|
|||
# Plots ----------------------------------------------------------------------------------------------------------------
|
||||
|
||||
|
||||
@threaded
|
||||
def plot_pr_curve(px, py, ap, save_dir=Path('pr_curve.png'), names=()):
|
||||
# Precision-recall curve
|
||||
fig, ax = plt.subplots(1, 1, figsize=(9, 6), tight_layout=True)
|
||||
|
@ -336,12 +338,13 @@ def plot_pr_curve(px, py, ap, save_dir=Path('pr_curve.png'), names=()):
|
|||
ax.set_ylabel('Precision')
|
||||
ax.set_xlim(0, 1)
|
||||
ax.set_ylim(0, 1)
|
||||
plt.legend(bbox_to_anchor=(1.04, 1), loc="upper left")
|
||||
plt.title('Precision-Recall Curve')
|
||||
ax.legend(bbox_to_anchor=(1.04, 1), loc="upper left")
|
||||
ax.set_title('Precision-Recall Curve')
|
||||
fig.savefig(save_dir, dpi=250)
|
||||
plt.close()
|
||||
plt.close(fig)
|
||||
|
||||
|
||||
@threaded
|
||||
def plot_mc_curve(px, py, save_dir=Path('mc_curve.png'), names=(), xlabel='Confidence', ylabel='Metric'):
|
||||
# Metric-confidence curve
|
||||
fig, ax = plt.subplots(1, 1, figsize=(9, 6), tight_layout=True)
|
||||
|
@ -358,7 +361,7 @@ def plot_mc_curve(px, py, save_dir=Path('mc_curve.png'), names=(), xlabel='Confi
|
|||
ax.set_ylabel(ylabel)
|
||||
ax.set_xlim(0, 1)
|
||||
ax.set_ylim(0, 1)
|
||||
plt.legend(bbox_to_anchor=(1.04, 1), loc="upper left")
|
||||
plt.title(f'{ylabel}-Confidence Curve')
|
||||
ax.legend(bbox_to_anchor=(1.04, 1), loc="upper left")
|
||||
ax.set_title(f'{ylabel}-Confidence Curve')
|
||||
fig.savefig(save_dir, dpi=250)
|
||||
plt.close()
|
||||
plt.close(fig)
|
||||
|
|
|
@ -19,8 +19,9 @@ import seaborn as sn
|
|||
import torch
|
||||
from PIL import Image, ImageDraw, ImageFont
|
||||
|
||||
from utils import TryExcept, threaded
|
||||
from utils.general import (CONFIG_DIR, FONT, LOGGER, check_font, check_requirements, clip_coords, increment_path,
|
||||
is_ascii, threaded, try_except, xywh2xyxy, xyxy2xywh)
|
||||
is_ascii, xywh2xyxy, xyxy2xywh)
|
||||
from utils.metrics import fitness
|
||||
|
||||
# Settings
|
||||
|
@ -339,7 +340,7 @@ def plot_val_study(file='', dir='', x=None): # from utils.plots import *; plot_
|
|||
plt.savefig(f, dpi=300)
|
||||
|
||||
|
||||
@try_except # known issue https://github.com/ultralytics/yolov5/issues/5395
|
||||
@TryExcept() # known issue https://github.com/ultralytics/yolov5/issues/5395
|
||||
def plot_labels(labels, names=(), save_dir=Path('')):
|
||||
# plot dataset labels
|
||||
LOGGER.info(f"Plotting labels to {save_dir / 'labels.jpg'}... ")
|
||||
|
|
Loading…
Reference in New Issue