mirror of https://github.com/JDAI-CV/fast-reid.git
81 lines
2.3 KiB
Python
81 lines
2.3 KiB
Python
# encoding: utf-8
|
|
"""
|
|
@author: xingyu liao
|
|
@contact: sherlockliao01@gmail.com
|
|
"""
|
|
|
|
from PIL import Image
|
|
import io
|
|
import logging
|
|
import numbers
|
|
|
|
import torch
|
|
from torch.utils.data import Dataset
|
|
|
|
from fastreid.data.common import CommDataset
|
|
|
|
logger = logging.getLogger("fastreid.face_data")
|
|
|
|
try:
|
|
import mxnet as mx
|
|
except ImportError:
|
|
logger.info("Please install mxnet if you want to use .rec file")
|
|
|
|
|
|
class MXFaceDataset(Dataset):
|
|
def __init__(self, path_imgrec, transforms):
|
|
super().__init__()
|
|
self.transforms = transforms
|
|
|
|
logger.info(f"loading recordio {path_imgrec}...")
|
|
path_imgidx = path_imgrec[0:-4] + ".idx"
|
|
self.imgrec = mx.recordio.MXIndexedRecordIO(path_imgidx, path_imgrec, 'r')
|
|
s = self.imgrec.read_idx(0)
|
|
header, _ = mx.recordio.unpack(s)
|
|
if header.flag > 0:
|
|
# logger.debug(f"header0 label: {header.label}")
|
|
self.header0 = (int(header.label[0]), int(header.label[1]))
|
|
self.imgidx = list(range(1, int(header.label[0])))
|
|
# logger.debug(self.imgidx)
|
|
else:
|
|
self.imgidx = list(self.imgrec.keys)
|
|
logger.info(f"Number of Samples: {len(self.imgidx)}, "
|
|
f"Number of Classes: {int(self.header0[1] - self.header0[0])}")
|
|
|
|
def __getitem__(self, index):
|
|
idx = self.imgidx[index]
|
|
s = self.imgrec.read_idx(idx)
|
|
header, img = mx.recordio.unpack(s)
|
|
label = header.label
|
|
if not isinstance(label, numbers.Number):
|
|
label = label[0]
|
|
label = torch.tensor(label, dtype=torch.long)
|
|
|
|
sample = Image.open(io.BytesIO(img)) # RGB
|
|
if self.transforms is not None: sample = self.transforms(sample)
|
|
return {
|
|
"images": sample,
|
|
"targets": label,
|
|
"camids": 0,
|
|
}
|
|
|
|
def __len__(self):
|
|
# logger.debug(f"mxface dataset length is {len(self.imgidx)}")
|
|
return len(self.imgidx)
|
|
|
|
@property
|
|
def num_classes(self):
|
|
return int(self.header0[1] - self.header0[0])
|
|
|
|
|
|
class TestFaceDataset(CommDataset):
|
|
def __init__(self, img_items, labels):
|
|
self.img_items = img_items
|
|
self.labels = labels
|
|
|
|
def __getitem__(self, index):
|
|
img = torch.tensor(self.img_items[index]) * 127.5 + 127.5
|
|
return {
|
|
"images": img,
|
|
}
|