EasyCV/easycv/datasets/detection/raw.py

145 lines
6.1 KiB
Python
Raw Normal View History

2022-04-02 20:01:06 +08:00
# Copyright (c) Alibaba, Inc. and its affiliates.
import numpy as np
from easycv.core.visualization.image import imshow_bboxes
2022-04-02 20:01:06 +08:00
from easycv.datasets.registry import DATASETS
from easycv.datasets.shared.base import BaseDataset
from easycv.utils.bbox_util import batched_xyxy2cxcywh_with_shape
2022-04-02 20:01:06 +08:00
@DATASETS.register_module
class DetDataset(BaseDataset):
"""Dataset for Detection
"""
def __init__(self, data_source, pipeline, profiling=False, classes=None):
"""
Args:
data_source: Data_source config dict
pipeline: Pipeline config list
profiling: If set True, will print pipeline time
classes: A list of class names, used in evaluation for result and groundtruth visualization
"""
self.classes = classes
self.CLASSES = classes
2022-04-02 20:01:06 +08:00
super(DetDataset, self).__init__(
data_source, pipeline, profiling=profiling)
self.num_samples = self.data_source.get_length()
def __len__(self):
return self.num_samples
2022-04-02 20:01:06 +08:00
def __getitem__(self, idx):
data_dict = self.data_source.get_sample(idx)
data_dict = self.pipeline(data_dict)
return data_dict
def evaluate(self, results, evaluators=None, logger=None):
'''Evaluates the detection boxes.
Args:
results: A dictionary containing
detection_boxes: List of length number of test images.
Float32 numpy array of shape [num_boxes, 4] and
format [ymin, xmin, ymax, xmax] in absolute image coordinates.
detection_scores: List of length number of test images,
detection scores for the boxes, float32 numpy array of shape [num_boxes].
detection_classes: List of length number of test images,
integer numpy array of shape [num_boxes]
containing 1-indexed detection classes for the boxes.
img_metas: List of length number of test images,
dict of image meta info, containing filename, img_shape,
origin_img_shape, scale_factor and so on.
evaluators: evaluators to calculate metric with results and groundtruth_dict
2022-04-02 20:01:06 +08:00
'''
2022-04-02 20:01:06 +08:00
eval_result = dict()
2022-04-02 20:01:06 +08:00
groundtruth_dict = {}
groundtruth_dict['groundtruth_boxes'] = [
batched_xyxy2cxcywh_with_shape(
self.data_source.get_ann_info(idx)['bboxes'],
results['img_metas'][idx]['ori_img_shape'])
for idx in range(len(results['img_metas']))
2022-04-02 20:01:06 +08:00
]
groundtruth_dict['groundtruth_classes'] = [
self.data_source.get_ann_info(idx)['labels']
for idx in range(len(results['img_metas']))
]
groundtruth_dict['groundtruth_is_crowd'] = [
self.data_source.get_ann_info(idx)['groundtruth_is_crowd']
for idx in range(len(results['img_metas']))
2022-04-02 20:01:06 +08:00
]
2022-04-02 20:01:06 +08:00
for evaluator in evaluators:
eval_result.update(evaluator.evaluate(results, groundtruth_dict))
return eval_result
def visualize(self, results, vis_num=10, score_thr=0.3, **kwargs):
"""Visulaize the model output on validation data.
Args:
results: A dictionary containing
detection_boxes: List of length number of test images.
Float32 numpy array of shape [num_boxes, 4] and
format [ymin, xmin, ymax, xmax] in absolute image coordinates.
detection_scores: List of length number of test images,
detection scores for the boxes, float32 numpy array of shape [num_boxes].
detection_classes: List of length number of test images,
integer numpy array of shape [num_boxes]
containing 1-indexed detection classes for the boxes.
img_metas: List of length number of test images,
dict of image meta info, containing filename, img_shape,
origin_img_shape, scale_factor and so on.
vis_num: number of images visualized
score_thr: The threshold to filter box,
boxes with scores greater than score_thr will be kept.
Returns: A dictionary containing
images: Visulaized images.
img_metas: List of length number of test images,
dict of image meta info, containing filename, img_shape,
origin_img_shape, scale_factor and so on.
"""
class_names = None
if hasattr(self.data_source, 'CLASSES'):
class_names = self.data_source.CLASSES
elif hasattr(self.data_source, 'classes'):
class_names = self.data_source.classes
if class_names is not None:
detection_classes = []
for classes_id in results['detection_classes']:
if classes_id is None:
detection_classes.append(None)
else:
detection_classes.append(
np.array([class_names[id] for id in classes_id]))
results['detection_classes'] = detection_classes
vis_imgs = []
img_metas = results['img_metas'][:vis_num]
detection_boxes = results.get('detection_boxes', [])
detection_scores = results.get('detection_scores', [])
detection_classes = results.get('detection_classes', [])
for i, img_meta in enumerate(img_metas):
filename = img_meta['filename']
bboxes = np.array(
[]) if detection_boxes[i] is None else detection_boxes[i]
scores = detection_scores[i]
classes = detection_classes[i]
if scores is not None and score_thr > 0:
inds = scores > score_thr
bboxes = bboxes[inds]
classes = classes[inds]
vis_img = imshow_bboxes(
img=filename, bboxes=bboxes, labels=classes, show=False)
vis_imgs.append(vis_img)
output = {'images': vis_imgs, 'img_metas': img_metas}
return output