2020-02-10 07:38:56 +08:00
|
|
|
# encoding: utf-8
|
|
|
|
"""
|
|
|
|
@author: liaoxingyu
|
|
|
|
@contact: sherlockliao01@gmail.com
|
|
|
|
"""
|
|
|
|
|
2020-02-13 00:19:15 +08:00
|
|
|
from torch.utils.data import Dataset
|
2020-02-10 07:38:56 +08:00
|
|
|
|
|
|
|
from .data_utils import read_image
|
|
|
|
|
|
|
|
|
2020-02-18 21:01:23 +08:00
|
|
|
class CommDataset(Dataset):
|
2020-02-10 07:38:56 +08:00
|
|
|
"""Image Person ReID Dataset"""
|
|
|
|
|
|
|
|
def __init__(self, img_items, transform=None, relabel=True):
|
2020-07-06 16:55:23 +08:00
|
|
|
self.img_items = img_items
|
2020-02-18 21:01:23 +08:00
|
|
|
self.transform = transform
|
2020-02-10 07:38:56 +08:00
|
|
|
self.relabel = relabel
|
|
|
|
|
2020-08-20 15:41:14 +08:00
|
|
|
pid_set = set([i[1] for i in img_items])
|
|
|
|
|
|
|
|
self.pids = sorted(list(pid_set))
|
|
|
|
if relabel: self.pid_dict = dict([(p, i) for i, p in enumerate(self.pids)])
|
2020-02-10 07:38:56 +08:00
|
|
|
|
|
|
|
def __len__(self):
|
|
|
|
return len(self.img_items)
|
|
|
|
|
|
|
|
def __getitem__(self, index):
|
|
|
|
img_path, pid, camid = self.img_items[index]
|
|
|
|
img = read_image(img_path)
|
2020-05-30 16:56:08 +08:00
|
|
|
if self.transform is not None: img = self.transform(img)
|
|
|
|
if self.relabel: pid = self.pid_dict[pid]
|
2020-02-10 07:38:56 +08:00
|
|
|
return {
|
2020-07-06 16:55:23 +08:00
|
|
|
"images": img,
|
|
|
|
"targets": pid,
|
|
|
|
"camid": camid,
|
|
|
|
"img_path": img_path
|
2020-02-10 07:38:56 +08:00
|
|
|
}
|
|
|
|
|
2020-07-06 16:55:23 +08:00
|
|
|
@property
|
|
|
|
def num_classes(self):
|
|
|
|
return len(self.pids)
|