[Fix] Fix test.py (#155)
* modeify test.py * modeify test.py * revise according to commentspull/156/head
parent
a225cb6bdb
commit
a7d4739d2b
|
@ -21,7 +21,13 @@ def parse_args():
|
|||
parser.add_argument('checkpoint', help='checkpoint file')
|
||||
parser.add_argument('--out', help='output result file')
|
||||
parser.add_argument(
|
||||
'--metric', type=str, default='accuracy', help='evaluation metric')
|
||||
'--metrics',
|
||||
type=str,
|
||||
nargs='+',
|
||||
help='evaluation metrics, which depends on the dataset, e.g., '
|
||||
'"accuracy", "precision", "recall", "f1_score", "support" for single '
|
||||
'label dataset, and "mAP", "CP", "CR", "CF1", "OP", "OR", "OF1" for '
|
||||
'multi-label dataset')
|
||||
parser.add_argument(
|
||||
'--gpu_collect',
|
||||
action='store_true',
|
||||
|
@ -94,11 +100,12 @@ def main():
|
|||
|
||||
rank, _ = get_dist_info()
|
||||
if rank == 0:
|
||||
if args.metric != '':
|
||||
results = dataset.evaluate(outputs, args.metric)
|
||||
for topk, acc in results.items():
|
||||
print(f'\n{topk} accuracy: {acc:.2f}')
|
||||
if args.metrics:
|
||||
results = dataset.evaluate(outputs, args.metrics)
|
||||
for k, v in results.items():
|
||||
print(f'\n{k} : {v:.2f}')
|
||||
else:
|
||||
warnings.warn('Evaluation metrics are not specified.')
|
||||
scores = np.vstack(outputs)
|
||||
pred_score = np.max(scores, axis=1)
|
||||
pred_label = np.argmax(scores, axis=1)
|
||||
|
|
Loading…
Reference in New Issue