mirror of
https://github.com/huggingface/pytorch-image-models.git
synced 2025-06-03 15:01:08 +08:00
dataset not passed through PrefetchLoader for inference script. Fix #10
* also, make top5 configurable for lower class count cases
This commit is contained in:
parent
2060e433c0
commit
c1a84ecb22
@ -61,6 +61,10 @@ class PrefetchLoader:
|
||||
def sampler(self):
|
||||
return self.loader.sampler
|
||||
|
||||
@property
|
||||
def dataset(self):
|
||||
return self.loader.dataset
|
||||
|
||||
@property
|
||||
def mixup_enabled(self):
|
||||
if isinstance(self.loader.collate_fn, FastCollateMixup):
|
||||
|
15
inference.py
15
inference.py
@ -48,6 +48,8 @@ parser.add_argument('--num-gpu', type=int, default=1,
|
||||
help='Number of GPUS to use')
|
||||
parser.add_argument('--no-test-pool', dest='no_test_pool', action='store_true',
|
||||
help='disable test time pool')
|
||||
parser.add_argument('--topk', default=5, type=int,
|
||||
metavar='N', help='Top-k to output to CSV')
|
||||
|
||||
|
||||
def main():
|
||||
@ -85,15 +87,16 @@ def main():
|
||||
|
||||
model.eval()
|
||||
|
||||
k = min(args.topk, args.num_classes)
|
||||
batch_time = AverageMeter()
|
||||
end = time.time()
|
||||
top5_ids = []
|
||||
topk_ids = []
|
||||
with torch.no_grad():
|
||||
for batch_idx, (input, _) in enumerate(loader):
|
||||
input = input.cuda()
|
||||
labels = model(input)
|
||||
top5 = labels.topk(5)[1]
|
||||
top5_ids.append(top5.cpu().numpy())
|
||||
topk = labels.topk(k)[1]
|
||||
topk_ids.append(topk.cpu().numpy())
|
||||
|
||||
# measure elapsed time
|
||||
batch_time.update(time.time() - end)
|
||||
@ -104,11 +107,11 @@ def main():
|
||||
'Time {batch_time.val:.3f} ({batch_time.avg:.3f})'.format(
|
||||
batch_idx, len(loader), batch_time=batch_time))
|
||||
|
||||
top5_ids = np.concatenate(top5_ids, axis=0).squeeze()
|
||||
topk_ids = np.concatenate(topk_ids, axis=0).squeeze()
|
||||
|
||||
with open(os.path.join(args.output_dir, './top5_ids.csv'), 'w') as out_file:
|
||||
with open(os.path.join(args.output_dir, './topk_ids.csv'), 'w') as out_file:
|
||||
filenames = loader.dataset.filenames()
|
||||
for filename, label in zip(filenames, top5_ids):
|
||||
for filename, label in zip(filenames, topk_ids):
|
||||
filename = os.path.basename(filename)
|
||||
out_file.write('{0},{1},{2},{3},{4},{5}\n'.format(
|
||||
filename, label[0], label[1], label[2], label[3], label[4]))
|
||||
|
Loading…
x
Reference in New Issue
Block a user