mirror of https://github.com/JDAI-CV/fast-reid.git
128 lines
3.3 KiB
Python
128 lines
3.3 KiB
Python
# encoding: utf-8
|
|
"""
|
|
@author: xingyu liao
|
|
@contact: sherlockliao01@gmail.com
|
|
"""
|
|
|
|
import copy
|
|
import logging
|
|
import os
|
|
|
|
from tabulate import tabulate
|
|
from termcolor import colored
|
|
|
|
logger = logging.getLogger("fastreid.attr_dataset")
|
|
|
|
|
|
class Dataset(object):
|
|
|
|
def __init__(
|
|
self,
|
|
train,
|
|
val,
|
|
test,
|
|
attr_dict,
|
|
mode='train',
|
|
verbose=True,
|
|
**kwargs,
|
|
):
|
|
self.train = train
|
|
self.val = val
|
|
self.test = test
|
|
self._attr_dict = attr_dict
|
|
self._num_attrs = len(self.attr_dict)
|
|
|
|
if mode == 'train':
|
|
self.data = self.train
|
|
elif mode == 'val':
|
|
self.data = self.val
|
|
else:
|
|
self.data = self.test
|
|
|
|
@property
|
|
def num_attrs(self):
|
|
return self._num_attrs
|
|
|
|
@property
|
|
def attr_dict(self):
|
|
return self._attr_dict
|
|
|
|
def __len__(self):
|
|
return len(self.data)
|
|
|
|
def __getitem__(self, index):
|
|
raise NotImplementedError
|
|
|
|
def check_before_run(self, required_files):
|
|
"""Checks if required files exist before going deeper.
|
|
Args:
|
|
required_files (str or list): string file name(s).
|
|
"""
|
|
if isinstance(required_files, str):
|
|
required_files = [required_files]
|
|
|
|
for fpath in required_files:
|
|
if not os.path.exists(fpath):
|
|
raise RuntimeError('"{}" is not found'.format(fpath))
|
|
|
|
def combine_all(self):
|
|
"""Combines train, val and test in a dataset for training."""
|
|
combined = copy.deepcopy(self.train)
|
|
|
|
def _combine_data(data):
|
|
for img_path, pid, camid in data:
|
|
if pid in self._junk_pids:
|
|
continue
|
|
pid = self.dataset_name + "_" + str(pid)
|
|
camid = self.dataset_name + "_" + str(camid)
|
|
combined.append((img_path, pid, camid))
|
|
|
|
_combine_data(self.query)
|
|
_combine_data(self.gallery)
|
|
|
|
self.train = combined
|
|
self.num_train_pids = self.get_num_pids(self.train)
|
|
|
|
def show_train(self):
|
|
num_train = len(self.train)
|
|
num_val = len(self.val)
|
|
num_total = num_train + num_val
|
|
|
|
headers = ['subset', '# images']
|
|
csv_results = [
|
|
['train', num_train],
|
|
['val', num_val],
|
|
['total', num_total],
|
|
]
|
|
|
|
# tabulate it
|
|
table = tabulate(
|
|
csv_results,
|
|
tablefmt="pipe",
|
|
headers=headers,
|
|
numalign="left",
|
|
)
|
|
logger.info(f"=> Loaded {self.__class__.__name__} in csv format: \n" + colored(table, "cyan"))
|
|
logger.info("attributes:")
|
|
for label, attr in self.attr_dict.items():
|
|
logger.info('{:3d}: {}'.format(label, attr))
|
|
logger.info("------------------------------")
|
|
logger.info("# attributes: {}".format(len(self.attr_dict)))
|
|
|
|
def show_test(self):
|
|
num_test = len(self.test)
|
|
|
|
headers = ['subset', '# images']
|
|
csv_results = [
|
|
['test', num_test],
|
|
]
|
|
|
|
# tabulate it
|
|
table = tabulate(
|
|
csv_results,
|
|
tablefmt="pipe",
|
|
headers=headers,
|
|
numalign="left",
|
|
)
|
|
logger.info(f"=> Loaded {self.__class__.__name__} in csv format: \n" + colored(table, "cyan"))
|