mirror of
https://github.com/open-mmlab/mmdeploy.git
synced 2025-01-14 08:09:43 +08:00
fix topk (#1439)
This commit is contained in:
parent
9ea8610133
commit
b521e7da03
@ -18,7 +18,7 @@ def topk__dynamic(ctx,
|
||||
sorted: bool = True):
|
||||
"""Rewrite `topk` for default backend.
|
||||
|
||||
Cast k to tensor and makesure k is smaller than input.shape[dim].
|
||||
Cast k to tensor and make sure k is smaller than input.shape[dim].
|
||||
"""
|
||||
|
||||
if dim is None:
|
||||
@ -28,7 +28,8 @@ def topk__dynamic(ctx,
|
||||
k = torch.tensor(k, device=input.device, dtype=torch.long)
|
||||
# Always keep topk op for dynamic input
|
||||
if isinstance(size, torch.Tensor):
|
||||
size = size.to(input.device)
|
||||
# size would be treated as cpu tensor, trick to avoid that.
|
||||
size = k.new_zeros(()) + size
|
||||
k = torch.where(k < size, k, size)
|
||||
return ctx.origin_func(input, k, dim=dim, largest=largest, sorted=sorted)
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user