fast-reid/fastreid/data/common.py

88 lines
2.4 KiB
Python
Raw Normal View History

2020-02-10 07:38:56 +08:00
# encoding: utf-8
"""
@author: liaoxingyu
@contact: sherlockliao01@gmail.com
"""
import torch
from torch.utils.data import Dataset
2020-02-10 07:38:56 +08:00
from .data_utils import read_image
class CommDataset(Dataset):
2020-02-10 07:38:56 +08:00
"""Image Person ReID Dataset"""
def __init__(self, img_items, transform=None, relabel=True):
self.transform = transform
2020-02-10 07:38:56 +08:00
self.relabel = relabel
self.pid2label = None
if self.relabel:
self.img_items = []
pids = set()
for i, item in enumerate(img_items):
pid = self.get_pids(item[0], item[1])
self.img_items.append((item[0], pid, item[2])) # replace pid
pids.add(pid)
self.pids = pids
self.pid2label = dict([(p, i) for i, p in enumerate(self.pids)])
else:
self.img_items = img_items
def __len__(self):
return len(self.img_items)
def __getitem__(self, index):
img_path, pid, camid = self.img_items[index]
img = read_image(img_path)
if self.transform is not None:
img = self.transform(img)
if self.relabel:
pid = self.pid2label[pid]
2020-02-10 07:38:56 +08:00
return {
'images': img,
'targets': pid,
'camid': camid
}
def get_pids(self, file_path, pid):
""" Suitable for muilti-dataset training """
if 'cuhk03' in file_path:
prefix = 'cuhk'
else:
prefix = file_path.split('/')[1]
return prefix + '_' + str(pid)
class data_prefetcher():
def __init__(self, cfg, loader):
self.loader = loader
self.loader_iter = iter(loader)
# normalize
assert len(cfg.MODEL.PIXEL_MEAN) == len(cfg.MODEL.PIXEL_STD)
num_channels = len(cfg.MODEL.PIXEL_MEAN)
self.mean = torch.tensor(cfg.MODEL.PIXEL_MEAN).view(1, num_channels, 1, 1)
self.std = torch.tensor(cfg.MODEL.PIXEL_STD).view(1, num_channels, 1, 1)
self.preload()
2020-02-27 12:16:57 +08:00
def reset(self):
self.loader_iter = iter(self.loader)
self.preload()
def preload(self):
try:
self.next_inputs = next(self.loader_iter)
except StopIteration:
self.next_inputs = None
return
self.next_inputs["images"].sub_(self.mean).div_(self.std)
def next(self):
inputs = self.next_inputs
self.preload()
return inputs