support phrase grounding mode
parent
beeb4c29cb
commit
a0cc07e12f
Binary file not shown.
After Width: | Height: | Size: 120 KiB |
18
README.md
18
README.md
|
@ -151,13 +151,27 @@ nvidia-smi
|
|||
Replace `{GPU ID}`, `image_you_want_to_detect.jpg`, and `"dir you want to save the output"` with appropriate values in the following command
|
||||
```bash
|
||||
CUDA_VISIBLE_DEVICES={GPU ID} python demo/inference_on_a_image.py \
|
||||
-c /GroundingDINO/groundingdino/config/GroundingDINO_SwinT_OGC.py \
|
||||
-p /GroundingDINO/weights/groundingdino_swint_ogc.pth \
|
||||
-c groundingdino/config/GroundingDINO_SwinT_OGC.py \
|
||||
-p weights/groundingdino_swint_ogc.pth \
|
||||
-i image_you_want_to_detect.jpg \
|
||||
-o "dir you want to save the output" \
|
||||
-t "chair"
|
||||
[--cpu-only] # open it for cpu mode
|
||||
```
|
||||
|
||||
If you would like to specify the phrases to detect, here is a demo:
|
||||
```bash
|
||||
CUDA_VISIBLE_DEVICES={GPU ID} python demo/inference_on_a_image.py \
|
||||
-c groundingdino/config/GroundingDINO_SwinT_OGC.py \
|
||||
-p /comp_robot/liushilong/data/pretrained/grounding_pretrain/groundingdino_swint_ogc.pth \
|
||||
-i .asset/cat_dog.jpeg \
|
||||
-o logs/1111 \
|
||||
-t "There is a cat and a dog in the image ." \
|
||||
--token_spans "[[[9, 10], [11, 14]], [[19, 20], [21, 24]]]"
|
||||
[--cpu-only] # open it for cpu mode
|
||||
```
|
||||
The token_spans specify the start and end positions of a phrases. For example, the first phrase is `[[9, 10], [11, 14]]`. `"There is a cat and a dog in the image ."[9:10] = 'a'`, `"There is a cat and a dog in the image ."[11:14] = 'cat'`. Hence it refere to the phrase `a cat` .
|
||||
|
||||
See the `demo/inference_on_a_image.py` for more details.
|
||||
|
||||
**Running with Python:**
|
||||
|
|
|
@ -11,6 +11,7 @@ from groundingdino.models import build_model
|
|||
from groundingdino.util import box_ops
|
||||
from groundingdino.util.slconfig import SLConfig
|
||||
from groundingdino.util.utils import clean_state_dict, get_phrases_from_posmap
|
||||
from groundingdino.util.vl_utils import create_positive_map_from_span
|
||||
|
||||
|
||||
def plot_boxes_to_image(image_pil, tgt):
|
||||
|
@ -80,7 +81,8 @@ def load_model(model_config_path, model_checkpoint_path, cpu_only=False):
|
|||
return model
|
||||
|
||||
|
||||
def get_grounding_output(model, image, caption, box_threshold, text_threshold, with_logits=True, cpu_only=False):
|
||||
def get_grounding_output(model, image, caption, box_threshold, text_threshold=None, with_logits=True, cpu_only=False, token_spans=None):
|
||||
assert text_threshold is not None or token_spans is not None, "text_threshould and token_spans should not be None at the same time!"
|
||||
caption = caption.lower()
|
||||
caption = caption.strip()
|
||||
if not caption.endswith("."):
|
||||
|
@ -90,29 +92,56 @@ def get_grounding_output(model, image, caption, box_threshold, text_threshold, w
|
|||
image = image.to(device)
|
||||
with torch.no_grad():
|
||||
outputs = model(image[None], captions=[caption])
|
||||
logits = outputs["pred_logits"].cpu().sigmoid()[0] # (nq, 256)
|
||||
boxes = outputs["pred_boxes"].cpu()[0] # (nq, 4)
|
||||
logits.shape[0]
|
||||
logits = outputs["pred_logits"].sigmoid()[0] # (nq, 256)
|
||||
boxes = outputs["pred_boxes"][0] # (nq, 4)
|
||||
|
||||
# filter output
|
||||
logits_filt = logits.clone()
|
||||
boxes_filt = boxes.clone()
|
||||
filt_mask = logits_filt.max(dim=1)[0] > box_threshold
|
||||
logits_filt = logits_filt[filt_mask] # num_filt, 256
|
||||
boxes_filt = boxes_filt[filt_mask] # num_filt, 4
|
||||
logits_filt.shape[0]
|
||||
if token_spans is None:
|
||||
logits_filt = logits.cpu().clone()
|
||||
boxes_filt = boxes.cpu().clone()
|
||||
filt_mask = logits_filt.max(dim=1)[0] > box_threshold
|
||||
logits_filt = logits_filt[filt_mask] # num_filt, 256
|
||||
boxes_filt = boxes_filt[filt_mask] # num_filt, 4
|
||||
|
||||
# get phrase
|
||||
tokenlizer = model.tokenizer
|
||||
tokenized = tokenlizer(caption)
|
||||
# build pred
|
||||
pred_phrases = []
|
||||
for logit, box in zip(logits_filt, boxes_filt):
|
||||
pred_phrase = get_phrases_from_posmap(logit > text_threshold, tokenized, tokenlizer)
|
||||
if with_logits:
|
||||
pred_phrases.append(pred_phrase + f"({str(logit.max().item())[:4]})")
|
||||
else:
|
||||
pred_phrases.append(pred_phrase)
|
||||
else:
|
||||
# given-phrase mode
|
||||
positive_maps = create_positive_map_from_span(
|
||||
model.tokenizer(text_prompt),
|
||||
token_span=token_spans
|
||||
).to(image.device) # n_phrase, 256
|
||||
|
||||
logits_for_phrases = positive_maps @ logits.T # n_phrase, nq
|
||||
all_logits = []
|
||||
all_phrases = []
|
||||
all_boxes = []
|
||||
for (token_span, logit_phr) in zip(token_spans, logits_for_phrases):
|
||||
# get phrase
|
||||
phrase = ' '.join([caption[_s:_e] for (_s, _e) in token_span])
|
||||
# get mask
|
||||
filt_mask = logit_phr > box_threshold
|
||||
# filt box
|
||||
all_boxes.append(boxes[filt_mask])
|
||||
# filt logits
|
||||
all_logits.append(logit_phr[filt_mask])
|
||||
if with_logits:
|
||||
logit_phr_num = logit_phr[filt_mask]
|
||||
all_phrases.extend([phrase + f"({str(logit.item())[:4]})" for logit in logit_phr_num])
|
||||
else:
|
||||
all_phrases.extend([phrase for _ in range(len(filt_mask))])
|
||||
boxes_filt = torch.cat(all_boxes, dim=0).cpu()
|
||||
pred_phrases = all_phrases
|
||||
|
||||
# get phrase
|
||||
tokenlizer = model.tokenizer
|
||||
tokenized = tokenlizer(caption)
|
||||
# build pred
|
||||
pred_phrases = []
|
||||
for logit, box in zip(logits_filt, boxes_filt):
|
||||
pred_phrase = get_phrases_from_posmap(logit > text_threshold, tokenized, tokenlizer)
|
||||
if with_logits:
|
||||
pred_phrases.append(pred_phrase + f"({str(logit.max().item())[:4]})")
|
||||
else:
|
||||
pred_phrases.append(pred_phrase)
|
||||
|
||||
return boxes_filt, pred_phrases
|
||||
|
||||
|
@ -132,6 +161,12 @@ if __name__ == "__main__":
|
|||
|
||||
parser.add_argument("--box_threshold", type=float, default=0.3, help="box threshold")
|
||||
parser.add_argument("--text_threshold", type=float, default=0.25, help="text threshold")
|
||||
parser.add_argument("--token_spans", type=str, default=None, help=
|
||||
"The positions of start and end positions of phrases of interest. \
|
||||
For example, a caption is 'a cat and a dog', \
|
||||
if you would like to detect 'cat', the token_spans should be '[[[2, 5]], ]', since 'a cat and a dog'[2:5] is 'cat'. \
|
||||
if you would like to detect 'a cat', the token_spans should be '[[[0, 1], [2, 5]], ]', since 'a cat and a dog'[0:1] is 'a', and 'a cat and a dog'[2:5] is 'cat'. \
|
||||
")
|
||||
|
||||
parser.add_argument("--cpu-only", action="store_true", help="running on cpu only!, default=False")
|
||||
args = parser.parse_args()
|
||||
|
@ -144,6 +179,7 @@ if __name__ == "__main__":
|
|||
output_dir = args.output_dir
|
||||
box_threshold = args.box_threshold
|
||||
text_threshold = args.text_threshold
|
||||
token_spans = args.token_spans
|
||||
|
||||
# make dir
|
||||
os.makedirs(output_dir, exist_ok=True)
|
||||
|
@ -155,9 +191,15 @@ if __name__ == "__main__":
|
|||
# visualize raw image
|
||||
image_pil.save(os.path.join(output_dir, "raw_image.jpg"))
|
||||
|
||||
# set the text_threshold to None if token_spans is set.
|
||||
if token_spans is not None:
|
||||
text_threshold = None
|
||||
print("Using token_spans. Set the text_threshold to None.")
|
||||
|
||||
|
||||
# run model
|
||||
boxes_filt, pred_phrases = get_grounding_output(
|
||||
model, image, text_prompt, box_threshold, text_threshold, cpu_only=args.cpu_only
|
||||
model, image, text_prompt, box_threshold, text_threshold, cpu_only=args.cpu_only, token_spans=eval(token_spans)
|
||||
)
|
||||
|
||||
# visualize pred
|
||||
|
|
Loading…
Reference in New Issue