From 49474ffceb5e23dace71fa7cb2d11b563a521d77 Mon Sep 17 00:00:00 2001 From: liu lili Date: Fri, 12 Sep 2025 10:40:38 +0800 Subject: [PATCH] lll: vllm-mindspore adapt mindspore aclgraph capture --- vllm_mindspore/__init__.py | 4 + vllm_mindspore/config.py | 5 - .../v1/attention/backends/ms_attn.py | 57 +++++- vllm_mindspore/v1/worker/gpu_model_runner.py | 165 ++++++++++++++++++ vllm_mindspore/v1/worker/gpu_worker.py | 5 + 5 files changed, 227 insertions(+), 9 deletions(-) diff --git a/vllm_mindspore/__init__.py b/vllm_mindspore/__init__.py index dbd26f9bf..f0de61ae1 100644 --- a/vllm_mindspore/__init__.py +++ b/vllm_mindspore/__init__.py @@ -432,6 +432,10 @@ from vllm_mindspore.forward_context import set_forward_context vllm.v1.worker.gpu_model_runner.GPUModelRunner.set_forward_context = ( set_forward_context) +from vllm_mindspore.v1.worker.gpu_model_runner import capture_model + +vllm.v1.worker.gpu_model_runner.GPUModelRunner.capture_model = capture_model + import vllm.v1.worker.block_table from vllm_mindspore.v1.worker.block_table import BlockTable diff --git a/vllm_mindspore/config.py b/vllm_mindspore/config.py index 7f93164b6..44ca81e1b 100644 --- a/vllm_mindspore/config.py +++ b/vllm_mindspore/config.py @@ -98,11 +98,6 @@ def vllm_config_post_init(self): self.compilation_config.cudagraph_num_of_warmups = 1 self.compilation_config.pass_config.enable_fusion = False self.compilation_config.pass_config.enable_noop = False - # When level is set to CompilationLevel.PIECEWISE, vllm will use cuda - # graph, which means the model inputs will be padded to cuda graph - # acceptable size, but it is not for mindspore. - # So here set to CompilationLevel.DYNAMO_AS_IS. - self.compilation_config.level = CompilationLevel.DYNAMO_AS_IS # Set a small compile_sizes for warmup. '20' is not in # 'cudagraph_capture_sizes'. So the warmup can be run. self.compilation_config.compile_sizes = [20] diff --git a/vllm_mindspore/v1/attention/backends/ms_attn.py b/vllm_mindspore/v1/attention/backends/ms_attn.py index d3ff6c8b0..1fab2644f 100644 --- a/vllm_mindspore/v1/attention/backends/ms_attn.py +++ b/vllm_mindspore/v1/attention/backends/ms_attn.py @@ -26,6 +26,7 @@ import numpy as np from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl, AttentionMetadata, AttentionType) from vllm.logger import init_logger +from vllm.v1.worker.gpu_model_runner import GPUModelRunner logger = init_logger(__name__) @@ -178,17 +179,65 @@ class MsAttentionImpl(AttentionImpl): class MsAttentionMetadataBuilder: - def __init__(self, runner, kv_cache_spec, block_table): - self.runner = runner + def __init__(self, runner: GPUModelRunner, kv_cache_spec, block_table): + self.runner: GPUModelRunner = runner self.block_table = block_table def reorder_batch(self, input_batch, scheduler_output) -> bool: return False - def build(self, num_reqs: int, num_actual_tokens: int, max_query_len: int, - common_prefix_len: int): + def build(self, + num_reqs: int, + num_actual_tokens: int, + max_query_len: int, + common_prefix_len: int, + aclgraph_pad_size: int = -1): # do not manually call 'tensor.move_to("Ascend", blocking=False)' here, # because it will cause a certain amount of host time. + + if aclgraph_pad_size != -1: + # pad attn_metadata seq_len, block_tables, slot_mapping for + # mindspore. mindspore aclgraph must match the graph inputs shape + # while native vllm attn_metadata is the inputs for attention which + # is not captured by aclgraph + + # current only decode can enable aclgraph + assert num_reqs == num_actual_tokens + pad_num_tokens = num_reqs + aclgraph_pad_size + + query_start_loc = ms.from_numpy( + self.runner.query_start_loc_np[:pad_num_tokens + 1]) + max_context_lens = self.runner.input_batch.num_computed_tokens_cpu[: + pad_num_tokens].max( + ) + num_prompt_tokens = ms.from_numpy( + self.runner.input_batch.num_prompt_tokens[:pad_num_tokens]) + slot_mapping = ms.from_numpy( + self.block_table.slot_mapping_np[:pad_num_tokens]) + seq_lens_np = self.runner.seq_lens_np[:pad_num_tokens] + max_seq_len = seq_lens_np.max() + seq_lens = ms.from_numpy(seq_lens_np) + context_lens = ms.from_numpy( + self.runner.input_batch. + num_computed_tokens_cpu[:pad_num_tokens]) + q_seq_lens_np = np.diff( + self.runner.query_start_loc_np[:pad_num_tokens + 1]) + + attn_metadata = MsAttentionMetadata( + seq_lens=seq_lens, + seq_lens_np=seq_lens_np, + block_tables=( + self.block_table.get_device_tensor()[:pad_num_tokens]), + slot_mapping=slot_mapping, + q_seq_lens_np=q_seq_lens_np, + max_seq_len=max_seq_len, + context_lens=context_lens, + max_context_lens=max_context_lens, + query_start_loc=query_start_loc, + num_prompt_tokens=num_prompt_tokens) + + return attn_metadata + query_start_loc = ms.from_numpy( self.runner.query_start_loc_np[:num_reqs + 1]) max_context_lens = self.runner.input_batch.num_computed_tokens_cpu[: diff --git a/vllm_mindspore/v1/worker/gpu_model_runner.py b/vllm_mindspore/v1/worker/gpu_model_runner.py index 54fb277cb..814649334 100644 --- a/vllm_mindspore/v1/worker/gpu_model_runner.py +++ b/vllm_mindspore/v1/worker/gpu_model_runner.py @@ -18,6 +18,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +import time import traceback from typing import Any, Optional @@ -27,6 +28,9 @@ import torch from mindspore import Generator as msGenerator from mindspore import Tensor, mint, mutable, ops from vllm.attention import AttentionType +from vllm.config import CompilationLevel +from vllm.distributed.parallel_state import get_pp_group +from vllm.forward_context import set_forward_context from vllm.logger import init_logger from vllm.sampling_params import SamplingType from vllm.v1.kv_cache_interface import (FullAttentionSpec, KVCacheConfig, @@ -35,6 +39,7 @@ from vllm.v1.outputs import ModelRunnerOutput from vllm.v1.spec_decode.metadata import SpecDecodeMetadata from vllm.v1.utils import bind_kv_cache from vllm.v1.worker.gpu_input_batch import CachedRequestState +from vllm.v1.worker.gpu_model_runner import GPUModelRunner from vllm.v1.worker.utils import initialize_kv_cache_for_kv_sharing from vllm_mindspore.model_executor.layers.rotary_embedding import ( @@ -56,6 +61,29 @@ def _prepare_inputs( num_reqs = self.input_batch.num_reqs assert num_reqs > 0 + aclgraph_pad_size = -1 + + if not hasattr(self, "use_cuda_graph_config"): + # the first time call _prepare_inputs, no use_cuda_graph_config attr + # record original use_cuda_graph to use_cuda_graph_config + self.use_cuda_graph_config = self.use_cuda_graph + + if self.use_cuda_graph_config: + # enable cuda graph + # check if is pure decode case + if total_num_scheduled_tokens == num_reqs: + self.use_cuda_graph = True + else: + self.use_cuda_graph = False + else: + self.use_cuda_graph = False + + if (self.use_cuda_graph + and total_num_scheduled_tokens <= self.cudagraph_batch_sizes[-1]): + pad_num_tokens = self.vllm_config.pad_for_cudagraph( + total_num_scheduled_tokens) + aclgraph_pad_size = pad_num_tokens - num_reqs + # OPTIMIZATION: Start copying the block table first. # This way, we can overlap the copy with the following CPU operations. self.input_batch.block_table.commit(num_reqs) @@ -177,6 +205,7 @@ def _prepare_inputs( num_actual_tokens=total_num_scheduled_tokens, max_query_len=max_num_scheduled_tokens, common_prefix_len=common_prefix_len, + aclgraph_pad_size=aclgraph_pad_size, )) for layer_name in kv_cache_group_spec.layer_names: attn_metadata[layer_name] = attn_metadata_i @@ -854,3 +883,139 @@ def get_dp_padding(self, num_tokens: int): # padded based on `num_tokens_across_dp`, while the model only accepts # inputs with actual shape. return 0, None + + +def _aclgraph_capture_dummy_run( + self: GPUModelRunner, + num_tokens: int, + skip_attn: bool = True, +): + # Padding for DP + num_pad, num_tokens_across_dp = self.get_dp_padding(num_tokens=num_tokens) + num_tokens += num_pad + + # Set num_scheduled_tokens based on num_tokens and max_num_seqs + # for dummy run with LoRA so that the num_reqs collectively + # has num_tokens in total. + + assert num_tokens <= self.scheduler_config.max_num_batched_tokens + max_num_reqs = self.scheduler_config.max_num_seqs + num_reqs = min(num_tokens, max_num_reqs) + min_tokens_per_seq = num_tokens // num_reqs + num_scheduled_tokens_list = [min_tokens_per_seq] * num_reqs + num_scheduled_tokens_list[-1] += num_tokens % num_reqs + assert sum(num_scheduled_tokens_list) == num_tokens + assert len(num_scheduled_tokens_list) == num_reqs + num_scheduled_tokens = np.array(num_scheduled_tokens_list, dtype=np.int32) + if skip_attn: + attn_metadata: Optional[dict[str, Any]] = None + else: + # Make sure max_model_len is used at the graph capture time. + self.seq_lens_np[:num_reqs] = self.max_model_len + self.seq_lens_np[num_reqs:] = 0 + self.seq_lens[:num_reqs].copy_(self.seq_lens_cpu[:num_reqs], + non_blocking=True) + + attn_metadata = {} + + # Prepare the attention memdata for each KV cache group and make layers + # in the same group share the same metadata. + for kv_cache_group_id, kv_cache_group_spec in enumerate( + self.kv_cache_config.kv_cache_groups): + # Prepare for cascade attention if enable & beneficial. + common_prefix_len = 0 + attn_metadata_i = ( + self.attn_metadata_builders[kv_cache_group_id].build( + num_reqs=num_reqs, + num_actual_tokens=num_tokens, + max_query_len=num_tokens, + common_prefix_len=common_prefix_len, + aclgraph_pad_size=0, + )) + # disable prefill by set max_context_len != 0 + attn_metadata_i.max_context_lens = 1 + for layer_name in kv_cache_group_spec.layer_names: + attn_metadata[layer_name] = attn_metadata_i + + with self.maybe_dummy_run_with_lora( + self.lora_config, num_scheduled_tokens=num_scheduled_tokens): + model = self.model + input_ids = self.input_ids[:num_tokens] + inputs_embeds = None + positions = self.positions[:num_tokens] + + if get_pp_group().is_first_rank: + intermediate_tensors = None + else: + if self.intermediate_tensors is None: + self.intermediate_tensors = ( + self.model.make_empty_intermediate_tensors( + batch_size=self.max_num_tokens, + dtype=self.model_config.dtype, + device=self.device)) + intermediate_tensors = self.sync_and_slice_intermediate_tensors( + num_tokens, None, False) + + with self.maybe_randomize_inputs(input_ids), set_forward_context( + attn_metadata, + self.vllm_config, + num_tokens=num_tokens, + num_tokens_across_dp=num_tokens_across_dp): + outputs = model( + input_ids=input_ids, + positions=positions, + intermediate_tensors=intermediate_tensors, + inputs_embeds=inputs_embeds, + ) + hidden_states = outputs + + logit_indices = np.cumsum(num_scheduled_tokens) - 1 + return hidden_states[logit_indices] + + +def capture_model(self: GPUModelRunner) -> None: + # this is a trick in vllm-mindspore, vllm-mindspore do not want to + # instead GPUModelRunner.execute_model. vllm-mindspore will add + # use_cuda_graph_config to save real config use_cuda_graph. + # In GPUModelRunner._prepare_inputs, it will dynamic change + # use_cuda_graph according to model infer case, eg. only th + # pure decode case will enable use_cuda_graph + use_cuda_graph = False + if hasattr(self, "use_cuda_graph_config"): + use_cuda_graph = self.use_cuda_graph_config + else: + use_cuda_graph = self.use_cuda_graph + + if not use_cuda_graph: + logger.warning( + "Skipping acl graph capture. Please add " + "-O %s to use aclgraph", CompilationLevel.PIECEWISE) + return + + start_time = time.perf_counter() + start_free_gpu_memory = torch.cuda.mem_get_info()[0] + + # Trigger aclgraph capture for specific shapes + # Capture the large shapes first so that the smaller shapes + # can reuse the memory poll allocated for the large shapes. + + # vllm-mindspore use full graph to capture aclgraph + # set the skip_attn to False + skip_attn = False + # enable mindspore graph capture + ms.set_kernel_launch_capture(True) + self.cudagraph_batch_sizes = [64, 128] + for num_tokens in reversed(self.cudagraph_batch_sizes): + for _ in range( + self.vllm_config.compilation_config.cudagraph_num_of_warmups): + _aclgraph_capture_dummy_run(self, num_tokens, skip_attn=skip_attn) + _aclgraph_capture_dummy_run(self, num_tokens, skip_attn=skip_attn) + + end_time = time.perf_counter() + end_free_gpu_memory = torch.cuda.mem_get_info()[0] + elapsed_time = end_time - start_time + cuda_graph_size = start_free_gpu_memory - end_free_gpu_memory + logger.info("Graph capturing finished in %.0f secs, took %.2f GiB", + elapsed_time, cuda_graph_size / (1 << 30)) + + # disable mindspore graph capture (captured graphs replay still work) diff --git a/vllm_mindspore/v1/worker/gpu_worker.py b/vllm_mindspore/v1/worker/gpu_worker.py index 91d2564a8..ce09e2b28 100644 --- a/vllm_mindspore/v1/worker/gpu_worker.py +++ b/vllm_mindspore/v1/worker/gpu_worker.py @@ -31,6 +31,11 @@ def compile_or_warm_up_model(self) -> None: default_max_num_reqs = 1 # For MindSpore, we only do one more decode here. # Only pp_last_rank requires _dummy_sampler_run, # and only pp_last_rank can _dummy_sampler_run. + + # capture decode aclgraph + if not self.model_config.enforce_eager: + self.model_runner.capture_model() + if get_pp_group().is_last_rank: self.model_runner._dummy_sampler_run( self.model_runner._dummy_run(num_tokens=default_max_num_reqs)) -- Gitee