[Fix] Fix test.py (#155)

* modeify test.py

* modeify test.py

* revise according to comments
pull/156/head
LXXXXR 2021-01-31 15:53:43 +08:00 committed by GitHub
parent a225cb6bdb
commit a7d4739d2b
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 12 additions and 5 deletions

View File

@ -21,7 +21,13 @@ def parse_args():
parser.add_argument('checkpoint', help='checkpoint file')
parser.add_argument('--out', help='output result file')
parser.add_argument(
'--metric', type=str, default='accuracy', help='evaluation metric')
'--metrics',
type=str,
nargs='+',
help='evaluation metrics, which depends on the dataset, e.g., '
'"accuracy", "precision", "recall", "f1_score", "support" for single '
'label dataset, and "mAP", "CP", "CR", "CF1", "OP", "OR", "OF1" for '
'multi-label dataset')
parser.add_argument(
'--gpu_collect',
action='store_true',
@ -94,11 +100,12 @@ def main():
rank, _ = get_dist_info()
if rank == 0:
if args.metric != '':
results = dataset.evaluate(outputs, args.metric)
for topk, acc in results.items():
print(f'\n{topk} accuracy: {acc:.2f}')
if args.metrics:
results = dataset.evaluate(outputs, args.metrics)
for k, v in results.items():
print(f'\n{k} : {v:.2f}')
else:
warnings.warn('Evaluation metrics are not specified.')
scores = np.vstack(outputs)
pred_score = np.max(scores, axis=1)
pred_label = np.argmax(scores, axis=1)