add data module
parent
46a0d5c27d
commit
77c47b4e53
|
@ -111,10 +111,10 @@ venv.bak/
|
|||
.mypy_cache/
|
||||
|
||||
# Custom
|
||||
data
|
||||
reid-data
|
||||
log
|
||||
saved-models
|
||||
debug.py
|
||||
debug
|
||||
.idea
|
||||
|
||||
# Cython eval code
|
||||
|
|
|
@ -0,0 +1,18 @@
|
|||
from torchreid.data import Dataset
|
||||
from torchreid.data.datasets import *
|
||||
|
||||
dataset1 = Market1501(root='data', combineall=False, transform='dummy', verbose=False)
|
||||
dataset2 = Market1501(root='data', combineall=False, transform='dummy', verbose=False)
|
||||
dataset3 = dataset1 + dataset2
|
||||
print(type(dataset3))
|
||||
print(dataset3)
|
||||
|
||||
print('** After addition **')
|
||||
print(type(dataset1))
|
||||
print(dataset1)
|
||||
|
||||
print(type(dataset2))
|
||||
print(dataset2)
|
||||
|
||||
print(type(dataset3))
|
||||
print(dataset3)
|
2
main.py
2
main.py
|
@ -1,7 +1,7 @@
|
|||
import torchreid
|
||||
|
||||
datamanager = torchreid.data.ImageDataManager(
|
||||
root='data',
|
||||
root='reid-data',
|
||||
sources=['market1501', 'cuhk03'],
|
||||
targets='market1501',
|
||||
height=128,
|
||||
|
|
|
@ -0,0 +1,5 @@
|
|||
from __future__ import absolute_import
|
||||
from __future__ import print_function
|
||||
|
||||
from .datasets import Dataset, ImageDataset, VideoDataset
|
||||
from .datamanager import ImageDataManager, VideoDataManager
|
|
@ -0,0 +1,266 @@
|
|||
from __future__ import absolute_import
|
||||
from __future__ import print_function
|
||||
from __future__ import division
|
||||
|
||||
import torch
|
||||
|
||||
from torchreid.data.sampler import build_train_sampler
|
||||
from torchreid.data.transforms import build_transforms
|
||||
from torchreid.data.datasets import init_image_dataset, init_video_dataset
|
||||
|
||||
|
||||
class DataManager(object):
|
||||
|
||||
def __init__(self, sources, targets=None, height=256, width=128, random_erase=False,
|
||||
color_jitter=False, color_aug=False, use_cpu=False):
|
||||
self.sources = sources
|
||||
self.targets = targets
|
||||
|
||||
if isinstance(self.sources, str):
|
||||
self.sources = [self.sources]
|
||||
|
||||
if isinstance(self.targets, str):
|
||||
self.targets = [self.targets]
|
||||
|
||||
if self.targets is None:
|
||||
self.targets = self.sources
|
||||
|
||||
self.transform_tr, self.transform_te = build_transforms(
|
||||
height, width,
|
||||
random_erase=random_erase,
|
||||
color_jitter=color_jitter,
|
||||
color_aug=color_aug
|
||||
)
|
||||
|
||||
self.use_gpu = (torch.cuda.is_available() and not use_cpu)
|
||||
|
||||
@property
|
||||
def num_train_pids(self):
|
||||
return self._num_train_pids
|
||||
|
||||
@property
|
||||
def num_train_cams(self):
|
||||
return self._num_train_cams
|
||||
|
||||
def return_dataloaders(self):
|
||||
return self.trainloader, self.testloader
|
||||
|
||||
def return_testdataset_by_name(self, name):
|
||||
return self.testdataset[name]['query'], self.testdataset[name]['gallery']
|
||||
|
||||
|
||||
class ImageDataManager(DataManager):
|
||||
|
||||
def __init__(self, root, sources, targets=None, height=256, width=128, random_erase=False,
|
||||
color_jitter=False, color_aug=False, use_cpu=False, split_id=0, combineall=False,
|
||||
batch_size=32, workers=4, num_instances=4, train_sampler=None,
|
||||
cuhk03_labeled=False, cuhk03_classic_split=False, market1501_500k=False):
|
||||
|
||||
super(ImageDataManager, self).__init__(sources, targets=targets, height=height, width=width,
|
||||
random_erase=random_erase, color_jitter=color_jitter,
|
||||
color_aug=color_aug, use_cpu=use_cpu)
|
||||
|
||||
print('=> Loading train (source) dataset')
|
||||
trainset = []
|
||||
for name in self.sources:
|
||||
trainset_ = init_image_dataset(
|
||||
name,
|
||||
transform=self.transform_tr,
|
||||
mode='train',
|
||||
combineall=combineall,
|
||||
root=root,
|
||||
split_id=split_id,
|
||||
cuhk03_labeled=cuhk03_labeled,
|
||||
cuhk03_classic_split=cuhk03_classic_split,
|
||||
market1501_500k=market1501_500k
|
||||
)
|
||||
trainset.append(trainset_)
|
||||
trainset = sum(trainset)
|
||||
|
||||
self._num_train_pids = trainset.num_train_pids
|
||||
self._num_train_cams = trainset.num_train_cams
|
||||
|
||||
train_sampler = build_train_sampler(
|
||||
trainset.train, train_sampler,
|
||||
batch_size=batch_size,
|
||||
num_instances=num_instances
|
||||
)
|
||||
|
||||
self.trainloader = torch.utils.data.DataLoader(
|
||||
trainset,
|
||||
sampler=train_sampler,
|
||||
batch_size=batch_size,
|
||||
shuffle=False,
|
||||
num_workers=workers,
|
||||
pin_memory=self.use_gpu,
|
||||
drop_last=True
|
||||
)
|
||||
|
||||
print('=> Loading test (target) dataset')
|
||||
self.testloader = {name: {'query': None, 'gallery': None} for name in self.targets}
|
||||
self.testdataset = {name: {'query': None, 'gallery': None} for name in self.targets}
|
||||
|
||||
for name in self.targets:
|
||||
# build query loader
|
||||
queryset = init_image_dataset(
|
||||
name,
|
||||
transform=self.transform_te,
|
||||
mode='query',
|
||||
combineall=combineall,
|
||||
root=root,
|
||||
split_id=split_id,
|
||||
cuhk03_labeled=cuhk03_labeled,
|
||||
cuhk03_classic_split=cuhk03_classic_split,
|
||||
market1501_500k=market1501_500k
|
||||
)
|
||||
self.testloader[name]['query'] = torch.utils.data.DataLoader(
|
||||
queryset,
|
||||
batch_size=batch_size,
|
||||
shuffle=False,
|
||||
num_workers=workers,
|
||||
pin_memory=self.use_gpu,
|
||||
drop_last=False
|
||||
)
|
||||
|
||||
# build gallery loader
|
||||
galleryset = init_image_dataset(
|
||||
name,
|
||||
transform=self.transform_te,
|
||||
mode='gallery',
|
||||
combineall=combineall,
|
||||
verbose=False,
|
||||
root=root,
|
||||
split_id=split_id,
|
||||
cuhk03_labeled=cuhk03_labeled,
|
||||
cuhk03_classic_split=cuhk03_classic_split,
|
||||
market1501_500k=market1501_500k
|
||||
)
|
||||
self.testloader[name]['gallery'] = torch.utils.data.DataLoader(
|
||||
galleryset,
|
||||
batch_size=batch_size,
|
||||
shuffle=False,
|
||||
num_workers=workers,
|
||||
pin_memory=self.use_gpu,
|
||||
drop_last=False
|
||||
)
|
||||
|
||||
self.testdataset[name]['query'] = queryset.query
|
||||
self.testdataset[name]['gallery'] = galleryset.gallery
|
||||
|
||||
print('\n')
|
||||
print(' **************** Summary ****************')
|
||||
print(' train : {}'.format(self.sources))
|
||||
print(' # train datasets : {}'.format(len(self.sources)))
|
||||
print(' # train ids : {}'.format(self.num_train_pids))
|
||||
print(' # train images : {}'.format(len(trainset)))
|
||||
print(' # train cameras : {}'.format(self.num_train_cams))
|
||||
print(' test : {}'.format(self.targets))
|
||||
print(' *****************************************')
|
||||
print('\n')
|
||||
|
||||
|
||||
class VideoDataManager(DataManager):
|
||||
|
||||
def __init__(self, root, sources, targets=None, height=256, width=128, random_erase=False,
|
||||
color_jitter=False, color_aug=False, use_cpu=False, split_id=0, combineall=False,
|
||||
batch_size=32, workers=4, num_instances=4, train_sampler=None,
|
||||
seq_len=15, sample_method='evenly'):
|
||||
|
||||
super(VideoDataManager, self).__init__(sources, targets=targets, height=height, width=width,
|
||||
random_erase=random_erase, color_jitter=color_jitter,
|
||||
color_aug=color_aug, use_cpu=use_cpu)
|
||||
|
||||
print('=> Loading train (source) dataset')
|
||||
trainset = []
|
||||
for name in self.sources:
|
||||
trainset_ = init_video_dataset(
|
||||
name,
|
||||
transform=self.transform_tr,
|
||||
mode='train',
|
||||
combineall=combineall,
|
||||
root=root,
|
||||
split_id=split_id,
|
||||
seq_len=seq_len,
|
||||
sample_method=sample_method
|
||||
)
|
||||
trainset.append(trainset_)
|
||||
trainset = sum(trainset)
|
||||
|
||||
self._num_train_pids = trainset.num_train_pids
|
||||
self._num_train_cams = trainset.num_train_cams
|
||||
|
||||
train_sampler = build_train_sampler(
|
||||
trainset.train, train_sampler,
|
||||
batch_size=batch_size,
|
||||
num_instances=num_instances
|
||||
)
|
||||
|
||||
self.trainloader = torch.utils.data.DataLoader(
|
||||
trainset,
|
||||
sampler=train_sampler,
|
||||
batch_size=batch_size,
|
||||
shuffle=False,
|
||||
num_workers=workers,
|
||||
pin_memory=self.use_gpu,
|
||||
drop_last=True
|
||||
)
|
||||
|
||||
print('=> Loading test (target) dataset')
|
||||
self.testloader = {name: {'query': None, 'gallery': None} for name in self.targets}
|
||||
self.testdataset = {name: {'query': None, 'gallery': None} for name in self.targets}
|
||||
|
||||
for name in self.targets:
|
||||
# build query loader
|
||||
queryset = init_video_dataset(
|
||||
name,
|
||||
transform=self.transform_te,
|
||||
mode='query',
|
||||
combineall=combineall,
|
||||
root=root,
|
||||
split_id=split_id,
|
||||
seq_len=seq_len,
|
||||
sample_method=sample_method
|
||||
)
|
||||
self.testloader[name]['query'] = torch.utils.data.DataLoader(
|
||||
queryset,
|
||||
batch_size=batch_size,
|
||||
shuffle=False,
|
||||
num_workers=workers,
|
||||
pin_memory=self.use_gpu,
|
||||
drop_last=False
|
||||
)
|
||||
|
||||
# build gallery loader
|
||||
galleryset = init_video_dataset(
|
||||
name,
|
||||
transform=self.transform_te,
|
||||
mode='gallery',
|
||||
combineall=combineall,
|
||||
verbose=False,
|
||||
root=root,
|
||||
split_id=split_id,
|
||||
seq_len=seq_len,
|
||||
sample_method=sample_method
|
||||
)
|
||||
self.testloader[name]['gallery'] = torch.utils.data.DataLoader(
|
||||
galleryset,
|
||||
batch_size=batch_size,
|
||||
shuffle=False,
|
||||
num_workers=workers,
|
||||
pin_memory=self.use_gpu,
|
||||
drop_last=False
|
||||
)
|
||||
|
||||
self.testdataset[name]['query'] = queryset.query
|
||||
self.testdataset[name]['gallery'] = galleryset.gallery
|
||||
|
||||
print('\n')
|
||||
print(' **************** Summary ****************')
|
||||
print(' train : {}'.format(self.sources))
|
||||
print(' # train datasets : {}'.format(len(self.sources)))
|
||||
print(' # train ids : {}'.format(self.num_train_pids))
|
||||
print(' # train tracklets : {}'.format(len(trainset)))
|
||||
print(' # train cameras : {}'.format(self.num_train_cams))
|
||||
print(' test : {}'.format(self.targets))
|
||||
print(' *****************************************')
|
||||
print('\n')
|
|
@ -0,0 +1,43 @@
|
|||
from __future__ import absolute_import
|
||||
from __future__ import print_function
|
||||
|
||||
from .dataset import Dataset, ImageDataset, VideoDataset
|
||||
from .image import *
|
||||
from .video import *
|
||||
|
||||
|
||||
__image_datasets = {
|
||||
'market1501': Market1501,
|
||||
'cuhk03': CUHK03,
|
||||
'dukemtmcreid': DukeMTMCreID,
|
||||
'msmt17': MSMT17,
|
||||
'viper': VIPeR,
|
||||
'grid': GRID,
|
||||
'cuhk01': CUHK01,
|
||||
'prid450s': PRID450S,
|
||||
'ilids': iLIDS,
|
||||
'sensereid': SenseReID,
|
||||
'prid': PRID
|
||||
}
|
||||
|
||||
|
||||
__video_datasets = {
|
||||
'mars': Mars,
|
||||
'ilidsvid': iLIDSVID,
|
||||
'prid2011': PRID2011,
|
||||
'dukemtmcvidreid': DukeMTMCVidReID
|
||||
}
|
||||
|
||||
|
||||
def init_image_dataset(name, **kwargs):
|
||||
avai_datasets = list(__image_datasets.keys())
|
||||
if name not in avai_datasets:
|
||||
raise KeyError('Invalid dataset name. Received "{}", but expected to be one of {}'.format(name, avai_datasets))
|
||||
return __image_datasets[name](**kwargs)
|
||||
|
||||
|
||||
def init_video_dataset(name, **kwargs):
|
||||
avai_datasets = list(__video_datasets.keys())
|
||||
if name not in avai_datasets:
|
||||
raise KeyError('Invalid dataset name. Received "{}", but expected to be one of {}'.format(name, avai_datasets))
|
||||
return __video_datasets[name](**kwargs)
|
|
@ -0,0 +1,290 @@
|
|||
from __future__ import absolute_import
|
||||
from __future__ import print_function
|
||||
from __future__ import division
|
||||
|
||||
import sys
|
||||
import os
|
||||
import os.path as osp
|
||||
import numpy as np
|
||||
import tarfile
|
||||
import zipfile
|
||||
import copy
|
||||
|
||||
import torch
|
||||
|
||||
from torchreid.utils import read_image, mkdir_if_missing, download_url
|
||||
|
||||
|
||||
class Dataset(object):
|
||||
|
||||
def __init__(self, train, query, gallery, transform=None, mode='train',
|
||||
combineall=False, verbose=True, **kwargs):
|
||||
self.train = train
|
||||
self.query = query
|
||||
self.gallery = gallery
|
||||
self.transform = transform
|
||||
self.mode = mode
|
||||
self.combineall = combineall
|
||||
self.verbose = verbose
|
||||
|
||||
self.num_train_pids = self.get_num_pids(self.train)
|
||||
self.num_train_cams = self.get_num_cams(self.train)
|
||||
|
||||
if self.combineall:
|
||||
self.combine_all()
|
||||
|
||||
if self.mode == 'train':
|
||||
self.data = self.train
|
||||
elif self.mode == 'query':
|
||||
self.data = self.query
|
||||
elif self.mode == 'gallery':
|
||||
self.data = self.gallery
|
||||
else:
|
||||
raise ValueError('Invalid mode. Got {}, but expected to be '
|
||||
'one of [train | query | gallery]'.format(self.mode))
|
||||
|
||||
if self.verbose:
|
||||
self.show_summary()
|
||||
|
||||
def __getitem__(self, index):
|
||||
raise NotImplementedError
|
||||
|
||||
def __len__(self):
|
||||
return len(self.data)
|
||||
|
||||
def __add__(self, other):
|
||||
"""Adds two datasets together (only the train set)"""
|
||||
train = copy.deepcopy(self.train)
|
||||
|
||||
for img_path, pid, camid in other.train:
|
||||
pid += self.num_train_pids
|
||||
camid += self.num_train_cams
|
||||
train.append((img_path, pid, camid))
|
||||
|
||||
return Dataset(train, self.query, self.gallery, transform=self.transform,
|
||||
mode=self.mode, combineall=self.combineall, verbose=self.verbose)
|
||||
|
||||
def __radd__(self, other):
|
||||
"""Supports sum([dataset1, dataset2, dataset3])"""
|
||||
if other == 0:
|
||||
return self
|
||||
else:
|
||||
return self.__add__(other)
|
||||
|
||||
def parse_data(self, data):
|
||||
"""Parses data
|
||||
|
||||
:param data: data list containing tuples of (img_path(s), pid, camid).
|
||||
:type data: list
|
||||
|
||||
:return num_pids: number of person identities
|
||||
:rtype num_pids: int
|
||||
:return num_cams: number of cameras
|
||||
:rtype num_cams: int
|
||||
"""
|
||||
pids = set()
|
||||
cams = set()
|
||||
for _, pid, camid in data:
|
||||
pids.add(pid)
|
||||
cams.add(camid)
|
||||
return len(pids), len(cams)
|
||||
|
||||
def get_num_pids(self, data):
|
||||
return self.parse_data(data)[0]
|
||||
|
||||
def get_num_cams(self, data):
|
||||
return self.parse_data(data)[1]
|
||||
|
||||
def show_summary(self):
|
||||
pass
|
||||
|
||||
def combine_all(self):
|
||||
"""Combines train, query and gallery"""
|
||||
combined = copy.deepcopy(self.train)
|
||||
|
||||
# relabel pids in gallery
|
||||
g_pids = set()
|
||||
for _, pid, _ in self.gallery:
|
||||
if pid==0 or pid==-1:
|
||||
continue
|
||||
g_pids.add(pid)
|
||||
pid2label = {pid: i for i, pid in enumerate(g_pids)}
|
||||
|
||||
def _combine_data(data):
|
||||
for img_path, pid, camid in data:
|
||||
if pid==0 or pid==-1:
|
||||
continue
|
||||
pid = pid2label[pid] + self.num_train_pids
|
||||
combined.append((img_path, pid, camid))
|
||||
|
||||
_combine_data(self.query)
|
||||
_combine_data(self.gallery)
|
||||
|
||||
self.train = combined
|
||||
self.num_train_pids = self.get_num_pids(self.train)
|
||||
|
||||
def download_dataset(self, dataset_dir, dataset_url):
|
||||
"""Downloads and extracts dataset
|
||||
|
||||
:param dataset_dir: dataset directory
|
||||
:type dataset_dir: str
|
||||
:param dataset_url: url to download dataset
|
||||
:type dataset_url: str
|
||||
"""
|
||||
if osp.exists(dataset_dir):
|
||||
return
|
||||
|
||||
if dataset_url is None:
|
||||
raise RuntimeError('{} dataset needs to be manually '
|
||||
'prepared, please follow the '
|
||||
'document to prepare this dataset'.format(self.__class__.__name__))
|
||||
|
||||
print('Creating directory "{}"'.format(dataset_dir))
|
||||
mkdir_if_missing(dataset_dir)
|
||||
fpath = osp.join(dataset_dir, osp.basename(dataset_url))
|
||||
|
||||
print('Downloading {} dataset to "{}"'.format(self.__class__.__name__, dataset_dir))
|
||||
download_url(dataset_url, fpath)
|
||||
|
||||
print('Extracting "{}"'.format(fpath))
|
||||
extension = osp.basename(fpath).split('.')[-1]
|
||||
try:
|
||||
tar = tarfile.open(fpath)
|
||||
tar.extractall(path=dataset_dir)
|
||||
tar.close()
|
||||
except:
|
||||
zip_ref = zipfile.ZipFile(fpath, 'r')
|
||||
zip_ref.extractall(dataset_dir)
|
||||
zip_ref.close()
|
||||
|
||||
print('{} dataset is ready'.format(self.__class__.__name__))
|
||||
|
||||
def check_before_run(self, required_files):
|
||||
"""Checks if required files exist before going deeper
|
||||
|
||||
:param required_files: string name(s) of file(s)
|
||||
:type required_files: str or list
|
||||
"""
|
||||
if isinstance(required_files, str):
|
||||
required_files = [required_files]
|
||||
|
||||
for fpath in required_files:
|
||||
if not osp.exists(fpath):
|
||||
raise RuntimeError('"{}" is not found'.format(fpath))
|
||||
|
||||
def __repr__(self):
|
||||
num_train_pids, num_train_cams = self.parse_data(self.train)
|
||||
num_query_pids, num_query_cams = self.parse_data(self.query)
|
||||
num_gallery_pids, num_gallery_cams = self.parse_data(self.gallery)
|
||||
|
||||
msg = ' ----------------------------------------\n' \
|
||||
' subset | # ids | # items | # cameras\n' \
|
||||
' ----------------------------------------\n' \
|
||||
' train | {:5d} | {:7d} | {:9d}\n' \
|
||||
' query | {:5d} | {:7d} | {:9d}\n' \
|
||||
' gallery | {:5d} | {:7d} | {:9d}\n' \
|
||||
' ----------------------------------------\n' \
|
||||
' items: images/tracklets for image/video dataset\n'.format(
|
||||
num_train_pids, len(self.train), num_train_cams,
|
||||
num_query_pids, len(self.query), num_query_cams,
|
||||
num_gallery_pids, len(self.gallery), num_gallery_cams
|
||||
)
|
||||
|
||||
return msg
|
||||
|
||||
|
||||
class ImageDataset(Dataset):
|
||||
|
||||
def __init__(self, train, query, gallery, **kwargs):
|
||||
super(ImageDataset, self).__init__(train, query, gallery, **kwargs)
|
||||
|
||||
def __getitem__(self, index):
|
||||
img_path, pid, camid = self.data[index]
|
||||
img = read_image(img_path)
|
||||
if self.transform is not None:
|
||||
img = self.transform(img)
|
||||
return img, pid, camid, img_path
|
||||
|
||||
def show_summary(self):
|
||||
num_train_pids, num_train_cams = self.parse_data(self.train)
|
||||
num_query_pids, num_query_cams = self.parse_data(self.query)
|
||||
num_gallery_pids, num_gallery_cams = self.parse_data(self.gallery)
|
||||
|
||||
print('=> Loaded {}'.format(self.__class__.__name__))
|
||||
print(' ----------------------------------------')
|
||||
print(' subset | # ids | # images | # cameras')
|
||||
print(' ----------------------------------------')
|
||||
print(' train | {:5d} | {:8d} | {:9d}'.format(num_train_pids, len(self.train), num_train_cams))
|
||||
print(' query | {:5d} | {:8d} | {:9d}'.format(num_query_pids, len(self.query), num_query_cams))
|
||||
print(' gallery | {:5d} | {:8d} | {:9d}'.format(num_gallery_pids, len(self.gallery), num_gallery_cams))
|
||||
print(' ----------------------------------------')
|
||||
|
||||
|
||||
class VideoDataset(Dataset):
|
||||
|
||||
def __init__(self, train, query, gallery, seq_len=15, sample_method='evenly', **kwargs):
|
||||
super(VideoDataset, self).__init__(train, query, gallery, **kwargs)
|
||||
self.seq_len = seq_len
|
||||
self.sample_method = sample_method
|
||||
|
||||
if self.transform is None:
|
||||
raise RuntimeError('transform must not be None')
|
||||
|
||||
def __getitem__(self, index):
|
||||
img_paths, pid, camid = self.data[index]
|
||||
num_imgs = len(img_paths)
|
||||
|
||||
if self.sample_method == 'random':
|
||||
# Randomly samples seq_len images from a tracklet of length num_imgs,
|
||||
# if num_imgs is smaller than seq_len, then replicates images
|
||||
indices = np.arange(num_imgs)
|
||||
replace = False if num_imgs>=self.seq_len else True
|
||||
indices = np.random.choice(indices, size=self.seq_len, replace=replace)
|
||||
# sort indices to keep temporal order (comment it to be order-agnostic)
|
||||
indices = np.sort(indices)
|
||||
|
||||
elif self.sample_method == 'evenly':
|
||||
# Evenly samples seq_len images from a tracklet
|
||||
if num_imgs >= self.seq_len:
|
||||
num_imgs -= num_imgs % self.seq_len
|
||||
indices = np.arange(0, num_imgs, num_imgs/self.seq_len)
|
||||
else:
|
||||
# if num_imgs is smaller than seq_len, simply replicate the last image
|
||||
# until the seq_len requirement is satisfied
|
||||
indices = np.arange(0, num_imgs)
|
||||
num_pads = self.seq_len - num_imgs
|
||||
indices = np.concatenate([indices, np.ones(num_pads).astype(np.int32)*(num_imgs-1)])
|
||||
assert len(indices) == self.seq_len
|
||||
|
||||
elif self.sample_method == 'all':
|
||||
# Samples all images in a tracklet. batch_size must be set to 1
|
||||
indices = np.arange(num_imgs)
|
||||
|
||||
else:
|
||||
raise ValueError('Unknown sample method: {}'.format(self.sample_method))
|
||||
|
||||
imgs = []
|
||||
for index in indices:
|
||||
img_path = img_paths[int(index)]
|
||||
img = read_image(img_path)
|
||||
if self.transform is not None:
|
||||
img = self.transform(img)
|
||||
img = img.unsqueeze(0) # img must be torch.Tensor
|
||||
imgs.append(img)
|
||||
imgs = torch.cat(imgs, dim=0)
|
||||
|
||||
return imgs, pid, camid
|
||||
|
||||
def show_summary(self):
|
||||
num_train_pids, num_train_cams = self.parse_data(self.train)
|
||||
num_query_pids, num_query_cams = self.parse_data(self.query)
|
||||
num_gallery_pids, num_gallery_cams = self.parse_data(self.gallery)
|
||||
|
||||
print('=> Loaded {}'.format(self.__class__.__name__))
|
||||
print(' -------------------------------------------')
|
||||
print(' subset | # ids | # tracklets | # cameras')
|
||||
print(' -------------------------------------------')
|
||||
print(' train | {:5d} | {:11d} | {:9d}'.format(num_train_pids, len(self.train), num_train_cams))
|
||||
print(' query | {:5d} | {:11d} | {:9d}'.format(num_query_pids, len(self.query), num_query_cams))
|
||||
print(' gallery | {:5d} | {:11d} | {:9d}'.format(num_gallery_pids, len(self.gallery), num_gallery_cams))
|
||||
print(' -------------------------------------------')
|
|
@ -0,0 +1,14 @@
|
|||
from __future__ import absolute_import
|
||||
from __future__ import print_function
|
||||
|
||||
from .market1501 import Market1501
|
||||
from .dukemtmcreid import DukeMTMCreID
|
||||
from .cuhk03 import CUHK03
|
||||
from .msmt17 import MSMT17
|
||||
from .viper import VIPeR
|
||||
from .grid import GRID
|
||||
from .cuhk01 import CUHK01
|
||||
from .prid450s import PRID450S
|
||||
from .ilids import iLIDS
|
||||
from .sensereid import SenseReID
|
||||
from .prid import PRID
|
|
@ -0,0 +1,135 @@
|
|||
from __future__ import absolute_import
|
||||
from __future__ import print_function
|
||||
from __future__ import division
|
||||
|
||||
import sys
|
||||
import os
|
||||
import os.path as osp
|
||||
import glob
|
||||
import zipfile
|
||||
import numpy as np
|
||||
|
||||
from torchreid.data.datasets import ImageDataset
|
||||
from torchreid.utils import read_json, write_json
|
||||
|
||||
|
||||
class CUHK01(ImageDataset):
|
||||
"""CUHK01
|
||||
|
||||
Reference:
|
||||
Li et al. Human Reidentification with Transferred Metric Learning. ACCV 2012.
|
||||
|
||||
URL: http://www.ee.cuhk.edu.hk/~xgwang/CUHK_identification.html
|
||||
|
||||
Dataset statistics:
|
||||
identities: 971
|
||||
images: 3884
|
||||
cameras: 4
|
||||
"""
|
||||
dataset_dir = 'cuhk01'
|
||||
dataset_url = None
|
||||
|
||||
def __init__(self, root='', split_id=0, **kwargs):
|
||||
self.root = osp.abspath(osp.expanduser(root))
|
||||
self.dataset_dir = osp.join(self.root, self.dataset_dir)
|
||||
self.download_dataset(self.dataset_dir, self.dataset_url)
|
||||
|
||||
self.zip_path = osp.join(self.dataset_dir, 'CUHK01.zip')
|
||||
self.campus_dir = osp.join(self.dataset_dir, 'campus')
|
||||
self.split_path = osp.join(self.dataset_dir, 'splits.json')
|
||||
|
||||
self.extract_file()
|
||||
|
||||
required_files = [
|
||||
self.dataset_dir,
|
||||
self.campus_dir
|
||||
]
|
||||
self.check_before_run(required_files)
|
||||
|
||||
self.prepare_split()
|
||||
splits = read_json(self.split_path)
|
||||
if split_id >= len(splits):
|
||||
raise ValueError('split_id exceeds range, received {}, but expected between 0 and {}'.format(split_id, len(splits)-1))
|
||||
split = splits[split_id]
|
||||
|
||||
train = split['train']
|
||||
query = split['query']
|
||||
gallery = split['gallery']
|
||||
|
||||
train = [tuple(item) for item in train]
|
||||
query = [tuple(item) for item in query]
|
||||
gallery = [tuple(item) for item in gallery]
|
||||
|
||||
super(CUHK01, self).__init__(train, query, gallery, **kwargs)
|
||||
|
||||
def extract_file(self):
|
||||
if not osp.exists(self.campus_dir):
|
||||
print('Extracting files')
|
||||
zip_ref = zipfile.ZipFile(self.zip_path, 'r')
|
||||
zip_ref.extractall(self.dataset_dir)
|
||||
zip_ref.close()
|
||||
|
||||
def prepare_split(self):
|
||||
"""
|
||||
Image name format: 0001001.png, where first four digits represent identity
|
||||
and last four digits represent cameras. Camera 1&2 are considered the same
|
||||
view and camera 3&4 are considered the same view.
|
||||
"""
|
||||
if not osp.exists(self.split_path):
|
||||
print('Creating 10 random splits of train ids and test ids')
|
||||
img_paths = sorted(glob.glob(osp.join(self.campus_dir, '*.png')))
|
||||
img_list = []
|
||||
pid_container = set()
|
||||
for img_path in img_paths:
|
||||
img_name = osp.basename(img_path)
|
||||
pid = int(img_name[:4]) - 1
|
||||
camid = (int(img_name[4:7]) - 1) // 2 # result is either 0 or 1
|
||||
img_list.append((img_path, pid, camid))
|
||||
pid_container.add(pid)
|
||||
|
||||
num_pids = len(pid_container)
|
||||
num_train_pids = num_pids // 2
|
||||
|
||||
splits = []
|
||||
for _ in range(10):
|
||||
order = np.arange(num_pids)
|
||||
np.random.shuffle(order)
|
||||
train_idxs = order[:num_train_pids]
|
||||
train_idxs = np.sort(train_idxs)
|
||||
idx2label = {idx: label for label, idx in enumerate(train_idxs)}
|
||||
|
||||
train, test_a, test_b = [], [], []
|
||||
for img_path, pid, camid in img_list:
|
||||
if pid in train_idxs:
|
||||
train.append((img_path, idx2label[pid], camid))
|
||||
else:
|
||||
if camid == 0:
|
||||
test_a.append((img_path, pid, camid))
|
||||
else:
|
||||
test_b.append((img_path, pid, camid))
|
||||
|
||||
# use cameraA as query and cameraB as gallery
|
||||
split = {
|
||||
'train': train,
|
||||
'query': test_a,
|
||||
'gallery': test_b,
|
||||
'num_train_pids': num_train_pids,
|
||||
'num_query_pids': num_pids - num_train_pids,
|
||||
'num_gallery_pids': num_pids - num_train_pids
|
||||
}
|
||||
splits.append(split)
|
||||
|
||||
# use cameraB as query and cameraA as gallery
|
||||
split = {
|
||||
'train': train,
|
||||
'query': test_b,
|
||||
'gallery': test_a,
|
||||
'num_train_pids': num_train_pids,
|
||||
'num_query_pids': num_pids - num_train_pids,
|
||||
'num_gallery_pids': num_pids - num_train_pids
|
||||
}
|
||||
splits.append(split)
|
||||
|
||||
print('Totally {} splits are created'.format(len(splits)))
|
||||
write_json(splits, self.split_path)
|
||||
print('Split file saved to {}'.format(self.split_path))
|
|
@ -0,0 +1,259 @@
|
|||
from __future__ import absolute_import
|
||||
from __future__ import print_function
|
||||
from __future__ import division
|
||||
|
||||
import sys
|
||||
import os
|
||||
import os.path as osp
|
||||
|
||||
from torchreid.data.datasets import ImageDataset
|
||||
from torchreid.utils import mkdir_if_missing, read_json, write_json
|
||||
|
||||
|
||||
class CUHK03(ImageDataset):
|
||||
"""CUHK03
|
||||
|
||||
Reference:
|
||||
Li et al. DeepReID: Deep Filter Pairing Neural Network for Person Re-identification. CVPR 2014.
|
||||
|
||||
URL: http://www.ee.cuhk.edu.hk/~xgwang/CUHK_identification.html#!
|
||||
|
||||
Dataset statistics:
|
||||
identities: 1360
|
||||
images: 13164
|
||||
cameras: 6
|
||||
splits: 20 (classic)
|
||||
"""
|
||||
dataset_dir = 'cuhk03'
|
||||
dataset_url = None
|
||||
|
||||
def __init__(self, root='', split_id=0, cuhk03_labeled=False, cuhk03_classic_split=False, **kwargs):
|
||||
self.root = osp.abspath(osp.expanduser(root))
|
||||
self.dataset_dir = osp.join(self.root, self.dataset_dir)
|
||||
self.download_dataset(self.dataset_dir, self.dataset_url)
|
||||
|
||||
self.data_dir = osp.join(self.dataset_dir, 'cuhk03_release')
|
||||
self.raw_mat_path = osp.join(self.data_dir, 'cuhk-03.mat')
|
||||
|
||||
self.imgs_detected_dir = osp.join(self.dataset_dir, 'images_detected')
|
||||
self.imgs_labeled_dir = osp.join(self.dataset_dir, 'images_labeled')
|
||||
|
||||
self.split_classic_det_json_path = osp.join(self.dataset_dir, 'splits_classic_detected.json')
|
||||
self.split_classic_lab_json_path = osp.join(self.dataset_dir, 'splits_classic_labeled.json')
|
||||
|
||||
self.split_new_det_json_path = osp.join(self.dataset_dir, 'splits_new_detected.json')
|
||||
self.split_new_lab_json_path = osp.join(self.dataset_dir, 'splits_new_labeled.json')
|
||||
|
||||
self.split_new_det_mat_path = osp.join(self.dataset_dir, 'cuhk03_new_protocol_config_detected.mat')
|
||||
self.split_new_lab_mat_path = osp.join(self.dataset_dir, 'cuhk03_new_protocol_config_labeled.mat')
|
||||
|
||||
required_files = [
|
||||
self.dataset_dir,
|
||||
self.data_dir,
|
||||
self.raw_mat_path,
|
||||
self.split_new_det_mat_path,
|
||||
self.split_new_lab_mat_path
|
||||
]
|
||||
self.check_before_run(required_files)
|
||||
|
||||
self.preprocess_split()
|
||||
|
||||
if cuhk03_labeled:
|
||||
split_path = self.split_classic_lab_json_path if cuhk03_classic_split else self.split_new_lab_json_path
|
||||
else:
|
||||
split_path = self.split_classic_det_json_path if cuhk03_classic_split else self.split_new_det_json_path
|
||||
|
||||
splits = read_json(split_path)
|
||||
assert split_id < len(splits), 'Condition split_id ({}) < len(splits) ({}) is false'.format(split_id, len(splits))
|
||||
split = splits[split_id]
|
||||
|
||||
train = split['train']
|
||||
query = split['query']
|
||||
gallery = split['gallery']
|
||||
|
||||
super(CUHK03, self).__init__(train, query, gallery, **kwargs)
|
||||
|
||||
def preprocess_split(self):
|
||||
"""
|
||||
This function is a bit complex and ugly, what it does is
|
||||
1. extract data from cuhk-03.mat and save as png images
|
||||
2. create 20 classic splits (Li et al. CVPR'14)
|
||||
3. create new split (Zhong et al. CVPR'17)
|
||||
"""
|
||||
if osp.exists(self.imgs_labeled_dir) \
|
||||
and osp.exists(self.imgs_detected_dir) \
|
||||
and osp.exists(self.split_classic_det_json_path) \
|
||||
and osp.exists(self.split_classic_lab_json_path) \
|
||||
and osp.exists(self.split_new_det_json_path) \
|
||||
and osp.exists(self.split_new_lab_json_path):
|
||||
return
|
||||
|
||||
import h5py
|
||||
from scipy.misc import imsave
|
||||
from scipy.io import loadmat
|
||||
|
||||
mkdir_if_missing(self.imgs_detected_dir)
|
||||
mkdir_if_missing(self.imgs_labeled_dir)
|
||||
|
||||
print('Extract image data from "{}" and save as png'.format(self.raw_mat_path))
|
||||
mat = h5py.File(self.raw_mat_path, 'r')
|
||||
|
||||
def _deref(ref):
|
||||
return mat[ref][:].T
|
||||
|
||||
def _process_images(img_refs, campid, pid, save_dir):
|
||||
img_paths = [] # Note: some persons only have images for one view
|
||||
for imgid, img_ref in enumerate(img_refs):
|
||||
img = _deref(img_ref)
|
||||
if img.size==0 or img.ndim<3:
|
||||
continue # skip empty cell
|
||||
# images are saved with the following format, index-1 (ensure uniqueness)
|
||||
# campid: index of camera pair (1-5)
|
||||
# pid: index of person in 'campid'-th camera pair
|
||||
# viewid: index of view, {1, 2}
|
||||
# imgid: index of image, (1-10)
|
||||
viewid = 1 if imgid<5 else 2
|
||||
img_name = '{:01d}_{:03d}_{:01d}_{:02d}.png'.format(campid+1, pid+1, viewid, imgid+1)
|
||||
img_path = osp.join(save_dir, img_name)
|
||||
if not osp.isfile(img_path):
|
||||
imsave(img_path, img)
|
||||
img_paths.append(img_path)
|
||||
return img_paths
|
||||
|
||||
def _extract_img(image_type):
|
||||
print('Processing {} images ...'.format(image_type))
|
||||
meta_data = []
|
||||
imgs_dir = self.imgs_detected_dir if image_type=='detected' else self.imgs_labeled_dir
|
||||
for campid, camp_ref in enumerate(mat[image_type][0]):
|
||||
camp = _deref(camp_ref)
|
||||
num_pids = camp.shape[0]
|
||||
for pid in range(num_pids):
|
||||
img_paths = _process_images(camp[pid,:], campid, pid, imgs_dir)
|
||||
assert len(img_paths) > 0, 'campid{}-pid{} has no images'.format(campid, pid)
|
||||
meta_data.append((campid+1, pid+1, img_paths))
|
||||
print('- done camera pair {} with {} identities'.format(campid+1, num_pids))
|
||||
return meta_data
|
||||
|
||||
meta_detected = _extract_img('detected')
|
||||
meta_labeled = _extract_img('labeled')
|
||||
|
||||
def _extract_classic_split(meta_data, test_split):
|
||||
train, test = [], []
|
||||
num_train_pids, num_test_pids = 0, 0
|
||||
num_train_imgs, num_test_imgs = 0, 0
|
||||
for i, (campid, pid, img_paths) in enumerate(meta_data):
|
||||
|
||||
if [campid, pid] in test_split:
|
||||
for img_path in img_paths:
|
||||
camid = int(osp.basename(img_path).split('_')[2]) - 1 # make it 0-based
|
||||
test.append((img_path, num_test_pids, camid))
|
||||
num_test_pids += 1
|
||||
num_test_imgs += len(img_paths)
|
||||
else:
|
||||
for img_path in img_paths:
|
||||
camid = int(osp.basename(img_path).split('_')[2]) - 1 # make it 0-based
|
||||
train.append((img_path, num_train_pids, camid))
|
||||
num_train_pids += 1
|
||||
num_train_imgs += len(img_paths)
|
||||
return train, num_train_pids, num_train_imgs, test, num_test_pids, num_test_imgs
|
||||
|
||||
print('Creating classic splits (# = 20) ...')
|
||||
splits_classic_det, splits_classic_lab = [], []
|
||||
for split_ref in mat['testsets'][0]:
|
||||
test_split = _deref(split_ref).tolist()
|
||||
|
||||
# create split for detected images
|
||||
train, num_train_pids, num_train_imgs, test, num_test_pids, num_test_imgs = \
|
||||
_extract_classic_split(meta_detected, test_split)
|
||||
splits_classic_det.append({
|
||||
'train': train,
|
||||
'query': test,
|
||||
'gallery': test,
|
||||
'num_train_pids': num_train_pids,
|
||||
'num_train_imgs': num_train_imgs,
|
||||
'num_query_pids': num_test_pids,
|
||||
'num_query_imgs': num_test_imgs,
|
||||
'num_gallery_pids': num_test_pids,
|
||||
'num_gallery_imgs': num_test_imgs
|
||||
})
|
||||
|
||||
# create split for labeled images
|
||||
train, num_train_pids, num_train_imgs, test, num_test_pids, num_test_imgs = \
|
||||
_extract_classic_split(meta_labeled, test_split)
|
||||
splits_classic_lab.append({
|
||||
'train': train,
|
||||
'query': test,
|
||||
'gallery': test,
|
||||
'num_train_pids': num_train_pids,
|
||||
'num_train_imgs': num_train_imgs,
|
||||
'num_query_pids': num_test_pids,
|
||||
'num_query_imgs': num_test_imgs,
|
||||
'num_gallery_pids': num_test_pids,
|
||||
'num_gallery_imgs': num_test_imgs
|
||||
})
|
||||
|
||||
write_json(splits_classic_det, self.split_classic_det_json_path)
|
||||
write_json(splits_classic_lab, self.split_classic_lab_json_path)
|
||||
|
||||
def _extract_set(filelist, pids, pid2label, idxs, img_dir, relabel):
|
||||
tmp_set = []
|
||||
unique_pids = set()
|
||||
for idx in idxs:
|
||||
img_name = filelist[idx][0]
|
||||
camid = int(img_name.split('_')[2]) - 1 # make it 0-based
|
||||
pid = pids[idx]
|
||||
if relabel:
|
||||
pid = pid2label[pid]
|
||||
img_path = osp.join(img_dir, img_name)
|
||||
tmp_set.append((img_path, int(pid), camid))
|
||||
unique_pids.add(pid)
|
||||
return tmp_set, len(unique_pids), len(idxs)
|
||||
|
||||
def _extract_new_split(split_dict, img_dir):
|
||||
train_idxs = split_dict['train_idx'].flatten() - 1 # index-0
|
||||
pids = split_dict['labels'].flatten()
|
||||
train_pids = set(pids[train_idxs])
|
||||
pid2label = {pid: label for label, pid in enumerate(train_pids)}
|
||||
query_idxs = split_dict['query_idx'].flatten() - 1
|
||||
gallery_idxs = split_dict['gallery_idx'].flatten() - 1
|
||||
filelist = split_dict['filelist'].flatten()
|
||||
train_info = _extract_set(filelist, pids, pid2label, train_idxs, img_dir, relabel=True)
|
||||
query_info = _extract_set(filelist, pids, pid2label, query_idxs, img_dir, relabel=False)
|
||||
gallery_info = _extract_set(filelist, pids, pid2label, gallery_idxs, img_dir, relabel=False)
|
||||
return train_info, query_info, gallery_info
|
||||
|
||||
print('Creating new split for detected images (767/700) ...')
|
||||
train_info, query_info, gallery_info = _extract_new_split(
|
||||
loadmat(self.split_new_det_mat_path),
|
||||
self.imgs_detected_dir
|
||||
)
|
||||
split = [{
|
||||
'train': train_info[0],
|
||||
'query': query_info[0],
|
||||
'gallery': gallery_info[0],
|
||||
'num_train_pids': train_info[1],
|
||||
'num_train_imgs': train_info[2],
|
||||
'num_query_pids': query_info[1],
|
||||
'num_query_imgs': query_info[2],
|
||||
'num_gallery_pids': gallery_info[1],
|
||||
'num_gallery_imgs': gallery_info[2]
|
||||
}]
|
||||
write_json(split, self.split_new_det_json_path)
|
||||
|
||||
print('Creating new split for labeled images (767/700) ...')
|
||||
train_info, query_info, gallery_info = _extract_new_split(
|
||||
loadmat(self.split_new_lab_mat_path),
|
||||
self.imgs_labeled_dir
|
||||
)
|
||||
split = [{
|
||||
'train': train_info[0],
|
||||
'query': query_info[0],
|
||||
'gallery': gallery_info[0],
|
||||
'num_train_pids': train_info[1],
|
||||
'num_train_imgs': train_info[2],
|
||||
'num_query_pids': query_info[1],
|
||||
'num_query_imgs': query_info[2],
|
||||
'num_gallery_pids': gallery_info[1],
|
||||
'num_gallery_imgs': gallery_info[2]
|
||||
}]
|
||||
write_json(split, self.split_new_lab_json_path)
|
|
@ -0,0 +1,71 @@
|
|||
from __future__ import absolute_import
|
||||
from __future__ import print_function
|
||||
from __future__ import division
|
||||
|
||||
import sys
|
||||
import os
|
||||
import os.path as osp
|
||||
import glob
|
||||
import re
|
||||
|
||||
from torchreid.data.datasets import ImageDataset
|
||||
|
||||
|
||||
class DukeMTMCreID(ImageDataset):
|
||||
"""DukeMTMC-reID
|
||||
|
||||
Reference:
|
||||
1. Ristani et al. Performance Measures and a Data Set for Multi-Target, Multi-Camera Tracking. ECCVW 2016.
|
||||
2. Zheng et al. Unlabeled Samples Generated by GAN Improve the Person Re-identification Baseline in vitro. ICCV 2017.
|
||||
|
||||
URL: https://github.com/layumi/DukeMTMC-reID_evaluation
|
||||
|
||||
Dataset statistics:
|
||||
identities: 1404 (train + query)
|
||||
images:16522 (train) + 2228 (query) + 17661 (gallery)
|
||||
cameras: 8
|
||||
"""
|
||||
dataset_dir = 'dukemtmc-reid'
|
||||
dataset_url = 'http://vision.cs.duke.edu/DukeMTMC/data/misc/DukeMTMC-reID.zip'
|
||||
|
||||
def __init__(self, root='', **kwargs):
|
||||
self.root = osp.abspath(osp.expanduser(root))
|
||||
self.dataset_dir = osp.join(self.root, self.dataset_dir)
|
||||
self.download_dataset(self.dataset_dir, self.dataset_url)
|
||||
self.train_dir = osp.join(self.dataset_dir, 'DukeMTMC-reID/bounding_box_train')
|
||||
self.query_dir = osp.join(self.dataset_dir, 'DukeMTMC-reID/query')
|
||||
self.gallery_dir = osp.join(self.dataset_dir, 'DukeMTMC-reID/bounding_box_test')
|
||||
|
||||
required_files = [
|
||||
self.dataset_dir,
|
||||
self.train_dir,
|
||||
self.query_dir,
|
||||
self.gallery_dir
|
||||
]
|
||||
self.check_before_run(required_files)
|
||||
|
||||
train = self.process_dir(self.train_dir, relabel=True)
|
||||
query = self.process_dir(self.query_dir, relabel=False)
|
||||
gallery = self.process_dir(self.gallery_dir, relabel=False)
|
||||
|
||||
super(DukeMTMCreID, self).__init__(train, query, gallery, **kwargs)
|
||||
|
||||
def process_dir(self, dir_path, relabel=False):
|
||||
img_paths = glob.glob(osp.join(dir_path, '*.jpg'))
|
||||
pattern = re.compile(r'([-\d]+)_c(\d)')
|
||||
|
||||
pid_container = set()
|
||||
for img_path in img_paths:
|
||||
pid, _ = map(int, pattern.search(img_path).groups())
|
||||
pid_container.add(pid)
|
||||
pid2label = {pid:label for label, pid in enumerate(pid_container)}
|
||||
|
||||
data = []
|
||||
for img_path in img_paths:
|
||||
pid, camid = map(int, pattern.search(img_path).groups())
|
||||
assert 1 <= camid <= 8
|
||||
camid -= 1 # index starts from 0
|
||||
if relabel: pid = pid2label[pid]
|
||||
data.append((img_path, pid, camid))
|
||||
|
||||
return data
|
|
@ -0,0 +1,114 @@
|
|||
from __future__ import absolute_import
|
||||
from __future__ import print_function
|
||||
from __future__ import division
|
||||
|
||||
import sys
|
||||
import os
|
||||
import os.path as osp
|
||||
import glob
|
||||
from scipy.io import loadmat
|
||||
|
||||
from torchreid.data.datasets import ImageDataset
|
||||
from torchreid.utils import read_json, write_json
|
||||
|
||||
|
||||
class GRID(ImageDataset):
|
||||
"""GRID
|
||||
|
||||
Reference:
|
||||
Loy et al. Multi-camera activity correlation analysis. CVPR 2009.
|
||||
|
||||
URL: http://personal.ie.cuhk.edu.hk/~ccloy/downloads_qmul_underground_reid.html
|
||||
|
||||
Dataset statistics:
|
||||
identities: 250
|
||||
images: 1275
|
||||
cameras: 8
|
||||
"""
|
||||
dataset_dir = 'grid'
|
||||
dataset_url = 'http://personal.ie.cuhk.edu.hk/~ccloy/files/datasets/underground_reid.zip'
|
||||
|
||||
def __init__(self, root='', split_id=0, **kwargs):
|
||||
self.root = osp.abspath(osp.expanduser(root))
|
||||
self.dataset_dir = osp.join(self.root, self.dataset_dir)
|
||||
self.download_dataset(self.dataset_dir, self.dataset_url)
|
||||
|
||||
self.probe_path = osp.join(self.dataset_dir, 'underground_reid', 'probe')
|
||||
self.gallery_path = osp.join(self.dataset_dir, 'underground_reid', 'gallery')
|
||||
self.split_mat_path = osp.join(self.dataset_dir, 'underground_reid', 'features_and_partitions.mat')
|
||||
self.split_path = osp.join(self.dataset_dir, 'splits.json')
|
||||
|
||||
required_files = [
|
||||
self.dataset_dir,
|
||||
self.probe_path,
|
||||
self.gallery_path,
|
||||
self.split_mat_path
|
||||
]
|
||||
self.check_before_run(required_files)
|
||||
|
||||
self.prepare_split()
|
||||
splits = read_json(self.split_path)
|
||||
if split_id >= len(splits):
|
||||
raise ValueError('split_id exceeds range, received {}, '
|
||||
'but expected between 0 and {}'.format(split_id, len(splits)-1))
|
||||
split = splits[split_id]
|
||||
|
||||
train = split['train']
|
||||
query = split['query']
|
||||
gallery = split['gallery']
|
||||
|
||||
train = [tuple(item) for item in train]
|
||||
query = [tuple(item) for item in query]
|
||||
gallery = [tuple(item) for item in gallery]
|
||||
|
||||
super(GRID, self).__init__(train, query, gallery, **kwargs)
|
||||
|
||||
def prepare_split(self):
|
||||
if not osp.exists(self.split_path):
|
||||
print('Creating 10 random splits')
|
||||
split_mat = loadmat(self.split_mat_path)
|
||||
trainIdxAll = split_mat['trainIdxAll'][0] # length = 10
|
||||
probe_img_paths = sorted(glob.glob(osp.join(self.probe_path, '*.jpeg')))
|
||||
gallery_img_paths = sorted(glob.glob(osp.join(self.gallery_path, '*.jpeg')))
|
||||
|
||||
splits = []
|
||||
for split_idx in range(10):
|
||||
train_idxs = trainIdxAll[split_idx][0][0][2][0].tolist()
|
||||
assert len(train_idxs) == 125
|
||||
idx2label = {idx: label for label, idx in enumerate(train_idxs)}
|
||||
|
||||
train, query, gallery = [], [], []
|
||||
|
||||
# processing probe folder
|
||||
for img_path in probe_img_paths:
|
||||
img_name = osp.basename(img_path)
|
||||
img_idx = int(img_name.split('_')[0])
|
||||
camid = int(img_name.split('_')[1]) - 1 # index starts from 0
|
||||
if img_idx in train_idxs:
|
||||
train.append((img_path, idx2label[img_idx], camid))
|
||||
else:
|
||||
query.append((img_path, img_idx, camid))
|
||||
|
||||
# process gallery folder
|
||||
for img_path in gallery_img_paths:
|
||||
img_name = osp.basename(img_path)
|
||||
img_idx = int(img_name.split('_')[0])
|
||||
camid = int(img_name.split('_')[1]) - 1 # index starts from 0
|
||||
if img_idx in train_idxs:
|
||||
train.append((img_path, idx2label[img_idx], camid))
|
||||
else:
|
||||
gallery.append((img_path, img_idx, camid))
|
||||
|
||||
split = {
|
||||
'train': train,
|
||||
'query': query,
|
||||
'gallery': gallery,
|
||||
'num_train_pids': 125,
|
||||
'num_query_pids': 125,
|
||||
'num_gallery_pids': 900
|
||||
}
|
||||
splits.append(split)
|
||||
|
||||
print('Totally {} splits are created'.format(len(splits)))
|
||||
write_json(splits, self.split_path)
|
||||
print('Split file saved to {}'.format(self.split_path))
|
|
@ -0,0 +1,142 @@
|
|||
from __future__ import absolute_import
|
||||
from __future__ import print_function
|
||||
from __future__ import division
|
||||
|
||||
import sys
|
||||
import os
|
||||
import os.path as osp
|
||||
import glob
|
||||
import numpy as np
|
||||
import copy
|
||||
import random
|
||||
from collections import defaultdict
|
||||
|
||||
from torchreid.data.datasets import ImageDataset
|
||||
from torchreid.utils import read_json, write_json
|
||||
|
||||
|
||||
class iLIDS(ImageDataset):
|
||||
"""QMUL-iLIDS
|
||||
|
||||
Reference:
|
||||
Zheng et al. Associating Groups of People. BMVC 2009.
|
||||
|
||||
Dataset statistics:
|
||||
identities: 119
|
||||
images: 476
|
||||
cameras: 8 (not explicitly provided)
|
||||
"""
|
||||
dataset_dir = 'ilids'
|
||||
dataset_url = 'http://www.eecs.qmul.ac.uk/~jason/data/i-LIDS_Pedestrian.tgz'
|
||||
|
||||
def __init__(self, root='', split_id=0, **kwargs):
|
||||
self.root = osp.abspath(osp.expanduser(root))
|
||||
self.dataset_dir = osp.join(self.root, self.dataset_dir)
|
||||
self.download_dataset(self.dataset_dir, self.dataset_url)
|
||||
|
||||
self.data_dir = osp.join(self.dataset_dir, 'i-LIDS_Pedestrian/Persons')
|
||||
self.split_path = osp.join(self.dataset_dir, 'splits.json')
|
||||
|
||||
required_files = [
|
||||
self.dataset_dir,
|
||||
self.data_dir
|
||||
]
|
||||
self.check_before_run(required_files)
|
||||
|
||||
self.prepare_split()
|
||||
splits = read_json(self.split_path)
|
||||
if split_id >= len(splits):
|
||||
raise ValueError('split_id exceeds range, received {}, but '
|
||||
'expected between 0 and {}'.format(split_id, len(splits)-1))
|
||||
split = splits[split_id]
|
||||
|
||||
train, query, gallery = self.process_split(split)
|
||||
|
||||
super(iLIDS, self).__init__(train, query, gallery, **kwargs)
|
||||
|
||||
def prepare_split(self):
|
||||
if not osp.exists(self.split_path):
|
||||
print('Creating splits ...')
|
||||
|
||||
paths = glob.glob(osp.join(self.data_dir, '*.jpg'))
|
||||
img_names = [osp.basename(path) for path in paths]
|
||||
num_imgs = len(img_names)
|
||||
assert num_imgs == 476, 'There should be 476 images, but ' \
|
||||
'got {}, please check the data'.format(num_imgs)
|
||||
|
||||
# store image names
|
||||
# image naming format:
|
||||
# the first four digits denote the person ID
|
||||
# the last four digits denote the sequence index
|
||||
pid_dict = defaultdict(list)
|
||||
for img_name in img_names:
|
||||
pid = int(img_name[:4])
|
||||
pid_dict[pid].append(img_name)
|
||||
pids = list(pid_dict.keys())
|
||||
num_pids = len(pids)
|
||||
assert num_pids == 119, 'There should be 119 identities, ' \
|
||||
'but got {}, please check the data'.format(num_pids)
|
||||
|
||||
num_train_pids = int(num_pids * 0.5)
|
||||
num_test_pids = num_pids - num_train_pids # supposed to be 60
|
||||
|
||||
splits = []
|
||||
for _ in range(10):
|
||||
# randomly choose num_train_pids train IDs and num_test_pids test IDs
|
||||
pids_copy = copy.deepcopy(pids)
|
||||
random.shuffle(pids_copy)
|
||||
train_pids = pids_copy[:num_train_pids]
|
||||
test_pids = pids_copy[num_train_pids:]
|
||||
|
||||
train = []
|
||||
query = []
|
||||
gallery = []
|
||||
|
||||
# for train IDs, all images are used in the train set.
|
||||
for pid in train_pids:
|
||||
img_names = pid_dict[pid]
|
||||
train.extend(img_names)
|
||||
|
||||
# for each test ID, randomly choose two images, one for
|
||||
# query and the other one for gallery.
|
||||
for pid in test_pids:
|
||||
img_names = pid_dict[pid]
|
||||
samples = random.sample(img_names, 2)
|
||||
query.append(samples[0])
|
||||
gallery.append(samples[1])
|
||||
|
||||
split = {'train': train, 'query': query, 'gallery': gallery}
|
||||
splits.append(split)
|
||||
|
||||
print('Totally {} splits are created'.format(len(splits)))
|
||||
write_json(splits, self.split_path)
|
||||
print('Split file is saved to {}'.format(self.split_path))
|
||||
|
||||
def get_pid2label(self, img_names):
|
||||
pid_container = set()
|
||||
for img_name in img_names:
|
||||
pid = int(img_name[:4])
|
||||
pid_container.add(pid)
|
||||
pid2label = {pid: label for label, pid in enumerate(pid_container)}
|
||||
return pid2label
|
||||
|
||||
def parse_img_names(self, img_names, pid2label=None):
|
||||
data = []
|
||||
|
||||
for img_name in img_names:
|
||||
pid = int(img_name[:4])
|
||||
if pid2label is not None:
|
||||
pid = pid2label[pid]
|
||||
camid = int(img_name[4:7]) - 1 # 0-based
|
||||
img_path = osp.join(self.data_dir, img_name)
|
||||
data.append((img_path, pid, camid))
|
||||
|
||||
return data
|
||||
|
||||
def process_split(self, split):
|
||||
train, query, gallery = [], [], []
|
||||
train_pid2label = self.get_pid2label(split['train'])
|
||||
train = self.parse_img_names(split['train'], train_pid2label)
|
||||
query = self.parse_img_names(split['query'])
|
||||
gallery = self.parse_img_names(split['gallery'])
|
||||
return train, query, gallery
|
|
@ -0,0 +1,88 @@
|
|||
from __future__ import absolute_import
|
||||
from __future__ import print_function
|
||||
from __future__ import division
|
||||
|
||||
import sys
|
||||
import os
|
||||
import os.path as osp
|
||||
import glob
|
||||
import re
|
||||
|
||||
from torchreid.data.datasets import ImageDataset
|
||||
|
||||
|
||||
class Market1501(ImageDataset):
|
||||
"""Market1501
|
||||
|
||||
Reference:
|
||||
Zheng et al. Scalable Person Re-identification: A Benchmark. ICCV 2015.
|
||||
|
||||
URL: http://www.liangzheng.org/Project/project_reid.html
|
||||
|
||||
Dataset statistics:
|
||||
identities: 1501 (+1 for background)
|
||||
images: 12936 (train) + 3368 (query) + 15913 (gallery)
|
||||
"""
|
||||
dataset_dir = 'market1501'
|
||||
dataset_url = 'http://188.138.127.15:81/Datasets/Market-1501-v15.09.15.zip'
|
||||
|
||||
def __init__(self, root='', market1501_500k=False, **kwargs):
|
||||
self.root = osp.abspath(osp.expanduser(root))
|
||||
self.dataset_dir = osp.join(self.root, self.dataset_dir)
|
||||
self.download_dataset(self.dataset_dir, self.dataset_url)
|
||||
|
||||
# allow alternative directory structure
|
||||
self.data_dir = self.dataset_dir
|
||||
data_dir = osp.join(self.data_dir, 'Market-1501-v15.09.15')
|
||||
if osp.isdir(data_dir):
|
||||
self.data_dir = data_dir
|
||||
|
||||
self.train_dir = osp.join(self.data_dir, 'bounding_box_train')
|
||||
self.query_dir = osp.join(self.data_dir, 'query')
|
||||
self.gallery_dir = osp.join(self.data_dir, 'bounding_box_test')
|
||||
self.extra_gallery_dir = osp.join(self.data_dir, 'images')
|
||||
self.market1501_500k = market1501_500k
|
||||
|
||||
required_files = [
|
||||
self.data_dir,
|
||||
self.train_dir,
|
||||
self.query_dir,
|
||||
self.gallery_dir
|
||||
]
|
||||
if self.market1501_500k:
|
||||
required_files.append(self.extra_gallery_dir)
|
||||
self.check_before_run(required_files)
|
||||
|
||||
train = self.process_dir(self.train_dir, relabel=True)
|
||||
query = self.process_dir(self.query_dir, relabel=False)
|
||||
gallery = self.process_dir(self.gallery_dir, relabel=False)
|
||||
if self.market1501_500k:
|
||||
gallery += self.process_dir(self.extra_gallery_dir, relabel=False)
|
||||
|
||||
super(Market1501, self).__init__(train, query, gallery, **kwargs)
|
||||
|
||||
def process_dir(self, dir_path, relabel=False):
|
||||
img_paths = glob.glob(osp.join(dir_path, '*.jpg'))
|
||||
pattern = re.compile(r'([-\d]+)_c(\d)')
|
||||
|
||||
pid_container = set()
|
||||
for img_path in img_paths:
|
||||
pid, _ = map(int, pattern.search(img_path).groups())
|
||||
if pid == -1:
|
||||
continue # junk images are just ignored
|
||||
pid_container.add(pid)
|
||||
pid2label = {pid:label for label, pid in enumerate(pid_container)}
|
||||
|
||||
data = []
|
||||
for img_path in img_paths:
|
||||
pid, camid = map(int, pattern.search(img_path).groups())
|
||||
if pid == -1:
|
||||
continue # junk images are just ignored
|
||||
assert 0 <= pid <= 1501 # pid == 0 means background
|
||||
assert 1 <= camid <= 6
|
||||
camid -= 1 # index starts from 0
|
||||
if relabel:
|
||||
pid = pid2label[pid]
|
||||
data.append((img_path, pid, camid))
|
||||
|
||||
return data
|
|
@ -0,0 +1,98 @@
|
|||
from __future__ import absolute_import
|
||||
from __future__ import print_function
|
||||
from __future__ import division
|
||||
|
||||
import sys
|
||||
import os
|
||||
import os.path as osp
|
||||
|
||||
from torchreid.data.datasets import ImageDataset
|
||||
|
||||
|
||||
# To adapt to different versions
|
||||
# Log:
|
||||
# 22.01.2019: v1 and v2 only differ in dir names
|
||||
TRAIN_DIR_KEY = 'train_dir'
|
||||
TEST_DIR_KEY = 'test_dir'
|
||||
VERSION_DICT = {
|
||||
'MSMT17_V1': {
|
||||
TRAIN_DIR_KEY: 'train',
|
||||
TEST_DIR_KEY: 'test',
|
||||
},
|
||||
'MSMT17_V2': {
|
||||
TRAIN_DIR_KEY: 'mask_train_v2',
|
||||
TEST_DIR_KEY: 'mask_test_v2',
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
class MSMT17(ImageDataset):
|
||||
"""MSMT17
|
||||
|
||||
Reference:
|
||||
Wei et al. Person Transfer GAN to Bridge Domain Gap for Person Re-Identification. CVPR 2018.
|
||||
|
||||
URL: http://www.pkuvmc.com/publications/msmt17.html
|
||||
|
||||
Dataset statistics:
|
||||
identities: 4101
|
||||
images: 32621 (train) + 11659 (query) + 82161 (gallery)
|
||||
cameras: 15
|
||||
"""
|
||||
dataset_dir = 'msmt17'
|
||||
dataset_url = None
|
||||
|
||||
def __init__(self, root='', **kwargs):
|
||||
self.root = osp.abspath(osp.expanduser(root))
|
||||
self.dataset_dir = osp.join(self.root, self.dataset_dir)
|
||||
self.download_dataset(self.dataset_dir, self.dataset_url)
|
||||
|
||||
has_main_dir = False
|
||||
for main_dir in VERSION_DICT:
|
||||
if osp.exists(osp.join(self.dataset_dir, main_dir)):
|
||||
train_dir = VERSION_DICT[main_dir][TRAIN_DIR_KEY]
|
||||
test_dir = VERSION_DICT[main_dir][TEST_DIR_KEY]
|
||||
has_main_dir = True
|
||||
break
|
||||
assert has_main_dir, 'Dataset folder not found'
|
||||
|
||||
self.train_dir = osp.join(self.dataset_dir, main_dir, train_dir)
|
||||
self.test_dir = osp.join(self.dataset_dir, main_dir, test_dir)
|
||||
self.list_train_path = osp.join(self.dataset_dir, main_dir, 'list_train.txt')
|
||||
self.list_val_path = osp.join(self.dataset_dir, main_dir, 'list_val.txt')
|
||||
self.list_query_path = osp.join(self.dataset_dir, main_dir, 'list_query.txt')
|
||||
self.list_gallery_path = osp.join(self.dataset_dir, main_dir, 'list_gallery.txt')
|
||||
|
||||
required_files = [
|
||||
self.dataset_dir,
|
||||
self.train_dir,
|
||||
self.test_dir
|
||||
]
|
||||
self.check_before_run(required_files)
|
||||
|
||||
train = self.process_dir(self.train_dir, self.list_train_path)
|
||||
val = self.process_dir(self.train_dir, self.list_val_path)
|
||||
query = self.process_dir(self.test_dir, self.list_query_path)
|
||||
gallery = self.process_dir(self.test_dir, self.list_gallery_path)
|
||||
|
||||
# Note: to fairly compare with published methods on the conventional ReID setting,
|
||||
# do not add val images to the training set.
|
||||
if 'combineall' in kwargs and kwargs['combineall']:
|
||||
train += val
|
||||
|
||||
super(MSMT17, self).__init__(train, query, gallery, **kwargs)
|
||||
|
||||
def process_dir(self, dir_path, list_path):
|
||||
with open(list_path, 'r') as txt:
|
||||
lines = txt.readlines()
|
||||
|
||||
data = []
|
||||
|
||||
for img_idx, img_info in enumerate(lines):
|
||||
img_path, pid = img_info.split(' ')
|
||||
pid = int(pid) # no need to relabel
|
||||
camid = int(img_path.split('_')[2]) - 1 # index starts from 0
|
||||
img_path = osp.join(dir_path, img_path)
|
||||
data.append((img_path, pid, camid))
|
||||
|
||||
return data
|
|
@ -0,0 +1,106 @@
|
|||
from __future__ import absolute_import
|
||||
from __future__ import print_function
|
||||
from __future__ import division
|
||||
|
||||
import sys
|
||||
import os
|
||||
import os.path as osp
|
||||
import random
|
||||
|
||||
from torchreid.data.datasets import ImageDataset
|
||||
from torchreid.utils import read_json, write_json
|
||||
|
||||
|
||||
class PRID(ImageDataset):
|
||||
"""PRID (single-shot version of prid-2011)
|
||||
|
||||
Reference:
|
||||
Hirzer et al. Person Re-Identification by Descriptive and Discriminative Classification. SCIA 2011.
|
||||
|
||||
URL: https://www.tugraz.at/institute/icg/research/team-bischof/lrs/downloads/PRID11/
|
||||
|
||||
Dataset statistics:
|
||||
Two views
|
||||
View A captures 385 identities
|
||||
View B captures 749 identities
|
||||
200 identities appear in both views
|
||||
"""
|
||||
dataset_dir = 'prid2011'
|
||||
dataset_url = None
|
||||
|
||||
def __init__(self, root='', split_id=0, **kwargs):
|
||||
self.root = osp.abspath(osp.expanduser(root))
|
||||
self.dataset_dir = osp.join(self.root, self.dataset_dir)
|
||||
self.download_dataset(self.dataset_dir, self.dataset_url)
|
||||
|
||||
self.cam_a_dir = osp.join(self.dataset_dir, 'prid_2011', 'single_shot', 'cam_a')
|
||||
self.cam_b_dir = osp.join(self.dataset_dir, 'prid_2011', 'single_shot', 'cam_b')
|
||||
self.split_path = osp.join(self.dataset_dir, 'splits_single_shot.json')
|
||||
|
||||
required_files = [
|
||||
self.dataset_dir,
|
||||
self.cam_a_dir,
|
||||
self.cam_b_dir
|
||||
]
|
||||
self.check_before_run(required_files)
|
||||
|
||||
self.prepare_split()
|
||||
splits = read_json(self.split_path)
|
||||
if split_id >= len(splits):
|
||||
raise ValueError('split_id exceeds range, received {}, but expected between 0 and {}'.format(split_id, len(splits)-1))
|
||||
split = splits[split_id]
|
||||
|
||||
train, query, gallery = self.process_split(split)
|
||||
|
||||
super(PRID, self).__init__(train, query, gallery, **kwargs)
|
||||
|
||||
def prepare_split(self):
|
||||
if not osp.exists(self.split_path):
|
||||
print('Creating splits ...')
|
||||
|
||||
splits = []
|
||||
for _ in range(10):
|
||||
# randomly sample 100 IDs for train and use the rest 100 IDs for test
|
||||
# (note: there are only 200 IDs appearing in both views)
|
||||
pids = [i for i in range(1, 201)]
|
||||
train_pids = random.sample(pids, 100)
|
||||
train_pids.sort()
|
||||
test_pids = [i for i in pids if i not in train_pids]
|
||||
split = {'train': train_pids, 'test': test_pids}
|
||||
splits.append(split)
|
||||
|
||||
print('Totally {} splits are created'.format(len(splits)))
|
||||
write_json(splits, self.split_path)
|
||||
print('Split file is saved to {}'.format(self.split_path))
|
||||
|
||||
def process_split(self, split):
|
||||
train, query, gallery = [], [], []
|
||||
train_pids = split['train']
|
||||
test_pids = split['test']
|
||||
|
||||
train_pid2label = {pid: label for label, pid in enumerate(train_pids)}
|
||||
|
||||
# train
|
||||
train = []
|
||||
for pid in train_pids:
|
||||
img_name = 'person_' + str(pid).zfill(4) + '.png'
|
||||
pid = train_pid2label[pid]
|
||||
img_a_path = osp.join(self.cam_a_dir, img_name)
|
||||
train.append((img_a_path, pid, 0))
|
||||
img_b_path = osp.join(self.cam_b_dir, img_name)
|
||||
train.append((img_b_path, pid, 1))
|
||||
|
||||
# query and gallery
|
||||
query, gallery = [], []
|
||||
for pid in test_pids:
|
||||
img_name = 'person_' + str(pid).zfill(4) + '.png'
|
||||
img_a_path = osp.join(self.cam_a_dir, img_name)
|
||||
query.append((img_a_path, pid, 0))
|
||||
img_b_path = osp.join(self.cam_b_dir, img_name)
|
||||
gallery.append((img_b_path, pid, 1))
|
||||
for pid in range(201, 750):
|
||||
img_name = 'person_' + str(pid).zfill(4) + '.png'
|
||||
img_b_path = osp.join(self.cam_b_dir, img_name)
|
||||
gallery.append((img_b_path, pid, 1))
|
||||
|
||||
return train, query, gallery
|
|
@ -0,0 +1,111 @@
|
|||
from __future__ import absolute_import
|
||||
from __future__ import print_function
|
||||
from __future__ import division
|
||||
|
||||
import sys
|
||||
import os
|
||||
import os.path as osp
|
||||
import glob
|
||||
import numpy as np
|
||||
|
||||
from torchreid.data.datasets import ImageDataset
|
||||
from torchreid.utils import read_json, write_json
|
||||
|
||||
|
||||
class PRID450S(ImageDataset):
|
||||
"""PRID450S
|
||||
|
||||
Reference:
|
||||
Roth et al. Mahalanobis Distance Learning for Person Re-Identification. PR 2014.
|
||||
|
||||
URL: https://www.tugraz.at/institute/icg/research/team-bischof/lrs/downloads/prid450s/
|
||||
|
||||
Dataset statistics:
|
||||
identities: 450
|
||||
images: 900
|
||||
cameras: 2
|
||||
"""
|
||||
dataset_dir = 'prid450s'
|
||||
dataset_url = 'https://files.icg.tugraz.at/f/8c709245bb/?raw=1'
|
||||
|
||||
def __init__(self, root='', split_id=0, min_seq_len=0, **kwargs):
|
||||
self.root = osp.abspath(osp.expanduser(root))
|
||||
self.dataset_dir = osp.join(self.root, self.dataset_dir)
|
||||
self.download_dataset(self.dataset_dir, self.dataset_url)
|
||||
|
||||
self.split_path = osp.join(self.dataset_dir, 'splits.json')
|
||||
self.cam_a_dir = osp.join(self.dataset_dir, 'cam_a')
|
||||
self.cam_b_dir = osp.join(self.dataset_dir, 'cam_b')
|
||||
|
||||
required_files = [
|
||||
self.dataset_dir,
|
||||
self.cam_a_dir,
|
||||
self.cam_b_dir
|
||||
]
|
||||
self.check_before_run(required_files)
|
||||
|
||||
self.prepare_split()
|
||||
splits = read_json(self.split_path)
|
||||
if split_id >= len(splits):
|
||||
raise ValueError('split_id exceeds range, received {}, but expected between 0 and {}'.format(split_id, len(splits)-1))
|
||||
split = splits[split_id]
|
||||
|
||||
train = split['train']
|
||||
query = split['query']
|
||||
gallery = split['gallery']
|
||||
|
||||
train = [tuple(item) for item in train]
|
||||
query = [tuple(item) for item in query]
|
||||
gallery = [tuple(item) for item in gallery]
|
||||
|
||||
super(PRID450S, self).__init__(train, query, gallery, **kwargs)
|
||||
|
||||
def prepare_split(self):
|
||||
if not osp.exists(self.split_path):
|
||||
cam_a_imgs = sorted(glob.glob(osp.join(self.cam_a_dir, 'img_*.png')))
|
||||
cam_b_imgs = sorted(glob.glob(osp.join(self.cam_b_dir, 'img_*.png')))
|
||||
assert len(cam_a_imgs) == len(cam_b_imgs)
|
||||
|
||||
num_pids = len(cam_a_imgs)
|
||||
num_train_pids = num_pids // 2
|
||||
|
||||
splits = []
|
||||
for _ in range(10):
|
||||
order = np.arange(num_pids)
|
||||
np.random.shuffle(order)
|
||||
train_idxs = np.sort(order[:num_train_pids])
|
||||
idx2label = {idx: label for label, idx in enumerate(train_idxs)}
|
||||
|
||||
train, test = [], []
|
||||
|
||||
# processing camera a
|
||||
for img_path in cam_a_imgs:
|
||||
img_name = osp.basename(img_path)
|
||||
img_idx = int(img_name.split('_')[1].split('.')[0])
|
||||
if img_idx in train_idxs:
|
||||
train.append((img_path, idx2label[img_idx], 0))
|
||||
else:
|
||||
test.append((img_path, img_idx, 0))
|
||||
|
||||
# processing camera b
|
||||
for img_path in cam_b_imgs:
|
||||
img_name = osp.basename(img_path)
|
||||
img_idx = int(img_name.split('_')[1].split('.')[0])
|
||||
if img_idx in train_idxs:
|
||||
train.append((img_path, idx2label[img_idx], 1))
|
||||
else:
|
||||
test.append((img_path, img_idx, 1))
|
||||
|
||||
split = {
|
||||
'train': train,
|
||||
'query': test,
|
||||
'gallery': test,
|
||||
'num_train_pids': num_train_pids,
|
||||
'num_query_pids': num_pids - num_train_pids,
|
||||
'num_gallery_pids': num_pids - num_train_pids
|
||||
}
|
||||
splits.append(split)
|
||||
|
||||
print('Totally {} splits are created'.format(len(splits)))
|
||||
write_json(splits, self.split_path)
|
||||
print('Split file saved to {}'.format(self.split_path))
|
|
@ -0,0 +1,73 @@
|
|||
from __future__ import absolute_import
|
||||
from __future__ import print_function
|
||||
from __future__ import division
|
||||
|
||||
import sys
|
||||
import os
|
||||
import os.path as osp
|
||||
import glob
|
||||
import copy
|
||||
|
||||
from torchreid.data.datasets import ImageDataset
|
||||
|
||||
|
||||
class SenseReID(ImageDataset):
|
||||
"""SenseReID
|
||||
|
||||
This dataset is used for test purpose only.
|
||||
|
||||
Reference:
|
||||
Zhao et al. Spindle Net: Person Re-identification with Human Body
|
||||
Region Guided Feature Decomposition and Fusion. CVPR 2017.
|
||||
|
||||
URL: https://drive.google.com/file/d/0B56OfSrVI8hubVJLTzkwV2VaOWM/view
|
||||
|
||||
Dataset statistics:
|
||||
train: 0 ids, 0 images
|
||||
query: 522 ids, 1040 images
|
||||
gallery: 1717 ids, 3388 images
|
||||
"""
|
||||
dataset_dir = 'sensereid'
|
||||
dataset_url = None
|
||||
|
||||
def __init__(self, root='', **kwargs):
|
||||
self.root = osp.abspath(osp.expanduser(root))
|
||||
self.dataset_dir = osp.join(self.root, self.dataset_dir)
|
||||
self.download_dataset(self.dataset_dir, self.dataset_url)
|
||||
|
||||
self.query_dir = osp.join(self.dataset_dir, 'SenseReID', 'test_probe')
|
||||
self.gallery_dir = osp.join(self.dataset_dir, 'SenseReID', 'test_gallery')
|
||||
|
||||
required_files = [
|
||||
self.dataset_dir,
|
||||
self.query_dir,
|
||||
self.gallery_dir
|
||||
]
|
||||
self.check_before_run(required_files)
|
||||
|
||||
query = self.process_dir(self.query_dir)
|
||||
gallery = self.process_dir(self.gallery_dir)
|
||||
|
||||
# relabel
|
||||
g_pids = set()
|
||||
for _, pid, _ in gallery:
|
||||
g_pids.add(pid)
|
||||
pid2label = {pid: i for i, pid in enumerate(g_pids)}
|
||||
|
||||
query = [(img_path, pid2label[pid], camid) for img_path, pid, camid in query]
|
||||
gallery = [(img_path, pid2label[pid], camid) for img_path, pid, camid in gallery]
|
||||
train = copy.deepcopy(query) + copy.deepcopy(gallery) # dummy variable
|
||||
|
||||
super(SenseReID, self).__init__(train, query, gallery, **kwargs)
|
||||
|
||||
def process_dir(self, dir_path):
|
||||
img_paths = glob.glob(osp.join(dir_path, '*.jpg'))
|
||||
data = []
|
||||
|
||||
for img_path in img_paths:
|
||||
img_name = osp.splitext(osp.basename(img_path))[0]
|
||||
pid, camid = img_name.split('_')
|
||||
pid, camid = int(pid), int(camid)
|
||||
data.append((img_path, pid, camid))
|
||||
|
||||
return data
|
|
@ -0,0 +1,131 @@
|
|||
from __future__ import absolute_import
|
||||
from __future__ import print_function
|
||||
from __future__ import division
|
||||
|
||||
import sys
|
||||
import os
|
||||
import os.path as osp
|
||||
import glob
|
||||
import numpy as np
|
||||
|
||||
from torchreid.data.datasets import ImageDataset
|
||||
from torchreid.utils import read_json, write_json
|
||||
|
||||
|
||||
class VIPeR(ImageDataset):
|
||||
"""VIPeR
|
||||
|
||||
Reference:
|
||||
Gray et al. Evaluating appearance models for recognition, reacquisition, and tracking. PETS 2007.
|
||||
|
||||
URL: https://vision.soe.ucsc.edu/node/178
|
||||
|
||||
Dataset statistics:
|
||||
identities: 632
|
||||
images: 632 x 2 = 1264
|
||||
cameras: 2
|
||||
"""
|
||||
dataset_dir = 'viper'
|
||||
dataset_url = 'http://users.soe.ucsc.edu/~manduchi/VIPeR.v1.0.zip'
|
||||
|
||||
def __init__(self, root='', split_id=0, **kwargs):
|
||||
self.root = osp.abspath(osp.expanduser(root))
|
||||
self.dataset_dir = osp.join(self.root, self.dataset_dir)
|
||||
self.download_dataset(self.dataset_dir, self.dataset_url)
|
||||
|
||||
self.cam_a_dir = osp.join(self.dataset_dir, 'VIPeR', 'cam_a')
|
||||
self.cam_b_dir = osp.join(self.dataset_dir, 'VIPeR', 'cam_b')
|
||||
self.split_path = osp.join(self.dataset_dir, 'splits.json')
|
||||
|
||||
required_files = [
|
||||
self.dataset_dir,
|
||||
self.cam_a_dir,
|
||||
self.cam_b_dir
|
||||
]
|
||||
self.check_before_run(required_files)
|
||||
|
||||
self.prepare_split()
|
||||
splits = read_json(self.split_path)
|
||||
if split_id >= len(splits):
|
||||
raise ValueError('split_id exceeds range, received {}, '
|
||||
'but expected between 0 and {}'.format(split_id, len(splits)-1))
|
||||
split = splits[split_id]
|
||||
|
||||
train = split['train']
|
||||
query = split['query'] # query and gallery share the same images
|
||||
gallery = split['gallery']
|
||||
|
||||
train = [tuple(item) for item in train]
|
||||
query = [tuple(item) for item in query]
|
||||
gallery = [tuple(item) for item in gallery]
|
||||
|
||||
super(VIPeR, self).__init__(train, query, gallery, **kwargs)
|
||||
|
||||
def prepare_split(self):
|
||||
if not osp.exists(self.split_path):
|
||||
print('Creating 10 random splits of train ids and test ids')
|
||||
|
||||
cam_a_imgs = sorted(glob.glob(osp.join(self.cam_a_dir, '*.bmp')))
|
||||
cam_b_imgs = sorted(glob.glob(osp.join(self.cam_b_dir, '*.bmp')))
|
||||
assert len(cam_a_imgs) == len(cam_b_imgs)
|
||||
num_pids = len(cam_a_imgs)
|
||||
print('Number of identities: {}'.format(num_pids))
|
||||
num_train_pids = num_pids // 2
|
||||
|
||||
"""
|
||||
In total, there will be 20 splits because each random split creates two
|
||||
sub-splits, one using cameraA as query and cameraB as gallery
|
||||
while the other using cameraB as query and cameraA as gallery.
|
||||
Therefore, results should be averaged over 20 splits (split_id=0~19).
|
||||
|
||||
In practice, a model trained on split_id=0 can be applied to split_id=0&1
|
||||
as split_id=0&1 share the same training data (so on and so forth).
|
||||
"""
|
||||
splits = []
|
||||
for _ in range(10):
|
||||
order = np.arange(num_pids)
|
||||
np.random.shuffle(order)
|
||||
train_idxs = order[:num_train_pids]
|
||||
test_idxs = order[num_train_pids:]
|
||||
assert not bool(set(train_idxs) & set(test_idxs)), 'Error: train and test overlap'
|
||||
|
||||
train = []
|
||||
for pid, idx in enumerate(train_idxs):
|
||||
cam_a_img = cam_a_imgs[idx]
|
||||
cam_b_img = cam_b_imgs[idx]
|
||||
train.append((cam_a_img, pid, 0))
|
||||
train.append((cam_b_img, pid, 1))
|
||||
|
||||
test_a = []
|
||||
test_b = []
|
||||
for pid, idx in enumerate(test_idxs):
|
||||
cam_a_img = cam_a_imgs[idx]
|
||||
cam_b_img = cam_b_imgs[idx]
|
||||
test_a.append((cam_a_img, pid, 0))
|
||||
test_b.append((cam_b_img, pid, 1))
|
||||
|
||||
# use cameraA as query and cameraB as gallery
|
||||
split = {
|
||||
'train': train,
|
||||
'query': test_a,
|
||||
'gallery': test_b,
|
||||
'num_train_pids': num_train_pids,
|
||||
'num_query_pids': num_pids - num_train_pids,
|
||||
'num_gallery_pids': num_pids - num_train_pids
|
||||
}
|
||||
splits.append(split)
|
||||
|
||||
# use cameraB as query and cameraA as gallery
|
||||
split = {
|
||||
'train': train,
|
||||
'query': test_b,
|
||||
'gallery': test_a,
|
||||
'num_train_pids': num_train_pids,
|
||||
'num_query_pids': num_pids - num_train_pids,
|
||||
'num_gallery_pids': num_pids - num_train_pids
|
||||
}
|
||||
splits.append(split)
|
||||
|
||||
print('Totally {} splits are created'.format(len(splits)))
|
||||
write_json(splits, self.split_path)
|
||||
print('Split file saved to {}'.format(self.split_path))
|
|
@ -0,0 +1,7 @@
|
|||
from __future__ import absolute_import
|
||||
from __future__ import print_function
|
||||
|
||||
from .mars import Mars
|
||||
from .ilidsvid import iLIDSVID
|
||||
from .prid2011 import PRID2011
|
||||
from .dukemtmcvidreid import DukeMTMCVidReID
|
|
@ -0,0 +1,109 @@
|
|||
from __future__ import absolute_import
|
||||
from __future__ import print_function
|
||||
from __future__ import division
|
||||
|
||||
import sys
|
||||
import os
|
||||
import os.path as osp
|
||||
import glob
|
||||
import warnings
|
||||
|
||||
from torchreid.data.datasets import VideoDataset
|
||||
from torchreid.utils import read_json, write_json
|
||||
|
||||
|
||||
class DukeMTMCVidReID(VideoDataset):
|
||||
"""DukeMTMCVidReID
|
||||
|
||||
Reference:
|
||||
Wu et al. Exploit the Unknown Gradually: One-Shot Video-Based Person
|
||||
Re-Identification by Stepwise Learning. CVPR 2018.
|
||||
|
||||
URL: https://github.com/Yu-Wu/DukeMTMC-VideoReID
|
||||
|
||||
Dataset statistics:
|
||||
identities: 702 (train) + 702 (test)
|
||||
tracklets: 2196 (train) + 2636 (test)
|
||||
"""
|
||||
dataset_dir = 'dukemtmc-vidreid'
|
||||
dataset_url = 'http://vision.cs.duke.edu/DukeMTMC/data/misc/DukeMTMC-VideoReID.zip'
|
||||
|
||||
def __init__(self, root='', min_seq_len=0, **kwargs):
|
||||
self.root = osp.abspath(osp.expanduser(root))
|
||||
self.dataset_dir = osp.join(self.root, self.dataset_dir)
|
||||
self.download_dataset(self.dataset_dir, self.dataset_url)
|
||||
|
||||
self.train_dir = osp.join(self.dataset_dir, 'DukeMTMC-VideoReID/train')
|
||||
self.query_dir = osp.join(self.dataset_dir, 'DukeMTMC-VideoReID/query')
|
||||
self.gallery_dir = osp.join(self.dataset_dir, 'DukeMTMC-VideoReID/gallery')
|
||||
self.split_train_json_path = osp.join(self.dataset_dir, 'split_train.json')
|
||||
self.split_query_json_path = osp.join(self.dataset_dir, 'split_query.json')
|
||||
self.split_gallery_json_path = osp.join(self.dataset_dir, 'split_gallery.json')
|
||||
self.min_seq_len = min_seq_len
|
||||
|
||||
required_files = [
|
||||
self.dataset_dir,
|
||||
self.train_dir,
|
||||
self.query_dir,
|
||||
self.gallery_dir
|
||||
]
|
||||
self.check_before_run(required_files)
|
||||
|
||||
train = self.process_dir(self.train_dir, self.split_train_json_path, relabel=True)
|
||||
query = self.process_dir(self.query_dir, self.split_query_json_path, relabel=False)
|
||||
gallery = self.process_dir(self.gallery_dir, self.split_gallery_json_path, relabel=False)
|
||||
|
||||
super(DukeMTMCVidReID, self).__init__(train, query, gallery, **kwargs)
|
||||
|
||||
def process_dir(self, dir_path, json_path, relabel):
|
||||
if osp.exists(json_path):
|
||||
split = read_json(json_path)
|
||||
return split['tracklets']
|
||||
|
||||
print('=> Generating split json file (** this might take a while **)')
|
||||
pdirs = glob.glob(osp.join(dir_path, '*')) # avoid .DS_Store
|
||||
print('Processing "{}" with {} person identities'.format(dir_path, len(pdirs)))
|
||||
|
||||
pid_container = set()
|
||||
for pdir in pdirs:
|
||||
pid = int(osp.basename(pdir))
|
||||
pid_container.add(pid)
|
||||
pid2label = {pid:label for label, pid in enumerate(pid_container)}
|
||||
|
||||
tracklets = []
|
||||
for pdir in pdirs:
|
||||
pid = int(osp.basename(pdir))
|
||||
if relabel:
|
||||
pid = pid2label[pid]
|
||||
tdirs = glob.glob(osp.join(pdir, '*'))
|
||||
for tdir in tdirs:
|
||||
raw_img_paths = glob.glob(osp.join(tdir, '*.jpg'))
|
||||
num_imgs = len(raw_img_paths)
|
||||
|
||||
if num_imgs < self.min_seq_len:
|
||||
continue
|
||||
|
||||
img_paths = []
|
||||
for img_idx in range(num_imgs):
|
||||
# some tracklet starts from 0002 instead of 0001
|
||||
img_idx_name = 'F' + str(img_idx+1).zfill(4)
|
||||
res = glob.glob(osp.join(tdir, '*' + img_idx_name + '*.jpg'))
|
||||
if len(res) == 0:
|
||||
warnings.warn('Index name {} in {} is missing, skip'.format(img_idx_name, tdir))
|
||||
continue
|
||||
img_paths.append(res[0])
|
||||
img_name = osp.basename(img_paths[0])
|
||||
if img_name.find('_') == -1:
|
||||
# old naming format: 0001C6F0099X30823.jpg
|
||||
camid = int(img_name[5]) - 1
|
||||
else:
|
||||
# new naming format: 0001_C6_F0099_X30823.jpg
|
||||
camid = int(img_name[6]) - 1
|
||||
img_paths = tuple(img_paths)
|
||||
tracklets.append((img_paths, pid, camid))
|
||||
|
||||
print('Saving split to {}'.format(json_path))
|
||||
split_dict = {'tracklets': tracklets}
|
||||
write_json(split_dict, json_path)
|
||||
|
||||
return tracklets
|
|
@ -0,0 +1,126 @@
|
|||
from __future__ import absolute_import
|
||||
from __future__ import print_function
|
||||
from __future__ import division
|
||||
|
||||
import sys
|
||||
import os
|
||||
import os.path as osp
|
||||
import glob
|
||||
from scipy.io import loadmat
|
||||
|
||||
from torchreid.data.datasets import VideoDataset
|
||||
from torchreid.utils import read_json, write_json
|
||||
|
||||
|
||||
class iLIDSVID(VideoDataset):
|
||||
"""iLIDS-VID
|
||||
|
||||
Reference:
|
||||
Wang et al. Person Re-Identification by Video Ranking. ECCV 2014.
|
||||
|
||||
URL: http://www.eecs.qmul.ac.uk/~xiatian/downloads_qmul_iLIDS-VID_ReID_dataset.html
|
||||
|
||||
Dataset statistics:
|
||||
identities: 300
|
||||
tracklets: 600
|
||||
cameras: 2
|
||||
"""
|
||||
dataset_dir = 'ilids-vid'
|
||||
dataset_url = 'http://www.eecs.qmul.ac.uk/~xiatian/iLIDS-VID/iLIDS-VID.tar'
|
||||
|
||||
def __init__(self, root='', split_id=0, **kwargs):
|
||||
self.root = osp.abspath(osp.expanduser(root))
|
||||
self.dataset_dir = osp.join(self.root, self.dataset_dir)
|
||||
self.download_dataset(self.dataset_dir, self.dataset_url)
|
||||
|
||||
self.data_dir = osp.join(self.dataset_dir, 'i-LIDS-VID')
|
||||
self.split_dir = osp.join(self.dataset_dir, 'train-test people splits')
|
||||
self.split_mat_path = osp.join(self.split_dir, 'train_test_splits_ilidsvid.mat')
|
||||
self.split_path = osp.join(self.dataset_dir, 'splits.json')
|
||||
self.cam_1_path = osp.join(self.dataset_dir, 'i-LIDS-VID/sequences/cam1')
|
||||
self.cam_2_path = osp.join(self.dataset_dir, 'i-LIDS-VID/sequences/cam2')
|
||||
|
||||
required_files = [
|
||||
self.dataset_dir,
|
||||
self.data_dir,
|
||||
self.split_dir
|
||||
]
|
||||
self.check_before_run(required_files)
|
||||
|
||||
self.prepare_split()
|
||||
splits = read_json(self.split_path)
|
||||
if split_id >= len(splits):
|
||||
raise ValueError('split_id exceeds range, received {}, but expected between 0 and {}'.format(split_id, len(splits)-1))
|
||||
split = splits[split_id]
|
||||
train_dirs, test_dirs = split['train'], split['test']
|
||||
|
||||
train = self.process_data(train_dirs, cam1=True, cam2=True)
|
||||
query = self.process_data(test_dirs, cam1=True, cam2=False)
|
||||
gallery = self.process_data(test_dirs, cam1=False, cam2=True)
|
||||
|
||||
super(iLIDSVID, self).__init__(train, query, gallery, **kwargs)
|
||||
|
||||
def prepare_split(self):
|
||||
if not osp.exists(self.split_path):
|
||||
print('Creating splits ...')
|
||||
mat_split_data = loadmat(self.split_mat_path)['ls_set']
|
||||
|
||||
num_splits = mat_split_data.shape[0]
|
||||
num_total_ids = mat_split_data.shape[1]
|
||||
assert num_splits == 10
|
||||
assert num_total_ids == 300
|
||||
num_ids_each = num_total_ids // 2
|
||||
|
||||
# pids in mat_split_data are indices, so we need to transform them
|
||||
# to real pids
|
||||
person_cam1_dirs = sorted(glob.glob(osp.join(self.cam_1_path, '*')))
|
||||
person_cam2_dirs = sorted(glob.glob(osp.join(self.cam_2_path, '*')))
|
||||
|
||||
person_cam1_dirs = [osp.basename(item) for item in person_cam1_dirs]
|
||||
person_cam2_dirs = [osp.basename(item) for item in person_cam2_dirs]
|
||||
|
||||
# make sure persons in one camera view can be found in the other camera view
|
||||
assert set(person_cam1_dirs) == set(person_cam2_dirs)
|
||||
|
||||
splits = []
|
||||
for i_split in range(num_splits):
|
||||
# first 50% for testing and the remaining for training, following Wang et al. ECCV'14.
|
||||
train_idxs = sorted(list(mat_split_data[i_split, num_ids_each:]))
|
||||
test_idxs = sorted(list(mat_split_data[i_split, :num_ids_each]))
|
||||
|
||||
train_idxs = [int(i)-1 for i in train_idxs]
|
||||
test_idxs = [int(i)-1 for i in test_idxs]
|
||||
|
||||
# transform pids to person dir names
|
||||
train_dirs = [person_cam1_dirs[i] for i in train_idxs]
|
||||
test_dirs = [person_cam1_dirs[i] for i in test_idxs]
|
||||
|
||||
split = {'train': train_dirs, 'test': test_dirs}
|
||||
splits.append(split)
|
||||
|
||||
print('Totally {} splits are created, following Wang et al. ECCV\'14'.format(len(splits)))
|
||||
print('Split file is saved to {}'.format(self.split_path))
|
||||
write_json(splits, self.split_path)
|
||||
|
||||
def process_data(self, dirnames, cam1=True, cam2=True):
|
||||
tracklets = []
|
||||
dirname2pid = {dirname:i for i, dirname in enumerate(dirnames)}
|
||||
|
||||
for dirname in dirnames:
|
||||
if cam1:
|
||||
person_dir = osp.join(self.cam_1_path, dirname)
|
||||
img_names = glob.glob(osp.join(person_dir, '*.png'))
|
||||
assert len(img_names) > 0
|
||||
img_names = tuple(img_names)
|
||||
pid = dirname2pid[dirname]
|
||||
tracklets.append((img_names, pid, 0))
|
||||
|
||||
if cam2:
|
||||
person_dir = osp.join(self.cam_2_path, dirname)
|
||||
img_names = glob.glob(osp.join(person_dir, '*.png'))
|
||||
assert len(img_names) > 0
|
||||
img_names = tuple(img_names)
|
||||
pid = dirname2pid[dirname]
|
||||
tracklets.append((img_names, pid, 1))
|
||||
|
||||
return tracklets
|
|
@ -0,0 +1,112 @@
|
|||
from __future__ import absolute_import
|
||||
from __future__ import print_function
|
||||
from __future__ import division
|
||||
|
||||
import sys
|
||||
import os
|
||||
import os.path as osp
|
||||
from scipy.io import loadmat
|
||||
import warnings
|
||||
|
||||
from torchreid.data.datasets import VideoDataset
|
||||
|
||||
|
||||
class Mars(VideoDataset):
|
||||
"""MARS
|
||||
|
||||
Reference:
|
||||
Zheng et al. MARS: A Video Benchmark for Large-Scale Person Re-identification. ECCV 2016.
|
||||
|
||||
URL: http://www.liangzheng.com.cn/Project/project_mars.html
|
||||
|
||||
Dataset statistics:
|
||||
identities: 1261
|
||||
tracklets: 8298 (train) + 1980 (query) + 9330 (gallery)
|
||||
cameras: 6
|
||||
"""
|
||||
dataset_dir = 'mars'
|
||||
dataset_url = None
|
||||
|
||||
def __init__(self, root='', **kwargs):
|
||||
self.root = osp.abspath(osp.expanduser(root))
|
||||
self.dataset_dir = osp.join(self.root, self.dataset_dir)
|
||||
self.download_dataset(self.dataset_dir, self.dataset_url)
|
||||
|
||||
self.train_name_path = osp.join(self.dataset_dir, 'info/train_name.txt')
|
||||
self.test_name_path = osp.join(self.dataset_dir, 'info/test_name.txt')
|
||||
self.track_train_info_path = osp.join(self.dataset_dir, 'info/tracks_train_info.mat')
|
||||
self.track_test_info_path = osp.join(self.dataset_dir, 'info/tracks_test_info.mat')
|
||||
self.query_IDX_path = osp.join(self.dataset_dir, 'info/query_IDX.mat')
|
||||
|
||||
required_files = [
|
||||
self.dataset_dir,
|
||||
self.train_name_path,
|
||||
self.test_name_path,
|
||||
self.track_train_info_path,
|
||||
self.track_test_info_path,
|
||||
self.query_IDX_path
|
||||
]
|
||||
self.check_before_run(required_files)
|
||||
|
||||
train_names = self.get_names(self.train_name_path)
|
||||
test_names = self.get_names(self.test_name_path)
|
||||
track_train = loadmat(self.track_train_info_path)['track_train_info'] # numpy.ndarray (8298, 4)
|
||||
track_test = loadmat(self.track_test_info_path)['track_test_info'] # numpy.ndarray (12180, 4)
|
||||
query_IDX = loadmat(self.query_IDX_path)['query_IDX'].squeeze() # numpy.ndarray (1980,)
|
||||
query_IDX -= 1 # index from 0
|
||||
track_query = track_test[query_IDX,:]
|
||||
gallery_IDX = [i for i in range(track_test.shape[0]) if i not in query_IDX]
|
||||
track_gallery = track_test[gallery_IDX,:]
|
||||
|
||||
train = self.process_data(train_names, track_train, home_dir='bbox_train', relabel=True)
|
||||
query = self.process_data(test_names, track_query, home_dir='bbox_test', relabel=False)
|
||||
gallery = self.process_data(test_names, track_gallery, home_dir='bbox_test', relabel=False)
|
||||
|
||||
super(Mars, self).__init__(train, query, gallery, **kwargs)
|
||||
|
||||
def get_names(self, fpath):
|
||||
names = []
|
||||
with open(fpath, 'r') as f:
|
||||
for line in f:
|
||||
new_line = line.rstrip()
|
||||
names.append(new_line)
|
||||
return names
|
||||
|
||||
def process_data(self, names, meta_data, home_dir=None, relabel=False, min_seq_len=0):
|
||||
assert home_dir in ['bbox_train', 'bbox_test']
|
||||
num_tracklets = meta_data.shape[0]
|
||||
pid_list = list(set(meta_data[:,2].tolist()))
|
||||
num_pids = len(pid_list)
|
||||
|
||||
if relabel: pid2label = {pid:label for label, pid in enumerate(pid_list)}
|
||||
tracklets = []
|
||||
|
||||
for tracklet_idx in range(num_tracklets):
|
||||
data = meta_data[tracklet_idx,...]
|
||||
start_index, end_index, pid, camid = data
|
||||
if pid == -1:
|
||||
continue # junk images are just ignored
|
||||
assert 1 <= camid <= 6
|
||||
if relabel: pid = pid2label[pid]
|
||||
camid -= 1 # index starts from 0
|
||||
img_names = names[start_index - 1:end_index]
|
||||
|
||||
# make sure image names correspond to the same person
|
||||
pnames = [img_name[:4] for img_name in img_names]
|
||||
assert len(set(pnames)) == 1, 'Error: a single tracklet contains different person images'
|
||||
|
||||
# make sure all images are captured under the same camera
|
||||
camnames = [img_name[5] for img_name in img_names]
|
||||
assert len(set(camnames)) == 1, 'Error: images are captured under different cameras!'
|
||||
|
||||
# append image names with directory information
|
||||
img_paths = [osp.join(self.dataset_dir, home_dir, img_name[:4], img_name) for img_name in img_names]
|
||||
if len(img_paths) >= min_seq_len:
|
||||
img_paths = tuple(img_paths)
|
||||
tracklets.append((img_paths, pid, camid))
|
||||
|
||||
return tracklets
|
||||
|
||||
def combine_all(self):
|
||||
warnings.warn('Some query IDs do not appear in gallery. Therefore, combineall '
|
||||
'does not make any difference to Mars')
|
|
@ -0,0 +1,79 @@
|
|||
from __future__ import absolute_import
|
||||
from __future__ import print_function
|
||||
from __future__ import division
|
||||
|
||||
import sys
|
||||
import os
|
||||
import os.path as osp
|
||||
import glob
|
||||
|
||||
from torchreid.data.datasets import VideoDataset
|
||||
from torchreid.utils import read_json, write_json
|
||||
|
||||
|
||||
class PRID2011(VideoDataset):
|
||||
"""PRID2011
|
||||
|
||||
Reference:
|
||||
Hirzer et al. Person Re-Identification by Descriptive and Discriminative Classification. SCIA 2011.
|
||||
|
||||
URL: https://www.tugraz.at/institute/icg/research/team-bischof/lrs/downloads/PRID11/
|
||||
|
||||
Dataset statistics:
|
||||
identities: 200
|
||||
tracklets: 400
|
||||
cameras: 2
|
||||
"""
|
||||
dataset_dir = 'prid2011'
|
||||
dataset_url = None
|
||||
|
||||
def __init__(self, root='', split_id=0, **kwargs):
|
||||
self.root = osp.abspath(osp.expanduser(root))
|
||||
self.dataset_dir = osp.join(self.root, self.dataset_dir)
|
||||
self.download_dataset(self.dataset_dir, self.dataset_url)
|
||||
|
||||
self.split_path = osp.join(self.dataset_dir, 'splits_prid2011.json')
|
||||
self.cam_a_dir = osp.join(self.dataset_dir, 'prid_2011', 'multi_shot', 'cam_a')
|
||||
self.cam_b_dir = osp.join(self.dataset_dir, 'prid_2011', 'multi_shot', 'cam_b')
|
||||
|
||||
required_files = [
|
||||
self.dataset_dir,
|
||||
self.cam_a_dir,
|
||||
self.cam_b_dir
|
||||
]
|
||||
self.check_before_run(required_files)
|
||||
|
||||
splits = read_json(self.split_path)
|
||||
if split_id >= len(splits):
|
||||
raise ValueError('split_id exceeds range, received {}, but expected between 0 and {}'.format(split_id, len(splits)-1))
|
||||
split = splits[split_id]
|
||||
train_dirs, test_dirs = split['train'], split['test']
|
||||
|
||||
train = self.process_dir(train_dirs, cam1=True, cam2=True)
|
||||
query = self.process_dir(test_dirs, cam1=True, cam2=False)
|
||||
gallery = self.process_dir(test_dirs, cam1=False, cam2=True)
|
||||
|
||||
super(PRID2011, self).__init__(train, query, gallery, **kwargs)
|
||||
|
||||
def process_dir(self, dirnames, cam1=True, cam2=True):
|
||||
tracklets = []
|
||||
dirname2pid = {dirname:i for i, dirname in enumerate(dirnames)}
|
||||
|
||||
for dirname in dirnames:
|
||||
if cam1:
|
||||
person_dir = osp.join(self.cam_a_dir, dirname)
|
||||
img_names = glob.glob(osp.join(person_dir, '*.png'))
|
||||
assert len(img_names) > 0
|
||||
img_names = tuple(img_names)
|
||||
pid = dirname2pid[dirname]
|
||||
tracklets.append((img_names, pid, 0))
|
||||
|
||||
if cam2:
|
||||
person_dir = osp.join(self.cam_b_dir, dirname)
|
||||
img_names = glob.glob(osp.join(person_dir, '*.png'))
|
||||
assert len(img_names) > 0
|
||||
img_names = tuple(img_names)
|
||||
pid = dirname2pid[dirname]
|
||||
tracklets.append((img_names, pid, 1))
|
||||
|
||||
return tracklets
|
|
@ -0,0 +1,87 @@
|
|||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
|
||||
from collections import defaultdict
|
||||
import numpy as np
|
||||
import copy
|
||||
import random
|
||||
|
||||
import torch
|
||||
from torch.utils.data.sampler import Sampler, RandomSampler
|
||||
|
||||
|
||||
class RandomIdentitySampler(Sampler):
|
||||
"""Randomly samples N identities each with K instances.
|
||||
|
||||
Args:
|
||||
data_source (list): contains a list of (img_path, pid, camid).
|
||||
batch_size (int): number of examples in a batch.
|
||||
num_instances (int): number of instances per identity in a batch.
|
||||
"""
|
||||
def __init__(self, data_source, batch_size, num_instances):
|
||||
self.data_source = data_source
|
||||
self.batch_size = batch_size
|
||||
self.num_instances = num_instances
|
||||
self.num_pids_per_batch = self.batch_size // self.num_instances
|
||||
self.index_dic = defaultdict(list)
|
||||
for index, (_, pid, _) in enumerate(self.data_source):
|
||||
self.index_dic[pid].append(index)
|
||||
self.pids = list(self.index_dic.keys())
|
||||
|
||||
# estimate number of examples in an epoch
|
||||
self.length = 0
|
||||
for pid in self.pids:
|
||||
idxs = self.index_dic[pid]
|
||||
num = len(idxs)
|
||||
if num < self.num_instances:
|
||||
num = self.num_instances
|
||||
self.length += num - num % self.num_instances
|
||||
|
||||
def __iter__(self):
|
||||
batch_idxs_dict = defaultdict(list)
|
||||
|
||||
for pid in self.pids:
|
||||
idxs = copy.deepcopy(self.index_dic[pid])
|
||||
if len(idxs) < self.num_instances:
|
||||
idxs = np.random.choice(idxs, size=self.num_instances, replace=True)
|
||||
random.shuffle(idxs)
|
||||
batch_idxs = []
|
||||
for idx in idxs:
|
||||
batch_idxs.append(idx)
|
||||
if len(batch_idxs) == self.num_instances:
|
||||
batch_idxs_dict[pid].append(batch_idxs)
|
||||
batch_idxs = []
|
||||
|
||||
avai_pids = copy.deepcopy(self.pids)
|
||||
final_idxs = []
|
||||
|
||||
while len(avai_pids) >= self.num_pids_per_batch:
|
||||
selected_pids = random.sample(avai_pids, self.num_pids_per_batch)
|
||||
for pid in selected_pids:
|
||||
batch_idxs = batch_idxs_dict[pid].pop(0)
|
||||
final_idxs.extend(batch_idxs)
|
||||
if len(batch_idxs_dict[pid]) == 0:
|
||||
avai_pids.remove(pid)
|
||||
|
||||
return iter(final_idxs)
|
||||
|
||||
def __len__(self):
|
||||
return self.length
|
||||
|
||||
|
||||
def build_train_sampler(data_source, train_sampler, batch_size, num_instances, **kwargs):
|
||||
"""Builds a training sampler.
|
||||
|
||||
Args:
|
||||
data_source (list): contains a list of (img_path, pid, camid).
|
||||
train_sampler (str): sampler name (default: RandomSampler).
|
||||
batch_size (int): training batch size.
|
||||
num_instances (int): number of instances per identity in a batch (for RandomIdentitySampler).
|
||||
"""
|
||||
if train_sampler == 'RandomIdentitySampler':
|
||||
sampler = RandomIdentitySampler(data_source, batch_size, num_instances)
|
||||
|
||||
else:
|
||||
sampler = RandomSampler(data_source)
|
||||
|
||||
return sampler
|
|
@ -0,0 +1,157 @@
|
|||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
from PIL import Image
|
||||
import random
|
||||
import numpy as np
|
||||
import math
|
||||
|
||||
import torch
|
||||
from torchvision.transforms import *
|
||||
|
||||
|
||||
class Random2DTranslation(object):
|
||||
"""Randomly translates the input image with a probability.
|
||||
|
||||
Specifically, given a predefined shape (height, width), the input is first
|
||||
resized with a factor of 1.25, leading to (height*1.25, width*1.25), then
|
||||
a random crop is performed. Such operation is done with a probability.
|
||||
|
||||
Args:
|
||||
height (int): target image height.
|
||||
width (int): target image width.
|
||||
p (float): probability of performing this transformation. Default: 0.5.
|
||||
"""
|
||||
|
||||
def __init__(self, height, width, p=0.5, interpolation=Image.BILINEAR):
|
||||
self.height = height
|
||||
self.width = width
|
||||
self.p = p
|
||||
self.interpolation = interpolation
|
||||
|
||||
def __call__(self, img):
|
||||
if random.uniform(0, 1) > self.p:
|
||||
return img.resize((self.width, self.height), self.interpolation)
|
||||
|
||||
new_width, new_height = int(round(self.width * 1.125)), int(round(self.height * 1.125))
|
||||
resized_img = img.resize((new_width, new_height), self.interpolation)
|
||||
x_maxrange = new_width - self.width
|
||||
y_maxrange = new_height - self.height
|
||||
x1 = int(round(random.uniform(0, x_maxrange)))
|
||||
y1 = int(round(random.uniform(0, y_maxrange)))
|
||||
croped_img = resized_img.crop((x1, y1, x1 + self.width, y1 + self.height))
|
||||
return croped_img
|
||||
|
||||
|
||||
class RandomErasing(object):
|
||||
'''Randomly erases an image patch.
|
||||
|
||||
Class that performs Random Erasing in Random Erasing Data Augmentation by Zhong et al.
|
||||
-------------------------------------------------------------------------------------
|
||||
probability: The probability that the operation will be performed.
|
||||
sl: min erasing area
|
||||
sh: max erasing area
|
||||
r1: min aspect ratio
|
||||
mean: erasing value
|
||||
-------------------------------------------------------------------------------------
|
||||
|
||||
Imported from https://github.com/zhunzhong07/Random-Erasing
|
||||
'''
|
||||
|
||||
def __init__(self, probability=0.5, sl=0.02, sh=0.4, r1=0.3, mean=[0.4914, 0.4822, 0.4465]):
|
||||
self.probability = probability
|
||||
self.mean = mean
|
||||
self.sl = sl
|
||||
self.sh = sh
|
||||
self.r1 = r1
|
||||
|
||||
def __call__(self, img):
|
||||
if random.uniform(0, 1) > self.probability:
|
||||
return img
|
||||
|
||||
for attempt in range(100):
|
||||
area = img.size()[1] * img.size()[2]
|
||||
|
||||
target_area = random.uniform(self.sl, self.sh) * area
|
||||
aspect_ratio = random.uniform(self.r1, 1/self.r1)
|
||||
|
||||
h = int(round(math.sqrt(target_area * aspect_ratio)))
|
||||
w = int(round(math.sqrt(target_area / aspect_ratio)))
|
||||
|
||||
if w < img.size()[2] and h < img.size()[1]:
|
||||
x1 = random.randint(0, img.size()[1] - h)
|
||||
y1 = random.randint(0, img.size()[2] - w)
|
||||
if img.size()[0] == 3:
|
||||
img[0, x1:x1+h, y1:y1+w] = self.mean[0]
|
||||
img[1, x1:x1+h, y1:y1+w] = self.mean[1]
|
||||
img[2, x1:x1+h, y1:y1+w] = self.mean[2]
|
||||
else:
|
||||
img[0, x1:x1+h, y1:y1+w] = self.mean[0]
|
||||
return img
|
||||
|
||||
return img
|
||||
|
||||
|
||||
class ColorAugmentation(object):
|
||||
"""Randomly alters the intensities of RGB channels.
|
||||
|
||||
Reference:
|
||||
Krizhevsky et al. ImageNet Classification with Deep ConvolutionalNeural Networks. NIPS 2012.
|
||||
"""
|
||||
|
||||
def __init__(self, p=0.5):
|
||||
self.p = p
|
||||
self.eig_vec = torch.Tensor([
|
||||
[0.4009, 0.7192, -0.5675],
|
||||
[-0.8140, -0.0045, -0.5808],
|
||||
[0.4203, -0.6948, -0.5836],
|
||||
])
|
||||
self.eig_val = torch.Tensor([[0.2175, 0.0188, 0.0045]])
|
||||
|
||||
def _check_input(self, tensor):
|
||||
assert tensor.dim() == 3 and tensor.size(0) == 3
|
||||
|
||||
def __call__(self, tensor):
|
||||
if random.uniform(0, 1) > self.p:
|
||||
return tensor
|
||||
alpha = torch.normal(mean=torch.zeros_like(self.eig_val)) * 0.1
|
||||
quatity = torch.mm(self.eig_val * alpha, self.eig_vec)
|
||||
tensor = tensor + quatity.view(3, 1, 1)
|
||||
return tensor
|
||||
|
||||
|
||||
def build_transforms(
|
||||
height,
|
||||
width,
|
||||
random_erase=False, # use random erasing for data augmentation
|
||||
color_jitter=False, # randomly change the brightness, contrast and saturation
|
||||
color_aug=False, # randomly alter the intensities of RGB channels
|
||||
norm_mean=[0.485, 0.456, 0.406], # default is imagenet mean
|
||||
norm_std=[0.229, 0.224, 0.225], # default is imagenet std
|
||||
**kwargs
|
||||
):
|
||||
normalize = Normalize(mean=norm_mean, std=norm_std)
|
||||
|
||||
# build train transformations
|
||||
transform_train = []
|
||||
transform_train += [Random2DTranslation(height, width)]
|
||||
transform_train += [RandomHorizontalFlip()]
|
||||
if color_jitter:
|
||||
transform_train += [ColorJitter(brightness=0.2, contrast=0.15, saturation=0, hue=0)]
|
||||
transform_train += [ToTensor()]
|
||||
if color_aug:
|
||||
transform_train += [ColorAugmentation()]
|
||||
transform_train += [normalize]
|
||||
if random_erase:
|
||||
transform_train += [RandomErasing()]
|
||||
transform_train = Compose(transform_train)
|
||||
|
||||
# build test transformations
|
||||
transform_test = Compose([
|
||||
Resize((height, width)),
|
||||
ToTensor(),
|
||||
normalize,
|
||||
])
|
||||
|
||||
return transform_train, transform_test
|
File diff suppressed because it is too large
Load Diff
Loading…
Reference in New Issue