mirror of
https://github.com/ultralytics/yolov5.git
synced 2025-06-03 14:49:29 +08:00
* [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Fix duplicate plots.py * Fix check_font() * # torch.use_deterministic_algorithms(True) * update doc detect->predict * Resolve precommit for segment/train and segment/val * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Resolve precommit for utils/segment * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Resolve precommit min_wh * Resolve precommit utils/segment/plots * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Resolve precommit utils/segment/general * Align NMS-seg closer to NMS * restore deterministic init_seeds code * remove easydict dependency * update * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * restore output_to_target mask * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * update * cleanup * Remove unused ImageFont import * Unified NMS * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * DetectMultiBackend compatibility * segment/predict.py update * update plot colors * fix bbox shifted * sort bbox by confidence * enable overlap by default * Merge detect/segment output_to_target() function * Start segmentation CI * fix plots * Update ci-testing.yml * fix training whitespace * optimize process mask functions (can we merge both?) * Update predict/detect * Update plot_images * Update plot_images_and_masks * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix * Add train to CI * fix precommit * fix precommit CI * fix precommit pycocotools * fix val float issues * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix masks float float issues * suppress errors * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix no-predictions plotting bug * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Add CSV Logger * fix val len(plot_masks) * speed up evaluation * fix process_mask * fix plots * update segment/utils build_targets * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * optimize utils/segment/general crop() * optimize utils/segment/general crop() 2 * minor updates * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * torch.where revert * downsample only if different shape * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * loss cleanup * loss cleanup 2 * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * loss cleanup 3 * update project names * Rename -seg yamls from _underscore to -dash * prepare for yolov5n-seg.pt * precommit space fix * add coco128-seg.yaml * update coco128-seg comments * cleanup val.py * Major val.py cleanup * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * precommit fix * precommit fix * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * optional pycocotools * remove CI pip install pycocotools (auto-installed now) * seg yaml fix * optimize mask_iou() and masks_iou() * threaded fix * Major train.py update * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Major segments/val/process_batch() update * yolov5/val updates from segment * process_batch numpy/tensor fix * opt-in to pycocotools with --save-json * threaded pycocotools ops for 2x speed increase * Avoid permute contiguous if possible * Add max_det=300 argument to both val.py and segment/val.py * fix onnx_dynamic * speed up pycocotools ops * faster process_mask(upsample=True) for predict * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * eliminate permutations for process_mask(upsample=True) * eliminate permute-contiguous in crop(), use native dimension order * cleanup comment * Add Proto() module * fix class count * fix anchor order * broadcast mask_gti in loss for speed * Cleanup seg loss * faster indexing * faster indexing fix * faster indexing fix2 * revert faster indexing * fix validation plotting * Loss cleanup and mxyxy simplification * Loss cleanup and mxyxy simplification 2 * revert validation plotting * replace missing tanh * Eliminate last permutation * delete unneeded .float() * Remove MaskIOULoss and crop(if HWC) * Final v6.3 SegmentationModel architecture updates * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Add support for TF export * remove debugger trace * add call * update * update * Merge master * Merge master * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Update dataloaders.py * Restore CI * Update dataloaders.py * Fix TF/TFLite export for segmentation model * Merge master * Cleanup predict.py mask plotting * cleanup scale_masks() * rename scale_masks to scale_image * cleanup/optimize plot_masks * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Add Annotator.masks() * Annotator.masks() fix * Update plots.py * Annotator mask optimization * Rename crop() to crop_mask() * Do not crop in predict.py * crop always * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Merge master * Add vid-stride from master PR * Update seg model outputs * Update seg model outputs * Add segmentation benchmarks * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Add segmentation benchmarks * Add segmentation benchmarks * Add segmentation benchmarks * Fix DetectMultiBackend for OpenVINO * update Annotator.masks * fix val plot * revert val plot * clean up * revert pil * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Fix CI error * fix predict log * remove upsample * update interpolate * fix validation plot logging * Annotator.masks() cleanup * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Remove segmentation_model definition * Restore 0.99999 decimals Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Glenn Jocher <glenn.jocher@ultralytics.com> Co-authored-by: Laughing-q <1185102784@qq.com> Co-authored-by: Jiacong Fang <zldrobit@126.com>
105 lines
3.7 KiB
Python
105 lines
3.7 KiB
Python
# YOLOv5 🚀 by Ultralytics, GPL-3.0 license
|
|
"""
|
|
Image augmentation functions
|
|
"""
|
|
|
|
import math
|
|
import random
|
|
|
|
import cv2
|
|
import numpy as np
|
|
|
|
from ..augmentations import box_candidates
|
|
from ..general import resample_segments, segment2box
|
|
|
|
|
|
def mixup(im, labels, segments, im2, labels2, segments2):
|
|
# Applies MixUp augmentation https://arxiv.org/pdf/1710.09412.pdf
|
|
r = np.random.beta(32.0, 32.0) # mixup ratio, alpha=beta=32.0
|
|
im = (im * r + im2 * (1 - r)).astype(np.uint8)
|
|
labels = np.concatenate((labels, labels2), 0)
|
|
segments = np.concatenate((segments, segments2), 0)
|
|
return im, labels, segments
|
|
|
|
|
|
def random_perspective(im,
|
|
targets=(),
|
|
segments=(),
|
|
degrees=10,
|
|
translate=.1,
|
|
scale=.1,
|
|
shear=10,
|
|
perspective=0.0,
|
|
border=(0, 0)):
|
|
# torchvision.transforms.RandomAffine(degrees=(-10, 10), translate=(.1, .1), scale=(.9, 1.1), shear=(-10, 10))
|
|
# targets = [cls, xyxy]
|
|
|
|
height = im.shape[0] + border[0] * 2 # shape(h,w,c)
|
|
width = im.shape[1] + border[1] * 2
|
|
|
|
# Center
|
|
C = np.eye(3)
|
|
C[0, 2] = -im.shape[1] / 2 # x translation (pixels)
|
|
C[1, 2] = -im.shape[0] / 2 # y translation (pixels)
|
|
|
|
# Perspective
|
|
P = np.eye(3)
|
|
P[2, 0] = random.uniform(-perspective, perspective) # x perspective (about y)
|
|
P[2, 1] = random.uniform(-perspective, perspective) # y perspective (about x)
|
|
|
|
# Rotation and Scale
|
|
R = np.eye(3)
|
|
a = random.uniform(-degrees, degrees)
|
|
# a += random.choice([-180, -90, 0, 90]) # add 90deg rotations to small rotations
|
|
s = random.uniform(1 - scale, 1 + scale)
|
|
# s = 2 ** random.uniform(-scale, scale)
|
|
R[:2] = cv2.getRotationMatrix2D(angle=a, center=(0, 0), scale=s)
|
|
|
|
# Shear
|
|
S = np.eye(3)
|
|
S[0, 1] = math.tan(random.uniform(-shear, shear) * math.pi / 180) # x shear (deg)
|
|
S[1, 0] = math.tan(random.uniform(-shear, shear) * math.pi / 180) # y shear (deg)
|
|
|
|
# Translation
|
|
T = np.eye(3)
|
|
T[0, 2] = (random.uniform(0.5 - translate, 0.5 + translate) * width) # x translation (pixels)
|
|
T[1, 2] = (random.uniform(0.5 - translate, 0.5 + translate) * height) # y translation (pixels)
|
|
|
|
# Combined rotation matrix
|
|
M = T @ S @ R @ P @ C # order of operations (right to left) is IMPORTANT
|
|
if (border[0] != 0) or (border[1] != 0) or (M != np.eye(3)).any(): # image changed
|
|
if perspective:
|
|
im = cv2.warpPerspective(im, M, dsize=(width, height), borderValue=(114, 114, 114))
|
|
else: # affine
|
|
im = cv2.warpAffine(im, M[:2], dsize=(width, height), borderValue=(114, 114, 114))
|
|
|
|
# Visualize
|
|
# import matplotlib.pyplot as plt
|
|
# ax = plt.subplots(1, 2, figsize=(12, 6))[1].ravel()
|
|
# ax[0].imshow(im[:, :, ::-1]) # base
|
|
# ax[1].imshow(im2[:, :, ::-1]) # warped
|
|
|
|
# Transform label coordinates
|
|
n = len(targets)
|
|
new_segments = []
|
|
if n:
|
|
new = np.zeros((n, 4))
|
|
segments = resample_segments(segments) # upsample
|
|
for i, segment in enumerate(segments):
|
|
xy = np.ones((len(segment), 3))
|
|
xy[:, :2] = segment
|
|
xy = xy @ M.T # transform
|
|
xy = (xy[:, :2] / xy[:, 2:3] if perspective else xy[:, :2]) # perspective rescale or affine
|
|
|
|
# clip
|
|
new[i] = segment2box(xy, width, height)
|
|
new_segments.append(xy)
|
|
|
|
# filter candidates
|
|
i = box_candidates(box1=targets[:, 1:5].T * s, box2=new.T, area_thr=0.01)
|
|
targets = targets[i]
|
|
targets[:, 1:5] = new[i]
|
|
new_segments = np.array(new_segments)[i]
|
|
|
|
return im, targets, new_segments
|