GroundingDINO/export.py

283 lines
11 KiB
Python

import argparse
import os
import torch
import cv2
import numpy as np
from PIL import Image
from typing import Tuple, List
from torchvision.ops import box_convert
import onnx
import onnxruntime as ort
from groundingdino.util.inference import load_model, annotate
import groundingdino.datasets.transforms as T
from groundingdino.util.utils import get_phrases_from_posmap
from groundingdino.models.GroundingDINO.bertwarper import generate_masks_with_special_tokens
def preprocess_caption(caption: str) -> str:
result = caption.lower().strip()
if result.endswith("."):
return result
return result + "."
class Encoder(torch.nn.Module):
def __init__(self, model):
super().__init__()
self.tokenizer = model.tokenizer
self.bert = model.bert
self.specical_tokens = model.specical_tokens
def forward(self,
input_ids: torch.Tensor,
token_type_ids: torch.Tensor,
text_self_attention_masks: torch.Tensor,
position_ids: torch.Tensor):
# extract text embeddings
tokenized_for_encoder = {}
tokenized_for_encoder["input_ids"] = input_ids
tokenized_for_encoder["token_type_ids"] = token_type_ids
tokenized_for_encoder["attention_mask"] = text_self_attention_masks.type(torch.bool)
tokenized_for_encoder["position_ids"] = position_ids
bert_output = self.bert(**tokenized_for_encoder) # bs, 195, 768
return bert_output["last_hidden_state"]
class Decoder(torch.nn.Module):
def __init__(self, model):
super().__init__()
self.model = model
self.tokenizer = model.tokenizer
self.specical_tokens = model.specical_tokens
def forward(self,
image: torch.Tensor,
last_hidden_state: torch.Tensor,
attention_mask: torch.Tensor,
position_ids: torch.Tensor,
text_self_attention_masks: torch.Tensor,
box_threshold: torch.Tensor,
text_threshold: torch.Tensor):
outputs = self.model(image,
last_hidden_state,
attention_mask,
position_ids,
text_self_attention_masks.type(torch.bool))
prediction_logits = outputs["pred_logits"].sigmoid().squeeze(0)
prediction_boxes = outputs["pred_boxes"].squeeze(0)
mask = prediction_logits.max(dim=1)[0] > box_threshold
prediction_logits = prediction_logits[mask]
prediction_input_ids_mask = prediction_logits > text_threshold
prediction_boxes = prediction_boxes[mask]
return (prediction_logits.max(dim=1)[0].unsqueeze(0),
prediction_boxes.unsqueeze(0),
prediction_input_ids_mask.unsqueeze(0))
def export_encoder(model, output):
onnx_file = output + "/" + "gdino.encoder.onnx"
caption = preprocess_caption("watermark")
tokenized = model.tokenizer(caption, padding="longest", return_tensors="pt")
(
text_self_attention_masks,
position_ids
) = generate_masks_with_special_tokens(tokenized, model.specical_tokens, model.tokenizer)
torch.onnx.export(
model,
args = (
tokenized["input_ids"].type(torch.int).to("cpu"),
tokenized["token_type_ids"].type(torch.int).to("cpu"),
text_self_attention_masks.type(torch.uint8).to("cpu"),
position_ids.type(torch.int).to("cpu"),
),
f = onnx_file,
input_names = [ "input_ids", "token_type_ids", "text_self_attention_masks", "position_ids" ],
output_names = [ "last_hidden_state" ],
opset_version = 17,
export_params = True,
do_constant_folding = True,
dynamic_axes = {
"input_ids": { 1: "token_num" },
"token_type_ids": { 1: "token_num" },
"text_self_attention_masks": { 1: "token_num", 2: "token_num" },
"position_ids": { 1: "token_num" },
"last_hidden_state": { 1: "token_num" }
},
)
print("export gdino.encoder.onnx ok!")
onnx_model = onnx.load(onnx_file)
onnx.checker.check_model(onnx_model)
print("check gdino.encoder.onnx ok!")
def export_decoder(model, output, encoder):
onnx_file = output + "/" + "gdino.decoder.onnx"
caption = preprocess_caption("watermark")
tokenized, last_hidden_state = inference_encoder_onnx(encoder, output, caption)
box_threshold = torch.tensor(0.35, dtype=torch.float32)
text_threshold = torch.tensor(0.25, dtype=torch.float32)
torch.onnx.export(
model,
args = (
torch.rand(1, 3, 800, 800).type(torch.float32).to("cpu"),
last_hidden_state,
tokenized["attention_mask"].type(torch.uint8).to("cpu"),
tokenized["position_ids"].type(torch.int).to("cpu"),
tokenized["text_self_attention_masks"].type(torch.uint8).to("cpu"),
box_threshold,
text_threshold),
f = onnx_file,
input_names = [ "image", "last_hidden_state", "attention_mask",
"position_ids", "text_self_attention_masks",
"box_threshold", "text_threshold" ],
output_names = [ "logits", "boxes", "masks" ],
opset_version = 17,
export_params = True,
do_constant_folding = True,
dynamic_axes = {
"last_hidden_state": { 1: "token_num" },
"attention_mask": { 1: "token_num" },
"position_ids": { 1: "token_num" },
"text_self_attention_masks": { 1: "token_num", 2: "token_num" }
},
)
print("export gdino.decoder.onnx ok!")
onnx_model = onnx.load(onnx_file)
onnx.checker.check_model(onnx_model)
print("check gdino.decoder.onnx ok!")
def inference_encoder_onnx(model, output, caption: str = None):
onnx_file = output + "/" + "gdino.encoder.onnx"
session = ort.InferenceSession(onnx_file)
if caption:
proc_caption = preprocess_caption(caption)
else:
proc_caption = preprocess_caption("watermark. cat. dog")
tokenized = model.tokenizer(proc_caption, padding="longest", return_tensors="pt")
(
text_self_attention_masks,
position_ids
) = generate_masks_with_special_tokens(tokenized, model.specical_tokens, model.tokenizer)
tokenized["text_self_attention_masks"] = text_self_attention_masks
tokenized["position_ids"] = position_ids
outputs = session.run(None, {
"input_ids": tokenized["input_ids"].numpy().astype(np.int32),
"token_type_ids": tokenized["token_type_ids"].numpy().astype(np.int32),
"text_self_attention_masks": tokenized["text_self_attention_masks"].numpy().astype(np.uint8),
"position_ids": tokenized["position_ids"].numpy().astype(np.int32)
})
if caption == None:
print(outputs)
last_hidden_state = torch.from_numpy(outputs[0]).type(torch.float32)
return tokenized, last_hidden_state
def preprocess_image(image_bgr: np.ndarray) -> torch.Tensor:
image_bgr = cv2.resize(image_bgr, (800, 800))
transform = T.Compose(
[
T.RandomResize([800], max_size=1333),
T.ToTensor(),
T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
]
)
image_pillow = Image.fromarray(cv2.cvtColor(image_bgr, cv2.COLOR_BGR2RGB))
image_transformed, _ = transform(image_pillow, None)
return image_transformed
def inference_decoder_onnx(model, output):
image = cv2.imread('asset/1.jpg')
processed_image = preprocess_image(image).unsqueeze(0)
caption = "watermark. glasses"
tokenized, last_hidden_state = inference_encoder_onnx(model, output, caption)
print(tokenized)
print(last_hidden_state)
onnx_file = output + "/" + "gdino.decoder.onnx"
session = ort.InferenceSession(onnx_file)
box_threshold = torch.tensor(0.35, dtype=torch.float32)
text_threshold = torch.tensor(0.25, dtype=torch.float32)
decode_outputs = session.run(None, {
"image": processed_image.numpy().astype(np.float32),
"last_hidden_state": last_hidden_state.numpy().astype(np.float32),
"attention_mask": tokenized["attention_mask"].numpy().astype(np.uint8),
"position_ids": tokenized["position_ids"].numpy().astype(np.int32),
"text_self_attention_masks": tokenized["text_self_attention_masks"].numpy().astype(np.uint8),
"box_threshold": box_threshold.numpy().astype(np.float32),
"text_threshold": text_threshold.numpy().astype(np.float32)
})
prediction_logits = torch.from_numpy(decode_outputs[0])
prediction_boxes = torch.from_numpy(decode_outputs[1])
prediction_masks = torch.from_numpy(decode_outputs[2])
input_ids = tokenized["input_ids"][0].tolist()
phrases = []
for mask in prediction_masks[0]:
prediction_token_ids = [input_ids[i] for i in mask.nonzero(as_tuple=True)[0].tolist()]
phrases.append(model.tokenizer.decode(prediction_token_ids).replace('.', ''))
with torch.no_grad():
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
image = annotate(image, prediction_boxes[0], prediction_logits[0], phrases)
cv2.imshow("image", image)
cv2.waitKey()
cv2.destroyAllWindows()
if __name__ == "__main__":
parser = argparse.ArgumentParser("Export Grounding DINO Model to ONNX", add_help=True)
parser.add_argument("--encode", "-e", help="test encoder.onnx model", action="store_true")
parser.add_argument("--decode", "-d", help="test decoder.onnx model", action="store_true")
parser.add_argument("--config_file", "-c", type=str, required=True, help="path to config file")
parser.add_argument(
"--checkpoint_path", "-p", type=str, required=True, help="path to checkpoint file"
)
parser.add_argument(
"--output_dir", "-o", type=str, default="outputs", required=True, help="output directory"
)
args = parser.parse_args()
# cfg
config_file = args.config_file # change the path of the model config file
checkpoint_path = args.checkpoint_path # change the path of the model
output_dir = args.output_dir
# make dir
os.makedirs(output_dir, exist_ok=True)
source_model = load_model(model_config_path = config_file,
model_checkpoint_path = checkpoint_path,
device = "cpu").to("cpu")
encoder = Encoder(source_model)
decoder = Decoder(source_model)
if args.encode:
inference_encoder_onnx(encoder, output_dir)
elif args.decode:
inference_decoder_onnx(decoder, output_dir)
else:
export_encoder(encoder, output_dir)
export_decoder(decoder, output_dir, encoder)