2022-09-18 10:11:55 +08:00
|
|
|
# Copyright (c) OpenMMLab. All rights reserved.
|
2022-10-12 21:05:34 +08:00
|
|
|
import logging
|
|
|
|
import os
|
|
|
|
import urllib
|
2022-09-18 10:11:55 +08:00
|
|
|
from argparse import ArgumentParser
|
|
|
|
|
|
|
|
import mmcv
|
2022-10-12 21:05:34 +08:00
|
|
|
import torch
|
2022-09-18 17:04:14 +08:00
|
|
|
from mmdet.apis import inference_detector, init_detector
|
2022-10-12 21:05:34 +08:00
|
|
|
from mmengine.logging import print_log
|
|
|
|
from mmengine.utils import ProgressBar, scandir
|
2022-09-20 10:57:33 +08:00
|
|
|
|
2022-09-18 10:11:55 +08:00
|
|
|
from mmyolo.registry import VISUALIZERS
|
|
|
|
from mmyolo.utils import register_all_modules
|
|
|
|
|
2022-10-12 21:05:34 +08:00
|
|
|
IMG_EXTENSIONS = ('.jpg', '.jpeg', '.png', '.ppm', '.bmp', '.pgm', '.tif',
|
|
|
|
'.tiff', '.webp')
|
|
|
|
|
2022-09-18 10:11:55 +08:00
|
|
|
|
|
|
|
def parse_args():
|
|
|
|
parser = ArgumentParser()
|
2022-10-12 21:05:34 +08:00
|
|
|
parser.add_argument(
|
|
|
|
'img', help='Image path, include image file, dir and URL.')
|
2022-09-18 10:11:55 +08:00
|
|
|
parser.add_argument('config', help='Config file')
|
|
|
|
parser.add_argument('checkpoint', help='Checkpoint file')
|
2022-10-12 21:05:34 +08:00
|
|
|
parser.add_argument(
|
|
|
|
'--out-dir', default='./output', help='Path to output file')
|
2022-09-18 10:11:55 +08:00
|
|
|
parser.add_argument(
|
|
|
|
'--device', default='cuda:0', help='Device used for inference')
|
|
|
|
parser.add_argument(
|
2022-10-12 21:05:34 +08:00
|
|
|
'--show', action='store_true', help='Show the detection results')
|
2022-09-18 10:11:55 +08:00
|
|
|
parser.add_argument(
|
|
|
|
'--score-thr', type=float, default=0.3, help='bbox score threshold')
|
|
|
|
args = parser.parse_args()
|
|
|
|
return args
|
|
|
|
|
|
|
|
|
|
|
|
def main(args):
|
|
|
|
# register all modules in mmdet into the registries
|
|
|
|
register_all_modules()
|
|
|
|
|
|
|
|
# build the model from a config file and a checkpoint file
|
|
|
|
model = init_detector(args.config, args.checkpoint, device=args.device)
|
|
|
|
|
|
|
|
# init visualizer
|
|
|
|
visualizer = VISUALIZERS.build(model.cfg.visualizer)
|
|
|
|
visualizer.dataset_meta = model.dataset_meta
|
|
|
|
|
2022-10-12 21:05:34 +08:00
|
|
|
is_dir = os.path.isdir(args.img)
|
|
|
|
is_url = args.img.startswith(('http:/', 'https:/'))
|
|
|
|
is_file = os.path.splitext(args.img)[-1] in (IMG_EXTENSIONS)
|
|
|
|
|
|
|
|
files = []
|
|
|
|
if is_dir:
|
|
|
|
# when input source is dir
|
|
|
|
for file in scandir(args.img, IMG_EXTENSIONS, recursive=True):
|
|
|
|
files.append(os.path.join(args.img, file))
|
|
|
|
elif is_url:
|
|
|
|
# when input source is url
|
|
|
|
filename = os.path.basename(
|
|
|
|
urllib.parse.unquote(args.img).split('?')[0])
|
|
|
|
torch.hub.download_url_to_file(args.img, filename)
|
|
|
|
files = [os.path.join(os.getcwd(), filename)]
|
|
|
|
elif is_file:
|
|
|
|
# when input source is single image
|
|
|
|
files = [args.img]
|
|
|
|
else:
|
|
|
|
print_log(
|
|
|
|
'Cannot find image file.', logger='current', level=logging.WARNING)
|
|
|
|
|
|
|
|
# start detector inference
|
|
|
|
progress_bar = ProgressBar(len(files))
|
|
|
|
for file in files:
|
|
|
|
result = inference_detector(model, file)
|
|
|
|
img = mmcv.imread(file)
|
|
|
|
img = mmcv.imconvert(img, 'bgr', 'rgb')
|
|
|
|
if is_dir:
|
|
|
|
filename = os.path.relpath(file, args.img).replace('/', '_')
|
|
|
|
else:
|
|
|
|
filename = os.path.basename(file)
|
|
|
|
out_file = None if args.show else os.path.join(args.out_dir, filename)
|
|
|
|
visualizer.add_datasample(
|
|
|
|
filename,
|
|
|
|
img,
|
|
|
|
data_sample=result,
|
|
|
|
draw_gt=False,
|
|
|
|
show=args.show,
|
|
|
|
wait_time=0,
|
|
|
|
out_file=out_file,
|
|
|
|
pred_score_thr=args.score_thr)
|
|
|
|
progress_bar.update()
|
|
|
|
if not args.show:
|
|
|
|
print_log(
|
|
|
|
f'\nResults have been saved at {os.path.abspath(args.out_dir)}')
|
2022-09-18 10:11:55 +08:00
|
|
|
|
|
|
|
|
|
|
|
if __name__ == '__main__':
|
|
|
|
args = parse_args()
|
|
|
|
main(args)
|