PaddleClas/ppcls/data/dataloader/face_dataset.py

246 lines
8.6 KiB
Python
Raw Normal View History

2022-05-06 12:50:26 +08:00
import os
import json
import numpy as np
from PIL import Image
import cv2
import paddle
import paddle.vision.datasets as datasets
from paddle.vision import transforms
from paddle.vision.transforms import functional as F
from paddle.io import Dataset
from .common_dataset import create_operators
2022-05-16 11:50:35 +08:00
from ppcls.data.preprocess import transform as transform_func
2022-05-06 12:50:26 +08:00
# code is based on AdaFace: https://github.com/mk-minchul/AdaFace
def _get_image_size(img):
if F._is_pil_image(img):
return img.size
elif F._is_numpy_image(img):
return img.shape[:2][::-1]
elif F._is_tensor_image(img):
return img.shape[1:][::-1] # chw
else:
raise TypeError("Unexpected type {}".format(type(img)))
class AdaFaceDataset(Dataset):
def __init__(
self,
root_dir,
label_path,
transform=None,
low_res_augmentation_prob=0.0,
crop_augmentation_prob=0.0,
photometric_augmentation_prob=0.0, ):
self.root_dir = root_dir
self.low_res_augmentation_prob = low_res_augmentation_prob
self.crop_augmentation_prob = crop_augmentation_prob
self.photometric_augmentation_prob = photometric_augmentation_prob
self.random_resized_crop = transforms.RandomResizedCrop(
size=(112, 112),
scale=(0.2, 1.0),
ratio=(0.75, 1.3333333333333333))
self.photometric = transforms.ColorJitter(
brightness=0.5, contrast=0.5, saturation=0.5, hue=0)
self.transform = create_operators(transform)
self.tot_rot_try = 0
self.rot_success = 0
with open(label_path) as fd:
lines = fd.readlines()
self.samples = []
for l in lines:
l = l.strip().split()
self.samples.append([os.path.join(root_dir, l[0]), int(l[1])])
def __len__(self):
return len(self.samples)
def __getitem__(self, index):
"""
Args:
index (int): Index
Returns:
tuple: (sample, target) where target is class_index of the target class.
"""
[path, target] = self.samples[index]
with open(path, 'rb') as f:
img = Image.open(f)
sample = img.convert('RGB')
# if 'WebFace' in self.root:
# # swap rgb to bgr since image is in rgb for webface
# sample = Image.fromarray(np.asarray(sample)[:, :, ::-1])
sample, _ = self.augment(sample)
if self.transform is not None:
2022-05-16 11:50:35 +08:00
sample = transform_func(sample, self.transform)
2022-05-06 12:50:26 +08:00
return sample, target
def augment(self, sample):
# crop with zero padding augmentation
if np.random.random() < self.crop_augmentation_prob:
# RandomResizedCrop augmentation
new = np.zeros_like(np.array(sample))
# orig_W, orig_H = F._get_image_size(sample)
orig_W, orig_H = _get_image_size(sample)
i, j, h, w = self.random_resized_crop._get_param(sample)
cropped = F.crop(sample, i, j, h, w)
new[i:i + h, j:j + w, :] = np.array(cropped)
sample = Image.fromarray(new.astype(np.uint8))
crop_ratio = min(h, w) / max(orig_H, orig_W)
else:
crop_ratio = 1.0
# low resolution augmentation
if np.random.random() < self.low_res_augmentation_prob:
# low res augmentation
img_np, resize_ratio = low_res_augmentation(np.array(sample))
sample = Image.fromarray(img_np.astype(np.uint8))
else:
resize_ratio = 1
# photometric augmentation
if np.random.random() < self.photometric_augmentation_prob:
sample = self.photometric(sample)
information_score = resize_ratio * crop_ratio
return sample, information_score
def low_res_augmentation(img):
# resize the image to a small size and enlarge it back
img_shape = img.shape
side_ratio = np.random.uniform(0.2, 1.0)
small_side = int(side_ratio * img_shape[0])
interpolation = np.random.choice([
cv2.INTER_NEAREST, cv2.INTER_LINEAR, cv2.INTER_AREA, cv2.INTER_CUBIC,
cv2.INTER_LANCZOS4
])
small_img = cv2.resize(
img, (small_side, small_side), interpolation=interpolation)
interpolation = np.random.choice([
cv2.INTER_NEAREST, cv2.INTER_LINEAR, cv2.INTER_AREA, cv2.INTER_CUBIC,
cv2.INTER_LANCZOS4
])
aug_img = cv2.resize(
small_img, (img_shape[1], img_shape[0]), interpolation=interpolation)
return aug_img, side_ratio
class FiveValidationDataset(Dataset):
def __init__(self, val_data_path, concat_mem_file_name):
'''
concatenates all validation datasets from emore
val_data_dict = {
'agedb_30': (agedb_30, agedb_30_issame),
"cfp_fp": (cfp_fp, cfp_fp_issame),
"lfw": (lfw, lfw_issame),
"cplfw": (cplfw, cplfw_issame),
"calfw": (calfw, calfw_issame),
}
agedb_30: 0
cfp_fp: 1
lfw: 2
cplfw: 3
calfw: 4
'''
val_data = get_val_data(val_data_path)
age_30, cfp_fp, lfw, age_30_issame, cfp_fp_issame, lfw_issame, cplfw, cplfw_issame, calfw, calfw_issame = val_data
val_data_dict = {
'agedb_30': (age_30, age_30_issame),
"cfp_fp": (cfp_fp, cfp_fp_issame),
"lfw": (lfw, lfw_issame),
"cplfw": (cplfw, cplfw_issame),
"calfw": (calfw, calfw_issame),
}
self.dataname_to_idx = {
"agedb_30": 0,
"cfp_fp": 1,
"lfw": 2,
"cplfw": 3,
"calfw": 4
}
self.val_data_dict = val_data_dict
# concat all dataset
all_imgs = []
all_issame = []
all_dataname = []
key_orders = []
for key, (imgs, issame) in val_data_dict.items():
all_imgs.append(imgs)
dup_issame = [
] # hacky way to make the issame length same as imgs. [1, 1, 0, 0, ...]
for same in issame:
dup_issame.append(same)
dup_issame.append(same)
all_issame.append(dup_issame)
all_dataname.append([self.dataname_to_idx[key]] * len(imgs))
key_orders.append(key)
assert key_orders == ['agedb_30', 'cfp_fp', 'lfw', 'cplfw', 'calfw']
if isinstance(all_imgs[0], np.memmap):
self.all_imgs = read_memmap(concat_mem_file_name)
else:
self.all_imgs = np.concatenate(all_imgs)
self.all_issame = np.concatenate(all_issame)
self.all_dataname = np.concatenate(all_dataname)
def __getitem__(self, index):
x_np = self.all_imgs[index].copy()
x = paddle.to_tensor(x_np)
y = self.all_issame[index]
dataname = self.all_dataname[index]
return x, y, dataname, index
def __len__(self):
return len(self.all_imgs)
def read_memmap(mem_file_name):
# r+ mode: Open existing file for reading and writing
with open(mem_file_name + '.conf', 'r') as file:
memmap_configs = json.load(file)
return np.memmap(mem_file_name, mode='r+', \
shape=tuple(memmap_configs['shape']), \
dtype=memmap_configs['dtype'])
def get_val_pair(path, name, use_memfile=True):
2022-05-17 15:48:23 +08:00
# installing bcolz should set proxy to access internet
import bcolz
2022-05-06 12:50:26 +08:00
if use_memfile:
mem_file_dir = os.path.join(path, name, 'memfile')
mem_file_name = os.path.join(mem_file_dir, 'mem_file.dat')
if os.path.isdir(mem_file_dir):
print('laoding validation data memfile')
np_array = read_memmap(mem_file_name)
else:
os.makedirs(mem_file_dir)
carray = bcolz.carray(rootdir=os.path.join(path, name), mode='r')
np_array = np.array(carray)
# mem_array = make_memmap(mem_file_name, np_array)
# del np_array, mem_array
del np_array
np_array = read_memmap(mem_file_name)
else:
np_array = bcolz.carray(rootdir=os.path.join(path, name), mode='r')
issame = np.load(os.path.join(path, '{}_list.npy'.format(name)))
return np_array, issame
def get_val_data(data_path):
agedb_30, agedb_30_issame = get_val_pair(data_path, 'agedb_30')
cfp_fp, cfp_fp_issame = get_val_pair(data_path, 'cfp_fp')
lfw, lfw_issame = get_val_pair(data_path, 'lfw')
cplfw, cplfw_issame = get_val_pair(data_path, 'cplfw')
calfw, calfw_issame = get_val_pair(data_path, 'calfw')
2022-05-16 11:50:35 +08:00
return agedb_30, cfp_fp, lfw, agedb_30_issame, cfp_fp_issame, lfw_issame, cplfw, cplfw_issame, calfw, calfw_issame