mirror of https://github.com/open-mmlab/mmyolo.git
135 lines
4.2 KiB
Python
135 lines
4.2 KiB
Python
# Copyright (c) OpenMMLab. All rights reserved.
|
|
import os
|
|
from argparse import ArgumentParser
|
|
|
|
import mmcv
|
|
from mmdet.apis import inference_detector, init_detector
|
|
from mmengine.logging import print_log
|
|
from mmengine.utils import ProgressBar, path
|
|
|
|
from mmyolo.registry import VISUALIZERS
|
|
from mmyolo.utils import switch_to_deploy
|
|
from mmyolo.utils.labelme_utils import LabelmeFormat
|
|
from mmyolo.utils.misc import get_file_list, show_data_classes
|
|
|
|
|
|
def parse_args():
|
|
parser = ArgumentParser()
|
|
parser.add_argument(
|
|
'img', help='Image path, include image file, dir and URL.')
|
|
parser.add_argument('config', help='Config file')
|
|
parser.add_argument('checkpoint', help='Checkpoint file')
|
|
parser.add_argument(
|
|
'--out-dir', default='./output', help='Path to output file')
|
|
parser.add_argument(
|
|
'--device', default='cuda:0', help='Device used for inference')
|
|
parser.add_argument(
|
|
'--show', action='store_true', help='Show the detection results')
|
|
parser.add_argument(
|
|
'--deploy',
|
|
action='store_true',
|
|
help='Switch model to deployment mode')
|
|
parser.add_argument(
|
|
'--score-thr', type=float, default=0.3, help='Bbox score threshold')
|
|
parser.add_argument(
|
|
'--class-name',
|
|
nargs='+',
|
|
type=str,
|
|
help='Only Save those classes if set')
|
|
parser.add_argument(
|
|
'--to-labelme',
|
|
action='store_true',
|
|
help='Output labelme style label file')
|
|
args = parser.parse_args()
|
|
return args
|
|
|
|
|
|
def main():
|
|
args = parse_args()
|
|
|
|
if args.to_labelme and args.show:
|
|
raise RuntimeError('`--to-labelme` or `--show` only '
|
|
'can choose one at the same time.')
|
|
|
|
# build the model from a config file and a checkpoint file
|
|
model = init_detector(args.config, args.checkpoint, device=args.device)
|
|
|
|
if args.deploy:
|
|
switch_to_deploy(model)
|
|
|
|
if not args.show:
|
|
path.mkdir_or_exist(args.out_dir)
|
|
|
|
# init visualizer
|
|
visualizer = VISUALIZERS.build(model.cfg.visualizer)
|
|
visualizer.dataset_meta = model.dataset_meta
|
|
|
|
# get file list
|
|
files, source_type = get_file_list(args.img)
|
|
|
|
# get model class name
|
|
dataset_classes = model.dataset_meta.get('classes')
|
|
|
|
# ready for labelme format if it is needed
|
|
to_label_format = LabelmeFormat(classes=dataset_classes)
|
|
|
|
# check class name
|
|
if args.class_name is not None:
|
|
for class_name in args.class_name:
|
|
if class_name in dataset_classes:
|
|
continue
|
|
show_data_classes(dataset_classes)
|
|
raise RuntimeError(
|
|
'Expected args.class_name to be one of the list, '
|
|
f'but got "{class_name}"')
|
|
|
|
# start detector inference
|
|
progress_bar = ProgressBar(len(files))
|
|
for file in files:
|
|
result = inference_detector(model, file)
|
|
|
|
img = mmcv.imread(file)
|
|
img = mmcv.imconvert(img, 'bgr', 'rgb')
|
|
|
|
if source_type['is_dir']:
|
|
filename = os.path.relpath(file, args.img).replace('/', '_')
|
|
else:
|
|
filename = os.path.basename(file)
|
|
out_file = None if args.show else os.path.join(args.out_dir, filename)
|
|
|
|
progress_bar.update()
|
|
|
|
# Get candidate predict info with score threshold
|
|
pred_instances = result.pred_instances[
|
|
result.pred_instances.scores > args.score_thr]
|
|
|
|
if args.to_labelme:
|
|
# save result to labelme files
|
|
out_file = out_file.replace(
|
|
os.path.splitext(out_file)[-1], '.json')
|
|
to_label_format(pred_instances, result.metainfo, out_file,
|
|
args.class_name)
|
|
continue
|
|
|
|
visualizer.add_datasample(
|
|
filename,
|
|
img,
|
|
data_sample=result,
|
|
draw_gt=False,
|
|
show=args.show,
|
|
wait_time=0,
|
|
out_file=out_file,
|
|
pred_score_thr=args.score_thr)
|
|
|
|
if not args.show and not args.to_labelme:
|
|
print_log(
|
|
f'\nResults have been saved at {os.path.abspath(args.out_dir)}')
|
|
|
|
elif args.to_labelme:
|
|
print_log('\nLabelme format label files '
|
|
f'had all been saved in {args.out_dir}')
|
|
|
|
|
|
if __name__ == '__main__':
|
|
main()
|