From 5e0599c8257f5fe84dc7186fdcfb06afa016e1e1 Mon Sep 17 00:00:00 2001 From: HinGwenWoong Date: Sun, 20 Nov 2022 15:37:58 +0800 Subject: [PATCH] [Feature] Add flag for output labelme label file in `image_demo` (#288) * Add flag for labelme * Add labelme format code * Add labelme format code * Add labelme format code * Add labelme format code * Add labelme format code * Add labelme format code * Add labelme format code * Add labelme format code * Add labelme format code * Add labelme format code * Add labelme format code * Not save `imageData` * preprocess score * preprocess score * Fix lint --- demo/image_demo.py | 34 ++++++++++++-- mmyolo/utils/labelme_utils.py | 87 +++++++++++++++++++++++++++++++++++ 2 files changed, 118 insertions(+), 3 deletions(-) create mode 100644 mmyolo/utils/labelme_utils.py diff --git a/demo/image_demo.py b/demo/image_demo.py index 5ccc7aef..ebe255b6 100644 --- a/demo/image_demo.py +++ b/demo/image_demo.py @@ -9,6 +9,7 @@ from mmengine.utils import ProgressBar from mmyolo.registry import VISUALIZERS from mmyolo.utils import register_all_modules, switch_to_deploy +from mmyolo.utils.labelme_utils import LabelmeFormat from mmyolo.utils.misc import get_file_list @@ -30,6 +31,10 @@ def parse_args(): help='Switch model to deployment mode') parser.add_argument( '--score-thr', type=float, default=0.3, help='Bbox score threshold') + parser.add_argument( + '--to-labelme', + action='store_true', + help='Output labelme style label file') args = parser.parse_args() return args @@ -37,6 +42,10 @@ def parse_args(): def main(): args = parse_args() + if args.to_labelme and args.show: + raise RuntimeError('`--to-labelme` or `--show` only ' + 'can choose one at the same time.') + # register all modules in mmdet into the registries register_all_modules() @@ -56,6 +65,9 @@ def main(): # get file list files, source_type = get_file_list(args.img) + # ready for labelme format if it is needed + to_label_format = LabelmeFormat(classes=model.dataset_meta.get('CLASSES')) + # start detector inference progress_bar = ProgressBar(len(files)) for file in files: @@ -70,8 +82,21 @@ def main(): filename = os.path.basename(file) out_file = None if args.show else os.path.join(args.out_dir, filename) + progress_bar.update() + + # Get candidate predict info with score threshold + pred_instances = result.pred_instances[ + result.pred_instances.scores > args.score_thr] + + if args.to_labelme: + # save result to labelme files + out_file = out_file.replace( + os.path.splitext(out_file)[-1], '.json') + to_label_format(result, out_file, pred_instances) + continue + visualizer.add_datasample( - os.path.basename(out_file), + filename, img, data_sample=result, draw_gt=False, @@ -79,12 +104,15 @@ def main(): wait_time=0, out_file=out_file, pred_score_thr=args.score_thr) - progress_bar.update() - if not args.show: + if not args.show and not args.to_labelme: print_log( f'\nResults have been saved at {os.path.abspath(args.out_dir)}') + elif args.to_labelme: + print_log('\nLabelme format label files ' + f'had all been saved in {args.out_dir}') + if __name__ == '__main__': main() diff --git a/mmyolo/utils/labelme_utils.py b/mmyolo/utils/labelme_utils.py new file mode 100644 index 00000000..b926a1f6 --- /dev/null +++ b/mmyolo/utils/labelme_utils.py @@ -0,0 +1,87 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import json + +from mmdet.structures import DetDataSample +from mmengine.structures import InstanceData + + +class LabelmeFormat: + """Predict results save into labelme file. + + Base on https://github.com/wkentaro/labelme/blob/main/labelme/label_file.py + + Args: + classes (tuple): Model classes name. + score_threshold (float): Predict score threshold. + """ + + def __init__(self, classes: tuple): + super().__init__() + self.classes = classes + + def __call__(self, results: DetDataSample, output_path: str, + pred_instances: InstanceData): + """Get image data field for labelme. + + Args: + results (DetDataSample): Predict info. + output_path (str): Image file path. + pred_instances (InstanceData): Candidate prediction info. + + Labelme file eg. + { + "version": "5.0.5", + "flags": {}, + "imagePath": "/data/cat/1.jpg", + "imageData": null, + "imageHeight": 3000, + "imageWidth": 4000, + "shapes": [ + { + "label": "cat", + "points": [ + [ + 1148.076923076923, + 1188.4615384615383 + ], + [ + 2471.1538461538457, + 2176.923076923077 + ] + ], + "group_id": null, + "shape_type": "rectangle", + "flags": {} + }, + {...} + ] + } + """ + + image_path = results.metainfo['img_path'] + + json_info = { + 'version': '5.0.5', + 'flags': {}, + 'imagePath': image_path, + 'imageData': None, + 'imageHeight': results.ori_shape[0], + 'imageWidth': results.ori_shape[1], + 'shapes': [] + } + + for pred_info in pred_instances: + pred_bbox = pred_info.bboxes.cpu().numpy().tolist()[0] + pred_label = self.classes[pred_info.labels] + + sub_dict = { + 'label': pred_label, + 'points': [pred_bbox[:2], pred_bbox[2:]], + 'group_id': None, + 'shape_type': 'rectangle', + 'flags': {} + } + json_info['shapes'].append(sub_dict) + + with open(output_path, 'w', encoding='utf-8') as f_json: + json.dump(json_info, f_json, ensure_ascii=False, indent=2)