fast-reid/projects/FastFace/fastface/face_data.py

81 lines
2.3 KiB
Python

# encoding: utf-8
"""
@author: xingyu liao
@contact: sherlockliao01@gmail.com
"""
from PIL import Image
import io
import logging
import numbers
import torch
from torch.utils.data import Dataset
from fastreid.data.common import CommDataset
logger = logging.getLogger("fastreid.face_data")
try:
import mxnet as mx
except ImportError:
logger.info("Please install mxnet if you want to use .rec file")
class MXFaceDataset(Dataset):
def __init__(self, path_imgrec, transforms):
super().__init__()
self.transforms = transforms
logger.info(f"loading recordio {path_imgrec}...")
path_imgidx = path_imgrec[0:-4] + ".idx"
self.imgrec = mx.recordio.MXIndexedRecordIO(path_imgidx, path_imgrec, 'r')
s = self.imgrec.read_idx(0)
header, _ = mx.recordio.unpack(s)
if header.flag > 0:
# logger.debug(f"header0 label: {header.label}")
self.header0 = (int(header.label[0]), int(header.label[1]))
self.imgidx = list(range(1, int(header.label[0])))
# logger.debug(self.imgidx)
else:
self.imgidx = list(self.imgrec.keys)
logger.info(f"Number of Samples: {len(self.imgidx)}, "
f"Number of Classes: {int(self.header0[1] - self.header0[0])}")
def __getitem__(self, index):
idx = self.imgidx[index]
s = self.imgrec.read_idx(idx)
header, img = mx.recordio.unpack(s)
label = header.label
if not isinstance(label, numbers.Number):
label = label[0]
label = torch.tensor(label, dtype=torch.long)
sample = Image.open(io.BytesIO(img)) # RGB
if self.transforms is not None: sample = self.transforms(sample)
return {
"images": sample,
"targets": label,
"camids": 0,
}
def __len__(self):
# logger.debug(f"mxface dataset length is {len(self.imgidx)}")
return len(self.imgidx)
@property
def num_classes(self):
return int(self.header0[1] - self.header0[0])
class TestFaceDataset(CommDataset):
def __init__(self, img_items, labels):
self.img_items = img_items
self.labels = labels
def __getitem__(self, index):
img = torch.tensor(self.img_items[index]) * 127.5 + 127.5
return {
"images": img,
}