fast-reid/fastreid/data/common.py

59 lines
1.5 KiB
Python
Raw Normal View History

2020-02-10 07:38:56 +08:00
# encoding: utf-8
"""
@author: liaoxingyu
@contact: sherlockliao01@gmail.com
"""
from torch.utils.data import Dataset
2020-02-10 07:38:56 +08:00
from .data_utils import read_image
class CommDataset(Dataset):
2020-02-10 07:38:56 +08:00
"""Image Person ReID Dataset"""
def __init__(self, img_items, transform=None, relabel=True):
2020-07-06 16:55:23 +08:00
self.img_items = img_items
self.transform = transform
2020-02-10 07:38:56 +08:00
self.relabel = relabel
2020-09-01 16:13:12 +08:00
pid_set = set()
cam_set = set()
for i in img_items:
pid_set.add(i[1])
cam_set.add(i[2])
self.pids = sorted(list(pid_set))
2020-09-01 16:13:12 +08:00
self.cams = sorted(list(cam_set))
if relabel:
self.pid_dict = dict([(p, i) for i, p in enumerate(self.pids)])
self.cam_dict = dict([(p, i) for i, p in enumerate(self.cams)])
2020-02-10 07:38:56 +08:00
def __len__(self):
return len(self.img_items)
def __getitem__(self, index):
2020-12-22 15:47:08 +08:00
img_item = self.img_items[index]
img_path = img_item[0]
pid = img_item[1]
camid = img_item[2]
2020-02-10 07:38:56 +08:00
img = read_image(img_path)
2020-05-30 16:56:08 +08:00
if self.transform is not None: img = self.transform(img)
2020-09-01 16:13:12 +08:00
if self.relabel:
pid = self.pid_dict[pid]
camid = self.cam_dict[camid]
2020-02-10 07:38:56 +08:00
return {
2020-07-06 16:55:23 +08:00
"images": img,
"targets": pid,
2020-09-01 16:13:12 +08:00
"camids": camid,
"img_paths": img_path,
2020-02-10 07:38:56 +08:00
}
2020-07-06 16:55:23 +08:00
@property
def num_classes(self):
return len(self.pids)
2020-09-01 16:13:12 +08:00
@property
def num_cameras(self):
return len(self.cams)