mirror of
https://github.com/huggingface/pytorch-image-models.git
synced 2025-06-03 15:01:08 +08:00
dataset not passed through PrefetchLoader for inference script. Fix #10
* also, make top5 configurable for lower class count cases
This commit is contained in:
parent
2060e433c0
commit
c1a84ecb22
@ -61,6 +61,10 @@ class PrefetchLoader:
|
|||||||
def sampler(self):
|
def sampler(self):
|
||||||
return self.loader.sampler
|
return self.loader.sampler
|
||||||
|
|
||||||
|
@property
|
||||||
|
def dataset(self):
|
||||||
|
return self.loader.dataset
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def mixup_enabled(self):
|
def mixup_enabled(self):
|
||||||
if isinstance(self.loader.collate_fn, FastCollateMixup):
|
if isinstance(self.loader.collate_fn, FastCollateMixup):
|
||||||
|
15
inference.py
15
inference.py
@ -48,6 +48,8 @@ parser.add_argument('--num-gpu', type=int, default=1,
|
|||||||
help='Number of GPUS to use')
|
help='Number of GPUS to use')
|
||||||
parser.add_argument('--no-test-pool', dest='no_test_pool', action='store_true',
|
parser.add_argument('--no-test-pool', dest='no_test_pool', action='store_true',
|
||||||
help='disable test time pool')
|
help='disable test time pool')
|
||||||
|
parser.add_argument('--topk', default=5, type=int,
|
||||||
|
metavar='N', help='Top-k to output to CSV')
|
||||||
|
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
@ -85,15 +87,16 @@ def main():
|
|||||||
|
|
||||||
model.eval()
|
model.eval()
|
||||||
|
|
||||||
|
k = min(args.topk, args.num_classes)
|
||||||
batch_time = AverageMeter()
|
batch_time = AverageMeter()
|
||||||
end = time.time()
|
end = time.time()
|
||||||
top5_ids = []
|
topk_ids = []
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
for batch_idx, (input, _) in enumerate(loader):
|
for batch_idx, (input, _) in enumerate(loader):
|
||||||
input = input.cuda()
|
input = input.cuda()
|
||||||
labels = model(input)
|
labels = model(input)
|
||||||
top5 = labels.topk(5)[1]
|
topk = labels.topk(k)[1]
|
||||||
top5_ids.append(top5.cpu().numpy())
|
topk_ids.append(topk.cpu().numpy())
|
||||||
|
|
||||||
# measure elapsed time
|
# measure elapsed time
|
||||||
batch_time.update(time.time() - end)
|
batch_time.update(time.time() - end)
|
||||||
@ -104,11 +107,11 @@ def main():
|
|||||||
'Time {batch_time.val:.3f} ({batch_time.avg:.3f})'.format(
|
'Time {batch_time.val:.3f} ({batch_time.avg:.3f})'.format(
|
||||||
batch_idx, len(loader), batch_time=batch_time))
|
batch_idx, len(loader), batch_time=batch_time))
|
||||||
|
|
||||||
top5_ids = np.concatenate(top5_ids, axis=0).squeeze()
|
topk_ids = np.concatenate(topk_ids, axis=0).squeeze()
|
||||||
|
|
||||||
with open(os.path.join(args.output_dir, './top5_ids.csv'), 'w') as out_file:
|
with open(os.path.join(args.output_dir, './topk_ids.csv'), 'w') as out_file:
|
||||||
filenames = loader.dataset.filenames()
|
filenames = loader.dataset.filenames()
|
||||||
for filename, label in zip(filenames, top5_ids):
|
for filename, label in zip(filenames, topk_ids):
|
||||||
filename = os.path.basename(filename)
|
filename = os.path.basename(filename)
|
||||||
out_file.write('{0},{1},{2},{3},{4},{5}\n'.format(
|
out_file.write('{0},{1},{2},{3},{4},{5}\n'.format(
|
||||||
filename, label[0], label[1], label[2], label[3], label[4]))
|
filename, label[0], label[1], label[2], label[3], label[4]))
|
||||||
|
Loading…
x
Reference in New Issue
Block a user