From 605775ba55a86c2823d38dd1c98b5657787dfe25 Mon Sep 17 00:00:00 2001 From: xiaocy1997 Date: Thu, 27 Nov 2025 04:11:35 +0000 Subject: [PATCH] =?UTF-8?q?=E6=B7=BB=E5=8A=A0cpu=5Fmodel=5Frunner=E5=92=8C?= =?UTF-8?q?=E8=87=AA=E5=AE=9A=E4=B9=89=E6=A8=A1=E5=9E=8B=E6=89=A7=E8=A1=8C?= =?UTF-8?q?=E8=B7=AF=E5=BE=84?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/vllm_adapter/model_executor/__init__.py | 14 ++ src/vllm_adapter/model_executor/model_map.py | 20 ++ src/vllm_adapter/model_executor/qwen2.py | 60 ++++++ src/vllm_adapter/syshax/syshax_config.py | 12 ++ src/vllm_adapter/v1/engine/core.py | 72 ++++++- .../v1/worker/cpu_model_runner.py | 196 ++++++++++++++++++ src/vllm_adapter/v1/worker/cpu_worker.py | 9 + 7 files changed, 382 insertions(+), 1 deletion(-) create mode 100644 src/vllm_adapter/model_executor/__init__.py create mode 100644 src/vllm_adapter/model_executor/model_map.py create mode 100644 src/vllm_adapter/model_executor/qwen2.py create mode 100644 src/vllm_adapter/v1/worker/cpu_model_runner.py diff --git a/src/vllm_adapter/model_executor/__init__.py b/src/vllm_adapter/model_executor/__init__.py new file mode 100644 index 0000000..bb5db81 --- /dev/null +++ b/src/vllm_adapter/model_executor/__init__.py @@ -0,0 +1,14 @@ +""" +Copyright (c) 2025-2025 Huawei Technologies Co., Ltd. + +sysHAX-adapter is licensed under Mulan PSL v2. +You can use this software according to the terms and conditions of the Mulan PSL v2. +You may obtain a copy of Mulan PSL v2 at: + http://license.coscl.org.cn/MulanPSL2 +THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY OR FIT FOR A PARTICULAR +PURPOSE. +See the Mulan PSL v2 for more details. +Created: 2025-11-27 +Desc: src.vllm_adapter.model_executor.__init__.py +""" \ No newline at end of file diff --git a/src/vllm_adapter/model_executor/model_map.py b/src/vllm_adapter/model_executor/model_map.py new file mode 100644 index 0000000..f533854 --- /dev/null +++ b/src/vllm_adapter/model_executor/model_map.py @@ -0,0 +1,20 @@ +""" +Copyright (c) 2025-2025 Huawei Technologies Co., Ltd. + +sysHAX-adapter is licensed under Mulan PSL v2. +You can use this software according to the terms and conditions of the Mulan PSL v2. +You may obtain a copy of Mulan PSL v2 at: + http://license.coscl.org.cn/MulanPSL2 +THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY OR FIT FOR A PARTICULAR +PURPOSE. +See the Mulan PSL v2 for more details. +Created: 2025-11-27 +Desc: src.vllm_adapter.model_executor.model_map.py +""" + +from src.vllm_adapter.model_executor.qwen2 import SyshaxQwen2ForCausalLM + +g_model_map = { + "qwen2": SyshaxQwen2ForCausalLM, +} \ No newline at end of file diff --git a/src/vllm_adapter/model_executor/qwen2.py b/src/vllm_adapter/model_executor/qwen2.py new file mode 100644 index 0000000..d0ee3bd --- /dev/null +++ b/src/vllm_adapter/model_executor/qwen2.py @@ -0,0 +1,60 @@ +""" +Copyright (c) 2025-2025 Huawei Technologies Co., Ltd. + +sysHAX-adapter is licensed under Mulan PSL v2. +You can use this software according to the terms and conditions of the Mulan PSL v2. +You may obtain a copy of Mulan PSL v2 at: + http://license.coscl.org.cn/MulanPSL2 +THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY OR FIT FOR A PARTICULAR +PURPOSE. +See the Mulan PSL v2 for more details. +Created: 2025-11-27 +Desc: src.vllm_adapter.model_executor.qwen2.py +""" + +from typing import Optional, Union +import torch +from torch import nn + +from vllm.config import VllmConfig +from vllm.model_executor.layers.logits_processor import LogitsProcessor +from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead +from vllm.model_executor.sampling_metadata import SamplingMetadata +from vllm.sequence import IntermediateTensors +from vllm.model_executor.models.interfaces import SupportsLoRA, SupportsPP +from vllm.model_executor.models.utils import PPMissingLayer, maybe_prefix +from vllm.distributed import get_pp_group + +from src.utils.logger import Logger + + +class SyshaxQwen2ForCausalLMForCPU(nn.Module): + """ + Syshax 版本的 Qwen2ForCausalLM,用于使用 NUMA 权重存储方案时的模型推理。 + 权重从全局存储中获取,forward 函数由用户实现。 + """ + + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): + super().__init__() + self.config = vllm_config.model_config.hf_config + # todo + + Logger.info("Initialized SyshaxQwen2ForCausalLM (using NUMA weight storage)") + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + intermediate_tensors: Optional[IntermediateTensors] = None, + inputs_embeds: Optional[torch.Tensor] = None, + ) -> Union[torch.Tensor, IntermediateTensors]: + """ + Forward 函数由用户实现,这里只提供接口。 + """ + Logger.debug("SyshaxQwen2ForCausalLM.forward called") + # forward 函数由用户实现,这里不做任何操作 + raise NotImplementedError( + "Forward function should be implemented by user. " + "Access weights via weight_loader.get_weight(name)." + ) \ No newline at end of file diff --git a/src/vllm_adapter/syshax/syshax_config.py b/src/vllm_adapter/syshax/syshax_config.py index 121b0aa..511a19c 100644 --- a/src/vllm_adapter/syshax/syshax_config.py +++ b/src/vllm_adapter/syshax/syshax_config.py @@ -22,6 +22,7 @@ from src.utils.logger import Logger class SyshaxConfig: ENABLE_AUTO_PD_OFFLOAD: bool = False USE_GREDDY: bool = True + MODEL_LOADING_SCHEME: int = 0 def __post_init__(self): enable_auto_pd_offload = os.getenv("ENABLE_AUTO_PD_OFFLOAD") @@ -31,6 +32,7 @@ class SyshaxConfig: Logger.info("开启了 ENABLE_AUTO_PD_OFFLOAD 选项") else: Logger.info("未启动 ENABLE_AUTO_PD_OFFLOAD 选项") + use_greddy = os.getenv("USE_GREDDY") if use_greddy is not None: self.USE_GREDDY = self._parse_bool(use_greddy) @@ -39,6 +41,14 @@ class SyshaxConfig: else: Logger.info("未启动 USE_GREDDY 选项") + model_loading_scheme = os.getenv("MODEL_LOADING_SCHEME") + if model_loading_scheme is not None: + self.MODEL_LOADING_SCHEME = int(model_loading_scheme) + if self.MODEL_LOADING_SCHEME == 0: + Logger.info("采用默认模型加载方案") + else: + Logger.info(f"采用自定义模型加载方案,MODEL_LOADING_SCHEME = {self.MODEL_LOADING_SCHEME}") + @staticmethod def _parse_bool(value: str) -> bool: return value.lower() in ("1", "yes", "true", "on") @@ -55,3 +65,5 @@ class SyshaxConfig: def use_greddy(self) -> bool: return self.USE_GREDDY + def model_loading_scheme(self) -> int: + return self.MODEL_LOADING_SCHEME diff --git a/src/vllm_adapter/v1/engine/core.py b/src/vllm_adapter/v1/engine/core.py index 6fd04c8..c569059 100644 --- a/src/vllm_adapter/v1/engine/core.py +++ b/src/vllm_adapter/v1/engine/core.py @@ -26,11 +26,21 @@ from vllm.v1.engine.core import EngineCoreProc, DPEngineCoreProc from vllm.v1.serial_utils import MsgpackDecoder from vllm.v1.engine import EngineCoreRequestType from vllm.v1.core.sched.output import SchedulerOutput +from vllm.v1.kv_cache_interface import ( + KVCacheConfig, + KVCacheTensor, + KVCacheGroupSpec, + FullAttentionSpec, +) +import torch +from vllm.v1.core.kv_cache_utils import unify_kv_cache_configs from src.utils.logger import Logger, init_logger from src.vllm_adapter.v1.engine import SyshaxEngineCoreRequest from src.vllm_adapter.v1.core.sched.scheduler import SyshaxScheduler from src.vllm_adapter.syshax.shared_memory_manager import SharedMemoryManager +from src.vllm_adapter.syshax.syshax_config import SyshaxConfig +from src.vllm_adapter.model_loader.default_loader import SyshaxDefaultModelLoader class SyshaxEngineCoreProc(EngineCoreProc): def __init__( @@ -43,6 +53,10 @@ class SyshaxEngineCoreProc(EngineCoreProc): engine_index: int = 0, ): init_logger() + self.syshax_config = SyshaxConfig.instance() + if self.syshax_config is None: + raise RuntimeError("Error! syshax_config创建失败。") + self._check_and_update_config(vllm_config) super().__init__( vllm_config=vllm_config, @@ -150,7 +164,63 @@ class SyshaxEngineCoreProc(EngineCoreProc): elif vllm_config.device_config.device_type == "cpu": vllm_config.parallel_config.worker_cls = "src.vllm_adapter.v1.worker.cpu_worker.SyshaxCPUWorker" else: - raise NotImplementedError(f"{self.device_config.device_type} is not supported on sysHAX") + raise NotImplementedError(f"{vllm_config.device_config.device_type} is not supported on sysHAX") + vllm_config.load_config.load_format = SyshaxDefaultModelLoader + + def _initialize_kv_caches(self, vllm_config: VllmConfig) -> tuple[int, int, KVCacheConfig]: + """ + 对于 CPU 设备,跳过 vLLM 的 KV cache 初始化,返回最小占位配置。 + KV cache 空间将自己管理。 + """ + if vllm_config.device_config.device_type == "cpu" and \ + self.syshax_config.model_loading_scheme() == 1: + # 获取可用内存(用于计算,但实际不分配) + available_gpu_memory = self.model_executor.determine_available_memory() + + # 为每个 worker 创建最小的占位配置 + kv_cache_configs = [] + for available_memory_one_worker in available_gpu_memory: + kv_cache_config = self._create_minimal_placeholder_config( + vllm_config, available_memory_one_worker) + kv_cache_configs.append(kv_cache_config) + unify_kv_cache_configs(kv_cache_configs) + + assert all([cfg.num_blocks == kv_cache_configs[0].num_blocks for cfg in kv_cache_configs]) + num_gpu_blocks = kv_cache_configs[0].num_blocks + num_cpu_blocks = 0 + scheduler_kv_cache_config = kv_cache_configs[0] + return num_gpu_blocks, num_cpu_blocks, scheduler_kv_cache_config + + return super()._initialize_kv_caches(vllm_config) + + def _create_minimal_placeholder_config(self, vllm_config: VllmConfig, available_memory: int) -> KVCacheConfig: + """ 创建一个最小的占位 KV cache 配置,当 kv_cache_spec 为空时使用。 """ + # 从模型配置中获取基本信息 + model_config = vllm_config.model_config + block_size = vllm_config.cache_config.block_size + + # 尝试从模型配置获取 KV heads 和 head size + num_kv_heads = getattr(model_config, 'num_key_value_heads', 32) + head_size = getattr(model_config, 'head_dim', 128) + dtype = getattr(model_config, 'torch_dtype', torch.float16) + + # 创建一个最小的 FullAttentionSpec + placeholder_spec = FullAttentionSpec(block_size=block_size, num_kv_heads=num_kv_heads, + head_size=head_size, dtype=dtype, use_mla=False) + + num_blocks = 1 # 计算最小的块数(1个块) + per_layer_size = placeholder_spec.page_size_bytes * num_blocks + layer_name = "model.layers.0.self_attn" # 创建一个占位的层名 + # 创建 KV cache tensor 和 group + kv_cache_tensors = [KVCacheTensor(size=per_layer_size, shared_by=[layer_name])] + kv_cache_groups = [KVCacheGroupSpec(layer_names=[layer_name],kv_cache_spec=placeholder_spec)] + kv_cache_config = KVCacheConfig(num_blocks=num_blocks, kv_cache_tensors=kv_cache_tensors, \ + kv_cache_groups=kv_cache_groups) + + Logger.info( + f"Created minimal placeholder KV cache config: " + f"{num_blocks} blocks, {num_kv_heads} kv_heads, {head_size} head_size") + return kv_cache_config def execute_model(self, scheduler_output: SchedulerOutput): """Inject KV cache from shared memory before model execution (v1).""" diff --git a/src/vllm_adapter/v1/worker/cpu_model_runner.py b/src/vllm_adapter/v1/worker/cpu_model_runner.py new file mode 100644 index 0000000..35d00ab --- /dev/null +++ b/src/vllm_adapter/v1/worker/cpu_model_runner.py @@ -0,0 +1,196 @@ +""" +Copyright (c) 2025-2025 Huawei Technologies Co., Ltd. + +sysHAX-adapter is licensed under Mulan PSL v2. +You can use this software according to the terms and conditions of the Mulan PSL v2. +You may obtain a copy of Mulan PSL v2 at: + http://license.coscl.org.cn/MulanPSL2 +THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY OR FIT FOR A PARTICULAR +PURPOSE. +See the Mulan PSL v2 for more details. +Created: 2025-11-18 +Desc: src.vllm_adapter.v1.worker.cpu_model_runner.py +""" + +from typing import TYPE_CHECKING, Dict, Optional, Union + +import torch + +from vllm.config import VllmConfig +from vllm.model_executor.model_loader import get_model +from vllm.sequence import IntermediateTensors +from vllm.v1.core.sched.output import SchedulerOutput +from vllm.v1.outputs import ModelRunnerOutput +from vllm.v1.worker.cpu_model_runner import CPUModelRunner + +from src.utils.logger import Logger +from src.vllm_adapter.model_loader import cpu_weight_loader +from src.vllm_adapter.model_executor.model_map import g_model_map + +class SyshaxCPUModelRunner(CPUModelRunner): + def __init__(self, vllm_config: VllmConfig, device: torch.device): + super().__init__(vllm_config, device) + # 存储权重字典(当使用 NUMA 权重存储方案时) + self.loaded_weights: Dict[str, Union[torch.Tensor, Dict[int, torch.Tensor]]] = {} + self.use_numa_weight_storage = False # 标记是否使用 NUMA 权重存储方案 + self.model_class = None + + def load_model(self) -> None: + """Load model, handling both traditional model loading and NUMA weight storage.""" + Logger.info(f"Starting to load model {self.model_config.model}...") + + # 调用 get_model,可能返回 nn.Module 或权重字典 + model_or_weights = get_model(vllm_config=self.vllm_config) + # 检查返回值类型 + if isinstance(model_or_weights, tuple): + # 返回的是模型名称和权重字典,说明使用了 NUMA 权重存储方案 + Logger.info("Detected NUMA weight storage scheme, weights are stored in global storage") + self.use_numa_weight_storage = True + self.model_class, self.loaded_weights = model_or_weights + model_class = g_model_map.get(self.model_class) + if model_class is None: + raise ValueError(f"Model class {self.model_class} is not supported") + self.model = model_class(vllm_config=self.vllm_config) + else: + # 返回的是 nn.Module,使用传统加载方式 + Logger.info("Using traditional model loading scheme") + self.model = model_or_weights + self.use_numa_weight_storage = False + + # 处理 LoRA(如果启用) + if self.lora_config: + self.model = self.load_lora_model( + self.model, + self.model_config, + self.scheduler_config, + self.lora_config, + self.device, + ) + + def execute_model( + self, + scheduler_output: "SchedulerOutput", + intermediate_tensors: Optional[IntermediateTensors] = None, + ) -> Union[ModelRunnerOutput, IntermediateTensors]: + """ + Execute model inference. + + If using NUMA weight storage scheme, uses loaded_weights directly. + Otherwise, falls back to parent class implementation. + """ + if self.use_numa_weight_storage: + # 使用 NUMA 权重存储方案,直接使用 tensor 进行推理 + return self._execute_model_with_tensor_weights(scheduler_output, intermediate_tensors) + else: + # 使用传统方式,调用父类方法 + return super().execute_model(scheduler_output, intermediate_tensors) + + def _execute_model_with_tensor_weights( + self, + scheduler_output: "SchedulerOutput", + intermediate_tensors: Optional[IntermediateTensors] = None, + ) -> Union[ModelRunnerOutput, IntermediateTensors]: + """ + 使用 tensor 权重进行模型推理的接口。 + + 权重可以通过以下方式获取: + - self.loaded_weights: 存储的权重字典 + - weight_loader.get_weight(name): 从全局存储中获取权重 + - weight_loader.get_all_weight_names(): 获取所有权重名称 + + Args: + scheduler_output: 调度器输出 + intermediate_tensors: 中间张量(可选) + + Returns: + ModelRunnerOutput 或 IntermediateTensors + """ + # 更新状态 + self._update_states(scheduler_output) + + if not scheduler_output.total_num_scheduled_tokens: + # 如果没有需要处理的 token,返回空的输出 + from vllm.v1.outputs import EMPTY_MODEL_RUNNER_OUTPUT + return EMPTY_MODEL_RUNNER_OUTPUT + + # 准备输入(复用父类的逻辑) + attn_metadata, logits_indices, spec_decode_metadata = self._prepare_inputs(scheduler_output) + num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens + num_input_tokens = num_scheduled_tokens + + # 准备 input_ids 和 positions + input_ids = self.input_ids[:num_input_tokens] + positions = self.positions[:num_input_tokens] + inputs_embeds = None + + # 处理 pipeline parallelism + from vllm.distributed import get_pp_group, get_tp_group + if get_pp_group().is_first_rank: + intermediate_tensors = None + else: + intermediate_tensors = self.sync_and_slice_intermediate_tensors( + num_input_tokens, intermediate_tensors, True + ) + + # 调用 SyshaxQwen2ForCausalLM 的 forward 方法 + from vllm.attention.backends.utils import set_forward_context + with set_forward_context( + attn_metadata, + self.vllm_config, + num_tokens=num_input_tokens, + ): + model_output = self.model.forward( + input_ids=input_ids, + positions=positions, + intermediate_tensors=intermediate_tensors, + inputs_embeds=inputs_embeds, + ) + + # 处理输出(与父类保持一致) + if isinstance(model_output, IntermediateTensors): + return model_output + + # 处理 pipeline parallelism 的输出 + if self.use_aux_hidden_state_outputs: + hidden_states, aux_hidden_states = model_output + else: + hidden_states = model_output + + # 处理 pipeline parallelism + if not get_pp_group().is_last_rank: + # 对于中间 pipeline 阶段,返回 hidden states + return hidden_states + + # 计算 logits + sample_hidden_states = hidden_states[logits_indices] + logits = self.model.compute_logits(sample_hidden_states, None) + + # 应用 structured output bitmasks(如果存在) + if scheduler_output.grammar_bitmask is not None: + self.apply_grammar_bitmask(scheduler_output, logits) + + # 采样下一个 token + sampling_metadata = self.input_batch.sampling_metadata + if spec_decode_metadata is None: + sampler_output = self.sampler( + logits=logits, + sampling_metadata=sampling_metadata, + ) + else: + # Speculative decoding + bonus_logits = logits[spec_decode_metadata.bonus_logits_indices] + sampler_output = self.sampler( + logits=logits, + sampling_metadata=sampling_metadata, + bonus_logits=bonus_logits, + bonus_logits_indices=spec_decode_metadata.bonus_logits_indices, + ) + + # 构建并返回 ModelRunnerOutput + from vllm.v1.outputs import ModelRunnerOutput + return ModelRunnerOutput( + outputs=sampler_output, + spec_decode_worker_metrics=spec_decode_metadata.worker_metrics + if spec_decode_metadata else None, + ) \ No newline at end of file diff --git a/src/vllm_adapter/v1/worker/cpu_worker.py b/src/vllm_adapter/v1/worker/cpu_worker.py index 392a9b6..8cc7716 100644 --- a/src/vllm_adapter/v1/worker/cpu_worker.py +++ b/src/vllm_adapter/v1/worker/cpu_worker.py @@ -23,6 +23,7 @@ from vllm.distributed import get_tensor_model_parallel_world_size from vllm.attention.ops.paged_attn import PagedAttention from src.vllm_adapter.syshax.shared_memory_manager import SharedMemoryManager +from src.vllm_adapter.v1.worker.cpu_model_runner import SyshaxCPUModelRunner from src.utils.logger import Logger, init_logger @@ -44,6 +45,14 @@ class SyshaxCPUWorker(V1CPUWorker): super().__init__(vllm_config, local_rank, rank, distributed_init_method, is_driver_worker) self.shared_memory_manager = SharedMemoryManager.instance() + def init_device(self): + """ + 重写 init_device 方法,使用 SyshaxCPUModelRunner 替换默认的 CPUModelRunner。 + """ + # 调用父类的其他初始化逻辑 + super().init_device() + self.model_runner = SyshaxCPUModelRunner(self.vllm_config, torch.device("cpu")) + def copy_block_from_sharememory( self, request_id: str, -- Gitee