# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved import copy import itertools import json import logging import os import pickle from collections import OrderedDict import torch from fvcore.common.file_io import PathManager import detectron2.utils.comm as comm from detectron2.data import MetadataCatalog from detectron2.structures import Boxes, BoxMode, pairwise_iou from detectron2.utils.logger import create_small_table from .coco_evaluation import instances_to_coco_json from .evaluator import DatasetEvaluator class LVISEvaluator(DatasetEvaluator): """ Evaluate object proposal and instance detection/segmentation outputs using LVIS's metrics and evaluation API. """ def __init__(self, dataset_name, cfg, distributed, output_dir=None): """ Args: dataset_name (str): name of the dataset to be evaluated. It must have the following corresponding metadata: "json_file": the path to the LVIS format annotation cfg (CfgNode): config instance distributed (True): if True, will collect results from all ranks for evaluation. Otherwise, will evaluate the results in the current process. output_dir (str): optional, an output directory to dump results. """ from lvis import LVIS self._tasks = self._tasks_from_config(cfg) self._distributed = distributed self._output_dir = output_dir self._cpu_device = torch.device("cpu") self._logger = logging.getLogger(__name__) self._metadata = MetadataCatalog.get(dataset_name) json_file = PathManager.get_local_path(self._metadata.json_file) self._lvis_api = LVIS(json_file) # Test set json files do not contain annotations (evaluation must be # performed using the LVIS evaluation server). self._do_evaluation = len(self._lvis_api.get_ann_ids()) > 0 def reset(self): self._predictions = [] def _tasks_from_config(self, cfg): """ Returns: tuple[str]: tasks that can be evaluated under the given configuration. """ tasks = ("bbox",) if cfg.MODEL.MASK_ON: tasks = tasks + ("segm",) return tasks def process(self, inputs, outputs): """ Args: inputs: the inputs to a LVIS model (e.g., GeneralizedRCNN). It is a list of dict. Each dict corresponds to an image and contains keys like "height", "width", "file_name", "image_id". outputs: the outputs of a LVIS model. It is a list of dicts with key "instances" that contains :class:`Instances`. """ for input, output in zip(inputs, outputs): prediction = {"image_id": input["image_id"]} if "instances" in output: instances = output["instances"].to(self._cpu_device) prediction["instances"] = instances_to_coco_json(instances, input["image_id"]) if "proposals" in output: prediction["proposals"] = output["proposals"].to(self._cpu_device) self._predictions.append(prediction) def evaluate(self): if self._distributed: comm.synchronize() predictions = comm.gather(self._predictions, dst=0) predictions = list(itertools.chain(*predictions)) if not comm.is_main_process(): return else: predictions = self._predictions if len(predictions) == 0: self._logger.warning("[LVISEvaluator] Did not receive valid predictions.") return {} if self._output_dir: PathManager.mkdirs(self._output_dir) file_path = os.path.join(self._output_dir, "instances_predictions.pth") with PathManager.open(file_path, "wb") as f: torch.save(predictions, f) self._results = OrderedDict() if "proposals" in predictions[0]: self._eval_box_proposals(predictions) if "instances" in predictions[0]: self._eval_predictions(set(self._tasks), predictions) # Copy so the caller can do whatever with results return copy.deepcopy(self._results) def _eval_predictions(self, tasks, predictions): """ Evaluate predictions on the given tasks. Fill self._results with the metrics of the tasks. Args: predictions (list[dict]): list of outputs from the model """ self._logger.info("Preparing results in the LVIS format ...") lvis_results = list(itertools.chain(*[x["instances"] for x in predictions])) # LVIS evaluator can be used to evaluate results for COCO dataset categories. # In this case `_metadata` variable will have a field with COCO-specific category mapping. if hasattr(self._metadata, "thing_dataset_id_to_contiguous_id"): reverse_id_mapping = { v: k for k, v in self._metadata.thing_dataset_id_to_contiguous_id.items() } for result in lvis_results: result["category_id"] = reverse_id_mapping[result["category_id"]] else: # unmap the category ids for LVIS (from 0-indexed to 1-indexed) for result in lvis_results: result["category_id"] += 1 if self._output_dir: file_path = os.path.join(self._output_dir, "lvis_instances_results.json") self._logger.info("Saving results to {}".format(file_path)) with PathManager.open(file_path, "w") as f: f.write(json.dumps(lvis_results)) f.flush() if not self._do_evaluation: self._logger.info("Annotations are not available for evaluation.") return self._logger.info("Evaluating predictions ...") for task in sorted(tasks): res = _evaluate_predictions_on_lvis( self._lvis_api, lvis_results, task, class_names=self._metadata.get("thing_classes") ) self._results[task] = res def _eval_box_proposals(self, predictions): """ Evaluate the box proposals in predictions. Fill self._results with the metrics for "box_proposals" task. """ if self._output_dir: # Saving generated box proposals to file. # Predicted box_proposals are in XYXY_ABS mode. bbox_mode = BoxMode.XYXY_ABS.value ids, boxes, objectness_logits = [], [], [] for prediction in predictions: ids.append(prediction["image_id"]) boxes.append(prediction["proposals"].proposal_boxes.tensor.numpy()) objectness_logits.append(prediction["proposals"].objectness_logits.numpy()) proposal_data = { "boxes": boxes, "objectness_logits": objectness_logits, "ids": ids, "bbox_mode": bbox_mode, } with PathManager.open(os.path.join(self._output_dir, "box_proposals.pkl"), "wb") as f: pickle.dump(proposal_data, f) if not self._do_evaluation: self._logger.info("Annotations are not available for evaluation.") return self._logger.info("Evaluating bbox proposals ...") res = {} areas = {"all": "", "small": "s", "medium": "m", "large": "l"} for limit in [100, 1000]: for area, suffix in areas.items(): stats = _evaluate_box_proposals(predictions, self._lvis_api, area=area, limit=limit) key = "AR{}@{:d}".format(suffix, limit) res[key] = float(stats["ar"].item() * 100) self._logger.info("Proposal metrics: \n" + create_small_table(res)) self._results["box_proposals"] = res # inspired from Detectron: # https://github.com/facebookresearch/Detectron/blob/a6a835f5b8208c45d0dce217ce9bbda915f44df7/detectron/datasets/json_dataset_evaluator.py#L255 # noqa def _evaluate_box_proposals(dataset_predictions, lvis_api, thresholds=None, area="all", limit=None): """ Evaluate detection proposal recall metrics. This function is a much faster alternative to the official LVIS API recall evaluation code. However, it produces slightly different results. """ # Record max overlap value for each gt box # Return vector of overlap values areas = { "all": 0, "small": 1, "medium": 2, "large": 3, "96-128": 4, "128-256": 5, "256-512": 6, "512-inf": 7, } area_ranges = [ [0 ** 2, 1e5 ** 2], # all [0 ** 2, 32 ** 2], # small [32 ** 2, 96 ** 2], # medium [96 ** 2, 1e5 ** 2], # large [96 ** 2, 128 ** 2], # 96-128 [128 ** 2, 256 ** 2], # 128-256 [256 ** 2, 512 ** 2], # 256-512 [512 ** 2, 1e5 ** 2], ] # 512-inf assert area in areas, "Unknown area range: {}".format(area) area_range = area_ranges[areas[area]] gt_overlaps = [] num_pos = 0 for prediction_dict in dataset_predictions: predictions = prediction_dict["proposals"] # sort predictions in descending order # TODO maybe remove this and make it explicit in the documentation inds = predictions.objectness_logits.sort(descending=True)[1] predictions = predictions[inds] ann_ids = lvis_api.get_ann_ids(img_ids=[prediction_dict["image_id"]]) anno = lvis_api.load_anns(ann_ids) gt_boxes = [ BoxMode.convert(obj["bbox"], BoxMode.XYWH_ABS, BoxMode.XYXY_ABS) for obj in anno ] gt_boxes = torch.as_tensor(gt_boxes).reshape(-1, 4) # guard against no boxes gt_boxes = Boxes(gt_boxes) gt_areas = torch.as_tensor([obj["area"] for obj in anno]) if len(gt_boxes) == 0 or len(predictions) == 0: continue valid_gt_inds = (gt_areas >= area_range[0]) & (gt_areas <= area_range[1]) gt_boxes = gt_boxes[valid_gt_inds] num_pos += len(gt_boxes) if len(gt_boxes) == 0: continue if limit is not None and len(predictions) > limit: predictions = predictions[:limit] overlaps = pairwise_iou(predictions.proposal_boxes, gt_boxes) _gt_overlaps = torch.zeros(len(gt_boxes)) for j in range(min(len(predictions), len(gt_boxes))): # find which proposal box maximally covers each gt box # and get the iou amount of coverage for each gt box max_overlaps, argmax_overlaps = overlaps.max(dim=0) # find which gt box is 'best' covered (i.e. 'best' = most iou) gt_ovr, gt_ind = max_overlaps.max(dim=0) assert gt_ovr >= 0 # find the proposal box that covers the best covered gt box box_ind = argmax_overlaps[gt_ind] # record the iou coverage of this gt box _gt_overlaps[j] = overlaps[box_ind, gt_ind] assert _gt_overlaps[j] == gt_ovr # mark the proposal box and the gt box as used overlaps[box_ind, :] = -1 overlaps[:, gt_ind] = -1 # append recorded iou coverage level gt_overlaps.append(_gt_overlaps) gt_overlaps = ( torch.cat(gt_overlaps, dim=0) if len(gt_overlaps) else torch.zeros(0, dtype=torch.float32) ) gt_overlaps, _ = torch.sort(gt_overlaps) if thresholds is None: step = 0.05 thresholds = torch.arange(0.5, 0.95 + 1e-5, step, dtype=torch.float32) recalls = torch.zeros_like(thresholds) # compute recall for each iou threshold for i, t in enumerate(thresholds): recalls[i] = (gt_overlaps >= t).float().sum() / float(num_pos) # ar = 2 * np.trapz(recalls, thresholds) ar = recalls.mean() return { "ar": ar, "recalls": recalls, "thresholds": thresholds, "gt_overlaps": gt_overlaps, "num_pos": num_pos, } def _evaluate_predictions_on_lvis(lvis_gt, lvis_results, iou_type, class_names=None): """ Args: iou_type (str): kpt_oks_sigmas (list[float]): class_names (None or list[str]): if provided, will use it to predict per-category AP. Returns: a dict of {metric name: score} """ metrics = { "bbox": ["AP", "AP50", "AP75", "APs", "APm", "APl", "APr", "APc", "APf"], "segm": ["AP", "AP50", "AP75", "APs", "APm", "APl", "APr", "APc", "APf"], }[iou_type] logger = logging.getLogger(__name__) if len(lvis_results) == 0: # TODO: check if needed logger.warn("No predictions from the model!") return {metric: float("nan") for metric in metrics} if iou_type == "segm": lvis_results = copy.deepcopy(lvis_results) # When evaluating mask AP, if the results contain bbox, LVIS API will # use the box area as the area of the instance, instead of the mask area. # This leads to a different definition of small/medium/large. # We remove the bbox field to let mask AP use mask area. for c in lvis_results: c.pop("bbox", None) from lvis import LVISEval, LVISResults lvis_results = LVISResults(lvis_gt, lvis_results) lvis_eval = LVISEval(lvis_gt, lvis_results, iou_type) lvis_eval.run() lvis_eval.print_results() # Pull the standard metrics from the LVIS results results = lvis_eval.get_results() results = {metric: float(results[metric] * 100) for metric in metrics} logger.info("Evaluation results for {}: \n".format(iou_type) + create_small_table(results)) return results