diff --git a/groundingdino/util/inference.py b/groundingdino/util/inference.py index 98e6c6d..b4d6ee2 100644 --- a/groundingdino/util/inference.py +++ b/groundingdino/util/inference.py @@ -89,11 +89,17 @@ def predict( tokenized["attention_mask"] = tokenized["attention_mask"][:, : max_text_len] tokenized["token_type_ids"] = tokenized["token_type_ids"][:, : max_text_len] + tokenized_for_encoder = {k: v for k, v in tokenized.items() if k != "attention_mask"} + tokenized_for_encoder["attention_mask"] = text_self_attention_masks + tokenized_for_encoder["position_ids"] = position_ids + + bert = model.bert + bert_output = bert(**tokenized_for_encoder) # bs, 195, 768 + with torch.no_grad(): outputs = model(image.unsqueeze(0), - input_ids=tokenized["input_ids"], + last_hidden_state=bert_output["last_hidden_state"], attention_mask=tokenized["attention_mask"], - token_type_ids=tokenized["token_type_ids"], position_ids = position_ids, text_self_attention_masks = text_self_attention_masks) # outputs = model(image[None], captions=[caption])