Update glee_model.py

pull/40/head
Junfeng Wu 2024-05-07 18:06:03 +08:00 committed by GitHub
parent 8dec8b138d
commit a2cd1f8884
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
1 changed files with 1 additions and 1 deletions

View File

@ -231,7 +231,7 @@ class GLEE_Model(nn.Module):
if 'spatial' in prompts:
## setp 1,2,3
key_images = [ images ] #bz*[1,3,H,W]
key_images = [images.tensor[kid].unsqueeze(0) for kid in range(len(images.tensor))] #bz*[1,3,H,W]
key_promptmasks = [m.unsqueeze(0) for m in prompts['spatial']] #bz*[1,1,H,W]
if is_train: