diff --git a/demo/inference_on_a_image.py b/demo/inference_on_a_image.py index 8c91899..62546d7 100644 --- a/demo/inference_on_a_image.py +++ b/demo/inference_on_a_image.py @@ -108,7 +108,7 @@ def get_grounding_output(model, image, caption, box_threshold, text_threshold, w # build pred pred_phrases = [] for logit, box in zip(logits_filt, boxes_filt): - pred_phrase = get_phrases_from_posmap(logit > text_threshold, tokenized, caption) + pred_phrase = get_phrases_from_posmap(logit > text_threshold, tokenized, tokenlizer) if with_logits: pred_phrases.append(pred_phrase + f"({str(logit.max().item())[:4]})") else: diff --git a/groundingdino/util/inference.py b/groundingdino/util/inference.py index 72c84ed..73d3b9c 100644 --- a/groundingdino/util/inference.py +++ b/groundingdino/util/inference.py @@ -71,7 +71,7 @@ def predict( tokenized = tokenizer(caption) phrases = [ - get_phrases_from_posmap(logit > text_threshold, tokenized, caption).replace('.', '') + get_phrases_from_posmap(logit > text_threshold, tokenized, tokenizer).replace('.', '') for logit in logits ] diff --git a/groundingdino/util/utils.py b/groundingdino/util/utils.py index c0d4268..e9f0318 100644 --- a/groundingdino/util/utils.py +++ b/groundingdino/util/utils.py @@ -7,6 +7,7 @@ from typing import Any, Dict, List import numpy as np import torch +from transformers import AutoTokenizer from groundingdino.util.slconfig import SLConfig @@ -595,27 +596,13 @@ def targets_to(targets: List[Dict[str, Any]], device): ] -def get_phrases_from_posmap(posmap: torch.BoolTensor, tokenlized, caption: str): +def get_phrases_from_posmap( + posmap: torch.BoolTensor, tokenized: Dict, tokenizer: AutoTokenizer +): assert isinstance(posmap, torch.Tensor), "posmap must be torch.Tensor" if posmap.dim() == 1: non_zero_idx = posmap.nonzero(as_tuple=True)[0].tolist() - words_list = caption.split() - - # build word idx list - words_idx_used_list = [] - for idx in non_zero_idx: - word_idx = tokenlized.token_to_word(idx) - if word_idx is not None: - words_idx_used_list.append(word_idx) - words_idx_used_list = set(words_idx_used_list) - - # build phrase - words_used_list = [] - for idx, word in enumerate(words_list): - if idx in words_idx_used_list: - words_used_list.append(word) - - sentence_res = " ".join(words_used_list) - return sentence_res + token_ids = [tokenized["input_ids"][i] for i in non_zero_idx] + return tokenizer.decode(token_ids) else: raise NotImplementedError("posmap must be 1-dim")