New TryExcept decorator (#9154)
* New TryExcept decorator * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>pull/9155/head
parent
f0e5a608f5
commit
d07ddc6996
|
@ -3,6 +3,33 @@
|
||||||
utils/initialization
|
utils/initialization
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
import contextlib
|
||||||
|
import threading
|
||||||
|
|
||||||
|
|
||||||
|
class TryExcept(contextlib.ContextDecorator):
|
||||||
|
# YOLOv5 TryExcept class. Usage: @TryExcept() decorator or 'with TryExcept():' context manager
|
||||||
|
def __init__(self, msg='default message here'):
|
||||||
|
self.msg = msg
|
||||||
|
|
||||||
|
def __enter__(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
def __exit__(self, exc_type, value, traceback):
|
||||||
|
if value:
|
||||||
|
print(f'{self.msg}: {value}')
|
||||||
|
return True
|
||||||
|
|
||||||
|
|
||||||
|
def threaded(func):
|
||||||
|
# Multi-threads a target function and returns thread. Usage: @threaded decorator
|
||||||
|
def wrapper(*args, **kwargs):
|
||||||
|
thread = threading.Thread(target=func, args=args, kwargs=kwargs, daemon=True)
|
||||||
|
thread.start()
|
||||||
|
return thread
|
||||||
|
|
||||||
|
return wrapper
|
||||||
|
|
||||||
|
|
||||||
def notebook_init(verbose=True):
|
def notebook_init(verbose=True):
|
||||||
# Check system software and hardware
|
# Check system software and hardware
|
||||||
|
|
|
@ -15,7 +15,6 @@ import re
|
||||||
import shutil
|
import shutil
|
||||||
import signal
|
import signal
|
||||||
import sys
|
import sys
|
||||||
import threading
|
|
||||||
import time
|
import time
|
||||||
import urllib
|
import urllib
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
|
@ -34,6 +33,7 @@ import torch
|
||||||
import torchvision
|
import torchvision
|
||||||
import yaml
|
import yaml
|
||||||
|
|
||||||
|
from utils import TryExcept
|
||||||
from utils.downloads import gsutil_getsize
|
from utils.downloads import gsutil_getsize
|
||||||
from utils.metrics import box_iou, fitness
|
from utils.metrics import box_iou, fitness
|
||||||
|
|
||||||
|
@ -195,27 +195,6 @@ class WorkingDirectory(contextlib.ContextDecorator):
|
||||||
os.chdir(self.cwd)
|
os.chdir(self.cwd)
|
||||||
|
|
||||||
|
|
||||||
def try_except(func):
|
|
||||||
# try-except function. Usage: @try_except decorator
|
|
||||||
def handler(*args, **kwargs):
|
|
||||||
try:
|
|
||||||
func(*args, **kwargs)
|
|
||||||
except Exception as e:
|
|
||||||
print(e)
|
|
||||||
|
|
||||||
return handler
|
|
||||||
|
|
||||||
|
|
||||||
def threaded(func):
|
|
||||||
# Multi-threads a target function and returns thread. Usage: @threaded decorator
|
|
||||||
def wrapper(*args, **kwargs):
|
|
||||||
thread = threading.Thread(target=func, args=args, kwargs=kwargs, daemon=True)
|
|
||||||
thread.start()
|
|
||||||
return thread
|
|
||||||
|
|
||||||
return wrapper
|
|
||||||
|
|
||||||
|
|
||||||
def methods(instance):
|
def methods(instance):
|
||||||
# Get class/instance methods
|
# Get class/instance methods
|
||||||
return [f for f in dir(instance) if callable(getattr(instance, f)) and not f.startswith("__")]
|
return [f for f in dir(instance) if callable(getattr(instance, f)) and not f.startswith("__")]
|
||||||
|
@ -319,7 +298,7 @@ def git_describe(path=ROOT): # path must be a directory
|
||||||
return ''
|
return ''
|
||||||
|
|
||||||
|
|
||||||
@try_except
|
@TryExcept()
|
||||||
@WorkingDirectory(ROOT)
|
@WorkingDirectory(ROOT)
|
||||||
def check_git_status(repo='ultralytics/yolov5'):
|
def check_git_status(repo='ultralytics/yolov5'):
|
||||||
# YOLOv5 status check, recommend 'git pull' if code is out of date
|
# YOLOv5 status check, recommend 'git pull' if code is out of date
|
||||||
|
@ -364,7 +343,7 @@ def check_version(current='0.0.0', minimum='0.0.0', name='version ', pinned=Fals
|
||||||
return result
|
return result
|
||||||
|
|
||||||
|
|
||||||
@try_except
|
@TryExcept()
|
||||||
def check_requirements(requirements=ROOT / 'requirements.txt', exclude=(), install=True, cmds=()):
|
def check_requirements(requirements=ROOT / 'requirements.txt', exclude=(), install=True, cmds=()):
|
||||||
# Check installed dependencies meet YOLOv5 requirements (pass *.txt file or list of packages)
|
# Check installed dependencies meet YOLOv5 requirements (pass *.txt file or list of packages)
|
||||||
prefix = colorstr('red', 'bold', 'requirements:')
|
prefix = colorstr('red', 'bold', 'requirements:')
|
||||||
|
|
|
@ -11,6 +11,8 @@ import matplotlib.pyplot as plt
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
|
from utils import TryExcept, threaded
|
||||||
|
|
||||||
|
|
||||||
def fitness(x):
|
def fitness(x):
|
||||||
# Model fitness as a weighted combination of metrics
|
# Model fitness as a weighted combination of metrics
|
||||||
|
@ -184,20 +186,21 @@ class ConfusionMatrix:
|
||||||
# fn = self.matrix.sum(0) - tp # false negatives (missed detections)
|
# fn = self.matrix.sum(0) - tp # false negatives (missed detections)
|
||||||
return tp[:-1], fp[:-1] # remove background class
|
return tp[:-1], fp[:-1] # remove background class
|
||||||
|
|
||||||
|
@TryExcept('WARNING: ConfusionMatrix plot failure')
|
||||||
def plot(self, normalize=True, save_dir='', names=()):
|
def plot(self, normalize=True, save_dir='', names=()):
|
||||||
try:
|
|
||||||
import seaborn as sn
|
import seaborn as sn
|
||||||
|
|
||||||
array = self.matrix / ((self.matrix.sum(0).reshape(1, -1) + 1E-9) if normalize else 1) # normalize columns
|
array = self.matrix / ((self.matrix.sum(0).reshape(1, -1) + 1E-9) if normalize else 1) # normalize columns
|
||||||
array[array < 0.005] = np.nan # don't annotate (would appear as 0.00)
|
array[array < 0.005] = np.nan # don't annotate (would appear as 0.00)
|
||||||
|
|
||||||
fig = plt.figure(figsize=(12, 9), tight_layout=True)
|
fig, ax = plt.subplots(1, 1, figsize=(12, 9), tight_layout=True)
|
||||||
nc, nn = self.nc, len(names) # number of classes, names
|
nc, nn = self.nc, len(names) # number of classes, names
|
||||||
sn.set(font_scale=1.0 if nc < 50 else 0.8) # for label size
|
sn.set(font_scale=1.0 if nc < 50 else 0.8) # for label size
|
||||||
labels = (0 < nn < 99) and (nn == nc) # apply names to ticklabels
|
labels = (0 < nn < 99) and (nn == nc) # apply names to ticklabels
|
||||||
with warnings.catch_warnings():
|
with warnings.catch_warnings():
|
||||||
warnings.simplefilter('ignore') # suppress empty matrix RuntimeWarning: All-NaN slice encountered
|
warnings.simplefilter('ignore') # suppress empty matrix RuntimeWarning: All-NaN slice encountered
|
||||||
sn.heatmap(array,
|
sn.heatmap(array,
|
||||||
|
ax=ax,
|
||||||
annot=nc < 30,
|
annot=nc < 30,
|
||||||
annot_kws={
|
annot_kws={
|
||||||
"size": 8},
|
"size": 8},
|
||||||
|
@ -207,13 +210,11 @@ class ConfusionMatrix:
|
||||||
vmin=0.0,
|
vmin=0.0,
|
||||||
xticklabels=names + ['background FP'] if labels else "auto",
|
xticklabels=names + ['background FP'] if labels else "auto",
|
||||||
yticklabels=names + ['background FN'] if labels else "auto").set_facecolor((1, 1, 1))
|
yticklabels=names + ['background FN'] if labels else "auto").set_facecolor((1, 1, 1))
|
||||||
fig.axes[0].set_xlabel('True')
|
ax.set_ylabel('True')
|
||||||
fig.axes[0].set_ylabel('Predicted')
|
ax.set_ylabel('Predicted')
|
||||||
plt.title('Confusion Matrix')
|
ax.set_title('Confusion Matrix')
|
||||||
fig.savefig(Path(save_dir) / 'confusion_matrix.png', dpi=250)
|
fig.savefig(Path(save_dir) / 'confusion_matrix.png', dpi=250)
|
||||||
plt.close()
|
plt.close(fig)
|
||||||
except Exception as e:
|
|
||||||
print(f'WARNING: ConfusionMatrix plot failure: {e}')
|
|
||||||
|
|
||||||
def print(self):
|
def print(self):
|
||||||
for i in range(self.nc + 1):
|
for i in range(self.nc + 1):
|
||||||
|
@ -320,6 +321,7 @@ def wh_iou(wh1, wh2, eps=1e-7):
|
||||||
# Plots ----------------------------------------------------------------------------------------------------------------
|
# Plots ----------------------------------------------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
@threaded
|
||||||
def plot_pr_curve(px, py, ap, save_dir=Path('pr_curve.png'), names=()):
|
def plot_pr_curve(px, py, ap, save_dir=Path('pr_curve.png'), names=()):
|
||||||
# Precision-recall curve
|
# Precision-recall curve
|
||||||
fig, ax = plt.subplots(1, 1, figsize=(9, 6), tight_layout=True)
|
fig, ax = plt.subplots(1, 1, figsize=(9, 6), tight_layout=True)
|
||||||
|
@ -336,12 +338,13 @@ def plot_pr_curve(px, py, ap, save_dir=Path('pr_curve.png'), names=()):
|
||||||
ax.set_ylabel('Precision')
|
ax.set_ylabel('Precision')
|
||||||
ax.set_xlim(0, 1)
|
ax.set_xlim(0, 1)
|
||||||
ax.set_ylim(0, 1)
|
ax.set_ylim(0, 1)
|
||||||
plt.legend(bbox_to_anchor=(1.04, 1), loc="upper left")
|
ax.legend(bbox_to_anchor=(1.04, 1), loc="upper left")
|
||||||
plt.title('Precision-Recall Curve')
|
ax.set_title('Precision-Recall Curve')
|
||||||
fig.savefig(save_dir, dpi=250)
|
fig.savefig(save_dir, dpi=250)
|
||||||
plt.close()
|
plt.close(fig)
|
||||||
|
|
||||||
|
|
||||||
|
@threaded
|
||||||
def plot_mc_curve(px, py, save_dir=Path('mc_curve.png'), names=(), xlabel='Confidence', ylabel='Metric'):
|
def plot_mc_curve(px, py, save_dir=Path('mc_curve.png'), names=(), xlabel='Confidence', ylabel='Metric'):
|
||||||
# Metric-confidence curve
|
# Metric-confidence curve
|
||||||
fig, ax = plt.subplots(1, 1, figsize=(9, 6), tight_layout=True)
|
fig, ax = plt.subplots(1, 1, figsize=(9, 6), tight_layout=True)
|
||||||
|
@ -358,7 +361,7 @@ def plot_mc_curve(px, py, save_dir=Path('mc_curve.png'), names=(), xlabel='Confi
|
||||||
ax.set_ylabel(ylabel)
|
ax.set_ylabel(ylabel)
|
||||||
ax.set_xlim(0, 1)
|
ax.set_xlim(0, 1)
|
||||||
ax.set_ylim(0, 1)
|
ax.set_ylim(0, 1)
|
||||||
plt.legend(bbox_to_anchor=(1.04, 1), loc="upper left")
|
ax.legend(bbox_to_anchor=(1.04, 1), loc="upper left")
|
||||||
plt.title(f'{ylabel}-Confidence Curve')
|
ax.set_title(f'{ylabel}-Confidence Curve')
|
||||||
fig.savefig(save_dir, dpi=250)
|
fig.savefig(save_dir, dpi=250)
|
||||||
plt.close()
|
plt.close(fig)
|
||||||
|
|
|
@ -19,8 +19,9 @@ import seaborn as sn
|
||||||
import torch
|
import torch
|
||||||
from PIL import Image, ImageDraw, ImageFont
|
from PIL import Image, ImageDraw, ImageFont
|
||||||
|
|
||||||
|
from utils import TryExcept, threaded
|
||||||
from utils.general import (CONFIG_DIR, FONT, LOGGER, check_font, check_requirements, clip_coords, increment_path,
|
from utils.general import (CONFIG_DIR, FONT, LOGGER, check_font, check_requirements, clip_coords, increment_path,
|
||||||
is_ascii, threaded, try_except, xywh2xyxy, xyxy2xywh)
|
is_ascii, xywh2xyxy, xyxy2xywh)
|
||||||
from utils.metrics import fitness
|
from utils.metrics import fitness
|
||||||
|
|
||||||
# Settings
|
# Settings
|
||||||
|
@ -339,7 +340,7 @@ def plot_val_study(file='', dir='', x=None): # from utils.plots import *; plot_
|
||||||
plt.savefig(f, dpi=300)
|
plt.savefig(f, dpi=300)
|
||||||
|
|
||||||
|
|
||||||
@try_except # known issue https://github.com/ultralytics/yolov5/issues/5395
|
@TryExcept() # known issue https://github.com/ultralytics/yolov5/issues/5395
|
||||||
def plot_labels(labels, names=(), save_dir=Path('')):
|
def plot_labels(labels, names=(), save_dir=Path('')):
|
||||||
# plot dataset labels
|
# plot dataset labels
|
||||||
LOGGER.info(f"Plotting labels to {save_dir / 'labels.jpg'}... ")
|
LOGGER.info(f"Plotting labels to {save_dir / 'labels.jpg'}... ")
|
||||||
|
|
Loading…
Reference in New Issue