diff --git a/vllm_mindspore/model_executor/layers/utils.py b/vllm_mindspore/model_executor/layers/utils.py index 3081182f0622ee424c7df22bd7bacc6550b2c929..475809788a7856c4248da4226b569b0425030715 100644 --- a/vllm_mindspore/model_executor/layers/utils.py +++ b/vllm_mindspore/model_executor/layers/utils.py @@ -36,6 +36,23 @@ def get_token_bin_counts_and_mask( return bin_counts, mask +def get_repetition_penalties_mask( + prompt_tokens: ms.Tensor, + output_tokens: ms.Tensor, + vocab_size: int, + num_seqs: int, +) -> ms.Tensor: + # Compute the bin counts for the tokens. + # vocab_size + 1 for padding. + bin_counts = ms.mint.zeros((num_seqs, vocab_size + 1), dtype=ms.int64) + bin_counts.scatter_add_(1, prompt_tokens, ms.mint.ones_like(prompt_tokens)) + bin_counts.scatter_add_(1, output_tokens, ms.mint.ones_like(output_tokens)) + bin_counts = bin_counts[:, :vocab_size] + mask = bin_counts > 0 + + return mask + + def apply_penalties(logits: ms.Tensor, prompt_tokens_tensor: ms.Tensor, output_tokens_tensor: ms.Tensor, presence_penalties: ms.Tensor, @@ -57,26 +74,34 @@ def apply_penalties(logits: ms.Tensor, prompt_tokens_tensor: ms.Tensor, if logits.numel() <= 0: return logits num_seqs, vocab_size = logits.shape - _, prompt_mask = get_token_bin_counts_and_mask(prompt_tokens_tensor, - vocab_size, num_seqs) - output_bin_counts, output_mask = get_token_bin_counts_and_mask( - output_tokens_tensor, vocab_size, num_seqs) - # use 'broadcast_to' to replace 'tensor.repeat' to improve performance - # when tensor shape is (num,seqs, 1), then 'tensor.repeat(1, vocab_size)' - # is equal to 'broadcast_to(tensor, (num_seqs, vocab_size))' - repetition_penalties = ms.mint.broadcast_to( - repetition_penalties.unsqueeze(dim=1), (num_seqs, vocab_size)) + if repetition_penalties is not None: + mask = get_repetition_penalties_mask( + prompt_tokens_tensor, + output_tokens_tensor, + vocab_size, + num_seqs, + ) + # use 'broadcast_to' to replace 'tensor.repeat' to improve performance + # when tensor shape is (num,seqs, 1), 'tensor.repeat(1, vocab_size)' + # is equal to 'broadcast_to(tensor, (num_seqs, vocab_size))' + repetition_penalties = ms.mint.broadcast_to( + repetition_penalties.unsqueeze(dim=1), (num_seqs, vocab_size)) - # use out of place computation instead of inplace setitem to improve - # performance 'tensor[tensor > 0]' will result in setitem, which is slow. - mask = prompt_mask | output_mask - logits = ms.mint.where(mask & (logits > 0), logits / repetition_penalties, - logits) - logits = ms.mint.where(mask & (logits <= 0), logits * repetition_penalties, - logits) + # use out of place computation instead of inplace setitem to improve + # performance 'tensor[tensor > 0]' will result in setitem, + # which is slow. + logits = ms.mint.where(mask & (logits > 0), + logits / repetition_penalties, logits) + logits = ms.mint.where(mask & (logits <= 0), + logits * repetition_penalties, logits) # We follow the definition in OpenAI API. - # Refer to https://platform.openai.com/docs/api-reference/parameter-details - logits -= frequency_penalties.unsqueeze(dim=1) * output_bin_counts - logits -= presence_penalties.unsqueeze(dim=1) * output_mask + # Refer to https://platform.openai.com/docs/api-reference/parameter-details\ + if frequency_penalties is not None or presence_penalties is not None: + output_bin_counts, output_mask = get_token_bin_counts_and_mask( + output_tokens_tensor, vocab_size, num_seqs) + if frequency_penalties is not None: + logits -= frequency_penalties.unsqueeze(dim=1) * output_bin_counts + if presence_penalties is not None: + logits -= presence_penalties.unsqueeze(dim=1) * output_mask return logits diff --git a/vllm_mindspore/v1/worker/gpu_input_batch.py b/vllm_mindspore/v1/worker/gpu_input_batch.py index f23dc4df53d6f536505b75964b3497ff6503b58e..0ed0cb125dbd339c61d86e08bcf07d2e764bf598 100644 --- a/vllm_mindspore/v1/worker/gpu_input_batch.py +++ b/vllm_mindspore/v1/worker/gpu_input_batch.py @@ -44,24 +44,33 @@ def _make_sampling_metadata(self) -> SamplingMetadata: _copy_slice_from_np(self.top_k_cpu, self.top_k, num_reqs) if not self.no_min_p: _copy_slice_from_np(self.min_p_cpu, self.min_p, num_reqs) - + frequency_penalties = None + presence_penalties = None + repetition_penalties = None + prompt_token_ids = None if not self.no_penalties: # Since syncing these tensors is expensive only copy them # if necessary i.e. if there are requests which require # penalties to be applied during sampling. - _copy_slice_from_np(self.frequency_penalties_cpu, - self.frequency_penalties, num_reqs) - _copy_slice_from_np(self.presence_penalties_cpu, - self.presence_penalties, num_reqs) - _copy_slice_from_np(self.repetition_penalties_cpu, - self.repetition_penalties, num_reqs) - - # The prompt tokens are used only for applying penalties during - # the sampling process. Hence copy these tensors only when - # there are requests which need penalties to be applied. + apply_freq = not np.all(self.frequency_penalties_cpu[:num_reqs] == 0.0) + apply_pres = not np.all(self.presence_penalties_cpu[:num_reqs] == 0.0) + apply_rep = not np.all(self.repetition_penalties_cpu[:num_reqs] == 1.0) prompt_token_ids = self._make_prompt_token_ids_tensor() - else: - prompt_token_ids = None + + if apply_freq: + _copy_slice_from_np(self.frequency_penalties_cpu, + self.frequency_penalties, num_reqs) + frequency_penalties = self.frequency_penalties[:num_reqs] + + if apply_pres: + _copy_slice_from_np(self.presence_penalties_cpu, + self.presence_penalties, num_reqs) + presence_penalties = self.presence_penalties[:num_reqs] + + if apply_rep: + _copy_slice_from_np(self.repetition_penalties_cpu, + self.repetition_penalties, num_reqs) + repetition_penalties = self.repetition_penalties[:num_reqs] allowed_token_ids_mask: Optional[Tensor] = None if not self.no_allowed_token_ids: @@ -82,9 +91,9 @@ def _make_sampling_metadata(self) -> SamplingMetadata: generators=self.generators, max_num_logprobs=self.max_num_logprobs, prompt_token_ids=prompt_token_ids, - frequency_penalties=self.frequency_penalties[:num_reqs], - presence_penalties=self.presence_penalties[:num_reqs], - repetition_penalties=self.repetition_penalties[:num_reqs], + frequency_penalties=frequency_penalties, + presence_penalties=presence_penalties, + repetition_penalties=repetition_penalties, output_token_ids=cast(list[list[int]], self.req_output_token_ids), min_tokens=self.min_tokens, no_penalties=self.no_penalties,