mirror of
https://github.com/ultralytics/yolov5.git
synced 2025-06-03 14:49:29 +08:00
* [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Fix duplicate plots.py * Fix check_font() * # torch.use_deterministic_algorithms(True) * update doc detect->predict * Resolve precommit for segment/train and segment/val * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Resolve precommit for utils/segment * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Resolve precommit min_wh * Resolve precommit utils/segment/plots * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Resolve precommit utils/segment/general * Align NMS-seg closer to NMS * restore deterministic init_seeds code * remove easydict dependency * update * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * restore output_to_target mask * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * update * cleanup * Remove unused ImageFont import * Unified NMS * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * DetectMultiBackend compatibility * segment/predict.py update * update plot colors * fix bbox shifted * sort bbox by confidence * enable overlap by default * Merge detect/segment output_to_target() function * Start segmentation CI * fix plots * Update ci-testing.yml * fix training whitespace * optimize process mask functions (can we merge both?) * Update predict/detect * Update plot_images * Update plot_images_and_masks * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix * Add train to CI * fix precommit * fix precommit CI * fix precommit pycocotools * fix val float issues * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix masks float float issues * suppress errors * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix no-predictions plotting bug * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Add CSV Logger * fix val len(plot_masks) * speed up evaluation * fix process_mask * fix plots * update segment/utils build_targets * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * optimize utils/segment/general crop() * optimize utils/segment/general crop() 2 * minor updates * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * torch.where revert * downsample only if different shape * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * loss cleanup * loss cleanup 2 * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * loss cleanup 3 * update project names * Rename -seg yamls from _underscore to -dash * prepare for yolov5n-seg.pt * precommit space fix * add coco128-seg.yaml * update coco128-seg comments * cleanup val.py * Major val.py cleanup * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * precommit fix * precommit fix * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * optional pycocotools * remove CI pip install pycocotools (auto-installed now) * seg yaml fix * optimize mask_iou() and masks_iou() * threaded fix * Major train.py update * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Major segments/val/process_batch() update * yolov5/val updates from segment * process_batch numpy/tensor fix * opt-in to pycocotools with --save-json * threaded pycocotools ops for 2x speed increase * Avoid permute contiguous if possible * Add max_det=300 argument to both val.py and segment/val.py * fix onnx_dynamic * speed up pycocotools ops * faster process_mask(upsample=True) for predict * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * eliminate permutations for process_mask(upsample=True) * eliminate permute-contiguous in crop(), use native dimension order * cleanup comment * Add Proto() module * fix class count * fix anchor order * broadcast mask_gti in loss for speed * Cleanup seg loss * faster indexing * faster indexing fix * faster indexing fix2 * revert faster indexing * fix validation plotting * Loss cleanup and mxyxy simplification * Loss cleanup and mxyxy simplification 2 * revert validation plotting * replace missing tanh * Eliminate last permutation * delete unneeded .float() * Remove MaskIOULoss and crop(if HWC) * Final v6.3 SegmentationModel architecture updates * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Add support for TF export * remove debugger trace * add call * update * update * Merge master * Merge master * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Update dataloaders.py * Restore CI * Update dataloaders.py * Fix TF/TFLite export for segmentation model * Merge master * Cleanup predict.py mask plotting * cleanup scale_masks() * rename scale_masks to scale_image * cleanup/optimize plot_masks * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Add Annotator.masks() * Annotator.masks() fix * Update plots.py * Annotator mask optimization * Rename crop() to crop_mask() * Do not crop in predict.py * crop always * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Merge master * Add vid-stride from master PR * Update seg model outputs * Update seg model outputs * Add segmentation benchmarks * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Add segmentation benchmarks * Add segmentation benchmarks * Add segmentation benchmarks * Fix DetectMultiBackend for OpenVINO * update Annotator.masks * fix val plot * revert val plot * clean up * revert pil * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Fix CI error * fix predict log * remove upsample * update interpolate * fix validation plot logging * Annotator.masks() cleanup * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Remove segmentation_model definition * Restore 0.99999 decimals Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Glenn Jocher <glenn.jocher@ultralytics.com> Co-authored-by: Laughing-q <1185102784@qq.com> Co-authored-by: Jiacong Fang <zldrobit@126.com>
331 lines
13 KiB
Python
331 lines
13 KiB
Python
# YOLOv5 🚀 by Ultralytics, GPL-3.0 license
|
|
"""
|
|
Dataloaders
|
|
"""
|
|
|
|
import os
|
|
import random
|
|
|
|
import cv2
|
|
import numpy as np
|
|
import torch
|
|
from torch.utils.data import DataLoader, distributed
|
|
|
|
from ..augmentations import augment_hsv, copy_paste, letterbox
|
|
from ..dataloaders import InfiniteDataLoader, LoadImagesAndLabels, seed_worker
|
|
from ..general import LOGGER, xyn2xy, xywhn2xyxy, xyxy2xywhn
|
|
from ..torch_utils import torch_distributed_zero_first
|
|
from .augmentations import mixup, random_perspective
|
|
|
|
|
|
def create_dataloader(path,
|
|
imgsz,
|
|
batch_size,
|
|
stride,
|
|
single_cls=False,
|
|
hyp=None,
|
|
augment=False,
|
|
cache=False,
|
|
pad=0.0,
|
|
rect=False,
|
|
rank=-1,
|
|
workers=8,
|
|
image_weights=False,
|
|
quad=False,
|
|
prefix='',
|
|
shuffle=False,
|
|
mask_downsample_ratio=1,
|
|
overlap_mask=False):
|
|
if rect and shuffle:
|
|
LOGGER.warning('WARNING: --rect is incompatible with DataLoader shuffle, setting shuffle=False')
|
|
shuffle = False
|
|
with torch_distributed_zero_first(rank): # init dataset *.cache only once if DDP
|
|
dataset = LoadImagesAndLabelsAndMasks(
|
|
path,
|
|
imgsz,
|
|
batch_size,
|
|
augment=augment, # augmentation
|
|
hyp=hyp, # hyperparameters
|
|
rect=rect, # rectangular batches
|
|
cache_images=cache,
|
|
single_cls=single_cls,
|
|
stride=int(stride),
|
|
pad=pad,
|
|
image_weights=image_weights,
|
|
prefix=prefix,
|
|
downsample_ratio=mask_downsample_ratio,
|
|
overlap=overlap_mask)
|
|
|
|
batch_size = min(batch_size, len(dataset))
|
|
nd = torch.cuda.device_count() # number of CUDA devices
|
|
nw = min([os.cpu_count() // max(nd, 1), batch_size if batch_size > 1 else 0, workers]) # number of workers
|
|
sampler = None if rank == -1 else distributed.DistributedSampler(dataset, shuffle=shuffle)
|
|
loader = DataLoader if image_weights else InfiniteDataLoader # only DataLoader allows for attribute updates
|
|
# generator = torch.Generator()
|
|
# generator.manual_seed(0)
|
|
return loader(
|
|
dataset,
|
|
batch_size=batch_size,
|
|
shuffle=shuffle and sampler is None,
|
|
num_workers=nw,
|
|
sampler=sampler,
|
|
pin_memory=True,
|
|
collate_fn=LoadImagesAndLabelsAndMasks.collate_fn4 if quad else LoadImagesAndLabelsAndMasks.collate_fn,
|
|
worker_init_fn=seed_worker,
|
|
# generator=generator,
|
|
), dataset
|
|
|
|
|
|
class LoadImagesAndLabelsAndMasks(LoadImagesAndLabels): # for training/testing
|
|
|
|
def __init__(
|
|
self,
|
|
path,
|
|
img_size=640,
|
|
batch_size=16,
|
|
augment=False,
|
|
hyp=None,
|
|
rect=False,
|
|
image_weights=False,
|
|
cache_images=False,
|
|
single_cls=False,
|
|
stride=32,
|
|
pad=0,
|
|
prefix="",
|
|
downsample_ratio=1,
|
|
overlap=False,
|
|
):
|
|
super().__init__(path, img_size, batch_size, augment, hyp, rect, image_weights, cache_images, single_cls,
|
|
stride, pad, prefix)
|
|
self.downsample_ratio = downsample_ratio
|
|
self.overlap = overlap
|
|
|
|
def __getitem__(self, index):
|
|
index = self.indices[index] # linear, shuffled, or image_weights
|
|
|
|
hyp = self.hyp
|
|
mosaic = self.mosaic and random.random() < hyp['mosaic']
|
|
masks = []
|
|
if mosaic:
|
|
# Load mosaic
|
|
img, labels, segments = self.load_mosaic(index)
|
|
shapes = None
|
|
|
|
# MixUp augmentation
|
|
if random.random() < hyp["mixup"]:
|
|
img, labels, segments = mixup(img, labels, segments, *self.load_mosaic(random.randint(0, self.n - 1)))
|
|
|
|
else:
|
|
# Load image
|
|
img, (h0, w0), (h, w) = self.load_image(index)
|
|
|
|
# Letterbox
|
|
shape = self.batch_shapes[self.batch[index]] if self.rect else self.img_size # final letterboxed shape
|
|
img, ratio, pad = letterbox(img, shape, auto=False, scaleup=self.augment)
|
|
shapes = (h0, w0), ((h / h0, w / w0), pad) # for COCO mAP rescaling
|
|
|
|
labels = self.labels[index].copy()
|
|
# [array, array, ....], array.shape=(num_points, 2), xyxyxyxy
|
|
segments = self.segments[index].copy()
|
|
if len(segments):
|
|
for i_s in range(len(segments)):
|
|
segments[i_s] = xyn2xy(
|
|
segments[i_s],
|
|
ratio[0] * w,
|
|
ratio[1] * h,
|
|
padw=pad[0],
|
|
padh=pad[1],
|
|
)
|
|
if labels.size: # normalized xywh to pixel xyxy format
|
|
labels[:, 1:] = xywhn2xyxy(labels[:, 1:], ratio[0] * w, ratio[1] * h, padw=pad[0], padh=pad[1])
|
|
|
|
if self.augment:
|
|
img, labels, segments = random_perspective(
|
|
img,
|
|
labels,
|
|
segments=segments,
|
|
degrees=hyp["degrees"],
|
|
translate=hyp["translate"],
|
|
scale=hyp["scale"],
|
|
shear=hyp["shear"],
|
|
perspective=hyp["perspective"],
|
|
return_seg=True,
|
|
)
|
|
|
|
nl = len(labels) # number of labels
|
|
if nl:
|
|
labels[:, 1:5] = xyxy2xywhn(labels[:, 1:5], w=img.shape[1], h=img.shape[0], clip=True, eps=1e-3)
|
|
if self.overlap:
|
|
masks, sorted_idx = polygons2masks_overlap(img.shape[:2],
|
|
segments,
|
|
downsample_ratio=self.downsample_ratio)
|
|
masks = masks[None] # (640, 640) -> (1, 640, 640)
|
|
labels = labels[sorted_idx]
|
|
else:
|
|
masks = polygons2masks(img.shape[:2], segments, color=1, downsample_ratio=self.downsample_ratio)
|
|
|
|
masks = (torch.from_numpy(masks) if len(masks) else torch.zeros(1 if self.overlap else nl, img.shape[0] //
|
|
self.downsample_ratio, img.shape[1] //
|
|
self.downsample_ratio))
|
|
# TODO: albumentations support
|
|
if self.augment:
|
|
# Albumentations
|
|
# there are some augmentation that won't change boxes and masks,
|
|
# so just be it for now.
|
|
img, labels = self.albumentations(img, labels)
|
|
nl = len(labels) # update after albumentations
|
|
|
|
# HSV color-space
|
|
augment_hsv(img, hgain=hyp["hsv_h"], sgain=hyp["hsv_s"], vgain=hyp["hsv_v"])
|
|
|
|
# Flip up-down
|
|
if random.random() < hyp["flipud"]:
|
|
img = np.flipud(img)
|
|
if nl:
|
|
labels[:, 2] = 1 - labels[:, 2]
|
|
masks = torch.flip(masks, dims=[1])
|
|
|
|
# Flip left-right
|
|
if random.random() < hyp["fliplr"]:
|
|
img = np.fliplr(img)
|
|
if nl:
|
|
labels[:, 1] = 1 - labels[:, 1]
|
|
masks = torch.flip(masks, dims=[2])
|
|
|
|
# Cutouts # labels = cutout(img, labels, p=0.5)
|
|
|
|
labels_out = torch.zeros((nl, 6))
|
|
if nl:
|
|
labels_out[:, 1:] = torch.from_numpy(labels)
|
|
|
|
# Convert
|
|
img = img.transpose((2, 0, 1))[::-1] # HWC to CHW, BGR to RGB
|
|
img = np.ascontiguousarray(img)
|
|
|
|
return (torch.from_numpy(img), labels_out, self.im_files[index], shapes, masks)
|
|
|
|
def load_mosaic(self, index):
|
|
# YOLOv5 4-mosaic loader. Loads 1 image + 3 random images into a 4-image mosaic
|
|
labels4, segments4 = [], []
|
|
s = self.img_size
|
|
yc, xc = (int(random.uniform(-x, 2 * s + x)) for x in self.mosaic_border) # mosaic center x, y
|
|
|
|
# 3 additional image indices
|
|
indices = [index] + random.choices(self.indices, k=3) # 3 additional image indices
|
|
for i, index in enumerate(indices):
|
|
# Load image
|
|
img, _, (h, w) = self.load_image(index)
|
|
|
|
# place img in img4
|
|
if i == 0: # top left
|
|
img4 = np.full((s * 2, s * 2, img.shape[2]), 114, dtype=np.uint8) # base image with 4 tiles
|
|
x1a, y1a, x2a, y2a = max(xc - w, 0), max(yc - h, 0), xc, yc # xmin, ymin, xmax, ymax (large image)
|
|
x1b, y1b, x2b, y2b = w - (x2a - x1a), h - (y2a - y1a), w, h # xmin, ymin, xmax, ymax (small image)
|
|
elif i == 1: # top right
|
|
x1a, y1a, x2a, y2a = xc, max(yc - h, 0), min(xc + w, s * 2), yc
|
|
x1b, y1b, x2b, y2b = 0, h - (y2a - y1a), min(w, x2a - x1a), h
|
|
elif i == 2: # bottom left
|
|
x1a, y1a, x2a, y2a = max(xc - w, 0), yc, xc, min(s * 2, yc + h)
|
|
x1b, y1b, x2b, y2b = w - (x2a - x1a), 0, w, min(y2a - y1a, h)
|
|
elif i == 3: # bottom right
|
|
x1a, y1a, x2a, y2a = xc, yc, min(xc + w, s * 2), min(s * 2, yc + h)
|
|
x1b, y1b, x2b, y2b = 0, 0, min(w, x2a - x1a), min(y2a - y1a, h)
|
|
|
|
img4[y1a:y2a, x1a:x2a] = img[y1b:y2b, x1b:x2b] # img4[ymin:ymax, xmin:xmax]
|
|
padw = x1a - x1b
|
|
padh = y1a - y1b
|
|
|
|
labels, segments = self.labels[index].copy(), self.segments[index].copy()
|
|
|
|
if labels.size:
|
|
labels[:, 1:] = xywhn2xyxy(labels[:, 1:], w, h, padw, padh) # normalized xywh to pixel xyxy format
|
|
segments = [xyn2xy(x, w, h, padw, padh) for x in segments]
|
|
labels4.append(labels)
|
|
segments4.extend(segments)
|
|
|
|
# Concat/clip labels
|
|
labels4 = np.concatenate(labels4, 0)
|
|
for x in (labels4[:, 1:], *segments4):
|
|
np.clip(x, 0, 2 * s, out=x) # clip when using random_perspective()
|
|
# img4, labels4 = replicate(img4, labels4) # replicate
|
|
|
|
# Augment
|
|
img4, labels4, segments4 = copy_paste(img4, labels4, segments4, p=self.hyp["copy_paste"])
|
|
img4, labels4, segments4 = random_perspective(img4,
|
|
labels4,
|
|
segments4,
|
|
degrees=self.hyp["degrees"],
|
|
translate=self.hyp["translate"],
|
|
scale=self.hyp["scale"],
|
|
shear=self.hyp["shear"],
|
|
perspective=self.hyp["perspective"],
|
|
border=self.mosaic_border) # border to remove
|
|
return img4, labels4, segments4
|
|
|
|
@staticmethod
|
|
def collate_fn(batch):
|
|
img, label, path, shapes, masks = zip(*batch) # transposed
|
|
batched_masks = torch.cat(masks, 0)
|
|
for i, l in enumerate(label):
|
|
l[:, 0] = i # add target image index for build_targets()
|
|
return torch.stack(img, 0), torch.cat(label, 0), path, shapes, batched_masks
|
|
|
|
|
|
def polygon2mask(img_size, polygons, color=1, downsample_ratio=1):
|
|
"""
|
|
Args:
|
|
img_size (tuple): The image size.
|
|
polygons (np.ndarray): [N, M], N is the number of polygons,
|
|
M is the number of points(Be divided by 2).
|
|
"""
|
|
mask = np.zeros(img_size, dtype=np.uint8)
|
|
polygons = np.asarray(polygons)
|
|
polygons = polygons.astype(np.int32)
|
|
shape = polygons.shape
|
|
polygons = polygons.reshape(shape[0], -1, 2)
|
|
cv2.fillPoly(mask, polygons, color=color)
|
|
nh, nw = (img_size[0] // downsample_ratio, img_size[1] // downsample_ratio)
|
|
# NOTE: fillPoly firstly then resize is trying the keep the same way
|
|
# of loss calculation when mask-ratio=1.
|
|
mask = cv2.resize(mask, (nw, nh))
|
|
return mask
|
|
|
|
|
|
def polygons2masks(img_size, polygons, color, downsample_ratio=1):
|
|
"""
|
|
Args:
|
|
img_size (tuple): The image size.
|
|
polygons (list[np.ndarray]): each polygon is [N, M],
|
|
N is the number of polygons,
|
|
M is the number of points(Be divided by 2).
|
|
"""
|
|
masks = []
|
|
for si in range(len(polygons)):
|
|
mask = polygon2mask(img_size, [polygons[si].reshape(-1)], color, downsample_ratio)
|
|
masks.append(mask)
|
|
return np.array(masks)
|
|
|
|
|
|
def polygons2masks_overlap(img_size, segments, downsample_ratio=1):
|
|
"""Return a (640, 640) overlap mask."""
|
|
masks = np.zeros((img_size[0] // downsample_ratio, img_size[1] // downsample_ratio), dtype=np.uint8)
|
|
areas = []
|
|
ms = []
|
|
for si in range(len(segments)):
|
|
mask = polygon2mask(
|
|
img_size,
|
|
[segments[si].reshape(-1)],
|
|
downsample_ratio=downsample_ratio,
|
|
color=1,
|
|
)
|
|
ms.append(mask)
|
|
areas.append(mask.sum())
|
|
areas = np.asarray(areas)
|
|
index = np.argsort(-areas)
|
|
ms = np.array(ms)[index]
|
|
for i in range(len(segments)):
|
|
mask = ms[i] * (i + 1)
|
|
masks = masks + mask
|
|
masks = np.clip(masks, a_min=0, a_max=i + 1)
|
|
return masks, index
|