Co-authored-by: grimoire <yaoqian@sensetime.com>pull/1514/merge
parent
c67e2db68e
commit
8a05b8d62d
|
@ -17,7 +17,7 @@ def topk__dynamic(input: torch.Tensor,
|
||||||
sorted: bool = True):
|
sorted: bool = True):
|
||||||
"""Rewrite `topk` for default backend.
|
"""Rewrite `topk` for default backend.
|
||||||
|
|
||||||
Cast k to tensor and makesure k is smaller than input.shape[dim].
|
Cast k to tensor and make sure k is smaller than input.shape[dim].
|
||||||
"""
|
"""
|
||||||
ctx = FUNCTION_REWRITER.get_context()
|
ctx = FUNCTION_REWRITER.get_context()
|
||||||
|
|
||||||
|
@ -28,7 +28,8 @@ def topk__dynamic(input: torch.Tensor,
|
||||||
k = torch.tensor(k, device=input.device, dtype=torch.long)
|
k = torch.tensor(k, device=input.device, dtype=torch.long)
|
||||||
# Always keep topk op for dynamic input
|
# Always keep topk op for dynamic input
|
||||||
if isinstance(size, torch.Tensor):
|
if isinstance(size, torch.Tensor):
|
||||||
size = size.to(input.device)
|
# size would be treated as cpu tensor, trick to avoid that.
|
||||||
|
size = k.new_zeros(()) + size
|
||||||
k = torch.where(k < size, k, size)
|
k = torch.where(k < size, k, size)
|
||||||
return ctx.origin_func(input, k, dim=dim, largest=largest, sorted=sorted)
|
return ctx.origin_func(input, k, dim=dim, largest=largest, sorted=sorted)
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue