diff --git a/mmdeploy/pytorch/functions/topk.py b/mmdeploy/pytorch/functions/topk.py index 38dac1978..1c230db78 100644 --- a/mmdeploy/pytorch/functions/topk.py +++ b/mmdeploy/pytorch/functions/topk.py @@ -17,7 +17,7 @@ def topk__dynamic(input: torch.Tensor, sorted: bool = True): """Rewrite `topk` for default backend. - Cast k to tensor and makesure k is smaller than input.shape[dim]. + Cast k to tensor and make sure k is smaller than input.shape[dim]. """ ctx = FUNCTION_REWRITER.get_context() @@ -28,7 +28,8 @@ def topk__dynamic(input: torch.Tensor, k = torch.tensor(k, device=input.device, dtype=torch.long) # Always keep topk op for dynamic input if isinstance(size, torch.Tensor): - size = size.to(input.device) + # size would be treated as cpu tensor, trick to avoid that. + size = k.new_zeros(()) + size k = torch.where(k < size, k, size) return ctx.origin_func(input, k, dim=dim, largest=largest, sorted=sorted)