parent
17f37607dc
commit
89b1ad20b9
83
export.py
83
export.py
|
@ -12,6 +12,7 @@ import onnxruntime as ort
|
|||
from groundingdino.util.inference import load_model, annotate
|
||||
import groundingdino.datasets.transforms as T
|
||||
from groundingdino.util.utils import get_phrases_from_posmap
|
||||
from groundingdino.models.GroundingDINO.bertwarper import generate_masks_with_special_tokens_and_transfer_map
|
||||
|
||||
class Model(torch.nn.Module):
|
||||
def __init__(
|
||||
|
@ -27,18 +28,21 @@ class Model(torch.nn.Module):
|
|||
device=device
|
||||
).to(device)
|
||||
self.tokenizer = self.model.tokenizer
|
||||
self.specical_tokens = self.model.specical_tokens
|
||||
self.max_text_len = self.model.max_text_len
|
||||
|
||||
# def forward(self, samples: NestedTensor, targets: List = None, **kw):
|
||||
def forward(self,
|
||||
image: torch.Tensor,
|
||||
input_ids: torch.Tensor,
|
||||
attention_mask: torch.Tensor,
|
||||
token_type_ids: torch.Tensor,
|
||||
position_ids: torch.Tensor,
|
||||
text_self_attention_masks: torch.Tensor,
|
||||
box_threshold: torch.Tensor,
|
||||
text_threshold: torch.Tensor,
|
||||
**kw):
|
||||
token_type_ids = torch.zeros(input_ids.size(), dtype=torch.int32)
|
||||
attention_mask = (input_ids != 0).int()
|
||||
|
||||
outputs = self.model(image, input_ids, attention_mask, token_type_ids)
|
||||
outputs = self.model(image, input_ids, attention_mask, token_type_ids, position_ids, text_self_attention_masks)
|
||||
prediction_logits = outputs["pred_logits"].sigmoid().squeeze(0)
|
||||
prediction_boxes = outputs["pred_boxes"].squeeze(0)
|
||||
|
||||
|
@ -70,26 +74,43 @@ def preprocess_caption(caption: str) -> str:
|
|||
|
||||
def export_onnx(model, output_dir):
|
||||
onnx_file = output_dir + "/" + "gdino.onnx"
|
||||
caption = preprocess_caption(".")
|
||||
caption = preprocess_caption("watermark")
|
||||
tokenized = model.tokenizer(caption, padding="longest", return_tensors="pt")
|
||||
box_threshold = torch.tensor(0.35, dtype=torch.float32)
|
||||
text_threshold = torch.tensor(0.25, dtype=torch.float32)
|
||||
|
||||
specical_tokens = model.specical_tokens
|
||||
(
|
||||
text_self_attention_masks,
|
||||
position_ids,
|
||||
_,
|
||||
) = generate_masks_with_special_tokens_and_transfer_map(
|
||||
tokenized, specical_tokens, model.tokenizer
|
||||
)
|
||||
|
||||
torch.onnx.export(
|
||||
model,
|
||||
args = (
|
||||
torch.rand(1, 3, 800, 800).type(torch.float32).to("cpu"),
|
||||
tokenized["input_ids"],
|
||||
tokenized["input_ids"].type(torch.int).to("cpu"),
|
||||
tokenized["attention_mask"].type(torch.uint8).to("cpu"),
|
||||
tokenized["token_type_ids"].type(torch.int).to("cpu"),
|
||||
position_ids.type(torch.int).to("cpu"),
|
||||
text_self_attention_masks.type(torch.bool).to("cpu"),
|
||||
box_threshold,
|
||||
text_threshold),
|
||||
f = onnx_file,
|
||||
input_names = [ "image", "input_ids", "box_threshold", "text_threshold" ],
|
||||
input_names = [ "image", "input_ids", "attention_mask", "token_type_ids", "position_ids", "text_self_attention_masks", "box_threshold", "text_threshold" ],
|
||||
output_names = [ "logits", "boxes", "masks" ],
|
||||
opset_version = 17,
|
||||
export_params = True,
|
||||
do_constant_folding = True,
|
||||
dynamic_axes = {
|
||||
"input_ids": { 1: "token_num" }
|
||||
"input_ids": { 1: "token_num" },
|
||||
"attention_mask": { 1: "token_num" },
|
||||
"token_type_ids": { 1: "token_num" },
|
||||
"position_ids": { 1: "token_num" },
|
||||
"text_self_attention_masks": { 1: "token_num", 2: "token_num" }
|
||||
},
|
||||
)
|
||||
|
||||
|
@ -100,15 +121,28 @@ def export_onnx(model, output_dir):
|
|||
print("check model ok!")
|
||||
|
||||
def inference(model):
|
||||
image = cv2.imread('asset/1.jpg')
|
||||
image = cv2.imread('asset/cat_dog.jpeg')
|
||||
processed_image = preprocess_image(image).unsqueeze(0)
|
||||
caption = preprocess_caption("watermark")
|
||||
caption = preprocess_caption("cat. dog")
|
||||
tokenized = model.tokenizer(caption, padding="longest", return_tensors="pt")
|
||||
box_threshold = torch.tensor(0.35, dtype=torch.float32)
|
||||
text_threshold = torch.tensor(0.25, dtype=torch.float32)
|
||||
|
||||
specical_tokens = model.specical_tokens
|
||||
(
|
||||
text_self_attention_masks,
|
||||
position_ids,
|
||||
_,
|
||||
) = generate_masks_with_special_tokens_and_transfer_map(
|
||||
tokenized, specical_tokens, model.tokenizer
|
||||
)
|
||||
|
||||
outputs = model(processed_image,
|
||||
tokenized["input_ids"],
|
||||
tokenized["attention_mask"],
|
||||
tokenized["token_type_ids"],
|
||||
position_ids,
|
||||
text_self_attention_masks,
|
||||
box_threshold,
|
||||
text_threshold)
|
||||
|
||||
|
@ -130,7 +164,7 @@ def inference(model):
|
|||
cv2.waitKey()
|
||||
cv2.destroyAllWindows()
|
||||
|
||||
def inference_onnx(output_dir):
|
||||
def inference_onnx(model, output_dir):
|
||||
onnx_file = output_dir + "/" + "gdino.onnx"
|
||||
session = ort.InferenceSession(onnx_file)
|
||||
|
||||
|
@ -141,9 +175,32 @@ def inference_onnx(output_dir):
|
|||
box_threshold = torch.tensor(0.35, dtype=torch.float32)
|
||||
text_threshold = torch.tensor(0.25, dtype=torch.float32)
|
||||
|
||||
specical_tokens = model.specical_tokens
|
||||
(
|
||||
text_self_attention_masks,
|
||||
position_ids,
|
||||
_,
|
||||
) = generate_masks_with_special_tokens_and_transfer_map(
|
||||
tokenized, specical_tokens, model.tokenizer
|
||||
)
|
||||
|
||||
max_text_len = model.max_text_len
|
||||
if text_self_attention_masks.shape[1] > max_text_len:
|
||||
text_self_attention_masks = text_self_attention_masks[
|
||||
:, : max_text_len, : max_text_len
|
||||
]
|
||||
position_ids = position_ids[:, : max_text_len]
|
||||
tokenized["input_ids"] = tokenized["input_ids"][:, : max_text_len]
|
||||
tokenized["attention_mask"] = tokenized["attention_mask"][:, : max_text_len]
|
||||
tokenized["token_type_ids"] = tokenized["token_type_ids"][:, : max_text_len]
|
||||
|
||||
outputs = session.run(None, {
|
||||
"image": processed_image.numpy().astype(np.float32) ,
|
||||
"input_ids": tokenized["input_ids"].numpy().astype(np.int64) ,
|
||||
"input_ids": tokenized["input_ids"].numpy().astype(np.int32) ,
|
||||
"attention_mask": tokenized["attention_mask"].numpy().astype(np.uint8) ,
|
||||
"token_type_ids": tokenized["token_type_ids"].numpy().astype(np.int32) ,
|
||||
"position_ids": position_ids.numpy().astype(np.int32) ,
|
||||
"text_self_attention_masks": text_self_attention_masks.numpy().astype(np.bool) ,
|
||||
"box_threshold": box_threshold.numpy().astype(np.float32) ,
|
||||
"text_threshold": text_threshold.numpy().astype(np.float32)
|
||||
})
|
||||
|
@ -191,7 +248,7 @@ if __name__ == "__main__":
|
|||
model = Model(config_file, checkpoint_path, device='cpu')
|
||||
|
||||
if args.test:
|
||||
inference_onnx(output_dir)
|
||||
inference_onnx(model, output_dir)
|
||||
elif args.orig:
|
||||
inference(model)
|
||||
else:
|
||||
|
|
Loading…
Reference in New Issue