mirror of https://github.com/open-mmlab/mmyolo.git
[Feature] Add flag for output labelme label file in `image_demo` (#288)
* Add flag for labelme * Add labelme format code * Add labelme format code * Add labelme format code * Add labelme format code * Add labelme format code * Add labelme format code * Add labelme format code * Add labelme format code * Add labelme format code * Add labelme format code * Add labelme format code * Not save `imageData` * preprocess score * preprocess score * Fix lintpull/307/head
parent
12a1e8a984
commit
5e0599c825
|
@ -9,6 +9,7 @@ from mmengine.utils import ProgressBar
|
||||||
|
|
||||||
from mmyolo.registry import VISUALIZERS
|
from mmyolo.registry import VISUALIZERS
|
||||||
from mmyolo.utils import register_all_modules, switch_to_deploy
|
from mmyolo.utils import register_all_modules, switch_to_deploy
|
||||||
|
from mmyolo.utils.labelme_utils import LabelmeFormat
|
||||||
from mmyolo.utils.misc import get_file_list
|
from mmyolo.utils.misc import get_file_list
|
||||||
|
|
||||||
|
|
||||||
|
@ -30,6 +31,10 @@ def parse_args():
|
||||||
help='Switch model to deployment mode')
|
help='Switch model to deployment mode')
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
'--score-thr', type=float, default=0.3, help='Bbox score threshold')
|
'--score-thr', type=float, default=0.3, help='Bbox score threshold')
|
||||||
|
parser.add_argument(
|
||||||
|
'--to-labelme',
|
||||||
|
action='store_true',
|
||||||
|
help='Output labelme style label file')
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
return args
|
return args
|
||||||
|
|
||||||
|
@ -37,6 +42,10 @@ def parse_args():
|
||||||
def main():
|
def main():
|
||||||
args = parse_args()
|
args = parse_args()
|
||||||
|
|
||||||
|
if args.to_labelme and args.show:
|
||||||
|
raise RuntimeError('`--to-labelme` or `--show` only '
|
||||||
|
'can choose one at the same time.')
|
||||||
|
|
||||||
# register all modules in mmdet into the registries
|
# register all modules in mmdet into the registries
|
||||||
register_all_modules()
|
register_all_modules()
|
||||||
|
|
||||||
|
@ -56,6 +65,9 @@ def main():
|
||||||
# get file list
|
# get file list
|
||||||
files, source_type = get_file_list(args.img)
|
files, source_type = get_file_list(args.img)
|
||||||
|
|
||||||
|
# ready for labelme format if it is needed
|
||||||
|
to_label_format = LabelmeFormat(classes=model.dataset_meta.get('CLASSES'))
|
||||||
|
|
||||||
# start detector inference
|
# start detector inference
|
||||||
progress_bar = ProgressBar(len(files))
|
progress_bar = ProgressBar(len(files))
|
||||||
for file in files:
|
for file in files:
|
||||||
|
@ -70,8 +82,21 @@ def main():
|
||||||
filename = os.path.basename(file)
|
filename = os.path.basename(file)
|
||||||
out_file = None if args.show else os.path.join(args.out_dir, filename)
|
out_file = None if args.show else os.path.join(args.out_dir, filename)
|
||||||
|
|
||||||
|
progress_bar.update()
|
||||||
|
|
||||||
|
# Get candidate predict info with score threshold
|
||||||
|
pred_instances = result.pred_instances[
|
||||||
|
result.pred_instances.scores > args.score_thr]
|
||||||
|
|
||||||
|
if args.to_labelme:
|
||||||
|
# save result to labelme files
|
||||||
|
out_file = out_file.replace(
|
||||||
|
os.path.splitext(out_file)[-1], '.json')
|
||||||
|
to_label_format(result, out_file, pred_instances)
|
||||||
|
continue
|
||||||
|
|
||||||
visualizer.add_datasample(
|
visualizer.add_datasample(
|
||||||
os.path.basename(out_file),
|
filename,
|
||||||
img,
|
img,
|
||||||
data_sample=result,
|
data_sample=result,
|
||||||
draw_gt=False,
|
draw_gt=False,
|
||||||
|
@ -79,12 +104,15 @@ def main():
|
||||||
wait_time=0,
|
wait_time=0,
|
||||||
out_file=out_file,
|
out_file=out_file,
|
||||||
pred_score_thr=args.score_thr)
|
pred_score_thr=args.score_thr)
|
||||||
progress_bar.update()
|
|
||||||
|
|
||||||
if not args.show:
|
if not args.show and not args.to_labelme:
|
||||||
print_log(
|
print_log(
|
||||||
f'\nResults have been saved at {os.path.abspath(args.out_dir)}')
|
f'\nResults have been saved at {os.path.abspath(args.out_dir)}')
|
||||||
|
|
||||||
|
elif args.to_labelme:
|
||||||
|
print_log('\nLabelme format label files '
|
||||||
|
f'had all been saved in {args.out_dir}')
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
main()
|
main()
|
||||||
|
|
|
@ -0,0 +1,87 @@
|
||||||
|
# Copyright (c) OpenMMLab. All rights reserved.
|
||||||
|
import json
|
||||||
|
|
||||||
|
from mmdet.structures import DetDataSample
|
||||||
|
from mmengine.structures import InstanceData
|
||||||
|
|
||||||
|
|
||||||
|
class LabelmeFormat:
|
||||||
|
"""Predict results save into labelme file.
|
||||||
|
|
||||||
|
Base on https://github.com/wkentaro/labelme/blob/main/labelme/label_file.py
|
||||||
|
|
||||||
|
Args:
|
||||||
|
classes (tuple): Model classes name.
|
||||||
|
score_threshold (float): Predict score threshold.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, classes: tuple):
|
||||||
|
super().__init__()
|
||||||
|
self.classes = classes
|
||||||
|
|
||||||
|
def __call__(self, results: DetDataSample, output_path: str,
|
||||||
|
pred_instances: InstanceData):
|
||||||
|
"""Get image data field for labelme.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
results (DetDataSample): Predict info.
|
||||||
|
output_path (str): Image file path.
|
||||||
|
pred_instances (InstanceData): Candidate prediction info.
|
||||||
|
|
||||||
|
Labelme file eg.
|
||||||
|
{
|
||||||
|
"version": "5.0.5",
|
||||||
|
"flags": {},
|
||||||
|
"imagePath": "/data/cat/1.jpg",
|
||||||
|
"imageData": null,
|
||||||
|
"imageHeight": 3000,
|
||||||
|
"imageWidth": 4000,
|
||||||
|
"shapes": [
|
||||||
|
{
|
||||||
|
"label": "cat",
|
||||||
|
"points": [
|
||||||
|
[
|
||||||
|
1148.076923076923,
|
||||||
|
1188.4615384615383
|
||||||
|
],
|
||||||
|
[
|
||||||
|
2471.1538461538457,
|
||||||
|
2176.923076923077
|
||||||
|
]
|
||||||
|
],
|
||||||
|
"group_id": null,
|
||||||
|
"shape_type": "rectangle",
|
||||||
|
"flags": {}
|
||||||
|
},
|
||||||
|
{...}
|
||||||
|
]
|
||||||
|
}
|
||||||
|
"""
|
||||||
|
|
||||||
|
image_path = results.metainfo['img_path']
|
||||||
|
|
||||||
|
json_info = {
|
||||||
|
'version': '5.0.5',
|
||||||
|
'flags': {},
|
||||||
|
'imagePath': image_path,
|
||||||
|
'imageData': None,
|
||||||
|
'imageHeight': results.ori_shape[0],
|
||||||
|
'imageWidth': results.ori_shape[1],
|
||||||
|
'shapes': []
|
||||||
|
}
|
||||||
|
|
||||||
|
for pred_info in pred_instances:
|
||||||
|
pred_bbox = pred_info.bboxes.cpu().numpy().tolist()[0]
|
||||||
|
pred_label = self.classes[pred_info.labels]
|
||||||
|
|
||||||
|
sub_dict = {
|
||||||
|
'label': pred_label,
|
||||||
|
'points': [pred_bbox[:2], pred_bbox[2:]],
|
||||||
|
'group_id': None,
|
||||||
|
'shape_type': 'rectangle',
|
||||||
|
'flags': {}
|
||||||
|
}
|
||||||
|
json_info['shapes'].append(sub_dict)
|
||||||
|
|
||||||
|
with open(output_path, 'w', encoding='utf-8') as f_json:
|
||||||
|
json.dump(json_info, f_json, ensure_ascii=False, indent=2)
|
Loading…
Reference in New Issue