Cherry-pick #1439 to fix 'topk' on different devices for onnxruntime-gpu inference (#1603)

Co-authored-by: grimoire <yaoqian@sensetime.com>
This commit is contained in:
hanrui1sensetime 2023-01-04 23:10:19 +08:00 committed by GitHub
parent c67e2db68e
commit 8a05b8d62d
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -28,7 +28,8 @@ def topk__dynamic(input: torch.Tensor,
k = torch.tensor(k, device=input.device, dtype=torch.long) k = torch.tensor(k, device=input.device, dtype=torch.long)
# Always keep topk op for dynamic input # Always keep topk op for dynamic input
if isinstance(size, torch.Tensor): if isinstance(size, torch.Tensor):
size = size.to(input.device) # size would be treated as cpu tensor, trick to avoid that.
size = k.new_zeros(()) + size
k = torch.where(k < size, k, size) k = torch.where(k < size, k, size)
return ctx.origin_func(input, k, dim=dim, largest=largest, sorted=sorted) return ctx.origin_func(input, k, dim=dim, largest=largest, sorted=sorted)