This commit is contained in:
q.yao 2022-11-28 17:35:15 +08:00 committed by GitHub
parent 9ea8610133
commit b521e7da03
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -18,7 +18,7 @@ def topk__dynamic(ctx,
sorted: bool = True):
"""Rewrite `topk` for default backend.
Cast k to tensor and makesure k is smaller than input.shape[dim].
Cast k to tensor and make sure k is smaller than input.shape[dim].
"""
if dim is None:
@ -28,7 +28,8 @@ def topk__dynamic(ctx,
k = torch.tensor(k, device=input.device, dtype=torch.long)
# Always keep topk op for dynamic input
if isinstance(size, torch.Tensor):
size = size.to(input.device)
# size would be treated as cpu tensor, trick to avoid that.
size = k.new_zeros(()) + size
k = torch.where(k < size, k, size)
return ctx.origin_func(input, k, dim=dim, largest=largest, sorted=sorted)