Add files via upload

update
pull/380/head
szsteven008 2025-01-01 16:21:57 +08:00 committed by GitHub
parent 17f37607dc
commit 89b1ad20b9
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
1 changed files with 70 additions and 13 deletions

View File

@ -12,6 +12,7 @@ import onnxruntime as ort
from groundingdino.util.inference import load_model, annotate from groundingdino.util.inference import load_model, annotate
import groundingdino.datasets.transforms as T import groundingdino.datasets.transforms as T
from groundingdino.util.utils import get_phrases_from_posmap from groundingdino.util.utils import get_phrases_from_posmap
from groundingdino.models.GroundingDINO.bertwarper import generate_masks_with_special_tokens_and_transfer_map
class Model(torch.nn.Module): class Model(torch.nn.Module):
def __init__( def __init__(
@ -27,18 +28,21 @@ class Model(torch.nn.Module):
device=device device=device
).to(device) ).to(device)
self.tokenizer = self.model.tokenizer self.tokenizer = self.model.tokenizer
self.specical_tokens = self.model.specical_tokens
self.max_text_len = self.model.max_text_len
# def forward(self, samples: NestedTensor, targets: List = None, **kw): # def forward(self, samples: NestedTensor, targets: List = None, **kw):
def forward(self, def forward(self,
image: torch.Tensor, image: torch.Tensor,
input_ids: torch.Tensor, input_ids: torch.Tensor,
attention_mask: torch.Tensor,
token_type_ids: torch.Tensor,
position_ids: torch.Tensor,
text_self_attention_masks: torch.Tensor,
box_threshold: torch.Tensor, box_threshold: torch.Tensor,
text_threshold: torch.Tensor, text_threshold: torch.Tensor,
**kw): **kw):
token_type_ids = torch.zeros(input_ids.size(), dtype=torch.int32) outputs = self.model(image, input_ids, attention_mask, token_type_ids, position_ids, text_self_attention_masks)
attention_mask = (input_ids != 0).int()
outputs = self.model(image, input_ids, attention_mask, token_type_ids)
prediction_logits = outputs["pred_logits"].sigmoid().squeeze(0) prediction_logits = outputs["pred_logits"].sigmoid().squeeze(0)
prediction_boxes = outputs["pred_boxes"].squeeze(0) prediction_boxes = outputs["pred_boxes"].squeeze(0)
@ -70,26 +74,43 @@ def preprocess_caption(caption: str) -> str:
def export_onnx(model, output_dir): def export_onnx(model, output_dir):
onnx_file = output_dir + "/" + "gdino.onnx" onnx_file = output_dir + "/" + "gdino.onnx"
caption = preprocess_caption(".") caption = preprocess_caption("watermark")
tokenized = model.tokenizer(caption, padding="longest", return_tensors="pt") tokenized = model.tokenizer(caption, padding="longest", return_tensors="pt")
box_threshold = torch.tensor(0.35, dtype=torch.float32) box_threshold = torch.tensor(0.35, dtype=torch.float32)
text_threshold = torch.tensor(0.25, dtype=torch.float32) text_threshold = torch.tensor(0.25, dtype=torch.float32)
specical_tokens = model.specical_tokens
(
text_self_attention_masks,
position_ids,
_,
) = generate_masks_with_special_tokens_and_transfer_map(
tokenized, specical_tokens, model.tokenizer
)
torch.onnx.export( torch.onnx.export(
model, model,
args = ( args = (
torch.rand(1, 3, 800, 800).type(torch.float32).to("cpu"), torch.rand(1, 3, 800, 800).type(torch.float32).to("cpu"),
tokenized["input_ids"], tokenized["input_ids"].type(torch.int).to("cpu"),
tokenized["attention_mask"].type(torch.uint8).to("cpu"),
tokenized["token_type_ids"].type(torch.int).to("cpu"),
position_ids.type(torch.int).to("cpu"),
text_self_attention_masks.type(torch.bool).to("cpu"),
box_threshold, box_threshold,
text_threshold), text_threshold),
f = onnx_file, f = onnx_file,
input_names = [ "image", "input_ids", "box_threshold", "text_threshold" ], input_names = [ "image", "input_ids", "attention_mask", "token_type_ids", "position_ids", "text_self_attention_masks", "box_threshold", "text_threshold" ],
output_names = [ "logits", "boxes", "masks" ], output_names = [ "logits", "boxes", "masks" ],
opset_version = 17, opset_version = 17,
export_params = True, export_params = True,
do_constant_folding = True, do_constant_folding = True,
dynamic_axes = { dynamic_axes = {
"input_ids": { 1: "token_num" } "input_ids": { 1: "token_num" },
"attention_mask": { 1: "token_num" },
"token_type_ids": { 1: "token_num" },
"position_ids": { 1: "token_num" },
"text_self_attention_masks": { 1: "token_num", 2: "token_num" }
}, },
) )
@ -100,15 +121,28 @@ def export_onnx(model, output_dir):
print("check model ok!") print("check model ok!")
def inference(model): def inference(model):
image = cv2.imread('asset/1.jpg') image = cv2.imread('asset/cat_dog.jpeg')
processed_image = preprocess_image(image).unsqueeze(0) processed_image = preprocess_image(image).unsqueeze(0)
caption = preprocess_caption("watermark") caption = preprocess_caption("cat. dog")
tokenized = model.tokenizer(caption, padding="longest", return_tensors="pt") tokenized = model.tokenizer(caption, padding="longest", return_tensors="pt")
box_threshold = torch.tensor(0.35, dtype=torch.float32) box_threshold = torch.tensor(0.35, dtype=torch.float32)
text_threshold = torch.tensor(0.25, dtype=torch.float32) text_threshold = torch.tensor(0.25, dtype=torch.float32)
specical_tokens = model.specical_tokens
(
text_self_attention_masks,
position_ids,
_,
) = generate_masks_with_special_tokens_and_transfer_map(
tokenized, specical_tokens, model.tokenizer
)
outputs = model(processed_image, outputs = model(processed_image,
tokenized["input_ids"], tokenized["input_ids"],
tokenized["attention_mask"],
tokenized["token_type_ids"],
position_ids,
text_self_attention_masks,
box_threshold, box_threshold,
text_threshold) text_threshold)
@ -130,7 +164,7 @@ def inference(model):
cv2.waitKey() cv2.waitKey()
cv2.destroyAllWindows() cv2.destroyAllWindows()
def inference_onnx(output_dir): def inference_onnx(model, output_dir):
onnx_file = output_dir + "/" + "gdino.onnx" onnx_file = output_dir + "/" + "gdino.onnx"
session = ort.InferenceSession(onnx_file) session = ort.InferenceSession(onnx_file)
@ -141,9 +175,32 @@ def inference_onnx(output_dir):
box_threshold = torch.tensor(0.35, dtype=torch.float32) box_threshold = torch.tensor(0.35, dtype=torch.float32)
text_threshold = torch.tensor(0.25, dtype=torch.float32) text_threshold = torch.tensor(0.25, dtype=torch.float32)
specical_tokens = model.specical_tokens
(
text_self_attention_masks,
position_ids,
_,
) = generate_masks_with_special_tokens_and_transfer_map(
tokenized, specical_tokens, model.tokenizer
)
max_text_len = model.max_text_len
if text_self_attention_masks.shape[1] > max_text_len:
text_self_attention_masks = text_self_attention_masks[
:, : max_text_len, : max_text_len
]
position_ids = position_ids[:, : max_text_len]
tokenized["input_ids"] = tokenized["input_ids"][:, : max_text_len]
tokenized["attention_mask"] = tokenized["attention_mask"][:, : max_text_len]
tokenized["token_type_ids"] = tokenized["token_type_ids"][:, : max_text_len]
outputs = session.run(None, { outputs = session.run(None, {
"image": processed_image.numpy().astype(np.float32) , "image": processed_image.numpy().astype(np.float32) ,
"input_ids": tokenized["input_ids"].numpy().astype(np.int64) , "input_ids": tokenized["input_ids"].numpy().astype(np.int32) ,
"attention_mask": tokenized["attention_mask"].numpy().astype(np.uint8) ,
"token_type_ids": tokenized["token_type_ids"].numpy().astype(np.int32) ,
"position_ids": position_ids.numpy().astype(np.int32) ,
"text_self_attention_masks": text_self_attention_masks.numpy().astype(np.bool) ,
"box_threshold": box_threshold.numpy().astype(np.float32) , "box_threshold": box_threshold.numpy().astype(np.float32) ,
"text_threshold": text_threshold.numpy().astype(np.float32) "text_threshold": text_threshold.numpy().astype(np.float32)
}) })
@ -191,7 +248,7 @@ if __name__ == "__main__":
model = Model(config_file, checkpoint_path, device='cpu') model = Model(config_file, checkpoint_path, device='cpu')
if args.test: if args.test:
inference_onnx(output_dir) inference_onnx(model, output_dir)
elif args.orig: elif args.orig:
inference(model) inference(model)
else: else: