Cherry-pick #1439 to fix 'topk' on different devices for onnxruntime-gpu inference (#1603)

Co-authored-by: grimoire <yaoqian@sensetime.com>
pull/1514/merge
hanrui1sensetime 2023-01-04 23:10:19 +08:00 committed by GitHub
parent c67e2db68e
commit 8a05b8d62d
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 3 additions and 2 deletions

View File

@ -17,7 +17,7 @@ def topk__dynamic(input: torch.Tensor,
sorted: bool = True):
"""Rewrite `topk` for default backend.
Cast k to tensor and makesure k is smaller than input.shape[dim].
Cast k to tensor and make sure k is smaller than input.shape[dim].
"""
ctx = FUNCTION_REWRITER.get_context()
@ -28,7 +28,8 @@ def topk__dynamic(input: torch.Tensor,
k = torch.tensor(k, device=input.device, dtype=torch.long)
# Always keep topk op for dynamic input
if isinstance(size, torch.Tensor):
size = size.to(input.device)
# size would be treated as cpu tensor, trick to avoid that.
size = k.new_zeros(()) + size
k = torch.where(k < size, k, size)
return ctx.origin_func(input, k, dim=dim, largest=largest, sorted=sorted)