diff --git a/projects/GLEE/glee/models/glee_model.py b/projects/GLEE/glee/models/glee_model.py index a433d88..1d69d36 100644 --- a/projects/GLEE/glee/models/glee_model.py +++ b/projects/GLEE/glee/models/glee_model.py @@ -231,7 +231,7 @@ class GLEE_Model(nn.Module): if 'spatial' in prompts: ## setp 1,2,3 - key_images = [ images ] #bz*[1,3,H,W] + key_images = [images.tensor[kid].unsqueeze(0) for kid in range(len(images.tensor))] #bz*[1,3,H,W] key_promptmasks = [m.unsqueeze(0) for m in prompts['spatial']] #bz*[1,1,H,W] if is_train: