GroundingDINO/demo/create_coco_dataset.py
George Pearse e93548c805
Create create_coco_dataset.py (#17)
add create_coco_dataset demo
2023-04-06 15:04:01 +08:00

84 lines
2.3 KiB
Python
Raw Permalink Blame History

This file contains invisible Unicode characters

This file contains invisible Unicode characters that are indistinguishable to humans but may be processed differently by a computer. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import typer
from groundingdino.util.inference import load_model, load_image, predict
from tqdm import tqdm
import torchvision
import torch
import fiftyone as fo
def main(
image_directory: str = 'test_grounding_dino',
text_prompt: str = 'bus, car',
box_threshold: float = 0.15,
text_threshold: float = 0.10,
export_dataset: bool = False,
view_dataset: bool = False,
export_annotated_images: bool = True,
weights_path : str = "groundingdino_swint_ogc.pth",
config_path: str = "../../GroundingDINO/groundingdino/config/GroundingDINO_SwinT_OGC.py",
subsample: int = None,
):
model = load_model(config_path, weights_path)
dataset = fo.Dataset.from_images_dir(image_directory)
samples = []
if subsample is not None:
if subsample < len(dataset):
dataset = dataset.take(subsample).clone()
for sample in tqdm(dataset):
image_source, image = load_image(sample.filepath)
boxes, logits, phrases = predict(
model=model,
image=image,
caption=text_prompt,
box_threshold=box_threshold,
text_threshold=text_threshold,
)
detections = []
for box, logit, phrase in zip(boxes, logits, phrases):
rel_box = torchvision.ops.box_convert(box, 'cxcywh', 'xywh')
detections.append(
fo.Detection(
label=phrase,
bounding_box=rel_box,
confidence=logit,
))
# Store detections in a field name of your choice
sample["detections"] = fo.Detections(detections=detections)
sample.save()
# loads the voxel fiftyone UI ready for viewing the dataset.
if view_dataset:
session = fo.launch_app(dataset)
session.wait()
# exports COCO dataset ready for training
if export_dataset:
dataset.export(
'coco_dataset',
dataset_type=fo.types.COCODetectionDataset,
)
# saves bounding boxes plotted on the input images to disk
if export_annotated_images:
dataset.draw_labels(
'images_with_bounding_boxes',
label_fields=['detections']
)
if __name__ == '__main__':
typer.run(main)