Add get_cat_ids in BaseDataset (#72)
* visualize results on image demo * add get_cat_ids in BaseDatasetpull/89/head
parent
784987fe9f
commit
909a6b9c3f
|
@ -38,12 +38,36 @@ class BaseDataset(Dataset, metaclass=ABCMeta):
|
|||
|
||||
@property
|
||||
def class_to_idx(self):
|
||||
"""Map mapping class name to class index.
|
||||
|
||||
Returns:
|
||||
dict: mapping from class name to class index.
|
||||
"""
|
||||
|
||||
return {_class: i for i, _class in enumerate(self.CLASSES)}
|
||||
|
||||
def get_gt_labels(self):
|
||||
"""Get all ground-truth labels (categories).
|
||||
|
||||
Returns:
|
||||
list[int]: categories for all images.
|
||||
"""
|
||||
|
||||
gt_labels = np.array([data['gt_label'] for data in self.data_infos])
|
||||
return gt_labels
|
||||
|
||||
def get_cat_ids(self, idx):
|
||||
"""Get category id by index.
|
||||
|
||||
Args:
|
||||
idx (int): Index of data.
|
||||
|
||||
Returns:
|
||||
int: Image category of specified index.
|
||||
"""
|
||||
|
||||
return self.data_infos[idx]['gt_label'].astype(np.int)
|
||||
|
||||
def prepare_data(self, idx):
|
||||
results = copy.deepcopy(self.data_infos[idx])
|
||||
return self.pipeline(results)
|
||||
|
|
Loading…
Reference in New Issue