mirror of
https://github.com/facebookresearch/segment-anything.git
synced 2025-06-03 14:59:27 +08:00
fixing image_encoder to work with cuda_graphs
Summary: the combination of tensors on multiple devices in get_rel_pos was preventing cuda graphs from correctly optimizing things Test Plan: Reviewers: Subscribers: Tasks: Tags: ghstack-source-id: 2256f130bb8249403710e1048ef69385ff71aed2 Pull Request resolved: https://github.com/facebookresearch/segment-anything/pull/393
This commit is contained in:
parent
6fdee8f272
commit
43c910f431
@ -315,8 +315,8 @@ def get_rel_pos(q_size: int, k_size: int, rel_pos: torch.Tensor) -> torch.Tensor
|
||||
rel_pos_resized = rel_pos
|
||||
|
||||
# Scale the coords with short length if shapes for q and k are different.
|
||||
q_coords = torch.arange(q_size)[:, None] * max(k_size / q_size, 1.0)
|
||||
k_coords = torch.arange(k_size)[None, :] * max(q_size / k_size, 1.0)
|
||||
q_coords = (torch.arange(q_size, device=rel_pos.device)[:, None] * max(k_size / q_size, 1.0))
|
||||
k_coords = (torch.arange(k_size, device=rel_pos.device)[None, :] * max(q_size / k_size, 1.0))
|
||||
relative_coords = (q_coords - k_coords) + (k_size - 1) * max(q_size / k_size, 1.0)
|
||||
|
||||
return rel_pos_resized[relative_coords.long()]
|
||||
|
Loading…
x
Reference in New Issue
Block a user