mmpretrain/tools/analyze_results.py

100 lines
2.9 KiB
Python

import argparse
import os.path as osp
import mmcv
from mmcv import DictAction
from mmcls.datasets import build_dataset
from mmcls.models import build_classifier
def parse_args():
parser = argparse.ArgumentParser(
description='MMCls evaluate prediction success/fail')
parser.add_argument('config', help='test config file path')
parser.add_argument('result', help='test result json/pkl file')
parser.add_argument('--out-dir', help='dir to store output files')
parser.add_argument(
'--topk',
default=20,
type=int,
help='Number of images to select for sucess/fail')
parser.add_argument(
'--options',
nargs='+',
action=DictAction,
help='override some settings in the used config, the key-value pair '
'in xxx=yyy format will be merged into config file.')
args = parser.parse_args()
return args
def save_imgs(result_dir, folder_name, results, model):
full_dir = osp.join(result_dir, folder_name)
mmcv.mkdir_or_exist(full_dir)
mmcv.dump(results, osp.join(full_dir, folder_name + '.json'))
# save imgs
show_keys = ['pred_score', 'pred_class', 'gt_class']
for result in results:
result_show = dict((k, v) for k, v in result.items() if k in show_keys)
outfile = osp.join(full_dir, osp.basename(result['filename']))
model.show_result(result['filename'], result_show, out_file=outfile)
def main():
args = parse_args()
cfg = mmcv.Config.fromfile(args.config)
if args.options is not None:
cfg.merge_from_dict(args.options)
model = build_classifier(cfg.model)
# build the dataloader
dataset = build_dataset(cfg.data.test)
filenames = list()
for info in dataset.data_infos:
if info['img_prefix'] is not None:
filename = osp.join(info['img_prefix'],
info['img_info']['filename'])
else:
filename = info['img_info']['filename']
filenames.append(filename)
gt_labels = list(dataset.get_gt_labels())
gt_classes = [dataset.CLASSES[x] for x in gt_labels]
# load test results
outputs = mmcv.load(args.result)
outputs['filename'] = filenames
outputs['gt_label'] = gt_labels
outputs['gt_class'] = gt_classes
outputs_list = list()
for i in range(len(gt_labels)):
output = dict()
for k in outputs.keys():
output[k] = outputs[k][i]
outputs_list.append(output)
# sort result
outputs_list = sorted(outputs_list, key=lambda x: x['pred_score'])
success = list()
fail = list()
for output in outputs_list:
if output['pred_label'] == output['gt_label']:
success.append(output)
else:
fail.append(output)
success = success[:args.topk]
fail = fail[:args.topk]
save_imgs(args.out_dir, 'success', success, model)
save_imgs(args.out_dir, 'fail', fail, model)
if __name__ == '__main__':
main()