Add get_cat_ids in BaseDataset (#72)

* visualize results on image demo

* add get_cat_ids in BaseDataset
pull/89/head
Lei Yang 2020-10-26 14:04:10 +08:00 committed by GitHub
parent 784987fe9f
commit 909a6b9c3f
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 24 additions and 0 deletions

View File

@ -38,12 +38,36 @@ class BaseDataset(Dataset, metaclass=ABCMeta):
@property
def class_to_idx(self):
"""Map mapping class name to class index.
Returns:
dict: mapping from class name to class index.
"""
return {_class: i for i, _class in enumerate(self.CLASSES)}
def get_gt_labels(self):
"""Get all ground-truth labels (categories).
Returns:
list[int]: categories for all images.
"""
gt_labels = np.array([data['gt_label'] for data in self.data_infos])
return gt_labels
def get_cat_ids(self, idx):
"""Get category id by index.
Args:
idx (int): Index of data.
Returns:
int: Image category of specified index.
"""
return self.data_infos[idx]['gt_label'].astype(np.int)
def prepare_data(self, idx):
results = copy.deepcopy(self.data_infos[idx])
return self.pipeline(results)