From 36d06f7c5df5b7b3f8cb4c82be056b119f3d4ae2 Mon Sep 17 00:00:00 2001 From: Sayantan Sarkar Date: Mon, 9 Dec 2024 09:57:46 -0800 Subject: [PATCH 1/2] Optimize for topk=1 case if we do not handle duplicates --- vllm/model_executor/layers/sampler.py | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/vllm/model_executor/layers/sampler.py b/vllm/model_executor/layers/sampler.py index 8aa6646c5dcea..d148c0e4e0298 100755 --- a/vllm/model_executor/layers/sampler.py +++ b/vllm/model_executor/layers/sampler.py @@ -405,6 +405,14 @@ def __init__(self, increment: int): self._increment = increment def __call__(self, logits: torch.Tensor, p: float, k: int): + if k==1 and not ApplyToppTopkScalar._handle_duplicates: + new_logits = torch.full(logits.shape, + -float("inf"), + device=logits.device) + vals, idx = torch.max(logits, keepdim=True, dim=1) + new_logits.scatter_(1, idx, vals.to(new_logits.dtype)) + return new_logits + if k > ApplyToppTopkScalar._padded_k: ApplyToppTopkScalar._padded_k = min(k + self._increment, logits.shape[1]) From be5f32b0fec2216a7b9425188bbaa798c67523b0 Mon Sep 17 00:00:00 2001 From: Sayantan Sarkar Date: Tue, 10 Dec 2024 16:31:15 -0800 Subject: [PATCH 2/2] style --- vllm/model_executor/layers/sampler.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/vllm/model_executor/layers/sampler.py b/vllm/model_executor/layers/sampler.py index d148c0e4e0298..6af14c4cf6aee 100755 --- a/vllm/model_executor/layers/sampler.py +++ b/vllm/model_executor/layers/sampler.py @@ -405,10 +405,10 @@ def __init__(self, increment: int): self._increment = increment def __call__(self, logits: torch.Tensor, p: float, k: int): - if k==1 and not ApplyToppTopkScalar._handle_duplicates: + if k == 1 and not ApplyToppTopkScalar._handle_duplicates: new_logits = torch.full(logits.shape, - -float("inf"), - device=logits.device) + -float("inf"), + device=logits.device) vals, idx = torch.max(logits, keepdim=True, dim=1) new_logits.scatter_(1, idx, vals.to(new_logits.dtype)) return new_logits