Copy-Paste augmentation for YOLOv5 (#3845)
* Copy-paste augmentation initial commit * if any segments * Add obscuration rejection * Add copy_paste hyperparameter * Update commentspull/3879/head
parent
25d1f2932c
commit
c6c88dc601
|
@ -36,3 +36,4 @@ flipud: 0.00856
|
|||
fliplr: 0.5
|
||||
mosaic: 1.0
|
||||
mixup: 0.243
|
||||
copy_paste: 0.0
|
||||
|
|
|
@ -26,3 +26,4 @@ flipud: 0.0
|
|||
fliplr: 0.5
|
||||
mosaic: 1.0
|
||||
mixup: 0.0
|
||||
copy_paste: 0.0
|
||||
|
|
|
@ -31,3 +31,4 @@ flipud: 0.0 # image flip up-down (probability)
|
|||
fliplr: 0.5 # image flip left-right (probability)
|
||||
mosaic: 1.0 # image mosaic (probability)
|
||||
mixup: 0.0 # image mixup (probability)
|
||||
copy_paste: 0.0 # segment copy-paste (probability)
|
||||
|
|
|
@ -31,3 +31,4 @@ flipud: 0.0 # image flip up-down (probability)
|
|||
fliplr: 0.5 # image flip left-right (probability)
|
||||
mosaic: 1.0 # image mosaic (probability)
|
||||
mixup: 0.0 # image mixup (probability)
|
||||
copy_paste: 0.0 # segment copy-paste (probability)
|
||||
|
|
5
train.py
5
train.py
|
@ -6,7 +6,6 @@ Usage:
|
|||
|
||||
import argparse
|
||||
import logging
|
||||
import math
|
||||
import os
|
||||
import random
|
||||
import sys
|
||||
|
@ -16,6 +15,7 @@ from copy import deepcopy
|
|||
from pathlib import Path
|
||||
from threading import Thread
|
||||
|
||||
import math
|
||||
import numpy as np
|
||||
import torch.distributed as dist
|
||||
import torch.nn as nn
|
||||
|
@ -591,7 +591,8 @@ def main(opt):
|
|||
'flipud': (1, 0.0, 1.0), # image flip up-down (probability)
|
||||
'fliplr': (0, 0.0, 1.0), # image flip left-right (probability)
|
||||
'mosaic': (1, 0.0, 1.0), # image mixup (probability)
|
||||
'mixup': (1, 0.0, 1.0)} # image mixup (probability)
|
||||
'mixup': (1, 0.0, 1.0), # image mixup (probability)
|
||||
'copy_paste': (1, 0.0, 1.0)} # segment copy-paste (probability)
|
||||
|
||||
with open(opt.hyp) as f:
|
||||
hyp = yaml.safe_load(f) # load hyps dict
|
||||
|
|
|
@ -25,6 +25,7 @@ from tqdm import tqdm
|
|||
|
||||
from utils.general import check_requirements, check_file, check_dataset, xywh2xyxy, xywhn2xyxy, xyxy2xywhn, \
|
||||
xyn2xy, segment2box, segments2boxes, resample_segments, clean_str
|
||||
from utils.metrics import bbox_ioa
|
||||
from utils.torch_utils import torch_distributed_zero_first
|
||||
|
||||
# Parameters
|
||||
|
@ -683,6 +684,7 @@ def load_mosaic(self, index):
|
|||
# img4, labels4 = replicate(img4, labels4) # replicate
|
||||
|
||||
# Augment
|
||||
img4, labels4, segments4 = copy_paste(img4, labels4, segments4, probability=self.hyp['copy_paste'])
|
||||
img4, labels4 = random_perspective(img4, labels4, segments4,
|
||||
degrees=self.hyp['degrees'],
|
||||
translate=self.hyp['translate'],
|
||||
|
@ -907,6 +909,30 @@ def random_perspective(img, targets=(), segments=(), degrees=10, translate=.1, s
|
|||
return img, targets
|
||||
|
||||
|
||||
def copy_paste(img, labels, segments, probability=0.5):
|
||||
# Implement Copy-Paste augmentation https://arxiv.org/abs/2012.07177, labels as nx5 np.array(cls, xyxy)
|
||||
n = len(segments)
|
||||
if probability and n:
|
||||
h, w, c = img.shape # height, width, channels
|
||||
im_new = np.zeros(img.shape, np.uint8)
|
||||
for j in random.sample(range(n), k=round(probability * n)):
|
||||
l, s = labels[j], segments[j]
|
||||
box = w - l[3], l[2], w - l[1], l[4]
|
||||
ioa = bbox_ioa(box, labels[:, 1:5]) # intersection over area
|
||||
if (ioa < 0.30).all(): # allow 30% obscuration of existing labels
|
||||
labels = np.concatenate((labels, [[l[0], *box]]), 0)
|
||||
segments.append(np.concatenate((w - s[:, 0:1], s[:, 1:2]), 1))
|
||||
cv2.drawContours(im_new, [segments[j].astype(np.int32)], -1, (255, 255, 255), cv2.FILLED)
|
||||
|
||||
result = cv2.bitwise_and(src1=img, src2=im_new)
|
||||
result = cv2.flip(result, 1) # augment segments (flip left-right)
|
||||
i = result > 0 # pixels to replace
|
||||
# i[:, :] = result.max(2).reshape(h, w, 1) # act over ch
|
||||
img[i] = result[i] # cv2.imwrite('debug.jpg', img) # debug
|
||||
|
||||
return img, labels, segments
|
||||
|
||||
|
||||
def box_candidates(box1, box2, wh_thr=2, ar_thr=20, area_thr=0.1, eps=1e-16): # box1(4,n), box2(4,n)
|
||||
# Compute candidate boxes: box1 before augment, box2 after augment, wh_thr (pixels), aspect_ratio_thr, area_ratio
|
||||
w1, h1 = box1[2] - box1[0], box1[3] - box1[1]
|
||||
|
@ -919,24 +945,6 @@ def cutout(image, labels):
|
|||
# Applies image cutout augmentation https://arxiv.org/abs/1708.04552
|
||||
h, w = image.shape[:2]
|
||||
|
||||
def bbox_ioa(box1, box2):
|
||||
# Returns the intersection over box2 area given box1, box2. box1 is 4, box2 is nx4. boxes are x1y1x2y2
|
||||
box2 = box2.transpose()
|
||||
|
||||
# Get the coordinates of bounding boxes
|
||||
b1_x1, b1_y1, b1_x2, b1_y2 = box1[0], box1[1], box1[2], box1[3]
|
||||
b2_x1, b2_y1, b2_x2, b2_y2 = box2[0], box2[1], box2[2], box2[3]
|
||||
|
||||
# Intersection area
|
||||
inter_area = (np.minimum(b1_x2, b2_x2) - np.maximum(b1_x1, b2_x1)).clip(0) * \
|
||||
(np.minimum(b1_y2, b2_y2) - np.maximum(b1_y1, b2_y1)).clip(0)
|
||||
|
||||
# box2 area
|
||||
box2_area = (b2_x2 - b2_x1) * (b2_y2 - b2_y1) + 1e-16
|
||||
|
||||
# Intersection over box2 area
|
||||
return inter_area / box2_area
|
||||
|
||||
# create random masks
|
||||
scales = [0.5] * 1 + [0.25] * 2 + [0.125] * 4 + [0.0625] * 8 + [0.03125] * 16 # image size fraction
|
||||
for s in scales:
|
||||
|
|
|
@ -1,9 +1,9 @@
|
|||
# Model validation metrics
|
||||
|
||||
import math
|
||||
import warnings
|
||||
from pathlib import Path
|
||||
|
||||
import math
|
||||
import matplotlib.pyplot as plt
|
||||
import numpy as np
|
||||
import torch
|
||||
|
@ -253,6 +253,30 @@ def box_iou(box1, box2):
|
|||
return inter / (area1[:, None] + area2 - inter) # iou = inter / (area1 + area2 - inter)
|
||||
|
||||
|
||||
def bbox_ioa(box1, box2, eps=1E-7):
|
||||
""" Returns the intersection over box2 area given box1, box2. Boxes are x1y1x2y2
|
||||
box1: np.array of shape(4)
|
||||
box2: np.array of shape(nx4)
|
||||
returns: np.array of shape(n)
|
||||
"""
|
||||
|
||||
box2 = box2.transpose()
|
||||
|
||||
# Get the coordinates of bounding boxes
|
||||
b1_x1, b1_y1, b1_x2, b1_y2 = box1[0], box1[1], box1[2], box1[3]
|
||||
b2_x1, b2_y1, b2_x2, b2_y2 = box2[0], box2[1], box2[2], box2[3]
|
||||
|
||||
# Intersection area
|
||||
inter_area = (np.minimum(b1_x2, b2_x2) - np.maximum(b1_x1, b2_x1)).clip(0) * \
|
||||
(np.minimum(b1_y2, b2_y2) - np.maximum(b1_y1, b2_y1)).clip(0)
|
||||
|
||||
# box2 area
|
||||
box2_area = (b2_x2 - b2_x1) * (b2_y2 - b2_y1) + eps
|
||||
|
||||
# Intersection over box2 area
|
||||
return inter_area / box2_area
|
||||
|
||||
|
||||
def wh_iou(wh1, wh2):
|
||||
# Returns the nxm IoU matrix. wh1 is nx2, wh2 is mx2
|
||||
wh1 = wh1[:, None] # [N,1,2]
|
||||
|
|
Loading…
Reference in New Issue