mirror of
https://github.com/huggingface/pytorch-image-models.git
synced 2025-06-03 15:01:08 +08:00
Update inference script for new loader style
This commit is contained in:
parent
58571e992e
commit
1e23727f2f
44
inference.py
44
inference.py
@ -10,10 +10,10 @@ import time
|
|||||||
import argparse
|
import argparse
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
import torch.utils.data as data
|
|
||||||
|
|
||||||
from models import create_model, transforms_imagenet_eval
|
from models import create_model
|
||||||
from dataset import Dataset
|
from data import Dataset, create_loader, get_model_meanstd
|
||||||
|
from utils import AverageMeter
|
||||||
|
|
||||||
|
|
||||||
parser = argparse.ArgumentParser(description='PyTorch ImageNet Inference')
|
parser = argparse.ArgumentParser(description='PyTorch ImageNet Inference')
|
||||||
@ -70,14 +70,15 @@ def main():
|
|||||||
else:
|
else:
|
||||||
model = model.cuda()
|
model = model.cuda()
|
||||||
|
|
||||||
dataset = Dataset(
|
data_mean, data_std = get_model_meanstd(args.model)
|
||||||
args.data,
|
loader = create_loader(
|
||||||
transforms_imagenet_eval(args.model, args.img_size))
|
Dataset(args.data),
|
||||||
|
img_size=args.img_size,
|
||||||
loader = data.DataLoader(
|
batch_size=args.batch_size,
|
||||||
dataset,
|
use_prefetcher=True,
|
||||||
batch_size=args.batch_size, shuffle=False,
|
mean=data_mean,
|
||||||
num_workers=args.workers, pin_memory=True)
|
std=data_std,
|
||||||
|
num_workers=args.workers)
|
||||||
|
|
||||||
model.eval()
|
model.eval()
|
||||||
|
|
||||||
@ -103,31 +104,12 @@ def main():
|
|||||||
top5_ids = np.concatenate(top5_ids, axis=0).squeeze()
|
top5_ids = np.concatenate(top5_ids, axis=0).squeeze()
|
||||||
|
|
||||||
with open(os.path.join(args.output_dir, './top5_ids.csv'), 'w') as out_file:
|
with open(os.path.join(args.output_dir, './top5_ids.csv'), 'w') as out_file:
|
||||||
filenames = dataset.filenames()
|
filenames = loader.dataset.filenames()
|
||||||
for filename, label in zip(filenames, top5_ids):
|
for filename, label in zip(filenames, top5_ids):
|
||||||
filename = os.path.basename(filename)
|
filename = os.path.basename(filename)
|
||||||
out_file.write('{0},{1},{2},{3},{4},{5}\n'.format(
|
out_file.write('{0},{1},{2},{3},{4},{5}\n'.format(
|
||||||
filename, label[0], label[1], label[2], label[3], label[4]))
|
filename, label[0], label[1], label[2], label[3], label[4]))
|
||||||
|
|
||||||
|
|
||||||
class AverageMeter(object):
|
|
||||||
"""Computes and stores the average and current value"""
|
|
||||||
|
|
||||||
def __init__(self):
|
|
||||||
self.reset()
|
|
||||||
|
|
||||||
def reset(self):
|
|
||||||
self.val = 0
|
|
||||||
self.avg = 0
|
|
||||||
self.sum = 0
|
|
||||||
self.count = 0
|
|
||||||
|
|
||||||
def update(self, val, n=1):
|
|
||||||
self.val = val
|
|
||||||
self.sum += val * n
|
|
||||||
self.count += n
|
|
||||||
self.avg = self.sum / self.count
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
main()
|
main()
|
||||||
|
Loading…
x
Reference in New Issue
Block a user