2020-11-14 18:50:32 +08:00
|
|
|
# Dataset utils and dataloaders
|
|
|
|
|
2020-05-30 08:04:54 +08:00
|
|
|
import glob
|
2021-05-26 20:26:52 +08:00
|
|
|
import hashlib
|
2021-06-12 19:26:41 +08:00
|
|
|
import json
|
2020-11-24 23:03:19 +08:00
|
|
|
import logging
|
2020-11-07 09:18:18 +08:00
|
|
|
import math
|
2020-05-30 08:04:54 +08:00
|
|
|
import os
|
|
|
|
import random
|
|
|
|
import shutil
|
|
|
|
import time
|
2020-11-07 09:18:18 +08:00
|
|
|
from itertools import repeat
|
2021-06-09 00:00:21 +08:00
|
|
|
from multiprocessing.pool import ThreadPool, Pool
|
2020-05-30 08:04:54 +08:00
|
|
|
from pathlib import Path
|
|
|
|
from threading import Thread
|
|
|
|
|
|
|
|
import cv2
|
|
|
|
import numpy as np
|
|
|
|
import torch
|
2021-01-05 11:54:09 +08:00
|
|
|
import torch.nn.functional as F
|
2021-06-09 05:09:45 +08:00
|
|
|
import yaml
|
2020-05-30 08:04:54 +08:00
|
|
|
from PIL import Image, ExifTags
|
|
|
|
from torch.utils.data import Dataset
|
|
|
|
from tqdm import tqdm
|
|
|
|
|
2021-06-09 05:09:45 +08:00
|
|
|
from utils.general import check_requirements, check_file, check_dataset, xyxy2xywh, xywh2xyxy, xywhn2xyxy, xyn2xy, \
|
|
|
|
segment2box, segments2boxes, resample_segments, clean_str
|
2020-11-14 18:50:32 +08:00
|
|
|
from utils.torch_utils import torch_distributed_zero_first
|
2020-05-30 08:04:54 +08:00
|
|
|
|
2020-11-14 18:50:32 +08:00
|
|
|
# Parameters
|
2020-05-30 08:04:54 +08:00
|
|
|
help_url = 'https://github.com/ultralytics/yolov5/wiki/Train-Custom-Data'
|
2021-03-26 19:45:22 +08:00
|
|
|
img_formats = ['bmp', 'jpg', 'jpeg', 'png', 'tif', 'tiff', 'dng', 'webp', 'mpo'] # acceptable image suffixes
|
2020-11-16 23:24:57 +08:00
|
|
|
vid_formats = ['mov', 'avi', 'mp4', 'mpg', 'mpeg', 'm4v', 'wmv', 'mkv'] # acceptable video suffixes
|
2021-06-09 00:00:21 +08:00
|
|
|
num_threads = min(8, os.cpu_count()) # number of multiprocessing threads
|
2020-11-24 23:13:04 +08:00
|
|
|
logger = logging.getLogger(__name__)
|
2020-05-30 08:04:54 +08:00
|
|
|
|
|
|
|
# Get orientation exif tag
|
|
|
|
for orientation in ExifTags.TAGS.keys():
|
|
|
|
if ExifTags.TAGS[orientation] == 'Orientation':
|
|
|
|
break
|
|
|
|
|
|
|
|
|
2021-05-26 20:26:52 +08:00
|
|
|
def get_hash(paths):
|
|
|
|
# Returns a single hash value of a list of paths (files or dirs)
|
|
|
|
size = sum(os.path.getsize(p) for p in paths if os.path.exists(p)) # sizes
|
|
|
|
h = hashlib.md5(str(size).encode()) # hash sizes
|
|
|
|
h.update(''.join(paths).encode()) # hash paths
|
|
|
|
return h.hexdigest() # return hash
|
2020-07-10 11:07:16 +08:00
|
|
|
|
|
|
|
|
2020-05-30 08:04:54 +08:00
|
|
|
def exif_size(img):
|
|
|
|
# Returns exif-corrected PIL size
|
|
|
|
s = img.size # (width, height)
|
|
|
|
try:
|
|
|
|
rotation = dict(img._getexif().items())[orientation]
|
|
|
|
if rotation == 6: # rotation 270
|
|
|
|
s = (s[1], s[0])
|
|
|
|
elif rotation == 8: # rotation 90
|
|
|
|
s = (s[1], s[0])
|
|
|
|
except:
|
|
|
|
pass
|
|
|
|
|
|
|
|
return s
|
|
|
|
|
|
|
|
|
2021-06-09 19:14:56 +08:00
|
|
|
def create_dataloader(path, imgsz, batch_size, stride, single_cls=False, hyp=None, augment=False, cache=False, pad=0.0,
|
2021-06-19 22:30:25 +08:00
|
|
|
rect=False, rank=-1, workers=8, image_weights=False, quad=False, prefix=''):
|
2020-11-14 18:50:32 +08:00
|
|
|
# Make sure only the first process in DDP process the dataset first, and the following others can use the cache
|
2020-08-12 02:18:45 +08:00
|
|
|
with torch_distributed_zero_first(rank):
|
2020-07-20 03:33:30 +08:00
|
|
|
dataset = LoadImagesAndLabels(path, imgsz, batch_size,
|
2020-07-24 13:49:54 +08:00
|
|
|
augment=augment, # augment images
|
|
|
|
hyp=hyp, # augmentation hyperparameters
|
|
|
|
rect=rect, # rectangular training
|
|
|
|
cache_images=cache,
|
2021-06-09 19:14:56 +08:00
|
|
|
single_cls=single_cls,
|
2020-07-24 13:49:54 +08:00
|
|
|
stride=int(stride),
|
2020-08-12 02:18:45 +08:00
|
|
|
pad=pad,
|
2021-01-13 02:33:15 +08:00
|
|
|
image_weights=image_weights,
|
|
|
|
prefix=prefix)
|
2020-06-27 09:56:13 +08:00
|
|
|
|
|
|
|
batch_size = min(batch_size, len(dataset))
|
2021-06-19 22:30:25 +08:00
|
|
|
nw = min([os.cpu_count(), batch_size if batch_size > 1 else 0, workers]) # number of workers
|
2020-09-02 08:02:47 +08:00
|
|
|
sampler = torch.utils.data.distributed.DistributedSampler(dataset) if rank != -1 else None
|
2020-11-26 20:25:51 +08:00
|
|
|
loader = torch.utils.data.DataLoader if image_weights else InfiniteDataLoader
|
|
|
|
# Use torch.utils.data.DataLoader() if dataset.properties will update during training else InfiniteDataLoader()
|
|
|
|
dataloader = loader(dataset,
|
|
|
|
batch_size=batch_size,
|
|
|
|
num_workers=nw,
|
|
|
|
sampler=sampler,
|
|
|
|
pin_memory=True,
|
2021-01-05 11:54:09 +08:00
|
|
|
collate_fn=LoadImagesAndLabels.collate_fn4 if quad else LoadImagesAndLabels.collate_fn)
|
2020-06-27 09:56:13 +08:00
|
|
|
return dataloader, dataset
|
|
|
|
|
|
|
|
|
2020-09-01 02:01:25 +08:00
|
|
|
class InfiniteDataLoader(torch.utils.data.dataloader.DataLoader):
|
2020-11-14 18:50:32 +08:00
|
|
|
""" Dataloader that reuses workers
|
2020-09-01 02:01:25 +08:00
|
|
|
|
2020-11-14 18:50:32 +08:00
|
|
|
Uses same syntax as vanilla DataLoader
|
2020-09-02 08:02:47 +08:00
|
|
|
"""
|
2020-09-01 02:01:25 +08:00
|
|
|
|
|
|
|
def __init__(self, *args, **kwargs):
|
|
|
|
super().__init__(*args, **kwargs)
|
2020-09-11 03:27:35 +08:00
|
|
|
object.__setattr__(self, 'batch_sampler', _RepeatSampler(self.batch_sampler))
|
2020-09-01 02:01:25 +08:00
|
|
|
self.iterator = super().__iter__()
|
|
|
|
|
|
|
|
def __len__(self):
|
|
|
|
return len(self.batch_sampler.sampler)
|
|
|
|
|
|
|
|
def __iter__(self):
|
|
|
|
for i in range(len(self)):
|
|
|
|
yield next(self.iterator)
|
|
|
|
|
|
|
|
|
2020-09-11 03:27:35 +08:00
|
|
|
class _RepeatSampler(object):
|
2020-11-14 18:50:32 +08:00
|
|
|
""" Sampler that repeats forever
|
2020-09-01 02:01:25 +08:00
|
|
|
|
2020-09-11 03:27:35 +08:00
|
|
|
Args:
|
|
|
|
sampler (Sampler)
|
|
|
|
"""
|
|
|
|
|
|
|
|
def __init__(self, sampler):
|
|
|
|
self.sampler = sampler
|
2020-09-01 02:01:25 +08:00
|
|
|
|
2020-09-11 03:27:35 +08:00
|
|
|
def __iter__(self):
|
|
|
|
while True:
|
|
|
|
yield from iter(self.sampler)
|
2020-09-01 02:01:25 +08:00
|
|
|
|
|
|
|
|
2020-05-30 08:04:54 +08:00
|
|
|
class LoadImages: # for inference
|
2021-01-31 05:47:23 +08:00
|
|
|
def __init__(self, path, img_size=640, stride=32):
|
2021-02-06 01:06:23 +08:00
|
|
|
p = str(Path(path).absolute()) # os-agnostic absolute path
|
2020-07-13 05:22:15 +08:00
|
|
|
if '*' in p:
|
2020-09-07 13:52:28 +08:00
|
|
|
files = sorted(glob.glob(p, recursive=True)) # glob
|
2020-07-13 05:22:15 +08:00
|
|
|
elif os.path.isdir(p):
|
|
|
|
files = sorted(glob.glob(os.path.join(p, '*.*'))) # dir
|
2020-07-13 05:14:51 +08:00
|
|
|
elif os.path.isfile(p):
|
2020-07-13 05:22:15 +08:00
|
|
|
files = [p] # files
|
2020-07-13 05:14:51 +08:00
|
|
|
else:
|
2021-01-13 02:33:15 +08:00
|
|
|
raise Exception(f'ERROR: {p} does not exist')
|
2020-05-30 08:04:54 +08:00
|
|
|
|
2020-11-16 23:24:57 +08:00
|
|
|
images = [x for x in files if x.split('.')[-1].lower() in img_formats]
|
|
|
|
videos = [x for x in files if x.split('.')[-1].lower() in vid_formats]
|
2020-07-13 05:14:51 +08:00
|
|
|
ni, nv = len(images), len(videos)
|
2020-05-30 08:04:54 +08:00
|
|
|
|
|
|
|
self.img_size = img_size
|
2021-01-31 05:47:23 +08:00
|
|
|
self.stride = stride
|
2020-05-30 08:04:54 +08:00
|
|
|
self.files = images + videos
|
2020-07-13 05:14:51 +08:00
|
|
|
self.nf = ni + nv # number of files
|
|
|
|
self.video_flag = [False] * ni + [True] * nv
|
2020-12-12 07:45:32 +08:00
|
|
|
self.mode = 'image'
|
2020-05-30 08:04:54 +08:00
|
|
|
if any(videos):
|
|
|
|
self.new_video(videos[0]) # new video
|
|
|
|
else:
|
|
|
|
self.cap = None
|
2021-01-13 02:33:15 +08:00
|
|
|
assert self.nf > 0, f'No images or videos found in {p}. ' \
|
|
|
|
f'Supported formats are:\nimages: {img_formats}\nvideos: {vid_formats}'
|
2020-05-30 08:04:54 +08:00
|
|
|
|
|
|
|
def __iter__(self):
|
|
|
|
self.count = 0
|
|
|
|
return self
|
|
|
|
|
|
|
|
def __next__(self):
|
2020-07-13 05:14:51 +08:00
|
|
|
if self.count == self.nf:
|
2020-05-30 08:04:54 +08:00
|
|
|
raise StopIteration
|
|
|
|
path = self.files[self.count]
|
|
|
|
|
|
|
|
if self.video_flag[self.count]:
|
|
|
|
# Read video
|
|
|
|
self.mode = 'video'
|
|
|
|
ret_val, img0 = self.cap.read()
|
|
|
|
if not ret_val:
|
|
|
|
self.count += 1
|
|
|
|
self.cap.release()
|
2020-07-13 05:14:51 +08:00
|
|
|
if self.count == self.nf: # last video
|
2020-05-30 08:04:54 +08:00
|
|
|
raise StopIteration
|
|
|
|
else:
|
|
|
|
path = self.files[self.count]
|
|
|
|
self.new_video(path)
|
|
|
|
ret_val, img0 = self.cap.read()
|
|
|
|
|
|
|
|
self.frame += 1
|
2021-05-17 18:27:40 +08:00
|
|
|
print(f'video {self.count + 1}/{self.nf} ({self.frame}/{self.frames}) {path}: ', end='')
|
2020-05-30 08:04:54 +08:00
|
|
|
|
|
|
|
else:
|
|
|
|
# Read image
|
|
|
|
self.count += 1
|
2021-05-31 16:39:00 +08:00
|
|
|
img0 = cv2.imread(path) # BGR
|
2020-05-30 08:04:54 +08:00
|
|
|
assert img0 is not None, 'Image Not Found ' + path
|
2021-01-13 02:33:15 +08:00
|
|
|
print(f'image {self.count}/{self.nf} {path}: ', end='')
|
2020-05-30 08:04:54 +08:00
|
|
|
|
|
|
|
# Padded resize
|
2021-01-31 05:47:23 +08:00
|
|
|
img = letterbox(img0, self.img_size, stride=self.stride)[0]
|
2020-05-30 08:04:54 +08:00
|
|
|
|
|
|
|
# Convert
|
|
|
|
img = img[:, :, ::-1].transpose(2, 0, 1) # BGR to RGB, to 3x416x416
|
|
|
|
img = np.ascontiguousarray(img)
|
|
|
|
|
|
|
|
return path, img, img0, self.cap
|
|
|
|
|
|
|
|
def new_video(self, path):
|
|
|
|
self.frame = 0
|
|
|
|
self.cap = cv2.VideoCapture(path)
|
2021-05-17 18:27:40 +08:00
|
|
|
self.frames = int(self.cap.get(cv2.CAP_PROP_FRAME_COUNT))
|
2020-05-30 08:04:54 +08:00
|
|
|
|
|
|
|
def __len__(self):
|
2020-07-13 05:14:51 +08:00
|
|
|
return self.nf # number of files
|
2020-05-30 08:04:54 +08:00
|
|
|
|
|
|
|
|
|
|
|
class LoadWebcam: # for inference
|
2021-01-31 05:47:23 +08:00
|
|
|
def __init__(self, pipe='0', img_size=640, stride=32):
|
2020-05-30 08:04:54 +08:00
|
|
|
self.img_size = img_size
|
2021-01-31 05:47:23 +08:00
|
|
|
self.stride = stride
|
2020-05-30 08:04:54 +08:00
|
|
|
|
2020-11-14 18:50:32 +08:00
|
|
|
if pipe.isnumeric():
|
|
|
|
pipe = eval(pipe) # local camera
|
2020-05-30 08:04:54 +08:00
|
|
|
# pipe = 'rtsp://192.168.1.64/1' # IP camera
|
|
|
|
# pipe = 'rtsp://username:password@192.168.1.64/1' # IP camera with login
|
|
|
|
# pipe = 'http://wmccpinetop.axiscam.net/mjpg/video.mjpg' # IP golf camera
|
|
|
|
|
|
|
|
self.pipe = pipe
|
|
|
|
self.cap = cv2.VideoCapture(pipe) # video capture object
|
|
|
|
self.cap.set(cv2.CAP_PROP_BUFFERSIZE, 3) # set buffer size
|
|
|
|
|
|
|
|
def __iter__(self):
|
|
|
|
self.count = -1
|
|
|
|
return self
|
|
|
|
|
|
|
|
def __next__(self):
|
|
|
|
self.count += 1
|
|
|
|
if cv2.waitKey(1) == ord('q'): # q to quit
|
|
|
|
self.cap.release()
|
|
|
|
cv2.destroyAllWindows()
|
|
|
|
raise StopIteration
|
|
|
|
|
|
|
|
# Read frame
|
|
|
|
if self.pipe == 0: # local camera
|
|
|
|
ret_val, img0 = self.cap.read()
|
|
|
|
img0 = cv2.flip(img0, 1) # flip left-right
|
|
|
|
else: # IP camera
|
|
|
|
n = 0
|
|
|
|
while True:
|
|
|
|
n += 1
|
|
|
|
self.cap.grab()
|
|
|
|
if n % 30 == 0: # skip frames
|
|
|
|
ret_val, img0 = self.cap.retrieve()
|
|
|
|
if ret_val:
|
|
|
|
break
|
|
|
|
|
|
|
|
# Print
|
2021-01-13 02:33:15 +08:00
|
|
|
assert ret_val, f'Camera Error {self.pipe}'
|
2020-05-30 08:04:54 +08:00
|
|
|
img_path = 'webcam.jpg'
|
2021-01-13 02:33:15 +08:00
|
|
|
print(f'webcam {self.count}: ', end='')
|
2020-05-30 08:04:54 +08:00
|
|
|
|
|
|
|
# Padded resize
|
2021-01-31 05:47:23 +08:00
|
|
|
img = letterbox(img0, self.img_size, stride=self.stride)[0]
|
2020-05-30 08:04:54 +08:00
|
|
|
|
|
|
|
# Convert
|
|
|
|
img = img[:, :, ::-1].transpose(2, 0, 1) # BGR to RGB, to 3x416x416
|
|
|
|
img = np.ascontiguousarray(img)
|
|
|
|
|
|
|
|
return img_path, img, img0, None
|
|
|
|
|
|
|
|
def __len__(self):
|
|
|
|
return 0
|
|
|
|
|
|
|
|
|
|
|
|
class LoadStreams: # multiple IP or RTSP cameras
|
2021-01-31 05:47:23 +08:00
|
|
|
def __init__(self, sources='streams.txt', img_size=640, stride=32):
|
2020-12-12 07:45:32 +08:00
|
|
|
self.mode = 'stream'
|
2020-05-30 08:04:54 +08:00
|
|
|
self.img_size = img_size
|
2021-01-31 05:47:23 +08:00
|
|
|
self.stride = stride
|
2020-05-30 08:04:54 +08:00
|
|
|
|
|
|
|
if os.path.isfile(sources):
|
|
|
|
with open(sources, 'r') as f:
|
2020-11-29 18:58:14 +08:00
|
|
|
sources = [x.strip() for x in f.read().strip().splitlines() if len(x.strip())]
|
2020-05-30 08:04:54 +08:00
|
|
|
else:
|
|
|
|
sources = [sources]
|
|
|
|
|
|
|
|
n = len(sources)
|
2021-05-21 22:51:07 +08:00
|
|
|
self.imgs, self.fps, self.frames, self.threads = [None] * n, [0] * n, [0] * n, [None] * n
|
2020-12-18 09:20:20 +08:00
|
|
|
self.sources = [clean_str(x) for x in sources] # clean source names for later
|
2021-04-16 23:58:28 +08:00
|
|
|
for i, s in enumerate(sources): # index, source
|
|
|
|
# Start thread to read frames from video stream
|
2021-01-13 02:33:15 +08:00
|
|
|
print(f'{i + 1}/{n}: {s}... ', end='')
|
2021-04-16 23:58:28 +08:00
|
|
|
if 'youtube.com/' in s or 'youtu.be/' in s: # if source is YouTube video
|
2021-04-12 00:53:40 +08:00
|
|
|
check_requirements(('pafy', 'youtube_dl'))
|
|
|
|
import pafy
|
2021-04-16 23:58:28 +08:00
|
|
|
s = pafy.new(s).getbest(preftype="mp4").url # YouTube URL
|
|
|
|
s = eval(s) if s.isnumeric() else s # i.e. s = '0' local webcam
|
|
|
|
cap = cv2.VideoCapture(s)
|
2021-01-13 02:33:15 +08:00
|
|
|
assert cap.isOpened(), f'Failed to open {s}'
|
2020-05-30 08:04:54 +08:00
|
|
|
w = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
|
|
|
|
h = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
|
2021-05-23 20:55:42 +08:00
|
|
|
self.fps[i] = max(cap.get(cv2.CAP_PROP_FPS) % 100, 0) or 30.0 # 30 FPS fallback
|
|
|
|
self.frames[i] = max(int(cap.get(cv2.CAP_PROP_FRAME_COUNT)), 0) or float('inf') # infinite stream fallback
|
2021-04-12 00:53:40 +08:00
|
|
|
|
2020-05-30 08:04:54 +08:00
|
|
|
_, self.imgs[i] = cap.read() # guarantee first frame
|
2021-05-21 22:51:07 +08:00
|
|
|
self.threads[i] = Thread(target=self.update, args=([i, cap]), daemon=True)
|
|
|
|
print(f" success ({self.frames[i]} frames {w}x{h} at {self.fps[i]:.2f} FPS)")
|
|
|
|
self.threads[i].start()
|
2020-11-24 23:13:04 +08:00
|
|
|
print('') # newline
|
2020-05-30 08:04:54 +08:00
|
|
|
|
|
|
|
# check for common shapes
|
2021-01-31 05:47:23 +08:00
|
|
|
s = np.stack([letterbox(x, self.img_size, stride=self.stride)[0].shape for x in self.imgs], 0) # shapes
|
2020-05-30 08:04:54 +08:00
|
|
|
self.rect = np.unique(s, axis=0).shape[0] == 1 # rect inference if all shapes equal
|
|
|
|
if not self.rect:
|
2020-11-24 23:13:04 +08:00
|
|
|
print('WARNING: Different stream shapes detected. For optimal performance supply similarly-shaped streams.')
|
2020-05-30 08:04:54 +08:00
|
|
|
|
2021-05-21 22:51:07 +08:00
|
|
|
def update(self, i, cap):
|
|
|
|
# Read stream `i` frames in daemon thread
|
|
|
|
n, f = 0, self.frames[i]
|
|
|
|
while cap.isOpened() and n < f:
|
2020-05-30 08:04:54 +08:00
|
|
|
n += 1
|
|
|
|
# _, self.imgs[index] = cap.read()
|
|
|
|
cap.grab()
|
2021-05-21 22:51:07 +08:00
|
|
|
if n % 4: # read every 4th frame
|
2021-02-16 03:02:20 +08:00
|
|
|
success, im = cap.retrieve()
|
2021-05-21 22:51:07 +08:00
|
|
|
self.imgs[i] = im if success else self.imgs[i] * 0
|
|
|
|
time.sleep(1 / self.fps[i]) # wait time
|
2020-05-30 08:04:54 +08:00
|
|
|
|
|
|
|
def __iter__(self):
|
|
|
|
self.count = -1
|
|
|
|
return self
|
|
|
|
|
|
|
|
def __next__(self):
|
|
|
|
self.count += 1
|
2021-05-21 22:51:07 +08:00
|
|
|
if not all(x.is_alive() for x in self.threads) or cv2.waitKey(1) == ord('q'): # q to quit
|
2020-05-30 08:04:54 +08:00
|
|
|
cv2.destroyAllWindows()
|
|
|
|
raise StopIteration
|
|
|
|
|
|
|
|
# Letterbox
|
2021-05-21 22:51:07 +08:00
|
|
|
img0 = self.imgs.copy()
|
2021-01-31 05:47:23 +08:00
|
|
|
img = [letterbox(x, self.img_size, auto=self.rect, stride=self.stride)[0] for x in img0]
|
2020-05-30 08:04:54 +08:00
|
|
|
|
|
|
|
# Stack
|
|
|
|
img = np.stack(img, 0)
|
|
|
|
|
|
|
|
# Convert
|
|
|
|
img = img[:, :, :, ::-1].transpose(0, 3, 1, 2) # BGR to RGB, to bsx3x416x416
|
|
|
|
img = np.ascontiguousarray(img)
|
|
|
|
|
|
|
|
return self.sources, img, img0, None
|
|
|
|
|
|
|
|
def __len__(self):
|
|
|
|
return 0 # 1E12 frames = 32 streams at 30 FPS for 30 years
|
|
|
|
|
|
|
|
|
2020-11-24 23:13:04 +08:00
|
|
|
def img2label_paths(img_paths):
|
|
|
|
# Define label paths as a function of image paths
|
|
|
|
sa, sb = os.sep + 'images' + os.sep, os.sep + 'labels' + os.sep # /images/, /labels/ substrings
|
2021-02-26 10:05:38 +08:00
|
|
|
return ['txt'.join(x.replace(sa, sb, 1).rsplit(x.split('.')[-1], 1)) for x in img_paths]
|
2020-11-24 23:13:04 +08:00
|
|
|
|
|
|
|
|
2020-05-30 08:04:54 +08:00
|
|
|
class LoadImagesAndLabels(Dataset): # for training/testing
|
2020-06-28 04:02:01 +08:00
|
|
|
def __init__(self, path, img_size=640, batch_size=16, augment=False, hyp=None, rect=False, image_weights=False,
|
2021-01-13 02:33:15 +08:00
|
|
|
cache_images=False, single_cls=False, stride=32, pad=0.0, prefix=''):
|
2020-10-24 21:09:19 +08:00
|
|
|
self.img_size = img_size
|
|
|
|
self.augment = augment
|
|
|
|
self.hyp = hyp
|
|
|
|
self.image_weights = image_weights
|
|
|
|
self.rect = False if image_weights else rect
|
|
|
|
self.mosaic = self.augment and not self.rect # load 4 images at a time into a mosaic (only during training)
|
|
|
|
self.mosaic_border = [-img_size // 2, -img_size // 2]
|
|
|
|
self.stride = stride
|
2021-02-02 13:38:41 +08:00
|
|
|
self.path = path
|
2021-02-06 01:06:23 +08:00
|
|
|
|
2020-05-30 08:04:54 +08:00
|
|
|
try:
|
2020-07-10 11:07:16 +08:00
|
|
|
f = [] # image files
|
2020-07-10 04:45:55 +08:00
|
|
|
for p in path if isinstance(path, list) else [path]:
|
2020-11-16 23:24:57 +08:00
|
|
|
p = Path(p) # os-agnostic
|
|
|
|
if p.is_dir(): # dir
|
|
|
|
f += glob.glob(str(p / '**' / '*.*'), recursive=True)
|
2021-02-06 01:06:23 +08:00
|
|
|
# f = list(p.rglob('**/*.*')) # pathlib
|
2020-11-16 23:24:57 +08:00
|
|
|
elif p.is_file(): # file
|
2020-07-10 04:45:55 +08:00
|
|
|
with open(p, 'r') as t:
|
2020-11-29 18:58:14 +08:00
|
|
|
t = t.read().strip().splitlines()
|
2020-11-16 23:24:57 +08:00
|
|
|
parent = str(p.parent) + os.sep
|
2020-07-10 04:45:55 +08:00
|
|
|
f += [x.replace('./', parent) if x.startswith('./') else x for x in t] # local to global path
|
2021-02-06 01:06:23 +08:00
|
|
|
# f += [p.parent / x.lstrip(os.sep) for x in t] # local to global path (pathlib)
|
2020-07-09 17:52:12 +08:00
|
|
|
else:
|
2021-01-13 02:33:15 +08:00
|
|
|
raise Exception(f'{prefix}{p} does not exist')
|
2020-11-16 23:24:57 +08:00
|
|
|
self.img_files = sorted([x.replace('/', os.sep) for x in f if x.split('.')[-1].lower() in img_formats])
|
2021-02-06 01:06:23 +08:00
|
|
|
# self.img_files = sorted([x for x in f if x.suffix[1:].lower() in img_formats]) # pathlib
|
2021-01-13 02:33:15 +08:00
|
|
|
assert self.img_files, f'{prefix}No images found'
|
2020-07-10 07:28:20 +08:00
|
|
|
except Exception as e:
|
2021-01-13 02:33:15 +08:00
|
|
|
raise Exception(f'{prefix}Error loading data from {path}: {e}\nSee {help_url}')
|
2020-05-30 08:04:54 +08:00
|
|
|
|
2020-07-10 11:07:16 +08:00
|
|
|
# Check cache
|
2020-10-24 20:50:50 +08:00
|
|
|
self.label_files = img2label_paths(self.img_files) # labels
|
2021-02-06 03:12:43 +08:00
|
|
|
cache_path = (p if p.is_file() else Path(self.label_files[0]).parent).with_suffix('.cache') # cached labels
|
2020-11-24 23:13:04 +08:00
|
|
|
if cache_path.is_file():
|
2021-02-12 13:22:45 +08:00
|
|
|
cache, exists = torch.load(cache_path), True # load
|
2021-06-19 17:22:09 +08:00
|
|
|
if cache['version'] != 0.3 or cache['hash'] != get_hash(self.label_files + self.img_files):
|
2021-02-12 13:22:45 +08:00
|
|
|
cache, exists = self.cache_labels(cache_path, prefix), False # re-cache
|
2020-07-10 11:07:16 +08:00
|
|
|
else:
|
2021-02-12 13:22:45 +08:00
|
|
|
cache, exists = self.cache_labels(cache_path, prefix), False # cache
|
2020-06-14 06:05:41 +08:00
|
|
|
|
2020-11-24 23:13:04 +08:00
|
|
|
# Display cache
|
2021-02-12 13:22:45 +08:00
|
|
|
nf, nm, ne, nc, n = cache.pop('results') # found, missing, empty, corrupted, total
|
|
|
|
if exists:
|
2021-03-15 14:16:17 +08:00
|
|
|
d = f"Scanning '{cache_path}' images and labels... {nf} found, {nm} missing, {ne} empty, {nc} corrupted"
|
2021-02-12 13:22:45 +08:00
|
|
|
tqdm(None, desc=prefix + d, total=n, initial=n) # display cache results
|
2021-06-18 16:21:47 +08:00
|
|
|
if cache['msgs']:
|
|
|
|
logging.info('\n'.join(cache['msgs'])) # display warnings
|
2021-01-13 02:33:15 +08:00
|
|
|
assert nf > 0 or not augment, f'{prefix}No labels in {cache_path}. Can not train without labels. See {help_url}'
|
2020-11-24 23:13:04 +08:00
|
|
|
|
2020-10-24 20:50:50 +08:00
|
|
|
# Read cache
|
2021-06-18 16:21:47 +08:00
|
|
|
[cache.pop(k) for k in ('hash', 'version', 'msgs')] # remove items
|
2021-02-12 13:22:45 +08:00
|
|
|
labels, shapes, self.segments = zip(*cache.values())
|
2020-07-10 11:07:16 +08:00
|
|
|
self.labels = list(labels)
|
2020-10-24 20:50:50 +08:00
|
|
|
self.shapes = np.array(shapes, dtype=np.float64)
|
|
|
|
self.img_files = list(cache.keys()) # update
|
|
|
|
self.label_files = img2label_paths(cache.keys()) # update
|
2020-11-24 23:13:04 +08:00
|
|
|
if single_cls:
|
|
|
|
for x in self.labels:
|
|
|
|
x[:, 0] = 0
|
2020-06-14 06:05:41 +08:00
|
|
|
|
2020-10-24 21:09:19 +08:00
|
|
|
n = len(shapes) # number of images
|
|
|
|
bi = np.floor(np.arange(n) / batch_size).astype(np.int) # batch index
|
|
|
|
nb = bi[-1] + 1 # number of batches
|
|
|
|
self.batch = bi # batch index of image
|
|
|
|
self.n = n
|
2020-11-26 18:49:01 +08:00
|
|
|
self.indices = range(n)
|
2020-10-24 21:09:19 +08:00
|
|
|
|
2020-10-24 20:50:50 +08:00
|
|
|
# Rectangular Training
|
2020-05-30 08:04:54 +08:00
|
|
|
if self.rect:
|
|
|
|
# Sort by aspect ratio
|
2020-06-14 06:05:41 +08:00
|
|
|
s = self.shapes # wh
|
2020-05-30 08:04:54 +08:00
|
|
|
ar = s[:, 1] / s[:, 0] # aspect ratio
|
|
|
|
irect = ar.argsort()
|
|
|
|
self.img_files = [self.img_files[i] for i in irect]
|
|
|
|
self.label_files = [self.label_files[i] for i in irect]
|
2020-07-10 11:39:11 +08:00
|
|
|
self.labels = [self.labels[i] for i in irect]
|
2020-05-30 08:04:54 +08:00
|
|
|
self.shapes = s[irect] # wh
|
|
|
|
ar = ar[irect]
|
|
|
|
|
|
|
|
# Set training image shapes
|
|
|
|
shapes = [[1, 1]] * nb
|
|
|
|
for i in range(nb):
|
|
|
|
ari = ar[bi == i]
|
|
|
|
mini, maxi = ari.min(), ari.max()
|
|
|
|
if maxi < 1:
|
|
|
|
shapes[i] = [maxi, 1]
|
|
|
|
elif mini > 1:
|
|
|
|
shapes[i] = [1, 1 / mini]
|
|
|
|
|
2020-06-25 04:02:27 +08:00
|
|
|
self.batch_shapes = np.ceil(np.array(shapes) * img_size / stride + pad).astype(np.int) * stride
|
2020-05-30 08:04:54 +08:00
|
|
|
|
|
|
|
# Cache images into memory for faster training (WARNING: large datasets may exceed system RAM)
|
2020-07-10 11:07:16 +08:00
|
|
|
self.imgs = [None] * n
|
|
|
|
if cache_images:
|
2020-05-30 08:04:54 +08:00
|
|
|
gb = 0 # Gigabytes of cached images
|
|
|
|
self.img_hw0, self.img_hw = [None] * n, [None] * n
|
2021-06-09 00:00:21 +08:00
|
|
|
results = ThreadPool(num_threads).imap(lambda x: load_image(*x), zip(repeat(self), range(n)))
|
2020-11-07 09:18:18 +08:00
|
|
|
pbar = tqdm(enumerate(results), total=n)
|
|
|
|
for i, x in pbar:
|
|
|
|
self.imgs[i], self.img_hw0[i], self.img_hw[i] = x # img, hw_original, hw_resized = load_image(self, i)
|
2020-05-30 08:04:54 +08:00
|
|
|
gb += self.imgs[i].nbytes
|
2021-01-13 02:33:15 +08:00
|
|
|
pbar.desc = f'{prefix}Caching images ({gb / 1E9:.1f}GB)'
|
2021-03-29 18:21:25 +08:00
|
|
|
pbar.close()
|
2021-04-12 00:53:40 +08:00
|
|
|
|
2021-01-13 02:33:15 +08:00
|
|
|
def cache_labels(self, path=Path('./labels.cache'), prefix=''):
|
2020-07-10 11:07:16 +08:00
|
|
|
# Cache dataset labels, check images and read shapes
|
|
|
|
x = {} # dict
|
2021-06-18 16:21:47 +08:00
|
|
|
nm, nf, ne, nc, msgs = 0, 0, 0, 0, [] # number missing, found, empty, corrupt, messages
|
2021-06-09 00:00:21 +08:00
|
|
|
desc = f"{prefix}Scanning '{path.parent / path.stem}' images and labels..."
|
|
|
|
with Pool(num_threads) as pool:
|
2021-06-09 00:36:40 +08:00
|
|
|
pbar = tqdm(pool.imap_unordered(verify_image_label, zip(self.img_files, self.label_files, repeat(prefix))),
|
2021-06-09 00:00:21 +08:00
|
|
|
desc=desc, total=len(self.img_files))
|
2021-06-18 16:21:47 +08:00
|
|
|
for im_file, l, shape, segments, nm_f, nf_f, ne_f, nc_f, msg in pbar:
|
2021-06-09 00:36:40 +08:00
|
|
|
nm += nm_f
|
|
|
|
nf += nf_f
|
|
|
|
ne += ne_f
|
|
|
|
nc += nc_f
|
2021-06-09 00:00:21 +08:00
|
|
|
if im_file:
|
|
|
|
x[im_file] = [l, shape, segments]
|
2021-06-18 16:21:47 +08:00
|
|
|
if msg:
|
|
|
|
msgs.append(msg)
|
2021-06-09 00:00:21 +08:00
|
|
|
pbar.desc = f"{desc}{nf} found, {nm} missing, {ne} empty, {nc} corrupted"
|
2021-04-12 00:53:40 +08:00
|
|
|
|
2021-06-09 00:36:40 +08:00
|
|
|
pbar.close()
|
2021-06-18 16:21:47 +08:00
|
|
|
if msgs:
|
|
|
|
logging.info('\n'.join(msgs))
|
2020-11-24 23:13:04 +08:00
|
|
|
if nf == 0:
|
2021-05-10 23:07:16 +08:00
|
|
|
logging.info(f'{prefix}WARNING: No labels found in {path}. See {help_url}')
|
2020-07-10 11:07:16 +08:00
|
|
|
x['hash'] = get_hash(self.label_files + self.img_files)
|
2021-06-09 00:00:21 +08:00
|
|
|
x['results'] = nf, nm, ne, nc, len(self.img_files)
|
2021-06-18 16:21:47 +08:00
|
|
|
x['msgs'] = msgs # warnings
|
|
|
|
x['version'] = 0.3 # cache version
|
2021-04-30 06:56:44 +08:00
|
|
|
try:
|
2021-05-26 20:26:52 +08:00
|
|
|
torch.save(x, path) # save cache for next time
|
2021-04-30 06:56:44 +08:00
|
|
|
logging.info(f'{prefix}New cache created: {path}')
|
|
|
|
except Exception as e:
|
|
|
|
logging.info(f'{prefix}WARNING: Cache directory {path.parent} is not writeable: {e}') # path not writeable
|
2020-07-10 11:07:16 +08:00
|
|
|
return x
|
2020-05-30 08:04:54 +08:00
|
|
|
|
|
|
|
def __len__(self):
|
|
|
|
return len(self.img_files)
|
|
|
|
|
|
|
|
# def __iter__(self):
|
|
|
|
# self.count = -1
|
2020-11-24 23:13:04 +08:00
|
|
|
# print('ran dataset iter')
|
2020-05-30 08:04:54 +08:00
|
|
|
# #self.shuffled_vector = np.random.permutation(self.nF) if self.augment else np.arange(self.nF)
|
|
|
|
# return self
|
|
|
|
|
|
|
|
def __getitem__(self, index):
|
2020-11-26 18:49:01 +08:00
|
|
|
index = self.indices[index] # linear, shuffled, or image_weights
|
2020-05-30 08:04:54 +08:00
|
|
|
|
|
|
|
hyp = self.hyp
|
2020-09-14 05:03:54 +08:00
|
|
|
mosaic = self.mosaic and random.random() < hyp['mosaic']
|
|
|
|
if mosaic:
|
2020-05-30 08:04:54 +08:00
|
|
|
# Load mosaic
|
|
|
|
img, labels = load_mosaic(self, index)
|
|
|
|
shapes = None
|
|
|
|
|
2020-07-13 05:14:51 +08:00
|
|
|
# MixUp https://arxiv.org/pdf/1710.09412.pdf
|
2020-08-02 04:47:54 +08:00
|
|
|
if random.random() < hyp['mixup']:
|
2020-11-26 18:49:01 +08:00
|
|
|
img2, labels2 = load_mosaic(self, random.randint(0, self.n - 1))
|
2021-06-04 18:47:53 +08:00
|
|
|
r = np.random.beta(32.0, 32.0) # mixup ratio, alpha=beta=32.0
|
2020-08-02 04:47:54 +08:00
|
|
|
img = (img * r + img2 * (1 - r)).astype(np.uint8)
|
|
|
|
labels = np.concatenate((labels, labels2), 0)
|
2020-07-13 05:14:51 +08:00
|
|
|
|
2020-05-30 08:04:54 +08:00
|
|
|
else:
|
|
|
|
# Load image
|
|
|
|
img, (h0, w0), (h, w) = load_image(self, index)
|
|
|
|
|
|
|
|
# Letterbox
|
|
|
|
shape = self.batch_shapes[self.batch[index]] if self.rect else self.img_size # final letterboxed shape
|
|
|
|
img, ratio, pad = letterbox(img, shape, auto=False, scaleup=self.augment)
|
|
|
|
shapes = (h0, w0), ((h / h0, w / w0), pad) # for COCO mAP rescaling
|
|
|
|
|
2021-01-20 05:33:52 +08:00
|
|
|
labels = self.labels[index].copy()
|
|
|
|
if labels.size: # normalized xywh to pixel xyxy format
|
|
|
|
labels[:, 1:] = xywhn2xyxy(labels[:, 1:], ratio[0] * w, ratio[1] * h, padw=pad[0], padh=pad[1])
|
2020-05-30 08:04:54 +08:00
|
|
|
|
|
|
|
if self.augment:
|
|
|
|
# Augment imagespace
|
2020-09-14 05:03:54 +08:00
|
|
|
if not mosaic:
|
2020-08-01 06:53:52 +08:00
|
|
|
img, labels = random_perspective(img, labels,
|
|
|
|
degrees=hyp['degrees'],
|
|
|
|
translate=hyp['translate'],
|
|
|
|
scale=hyp['scale'],
|
2020-08-02 04:47:54 +08:00
|
|
|
shear=hyp['shear'],
|
|
|
|
perspective=hyp['perspective'])
|
2020-05-30 08:04:54 +08:00
|
|
|
|
|
|
|
# Augment colorspace
|
|
|
|
augment_hsv(img, hgain=hyp['hsv_h'], sgain=hyp['hsv_s'], vgain=hyp['hsv_v'])
|
|
|
|
|
|
|
|
# Apply cutouts
|
|
|
|
# if random.random() < 0.9:
|
|
|
|
# labels = cutout(img, labels)
|
|
|
|
|
|
|
|
nL = len(labels) # number of labels
|
|
|
|
if nL:
|
2020-08-02 04:47:54 +08:00
|
|
|
labels[:, 1:5] = xyxy2xywh(labels[:, 1:5]) # convert xyxy to xywh
|
|
|
|
labels[:, [2, 4]] /= img.shape[0] # normalized height 0-1
|
|
|
|
labels[:, [1, 3]] /= img.shape[1] # normalized width 0-1
|
2020-05-30 08:04:54 +08:00
|
|
|
|
|
|
|
if self.augment:
|
2020-08-02 04:47:54 +08:00
|
|
|
# flip up-down
|
|
|
|
if random.random() < hyp['flipud']:
|
2020-05-30 08:04:54 +08:00
|
|
|
img = np.flipud(img)
|
|
|
|
if nL:
|
|
|
|
labels[:, 2] = 1 - labels[:, 2]
|
|
|
|
|
2020-08-02 04:47:54 +08:00
|
|
|
# flip left-right
|
|
|
|
if random.random() < hyp['fliplr']:
|
|
|
|
img = np.fliplr(img)
|
|
|
|
if nL:
|
|
|
|
labels[:, 1] = 1 - labels[:, 1]
|
|
|
|
|
2020-05-30 08:04:54 +08:00
|
|
|
labels_out = torch.zeros((nL, 6))
|
|
|
|
if nL:
|
|
|
|
labels_out[:, 1:] = torch.from_numpy(labels)
|
|
|
|
|
|
|
|
# Convert
|
|
|
|
img = img[:, :, ::-1].transpose(2, 0, 1) # BGR to RGB, to 3x416x416
|
|
|
|
img = np.ascontiguousarray(img)
|
|
|
|
|
|
|
|
return torch.from_numpy(img), labels_out, self.img_files[index], shapes
|
|
|
|
|
|
|
|
@staticmethod
|
|
|
|
def collate_fn(batch):
|
|
|
|
img, label, path, shapes = zip(*batch) # transposed
|
|
|
|
for i, l in enumerate(label):
|
|
|
|
l[:, 0] = i # add target image index for build_targets()
|
|
|
|
return torch.stack(img, 0), torch.cat(label, 0), path, shapes
|
|
|
|
|
2021-01-05 11:54:09 +08:00
|
|
|
@staticmethod
|
|
|
|
def collate_fn4(batch):
|
|
|
|
img, label, path, shapes = zip(*batch) # transposed
|
|
|
|
n = len(shapes) // 4
|
|
|
|
img4, label4, path4, shapes4 = [], [], path[:n], shapes[:n]
|
|
|
|
|
|
|
|
ho = torch.tensor([[0., 0, 0, 1, 0, 0]])
|
|
|
|
wo = torch.tensor([[0., 0, 1, 0, 0, 0]])
|
|
|
|
s = torch.tensor([[1, 1, .5, .5, .5, .5]]) # scale
|
|
|
|
for i in range(n): # zidane torch.zeros(16,3,720,1280) # BCHW
|
|
|
|
i *= 4
|
|
|
|
if random.random() < 0.5:
|
|
|
|
im = F.interpolate(img[i].unsqueeze(0).float(), scale_factor=2., mode='bilinear', align_corners=False)[
|
|
|
|
0].type(img[i].type())
|
|
|
|
l = label[i]
|
|
|
|
else:
|
|
|
|
im = torch.cat((torch.cat((img[i], img[i + 1]), 1), torch.cat((img[i + 2], img[i + 3]), 1)), 2)
|
|
|
|
l = torch.cat((label[i], label[i + 1] + ho, label[i + 2] + wo, label[i + 3] + ho + wo), 0) * s
|
|
|
|
img4.append(im)
|
|
|
|
label4.append(l)
|
|
|
|
|
|
|
|
for i, l in enumerate(label4):
|
|
|
|
l[:, 0] = i # add target image index for build_targets()
|
|
|
|
|
|
|
|
return torch.stack(img4, 0), torch.cat(label4, 0), path4, shapes4
|
|
|
|
|
2020-05-30 08:04:54 +08:00
|
|
|
|
2020-07-24 13:49:54 +08:00
|
|
|
# Ancillary functions --------------------------------------------------------------------------------------------------
|
2020-05-30 08:04:54 +08:00
|
|
|
def load_image(self, index):
|
|
|
|
# loads 1 image from dataset, returns img, original hw, resized hw
|
|
|
|
img = self.imgs[index]
|
|
|
|
if img is None: # not cached
|
|
|
|
path = self.img_files[index]
|
|
|
|
img = cv2.imread(path) # BGR
|
|
|
|
assert img is not None, 'Image Not Found ' + path
|
|
|
|
h0, w0 = img.shape[:2] # orig hw
|
2021-04-21 21:50:28 +08:00
|
|
|
r = self.img_size / max(h0, w0) # ratio
|
|
|
|
if r != 1: # if sizes are not equal
|
|
|
|
img = cv2.resize(img, (int(w0 * r), int(h0 * r)),
|
|
|
|
interpolation=cv2.INTER_AREA if r < 1 and not self.augment else cv2.INTER_LINEAR)
|
2020-05-30 08:04:54 +08:00
|
|
|
return img, (h0, w0), img.shape[:2] # img, hw_original, hw_resized
|
|
|
|
else:
|
|
|
|
return self.imgs[index], self.img_hw0[index], self.img_hw[index] # img, hw_original, hw_resized
|
|
|
|
|
|
|
|
|
|
|
|
def augment_hsv(img, hgain=0.5, sgain=0.5, vgain=0.5):
|
2021-06-19 17:51:21 +08:00
|
|
|
if hgain or sgain or vgain:
|
|
|
|
r = np.random.uniform(-1, 1, 3) * [hgain, sgain, vgain] + 1 # random gains
|
|
|
|
hue, sat, val = cv2.split(cv2.cvtColor(img, cv2.COLOR_BGR2HSV))
|
|
|
|
dtype = img.dtype # uint8
|
|
|
|
|
|
|
|
x = np.arange(0, 256, dtype=r.dtype)
|
|
|
|
lut_hue = ((x * r[0]) % 180).astype(dtype)
|
|
|
|
lut_sat = np.clip(x * r[1], 0, 255).astype(dtype)
|
|
|
|
lut_val = np.clip(x * r[2], 0, 255).astype(dtype)
|
|
|
|
|
|
|
|
img_hsv = cv2.merge((cv2.LUT(hue, lut_hue), cv2.LUT(sat, lut_sat), cv2.LUT(val, lut_val)))
|
|
|
|
cv2.cvtColor(img_hsv, cv2.COLOR_HSV2BGR, dst=img) # no return needed
|
2020-05-30 08:04:54 +08:00
|
|
|
|
2021-01-27 09:30:42 +08:00
|
|
|
|
|
|
|
def hist_equalize(img, clahe=True, bgr=False):
|
|
|
|
# Equalize histogram on BGR image 'img' with img.shape(n,m,3) and range 0-255
|
|
|
|
yuv = cv2.cvtColor(img, cv2.COLOR_BGR2YUV if bgr else cv2.COLOR_RGB2YUV)
|
|
|
|
if clahe:
|
|
|
|
c = cv2.createCLAHE(clipLimit=2.0, tileGridSize=(8, 8))
|
|
|
|
yuv[:, :, 0] = c.apply(yuv[:, :, 0])
|
|
|
|
else:
|
|
|
|
yuv[:, :, 0] = cv2.equalizeHist(yuv[:, :, 0]) # equalize Y channel histogram
|
|
|
|
return cv2.cvtColor(yuv, cv2.COLOR_YUV2BGR if bgr else cv2.COLOR_YUV2RGB) # convert YUV image to RGB
|
2020-05-30 08:04:54 +08:00
|
|
|
|
|
|
|
|
|
|
|
def load_mosaic(self, index):
|
2021-01-05 11:54:09 +08:00
|
|
|
# loads images in a 4-mosaic
|
2020-05-30 08:04:54 +08:00
|
|
|
|
2021-02-12 13:22:45 +08:00
|
|
|
labels4, segments4 = [], []
|
2020-05-30 08:04:54 +08:00
|
|
|
s = self.img_size
|
2020-08-14 05:25:05 +08:00
|
|
|
yc, xc = [int(random.uniform(-x, 2 * s + x)) for x in self.mosaic_border] # mosaic center x, y
|
2021-03-06 14:58:26 +08:00
|
|
|
indices = [index] + random.choices(self.indices, k=3) # 3 additional image indices
|
2020-05-30 08:04:54 +08:00
|
|
|
for i, index in enumerate(indices):
|
|
|
|
# Load image
|
|
|
|
img, _, (h, w) = load_image(self, index)
|
|
|
|
|
|
|
|
# place img in img4
|
|
|
|
if i == 0: # top left
|
|
|
|
img4 = np.full((s * 2, s * 2, img.shape[2]), 114, dtype=np.uint8) # base image with 4 tiles
|
|
|
|
x1a, y1a, x2a, y2a = max(xc - w, 0), max(yc - h, 0), xc, yc # xmin, ymin, xmax, ymax (large image)
|
|
|
|
x1b, y1b, x2b, y2b = w - (x2a - x1a), h - (y2a - y1a), w, h # xmin, ymin, xmax, ymax (small image)
|
|
|
|
elif i == 1: # top right
|
|
|
|
x1a, y1a, x2a, y2a = xc, max(yc - h, 0), min(xc + w, s * 2), yc
|
|
|
|
x1b, y1b, x2b, y2b = 0, h - (y2a - y1a), min(w, x2a - x1a), h
|
|
|
|
elif i == 2: # bottom left
|
|
|
|
x1a, y1a, x2a, y2a = max(xc - w, 0), yc, xc, min(s * 2, yc + h)
|
2020-09-24 00:51:33 +08:00
|
|
|
x1b, y1b, x2b, y2b = w - (x2a - x1a), 0, w, min(y2a - y1a, h)
|
2020-05-30 08:04:54 +08:00
|
|
|
elif i == 3: # bottom right
|
|
|
|
x1a, y1a, x2a, y2a = xc, yc, min(xc + w, s * 2), min(s * 2, yc + h)
|
|
|
|
x1b, y1b, x2b, y2b = 0, 0, min(w, x2a - x1a), min(y2a - y1a, h)
|
|
|
|
|
|
|
|
img4[y1a:y2a, x1a:x2a] = img[y1b:y2b, x1b:x2b] # img4[ymin:ymax, xmin:xmax]
|
|
|
|
padw = x1a - x1b
|
|
|
|
padh = y1a - y1b
|
|
|
|
|
|
|
|
# Labels
|
2021-02-12 13:22:45 +08:00
|
|
|
labels, segments = self.labels[index].copy(), self.segments[index].copy()
|
2021-01-20 05:33:52 +08:00
|
|
|
if labels.size:
|
|
|
|
labels[:, 1:] = xywhn2xyxy(labels[:, 1:], w, h, padw, padh) # normalized xywh to pixel xyxy format
|
2021-02-12 13:22:45 +08:00
|
|
|
segments = [xyn2xy(x, w, h, padw, padh) for x in segments]
|
2020-05-30 08:04:54 +08:00
|
|
|
labels4.append(labels)
|
2021-02-12 13:22:45 +08:00
|
|
|
segments4.extend(segments)
|
2020-05-30 08:04:54 +08:00
|
|
|
|
|
|
|
# Concat/clip labels
|
2021-02-12 13:22:45 +08:00
|
|
|
labels4 = np.concatenate(labels4, 0)
|
|
|
|
for x in (labels4[:, 1:], *segments4):
|
|
|
|
np.clip(x, 0, 2 * s, out=x) # clip when using random_perspective()
|
|
|
|
# img4, labels4 = replicate(img4, labels4) # replicate
|
2020-06-30 08:10:33 +08:00
|
|
|
|
2020-05-30 08:04:54 +08:00
|
|
|
# Augment
|
2021-02-12 13:22:45 +08:00
|
|
|
img4, labels4 = random_perspective(img4, labels4, segments4,
|
2020-08-01 06:53:52 +08:00
|
|
|
degrees=self.hyp['degrees'],
|
|
|
|
translate=self.hyp['translate'],
|
|
|
|
scale=self.hyp['scale'],
|
|
|
|
shear=self.hyp['shear'],
|
2020-08-02 04:47:54 +08:00
|
|
|
perspective=self.hyp['perspective'],
|
2020-08-01 06:53:52 +08:00
|
|
|
border=self.mosaic_border) # border to remove
|
2020-05-30 08:04:54 +08:00
|
|
|
|
|
|
|
return img4, labels4
|
|
|
|
|
|
|
|
|
2021-01-05 11:54:09 +08:00
|
|
|
def load_mosaic9(self, index):
|
|
|
|
# loads images in a 9-mosaic
|
|
|
|
|
2021-02-12 13:22:45 +08:00
|
|
|
labels9, segments9 = [], []
|
2021-01-05 11:54:09 +08:00
|
|
|
s = self.img_size
|
2021-03-06 14:58:26 +08:00
|
|
|
indices = [index] + random.choices(self.indices, k=8) # 8 additional image indices
|
2021-01-05 11:54:09 +08:00
|
|
|
for i, index in enumerate(indices):
|
|
|
|
# Load image
|
|
|
|
img, _, (h, w) = load_image(self, index)
|
|
|
|
|
|
|
|
# place img in img9
|
|
|
|
if i == 0: # center
|
|
|
|
img9 = np.full((s * 3, s * 3, img.shape[2]), 114, dtype=np.uint8) # base image with 4 tiles
|
|
|
|
h0, w0 = h, w
|
|
|
|
c = s, s, s + w, s + h # xmin, ymin, xmax, ymax (base) coordinates
|
|
|
|
elif i == 1: # top
|
|
|
|
c = s, s - h, s + w, s
|
|
|
|
elif i == 2: # top right
|
|
|
|
c = s + wp, s - h, s + wp + w, s
|
|
|
|
elif i == 3: # right
|
|
|
|
c = s + w0, s, s + w0 + w, s + h
|
|
|
|
elif i == 4: # bottom right
|
|
|
|
c = s + w0, s + hp, s + w0 + w, s + hp + h
|
|
|
|
elif i == 5: # bottom
|
|
|
|
c = s + w0 - w, s + h0, s + w0, s + h0 + h
|
|
|
|
elif i == 6: # bottom left
|
|
|
|
c = s + w0 - wp - w, s + h0, s + w0 - wp, s + h0 + h
|
|
|
|
elif i == 7: # left
|
|
|
|
c = s - w, s + h0 - h, s, s + h0
|
|
|
|
elif i == 8: # top left
|
|
|
|
c = s - w, s + h0 - hp - h, s, s + h0 - hp
|
|
|
|
|
|
|
|
padx, pady = c[:2]
|
|
|
|
x1, y1, x2, y2 = [max(x, 0) for x in c] # allocate coords
|
|
|
|
|
|
|
|
# Labels
|
2021-02-12 13:22:45 +08:00
|
|
|
labels, segments = self.labels[index].copy(), self.segments[index].copy()
|
2021-01-20 05:33:52 +08:00
|
|
|
if labels.size:
|
|
|
|
labels[:, 1:] = xywhn2xyxy(labels[:, 1:], w, h, padx, pady) # normalized xywh to pixel xyxy format
|
2021-02-12 13:22:45 +08:00
|
|
|
segments = [xyn2xy(x, w, h, padx, pady) for x in segments]
|
2021-01-05 11:54:09 +08:00
|
|
|
labels9.append(labels)
|
2021-02-12 13:22:45 +08:00
|
|
|
segments9.extend(segments)
|
2021-01-05 11:54:09 +08:00
|
|
|
|
|
|
|
# Image
|
|
|
|
img9[y1:y2, x1:x2] = img[y1 - pady:, x1 - padx:] # img9[ymin:ymax, xmin:xmax]
|
|
|
|
hp, wp = h, w # height, width previous
|
|
|
|
|
|
|
|
# Offset
|
2021-02-12 13:22:45 +08:00
|
|
|
yc, xc = [int(random.uniform(0, s)) for _ in self.mosaic_border] # mosaic center x, y
|
2021-01-05 11:54:09 +08:00
|
|
|
img9 = img9[yc:yc + 2 * s, xc:xc + 2 * s]
|
|
|
|
|
|
|
|
# Concat/clip labels
|
2021-02-12 13:22:45 +08:00
|
|
|
labels9 = np.concatenate(labels9, 0)
|
|
|
|
labels9[:, [1, 3]] -= xc
|
|
|
|
labels9[:, [2, 4]] -= yc
|
|
|
|
c = np.array([xc, yc]) # centers
|
|
|
|
segments9 = [x - c for x in segments9]
|
2021-01-05 11:54:09 +08:00
|
|
|
|
2021-02-12 13:22:45 +08:00
|
|
|
for x in (labels9[:, 1:], *segments9):
|
|
|
|
np.clip(x, 0, 2 * s, out=x) # clip when using random_perspective()
|
|
|
|
# img9, labels9 = replicate(img9, labels9) # replicate
|
2021-01-05 11:54:09 +08:00
|
|
|
|
|
|
|
# Augment
|
2021-02-12 13:22:45 +08:00
|
|
|
img9, labels9 = random_perspective(img9, labels9, segments9,
|
2021-01-05 11:54:09 +08:00
|
|
|
degrees=self.hyp['degrees'],
|
|
|
|
translate=self.hyp['translate'],
|
|
|
|
scale=self.hyp['scale'],
|
|
|
|
shear=self.hyp['shear'],
|
|
|
|
perspective=self.hyp['perspective'],
|
|
|
|
border=self.mosaic_border) # border to remove
|
|
|
|
|
|
|
|
return img9, labels9
|
|
|
|
|
|
|
|
|
2020-06-30 08:10:33 +08:00
|
|
|
def replicate(img, labels):
|
|
|
|
# Replicate labels
|
|
|
|
h, w = img.shape[:2]
|
|
|
|
boxes = labels[:, 1:].astype(int)
|
|
|
|
x1, y1, x2, y2 = boxes.T
|
|
|
|
s = ((x2 - x1) + (y2 - y1)) / 2 # side length (pixels)
|
|
|
|
for i in s.argsort()[:round(s.size * 0.5)]: # smallest indices
|
|
|
|
x1b, y1b, x2b, y2b = boxes[i]
|
|
|
|
bh, bw = y2b - y1b, x2b - x1b
|
|
|
|
yc, xc = int(random.uniform(0, h - bh)), int(random.uniform(0, w - bw)) # offset x, y
|
|
|
|
x1a, y1a, x2a, y2a = [xc, yc, xc + bw, yc + bh]
|
|
|
|
img[y1a:y2a, x1a:x2a] = img[y1b:y2b, x1b:x2b] # img4[ymin:ymax, xmin:xmax]
|
|
|
|
labels = np.append(labels, [[labels[i, 0], x1a, y1a, x2a, y2a]], axis=0)
|
|
|
|
|
|
|
|
return img, labels
|
|
|
|
|
|
|
|
|
2021-01-31 05:47:23 +08:00
|
|
|
def letterbox(img, new_shape=(640, 640), color=(114, 114, 114), auto=True, scaleFill=False, scaleup=True, stride=32):
|
|
|
|
# Resize and pad image while meeting stride-multiple constraints
|
2020-05-30 08:04:54 +08:00
|
|
|
shape = img.shape[:2] # current shape [height, width]
|
|
|
|
if isinstance(new_shape, int):
|
|
|
|
new_shape = (new_shape, new_shape)
|
|
|
|
|
|
|
|
# Scale ratio (new / old)
|
|
|
|
r = min(new_shape[0] / shape[0], new_shape[1] / shape[1])
|
|
|
|
if not scaleup: # only scale down, do not scale up (for better test mAP)
|
|
|
|
r = min(r, 1.0)
|
|
|
|
|
|
|
|
# Compute padding
|
|
|
|
ratio = r, r # width, height ratios
|
|
|
|
new_unpad = int(round(shape[1] * r)), int(round(shape[0] * r))
|
|
|
|
dw, dh = new_shape[1] - new_unpad[0], new_shape[0] - new_unpad[1] # wh padding
|
|
|
|
if auto: # minimum rectangle
|
2021-01-31 05:47:23 +08:00
|
|
|
dw, dh = np.mod(dw, stride), np.mod(dh, stride) # wh padding
|
2020-05-30 08:04:54 +08:00
|
|
|
elif scaleFill: # stretch
|
|
|
|
dw, dh = 0.0, 0.0
|
2020-07-03 13:09:21 +08:00
|
|
|
new_unpad = (new_shape[1], new_shape[0])
|
|
|
|
ratio = new_shape[1] / shape[1], new_shape[0] / shape[0] # width, height ratios
|
2020-05-30 08:04:54 +08:00
|
|
|
|
|
|
|
dw /= 2 # divide padding into 2 sides
|
|
|
|
dh /= 2
|
|
|
|
|
|
|
|
if shape[::-1] != new_unpad: # resize
|
|
|
|
img = cv2.resize(img, new_unpad, interpolation=cv2.INTER_LINEAR)
|
|
|
|
top, bottom = int(round(dh - 0.1)), int(round(dh + 0.1))
|
|
|
|
left, right = int(round(dw - 0.1)), int(round(dw + 0.1))
|
|
|
|
img = cv2.copyMakeBorder(img, top, bottom, left, right, cv2.BORDER_CONSTANT, value=color) # add border
|
|
|
|
return img, ratio, (dw, dh)
|
|
|
|
|
|
|
|
|
2021-02-12 13:22:45 +08:00
|
|
|
def random_perspective(img, targets=(), segments=(), degrees=10, translate=.1, scale=.1, shear=10, perspective=0.0,
|
|
|
|
border=(0, 0)):
|
2020-05-30 08:04:54 +08:00
|
|
|
# torchvision.transforms.RandomAffine(degrees=(-10, 10), translate=(.1, .1), scale=(.9, 1.1), shear=(-10, 10))
|
|
|
|
# targets = [cls, xyxy]
|
|
|
|
|
2020-06-28 04:02:01 +08:00
|
|
|
height = img.shape[0] + border[0] * 2 # shape(h,w,c)
|
|
|
|
width = img.shape[1] + border[1] * 2
|
2020-05-30 08:04:54 +08:00
|
|
|
|
2020-08-01 06:53:52 +08:00
|
|
|
# Center
|
|
|
|
C = np.eye(3)
|
|
|
|
C[0, 2] = -img.shape[1] / 2 # x translation (pixels)
|
|
|
|
C[1, 2] = -img.shape[0] / 2 # y translation (pixels)
|
|
|
|
|
|
|
|
# Perspective
|
|
|
|
P = np.eye(3)
|
|
|
|
P[2, 0] = random.uniform(-perspective, perspective) # x perspective (about y)
|
|
|
|
P[2, 1] = random.uniform(-perspective, perspective) # y perspective (about x)
|
|
|
|
|
2020-05-30 08:04:54 +08:00
|
|
|
# Rotation and Scale
|
|
|
|
R = np.eye(3)
|
|
|
|
a = random.uniform(-degrees, degrees)
|
|
|
|
# a += random.choice([-180, -90, 0, 90]) # add 90deg rotations to small rotations
|
|
|
|
s = random.uniform(1 - scale, 1 + scale)
|
|
|
|
# s = 2 ** random.uniform(-scale, scale)
|
2020-08-01 06:53:52 +08:00
|
|
|
R[:2] = cv2.getRotationMatrix2D(angle=a, center=(0, 0), scale=s)
|
2020-05-30 08:04:54 +08:00
|
|
|
|
|
|
|
# Shear
|
|
|
|
S = np.eye(3)
|
|
|
|
S[0, 1] = math.tan(random.uniform(-shear, shear) * math.pi / 180) # x shear (deg)
|
|
|
|
S[1, 0] = math.tan(random.uniform(-shear, shear) * math.pi / 180) # y shear (deg)
|
|
|
|
|
2020-08-01 06:53:52 +08:00
|
|
|
# Translation
|
|
|
|
T = np.eye(3)
|
|
|
|
T[0, 2] = random.uniform(0.5 - translate, 0.5 + translate) * width # x translation (pixels)
|
|
|
|
T[1, 2] = random.uniform(0.5 - translate, 0.5 + translate) * height # y translation (pixels)
|
|
|
|
|
2020-05-30 08:04:54 +08:00
|
|
|
# Combined rotation matrix
|
2020-08-01 06:53:52 +08:00
|
|
|
M = T @ S @ R @ P @ C # order of operations (right to left) is IMPORTANT
|
2020-06-28 04:02:01 +08:00
|
|
|
if (border[0] != 0) or (border[1] != 0) or (M != np.eye(3)).any(): # image changed
|
2020-08-01 06:53:52 +08:00
|
|
|
if perspective:
|
|
|
|
img = cv2.warpPerspective(img, M, dsize=(width, height), borderValue=(114, 114, 114))
|
|
|
|
else: # affine
|
|
|
|
img = cv2.warpAffine(img, M[:2], dsize=(width, height), borderValue=(114, 114, 114))
|
|
|
|
|
|
|
|
# Visualize
|
|
|
|
# import matplotlib.pyplot as plt
|
|
|
|
# ax = plt.subplots(1, 2, figsize=(12, 6))[1].ravel()
|
|
|
|
# ax[0].imshow(img[:, :, ::-1]) # base
|
|
|
|
# ax[1].imshow(img2[:, :, ::-1]) # warped
|
2020-05-30 08:04:54 +08:00
|
|
|
|
|
|
|
# Transform label coordinates
|
|
|
|
n = len(targets)
|
|
|
|
if n:
|
2021-02-12 13:22:45 +08:00
|
|
|
use_segments = any(x.any() for x in segments)
|
|
|
|
new = np.zeros((n, 4))
|
|
|
|
if use_segments: # warp segments
|
|
|
|
segments = resample_segments(segments) # upsample
|
|
|
|
for i, segment in enumerate(segments):
|
|
|
|
xy = np.ones((len(segment), 3))
|
|
|
|
xy[:, :2] = segment
|
|
|
|
xy = xy @ M.T # transform
|
|
|
|
xy = xy[:, :2] / xy[:, 2:3] if perspective else xy[:, :2] # perspective rescale or affine
|
|
|
|
|
|
|
|
# clip
|
|
|
|
new[i] = segment2box(xy, width, height)
|
|
|
|
|
|
|
|
else: # warp boxes
|
|
|
|
xy = np.ones((n * 4, 3))
|
|
|
|
xy[:, :2] = targets[:, [1, 2, 3, 4, 1, 4, 3, 2]].reshape(n * 4, 2) # x1y1, x2y2, x1y2, x2y1
|
|
|
|
xy = xy @ M.T # transform
|
|
|
|
xy = (xy[:, :2] / xy[:, 2:3] if perspective else xy[:, :2]).reshape(n, 8) # perspective rescale or affine
|
|
|
|
|
|
|
|
# create new boxes
|
|
|
|
x = xy[:, [0, 2, 4, 6]]
|
|
|
|
y = xy[:, [1, 3, 5, 7]]
|
|
|
|
new = np.concatenate((x.min(1), y.min(1), x.max(1), y.max(1))).reshape(4, n).T
|
|
|
|
|
|
|
|
# clip
|
|
|
|
new[:, [0, 2]] = new[:, [0, 2]].clip(0, width)
|
|
|
|
new[:, [1, 3]] = new[:, [1, 3]].clip(0, height)
|
2020-05-30 08:04:54 +08:00
|
|
|
|
2020-07-24 13:49:54 +08:00
|
|
|
# filter candidates
|
2021-02-12 13:22:45 +08:00
|
|
|
i = box_candidates(box1=targets[:, 1:5].T * s, box2=new.T, area_thr=0.01 if use_segments else 0.10)
|
2020-05-30 08:04:54 +08:00
|
|
|
targets = targets[i]
|
2021-02-12 13:22:45 +08:00
|
|
|
targets[:, 1:5] = new[i]
|
2020-05-30 08:04:54 +08:00
|
|
|
|
|
|
|
return img, targets
|
|
|
|
|
|
|
|
|
2021-01-05 11:54:09 +08:00
|
|
|
def box_candidates(box1, box2, wh_thr=2, ar_thr=20, area_thr=0.1, eps=1e-16): # box1(4,n), box2(4,n)
|
2020-07-24 13:49:54 +08:00
|
|
|
# Compute candidate boxes: box1 before augment, box2 after augment, wh_thr (pixels), aspect_ratio_thr, area_ratio
|
|
|
|
w1, h1 = box1[2] - box1[0], box1[3] - box1[1]
|
|
|
|
w2, h2 = box2[2] - box2[0], box2[3] - box2[1]
|
2021-01-05 11:54:09 +08:00
|
|
|
ar = np.maximum(w2 / (h2 + eps), h2 / (w2 + eps)) # aspect ratio
|
|
|
|
return (w2 > wh_thr) & (h2 > wh_thr) & (w2 * h2 / (w1 * h1 + eps) > area_thr) & (ar < ar_thr) # candidates
|
2020-07-24 13:49:54 +08:00
|
|
|
|
|
|
|
|
2020-05-30 08:04:54 +08:00
|
|
|
def cutout(image, labels):
|
2020-07-24 13:49:54 +08:00
|
|
|
# Applies image cutout augmentation https://arxiv.org/abs/1708.04552
|
2020-05-30 08:04:54 +08:00
|
|
|
h, w = image.shape[:2]
|
|
|
|
|
|
|
|
def bbox_ioa(box1, box2):
|
|
|
|
# Returns the intersection over box2 area given box1, box2. box1 is 4, box2 is nx4. boxes are x1y1x2y2
|
|
|
|
box2 = box2.transpose()
|
|
|
|
|
|
|
|
# Get the coordinates of bounding boxes
|
|
|
|
b1_x1, b1_y1, b1_x2, b1_y2 = box1[0], box1[1], box1[2], box1[3]
|
|
|
|
b2_x1, b2_y1, b2_x2, b2_y2 = box2[0], box2[1], box2[2], box2[3]
|
|
|
|
|
|
|
|
# Intersection area
|
|
|
|
inter_area = (np.minimum(b1_x2, b2_x2) - np.maximum(b1_x1, b2_x1)).clip(0) * \
|
|
|
|
(np.minimum(b1_y2, b2_y2) - np.maximum(b1_y1, b2_y1)).clip(0)
|
|
|
|
|
|
|
|
# box2 area
|
|
|
|
box2_area = (b2_x2 - b2_x1) * (b2_y2 - b2_y1) + 1e-16
|
|
|
|
|
|
|
|
# Intersection over box2 area
|
|
|
|
return inter_area / box2_area
|
|
|
|
|
|
|
|
# create random masks
|
|
|
|
scales = [0.5] * 1 + [0.25] * 2 + [0.125] * 4 + [0.0625] * 8 + [0.03125] * 16 # image size fraction
|
|
|
|
for s in scales:
|
|
|
|
mask_h = random.randint(1, int(h * s))
|
|
|
|
mask_w = random.randint(1, int(w * s))
|
|
|
|
|
|
|
|
# box
|
|
|
|
xmin = max(0, random.randint(0, w) - mask_w // 2)
|
|
|
|
ymin = max(0, random.randint(0, h) - mask_h // 2)
|
|
|
|
xmax = min(w, xmin + mask_w)
|
|
|
|
ymax = min(h, ymin + mask_h)
|
|
|
|
|
|
|
|
# apply random color mask
|
|
|
|
image[ymin:ymax, xmin:xmax] = [random.randint(64, 191) for _ in range(3)]
|
|
|
|
|
|
|
|
# return unobscured labels
|
|
|
|
if len(labels) and s > 0.03:
|
|
|
|
box = np.array([xmin, ymin, xmax, ymax], dtype=np.float32)
|
|
|
|
ioa = bbox_ioa(box, labels[:, 1:5]) # intersection over area
|
|
|
|
labels = labels[ioa < 0.60] # remove >60% obscured labels
|
|
|
|
|
|
|
|
return labels
|
|
|
|
|
|
|
|
|
2020-07-24 13:49:54 +08:00
|
|
|
def create_folder(path='./new'):
|
2020-05-30 08:04:54 +08:00
|
|
|
# Create folder
|
|
|
|
if os.path.exists(path):
|
|
|
|
shutil.rmtree(path) # delete output folder
|
|
|
|
os.makedirs(path) # make new output folder
|
2020-11-09 19:24:11 +08:00
|
|
|
|
|
|
|
|
|
|
|
def flatten_recursive(path='../coco128'):
|
|
|
|
# Flatten a recursive directory by bringing all files to top level
|
|
|
|
new_path = Path(path + '_flat')
|
|
|
|
create_folder(new_path)
|
|
|
|
for file in tqdm(glob.glob(str(Path(path)) + '/**/*.*', recursive=True)):
|
|
|
|
shutil.copyfile(file, new_path / Path(file).name)
|
2020-11-24 00:18:21 +08:00
|
|
|
|
|
|
|
|
2020-11-24 23:13:04 +08:00
|
|
|
def extract_boxes(path='../coco128/'): # from utils.datasets import *; extract_boxes('../coco128')
|
|
|
|
# Convert detection dataset into classification dataset, with one directory per class
|
|
|
|
|
|
|
|
path = Path(path) # images dir
|
|
|
|
shutil.rmtree(path / 'classifier') if (path / 'classifier').is_dir() else None # remove existing
|
|
|
|
files = list(path.rglob('*.*'))
|
|
|
|
n = len(files) # number of files
|
|
|
|
for im_file in tqdm(files, total=n):
|
|
|
|
if im_file.suffix[1:] in img_formats:
|
|
|
|
# image
|
|
|
|
im = cv2.imread(str(im_file))[..., ::-1] # BGR to RGB
|
|
|
|
h, w = im.shape[:2]
|
|
|
|
|
|
|
|
# labels
|
|
|
|
lb_file = Path(img2label_paths([str(im_file)])[0])
|
|
|
|
if Path(lb_file).exists():
|
|
|
|
with open(lb_file, 'r') as f:
|
2020-11-29 18:59:52 +08:00
|
|
|
lb = np.array([x.split() for x in f.read().strip().splitlines()], dtype=np.float32) # labels
|
2020-11-24 23:13:04 +08:00
|
|
|
|
|
|
|
for j, x in enumerate(lb):
|
|
|
|
c = int(x[0]) # class
|
|
|
|
f = (path / 'classifier') / f'{c}' / f'{path.stem}_{im_file.stem}_{j}.jpg' # new filename
|
|
|
|
if not f.parent.is_dir():
|
|
|
|
f.parent.mkdir(parents=True)
|
|
|
|
|
|
|
|
b = x[1:] * [w, h, w, h] # box
|
|
|
|
# b[2:] = b[2:].max() # rectangle to square
|
|
|
|
b[2:] = b[2:] * 1.2 + 3 # pad
|
|
|
|
b = xywh2xyxy(b.reshape(-1, 4)).ravel().astype(np.int)
|
|
|
|
|
|
|
|
b[[0, 2]] = np.clip(b[[0, 2]], 0, w) # clip boxes outside of image
|
|
|
|
b[[1, 3]] = np.clip(b[[1, 3]], 0, h)
|
|
|
|
assert cv2.imwrite(str(f), im[b[1]:b[3], b[0]:b[2]]), f'box failure in {f}'
|
|
|
|
|
2021-04-12 00:53:40 +08:00
|
|
|
|
Update autosplit() with annotated_only option (#2466)
* Be able to create dataset from annotated images only
Add the ability to create a dataset/splits only with images that have an annotation file, i.e a .txt file, associated to it. As we talked about this, the absence of a txt file could mean two things:
* either the image wasn't yet labelled by someone,
* either there is no object to detect.
When it's easy to create small datasets, when you have to create datasets with thousands of images (and more coming), it's hard to track where you at and you don't want to wait to have all of them annotated before starting to train. Which means some images would lack txt files and annotations, resulting in label inconsistency as you say in #2313. By adding the annotated_only argument to the function, people could create, if they want to, datasets/splits only with images that were labelled, for sure.
* Cleanup and update print()
Co-authored-by: Glenn Jocher <glenn.jocher@ultralytics.com>
2021-03-15 08:11:27 +08:00
|
|
|
def autosplit(path='../coco128', weights=(0.9, 0.1, 0.0), annotated_only=False):
|
2020-11-24 01:35:25 +08:00
|
|
|
""" Autosplit a dataset into train/val/test splits and save path/autosplit_*.txt files
|
Update autosplit() with annotated_only option (#2466)
* Be able to create dataset from annotated images only
Add the ability to create a dataset/splits only with images that have an annotation file, i.e a .txt file, associated to it. As we talked about this, the absence of a txt file could mean two things:
* either the image wasn't yet labelled by someone,
* either there is no object to detect.
When it's easy to create small datasets, when you have to create datasets with thousands of images (and more coming), it's hard to track where you at and you don't want to wait to have all of them annotated before starting to train. Which means some images would lack txt files and annotations, resulting in label inconsistency as you say in #2313. By adding the annotated_only argument to the function, people could create, if they want to, datasets/splits only with images that were labelled, for sure.
* Cleanup and update print()
Co-authored-by: Glenn Jocher <glenn.jocher@ultralytics.com>
2021-03-15 08:11:27 +08:00
|
|
|
Usage: from utils.datasets import *; autosplit('../coco128')
|
|
|
|
Arguments
|
|
|
|
path: Path to images directory
|
|
|
|
weights: Train, val, test weights (list)
|
|
|
|
annotated_only: Only use images with an annotated txt file
|
2020-11-24 00:18:21 +08:00
|
|
|
"""
|
|
|
|
path = Path(path) # images dir
|
Update autosplit() with annotated_only option (#2466)
* Be able to create dataset from annotated images only
Add the ability to create a dataset/splits only with images that have an annotation file, i.e a .txt file, associated to it. As we talked about this, the absence of a txt file could mean two things:
* either the image wasn't yet labelled by someone,
* either there is no object to detect.
When it's easy to create small datasets, when you have to create datasets with thousands of images (and more coming), it's hard to track where you at and you don't want to wait to have all of them annotated before starting to train. Which means some images would lack txt files and annotations, resulting in label inconsistency as you say in #2313. By adding the annotated_only argument to the function, people could create, if they want to, datasets/splits only with images that were labelled, for sure.
* Cleanup and update print()
Co-authored-by: Glenn Jocher <glenn.jocher@ultralytics.com>
2021-03-15 08:11:27 +08:00
|
|
|
files = sum([list(path.rglob(f"*.{img_ext}")) for img_ext in img_formats], []) # image files only
|
2020-11-24 01:35:25 +08:00
|
|
|
n = len(files) # number of files
|
|
|
|
indices = random.choices([0, 1, 2], weights=weights, k=n) # assign each image to a split
|
Update autosplit() with annotated_only option (#2466)
* Be able to create dataset from annotated images only
Add the ability to create a dataset/splits only with images that have an annotation file, i.e a .txt file, associated to it. As we talked about this, the absence of a txt file could mean two things:
* either the image wasn't yet labelled by someone,
* either there is no object to detect.
When it's easy to create small datasets, when you have to create datasets with thousands of images (and more coming), it's hard to track where you at and you don't want to wait to have all of them annotated before starting to train. Which means some images would lack txt files and annotations, resulting in label inconsistency as you say in #2313. By adding the annotated_only argument to the function, people could create, if they want to, datasets/splits only with images that were labelled, for sure.
* Cleanup and update print()
Co-authored-by: Glenn Jocher <glenn.jocher@ultralytics.com>
2021-03-15 08:11:27 +08:00
|
|
|
|
2020-11-24 00:18:21 +08:00
|
|
|
txt = ['autosplit_train.txt', 'autosplit_val.txt', 'autosplit_test.txt'] # 3 txt files
|
|
|
|
[(path / x).unlink() for x in txt if (path / x).exists()] # remove existing
|
Update autosplit() with annotated_only option (#2466)
* Be able to create dataset from annotated images only
Add the ability to create a dataset/splits only with images that have an annotation file, i.e a .txt file, associated to it. As we talked about this, the absence of a txt file could mean two things:
* either the image wasn't yet labelled by someone,
* either there is no object to detect.
When it's easy to create small datasets, when you have to create datasets with thousands of images (and more coming), it's hard to track where you at and you don't want to wait to have all of them annotated before starting to train. Which means some images would lack txt files and annotations, resulting in label inconsistency as you say in #2313. By adding the annotated_only argument to the function, people could create, if they want to, datasets/splits only with images that were labelled, for sure.
* Cleanup and update print()
Co-authored-by: Glenn Jocher <glenn.jocher@ultralytics.com>
2021-03-15 08:11:27 +08:00
|
|
|
|
|
|
|
print(f'Autosplitting images from {path}' + ', using *.txt labeled images only' * annotated_only)
|
2020-11-24 01:35:25 +08:00
|
|
|
for i, img in tqdm(zip(indices, files), total=n):
|
Update autosplit() with annotated_only option (#2466)
* Be able to create dataset from annotated images only
Add the ability to create a dataset/splits only with images that have an annotation file, i.e a .txt file, associated to it. As we talked about this, the absence of a txt file could mean two things:
* either the image wasn't yet labelled by someone,
* either there is no object to detect.
When it's easy to create small datasets, when you have to create datasets with thousands of images (and more coming), it's hard to track where you at and you don't want to wait to have all of them annotated before starting to train. Which means some images would lack txt files and annotations, resulting in label inconsistency as you say in #2313. By adding the annotated_only argument to the function, people could create, if they want to, datasets/splits only with images that were labelled, for sure.
* Cleanup and update print()
Co-authored-by: Glenn Jocher <glenn.jocher@ultralytics.com>
2021-03-15 08:11:27 +08:00
|
|
|
if not annotated_only or Path(img2label_paths([str(img)])[0]).exists(): # check label
|
2020-11-24 00:18:21 +08:00
|
|
|
with open(path / txt[i], 'a') as f:
|
|
|
|
f.write(str(img) + '\n') # add image to txt file
|
2021-06-09 00:00:21 +08:00
|
|
|
|
|
|
|
|
2021-06-16 17:12:15 +08:00
|
|
|
def verify_image_label(args):
|
2021-06-09 00:00:21 +08:00
|
|
|
# Verify one image-label pair
|
2021-06-16 17:12:15 +08:00
|
|
|
im_file, lb_file, prefix = args
|
2021-06-09 00:00:21 +08:00
|
|
|
nm, nf, ne, nc = 0, 0, 0, 0 # number missing, found, empty, corrupt
|
|
|
|
try:
|
|
|
|
# verify images
|
|
|
|
im = Image.open(im_file)
|
|
|
|
im.verify() # PIL verify
|
|
|
|
shape = exif_size(im) # image size
|
|
|
|
assert (shape[0] > 9) & (shape[1] > 9), f'image size {shape} <10 pixels'
|
|
|
|
assert im.format.lower() in img_formats, f'invalid image format {im.format}'
|
2021-06-16 19:31:26 +08:00
|
|
|
if im.format.lower() in ('jpg', 'jpeg'):
|
|
|
|
with open(im_file, 'rb') as f:
|
|
|
|
f.seek(-2, 2)
|
|
|
|
assert f.read() == b'\xff\xd9', 'corrupted JPEG'
|
2021-06-09 00:00:21 +08:00
|
|
|
|
|
|
|
# verify labels
|
2021-06-16 17:12:15 +08:00
|
|
|
segments = [] # instance segments
|
2021-06-09 00:00:21 +08:00
|
|
|
if os.path.isfile(lb_file):
|
|
|
|
nf = 1 # label found
|
|
|
|
with open(lb_file, 'r') as f:
|
|
|
|
l = [x.split() for x in f.read().strip().splitlines() if len(x)]
|
|
|
|
if any([len(x) > 8 for x in l]): # is segment
|
|
|
|
classes = np.array([x[0] for x in l], dtype=np.float32)
|
|
|
|
segments = [np.array(x[1:], dtype=np.float32).reshape(-1, 2) for x in l] # (cls, xy1...)
|
|
|
|
l = np.concatenate((classes.reshape(-1, 1), segments2boxes(segments)), 1) # (cls, xywh)
|
|
|
|
l = np.array(l, dtype=np.float32)
|
|
|
|
if len(l):
|
|
|
|
assert l.shape[1] == 5, 'labels require 5 columns each'
|
|
|
|
assert (l >= 0).all(), 'negative labels'
|
|
|
|
assert (l[:, 1:] <= 1).all(), 'non-normalized or out of bounds coordinate labels'
|
|
|
|
assert np.unique(l, axis=0).shape[0] == l.shape[0], 'duplicate labels'
|
|
|
|
else:
|
|
|
|
ne = 1 # label empty
|
|
|
|
l = np.zeros((0, 5), dtype=np.float32)
|
|
|
|
else:
|
|
|
|
nm = 1 # label missing
|
|
|
|
l = np.zeros((0, 5), dtype=np.float32)
|
2021-06-18 16:21:47 +08:00
|
|
|
return im_file, l, shape, segments, nm, nf, ne, nc, ''
|
2021-06-09 00:00:21 +08:00
|
|
|
except Exception as e:
|
|
|
|
nc = 1
|
2021-06-18 16:21:47 +08:00
|
|
|
msg = f'{prefix}WARNING: Ignoring corrupted image and/or label {im_file}: {e}'
|
|
|
|
return [None, None, None, None, nm, nf, ne, nc, msg]
|
2021-06-09 05:09:45 +08:00
|
|
|
|
|
|
|
|
2021-06-09 16:56:11 +08:00
|
|
|
def dataset_stats(path='coco128.yaml', autodownload=False, verbose=False):
|
2021-06-09 05:09:45 +08:00
|
|
|
""" Return dataset statistics dictionary with images and instances counts per split per class
|
2021-06-09 16:56:11 +08:00
|
|
|
Usage: from utils.datasets import *; dataset_stats('coco128.yaml', verbose=True)
|
2021-06-09 05:09:45 +08:00
|
|
|
Arguments
|
|
|
|
path: Path to data.yaml
|
2021-06-09 16:56:11 +08:00
|
|
|
autodownload: Attempt to download dataset if not found locally
|
2021-06-09 05:09:45 +08:00
|
|
|
verbose: Print stats dictionary
|
|
|
|
"""
|
2021-06-17 19:59:52 +08:00
|
|
|
|
|
|
|
def round_labels(labels):
|
|
|
|
# Update labels to integer class and 6 decimal place floats
|
|
|
|
return [[int(c), *[round(x, 6) for x in points]] for c, *points in labels]
|
|
|
|
|
2021-06-15 19:21:04 +08:00
|
|
|
with open(check_file(path)) as f:
|
2021-06-09 05:09:45 +08:00
|
|
|
data = yaml.safe_load(f) # data dict
|
2021-06-09 16:56:11 +08:00
|
|
|
check_dataset(data, autodownload) # download dataset if missing
|
2021-06-09 05:09:45 +08:00
|
|
|
nc = data['nc'] # number of classes
|
|
|
|
stats = {'nc': nc, 'names': data['names']} # statistics dictionary
|
|
|
|
for split in 'train', 'val', 'test':
|
|
|
|
if split not in data:
|
|
|
|
stats[split] = None # i.e. no test set
|
|
|
|
continue
|
|
|
|
x = []
|
|
|
|
dataset = LoadImagesAndLabels(data[split], augment=False, rect=True) # load dataset
|
2021-06-12 19:26:41 +08:00
|
|
|
if split == 'train':
|
|
|
|
cache_path = Path(dataset.label_files[0]).parent.with_suffix('.cache') # *.cache path
|
2021-06-09 05:09:45 +08:00
|
|
|
for label in tqdm(dataset.labels, total=dataset.n, desc='Statistics'):
|
|
|
|
x.append(np.bincount(label[:, 0].astype(int), minlength=nc))
|
|
|
|
x = np.array(x) # shape(128x80)
|
2021-06-12 19:26:41 +08:00
|
|
|
stats[split] = {'instance_stats': {'total': int(x.sum()), 'per_class': x.sum(0).tolist()},
|
|
|
|
'image_stats': {'total': dataset.n, 'unlabelled': int(np.all(x == 0, 1).sum()),
|
|
|
|
'per_class': (x > 0).sum(0).tolist()},
|
2021-06-17 19:59:52 +08:00
|
|
|
'labels': [{str(Path(k).name): round_labels(v.tolist())} for k, v in
|
|
|
|
zip(dataset.img_files, dataset.labels)]}
|
2021-06-12 19:26:41 +08:00
|
|
|
|
|
|
|
# Save, print and return
|
|
|
|
with open(cache_path.with_suffix('.json'), 'w') as f:
|
|
|
|
json.dump(stats, f) # save stats *.json
|
2021-06-09 05:09:45 +08:00
|
|
|
if verbose:
|
2021-06-17 19:59:52 +08:00
|
|
|
print(json.dumps(stats, indent=2, sort_keys=False))
|
|
|
|
# print(yaml.dump([stats], sort_keys=False, default_flow_style=False))
|
2021-06-09 05:09:45 +08:00
|
|
|
return stats
|