diff --git a/export.py b/export.py new file mode 100644 index 0000000..a304f0f --- /dev/null +++ b/export.py @@ -0,0 +1,282 @@ +import argparse +import os +import torch +import cv2 +import numpy as np +from PIL import Image +from typing import Tuple, List +from torchvision.ops import box_convert +import onnx +import onnxruntime as ort + +from groundingdino.util.inference import load_model, annotate +import groundingdino.datasets.transforms as T +from groundingdino.util.utils import get_phrases_from_posmap +from groundingdino.models.GroundingDINO.bertwarper import generate_masks_with_special_tokens + +def preprocess_caption(caption: str) -> str: + result = caption.lower().strip() + if result.endswith("."): + return result + return result + "." + +class Encoder(torch.nn.Module): + def __init__(self, model): + super().__init__() + self.tokenizer = model.tokenizer + self.bert = model.bert + self.specical_tokens = model.specical_tokens + + def forward(self, + input_ids: torch.Tensor, + token_type_ids: torch.Tensor, + text_self_attention_masks: torch.Tensor, + position_ids: torch.Tensor): + # extract text embeddings + tokenized_for_encoder = {} + tokenized_for_encoder["input_ids"] = input_ids + tokenized_for_encoder["token_type_ids"] = token_type_ids + tokenized_for_encoder["attention_mask"] = text_self_attention_masks.type(torch.bool) + tokenized_for_encoder["position_ids"] = position_ids + + bert_output = self.bert(**tokenized_for_encoder) # bs, 195, 768 + + return bert_output["last_hidden_state"] + +class Decoder(torch.nn.Module): + def __init__(self, model): + super().__init__() + self.model = model + self.tokenizer = model.tokenizer + self.specical_tokens = model.specical_tokens + + def forward(self, + image: torch.Tensor, + last_hidden_state: torch.Tensor, + attention_mask: torch.Tensor, + position_ids: torch.Tensor, + text_self_attention_masks: torch.Tensor, + box_threshold: torch.Tensor, + text_threshold: torch.Tensor): + outputs = self.model(image, + last_hidden_state, + attention_mask, + position_ids, + text_self_attention_masks.type(torch.bool)) + prediction_logits = outputs["pred_logits"].sigmoid().squeeze(0) + prediction_boxes = outputs["pred_boxes"].squeeze(0) + + mask = prediction_logits.max(dim=1)[0] > box_threshold + prediction_logits = prediction_logits[mask] + prediction_input_ids_mask = prediction_logits > text_threshold + prediction_boxes = prediction_boxes[mask] + + return (prediction_logits.max(dim=1)[0].unsqueeze(0), + prediction_boxes.unsqueeze(0), + prediction_input_ids_mask.unsqueeze(0)) + +def export_encoder(model, output): + onnx_file = output + "/" + "gdino.encoder.onnx" + caption = preprocess_caption("watermark") + tokenized = model.tokenizer(caption, padding="longest", return_tensors="pt") + + ( + text_self_attention_masks, + position_ids + ) = generate_masks_with_special_tokens(tokenized, model.specical_tokens, model.tokenizer) + + torch.onnx.export( + model, + args = ( + tokenized["input_ids"].type(torch.int).to("cpu"), + tokenized["token_type_ids"].type(torch.int).to("cpu"), + text_self_attention_masks.type(torch.uint8).to("cpu"), + position_ids.type(torch.int).to("cpu"), + ), + f = onnx_file, + input_names = [ "input_ids", "token_type_ids", "text_self_attention_masks", "position_ids" ], + output_names = [ "last_hidden_state" ], + opset_version = 17, + export_params = True, + do_constant_folding = True, + dynamic_axes = { + "input_ids": { 1: "token_num" }, + "token_type_ids": { 1: "token_num" }, + "text_self_attention_masks": { 1: "token_num", 2: "token_num" }, + "position_ids": { 1: "token_num" }, + "last_hidden_state": { 1: "token_num" } + }, + ) + + print("export gdino.encoder.onnx ok!") + + onnx_model = onnx.load(onnx_file) + onnx.checker.check_model(onnx_model) + print("check gdino.encoder.onnx ok!") + +def export_decoder(model, output, encoder): + onnx_file = output + "/" + "gdino.decoder.onnx" + caption = preprocess_caption("watermark") + + tokenized, last_hidden_state = inference_encoder_onnx(encoder, output, caption) + + box_threshold = torch.tensor(0.35, dtype=torch.float32) + text_threshold = torch.tensor(0.25, dtype=torch.float32) + + torch.onnx.export( + model, + args = ( + torch.rand(1, 3, 800, 800).type(torch.float32).to("cpu"), + last_hidden_state, + tokenized["attention_mask"].type(torch.uint8).to("cpu"), + tokenized["position_ids"].type(torch.int).to("cpu"), + tokenized["text_self_attention_masks"].type(torch.uint8).to("cpu"), + box_threshold, + text_threshold), + f = onnx_file, + input_names = [ "image", "last_hidden_state", "attention_mask", + "position_ids", "text_self_attention_masks", + "box_threshold", "text_threshold" ], + output_names = [ "logits", "boxes", "masks" ], + opset_version = 17, + export_params = True, + do_constant_folding = True, + dynamic_axes = { + "last_hidden_state": { 1: "token_num" }, + "attention_mask": { 1: "token_num" }, + "position_ids": { 1: "token_num" }, + "text_self_attention_masks": { 1: "token_num", 2: "token_num" } + }, + ) + + print("export gdino.decoder.onnx ok!") + + onnx_model = onnx.load(onnx_file) + onnx.checker.check_model(onnx_model) + print("check gdino.decoder.onnx ok!") + +def inference_encoder_onnx(model, output, caption: str = None): + onnx_file = output + "/" + "gdino.encoder.onnx" + session = ort.InferenceSession(onnx_file) + + if caption: + proc_caption = preprocess_caption(caption) + else: + proc_caption = preprocess_caption("watermark. cat. dog") + tokenized = model.tokenizer(proc_caption, padding="longest", return_tensors="pt") + + ( + text_self_attention_masks, + position_ids + ) = generate_masks_with_special_tokens(tokenized, model.specical_tokens, model.tokenizer) + + tokenized["text_self_attention_masks"] = text_self_attention_masks + tokenized["position_ids"] = position_ids + + outputs = session.run(None, { + "input_ids": tokenized["input_ids"].numpy().astype(np.int32), + "token_type_ids": tokenized["token_type_ids"].numpy().astype(np.int32), + "text_self_attention_masks": tokenized["text_self_attention_masks"].numpy().astype(np.uint8), + "position_ids": tokenized["position_ids"].numpy().astype(np.int32) + }) + + if caption == None: + print(outputs) + + last_hidden_state = torch.from_numpy(outputs[0]).type(torch.float32) + return tokenized, last_hidden_state + +def preprocess_image(image_bgr: np.ndarray) -> torch.Tensor: + image_bgr = cv2.resize(image_bgr, (800, 800)) + transform = T.Compose( + [ + T.RandomResize([800], max_size=1333), + T.ToTensor(), + T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]), + ] + ) + image_pillow = Image.fromarray(cv2.cvtColor(image_bgr, cv2.COLOR_BGR2RGB)) + image_transformed, _ = transform(image_pillow, None) + return image_transformed + +def inference_decoder_onnx(model, output): + image = cv2.imread('asset/1.jpg') + processed_image = preprocess_image(image).unsqueeze(0) + + caption = "watermark. glasses" + + tokenized, last_hidden_state = inference_encoder_onnx(model, output, caption) + + print(tokenized) + print(last_hidden_state) + + onnx_file = output + "/" + "gdino.decoder.onnx" + session = ort.InferenceSession(onnx_file) + + box_threshold = torch.tensor(0.35, dtype=torch.float32) + text_threshold = torch.tensor(0.25, dtype=torch.float32) + + decode_outputs = session.run(None, { + "image": processed_image.numpy().astype(np.float32), + "last_hidden_state": last_hidden_state.numpy().astype(np.float32), + "attention_mask": tokenized["attention_mask"].numpy().astype(np.uint8), + "position_ids": tokenized["position_ids"].numpy().astype(np.int32), + "text_self_attention_masks": tokenized["text_self_attention_masks"].numpy().astype(np.uint8), + "box_threshold": box_threshold.numpy().astype(np.float32), + "text_threshold": text_threshold.numpy().astype(np.float32) + }) + + prediction_logits = torch.from_numpy(decode_outputs[0]) + prediction_boxes = torch.from_numpy(decode_outputs[1]) + prediction_masks = torch.from_numpy(decode_outputs[2]) + + input_ids = tokenized["input_ids"][0].tolist() + phrases = [] + for mask in prediction_masks[0]: + prediction_token_ids = [input_ids[i] for i in mask.nonzero(as_tuple=True)[0].tolist()] + phrases.append(model.tokenizer.decode(prediction_token_ids).replace('.', '')) + + with torch.no_grad(): + image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) + image = annotate(image, prediction_boxes[0], prediction_logits[0], phrases) + + cv2.imshow("image", image) + cv2.waitKey() + cv2.destroyAllWindows() + +if __name__ == "__main__": + parser = argparse.ArgumentParser("Export Grounding DINO Model to ONNX", add_help=True) + parser.add_argument("--encode", "-e", help="test encoder.onnx model", action="store_true") + parser.add_argument("--decode", "-d", help="test decoder.onnx model", action="store_true") + parser.add_argument("--config_file", "-c", type=str, required=True, help="path to config file") + parser.add_argument( + "--checkpoint_path", "-p", type=str, required=True, help="path to checkpoint file" + ) + parser.add_argument( + "--output_dir", "-o", type=str, default="outputs", required=True, help="output directory" + ) + + args = parser.parse_args() + + # cfg + config_file = args.config_file # change the path of the model config file + checkpoint_path = args.checkpoint_path # change the path of the model + output_dir = args.output_dir + + # make dir + os.makedirs(output_dir, exist_ok=True) + + source_model = load_model(model_config_path = config_file, + model_checkpoint_path = checkpoint_path, + device = "cpu").to("cpu") + + encoder = Encoder(source_model) + decoder = Decoder(source_model) + + if args.encode: + inference_encoder_onnx(encoder, output_dir) + elif args.decode: + inference_decoder_onnx(decoder, output_dir) + else: + export_encoder(encoder, output_dir) + export_decoder(decoder, output_dir, encoder) diff --git a/groundingdino/config/GroundingDINO_SwinB_cfg.py b/groundingdino/config/GroundingDINO_SwinB_cfg.py index f490c4b..320bf31 100644 --- a/groundingdino/config/GroundingDINO_SwinB_cfg.py +++ b/groundingdino/config/GroundingDINO_SwinB_cfg.py @@ -34,8 +34,8 @@ max_text_len = 256 text_encoder_type = "bert-base-uncased" use_text_enhancer = True use_fusion_layer = True -use_checkpoint = True -use_transformer_ckpt = True +use_checkpoint = False #True +use_transformer_ckpt = False #True use_text_cross_attention = True text_dropout = 0.0 fusion_dropout = 0.0 diff --git a/groundingdino/config/GroundingDINO_SwinT_OGC.py b/groundingdino/config/GroundingDINO_SwinT_OGC.py index 9158d5f..a1196a2 100644 --- a/groundingdino/config/GroundingDINO_SwinT_OGC.py +++ b/groundingdino/config/GroundingDINO_SwinT_OGC.py @@ -34,8 +34,8 @@ max_text_len = 256 text_encoder_type = "bert-base-uncased" use_text_enhancer = True use_fusion_layer = True -use_checkpoint = True -use_transformer_ckpt = True +use_checkpoint = False #True +use_transformer_ckpt = False #True use_text_cross_attention = True text_dropout = 0.0 fusion_dropout = 0.0 diff --git a/groundingdino/models/GroundingDINO/bertwarper.py b/groundingdino/models/GroundingDINO/bertwarper.py index f0cf977..9dedbf5 100644 --- a/groundingdino/models/GroundingDINO/bertwarper.py +++ b/groundingdino/models/GroundingDINO/bertwarper.py @@ -190,6 +190,7 @@ def generate_masks_with_special_tokens(tokenized, special_tokens_list, tokenizer # special_tokens_mask: bs, num_token. 1 for special tokens. 0 for normal tokens special_tokens_mask = torch.zeros((bs, num_token), device=input_ids.device).bool() for special_token in special_tokens_list: +# special_tokens_mask = torch.logical_or(special_tokens_mask, input_ids == special_token) special_tokens_mask |= input_ids == special_token # idxs: each row is a list of indices of special tokens @@ -234,6 +235,7 @@ def generate_masks_with_special_tokens_and_transfer_map(tokenized, special_token # special_tokens_mask: bs, num_token. 1 for special tokens. 0 for normal tokens special_tokens_mask = torch.zeros((bs, num_token), device=input_ids.device).bool() for special_token in special_tokens_list: +# special_tokens_mask = torch.logical_or(special_tokens_mask, input_ids == special_token) special_tokens_mask |= input_ids == special_token # idxs: each row is a list of indices of special tokens diff --git a/groundingdino/models/GroundingDINO/groundingdino.py b/groundingdino/models/GroundingDINO/groundingdino.py index cd97028..37c3b95 100644 --- a/groundingdino/models/GroundingDINO/groundingdino.py +++ b/groundingdino/models/GroundingDINO/groundingdino.py @@ -19,7 +19,7 @@ from typing import List import torch import torch.nn.functional as F -from torch import nn +from torch import nn, Tensor from torchvision.ops.boxes import nms from transformers import AutoTokenizer, BertModel, BertTokenizer, RobertaModel, RobertaTokenizerFast @@ -224,7 +224,14 @@ class GroundingDINO(nn.Module): def init_ref_points(self, use_num_queries): self.refpoint_embed = nn.Embedding(use_num_queries, self.query_dim) - def forward(self, samples: NestedTensor, targets: List = None, **kw): +# def forward(self, samples: NestedTensor, targets: List = None, **kw): + def forward(self, + samples: NestedTensor, + last_hidden_state: Tensor, + attention_mask: Tensor, + position_ids: Tensor, + text_self_attention_masks: Tensor, + **kw): """The forward expects a NestedTensor, which consists of: - samples.tensor: batched images, of shape [batch_size x 3 x H x W] - samples.mask: a binary mask of shape [batch_size x H x W], containing 1 on padded pixels @@ -239,6 +246,7 @@ class GroundingDINO(nn.Module): - "aux_outputs": Optional, only returned when auxilary losses are activated. It is a list of dictionnaries containing the two above keys for each decoder layer. """ + """ if targets is None: captions = kw["captions"] else: @@ -248,38 +256,48 @@ class GroundingDINO(nn.Module): tokenized = self.tokenizer(captions, padding="longest", return_tensors="pt").to( samples.device ) - ( - text_self_attention_masks, - position_ids, - cate_to_token_mask_list, - ) = generate_masks_with_special_tokens_and_transfer_map( - tokenized, self.specical_tokens, self.tokenizer - ) + """ +# tokenized = { +# "input_ids": input_ids, +# "attention_mask": attention_mask, +# "token_type_ids": token_type_ids, +# } - if text_self_attention_masks.shape[1] > self.max_text_len: - text_self_attention_masks = text_self_attention_masks[ - :, : self.max_text_len, : self.max_text_len - ] - position_ids = position_ids[:, : self.max_text_len] - tokenized["input_ids"] = tokenized["input_ids"][:, : self.max_text_len] - tokenized["attention_mask"] = tokenized["attention_mask"][:, : self.max_text_len] - tokenized["token_type_ids"] = tokenized["token_type_ids"][:, : self.max_text_len] +# ( +# text_self_attention_masks, +# position_ids, +# cate_to_token_mask_list, +# ) = generate_masks_with_special_tokens_and_transfer_map( +# tokenized, self.specical_tokens, self.tokenizer +# ) - # extract text embeddings - if self.sub_sentence_present: - tokenized_for_encoder = {k: v for k, v in tokenized.items() if k != "attention_mask"} - tokenized_for_encoder["attention_mask"] = text_self_attention_masks - tokenized_for_encoder["position_ids"] = position_ids - else: - # import ipdb; ipdb.set_trace() - tokenized_for_encoder = tokenized +# if text_self_attention_masks.shape[1] > self.max_text_len: +# text_self_attention_masks = text_self_attention_masks[ +# :, : self.max_text_len, : self.max_text_len +# ] +# position_ids = position_ids[:, : self.max_text_len] +# tokenized["input_ids"] = tokenized["input_ids"][:, : self.max_text_len] +# tokenized["attention_mask"] = tokenized["attention_mask"][:, : self.max_text_len] +# tokenized["token_type_ids"] = tokenized["token_type_ids"][:, : self.max_text_len] - bert_output = self.bert(**tokenized_for_encoder) # bs, 195, 768 +# # extract text embeddings +# if self.sub_sentence_present: +# tokenized_for_encoder = {k: v for k, v in tokenized.items() if k != "attention_mask"} +# tokenized_for_encoder["attention_mask"] = text_self_attention_masks +# tokenized_for_encoder["position_ids"] = position_ids +# else: +# # import ipdb; ipdb.set_trace() +# tokenized_for_encoder = tokenized - encoded_text = self.feat_map(bert_output["last_hidden_state"]) # bs, 195, d_model - text_token_mask = tokenized.attention_mask.bool() # bs, 195 +# bert_output = self.bert(**tokenized_for_encoder) # bs, 195, 768 + +# encoded_text = self.feat_map(bert_output["last_hidden_state"]) # bs, 195, d_model +# text_token_mask = tokenized["attention_mask"].bool() # bs, 195 +# text_token_mask = tokenized.attention_mask.bool() # bs, 195 # text_token_mask: True for nomask, False for mask # text_self_attention_masks: True for nomask, False for mask + encoded_text = self.feat_map(last_hidden_state) # bs, 195, d_model + text_token_mask = attention_mask.bool() # bs, 195 if encoded_text.shape[1] > self.max_text_len: encoded_text = encoded_text[:, : self.max_text_len, :] diff --git a/groundingdino/models/GroundingDINO/transformer.py b/groundingdino/models/GroundingDINO/transformer.py index fcb8742..c04d440 100644 --- a/groundingdino/models/GroundingDINO/transformer.py +++ b/groundingdino/models/GroundingDINO/transformer.py @@ -859,7 +859,8 @@ class DeformableTransformerDecoderLayer(nn.Module): return tensor if pos is None else tensor + pos def forward_ffn(self, tgt): - with torch.cuda.amp.autocast(enabled=False): +# with torch.cuda.amp.autocast(enabled=False): + with torch.amp.autocast(str(tgt.device), enabled=False): tgt2 = self.linear2(self.dropout3(self.activation(self.linear1(tgt)))) tgt = tgt + self.dropout4(tgt2) tgt = self.norm3(tgt) diff --git a/groundingdino/util/inference.py b/groundingdino/util/inference.py index 84a962e..b4d6ee2 100644 --- a/groundingdino/util/inference.py +++ b/groundingdino/util/inference.py @@ -13,6 +13,7 @@ from groundingdino.models import build_model from groundingdino.util.misc import clean_state_dict from groundingdino.util.slconfig import SLConfig from groundingdino.util.utils import get_phrases_from_posmap +from groundingdino.models.GroundingDINO.bertwarper import generate_masks_with_special_tokens_and_transfer_map # ---------------------------------------------------------------------------------------------------------------------- # OLD API @@ -64,8 +65,44 @@ def predict( model = model.to(device) image = image.to(device) + tokenizer = model.tokenizer + tokenized = tokenizer([caption], padding="longest", return_tensors="pt").to( + device + ) + + specical_tokens = model.specical_tokens + ( + text_self_attention_masks, + position_ids, + _, + ) = generate_masks_with_special_tokens_and_transfer_map( + tokenized, specical_tokens, tokenizer + ) + + max_text_len = model.max_text_len + if text_self_attention_masks.shape[1] > max_text_len: + text_self_attention_masks = text_self_attention_masks[ + :, : max_text_len, : max_text_len + ] + position_ids = position_ids[:, : max_text_len] + tokenized["input_ids"] = tokenized["input_ids"][:, : max_text_len] + tokenized["attention_mask"] = tokenized["attention_mask"][:, : max_text_len] + tokenized["token_type_ids"] = tokenized["token_type_ids"][:, : max_text_len] + + tokenized_for_encoder = {k: v for k, v in tokenized.items() if k != "attention_mask"} + tokenized_for_encoder["attention_mask"] = text_self_attention_masks + tokenized_for_encoder["position_ids"] = position_ids + + bert = model.bert + bert_output = bert(**tokenized_for_encoder) # bs, 195, 768 + with torch.no_grad(): - outputs = model(image[None], captions=[caption]) + outputs = model(image.unsqueeze(0), + last_hidden_state=bert_output["last_hidden_state"], + attention_mask=tokenized["attention_mask"], + position_ids = position_ids, + text_self_attention_masks = text_self_attention_masks) +# outputs = model(image[None], captions=[caption]) prediction_logits = outputs["pred_logits"].cpu().sigmoid()[0] # prediction_logits.shape = (nq, 256) prediction_boxes = outputs["pred_boxes"].cpu()[0] # prediction_boxes.shape = (nq, 4) @@ -74,7 +111,7 @@ def predict( logits = prediction_logits[mask] # logits.shape = (n, 256) boxes = prediction_boxes[mask] # boxes.shape = (n, 4) - tokenizer = model.tokenizer +# tokenizer = model.tokenizer tokenized = tokenizer(caption) if remove_combined: diff --git a/sample.py b/sample.py new file mode 100644 index 0000000..6c680ee --- /dev/null +++ b/sample.py @@ -0,0 +1,26 @@ +import numpy as np +import cv2 +import supervision as sv + +from groundingdino.util.inference import Model, annotate + +image = cv2.imread("asset/cat_dog.jpeg") +caption = "cat . dog" + +model = Model("groundingdino/config/GroundingDINO_SwinT_OGC.py", + "weights/groundingdino_swint_ogc.pth", + "cpu") + +detections, phrases = model.predict_with_caption(image, caption) + +labels = [ f"{phrase}" for phrase in phrases ] + +bbox_annotator = sv.BoxAnnotator(color_lookup=sv.ColorLookup.INDEX) +label_annotator = sv.LabelAnnotator(color_lookup=sv.ColorLookup.INDEX) +annotated_frame = cv2.cvtColor(image, cv2.COLOR_RGB2BGR) +annotated_frame = bbox_annotator.annotate(scene=image, detections=detections) +annotated_frame = label_annotator.annotate(scene=image, detections=detections, labels=labels) + +cv2.imshow("image", annotated_frame) +cv2.waitKey() +cv2.destroyAllWindows()