mmyolo/demo/image_demo.py
Haian Huang(深度眸) bb3aa48689 [Feature] Support RTMDet (#85)
* add cspnext

* add neck


* update

* align s,tiny test

* update

* update

* fix bug of UT

* fix typehint

* fix resize bug

* update s and tiny

* update s and tiny link

* update convert

* add cls config with rtmdet

* update config


* fix ci error

Co-authored-by: wanghonglie <wanghonglie@pjlab.org.cn>
2022-09-29 19:09:56 +08:00

62 lines
1.8 KiB
Python

# Copyright (c) OpenMMLab. All rights reserved.
from argparse import ArgumentParser
import mmcv
from mmdet.apis import inference_detector, init_detector
from mmyolo.registry import VISUALIZERS
from mmyolo.utils import register_all_modules
def parse_args():
parser = ArgumentParser()
parser.add_argument('img', help='Image file')
parser.add_argument('config', help='Config file')
parser.add_argument('checkpoint', help='Checkpoint file')
parser.add_argument('--out-file', default=None, help='Path to output file')
parser.add_argument(
'--device', default='cuda:0', help='Device used for inference')
parser.add_argument(
'--palette',
default='coco',
choices=['coco', 'voc', 'citys', 'random'],
help='Color palette used for visualization')
parser.add_argument(
'--score-thr', type=float, default=0.3, help='bbox score threshold')
args = parser.parse_args()
return args
def main(args):
# register all modules in mmdet into the registries
register_all_modules()
# TODO: Support inference of image directory.
# build the model from a config file and a checkpoint file
model = init_detector(args.config, args.checkpoint, device=args.device)
# init visualizer
visualizer = VISUALIZERS.build(model.cfg.visualizer)
visualizer.dataset_meta = model.dataset_meta
# test a single image
result = inference_detector(model, args.img)
# show the results
img = mmcv.imread(args.img)
img = mmcv.imconvert(img, 'bgr', 'rgb')
visualizer.add_datasample(
'result',
img,
data_sample=result,
draw_gt=False,
show=True,
wait_time=0,
out_file=args.out_file,
pred_score_thr=args.score_thr)
if __name__ == '__main__':
args = parse_args()
main(args)