Use gather (fancy indexing) for valid labels instead of bool mask in validate.py

This commit is contained in:
Ross Wightman 2023-03-18 15:08:19 -07:00
parent 9fcfb8bcc1
commit 3448cc689c

View File

@ -255,8 +255,7 @@ def validate(args):
if args.valid_labels:
with open(args.valid_labels, 'r') as f:
valid_labels = {int(line.rstrip()) for line in f}
valid_labels = [i in valid_labels for i in range(args.num_classes)]
valid_labels = [int(line.rstrip()) for line in f]
else:
valid_labels = None