mirror of https://github.com/open-mmlab/mmyolo.git
[Improvement] Add `--class-name` for filter when save to labelme label. (#314)
* Add `--class-name` for filter when save to labelme label. * Improve code * Add check class * Improve coding * Improve codingpull/317/head
parent
9221499af4
commit
e3f1cf93a6
|
@ -10,7 +10,7 @@ from mmengine.utils import ProgressBar
|
|||
from mmyolo.registry import VISUALIZERS
|
||||
from mmyolo.utils import register_all_modules, switch_to_deploy
|
||||
from mmyolo.utils.labelme_utils import LabelmeFormat
|
||||
from mmyolo.utils.misc import get_file_list
|
||||
from mmyolo.utils.misc import get_file_list, show_data_classes
|
||||
|
||||
|
||||
def parse_args():
|
||||
|
@ -31,6 +31,11 @@ def parse_args():
|
|||
help='Switch model to deployment mode')
|
||||
parser.add_argument(
|
||||
'--score-thr', type=float, default=0.3, help='Bbox score threshold')
|
||||
parser.add_argument(
|
||||
'--class-name',
|
||||
nargs='+',
|
||||
type=str,
|
||||
help='Only Save those classes if set')
|
||||
parser.add_argument(
|
||||
'--to-labelme',
|
||||
action='store_true',
|
||||
|
@ -65,8 +70,21 @@ def main():
|
|||
# get file list
|
||||
files, source_type = get_file_list(args.img)
|
||||
|
||||
# get model class name
|
||||
dataset_classes = model.dataset_meta.get('CLASSES')
|
||||
|
||||
# ready for labelme format if it is needed
|
||||
to_label_format = LabelmeFormat(classes=model.dataset_meta.get('CLASSES'))
|
||||
to_label_format = LabelmeFormat(classes=dataset_classes)
|
||||
|
||||
# check class name
|
||||
if args.class_name is not None:
|
||||
for class_name in args.class_name:
|
||||
if class_name in dataset_classes:
|
||||
continue
|
||||
show_data_classes(dataset_classes)
|
||||
raise RuntimeError(
|
||||
'Expected args.class_name to be one of the list, '
|
||||
f'but got "{class_name}"')
|
||||
|
||||
# start detector inference
|
||||
progress_bar = ProgressBar(len(files))
|
||||
|
@ -92,7 +110,8 @@ def main():
|
|||
# save result to labelme files
|
||||
out_file = out_file.replace(
|
||||
os.path.splitext(out_file)[-1], '.json')
|
||||
to_label_format(result, out_file, pred_instances)
|
||||
to_label_format(pred_instances, result.metainfo, out_file,
|
||||
args.class_name)
|
||||
continue
|
||||
|
||||
visualizer.add_datasample(
|
||||
|
|
|
@ -1,7 +1,6 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import json
|
||||
|
||||
from mmdet.structures import DetDataSample
|
||||
from mmengine.structures import InstanceData
|
||||
|
||||
|
||||
|
@ -12,21 +11,21 @@ class LabelmeFormat:
|
|||
|
||||
Args:
|
||||
classes (tuple): Model classes name.
|
||||
score_threshold (float): Predict score threshold.
|
||||
"""
|
||||
|
||||
def __init__(self, classes: tuple):
|
||||
super().__init__()
|
||||
self.classes = classes
|
||||
|
||||
def __call__(self, results: DetDataSample, output_path: str,
|
||||
pred_instances: InstanceData):
|
||||
def __call__(self, pred_instances: InstanceData, metainfo: dict,
|
||||
output_path: str, selected_classes: list):
|
||||
"""Get image data field for labelme.
|
||||
|
||||
Args:
|
||||
results (DetDataSample): Predict info.
|
||||
output_path (str): Image file path.
|
||||
pred_instances (InstanceData): Candidate prediction info.
|
||||
metainfo (dict): Meta info of prediction.
|
||||
output_path (str): Image file path.
|
||||
selected_classes (list): Selected class name.
|
||||
|
||||
Labelme file eg.
|
||||
{
|
||||
|
@ -58,21 +57,26 @@ class LabelmeFormat:
|
|||
}
|
||||
"""
|
||||
|
||||
image_path = results.metainfo['img_path']
|
||||
image_path = metainfo['img_path']
|
||||
|
||||
json_info = {
|
||||
'version': '5.0.5',
|
||||
'flags': {},
|
||||
'imagePath': image_path,
|
||||
'imageData': None,
|
||||
'imageHeight': results.ori_shape[0],
|
||||
'imageWidth': results.ori_shape[1],
|
||||
'imageHeight': metainfo['ori_shape'][0],
|
||||
'imageWidth': metainfo['ori_shape'][1],
|
||||
'shapes': []
|
||||
}
|
||||
|
||||
for pred_info in pred_instances:
|
||||
pred_bbox = pred_info.bboxes.cpu().numpy().tolist()[0]
|
||||
pred_label = self.classes[pred_info.labels]
|
||||
for pred_instance in pred_instances:
|
||||
pred_bbox = pred_instance.bboxes.cpu().numpy().tolist()[0]
|
||||
pred_label = self.classes[pred_instance.labels]
|
||||
|
||||
if selected_classes is not None and \
|
||||
pred_label not in selected_classes:
|
||||
# filter class name
|
||||
continue
|
||||
|
||||
sub_dict = {
|
||||
'label': pred_label,
|
||||
|
|
|
@ -5,6 +5,7 @@ import urllib
|
|||
import numpy as np
|
||||
import torch
|
||||
from mmengine.utils import scandir
|
||||
from prettytable import PrettyTable
|
||||
|
||||
from mmyolo.models import RepVGGBlock
|
||||
|
||||
|
@ -90,3 +91,26 @@ def get_file_list(source_root: str) -> [list, dict]:
|
|||
source_type = dict(is_dir=is_dir, is_url=is_url, is_file=is_file)
|
||||
|
||||
return source_file_path_list, source_type
|
||||
|
||||
|
||||
def show_data_classes(data_classes):
|
||||
"""When printing an error, all class names of the dataset."""
|
||||
print('\n\nThe name of the class contained in the dataset:')
|
||||
data_classes_info = PrettyTable()
|
||||
data_classes_info.title = 'Information of dataset class'
|
||||
# List Print Settings
|
||||
# If the quantity is too large, 25 rows will be displayed in each column
|
||||
if len(data_classes) < 25:
|
||||
data_classes_info.add_column('Class name', data_classes)
|
||||
elif len(data_classes) % 25 != 0 and len(data_classes) > 25:
|
||||
col_num = int(len(data_classes) / 25) + 1
|
||||
data_name_list = list(data_classes)
|
||||
for i in range(0, (col_num * 25) - len(data_classes)):
|
||||
data_name_list.append('')
|
||||
for i in range(0, len(data_name_list), 25):
|
||||
data_classes_info.add_column('Class name',
|
||||
data_name_list[i:i + 25])
|
||||
|
||||
# Align display data to the left
|
||||
data_classes_info.align['Class name'] = 'l'
|
||||
print(data_classes_info)
|
||||
|
|
|
@ -1 +1,2 @@
|
|||
numpy
|
||||
prettytable
|
||||
|
|
|
@ -12,6 +12,7 @@ from prettytable import PrettyTable
|
|||
|
||||
from mmyolo.registry import DATASETS
|
||||
from mmyolo.utils import register_all_modules
|
||||
from mmyolo.utils.misc import show_data_classes
|
||||
|
||||
|
||||
def parse_args():
|
||||
|
@ -347,29 +348,6 @@ def show_data_list(args, area_rule):
|
|||
print(data_info)
|
||||
|
||||
|
||||
def show_data_classes(data_classes):
|
||||
"""When printing an error, all class names of the dataset."""
|
||||
print('\n\nThe name of the class contained in the dataset:')
|
||||
data_classes_info = PrettyTable()
|
||||
data_classes_info.title = 'Information of dataset class'
|
||||
# List Print Settings
|
||||
# If the quantity is too large, 25 rows will be displayed in each column
|
||||
if len(data_classes) < 25:
|
||||
data_classes_info.add_column('Class name', data_classes)
|
||||
elif len(data_classes) % 25 != 0 and len(data_classes) > 25:
|
||||
col_num = int(len(data_classes) / 25) + 1
|
||||
data_name_list = list(data_classes)
|
||||
for i in range(0, (col_num * 25) - len(data_classes)):
|
||||
data_name_list.append('')
|
||||
for i in range(0, len(data_name_list), 25):
|
||||
data_classes_info.add_column('Class name',
|
||||
data_name_list[i:i + 25])
|
||||
|
||||
# Align display data to the left
|
||||
data_classes_info.align['Class name'] = 'l'
|
||||
print(data_classes_info)
|
||||
|
||||
|
||||
def main():
|
||||
args = parse_args()
|
||||
cfg = Config.fromfile(args.config)
|
||||
|
|
Loading…
Reference in New Issue