add random patch data augmentation

pull/249/head
KaiyangZhou 2019-09-01 15:03:42 +01:00
parent 9131c537d5
commit 0ec14d9afc
1 changed files with 80 additions and 0 deletions

View File

@ -6,6 +6,7 @@ from PIL import Image
import random
import numpy as np
import math
from collections import deque
import torch
from torchvision.transforms import *
@ -131,6 +132,82 @@ class ColorAugmentation(object):
return tensor
class RandomPatch(object):
"""Random patch data augmentation.
There is a patch pool that stores randomly extracted pathces from person images.
For each input image,
1) we extract a random patch and store the patch in the patch pool;
2) randomly select a patch from the patch pool and paste it on the
input to simulate occlusion.
Reference:
- Zhou et al. Omni-Scale Feature Learning for Person Re-Identification. ICCV, 2019.
"""
def __init__(self, prob_happen=0.5, pool_capacity=50000, min_sample_size=100,
patch_min_area=0.01, patch_max_area=0.5, patch_min_ratio=0.1,
prob_rotate=0.5, prob_flip_leftright=0.5,
):
self.prob_happen = prob_happen
self.patch_min_area = patch_min_area
self.patch_max_area = patch_max_area
self.patch_min_ratio = patch_min_ratio
self.prob_rotate = prob_rotate
self.prob_flip_leftright = prob_flip_leftright
self.patchpool = deque(maxlen=pool_capacity)
self.min_sample_size = min_sample_size
def generate_wh(self, W, H):
area = W * H
for attempt in range(100):
target_area = random.uniform(self.patch_min_area, self.patch_max_area) * area
aspect_ratio = random.uniform(self.patch_min_ratio, 1./self.patch_min_ratio)
h = int(round(math.sqrt(target_area * aspect_ratio)))
w = int(round(math.sqrt(target_area / aspect_ratio)))
if w < W and h < H:
return w, h
return None, None
def transform_patch(self, patch):
if random.uniform(0, 1) > self.prob_flip_leftright:
patch = patch.transpose(Image.FLIP_LEFT_RIGHT)
if random.uniform(0, 1) > self.prob_rotate:
patch = patch.rotate(random.randint(-10, 10))
return patch
def __call__(self, img):
W, H = img.size # original image size
# collect new patch
w, h = self.generate_wh(W, H)
if w is not None and h is not None:
x1 = random.randint(0, W - w)
y1 = random.randint(0, H - h)
new_patch = img.crop((x1, y1, x1 + w, y1 + h))
self.patchpool.append(new_patch)
if len(self.patchpool) < self.min_sample_size:
return img
if random.uniform(0, 1) > self.prob_happen:
return img
# paste a randomly selected patch on a random position
patch = random.sample(self.patchpool, 1)[0]
patchW, patchH = patch.size
x1 = random.randint(0, W - patchW)
y1 = random.randint(0, H - patchH)
patch = self.transform_patch(patch)
img.paste(patch, (x1, y1))
return img
def build_transforms(height, width, transforms='random_flip', norm_mean=[0.485, 0.456, 0.406],
norm_std=[0.229, 0.224, 0.225], **kwargs):
"""Builds train and test transform functions.
@ -172,6 +249,9 @@ def build_transforms(height, width, transforms='random_flip', norm_mean=[0.485,
print('+ random crop (enlarge to {}x{} and ' \
'crop {}x{})'.format(int(round(height*1.125)), int(round(width*1.125)), height, width))
transform_tr += [Random2DTranslation(height, width)]
if 'random_patch' in transforms:
print('+ random patch')
transform_tr += [RandomPatch()]
if 'color_jitter' in transforms:
print('+ color jitter')
transform_tr += [ColorJitter(brightness=0.2, contrast=0.15, saturation=0, hue=0)]