mmfewshot/demo/demo_attention_rpn_detector_inference.py
Linyiqi 73e8c8938c
Add demo and demo images. (#50)
* add demo and demo images

* fix demo comments
2021-11-12 23:27:47 +08:00

58 lines
2.0 KiB
Python

# Copyright (c) OpenMMLab. All rights reserved.
"""Inference Attention RPN Detector with support instances.
Example:
python demo/demo_attention_rpn_detector_inference.py \
./demo/demo_detection_images/query_images/demo_query.jpg
configs/detection/attention_rpn/coco/attention-rpn_r50_c4_4xb2_coco_base-training.py
./work_dirs/attention-rpn_r50_c4_4xb2_coco-base-training/latest.pth
""" # nowq
import os
from argparse import ArgumentParser
from mmdet.apis import show_result_pyplot
from mmfewshot.detection.apis import (inference_detector, init_detector,
process_support_images)
def parse_args():
parser = ArgumentParser('attention rpn inference.')
parser.add_argument('img', help='Image file')
parser.add_argument('config', help='Config file')
parser.add_argument('checkpoint', help='Checkpoint file')
parser.add_argument(
'--device', default='cuda:0', help='Device used for inference')
parser.add_argument(
'--score-thr', type=float, default=0.3, help='bbox score threshold')
parser.add_argument(
'--support-imgs-dir',
default='demo/demo_detection_images/support_images',
help='Image file')
args = parser.parse_args()
return args
def main(args):
# build the model from a config file and a checkpoint file
model = init_detector(args.config, args.checkpoint, device=args.device)
# prepare support images, each demo image only contain one instance
files = os.listdir(args.support_imgs_dir)
support_imgs = [
os.path.join(args.support_imgs_dir, file) for file in files
]
classes = [file.split('.')[0] for file in files]
support_labels = [[file.split('.')[0]] for file in files]
process_support_images(
model, support_imgs, support_labels, classes=classes)
# test a single image
result = inference_detector(model, args.img)
# show the results
show_result_pyplot(model, args.img, result, score_thr=args.score_thr)
if __name__ == '__main__':
args = parse_args()
main(args)