mirror of https://github.com/FoundationVision/GLEE
Update glee_model.py
parent
8dec8b138d
commit
a2cd1f8884
|
@ -231,7 +231,7 @@ class GLEE_Model(nn.Module):
|
|||
|
||||
if 'spatial' in prompts:
|
||||
## setp 1,2,3
|
||||
key_images = [ images ] #bz*[1,3,H,W]
|
||||
key_images = [images.tensor[kid].unsqueeze(0) for kid in range(len(images.tensor))] #bz*[1,3,H,W]
|
||||
key_promptmasks = [m.unsqueeze(0) for m in prompts['spatial']] #bz*[1,1,H,W]
|
||||
|
||||
if is_train:
|
||||
|
|
Loading…
Reference in New Issue