parent
858efccbad
commit
c974f60d73
|
@ -108,7 +108,7 @@ def get_grounding_output(model, image, caption, box_threshold, text_threshold, w
|
|||
# build pred
|
||||
pred_phrases = []
|
||||
for logit, box in zip(logits_filt, boxes_filt):
|
||||
pred_phrase = get_phrases_from_posmap(logit > text_threshold, tokenized, caption)
|
||||
pred_phrase = get_phrases_from_posmap(logit > text_threshold, tokenized, tokenlizer)
|
||||
if with_logits:
|
||||
pred_phrases.append(pred_phrase + f"({str(logit.max().item())[:4]})")
|
||||
else:
|
||||
|
|
|
@ -71,7 +71,7 @@ def predict(
|
|||
tokenized = tokenizer(caption)
|
||||
|
||||
phrases = [
|
||||
get_phrases_from_posmap(logit > text_threshold, tokenized, caption).replace('.', '')
|
||||
get_phrases_from_posmap(logit > text_threshold, tokenized, tokenizer).replace('.', '')
|
||||
for logit
|
||||
in logits
|
||||
]
|
||||
|
|
|
@ -7,6 +7,7 @@ from typing import Any, Dict, List
|
|||
|
||||
import numpy as np
|
||||
import torch
|
||||
from transformers import AutoTokenizer
|
||||
|
||||
from groundingdino.util.slconfig import SLConfig
|
||||
|
||||
|
@ -595,27 +596,13 @@ def targets_to(targets: List[Dict[str, Any]], device):
|
|||
]
|
||||
|
||||
|
||||
def get_phrases_from_posmap(posmap: torch.BoolTensor, tokenlized, caption: str):
|
||||
def get_phrases_from_posmap(
|
||||
posmap: torch.BoolTensor, tokenized: Dict, tokenizer: AutoTokenizer
|
||||
):
|
||||
assert isinstance(posmap, torch.Tensor), "posmap must be torch.Tensor"
|
||||
if posmap.dim() == 1:
|
||||
non_zero_idx = posmap.nonzero(as_tuple=True)[0].tolist()
|
||||
words_list = caption.split()
|
||||
|
||||
# build word idx list
|
||||
words_idx_used_list = []
|
||||
for idx in non_zero_idx:
|
||||
word_idx = tokenlized.token_to_word(idx)
|
||||
if word_idx is not None:
|
||||
words_idx_used_list.append(word_idx)
|
||||
words_idx_used_list = set(words_idx_used_list)
|
||||
|
||||
# build phrase
|
||||
words_used_list = []
|
||||
for idx, word in enumerate(words_list):
|
||||
if idx in words_idx_used_list:
|
||||
words_used_list.append(word)
|
||||
|
||||
sentence_res = " ".join(words_used_list)
|
||||
return sentence_res
|
||||
token_ids = [tokenized["input_ids"][i] for i in non_zero_idx]
|
||||
return tokenizer.decode(token_ids)
|
||||
else:
|
||||
raise NotImplementedError("posmap must be 1-dim")
|
||||
|
|
Loading…
Reference in New Issue