From 0c2931b8cd0d4d76e7c598ab071e3a831a0d5725 Mon Sep 17 00:00:00 2001
From: SkalskiP <piotr.skalski92@gmail.com>
Date: Tue, 28 Mar 2023 00:55:25 +0200
Subject: [PATCH] Test fix for #11

---
 demo/inference_on_a_image.py    |  2 +-
 groundingdino/util/inference.py |  2 +-
 groundingdino/util/utils.py     | 25 ++++++-------------------
 3 files changed, 8 insertions(+), 21 deletions(-)

diff --git a/demo/inference_on_a_image.py b/demo/inference_on_a_image.py
index 8c91899..62546d7 100644
--- a/demo/inference_on_a_image.py
+++ b/demo/inference_on_a_image.py
@@ -108,7 +108,7 @@ def get_grounding_output(model, image, caption, box_threshold, text_threshold, w
     # build pred
     pred_phrases = []
     for logit, box in zip(logits_filt, boxes_filt):
-        pred_phrase = get_phrases_from_posmap(logit > text_threshold, tokenized, caption)
+        pred_phrase = get_phrases_from_posmap(logit > text_threshold, tokenized, tokenlizer)
         if with_logits:
             pred_phrases.append(pred_phrase + f"({str(logit.max().item())[:4]})")
         else:
diff --git a/groundingdino/util/inference.py b/groundingdino/util/inference.py
index 72c84ed..73d3b9c 100644
--- a/groundingdino/util/inference.py
+++ b/groundingdino/util/inference.py
@@ -71,7 +71,7 @@ def predict(
     tokenized = tokenizer(caption)
 
     phrases = [
-        get_phrases_from_posmap(logit > text_threshold, tokenized, caption).replace('.', '')
+        get_phrases_from_posmap(logit > text_threshold, tokenized, tokenizer).replace('.', '')
         for logit
         in logits
     ]
diff --git a/groundingdino/util/utils.py b/groundingdino/util/utils.py
index c0d4268..e9f0318 100644
--- a/groundingdino/util/utils.py
+++ b/groundingdino/util/utils.py
@@ -7,6 +7,7 @@ from typing import Any, Dict, List
 
 import numpy as np
 import torch
+from transformers import AutoTokenizer
 
 from groundingdino.util.slconfig import SLConfig
 
@@ -595,27 +596,13 @@ def targets_to(targets: List[Dict[str, Any]], device):
     ]
 
 
-def get_phrases_from_posmap(posmap: torch.BoolTensor, tokenlized, caption: str):
+def get_phrases_from_posmap(
+    posmap: torch.BoolTensor, tokenized: Dict, tokenizer: AutoTokenizer
+):
     assert isinstance(posmap, torch.Tensor), "posmap must be torch.Tensor"
     if posmap.dim() == 1:
         non_zero_idx = posmap.nonzero(as_tuple=True)[0].tolist()
-        words_list = caption.split()
-
-        # build word idx list
-        words_idx_used_list = []
-        for idx in non_zero_idx:
-            word_idx = tokenlized.token_to_word(idx)
-            if word_idx is not None:
-                words_idx_used_list.append(word_idx)
-        words_idx_used_list = set(words_idx_used_list)
-
-        # build phrase
-        words_used_list = []
-        for idx, word in enumerate(words_list):
-            if idx in words_idx_used_list:
-                words_used_list.append(word)
-
-        sentence_res = " ".join(words_used_list)
-        return sentence_res
+        token_ids = [tokenized["input_ids"][i] for i in non_zero_idx]
+        return tokenizer.decode(token_ids)
     else:
         raise NotImplementedError("posmap must be 1-dim")