mirror of
https://github.com/ultralytics/yolov5.git
synced 2025-06-03 14:49:29 +08:00
* [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Fix duplicate plots.py * Fix check_font() * # torch.use_deterministic_algorithms(True) * update doc detect->predict * Resolve precommit for segment/train and segment/val * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Resolve precommit for utils/segment * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Resolve precommit min_wh * Resolve precommit utils/segment/plots * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Resolve precommit utils/segment/general * Align NMS-seg closer to NMS * restore deterministic init_seeds code * remove easydict dependency * update * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * restore output_to_target mask * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * update * cleanup * Remove unused ImageFont import * Unified NMS * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * DetectMultiBackend compatibility * segment/predict.py update * update plot colors * fix bbox shifted * sort bbox by confidence * enable overlap by default * Merge detect/segment output_to_target() function * Start segmentation CI * fix plots * Update ci-testing.yml * fix training whitespace * optimize process mask functions (can we merge both?) * Update predict/detect * Update plot_images * Update plot_images_and_masks * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix * Add train to CI * fix precommit * fix precommit CI * fix precommit pycocotools * fix val float issues * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix masks float float issues * suppress errors * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix no-predictions plotting bug * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Add CSV Logger * fix val len(plot_masks) * speed up evaluation * fix process_mask * fix plots * update segment/utils build_targets * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * optimize utils/segment/general crop() * optimize utils/segment/general crop() 2 * minor updates * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * torch.where revert * downsample only if different shape * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * loss cleanup * loss cleanup 2 * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * loss cleanup 3 * update project names * Rename -seg yamls from _underscore to -dash * prepare for yolov5n-seg.pt * precommit space fix * add coco128-seg.yaml * update coco128-seg comments * cleanup val.py * Major val.py cleanup * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * precommit fix * precommit fix * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * optional pycocotools * remove CI pip install pycocotools (auto-installed now) * seg yaml fix * optimize mask_iou() and masks_iou() * threaded fix * Major train.py update * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Major segments/val/process_batch() update * yolov5/val updates from segment * process_batch numpy/tensor fix * opt-in to pycocotools with --save-json * threaded pycocotools ops for 2x speed increase * Avoid permute contiguous if possible * Add max_det=300 argument to both val.py and segment/val.py * fix onnx_dynamic * speed up pycocotools ops * faster process_mask(upsample=True) for predict * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * eliminate permutations for process_mask(upsample=True) * eliminate permute-contiguous in crop(), use native dimension order * cleanup comment * Add Proto() module * fix class count * fix anchor order * broadcast mask_gti in loss for speed * Cleanup seg loss * faster indexing * faster indexing fix * faster indexing fix2 * revert faster indexing * fix validation plotting * Loss cleanup and mxyxy simplification * Loss cleanup and mxyxy simplification 2 * revert validation plotting * replace missing tanh * Eliminate last permutation * delete unneeded .float() * Remove MaskIOULoss and crop(if HWC) * Final v6.3 SegmentationModel architecture updates * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Add support for TF export * remove debugger trace * add call * update * update * Merge master * Merge master * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Update dataloaders.py * Restore CI * Update dataloaders.py * Fix TF/TFLite export for segmentation model * Merge master * Cleanup predict.py mask plotting * cleanup scale_masks() * rename scale_masks to scale_image * cleanup/optimize plot_masks * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Add Annotator.masks() * Annotator.masks() fix * Update plots.py * Annotator mask optimization * Rename crop() to crop_mask() * Do not crop in predict.py * crop always * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Merge master * Add vid-stride from master PR * Update seg model outputs * Update seg model outputs * Add segmentation benchmarks * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Add segmentation benchmarks * Add segmentation benchmarks * Add segmentation benchmarks * Fix DetectMultiBackend for OpenVINO * update Annotator.masks * fix val plot * revert val plot * clean up * revert pil * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Fix CI error * fix predict log * remove upsample * update interpolate * fix validation plot logging * Annotator.masks() cleanup * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Remove segmentation_model definition * Restore 0.99999 decimals Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Glenn Jocher <glenn.jocher@ultralytics.com> Co-authored-by: Laughing-q <1185102784@qq.com> Co-authored-by: Jiacong Fang <zldrobit@126.com>
144 lines
6.2 KiB
Python
144 lines
6.2 KiB
Python
import contextlib
|
|
import math
|
|
from pathlib import Path
|
|
|
|
import cv2
|
|
import matplotlib.pyplot as plt
|
|
import numpy as np
|
|
import pandas as pd
|
|
import torch
|
|
|
|
from .. import threaded
|
|
from ..general import xywh2xyxy
|
|
from ..plots import Annotator, colors
|
|
|
|
|
|
@threaded
|
|
def plot_images_and_masks(images, targets, masks, paths=None, fname='images.jpg', names=None):
|
|
# Plot image grid with labels
|
|
if isinstance(images, torch.Tensor):
|
|
images = images.cpu().float().numpy()
|
|
if isinstance(targets, torch.Tensor):
|
|
targets = targets.cpu().numpy()
|
|
if isinstance(masks, torch.Tensor):
|
|
masks = masks.cpu().numpy().astype(int)
|
|
|
|
max_size = 1920 # max image size
|
|
max_subplots = 16 # max image subplots, i.e. 4x4
|
|
bs, _, h, w = images.shape # batch size, _, height, width
|
|
bs = min(bs, max_subplots) # limit plot images
|
|
ns = np.ceil(bs ** 0.5) # number of subplots (square)
|
|
if np.max(images[0]) <= 1:
|
|
images *= 255 # de-normalise (optional)
|
|
|
|
# Build Image
|
|
mosaic = np.full((int(ns * h), int(ns * w), 3), 255, dtype=np.uint8) # init
|
|
for i, im in enumerate(images):
|
|
if i == max_subplots: # if last batch has fewer images than we expect
|
|
break
|
|
x, y = int(w * (i // ns)), int(h * (i % ns)) # block origin
|
|
im = im.transpose(1, 2, 0)
|
|
mosaic[y:y + h, x:x + w, :] = im
|
|
|
|
# Resize (optional)
|
|
scale = max_size / ns / max(h, w)
|
|
if scale < 1:
|
|
h = math.ceil(scale * h)
|
|
w = math.ceil(scale * w)
|
|
mosaic = cv2.resize(mosaic, tuple(int(x * ns) for x in (w, h)))
|
|
|
|
# Annotate
|
|
fs = int((h + w) * ns * 0.01) # font size
|
|
annotator = Annotator(mosaic, line_width=round(fs / 10), font_size=fs, pil=True, example=names)
|
|
for i in range(i + 1):
|
|
x, y = int(w * (i // ns)), int(h * (i % ns)) # block origin
|
|
annotator.rectangle([x, y, x + w, y + h], None, (255, 255, 255), width=2) # borders
|
|
if paths:
|
|
annotator.text((x + 5, y + 5 + h), text=Path(paths[i]).name[:40], txt_color=(220, 220, 220)) # filenames
|
|
if len(targets) > 0:
|
|
idx = targets[:, 0] == i
|
|
ti = targets[idx] # image targets
|
|
|
|
boxes = xywh2xyxy(ti[:, 2:6]).T
|
|
classes = ti[:, 1].astype('int')
|
|
labels = ti.shape[1] == 6 # labels if no conf column
|
|
conf = None if labels else ti[:, 6] # check for confidence presence (label vs pred)
|
|
|
|
if boxes.shape[1]:
|
|
if boxes.max() <= 1.01: # if normalized with tolerance 0.01
|
|
boxes[[0, 2]] *= w # scale to pixels
|
|
boxes[[1, 3]] *= h
|
|
elif scale < 1: # absolute coords need scale if image scales
|
|
boxes *= scale
|
|
boxes[[0, 2]] += x
|
|
boxes[[1, 3]] += y
|
|
for j, box in enumerate(boxes.T.tolist()):
|
|
cls = classes[j]
|
|
color = colors(cls)
|
|
cls = names[cls] if names else cls
|
|
if labels or conf[j] > 0.25: # 0.25 conf thresh
|
|
label = f'{cls}' if labels else f'{cls} {conf[j]:.1f}'
|
|
annotator.box_label(box, label, color=color)
|
|
|
|
# Plot masks
|
|
if len(masks):
|
|
if masks.max() > 1.0: # mean that masks are overlap
|
|
image_masks = masks[[i]] # (1, 640, 640)
|
|
nl = len(ti)
|
|
index = np.arange(nl).reshape(nl, 1, 1) + 1
|
|
image_masks = np.repeat(image_masks, nl, axis=0)
|
|
image_masks = np.where(image_masks == index, 1.0, 0.0)
|
|
else:
|
|
image_masks = masks[idx]
|
|
|
|
im = np.asarray(annotator.im).copy()
|
|
for j, box in enumerate(boxes.T.tolist()):
|
|
if labels or conf[j] > 0.25: # 0.25 conf thresh
|
|
color = colors(classes[j])
|
|
mh, mw = image_masks[j].shape
|
|
if mh != h or mw != w:
|
|
mask = image_masks[j].astype(np.uint8)
|
|
mask = cv2.resize(mask, (w, h))
|
|
mask = mask.astype(np.bool)
|
|
else:
|
|
mask = image_masks[j].astype(np.bool)
|
|
with contextlib.suppress(Exception):
|
|
im[y:y + h, x:x + w, :][mask] = im[y:y + h, x:x + w, :][mask] * 0.4 + np.array(color) * 0.6
|
|
annotator.fromarray(im)
|
|
annotator.im.save(fname) # save
|
|
|
|
|
|
def plot_results_with_masks(file="path/to/results.csv", dir="", best=True):
|
|
# Plot training results.csv. Usage: from utils.plots import *; plot_results('path/to/results.csv')
|
|
save_dir = Path(file).parent if file else Path(dir)
|
|
fig, ax = plt.subplots(2, 8, figsize=(18, 6), tight_layout=True)
|
|
ax = ax.ravel()
|
|
files = list(save_dir.glob("results*.csv"))
|
|
assert len(files), f"No results.csv files found in {save_dir.resolve()}, nothing to plot."
|
|
for f in files:
|
|
try:
|
|
data = pd.read_csv(f)
|
|
index = np.argmax(0.9 * data.values[:, 8] + 0.1 * data.values[:, 7] + 0.9 * data.values[:, 12] +
|
|
0.1 * data.values[:, 11])
|
|
s = [x.strip() for x in data.columns]
|
|
x = data.values[:, 0]
|
|
for i, j in enumerate([1, 2, 3, 4, 5, 6, 9, 10, 13, 14, 15, 16, 7, 8, 11, 12]):
|
|
y = data.values[:, j]
|
|
# y[y == 0] = np.nan # don't show zero values
|
|
ax[i].plot(x, y, marker=".", label=f.stem, linewidth=2, markersize=2)
|
|
if best:
|
|
# best
|
|
ax[i].scatter(index, y[index], color="r", label=f"best:{index}", marker="*", linewidth=3)
|
|
ax[i].set_title(s[j] + f"\n{round(y[index], 5)}")
|
|
else:
|
|
# last
|
|
ax[i].scatter(x[-1], y[-1], color="r", label="last", marker="*", linewidth=3)
|
|
ax[i].set_title(s[j] + f"\n{round(y[-1], 5)}")
|
|
# if j in [8, 9, 10]: # share train and val loss y axes
|
|
# ax[i].get_shared_y_axes().join(ax[i], ax[i - 5])
|
|
except Exception as e:
|
|
print(f"Warning: Plotting error for {f}: {e}")
|
|
ax[1].legend()
|
|
fig.savefig(save_dir / "results.png", dpi=200)
|
|
plt.close()
|