[Fix] Fix test.py (#155)
* modeify test.py * modeify test.py * revise according to commentspull/156/head
parent
a225cb6bdb
commit
a7d4739d2b
|
@ -21,7 +21,13 @@ def parse_args():
|
||||||
parser.add_argument('checkpoint', help='checkpoint file')
|
parser.add_argument('checkpoint', help='checkpoint file')
|
||||||
parser.add_argument('--out', help='output result file')
|
parser.add_argument('--out', help='output result file')
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
'--metric', type=str, default='accuracy', help='evaluation metric')
|
'--metrics',
|
||||||
|
type=str,
|
||||||
|
nargs='+',
|
||||||
|
help='evaluation metrics, which depends on the dataset, e.g., '
|
||||||
|
'"accuracy", "precision", "recall", "f1_score", "support" for single '
|
||||||
|
'label dataset, and "mAP", "CP", "CR", "CF1", "OP", "OR", "OF1" for '
|
||||||
|
'multi-label dataset')
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
'--gpu_collect',
|
'--gpu_collect',
|
||||||
action='store_true',
|
action='store_true',
|
||||||
|
@ -94,11 +100,12 @@ def main():
|
||||||
|
|
||||||
rank, _ = get_dist_info()
|
rank, _ = get_dist_info()
|
||||||
if rank == 0:
|
if rank == 0:
|
||||||
if args.metric != '':
|
if args.metrics:
|
||||||
results = dataset.evaluate(outputs, args.metric)
|
results = dataset.evaluate(outputs, args.metrics)
|
||||||
for topk, acc in results.items():
|
for k, v in results.items():
|
||||||
print(f'\n{topk} accuracy: {acc:.2f}')
|
print(f'\n{k} : {v:.2f}')
|
||||||
else:
|
else:
|
||||||
|
warnings.warn('Evaluation metrics are not specified.')
|
||||||
scores = np.vstack(outputs)
|
scores = np.vstack(outputs)
|
||||||
pred_score = np.max(scores, axis=1)
|
pred_score = np.max(scores, axis=1)
|
||||||
pred_label = np.argmax(scores, axis=1)
|
pred_label = np.argmax(scores, axis=1)
|
||||||
|
|
Loading…
Reference in New Issue