pull/40/merge
Jack Kolb 2024-10-21 15:34:41 +08:00 committed by GitHub
commit 456c25b8ca
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
1 changed files with 3 additions and 4 deletions

7
app.py
View File

@ -166,14 +166,13 @@ def segment_image(img,prompt_mode, categoryname, custom_category, expressiong, r
# prompt_list = [mask_ori[0]]
prompt_list = []
with torch.no_grad():
(outputs,_) = GLEEmodel(infer_image, prompt_list, task="coco", batch_name_list=batch_category_name, is_train=False)
(outputs,_,_) = GLEEmodel(infer_image, prompt_list, task="coco", batch_name_list=batch_category_name, is_train=False)
topK_instance = max(num_inst_select,1)
else:
topK_instance = 1
prompt_list = {'grounding':[expressiong]}
with torch.no_grad():
(outputs,_) = GLEEmodel(infer_image, prompt_list, task="grounding", batch_name_list=[], is_train=False)
(outputs,_,_) = GLEEmodel(infer_image, prompt_list, task="grounding", batch_name_list=[], is_train=False)
mask_pred = outputs['pred_masks'][0]
mask_cls = outputs['pred_logits'][0]
@ -309,7 +308,7 @@ def segment_image(img,prompt_mode, categoryname, custom_category, expressiong, r
prompt_list = {'spatial':[visual_prompt]}
with torch.no_grad():
(outputs,_) = GLEEmodel(infer_image, prompt_list, task="coco", batch_name_list=['object'], is_train=False, visual_prompt_type=prompt_mode )
(outputs,_,_) = GLEEmodel(infer_image, prompt_list, task="coco", batch_name_list=['object'], is_train=False, visual_prompt_type=prompt_mode )
mask_pred = outputs['pred_masks'][0]
mask_cls = outputs['pred_logits'][0]