Co-authored-by: grimoire <yaoqian@sensetime.com>pull/1514/merge
parent
c67e2db68e
commit
8a05b8d62d
|
@ -17,7 +17,7 @@ def topk__dynamic(input: torch.Tensor,
|
|||
sorted: bool = True):
|
||||
"""Rewrite `topk` for default backend.
|
||||
|
||||
Cast k to tensor and makesure k is smaller than input.shape[dim].
|
||||
Cast k to tensor and make sure k is smaller than input.shape[dim].
|
||||
"""
|
||||
ctx = FUNCTION_REWRITER.get_context()
|
||||
|
||||
|
@ -28,7 +28,8 @@ def topk__dynamic(input: torch.Tensor,
|
|||
k = torch.tensor(k, device=input.device, dtype=torch.long)
|
||||
# Always keep topk op for dynamic input
|
||||
if isinstance(size, torch.Tensor):
|
||||
size = size.to(input.device)
|
||||
# size would be treated as cpu tensor, trick to avoid that.
|
||||
size = k.new_zeros(()) + size
|
||||
k = torch.where(k < size, k, size)
|
||||
return ctx.origin_func(input, k, dim=dim, largest=largest, sorted=sorted)
|
||||
|
||||
|
|
Loading…
Reference in New Issue