diff --git a/export.py b/export.py index 99d07d3..be4554f 100644 --- a/export.py +++ b/export.py @@ -12,6 +12,7 @@ import onnxruntime as ort from groundingdino.util.inference import load_model, annotate import groundingdino.datasets.transforms as T from groundingdino.util.utils import get_phrases_from_posmap +from groundingdino.models.GroundingDINO.bertwarper import generate_masks_with_special_tokens_and_transfer_map class Model(torch.nn.Module): def __init__( @@ -27,18 +28,21 @@ class Model(torch.nn.Module): device=device ).to(device) self.tokenizer = self.model.tokenizer + self.specical_tokens = self.model.specical_tokens + self.max_text_len = self.model.max_text_len # def forward(self, samples: NestedTensor, targets: List = None, **kw): def forward(self, image: torch.Tensor, input_ids: torch.Tensor, + attention_mask: torch.Tensor, + token_type_ids: torch.Tensor, + position_ids: torch.Tensor, + text_self_attention_masks: torch.Tensor, box_threshold: torch.Tensor, text_threshold: torch.Tensor, **kw): - token_type_ids = torch.zeros(input_ids.size(), dtype=torch.int32) - attention_mask = (input_ids != 0).int() - - outputs = self.model(image, input_ids, attention_mask, token_type_ids) + outputs = self.model(image, input_ids, attention_mask, token_type_ids, position_ids, text_self_attention_masks) prediction_logits = outputs["pred_logits"].sigmoid().squeeze(0) prediction_boxes = outputs["pred_boxes"].squeeze(0) @@ -70,26 +74,43 @@ def preprocess_caption(caption: str) -> str: def export_onnx(model, output_dir): onnx_file = output_dir + "/" + "gdino.onnx" - caption = preprocess_caption(".") + caption = preprocess_caption("watermark") tokenized = model.tokenizer(caption, padding="longest", return_tensors="pt") box_threshold = torch.tensor(0.35, dtype=torch.float32) text_threshold = torch.tensor(0.25, dtype=torch.float32) + specical_tokens = model.specical_tokens + ( + text_self_attention_masks, + position_ids, + _, + ) = generate_masks_with_special_tokens_and_transfer_map( + tokenized, specical_tokens, model.tokenizer + ) + torch.onnx.export( model, args = ( torch.rand(1, 3, 800, 800).type(torch.float32).to("cpu"), - tokenized["input_ids"], + tokenized["input_ids"].type(torch.int).to("cpu"), + tokenized["attention_mask"].type(torch.uint8).to("cpu"), + tokenized["token_type_ids"].type(torch.int).to("cpu"), + position_ids.type(torch.int).to("cpu"), + text_self_attention_masks.type(torch.bool).to("cpu"), box_threshold, text_threshold), f = onnx_file, - input_names = [ "image", "input_ids", "box_threshold", "text_threshold" ], + input_names = [ "image", "input_ids", "attention_mask", "token_type_ids", "position_ids", "text_self_attention_masks", "box_threshold", "text_threshold" ], output_names = [ "logits", "boxes", "masks" ], opset_version = 17, export_params = True, do_constant_folding = True, dynamic_axes = { - "input_ids": { 1: "token_num" } + "input_ids": { 1: "token_num" }, + "attention_mask": { 1: "token_num" }, + "token_type_ids": { 1: "token_num" }, + "position_ids": { 1: "token_num" }, + "text_self_attention_masks": { 1: "token_num", 2: "token_num" } }, ) @@ -100,15 +121,28 @@ def export_onnx(model, output_dir): print("check model ok!") def inference(model): - image = cv2.imread('asset/1.jpg') + image = cv2.imread('asset/cat_dog.jpeg') processed_image = preprocess_image(image).unsqueeze(0) - caption = preprocess_caption("watermark") + caption = preprocess_caption("cat. dog") tokenized = model.tokenizer(caption, padding="longest", return_tensors="pt") box_threshold = torch.tensor(0.35, dtype=torch.float32) text_threshold = torch.tensor(0.25, dtype=torch.float32) + specical_tokens = model.specical_tokens + ( + text_self_attention_masks, + position_ids, + _, + ) = generate_masks_with_special_tokens_and_transfer_map( + tokenized, specical_tokens, model.tokenizer + ) + outputs = model(processed_image, tokenized["input_ids"], + tokenized["attention_mask"], + tokenized["token_type_ids"], + position_ids, + text_self_attention_masks, box_threshold, text_threshold) @@ -130,7 +164,7 @@ def inference(model): cv2.waitKey() cv2.destroyAllWindows() -def inference_onnx(output_dir): +def inference_onnx(model, output_dir): onnx_file = output_dir + "/" + "gdino.onnx" session = ort.InferenceSession(onnx_file) @@ -141,9 +175,32 @@ def inference_onnx(output_dir): box_threshold = torch.tensor(0.35, dtype=torch.float32) text_threshold = torch.tensor(0.25, dtype=torch.float32) + specical_tokens = model.specical_tokens + ( + text_self_attention_masks, + position_ids, + _, + ) = generate_masks_with_special_tokens_and_transfer_map( + tokenized, specical_tokens, model.tokenizer + ) + + max_text_len = model.max_text_len + if text_self_attention_masks.shape[1] > max_text_len: + text_self_attention_masks = text_self_attention_masks[ + :, : max_text_len, : max_text_len + ] + position_ids = position_ids[:, : max_text_len] + tokenized["input_ids"] = tokenized["input_ids"][:, : max_text_len] + tokenized["attention_mask"] = tokenized["attention_mask"][:, : max_text_len] + tokenized["token_type_ids"] = tokenized["token_type_ids"][:, : max_text_len] + outputs = session.run(None, { "image": processed_image.numpy().astype(np.float32) , - "input_ids": tokenized["input_ids"].numpy().astype(np.int64) , + "input_ids": tokenized["input_ids"].numpy().astype(np.int32) , + "attention_mask": tokenized["attention_mask"].numpy().astype(np.uint8) , + "token_type_ids": tokenized["token_type_ids"].numpy().astype(np.int32) , + "position_ids": position_ids.numpy().astype(np.int32) , + "text_self_attention_masks": text_self_attention_masks.numpy().astype(np.bool) , "box_threshold": box_threshold.numpy().astype(np.float32) , "text_threshold": text_threshold.numpy().astype(np.float32) }) @@ -191,7 +248,7 @@ if __name__ == "__main__": model = Model(config_file, checkpoint_path, device='cpu') if args.test: - inference_onnx(output_dir) + inference_onnx(model, output_dir) elif args.orig: inference(model) else: