parent
17f37607dc
commit
89b1ad20b9
83
export.py
83
export.py
|
@ -12,6 +12,7 @@ import onnxruntime as ort
|
||||||
from groundingdino.util.inference import load_model, annotate
|
from groundingdino.util.inference import load_model, annotate
|
||||||
import groundingdino.datasets.transforms as T
|
import groundingdino.datasets.transforms as T
|
||||||
from groundingdino.util.utils import get_phrases_from_posmap
|
from groundingdino.util.utils import get_phrases_from_posmap
|
||||||
|
from groundingdino.models.GroundingDINO.bertwarper import generate_masks_with_special_tokens_and_transfer_map
|
||||||
|
|
||||||
class Model(torch.nn.Module):
|
class Model(torch.nn.Module):
|
||||||
def __init__(
|
def __init__(
|
||||||
|
@ -27,18 +28,21 @@ class Model(torch.nn.Module):
|
||||||
device=device
|
device=device
|
||||||
).to(device)
|
).to(device)
|
||||||
self.tokenizer = self.model.tokenizer
|
self.tokenizer = self.model.tokenizer
|
||||||
|
self.specical_tokens = self.model.specical_tokens
|
||||||
|
self.max_text_len = self.model.max_text_len
|
||||||
|
|
||||||
# def forward(self, samples: NestedTensor, targets: List = None, **kw):
|
# def forward(self, samples: NestedTensor, targets: List = None, **kw):
|
||||||
def forward(self,
|
def forward(self,
|
||||||
image: torch.Tensor,
|
image: torch.Tensor,
|
||||||
input_ids: torch.Tensor,
|
input_ids: torch.Tensor,
|
||||||
|
attention_mask: torch.Tensor,
|
||||||
|
token_type_ids: torch.Tensor,
|
||||||
|
position_ids: torch.Tensor,
|
||||||
|
text_self_attention_masks: torch.Tensor,
|
||||||
box_threshold: torch.Tensor,
|
box_threshold: torch.Tensor,
|
||||||
text_threshold: torch.Tensor,
|
text_threshold: torch.Tensor,
|
||||||
**kw):
|
**kw):
|
||||||
token_type_ids = torch.zeros(input_ids.size(), dtype=torch.int32)
|
outputs = self.model(image, input_ids, attention_mask, token_type_ids, position_ids, text_self_attention_masks)
|
||||||
attention_mask = (input_ids != 0).int()
|
|
||||||
|
|
||||||
outputs = self.model(image, input_ids, attention_mask, token_type_ids)
|
|
||||||
prediction_logits = outputs["pred_logits"].sigmoid().squeeze(0)
|
prediction_logits = outputs["pred_logits"].sigmoid().squeeze(0)
|
||||||
prediction_boxes = outputs["pred_boxes"].squeeze(0)
|
prediction_boxes = outputs["pred_boxes"].squeeze(0)
|
||||||
|
|
||||||
|
@ -70,26 +74,43 @@ def preprocess_caption(caption: str) -> str:
|
||||||
|
|
||||||
def export_onnx(model, output_dir):
|
def export_onnx(model, output_dir):
|
||||||
onnx_file = output_dir + "/" + "gdino.onnx"
|
onnx_file = output_dir + "/" + "gdino.onnx"
|
||||||
caption = preprocess_caption(".")
|
caption = preprocess_caption("watermark")
|
||||||
tokenized = model.tokenizer(caption, padding="longest", return_tensors="pt")
|
tokenized = model.tokenizer(caption, padding="longest", return_tensors="pt")
|
||||||
box_threshold = torch.tensor(0.35, dtype=torch.float32)
|
box_threshold = torch.tensor(0.35, dtype=torch.float32)
|
||||||
text_threshold = torch.tensor(0.25, dtype=torch.float32)
|
text_threshold = torch.tensor(0.25, dtype=torch.float32)
|
||||||
|
|
||||||
|
specical_tokens = model.specical_tokens
|
||||||
|
(
|
||||||
|
text_self_attention_masks,
|
||||||
|
position_ids,
|
||||||
|
_,
|
||||||
|
) = generate_masks_with_special_tokens_and_transfer_map(
|
||||||
|
tokenized, specical_tokens, model.tokenizer
|
||||||
|
)
|
||||||
|
|
||||||
torch.onnx.export(
|
torch.onnx.export(
|
||||||
model,
|
model,
|
||||||
args = (
|
args = (
|
||||||
torch.rand(1, 3, 800, 800).type(torch.float32).to("cpu"),
|
torch.rand(1, 3, 800, 800).type(torch.float32).to("cpu"),
|
||||||
tokenized["input_ids"],
|
tokenized["input_ids"].type(torch.int).to("cpu"),
|
||||||
|
tokenized["attention_mask"].type(torch.uint8).to("cpu"),
|
||||||
|
tokenized["token_type_ids"].type(torch.int).to("cpu"),
|
||||||
|
position_ids.type(torch.int).to("cpu"),
|
||||||
|
text_self_attention_masks.type(torch.bool).to("cpu"),
|
||||||
box_threshold,
|
box_threshold,
|
||||||
text_threshold),
|
text_threshold),
|
||||||
f = onnx_file,
|
f = onnx_file,
|
||||||
input_names = [ "image", "input_ids", "box_threshold", "text_threshold" ],
|
input_names = [ "image", "input_ids", "attention_mask", "token_type_ids", "position_ids", "text_self_attention_masks", "box_threshold", "text_threshold" ],
|
||||||
output_names = [ "logits", "boxes", "masks" ],
|
output_names = [ "logits", "boxes", "masks" ],
|
||||||
opset_version = 17,
|
opset_version = 17,
|
||||||
export_params = True,
|
export_params = True,
|
||||||
do_constant_folding = True,
|
do_constant_folding = True,
|
||||||
dynamic_axes = {
|
dynamic_axes = {
|
||||||
"input_ids": { 1: "token_num" }
|
"input_ids": { 1: "token_num" },
|
||||||
|
"attention_mask": { 1: "token_num" },
|
||||||
|
"token_type_ids": { 1: "token_num" },
|
||||||
|
"position_ids": { 1: "token_num" },
|
||||||
|
"text_self_attention_masks": { 1: "token_num", 2: "token_num" }
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -100,15 +121,28 @@ def export_onnx(model, output_dir):
|
||||||
print("check model ok!")
|
print("check model ok!")
|
||||||
|
|
||||||
def inference(model):
|
def inference(model):
|
||||||
image = cv2.imread('asset/1.jpg')
|
image = cv2.imread('asset/cat_dog.jpeg')
|
||||||
processed_image = preprocess_image(image).unsqueeze(0)
|
processed_image = preprocess_image(image).unsqueeze(0)
|
||||||
caption = preprocess_caption("watermark")
|
caption = preprocess_caption("cat. dog")
|
||||||
tokenized = model.tokenizer(caption, padding="longest", return_tensors="pt")
|
tokenized = model.tokenizer(caption, padding="longest", return_tensors="pt")
|
||||||
box_threshold = torch.tensor(0.35, dtype=torch.float32)
|
box_threshold = torch.tensor(0.35, dtype=torch.float32)
|
||||||
text_threshold = torch.tensor(0.25, dtype=torch.float32)
|
text_threshold = torch.tensor(0.25, dtype=torch.float32)
|
||||||
|
|
||||||
|
specical_tokens = model.specical_tokens
|
||||||
|
(
|
||||||
|
text_self_attention_masks,
|
||||||
|
position_ids,
|
||||||
|
_,
|
||||||
|
) = generate_masks_with_special_tokens_and_transfer_map(
|
||||||
|
tokenized, specical_tokens, model.tokenizer
|
||||||
|
)
|
||||||
|
|
||||||
outputs = model(processed_image,
|
outputs = model(processed_image,
|
||||||
tokenized["input_ids"],
|
tokenized["input_ids"],
|
||||||
|
tokenized["attention_mask"],
|
||||||
|
tokenized["token_type_ids"],
|
||||||
|
position_ids,
|
||||||
|
text_self_attention_masks,
|
||||||
box_threshold,
|
box_threshold,
|
||||||
text_threshold)
|
text_threshold)
|
||||||
|
|
||||||
|
@ -130,7 +164,7 @@ def inference(model):
|
||||||
cv2.waitKey()
|
cv2.waitKey()
|
||||||
cv2.destroyAllWindows()
|
cv2.destroyAllWindows()
|
||||||
|
|
||||||
def inference_onnx(output_dir):
|
def inference_onnx(model, output_dir):
|
||||||
onnx_file = output_dir + "/" + "gdino.onnx"
|
onnx_file = output_dir + "/" + "gdino.onnx"
|
||||||
session = ort.InferenceSession(onnx_file)
|
session = ort.InferenceSession(onnx_file)
|
||||||
|
|
||||||
|
@ -141,9 +175,32 @@ def inference_onnx(output_dir):
|
||||||
box_threshold = torch.tensor(0.35, dtype=torch.float32)
|
box_threshold = torch.tensor(0.35, dtype=torch.float32)
|
||||||
text_threshold = torch.tensor(0.25, dtype=torch.float32)
|
text_threshold = torch.tensor(0.25, dtype=torch.float32)
|
||||||
|
|
||||||
|
specical_tokens = model.specical_tokens
|
||||||
|
(
|
||||||
|
text_self_attention_masks,
|
||||||
|
position_ids,
|
||||||
|
_,
|
||||||
|
) = generate_masks_with_special_tokens_and_transfer_map(
|
||||||
|
tokenized, specical_tokens, model.tokenizer
|
||||||
|
)
|
||||||
|
|
||||||
|
max_text_len = model.max_text_len
|
||||||
|
if text_self_attention_masks.shape[1] > max_text_len:
|
||||||
|
text_self_attention_masks = text_self_attention_masks[
|
||||||
|
:, : max_text_len, : max_text_len
|
||||||
|
]
|
||||||
|
position_ids = position_ids[:, : max_text_len]
|
||||||
|
tokenized["input_ids"] = tokenized["input_ids"][:, : max_text_len]
|
||||||
|
tokenized["attention_mask"] = tokenized["attention_mask"][:, : max_text_len]
|
||||||
|
tokenized["token_type_ids"] = tokenized["token_type_ids"][:, : max_text_len]
|
||||||
|
|
||||||
outputs = session.run(None, {
|
outputs = session.run(None, {
|
||||||
"image": processed_image.numpy().astype(np.float32) ,
|
"image": processed_image.numpy().astype(np.float32) ,
|
||||||
"input_ids": tokenized["input_ids"].numpy().astype(np.int64) ,
|
"input_ids": tokenized["input_ids"].numpy().astype(np.int32) ,
|
||||||
|
"attention_mask": tokenized["attention_mask"].numpy().astype(np.uint8) ,
|
||||||
|
"token_type_ids": tokenized["token_type_ids"].numpy().astype(np.int32) ,
|
||||||
|
"position_ids": position_ids.numpy().astype(np.int32) ,
|
||||||
|
"text_self_attention_masks": text_self_attention_masks.numpy().astype(np.bool) ,
|
||||||
"box_threshold": box_threshold.numpy().astype(np.float32) ,
|
"box_threshold": box_threshold.numpy().astype(np.float32) ,
|
||||||
"text_threshold": text_threshold.numpy().astype(np.float32)
|
"text_threshold": text_threshold.numpy().astype(np.float32)
|
||||||
})
|
})
|
||||||
|
@ -191,7 +248,7 @@ if __name__ == "__main__":
|
||||||
model = Model(config_file, checkpoint_path, device='cpu')
|
model = Model(config_file, checkpoint_path, device='cpu')
|
||||||
|
|
||||||
if args.test:
|
if args.test:
|
||||||
inference_onnx(output_dir)
|
inference_onnx(model, output_dir)
|
||||||
elif args.orig:
|
elif args.orig:
|
||||||
inference(model)
|
inference(model)
|
||||||
else:
|
else:
|
||||||
|
|
Loading…
Reference in New Issue