From a74e7fd9912ac60452def1fa89d282bbb2bd2ca5 Mon Sep 17 00:00:00 2001 From: superxf Date: Mon, 22 Sep 2025 10:33:32 +0800 Subject: [PATCH] opt topk --- .../v1/sample/ops/topk_topp_sampler.py | 38 +++++++++++++++---- 1 file changed, 30 insertions(+), 8 deletions(-) diff --git a/vllm_mindspore/v1/sample/ops/topk_topp_sampler.py b/vllm_mindspore/v1/sample/ops/topk_topp_sampler.py index 694ce11a..9bde914a 100644 --- a/vllm_mindspore/v1/sample/ops/topk_topp_sampler.py +++ b/vllm_mindspore/v1/sample/ops/topk_topp_sampler.py @@ -26,26 +26,48 @@ from mindspore import mint def apply_top_k_top_p_ms(logits, k, p): """ - Apply top-k and top-p masks to the logits for high performance. - which is reference from 'apply_top_k_top_p_tpu' in vllm. + Apply top-k and top-p masks to the logits. + + This optimized version performs top-p calculations *only* on the + top-k elements, significantly reducing sort and cumsum overhead. """ - if k is not None: - # use `apply_top_k_only` defined in this file. - logits = apply_top_k_only(logits, k) - if p is not None: + if k is None: + if p is None: + return logits + probs = logits.softmax(dim=-1) probs_sort, _ = mint.sort(probs, dim=-1, descending=False) cumprob = mint.cumsum(probs_sort, dim=-1) top_p_mask = cumprob <= 1 - p.unsqueeze(dim=1) top_p_mask[:, -1] = False # at least one - top_p_count = top_p_mask.sum(dim=-1).unsqueeze(1) top_p_cutoff = probs_sort.gather(-1, top_p_count) elements_to_discard = probs < top_p_cutoff logits.masked_fill_(elements_to_discard, -float("inf")) + return logits - return logits + no_top_k_mask = (k == logits.shape[1]) + k_for_topk = k.masked_fill(no_top_k_mask, 1) + max_top_k = k_for_topk.max() + int_max_top_k = max_top_k.item() + top_k_logits, top_k_indices = logits.topk(int_max_top_k, dim=-1) + indices_to_keep = mint.arange(int_max_top_k).unsqueeze(0) + k_mask = indices_to_keep < k.unsqueeze(1) + k_mask = k_mask | no_top_k_mask.unsqueeze(1) + top_k_logits.masked_fill_(~k_mask, -float("inf")) + + if p is not None: + probs_k = top_k_logits.softmax(dim=-1) + cumprob_k = mint.cumsum(probs_k, dim=-1) + shifted_cumprob_k = cumprob_k - probs_k + top_p_mask = shifted_cumprob_k > p.unsqueeze(dim=1) + top_k_logits.masked_fill_(top_p_mask, -float("inf")) + + final_logits = mint.full_like(logits, -float("inf")) + final_logits.scatter_(dim=-1, index=top_k_indices, src=top_k_logits) + final_logits = mint.where(no_top_k_mask.unsqueeze(1), logits, final_logits) + return final_logits def random_sample( -- Gitee