diff --git a/vllm_mindspore/model_executor/models/mf_models/config.py b/vllm_mindspore/model_executor/models/mf_models/config.py new file mode 100644 index 0000000000000000000000000000000000000000..5f6ecf6e2a945200639eec710eb96487858d06b5 --- /dev/null +++ b/vllm_mindspore/model_executor/models/mf_models/config.py @@ -0,0 +1,164 @@ +#!/usr/bin/env python3 +# encoding: utf-8 +# Copyright 2025 Huawei Technologies Co., Ltd +# Copyright 2025 The vLLM team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +import types + +from mindformers.models.configuration_utils import PretrainedConfig +from mindformers.tools.register.config import MindFormerConfig +from vllm.config import VllmConfig + + +MF_CTX_MAPPING = { + 'run_mode': (None, "predict"), + 'use_legacy': (None, False), + 'load_ckpt_format': (None, 'safetensors'), + 'auto_trans_ckpt': (None, True), +} + + +MF_PARALLEL_MAPPING = { + 'parallel_mode': (None, 'STAND_ALONE'), + 'parallel_config.model_parallel': ('parallel_config.tensor_parallel_size', None), + 'parallel_config.pipeline_stage': ('parallel_config.pipeline_parallel_size', None), + 'parallel_config.vocab_emb_dp': (None, False) +} + + +# Common model config +MODEL_COMMON_MAPPING = { + 'seq_length': ('model_config.max_model_len', None), + + 'use_flash_attention': (None, True), + + "compute_dtype": ('model_config.hf_config.torch_dtype', 'bfloat16'), + + 'architectures': ('model_config.hf_config.architectures', None), + 'bos_token_id': ('model_config.hf_config.bos_token_id', None), + 'eos_token_id': ('model_config.hf_config.eos_token_id', None), + 'model_type': ('model_config.hf_config.model_type', None), + + # transformer_config + 'attention_dropout': ('model_config.hf_config.attention_dropout', None), + 'hidden_act': ('model_config.hf_config.hidden_act', None), + 'hidden_size': ('model_config.hf_config.hidden_size', None), + 'intermediate_size': ('model_config.hf_config.intermediate_size', None), + 'max_position_embeddings': ('model_config.hf_config.max_position_embeddings', None), + 'num_attention_heads': ('model_config.hf_config.num_attention_heads', None), + 'rms_norm_eps': ('model_config.hf_config.rms_norm_eps', None), + + 'num_hidden_layers': ('model_config.hf_config.num_hidden_layers', None), + 'num_layers': ('model_config.hf_config.num_layers', None), + + 'num_key_value_heads': ('model_config.hf_config.num_key_value_heads', None), + 'n_kv_heads': ('model_config.hf_config.n_kv_heads', None), + + 'head_dim': ('model_config.hf_config.head_dim', None), + + 'rope_theta': ('model_config.hf_config.rope_theta', None), + 'tie_word_embeddings': ('model_config.hf_config.tie_word_embeddings', None), + 'vocab_size': ('model_config.hf_config.vocab_size', None), +} + + +# model default config +MODEL_RELATED_MAPPING = { + 'qwen2': { + "gated_linear_unit": True, + 'params_dtype': 'float32', # need an input + 'add_qkv_bias': True, + }, + 'qwen3': { + "gated_linear_unit": True, + 'params_dtype': 'float32', # need an input + 'add_qkv_bias': False, + } + # Add anther model type... +} + + +def get_nested_attr(obj, path: str, default=None): + """get nested attr from obj.""" + current = obj + for attr in path.split('.'): + if not hasattr(current, attr): + return default + current = getattr(current, attr) + return current + +def set_nested_attr(obj, path: str, value): + """Set nested attr of MindFormerConfig.""" + attrs = path.split('.') + + current = obj + for attr in attrs[:-1]: + if not hasattr(current, attr) or getattr(current, attr) is None: + setattr(current, attr, MindFormerConfig()) + current = getattr(current, attr) + + setattr(current, attrs[-1], value) + + +def transform_config(mapping_table: dict, vllm_config: VllmConfig, target_config): + for target_path, mapping in mapping_table.items(): + src_path, transform = mapping + + src_value = get_nested_attr(vllm_config, src_path) if src_path is not None else None + + if src_value is not None: + transformed_value = src_value + elif transform and isinstance(transform, (types.FunctionType, types.BuiltinFunctionType)): + transformed_value = transform(src_value) + else: + transformed_value = transform + + if transformed_value is not None: + set_nested_attr(target_config, target_path, transformed_value) + + +def gen_model_relatived_config(model_type): + return MODEL_RELATED_MAPPING.get(model_type) + + +def gen_model_config_dict(vllm_config: VllmConfig): + target_config = MindFormerConfig() + + transform_config(MODEL_COMMON_MAPPING, vllm_config, target_config) + + model_type = vllm_config.model_config.hf_config.model_type + model_related_config = gen_model_relatived_config(model_type) + target_config.update(model_related_config) + + return target_config + + +def gen_mf_config(vllm_config: VllmConfig): + target_config = MindFormerConfig() + transform_config(MF_CTX_MAPPING, vllm_config, target_config) + transform_config(MF_PARALLEL_MAPPING, vllm_config, target_config) + target_config.set_value( + 'model.model_config', + MindFormerConfig(**gen_model_config_dict(vllm_config)) + ) + return target_config + + +def gen_model_config(mf_config: MindFormerConfig, model_config_type: PretrainedConfig): + model_config = model_config_type( + **mf_config.model.model_config, parallel_config=mf_config.parallel_config + ) + model_config.post_process = False + return model_config diff --git a/vllm_mindspore/model_executor/models/mf_models/mf_model_base.py b/vllm_mindspore/model_executor/models/mf_models/mf_model_base.py index 8a5a07778cd2095b24289db1ac309da8c5ad28e4..922613340f73cee38861bd4bd8f6f3ce199ead27 100644 --- a/vllm_mindspore/model_executor/models/mf_models/mf_model_base.py +++ b/vllm_mindspore/model_executor/models/mf_models/mf_model_base.py @@ -53,7 +53,13 @@ class MfModelBase(MsModelBase): vllm_config=vllm_config, prefix=prefix ) - self.mf_config = MindFormerConfig(os.getenv("MINDFORMERS_MODEL_CONFIG")) + model_config_path = os.getenv("MINDFORMERS_MODEL_CONFIG") + if model_config_path is None: + raise RuntimeError( + f'For "MindFormers" model backend, environments MINDFORMERS_MODEL_CONFIG should be set!' + ) + + self.mf_config = MindFormerConfig(model_config_path) build_mf_context(self.mf_config) build_parallel_config(self.mf_config) self.mf_config.model.model_config.parallel_config = ( diff --git a/vllm_mindspore/model_executor/models/mf_models/qwen3.py b/vllm_mindspore/model_executor/models/mf_models/qwen3.py index a5a8b01d6e906f2c9b8e51c7f3d0af288f05137b..a3e021ac07f0da3d7c64f049972a0f2ba0850d9e 100644 --- a/vllm_mindspore/model_executor/models/mf_models/qwen3.py +++ b/vllm_mindspore/model_executor/models/mf_models/qwen3.py @@ -16,70 +16,258 @@ # limitations under the License. # ============================================================================ -from typing import Iterable, Set, Tuple +from typing import Iterable, Optional, Tuple, Union -from vllm.config import VllmConfig -from vllm.config import get_current_vllm_config -from vllm.logger import init_logger - -from mindspore import Tensor, JitConfig +import mindspore as ms +import numpy as np +from mindformers.core.context import build_mf_context +from mindformers.core.parallel_config import build_parallel_config +from mindformers.models.qwen3.configuration_qwen3 import Qwen3Config +from mindformers.models.qwen3.modeling_qwen3 import Qwen3ForCausalLM as ParallelQwenForCausalLM_MF +from mindformers.tools.utils import is_pynative +from mindspore import Tensor, ops +from mindspore.common.api import _pynative_executor from mindspore.nn.utils import no_init_parameters +from vllm import envs +from vllm.config import VllmConfig, get_current_vllm_config +from vllm.forward_context import get_forward_context +from vllm.logger import init_logger +from vllm.model_executor.layers.sampler import SamplerOutput +from vllm.sequence import IntermediateTensors -from mindformers.models.llama import LlamaConfig as LlamaConfig_MF -from research.qwen3.qwen3 import ( - ParallelQwen3ForCausalLM as ParallelQwenForCausalLM_MF, -) - +from vllm_mindspore import SamplingMetadata from vllm_mindspore.model_executor.layers.sampler import get_sampler -from vllm_mindspore.model_executor.models.model_base import Fake_Attention -from vllm_mindspore.model_executor.models.mf_models.mf_model_base import MfModelBase -from vllm_mindspore.model_executor.models.mf_models.qwen3_weight_processor import Qwen3WeightProcessor - +from vllm_mindspore.model_executor.models.attention_mask import LowerTriangularMask +from vllm_mindspore.model_executor.models.mf_models.config import gen_mf_config, gen_model_config +from vllm_mindspore.model_executor.models.model_base import (Fake_Attention, + Fake_Attention_V1, + MsModelBase) +from vllm_mindspore.v1.attention.backends.flash_attn import FlashAttentionMetadata logger = init_logger(__name__) -class Qwen3ForCausalLM(MfModelBase): +class Qwen3ForCausalLM(MsModelBase): def __init__(self, *, vllm_config: VllmConfig, prefix: str = "") -> None: super(Qwen3ForCausalLM, self).__init__(vllm_config=vllm_config, prefix=prefix) self.mf_kvcaches_init = False + + self.vllm_config = vllm_config + self.kv_transfer_config = vllm_config.kv_transfer_config + self.set_flags = False + + mf_config = gen_mf_config(vllm_config) + mf_config.load_checkpoint = self.get_model_path() + self.mf_config = mf_config + + build_mf_context(self.mf_config) + build_parallel_config(self.mf_config) + + self._generate_model_config() + self.casual_mask = LowerTriangularMask(dtype=self.mf_model_config.compute_dtype, + max_model_len=self.mf_model_config.seq_length) + self.network, self.lm_head = self._create_network() + affinity_config = self.mf_config.get('context', {}).get('affinity_cpu_list', {}) + if isinstance(affinity_config, dict): + ms.runtime.set_cpu_affinity(True, affinity_config) + + self._set_dynamic_inputs() + self.sampler = get_sampler() self.set_modules({"model": self.network}) - self.kv_caches = [Fake_Attention() for i in range(self.mf_model_config.num_layers)] + if envs.VLLM_USE_V1: + self.kv_caches = [Fake_Attention_V1() for _ in range(self.mf_model_config.num_hidden_layers)] + else: + self.kv_caches = [Fake_Attention() for _ in range(self.mf_model_config.num_hidden_layers)] compilation_config = get_current_vllm_config().compilation_config if prefix in compilation_config.static_forward_context: raise ValueError(f"Duplicate layer name: {prefix}") - for i in range(self.mf_model_config.num_layers): + for i in range(self.mf_model_config.num_hidden_layers): compilation_config.static_forward_context[str(i)] = self.kv_caches[i] - self.set_flags = False + self.cast = ops.Cast() + + + def _set_dynamic_inputs(self): + self.network.set_dynamic_inputs() + dynamic_hidden_states = Tensor(shape=[None, None], dtype=self.mf_model_config.compute_dtype) + self.lm_head.set_inputs(dynamic_hidden_states) + + + def _dummy_attention_metadata(self, input_ids: Tensor, positions: Tensor) -> FlashAttentionMetadata: + input_len = input_ids.shape[0] + max_seq_len = ms.Tensor(input_len, dtype=ms.int32) + seq_lengths = ms.Tensor([input_len], dtype=ms.int32) + q_seq_lens = ms.Tensor([input_len], dtype=ms.int32) + q_seq_lens_np = np.array([input_len], dtype=np.int32) + seq_lens_np = np.array([input_len], dtype=np.int32) + + block_tables = ms.Tensor([[0]], dtype=ms.int32) + slot_mapping = [-1 for _ in range(input_len)] + slot_mapping = ms.Tensor(slot_mapping, dtype=ms.int32) + return FlashAttentionMetadata( + max_seq_len=max_seq_len, + seq_lens=seq_lengths, + seq_lens_np=seq_lens_np, + block_tables=block_tables, + slot_mapping=slot_mapping, + q_seq_lens=q_seq_lens, + q_seq_lens_np=q_seq_lens_np, + context_lens=ms.Tensor([0], dtype=ms.int32), + # To enforce prefill and decode are both complied in warmup process. + # So set max_context_lens to 0 for prefill and 1 for decode. + max_context_lens=0 if not self.set_flags else 1, + query_start_loc = None + ) + + + def prepare_inputs(self, input_ids, positions, attn_metadata): + key_cache, value_cache = self.get_kvcache() + if not envs.VLLM_USE_V1: + seq_lens = attn_metadata.seq_lens + max_query_len = attn_metadata.max_query_len + # When Mutli-Step is enabled with Chunked-Prefill, prefills and + # decodes are scheduled together. In the first step, all the + # prefills turn into decodes and max_query_len will be 1. + if self.is_multi_step_chunked_prefill and max_query_len == 1: + query_lens = [1] * len(seq_lens) + else: + query_lens = attn_metadata.query_lens + + seq_lens = attn_metadata.seq_lens + max_query_len = attn_metadata.max_query_len + # When Mutli-Step is enabled with Chunked-Prefill, prefills and + # decodes are scheduled together. In the first step, all the + # prefills turn into decodes and max_query_len will be 1. + if self.is_multi_step_chunked_prefill and max_query_len == 1: + query_lens = [1] * len(seq_lens) + else: + query_lens = attn_metadata.query_lens + + seq_lens_np = np.array(seq_lens, dtype=np.int32) + query_lens_np = np.array(query_lens, dtype=np.int32) + kv_cache_lens = seq_lens_np - query_lens_np + if attn_metadata.num_decode_tokens == 0 and kv_cache_lens.max() == 0: + is_prefill = True + else: + is_prefill = False + + q_seq_lens = ms.Tensor(query_lens_np, dtype=ms.int32) + position_ids = ms.Tensor(positions, dtype=ms.int32) + attention_mask = self.casual_mask.gen_attention_mask(is_prefill, position_ids, query_lens) + + model_inputs = {} + model_inputs["input_ids"] = input_ids.astype(ms.int32) + model_inputs["batch_valid_length"] = ms.from_numpy(seq_lens_np) + model_inputs["block_tables"] = attn_metadata.block_tables + model_inputs["slot_mapping"] = attn_metadata.slot_mapping + model_inputs["q_seq_lens"] = q_seq_lens + model_inputs["attention_mask"] = attention_mask + model_inputs["key_cache"] = key_cache + model_inputs["value_cache"] = value_cache + model_inputs["positions"] = position_ids + model_inputs["context_lens_tensor"] = ms.from_numpy(kv_cache_lens) + else: + if attn_metadata.max_context_lens == 0: + is_prefill = True + else: + is_prefill = False + q_seq_lens = attn_metadata.q_seq_lens + query_lens_np = attn_metadata.q_seq_lens_np + attention_mask = self.casual_mask.gen_attention_mask(is_prefill, positions, query_lens_np) + model_inputs = {} + model_inputs["input_ids"] = input_ids.astype(ms.int32) + model_inputs["batch_valid_length"] = ms.from_numpy(attn_metadata.seq_lens_np) + model_inputs["block_tables"] = attn_metadata.block_tables + model_inputs["slot_mapping"] = attn_metadata.slot_mapping + model_inputs["q_seq_lens"] = q_seq_lens + model_inputs["attention_mask"] = attention_mask + model_inputs["key_cache"] = key_cache + model_inputs["value_cache"] = value_cache + model_inputs["positions"] = positions.to(ms.int32) + model_inputs["context_lens_tensor"] = attn_metadata.context_lens + + return model_inputs, is_prefill + + + def forward( + self, + input_ids: Tensor, + positions: Tensor, + intermediate_tensors: Optional[IntermediateTensors] = None, + inputs_embeds: Optional[Tensor] = None, + **kwargs + ) -> Union[Tensor, IntermediateTensors]: + attn_metadata = get_forward_context().attn_metadata + if attn_metadata is None: + attn_metadata = self._dummy_attention_metadata(input_ids, positions) + model_inputs, is_prefill = self.prepare_inputs(input_ids, positions, attn_metadata) + model_inputs = self.update_model_inputs(model_inputs, **kwargs) + + if is_prefill: + self.network.phase = "prefill" + if not self.set_flags or is_pynative(): + self.network.add_flags_custom_mcore(is_prefill=True) + hidden_states = self.network(**model_inputs) + self.network.phase = "increment" + if not self.set_flags or is_pynative(): + self.network.add_flags_custom_mcore(is_prefill=False) + self.set_flags = True + else: + hidden_states = self.network(**model_inputs) + + return hidden_states + def _generate_model_config(self): - self.mf_config.load_checkpoint = self.get_model_path() - self.mf_model_config = LlamaConfig_MF(**self.mf_config.model.model_config) - if self.mf_config.moe_config: - self.mf_model_config.moe_config = self.mf_config.moe_config - self.mf_model_config.return_hidden_states = True + self.mf_model_config = gen_model_config(self.mf_config, Qwen3Config) + logger.debug("=====mf_model_config====\n", self.mf_model_config) - # qwen qkv concat will support in next version - self.mf_model_config.qkv_concat = False - setattr(self.mf_model_config, 'npu_mem_size', -1) - self.mf_config.model.model_config.qkv_concat = False def _create_network(self): # Initial network with no_init_parameters(): # Delay initialization network = ParallelQwenForCausalLM_MF(self.mf_model_config) - return network, network.lm_head + return network, network.model.output_layer + - def load_weights(self, weights: Iterable[Tuple[str, Tensor]]) -> Set[str]: - weight_processor = Qwen3WeightProcessor(self.mf_config, self.network, False) - weight_processor.load_safetensors_shard(self.mf_config.load_checkpoint) + def update_model_inputs(self, model_inputs, **kwargs): + return model_inputs + + def compute_logits( + self, + hidden_states: Tensor, + sampling_metadata: SamplingMetadata, + ) -> Optional[Tensor]: + if sampling_metadata is not None: + selected_token_indices = sampling_metadata.selected_token_indices + if selected_token_indices is not None and selected_token_indices.numel() <= 0: + logits = ms.mint.zeros((0, self.mf_model_config.vocab_size), + dtype=self.mf_model_config.compute_dtype) + else: + hidden_states = hidden_states.reshape((-1, hidden_states.shape[-1])) + hidden_states = hidden_states.index_select(0, selected_token_indices) + logits = self.lm_head(hidden_states) + logits = logits.view(-1, logits.shape[-1]) + else: + logits = self.lm_head(hidden_states) + logits = logits.view(-1, logits.shape[-1]) + return logits + + def sample( + self, + logits: Tensor, + sampling_metadata: SamplingMetadata, + ) -> Optional[SamplerOutput]: + next_tokens = self.sampler(logits, sampling_metadata) + _pynative_executor.sync() + return next_tokens + + def load_weights(self, weights: Iterable[Tuple[str, Tensor]]): + self.network.load_weights(self.mf_config.load_checkpoint) self.network.set_dynamic_inputs() - dynamic_hidden_states = Tensor(shape=[None, None], dtype=self.mf_model_config.compute_dtype) - self.lm_head.set_inputs(dynamic_hidden_states) return None diff --git a/vllm_mindspore/model_executor/models/mf_models/qwen3_weight_processor.py b/vllm_mindspore/model_executor/models/mf_models/qwen3_weight_processor.py deleted file mode 100644 index 338616cafda4f4864c26b58530f7db8d11481d9e..0000000000000000000000000000000000000000 --- a/vllm_mindspore/model_executor/models/mf_models/qwen3_weight_processor.py +++ /dev/null @@ -1,115 +0,0 @@ -# Copyright 2025 Huawei Technologies Co., Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================ - -""" -transform huggingface model to mindspore safetensor. -""" -import numpy as np - -import mindspore as ms - -from vllm_mindspore.model_executor.models.mf_models.qwen2_weight_processor import Qwen2WeightProcessor - - -class Qwen3WeightProcessor(Qwen2WeightProcessor): - r""" - Provide Qwen3 Model weight load and shards. - Args: - config (Qwen3Config): The config of Qwen3 model. - network (InferenceQwen3ForCausalLM): The network of Qwen3. - - """ - - def __init__(self, config, network, is_quant): - super().__init__(config, network, is_quant) - - def convert_weight_name(self, weight_name: str): - """replace weight name""" - weight_name = weight_name.replace('embed_tokens.weight', 'tok_embeddings.embedding_weight') - weight_name = weight_name.replace('self_attn.q_proj.', 'attention.wq.') - weight_name = weight_name.replace('self_attn.k_proj.', 'attention.wk.') - weight_name = weight_name.replace('self_attn.v_proj.', 'attention.wv.') - weight_name = weight_name.replace('self_attn.o_proj.', 'attention.wo.') - weight_name = weight_name.replace('self_attn.q_norm.', 'attention.q_norm.') - weight_name = weight_name.replace('self_attn.k_norm.', 'attention.k_norm.') - - weight_name = weight_name.replace('mlp.gate_proj.', 'feed_forward.w1.') - weight_name = weight_name.replace('mlp.down_proj.', 'feed_forward.w2.') - weight_name = weight_name.replace('mlp.up_proj.', 'feed_forward.w3.') - weight_name = weight_name.replace('.input_layernorm.', '.attention_norm.') - weight_name = weight_name.replace('.post_attention_layernorm.', '.ffn_norm.') - weight_name = weight_name.replace('model.norm.weight', 'model.norm_out.weight') - return weight_name - - def infer_process_attention_weight(self, src_hf_dir, layer_id, hf_weight_map): - """infer process attention weight""" - qkv_concat = self.config.model.model_config.qkv_concat - # wq - wq_hf_name = f"model.layers.{layer_id}.self_attn.q_proj.weight" - wq_ms_name = self.convert_weight_name(wq_hf_name) - wq_ms_param, _ = self.get_safetensor_from_file(wq_hf_name, src_hf_dir, hf_weight_map, is_split_param=True, - split_axis=0) - - # wk - wk_hf_name = f"model.layers.{layer_id}.self_attn.k_proj.weight" - wk_ms_name = self.convert_weight_name(wk_hf_name) - wk_ms_param, _ = self.get_safetensor_from_file(wk_hf_name, src_hf_dir, hf_weight_map, is_split_param=True, - split_axis=0) - - # wv - wv_hf_name = f"model.layers.{layer_id}.self_attn.v_proj.weight" - wv_ms_name = self.convert_weight_name(wv_hf_name) - wv_ms_param, _ = self.get_safetensor_from_file(wv_hf_name, src_hf_dir, hf_weight_map, is_split_param=True, - split_axis=0) - - # wq_norm - q_norm_hf_name = f"model.layers.{layer_id}.self_attn.q_norm.weight" - q_norm_ms_name = self.convert_weight_name(q_norm_hf_name) - q_norm_ms_param, _ = self.get_safetensor_from_file(q_norm_hf_name, src_hf_dir, hf_weight_map) - self.parameter_dict[q_norm_ms_name] = ms.Parameter(ms.Tensor(q_norm_ms_param, ms.bfloat16), name=q_norm_ms_name, - requires_grad=False) - - #wk_norm - k_norm_hf_name = f"model.layers.{layer_id}.self_attn.k_norm.weight" - k_norm_ms_name = self.convert_weight_name(k_norm_hf_name) - k_norm_ms_param, _ = self.get_safetensor_from_file(k_norm_hf_name, src_hf_dir, hf_weight_map) - self.parameter_dict[k_norm_ms_name] = ms.Parameter(ms.Tensor(k_norm_ms_param, ms.bfloat16), name=k_norm_ms_name, - requires_grad=False) - - if qkv_concat: - w_qkv_name = f"model.layers.{layer_id}.attention.w_qkv.weight" - w_qkv_param = np.concatenate((wq_ms_param, wk_ms_param, wv_ms_param), axis=0) - w_qkv_param = ms.from_numpy(w_qkv_param).astype(ms.bfloat16) - self.parameter_dict[w_qkv_name] = ms.Parameter(w_qkv_param, name=w_qkv_name, requires_grad=False) - - else: - self.parameter_dict[wq_ms_name] = ms.Parameter(ms.from_numpy(wq_ms_param).astype(ms.bfloat16), - name=wq_ms_name, - requires_grad=False) - self.parameter_dict[wk_ms_name] = ms.Parameter(ms.from_numpy(wk_ms_param).astype(ms.bfloat16), - name=wk_ms_name, - requires_grad=False) - self.parameter_dict[wv_ms_name] = ms.Parameter(ms.from_numpy(wv_ms_param).astype(ms.bfloat16), - name=wv_ms_name, - requires_grad=False) - - # wo - wo_hf_name = f"model.layers.{layer_id}.self_attn.o_proj.weight" - wo_ms_name = self.convert_weight_name(wo_hf_name) - wo_ms_param, _ = self.get_safetensor_from_file(wo_hf_name, src_hf_dir, hf_weight_map, is_split_param=True, - split_axis=1) - self.parameter_dict[wo_ms_name] = ms.Parameter(ms.from_numpy(wo_ms_param).astype(ms.bfloat16), - name=wo_ms_name, - requires_grad=False) diff --git a/vllm_mindspore/model_executor/models/registry.py b/vllm_mindspore/model_executor/models/registry.py index bdc43d8bfe04029188dfb0723ea8df50dab2b3e2..dc518c31a26d1194e63436362a62fa6278b1827d 100644 --- a/vllm_mindspore/model_executor/models/registry.py +++ b/vllm_mindspore/model_executor/models/registry.py @@ -15,6 +15,7 @@ # limitations under the License. # ============================================================================ +import os import pickle import sys from typing import TypeVar @@ -32,9 +33,9 @@ _MINDSPORE_MODELS = { _MINDFORMERS_MODELS = { "Qwen2ForCausalLM": ("qwen2", "Qwen2ForCausalLM"), + "Qwen3ForCausalLM": ("qwen3", "Qwen3ForCausalLM"), "DeepseekV3ForCausalLM": ("deepseek_v3", "DeepseekV3ForCausalLM"), "DeepSeekMTPModel": ("deepseek_mtp", "DeepseekV3MTPForCausalLM"), - "Qwen3ForCausalLM": ("qwen3", "Qwen3ForCausalLM"), } _MINDONE_MODELS = { @@ -48,8 +49,7 @@ _registry_dict = {} if is_mindformers_model_backend(): _registry_dict = { model_arch: _LazyRegisteredModel( - module_name= - f"vllm_mindspore.model_executor.models.mf_models.{mod_relname}", + module_name=f"vllm_mindspore.model_executor.models.mf_models.{mod_relname}", class_name=cls_name, ) for model_arch, (mod_relname, cls_name) in _MINDFORMERS_MODELS.items() diff --git a/vllm_mindspore/model_executor/models/utils.py b/vllm_mindspore/model_executor/models/utils.py index 4bb7831c584c03b61bda9e0d751d32be934db19b..1acd616da3340364986bca61edb047217b114d62 100644 --- a/vllm_mindspore/model_executor/models/utils.py +++ b/vllm_mindspore/model_executor/models/utils.py @@ -19,15 +19,15 @@ from dataclasses import dataclass, field from typing import List, Tuple, Union, Mapping, Optional, Iterable +import mindspore as ms +from mindspore import mint +from mindspore import ops + from vllm.sequence import IntermediateTensors from vllm_mindspore.multimodal.inputs import NestedTensors from vllm_mindspore.utils import get_valid_dtype -import mindspore as ms -from mindspore import mint -from mindspore import ops - WeightsMapping = Mapping[str, Optional[str]] """If a key maps to a value of `None`, the corresponding weight is ignored.""" @@ -71,7 +71,6 @@ class WeightsMapper: return ((out_name, data) for name, data in weights if (out_name := self._map_name(name)) is not None) - class PPMissingLayer(ms.nn.Cell): """ A placeholder layer for missing layers in a pipeline parallel model. diff --git a/vllm_mindspore/utils.py b/vllm_mindspore/utils.py index 60cd4af040fc9e9bda617eb8b6cd5d5130ef765f..6fd3ca83193c5ce6fc719a24072ccd5849bce920 100644 --- a/vllm_mindspore/utils.py +++ b/vllm_mindspore/utils.py @@ -19,6 +19,7 @@ import contextlib import gc import os import sys +from enum import Enum from typing import (TYPE_CHECKING, Callable, Generator, List, Optional, Tuple, Union) @@ -152,6 +153,11 @@ STR_DTYPE_TO_MS_DTYPE = { "fp8_e5m2": mstype.uint8, } +class vllmModelBackendEnum(str, Enum): + """Define the variable Enum of vLLM_MODEL_BACKEND""" + MF = 'MindFormers' + MIND_ONE = 'MindONE' + def get_dtype_size(dtype: torch.dtype) -> int: """Get the size of the data type in bytes.""" @@ -203,15 +209,24 @@ def ascend_is_initialized(): def is_mindformers_model_backend(): - return (os.getenv("vLLM_MODEL_BACKEND") # noqa: SIM112 - and - os.environ["vLLM_MODEL_BACKEND"] == "MindFormers" # noqa: SIM112 - ) + vllm_model_backend = os.getenv("vLLM_MODEL_BACKEND") + if vllm_model_backend: + try: + vllmModelBackendEnum(vllm_model_backend) + except ValueError as exc: + allowed_values = [member.value for member in vllmModelBackendEnum] + raise ValueError( + f"Illegal value of vLLM_MODEL_BACKEND '{vllm_model_backend}'," + f" allowed_values: {', '.join(allowed_values)}" + ) from exc + finally: + return vllm_model_backend == vllmModelBackendEnum.MF + return False def is_mindone_model_backend(): return (os.getenv("vLLM_MODEL_BACKEND") # noqa: SIM112 - and os.environ["vLLM_MODEL_BACKEND"] == "MindONE" # noqa: SIM112 + and os.environ["vLLM_MODEL_BACKEND"] == vllmModelBackendEnum.MIND_ONE # noqa: SIM112 ) @@ -234,15 +249,6 @@ def check_ready(): if is_mindformers_model_backend(): logger.info("Run with Mindformers backend!") - necessary_envs = ("MINDFORMERS_MODEL_CONFIG", ) - lost_envs = [ - env_item for env_item in necessary_envs if not os.getenv(env_item) - ] - - if lost_envs: - raise RuntimeError( - f'For "MindFormers" model backend, environments {str(lost_envs)} should be set!' - ) elif is_mindone_model_backend(): logger.info("Run with MindONE backend!") else: