From a2cd1f88846c9faa115ec89b652716183f06c7d3 Mon Sep 17 00:00:00 2001 From: Junfeng Wu <44546837+wjf5203@users.noreply.github.com> Date: Tue, 7 May 2024 18:06:03 +0800 Subject: [PATCH] Update glee_model.py --- projects/GLEE/glee/models/glee_model.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/projects/GLEE/glee/models/glee_model.py b/projects/GLEE/glee/models/glee_model.py index a433d88..1d69d36 100644 --- a/projects/GLEE/glee/models/glee_model.py +++ b/projects/GLEE/glee/models/glee_model.py @@ -231,7 +231,7 @@ class GLEE_Model(nn.Module): if 'spatial' in prompts: ## setp 1,2,3 - key_images = [ images ] #bz*[1,3,H,W] + key_images = [images.tensor[kid].unsqueeze(0) for kid in range(len(images.tensor))] #bz*[1,3,H,W] key_promptmasks = [m.unsqueeze(0) for m in prompts['spatial']] #bz*[1,1,H,W] if is_train: