Add files via upload

main
RE-OWOD 2022-06-19 20:06:04 +08:00 committed by GitHub
parent d0a04a5b90
commit fa6d74ed4c
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 42 additions and 42 deletions

View File

@ -5,9 +5,14 @@ import time
from collections import OrderedDict
from contextlib import contextmanager
import torch
import numpy as np
import pickle
import sys
import pdb
from detectron2.utils.comm import get_world_size, is_main_process
from detectron2.utils.logger import log_every_n_seconds
from detectron2.structures import pairwise_iou
class DatasetEvaluator:
@ -98,7 +103,10 @@ class DatasetEvaluators(DatasetEvaluator):
return results
def inference_on_dataset(model, data_loader, evaluator,GENERATE_CALI):
class_score = [[] for _ in range(21)]
def inference_on_dataset(cfg, model, data_loader, evaluator):
"""
Run model on the data_loader and evaluate the metrics with evaluator.
Also benchmark the inference speed of `model.forward` accurately.
@ -118,7 +126,6 @@ def inference_on_dataset(model, data_loader, evaluator,GENERATE_CALI):
Returns:
The return value of `evaluator.evaluate()`
"""
global class_score
num_devices = get_world_size()
@ -141,44 +148,41 @@ def inference_on_dataset(model, data_loader, evaluator,GENERATE_CALI):
total_compute_time = 0
start_compute_time = time.perf_counter()
outputs = model(inputs)
if GENERATE_CALI is True:
single_input = inputs[0]['instances']
single_res = outputs[0]['instances']
if len(single_res) > 0:
match_quality_matrix = pairwise_iou(
single_input.gt_boxes.to(single_res.pred_boxes.device), single_res.pred_boxes
)
# match_quality_matrix is M (gt) x N (predicted)
# Max over gt elements (dim 0) to find best gt candidate for each prediction
matched_vals, matches = match_quality_matrix.max(dim=0)
# TODO: add some check to analazy =======================
for i in range(len(matched_vals)): # N, prediction
pre_iou = matched_vals[i] # max iou
pre_gt_class = single_input.gt_classes[matches[i]].to(single_res.pred_boxes.device)
pre_res_class_score = single_res.scores[i]
pre_res_class = single_res.pred_classes[i]
if pre_iou < 0.7:
# prediction's iou is not enough
# if pre_gt_class == 5 or pre_gt_class == 18:
# # 5 and 14 confuse
# print("in iou, < 0.5, gt class:{}, pred class:{}, ".format(pre_gt_class, pre_res_class))
# print(pre_res_class_score)
continue
# has intersection with gt box
if pre_gt_class == pre_res_class:
class_score[pre_gt_class].append(pre_res_class_score)
# logger.info(
# "class_score:{}, iou_below_but_right:{}, unknown_to_bg:{}, all_bg:{}, true_bg:{}".format(
# "None", iou_below_but_right, unknown_to_bg, all_bg, true_bg
# )
# )
# ================================================================
# if cfg.OWOD.GENERATE_CALI:
# # for testing class average scores
# # inputs[0]: {'file_name', 'image_id', 'height', 'width', 'image'}
# # outputs[0]: {'instances'}
# # ================================================================
# single_input = inputs[0]['instances']
# single_res = outputs[0]['instances']
# if len(single_res) > 0:
# match_quality_matrix = pairwise_iou(
# single_input.gt_boxes.to(single_res.pred_boxes.device), single_res.pred_boxes
# )
# # match_quality_matrix is M (gt) x N (predicted)
# # Max over gt elements (dim 0) to find best gt candidate for each prediction
# matched_vals, matches = match_quality_matrix.max(dim=0)
# # TODO: add some check to analazy =======================
# for i in range(len(matched_vals)): # N, prediction
# pre_iou = matched_vals[i] # max iou
# pre_gt_class = single_input.gt_classes[matches[i]].to(single_res.pred_boxes.device)
# pre_res_class_score = single_res.scores[i]
# pre_res_class = single_res.pred_classes[i]
# # prediction's iou is not enough
# if pre_iou < cfg.OWOD.CALIBRATION:
# continue
# # submit gt boxes
# if pre_gt_class == pre_res_class:
# class_score[pre_gt_class].append(pre_res_class_score)
#
# cali_path = cfg.OWOD.CALI_PATH + '.pickle'
# score_file = open(cali_path, 'wb')
# pickle.dump(class_score, score_file)
# score_file.close()
# # ================================================================
if torch.cuda.is_available():
torch.cuda.synchronize()
@ -213,10 +217,6 @@ def inference_on_dataset(model, data_loader, evaluator,GENERATE_CALI):
total_compute_time_str, total_compute_time / (total - num_warmup), num_devices
)
)
if GENERATE_CALI is True:
score_file = open('analyze/1122/t3_ori_set_train_scores_cali_07.pickle', 'wb')
pickle.dump(class_score, score_file)
score_file.close()
results = evaluator.evaluate()
# An evaluator may return None when not in main process.