mmfewshot/demo/demo_metric_classifier_1sho...

51 lines
1.9 KiB
Python

# Copyright (c) OpenMMLab. All rights reserved.
"""Inference One Shot Classifier with support shots.
Example:
python demo/demo_metric_classifier_1shot_inference.py \
demo/demo_classification_images/query_images/Least_Auklet.jpg
configs/classification/proto_net/cub/proto-net_conv4_1xb105_cub_5way-1shot.py
./work_dirs/proto-net_conv4_1xb105_cub_5way-1shot/best_accuracy_mean.pth
""" # nowq
import os
from argparse import ArgumentParser
from mmfewshot.classification.apis import (inference_classifier,
init_classifier,
process_support_images,
show_result_pyplot)
def main():
parser = ArgumentParser('N way 1 shot inference.')
parser.add_argument('image', help='Image file')
parser.add_argument('config', help='Config file')
parser.add_argument('checkpoint', help='Checkpoint file')
parser.add_argument(
'--device', default='cuda:0', help='Device used for inference')
parser.add_argument(
'--support-images-dir',
default='demo/demo_classification_images/support_images',
help='path to support images directory, each image will use'
'file name as class')
args = parser.parse_args()
# build the model from a config file and a checkpoint file
model = init_classifier(args.config, args.checkpoint, device=args.device)
# prepare support set, each support class only contains one shot
files = os.listdir(args.support_images_dir)
support_images = [
os.path.join(args.support_images_dir, file) for file in files
]
support_labels = [file.split('.')[0] for file in files]
process_support_images(model, support_images, support_labels)
# test a single image
result = inference_classifier(model, args.image)
# show the results
show_result_pyplot(args.image, result)
if __name__ == '__main__':
main()