Add files via upload

update
pull/380/head
szsteven008 2025-01-01 16:21:57 +08:00 committed by GitHub
parent 17f37607dc
commit 89b1ad20b9
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
1 changed files with 70 additions and 13 deletions

View File

@ -12,6 +12,7 @@ import onnxruntime as ort
from groundingdino.util.inference import load_model, annotate
import groundingdino.datasets.transforms as T
from groundingdino.util.utils import get_phrases_from_posmap
from groundingdino.models.GroundingDINO.bertwarper import generate_masks_with_special_tokens_and_transfer_map
class Model(torch.nn.Module):
def __init__(
@ -27,18 +28,21 @@ class Model(torch.nn.Module):
device=device
).to(device)
self.tokenizer = self.model.tokenizer
self.specical_tokens = self.model.specical_tokens
self.max_text_len = self.model.max_text_len
# def forward(self, samples: NestedTensor, targets: List = None, **kw):
def forward(self,
image: torch.Tensor,
input_ids: torch.Tensor,
attention_mask: torch.Tensor,
token_type_ids: torch.Tensor,
position_ids: torch.Tensor,
text_self_attention_masks: torch.Tensor,
box_threshold: torch.Tensor,
text_threshold: torch.Tensor,
**kw):
token_type_ids = torch.zeros(input_ids.size(), dtype=torch.int32)
attention_mask = (input_ids != 0).int()
outputs = self.model(image, input_ids, attention_mask, token_type_ids)
outputs = self.model(image, input_ids, attention_mask, token_type_ids, position_ids, text_self_attention_masks)
prediction_logits = outputs["pred_logits"].sigmoid().squeeze(0)
prediction_boxes = outputs["pred_boxes"].squeeze(0)
@ -70,26 +74,43 @@ def preprocess_caption(caption: str) -> str:
def export_onnx(model, output_dir):
onnx_file = output_dir + "/" + "gdino.onnx"
caption = preprocess_caption(".")
caption = preprocess_caption("watermark")
tokenized = model.tokenizer(caption, padding="longest", return_tensors="pt")
box_threshold = torch.tensor(0.35, dtype=torch.float32)
text_threshold = torch.tensor(0.25, dtype=torch.float32)
specical_tokens = model.specical_tokens
(
text_self_attention_masks,
position_ids,
_,
) = generate_masks_with_special_tokens_and_transfer_map(
tokenized, specical_tokens, model.tokenizer
)
torch.onnx.export(
model,
args = (
torch.rand(1, 3, 800, 800).type(torch.float32).to("cpu"),
tokenized["input_ids"],
tokenized["input_ids"].type(torch.int).to("cpu"),
tokenized["attention_mask"].type(torch.uint8).to("cpu"),
tokenized["token_type_ids"].type(torch.int).to("cpu"),
position_ids.type(torch.int).to("cpu"),
text_self_attention_masks.type(torch.bool).to("cpu"),
box_threshold,
text_threshold),
f = onnx_file,
input_names = [ "image", "input_ids", "box_threshold", "text_threshold" ],
input_names = [ "image", "input_ids", "attention_mask", "token_type_ids", "position_ids", "text_self_attention_masks", "box_threshold", "text_threshold" ],
output_names = [ "logits", "boxes", "masks" ],
opset_version = 17,
export_params = True,
do_constant_folding = True,
dynamic_axes = {
"input_ids": { 1: "token_num" }
"input_ids": { 1: "token_num" },
"attention_mask": { 1: "token_num" },
"token_type_ids": { 1: "token_num" },
"position_ids": { 1: "token_num" },
"text_self_attention_masks": { 1: "token_num", 2: "token_num" }
},
)
@ -100,15 +121,28 @@ def export_onnx(model, output_dir):
print("check model ok!")
def inference(model):
image = cv2.imread('asset/1.jpg')
image = cv2.imread('asset/cat_dog.jpeg')
processed_image = preprocess_image(image).unsqueeze(0)
caption = preprocess_caption("watermark")
caption = preprocess_caption("cat. dog")
tokenized = model.tokenizer(caption, padding="longest", return_tensors="pt")
box_threshold = torch.tensor(0.35, dtype=torch.float32)
text_threshold = torch.tensor(0.25, dtype=torch.float32)
specical_tokens = model.specical_tokens
(
text_self_attention_masks,
position_ids,
_,
) = generate_masks_with_special_tokens_and_transfer_map(
tokenized, specical_tokens, model.tokenizer
)
outputs = model(processed_image,
tokenized["input_ids"],
tokenized["attention_mask"],
tokenized["token_type_ids"],
position_ids,
text_self_attention_masks,
box_threshold,
text_threshold)
@ -130,7 +164,7 @@ def inference(model):
cv2.waitKey()
cv2.destroyAllWindows()
def inference_onnx(output_dir):
def inference_onnx(model, output_dir):
onnx_file = output_dir + "/" + "gdino.onnx"
session = ort.InferenceSession(onnx_file)
@ -141,9 +175,32 @@ def inference_onnx(output_dir):
box_threshold = torch.tensor(0.35, dtype=torch.float32)
text_threshold = torch.tensor(0.25, dtype=torch.float32)
specical_tokens = model.specical_tokens
(
text_self_attention_masks,
position_ids,
_,
) = generate_masks_with_special_tokens_and_transfer_map(
tokenized, specical_tokens, model.tokenizer
)
max_text_len = model.max_text_len
if text_self_attention_masks.shape[1] > max_text_len:
text_self_attention_masks = text_self_attention_masks[
:, : max_text_len, : max_text_len
]
position_ids = position_ids[:, : max_text_len]
tokenized["input_ids"] = tokenized["input_ids"][:, : max_text_len]
tokenized["attention_mask"] = tokenized["attention_mask"][:, : max_text_len]
tokenized["token_type_ids"] = tokenized["token_type_ids"][:, : max_text_len]
outputs = session.run(None, {
"image": processed_image.numpy().astype(np.float32) ,
"input_ids": tokenized["input_ids"].numpy().astype(np.int64) ,
"input_ids": tokenized["input_ids"].numpy().astype(np.int32) ,
"attention_mask": tokenized["attention_mask"].numpy().astype(np.uint8) ,
"token_type_ids": tokenized["token_type_ids"].numpy().astype(np.int32) ,
"position_ids": position_ids.numpy().astype(np.int32) ,
"text_self_attention_masks": text_self_attention_masks.numpy().astype(np.bool) ,
"box_threshold": box_threshold.numpy().astype(np.float32) ,
"text_threshold": text_threshold.numpy().astype(np.float32)
})
@ -191,7 +248,7 @@ if __name__ == "__main__":
model = Model(config_file, checkpoint_path, device='cpu')
if args.test:
inference_onnx(output_dir)
inference_onnx(model, output_dir)
elif args.orig:
inference(model)
else: