283 lines
11 KiB
Python
283 lines
11 KiB
Python
import argparse
|
|
import os
|
|
import torch
|
|
import cv2
|
|
import numpy as np
|
|
from PIL import Image
|
|
from typing import Tuple, List
|
|
from torchvision.ops import box_convert
|
|
import onnx
|
|
import onnxruntime as ort
|
|
|
|
from groundingdino.util.inference import load_model, annotate
|
|
import groundingdino.datasets.transforms as T
|
|
from groundingdino.util.utils import get_phrases_from_posmap
|
|
from groundingdino.models.GroundingDINO.bertwarper import generate_masks_with_special_tokens
|
|
|
|
def preprocess_caption(caption: str) -> str:
|
|
result = caption.lower().strip()
|
|
if result.endswith("."):
|
|
return result
|
|
return result + "."
|
|
|
|
class Encoder(torch.nn.Module):
|
|
def __init__(self, model):
|
|
super().__init__()
|
|
self.tokenizer = model.tokenizer
|
|
self.bert = model.bert
|
|
self.specical_tokens = model.specical_tokens
|
|
|
|
def forward(self,
|
|
input_ids: torch.Tensor,
|
|
token_type_ids: torch.Tensor,
|
|
text_self_attention_masks: torch.Tensor,
|
|
position_ids: torch.Tensor):
|
|
# extract text embeddings
|
|
tokenized_for_encoder = {}
|
|
tokenized_for_encoder["input_ids"] = input_ids
|
|
tokenized_for_encoder["token_type_ids"] = token_type_ids
|
|
tokenized_for_encoder["attention_mask"] = text_self_attention_masks.type(torch.bool)
|
|
tokenized_for_encoder["position_ids"] = position_ids
|
|
|
|
bert_output = self.bert(**tokenized_for_encoder) # bs, 195, 768
|
|
|
|
return bert_output["last_hidden_state"]
|
|
|
|
class Decoder(torch.nn.Module):
|
|
def __init__(self, model):
|
|
super().__init__()
|
|
self.model = model
|
|
self.tokenizer = model.tokenizer
|
|
self.specical_tokens = model.specical_tokens
|
|
|
|
def forward(self,
|
|
image: torch.Tensor,
|
|
last_hidden_state: torch.Tensor,
|
|
attention_mask: torch.Tensor,
|
|
position_ids: torch.Tensor,
|
|
text_self_attention_masks: torch.Tensor,
|
|
box_threshold: torch.Tensor,
|
|
text_threshold: torch.Tensor):
|
|
outputs = self.model(image,
|
|
last_hidden_state,
|
|
attention_mask,
|
|
position_ids,
|
|
text_self_attention_masks.type(torch.bool))
|
|
prediction_logits = outputs["pred_logits"].sigmoid().squeeze(0)
|
|
prediction_boxes = outputs["pred_boxes"].squeeze(0)
|
|
|
|
mask = prediction_logits.max(dim=1)[0] > box_threshold
|
|
prediction_logits = prediction_logits[mask]
|
|
prediction_input_ids_mask = prediction_logits > text_threshold
|
|
prediction_boxes = prediction_boxes[mask]
|
|
|
|
return (prediction_logits.max(dim=1)[0].unsqueeze(0),
|
|
prediction_boxes.unsqueeze(0),
|
|
prediction_input_ids_mask.unsqueeze(0))
|
|
|
|
def export_encoder(model, output):
|
|
onnx_file = output + "/" + "gdino.encoder.onnx"
|
|
caption = preprocess_caption("watermark")
|
|
tokenized = model.tokenizer(caption, padding="longest", return_tensors="pt")
|
|
|
|
(
|
|
text_self_attention_masks,
|
|
position_ids
|
|
) = generate_masks_with_special_tokens(tokenized, model.specical_tokens, model.tokenizer)
|
|
|
|
torch.onnx.export(
|
|
model,
|
|
args = (
|
|
tokenized["input_ids"].type(torch.int).to("cpu"),
|
|
tokenized["token_type_ids"].type(torch.int).to("cpu"),
|
|
text_self_attention_masks.type(torch.uint8).to("cpu"),
|
|
position_ids.type(torch.int).to("cpu"),
|
|
),
|
|
f = onnx_file,
|
|
input_names = [ "input_ids", "token_type_ids", "text_self_attention_masks", "position_ids" ],
|
|
output_names = [ "last_hidden_state" ],
|
|
opset_version = 17,
|
|
export_params = True,
|
|
do_constant_folding = True,
|
|
dynamic_axes = {
|
|
"input_ids": { 1: "token_num" },
|
|
"token_type_ids": { 1: "token_num" },
|
|
"text_self_attention_masks": { 1: "token_num", 2: "token_num" },
|
|
"position_ids": { 1: "token_num" },
|
|
"last_hidden_state": { 1: "token_num" }
|
|
},
|
|
)
|
|
|
|
print("export gdino.encoder.onnx ok!")
|
|
|
|
onnx_model = onnx.load(onnx_file)
|
|
onnx.checker.check_model(onnx_model)
|
|
print("check gdino.encoder.onnx ok!")
|
|
|
|
def export_decoder(model, output, encoder):
|
|
onnx_file = output + "/" + "gdino.decoder.onnx"
|
|
caption = preprocess_caption("watermark")
|
|
|
|
tokenized, last_hidden_state = inference_encoder_onnx(encoder, output, caption)
|
|
|
|
box_threshold = torch.tensor(0.35, dtype=torch.float32)
|
|
text_threshold = torch.tensor(0.25, dtype=torch.float32)
|
|
|
|
torch.onnx.export(
|
|
model,
|
|
args = (
|
|
torch.rand(1, 3, 800, 800).type(torch.float32).to("cpu"),
|
|
last_hidden_state,
|
|
tokenized["attention_mask"].type(torch.uint8).to("cpu"),
|
|
tokenized["position_ids"].type(torch.int).to("cpu"),
|
|
tokenized["text_self_attention_masks"].type(torch.uint8).to("cpu"),
|
|
box_threshold,
|
|
text_threshold),
|
|
f = onnx_file,
|
|
input_names = [ "image", "last_hidden_state", "attention_mask",
|
|
"position_ids", "text_self_attention_masks",
|
|
"box_threshold", "text_threshold" ],
|
|
output_names = [ "logits", "boxes", "masks" ],
|
|
opset_version = 17,
|
|
export_params = True,
|
|
do_constant_folding = True,
|
|
dynamic_axes = {
|
|
"last_hidden_state": { 1: "token_num" },
|
|
"attention_mask": { 1: "token_num" },
|
|
"position_ids": { 1: "token_num" },
|
|
"text_self_attention_masks": { 1: "token_num", 2: "token_num" }
|
|
},
|
|
)
|
|
|
|
print("export gdino.decoder.onnx ok!")
|
|
|
|
onnx_model = onnx.load(onnx_file)
|
|
onnx.checker.check_model(onnx_model)
|
|
print("check gdino.decoder.onnx ok!")
|
|
|
|
def inference_encoder_onnx(model, output, caption: str = None):
|
|
onnx_file = output + "/" + "gdino.encoder.onnx"
|
|
session = ort.InferenceSession(onnx_file)
|
|
|
|
if caption:
|
|
proc_caption = preprocess_caption(caption)
|
|
else:
|
|
proc_caption = preprocess_caption("watermark. cat. dog")
|
|
tokenized = model.tokenizer(proc_caption, padding="longest", return_tensors="pt")
|
|
|
|
(
|
|
text_self_attention_masks,
|
|
position_ids
|
|
) = generate_masks_with_special_tokens(tokenized, model.specical_tokens, model.tokenizer)
|
|
|
|
tokenized["text_self_attention_masks"] = text_self_attention_masks
|
|
tokenized["position_ids"] = position_ids
|
|
|
|
outputs = session.run(None, {
|
|
"input_ids": tokenized["input_ids"].numpy().astype(np.int32),
|
|
"token_type_ids": tokenized["token_type_ids"].numpy().astype(np.int32),
|
|
"text_self_attention_masks": tokenized["text_self_attention_masks"].numpy().astype(np.uint8),
|
|
"position_ids": tokenized["position_ids"].numpy().astype(np.int32)
|
|
})
|
|
|
|
if caption == None:
|
|
print(outputs)
|
|
|
|
last_hidden_state = torch.from_numpy(outputs[0]).type(torch.float32)
|
|
return tokenized, last_hidden_state
|
|
|
|
def preprocess_image(image_bgr: np.ndarray) -> torch.Tensor:
|
|
image_bgr = cv2.resize(image_bgr, (800, 800))
|
|
transform = T.Compose(
|
|
[
|
|
T.RandomResize([800], max_size=1333),
|
|
T.ToTensor(),
|
|
T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
|
|
]
|
|
)
|
|
image_pillow = Image.fromarray(cv2.cvtColor(image_bgr, cv2.COLOR_BGR2RGB))
|
|
image_transformed, _ = transform(image_pillow, None)
|
|
return image_transformed
|
|
|
|
def inference_decoder_onnx(model, output):
|
|
image = cv2.imread('asset/1.jpg')
|
|
processed_image = preprocess_image(image).unsqueeze(0)
|
|
|
|
caption = "watermark. glasses"
|
|
|
|
tokenized, last_hidden_state = inference_encoder_onnx(model, output, caption)
|
|
|
|
print(tokenized)
|
|
print(last_hidden_state)
|
|
|
|
onnx_file = output + "/" + "gdino.decoder.onnx"
|
|
session = ort.InferenceSession(onnx_file)
|
|
|
|
box_threshold = torch.tensor(0.35, dtype=torch.float32)
|
|
text_threshold = torch.tensor(0.25, dtype=torch.float32)
|
|
|
|
decode_outputs = session.run(None, {
|
|
"image": processed_image.numpy().astype(np.float32),
|
|
"last_hidden_state": last_hidden_state.numpy().astype(np.float32),
|
|
"attention_mask": tokenized["attention_mask"].numpy().astype(np.uint8),
|
|
"position_ids": tokenized["position_ids"].numpy().astype(np.int32),
|
|
"text_self_attention_masks": tokenized["text_self_attention_masks"].numpy().astype(np.uint8),
|
|
"box_threshold": box_threshold.numpy().astype(np.float32),
|
|
"text_threshold": text_threshold.numpy().astype(np.float32)
|
|
})
|
|
|
|
prediction_logits = torch.from_numpy(decode_outputs[0])
|
|
prediction_boxes = torch.from_numpy(decode_outputs[1])
|
|
prediction_masks = torch.from_numpy(decode_outputs[2])
|
|
|
|
input_ids = tokenized["input_ids"][0].tolist()
|
|
phrases = []
|
|
for mask in prediction_masks[0]:
|
|
prediction_token_ids = [input_ids[i] for i in mask.nonzero(as_tuple=True)[0].tolist()]
|
|
phrases.append(model.tokenizer.decode(prediction_token_ids).replace('.', ''))
|
|
|
|
with torch.no_grad():
|
|
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
|
|
image = annotate(image, prediction_boxes[0], prediction_logits[0], phrases)
|
|
|
|
cv2.imshow("image", image)
|
|
cv2.waitKey()
|
|
cv2.destroyAllWindows()
|
|
|
|
if __name__ == "__main__":
|
|
parser = argparse.ArgumentParser("Export Grounding DINO Model to ONNX", add_help=True)
|
|
parser.add_argument("--encode", "-e", help="test encoder.onnx model", action="store_true")
|
|
parser.add_argument("--decode", "-d", help="test decoder.onnx model", action="store_true")
|
|
parser.add_argument("--config_file", "-c", type=str, required=True, help="path to config file")
|
|
parser.add_argument(
|
|
"--checkpoint_path", "-p", type=str, required=True, help="path to checkpoint file"
|
|
)
|
|
parser.add_argument(
|
|
"--output_dir", "-o", type=str, default="outputs", required=True, help="output directory"
|
|
)
|
|
|
|
args = parser.parse_args()
|
|
|
|
# cfg
|
|
config_file = args.config_file # change the path of the model config file
|
|
checkpoint_path = args.checkpoint_path # change the path of the model
|
|
output_dir = args.output_dir
|
|
|
|
# make dir
|
|
os.makedirs(output_dir, exist_ok=True)
|
|
|
|
source_model = load_model(model_config_path = config_file,
|
|
model_checkpoint_path = checkpoint_path,
|
|
device = "cpu").to("cpu")
|
|
|
|
encoder = Encoder(source_model)
|
|
decoder = Decoder(source_model)
|
|
|
|
if args.encode:
|
|
inference_encoder_onnx(encoder, output_dir)
|
|
elif args.decode:
|
|
inference_decoder_onnx(decoder, output_dir)
|
|
else:
|
|
export_encoder(encoder, output_dir)
|
|
export_decoder(decoder, output_dir, encoder)
|