diff --git a/.jenkins/test/config/dependent_packages.yaml b/.jenkins/test/config/dependent_packages.yaml index 8aad948e1bd7c5e79c66a9af0e4448981c26ae36..5f4d0525c7520cd649a4f8d7b76cbe8e064a84fd 100644 --- a/.jenkins/test/config/dependent_packages.yaml +++ b/.jenkins/test/config/dependent_packages.yaml @@ -1,11 +1,11 @@ mindspore: - 'https://repo.mindspore.cn/mindspore/mindspore/version/202509/20250927/master_20250927010016_1bd70be392807891ccb1c5f247beff2e7c3403a4_newest/' + 'https://repo.mindspore.cn/mindspore/mindspore/version/202510/20251011/master_20251011121011_b98662acda7a876f70f117c237ab47cb125ec9ab_newest/' mindspore_gs: 'https://repo.mindspore.cn/mindspore/golden-stick/version/202509/20250901/master_20250901221800_3e34fd43040b0c5d296e6bc1a82212deae3ee041_newest/' msadapter: - 'https://repo.mindspore.cn/mindspore/msadapter/version/202509/20250929/master_20250929203817_1b4f3bc61383eab75bd823ba591e15fd09afa24a_newest/' + 'https://repo.mindspore.cn/mindspore/msadapter/version/202510/20251011/master_20251011010017_2e481ca185a5fefc603873d35d09a786f4077a2b_newest/' vllm: 'https://repo.mindspore.cn/mirrors/vllm/version/202507/20250715/v0.9.1/' diff --git a/tests/mindformers b/tests/mindformers index cd39c1543c3c29b57da6e305f0902096a046be35..3d1e55ce5636fcef7472935086b60f2554cf3bdb 160000 --- a/tests/mindformers +++ b/tests/mindformers @@ -1 +1 @@ -Subproject commit cd39c1543c3c29b57da6e305f0902096a046be35 +Subproject commit 3d1e55ce5636fcef7472935086b60f2554cf3bdb diff --git a/vllm_mindspore/__init__.py b/vllm_mindspore/__init__.py index efd917b0a3cef0cc0d11212dc25de9dc60166260..831529801464cae6e1dfdab60c07a70614c104f6 100644 --- a/vllm_mindspore/__init__.py +++ b/vllm_mindspore/__init__.py @@ -391,6 +391,10 @@ from vllm_mindspore.v1.worker.gpu_model_runner import _prepare_inputs vllm.v1.worker.gpu_model_runner.GPUModelRunner._prepare_inputs = _prepare_inputs +from vllm_mindspore.v1.worker.gpu_model_runner import _dummy_run as v1_dummy_run + +vllm.v1.worker.gpu_model_runner.GPUModelRunner._dummy_run = v1_dummy_run + from vllm_mindspore.v1.worker.gpu_model_runner import _calc_mrope_positions vllm.v1.worker.gpu_model_runner.GPUModelRunner._calc_mrope_positions = \ @@ -430,6 +434,10 @@ from vllm_mindspore.forward_context import set_forward_context vllm.v1.worker.gpu_model_runner.GPUModelRunner.set_forward_context = ( set_forward_context) +from vllm_mindspore.forward_context import make as dp_metadata_make + +vllm.forward_context.DPMetadata.make = staticmethod(dp_metadata_make) + import vllm.v1.worker.block_table from vllm_mindspore.v1.worker.block_table import BlockTable diff --git a/vllm_mindspore/forward_context.py b/vllm_mindspore/forward_context.py index 89fd3ce920040695166b8af853d62943f55dc553..979998e808d542eafbeb74af1a55357ed2aa3c6d 100644 --- a/vllm_mindspore/forward_context.py +++ b/vllm_mindspore/forward_context.py @@ -26,16 +26,54 @@ from typing import Any, Optional import torch import vllm.envs as envs -from vllm.config import VllmConfig +from vllm.config import ParallelConfig, VllmConfig from vllm.forward_context import DPMetadata, ForwardContext from vllm.logger import init_logger +from vllm_mindspore.utils import convert_pin + logger = init_logger(__name__) track_batchsize: bool = envs.VLLM_LOG_BATCHSIZE_INTERVAL >= 0 batchsize_forward_time: defaultdict = defaultdict(list) +def make(parallel_config: ParallelConfig, + attn_metadata: Any, + num_tokens: int, + num_tokens_across_dp: Optional[torch.Tensor] = None) -> "DPMetadata": + + assert parallel_config.data_parallel_size > 1 + dp_size = parallel_config.data_parallel_size + dp_rank = parallel_config.data_parallel_rank + if attn_metadata is not None and hasattr(attn_metadata, + "num_prefill_tokens"): + # for v0 attention backends + batchsize = attn_metadata.num_prefill_tokens + \ + attn_metadata.num_decode_tokens + else: + # for v1 attention backends or no attn_metadata + batchsize = num_tokens + + # If num_tokens_across_dp is None, it will be computed by all_reduce + # Otherwise, num_tokens_across_dp[dp_rank] should be equal to batchsize + assert (num_tokens_across_dp is None + or num_tokens_across_dp[dp_rank] == batchsize) + if num_tokens_across_dp is None: + num_tokens_across_dp = DPMetadata.num_tokens_across_dp( + batchsize, dp_size, dp_rank) + + # NOTE: vLLM-MindSpore Plugin: + # In asynchronous copying scenarios, if the data is not pin_memory, + # it will revert back to synchronous copying and cause interruption, + # resulting in slower execution performance. + num_tokens_across_dp = convert_pin(num_tokens_across_dp) + + max_tokens_across_dp_cpu = torch.max(num_tokens_across_dp) + cu_tokens_across_dp_cpu = torch.cumsum(num_tokens_across_dp, dim=0) + return DPMetadata(max_tokens_across_dp_cpu, cu_tokens_across_dp_cpu) + + @contextmanager def set_forward_context(attn_metadata: Any, vllm_config: VllmConfig, diff --git a/vllm_mindspore/lora/models.py b/vllm_mindspore/lora/models.py index 7a687019cfb66c30220826a7e62c38ac824a2c42..99d1dfd5af8197a66bc4577ff284d3601e404099 100644 --- a/vllm_mindspore/lora/models.py +++ b/vllm_mindspore/lora/models.py @@ -33,6 +33,7 @@ from vllm.model_executor.models.utils import WeightsMapper from vllm.utils import is_pin_memory_available from vllm_mindspore.lora.layers import BaseLayerWithLoRA +from vllm_mindspore.utils import convert_pin _GLOBAL_LORA_ID = 0 @@ -77,10 +78,9 @@ def from_lora_tensors( if embeddings_module: lora_embeddings_tensor = embeddings[ embedding_modules[embeddings_module]] - if pin_memory and \ - lora_embeddings_tensor._ms_device == "CPU": - lora_embeddings_tensor = ( - lora_embeddings_tensor.pin_memory()) + if pin_memory: + lora_embeddings_tensor = convert_pin( + lora_embeddings_tensor) loras[module_name] = LoRALayerWeights.from_config( module_name, peft_helper, lora_embeddings_tensor) @@ -88,14 +88,14 @@ def from_lora_tensors( # vllm-mindspore remove tensor device loras[module_name].bias = tensor.to(dtype=dtype).t() bias = tensor.to(dtype=dtype).t() - if pin_memory and bias._ms_device == "CPU": - bias = bias.pin_memory() + if pin_memory: + bias = convert_pin(bias) loras[module_name].bias = bias elif is_lora_a: loras[module_name].lora_a = tensor.to(dtype=dtype).t() - if pin_memory and loras[module_name].lora_a._ms_device == "CPU": - loras[module_name].lora_a = loras[ - module_name].lora_a.pin_memory() + if pin_memory: + loras[module_name].lora_a = convert_pin( + loras[module_name].lora_a) else: loras[module_name].lora_b = tensor.to(dtype=dtype).t() assert embedding_padding_modules is not None @@ -106,9 +106,9 @@ def from_lora_tensors( addition = target_embedding_padding - lora_b.shape[1] loras[module_name].lora_b = mint.nn.functional.pad( lora_b, (0, addition)) - if pin_memory and loras[module_name].lora_b._ms_device == "CPU": - loras[module_name].lora_b = loras[ - module_name].lora_b.pin_memory() + if pin_memory: + loras[module_name].lora_b = convert_pin( + loras[module_name].lora_b) for lora in loras.values(): lora.optimize() diff --git a/vllm_mindspore/model_executor/models/mf_models/mindformers.py b/vllm_mindspore/model_executor/models/mf_models/mindformers.py index ebd9628e1e0b6454370d9443bd5636c68c63b235..ba7d82c884e4e43a6914b891ebc703559a658921 100644 --- a/vllm_mindspore/model_executor/models/mf_models/mindformers.py +++ b/vllm_mindspore/model_executor/models/mf_models/mindformers.py @@ -39,8 +39,8 @@ from vllm_mindspore.model_executor.models.mf_models.config import gen_mf_config from vllm_mindspore.model_executor.models.model_base import ( AttentionWrapper, MLAAttentionWrapper, MsModelBase) from vllm_mindspore.model_executor.models.utils import ( - convert_pin, is_use_ringmla, make_empty_intermediate_tensors_factory) -from vllm_mindspore.utils import is_310p + is_use_ringmla, make_empty_intermediate_tensors_factory) +from vllm_mindspore.utils import convert_pin, is_310p logger = init_logger(__name__) @@ -211,7 +211,11 @@ class MindFormersForCausalLM(MsModelBase, SupportsPP): if dp_size == 1 or (dp_size == ep_size and tp_size == 1): return None, None, None, None - tokens_len_per_dp = q_seq_len.sum().reshape(-1) + # NOTE: vLLM-MindSpore Plugin: + # In asynchronous copying scenarios, if the data is not pin_memory, + # it will revert back to synchronous copying and cause interruption, + # resulting in slower execution performance. + tokens_len_per_dp = convert_pin(q_seq_len).sum().reshape(-1) tokens_len_per_dp = get_dp_group().all_gather(tokens_len_per_dp) tokens_len_per_dp = tokens_len_per_dp.asnumpy() @@ -306,6 +310,16 @@ class MindFormersForCausalLM(MsModelBase, SupportsPP): else: # V1 is_prefill = attn_metadata.max_context_lens == 0 + + # NOTE: vLLM-MindSpore Plugin: + # In asynchronous copying scenarios, if the data is not pin_memory, + # it will revert back to synchronous copying and cause interruption, + # resulting in slower execution performance. + attn_metadata.context_lens = convert_pin( + attn_metadata.context_lens) + attn_metadata.num_prompt_tokens = convert_pin( + attn_metadata.num_prompt_tokens) + is_ringmla_chunked = \ self.use_ringmla and not is_prefill and \ bool((attn_metadata.context_lens - \ diff --git a/vllm_mindspore/model_executor/models/model_base.py b/vllm_mindspore/model_executor/models/model_base.py index 1e4627c6b45044ac404f40bbfb5115a81bc66c19..0aa14b53fde462f12c1b8d2c1589a455ef055db5 100644 --- a/vllm_mindspore/model_executor/models/model_base.py +++ b/vllm_mindspore/model_executor/models/model_base.py @@ -32,10 +32,10 @@ from vllm.sequence import IntermediateTensors from vllm_mindspore.model_executor.models.attention_mask import ( LowerTriangularMask) -from vllm_mindspore.model_executor.models.utils import (convert_pin, - is_use_ringmla) +from vllm_mindspore.model_executor.models.utils import is_use_ringmla from vllm_mindspore.model_executor.utils import set_model_context -from vllm_mindspore.utils import STR_DTYPE_TO_MS_DTYPE, create_kv_cache +from vllm_mindspore.utils import (STR_DTYPE_TO_MS_DTYPE, convert_pin, + create_kv_cache) from vllm_mindspore.v1.attention.backends.ms_attn import MsAttentionMetadata diff --git a/vllm_mindspore/model_executor/models/utils.py b/vllm_mindspore/model_executor/models/utils.py index 2a4b3e51c62c649aa2587216d3e7e7caecab6c6e..5d9e56d3431257e81e07c47515b6084b5fa38f0a 100644 --- a/vllm_mindspore/model_executor/models/utils.py +++ b/vllm_mindspore/model_executor/models/utils.py @@ -23,7 +23,7 @@ from dataclasses import dataclass, field from typing import Optional, Union import mindspore as ms -from mindspore import Tensor, mint, ops +from mindspore import mint, ops from vllm import envs from vllm.sequence import IntermediateTensors @@ -34,23 +34,6 @@ WeightsMapping = Mapping[str, Optional[str]] """If a key maps to a value of `None`, the corresponding weight is ignored.""" -def convert_pin(input_tensor): - """Convert tensor to pinned memory if it's on CPU and not already pinned. - - Args: - input_tensor: Input tensor to convert - - Returns: - Tensor with pinned memory if applicable, otherwise original tensor - """ - if not isinstance(input_tensor, Tensor): - return input_tensor - if input_tensor._ms_device == "CPU" and not input_tensor.is_pinned(): - input_pined = input_tensor.pin_memory() - return input_pined - return input_tensor - - @dataclass class WeightsMapper: """Maps the name of each weight if they match the following patterns.""" diff --git a/vllm_mindspore/utils.py b/vllm_mindspore/utils.py index 12915b87a1ed413216d07a1398dd0f56a1d03f94..34a7c9118490a291bf00eb25c6a4fdd77b550b6e 100644 --- a/vllm_mindspore/utils.py +++ b/vllm_mindspore/utils.py @@ -39,6 +39,7 @@ else: Library = None import mindspore as ms +from mindspore import Tensor from mindspore import dtype as mstype from mindspore.common.initializer import Zero from vllm.logger import init_logger @@ -453,3 +454,20 @@ def ms_memory_profiling( result.non_torch_increase = diff_from_create.non_torch_memory result.profile_time = diff_profile.timestamp result.non_kv_cache_memory = result.non_torch_increase + result.torch_peak_increase + result.weights_memory # noqa + + +def convert_pin(input_tensor): + """Convert tensor to pinned memory if it's on CPU and not already pinned. + + Args: + input_tensor: Input tensor to convert + + Returns: + Tensor with pinned memory if applicable, otherwise original tensor + """ + if not isinstance(input_tensor, Tensor): + return input_tensor + if input_tensor._ms_device == "CPU": + input_pined = input_tensor.pin_memory() + return input_pined + return input_tensor diff --git a/vllm_mindspore/v1/worker/gpu_model_runner.py b/vllm_mindspore/v1/worker/gpu_model_runner.py index 5e70af7aa83c53bcd1d51d76fae4d7dadfbfcc29..33bf5d8d0a0f56afd16d62fa70197cc0926ab790 100644 --- a/vllm_mindspore/v1/worker/gpu_model_runner.py +++ b/vllm_mindspore/v1/worker/gpu_model_runner.py @@ -27,11 +27,15 @@ import torch from mindspore import Generator as msGenerator from mindspore import Tensor, mint, mutable from vllm.attention import AttentionType +from vllm.distributed.parallel_state import get_pp_group +from vllm.forward_context import set_forward_context from vllm.logger import init_logger from vllm.sampling_params import SamplingType +from vllm.v1.attention.backends.utils import CommonAttentionMetadata from vllm.v1.kv_cache_interface import (FullAttentionSpec, KVCacheConfig, KVCacheSpec, SlidingWindowSpec) from vllm.v1.outputs import ModelRunnerOutput +from vllm.v1.spec_decode.eagle import EagleProposer from vllm.v1.spec_decode.metadata import SpecDecodeMetadata from vllm.v1.utils import bind_kv_cache from vllm.v1.worker.gpu_input_batch import CachedRequestState @@ -40,7 +44,7 @@ from vllm.v1.worker.utils import initialize_kv_cache_for_kv_sharing from vllm_mindspore.model_executor.layers.rotary_embedding import ( InferMRotaryEmbedding as MRotaryEmbedding) from vllm_mindspore.model_executor.models.utils import is_use_ringmla -from vllm_mindspore.utils import (create_kv_cache, get_dtype_size, +from vllm_mindspore.utils import (convert_pin, create_kv_cache, get_dtype_size, get_valid_dtype, is_310p) from vllm_mindspore.v1.kv_cache_interface import MLAQuantFullAttentionSpec @@ -211,6 +215,116 @@ def _prepare_inputs( return attn_metadata, logits_indices, spec_decode_metadata +@torch.inference_mode() +def _dummy_run( + self, + num_tokens: int, + skip_attn: bool = True, +) -> torch.Tensor: + + # Padding for DP + num_pad, num_tokens_across_dp = self.get_dp_padding(num_tokens) + num_tokens += num_pad + + # Set num_scheduled_tokens based on num_tokens and max_num_seqs + # for dummy run with LoRA so that the num_reqs collectively + # has num_tokens in total. + assert num_tokens <= self.scheduler_config.max_num_batched_tokens + max_num_reqs = self.scheduler_config.max_num_seqs + num_reqs = min(num_tokens, max_num_reqs) + min_tokens_per_req = num_tokens // num_reqs + num_scheduled_tokens_list = [min_tokens_per_req] * num_reqs + num_scheduled_tokens_list[-1] += num_tokens % num_reqs + assert sum(num_scheduled_tokens_list) == num_tokens + assert len(num_scheduled_tokens_list) == num_reqs + num_scheduled_tokens = np.array(num_scheduled_tokens_list, dtype=np.int32) + + if skip_attn: + attn_metadata: Optional[dict[str, Any]] = None + else: + query_start_loc = self.query_start_loc[:num_reqs + 1] + # Make sure max_model_len is used at the graph capture time. + self.seq_lens_np[:num_reqs] = self.max_model_len + self.seq_lens_np[num_reqs:] = 0 + self.seq_lens[:num_reqs].copy_(self.seq_lens_cpu[:num_reqs], + non_blocking=True) + seq_lens = self.seq_lens[:num_reqs] + + common_attn_metadata = CommonAttentionMetadata( + query_start_loc=query_start_loc, seq_lens=seq_lens) + + attn_metadata = {} + for kv_cache_group_id, kv_cache_group_spec in enumerate( + self.kv_cache_config.kv_cache_groups): + attn_metadata_i = ( + self.attn_metadata_builders[kv_cache_group_id].build( + num_reqs=num_reqs, + num_actual_tokens=num_tokens, + max_query_len=num_tokens, + common_prefix_len=0, + common_attn_metadata=common_attn_metadata, + )) + for layer_name in kv_cache_group_spec.layer_names: + attn_metadata[layer_name] = attn_metadata_i + + with self.maybe_dummy_run_with_lora(self.lora_config, + num_scheduled_tokens): + model = self.model + if self.is_multimodal_model: + input_ids = None + inputs_embeds = self.inputs_embeds[:num_tokens] + else: + input_ids = self.input_ids[:num_tokens] + inputs_embeds = None + if self.uses_mrope: + positions = self.mrope_positions[:, :num_tokens] + else: + positions = self.positions[:num_tokens] + + if get_pp_group().is_first_rank: + intermediate_tensors = None + else: + if self.intermediate_tensors is None: + self.intermediate_tensors = ( + self.model.make_empty_intermediate_tensors( + batch_size=self.max_num_tokens, + dtype=self.model_config.dtype, + device=self.device)) + + intermediate_tensors = self.sync_and_slice_intermediate_tensors( + num_tokens, None, False) + + with self.maybe_randomize_inputs(input_ids), set_forward_context( + attn_metadata, + self.vllm_config, + num_tokens=num_tokens, + num_tokens_across_dp=num_tokens_across_dp): + outputs = model( + input_ids=input_ids, + positions=positions, + intermediate_tensors=intermediate_tensors, + inputs_embeds=inputs_embeds, + ) + if self.use_aux_hidden_state_outputs: + hidden_states, _ = outputs + else: + hidden_states = outputs + + if self.speculative_config and self.speculative_config.use_eagle(): + assert isinstance(self.drafter, EagleProposer) + self.drafter.dummy_run(num_tokens) + + logit_indices = np.cumsum(num_scheduled_tokens) - 1 + + # NOTE: vLLM-MindSpore Plugin: + # In asynchronous copying scenarios, if the data is not pin_memory, + # it will revert back to synchronous copying and cause interruption, + # resulting in slower execution performance. + logit_indices = convert_pin(ms.from_numpy(logit_indices)) + + return hidden_states[logit_indices] + + def create_block(shape, dtype, name=None, device=None): blocks = mint.empty(shape, dtype=dtype, device=device) return blocks