diff --git a/classifier.py b/classifier.py index 857f59fbf..2cbf47469 100644 --- a/classifier.py +++ b/classifier.py @@ -8,9 +8,9 @@ Usage-train: Usage-inference: from classifier import * - model = torch.load('best.pt', map_location=torch.device('cpu'))['model'].float() - files = Path('../datasets/mnist/test/7').glob('*.png') - for f in list(files)[:10]: + model = torch.load('path/to/best.pt', map_location=torch.device('cpu'))['model'].float() + files = Path('../datasets/mnist/test/7').glob('*.png') # images from dir + for f in list(files)[:10]: # first 10 images classify(model, size=128, file=f) """