diff --git a/vllm_mindspore/model_executor/models/internlm3.py b/vllm_mindspore/model_executor/models/internlm3.py new file mode 100644 index 0000000000000000000000000000000000000000..4293583c4f06dbf0fbbf96eb562e5658ad9021fb --- /dev/null +++ b/vllm_mindspore/model_executor/models/internlm3.py @@ -0,0 +1,532 @@ +# SPDX-License-Identifier: Apache-2.0 + +# Adapted from +# https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/models/internlm3/modeling_internlm3.py +# Copyright 2025 Huawei Technologies Co., Ltd. +# Copyright 2024 The InternLM team. +# Copyright 2023 The vLLM team. +# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved. +# +# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX +# and OPT implementations in this library. It has been modified from its +# original forms to accommodate minor architectural differences compared +# to GPT-NeoX and OPT used by the Meta AI team that trained the model. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# The adaptation of internlm3 in vllm-mindspore mainly includes the +# following points: +# 1. Additional model input parameters have been added, such as +# `key_cache`, `block_tables`, etc., to accommodate the +# vllm-mindspore calling convention. +# 2. During model initialization, methods from the NativeModel base +# class, such as `common_preprocess`, are invoked to adapt to the +# vllm-mindspore workflow. +# 3. In the `forward` function, the `exec_model` method is called to +# perform the model's forward computation, aligning with the +# vllm-mindspore execution flow. +# 4. In the `load_weights` function, due to the lack of `skip_prefix` +# functionality, the handling of `tie_word_embeddings` has been +# adapted. + +"""Inference-only InternLM3 model compatible with HuggingFace weights.""" +from collections.abc import Iterable +from typing import TYPE_CHECKING, Any, Optional, Union + +if TYPE_CHECKING: + from transformers import InternLM3Config +else: + InternLM3Config = None + +from mindspore import Tensor, mint, nn +from vllm.attention import AttentionType +from vllm.config import CacheConfig, VllmConfig +from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size +from vllm.logger import init_logger +from vllm.model_executor.layers.quantization import QuantizationConfig +from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler +from vllm.model_executor.sampling_metadata import SamplingMetadata +from vllm.sequence import IntermediateTensors + +from vllm_mindspore.attention import Attention +from vllm_mindspore.model_executor.layers.activation import SiluAndMul +from vllm_mindspore.model_executor.layers.layernorm import RMSNorm +from vllm_mindspore.model_executor.layers.linear import ( + MergedColumnParallelLinear, QKVParallelLinear, RowParallelLinear) +from vllm_mindspore.model_executor.layers.logits_processor import ( + LogitsProcessor) +from vllm_mindspore.model_executor.layers.rotary_embedding import get_rope +from vllm_mindspore.model_executor.layers.vocab_parallel_embedding import ( + DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding) +from vllm_mindspore.model_executor.model_loader.weight_utils import ( + default_weight_loader) +from vllm_mindspore.model_executor.models.model_base import NativeModel +from vllm_mindspore.model_executor.models.utils import ( + PPMissingLayer, extract_layer_index, + make_empty_intermediate_tensors_factory, make_layers, maybe_prefix) + +logger = init_logger(__name__) + + +class InternLM3MLP(nn.Cell): + + def __init__( + self, + hidden_size: int, + intermediate_size: int, + hidden_act: str, + quant_config=None, + bias: bool = False, + prefix: str = "", + ) -> None: + super().__init__() + self.gate_up_proj = MergedColumnParallelLinear( + input_size=hidden_size, + output_sizes=[intermediate_size] * 2, + bias=bias, + quant_config=quant_config, + prefix=f"{prefix}.gate_up_proj", + ) + self.down_proj = RowParallelLinear( + input_size=intermediate_size, + output_size=hidden_size, + bias=bias, + quant_config=quant_config, + prefix=f"{prefix}.down_proj", + ) + if hidden_act != "silu": + raise ValueError(f"Unsupported activation: {hidden_act}. " + "Only silu is supported for now.") + self.act_fn = SiluAndMul() + + def construct(self, x): + x, _ = self.gate_up_proj(x) + x = self.act_fn(x) + x, _ = self.down_proj(x) + return x + + +class InternLM3Attention(nn.Cell): + + def __init__( + self, + config: InternLM3Config, + hidden_size: int, + num_heads: int, + num_kv_heads: int, + rope_theta: float = 10000, + rope_scaling: Optional[dict[str, Any]] = None, + max_position_embeddings: int = 8192, + quant_config=None, + bias: bool = False, + bias_o_proj: bool = False, + cache_config=None, + prefix: str = "", + ) -> None: + super().__init__() + layer_idx = extract_layer_index(prefix) + self.hidden_size = hidden_size + tp_size = get_tensor_model_parallel_world_size() + self.total_num_heads = num_heads + assert self.total_num_heads % tp_size == 0 + self.num_heads = self.total_num_heads // tp_size + self.total_num_kv_heads = num_kv_heads + if self.total_num_kv_heads >= tp_size: + # Number of KV heads is greater than TP size, so we partition + # the KV heads across multiple tensor parallel GPUs. + assert self.total_num_kv_heads % tp_size == 0 + else: + # Number of KV heads is less than TP size, so we replicate + # the KV heads across multiple tensor parallel GPUs. + assert tp_size % self.total_num_kv_heads == 0 + self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size) + # InternLM3 has an optional head_dim introduced + self.head_dim = getattr(config, "head_dim", + self.hidden_size // self.total_num_heads) + self.rotary_dim = self.head_dim + self.q_size = self.num_heads * self.head_dim + self.kv_size = self.num_kv_heads * self.head_dim + self.scaling = self.head_dim**-0.5 + self.rope_theta = rope_theta + self.max_position_embeddings = max_position_embeddings + + self.qkv_proj = QKVParallelLinear( + hidden_size=hidden_size, + head_size=self.head_dim, + total_num_heads=self.total_num_heads, + total_num_kv_heads=self.total_num_kv_heads, + bias=bias, + quant_config=quant_config, + prefix=f"{prefix}.qkv_proj", + ) + + self.o_proj = RowParallelLinear( + input_size=self.total_num_heads * self.head_dim, + output_size=hidden_size, + bias=bias_o_proj, + quant_config=quant_config, + prefix=f"{prefix}.o_proj", + ) + + is_neox_style = True + self.rotary_emb = get_rope( + self.head_dim, + rotary_dim=self.head_dim, + max_position=max_position_embeddings, + base=rope_theta, + rope_scaling=rope_scaling, + is_neox_style=is_neox_style, + ) + + self.attn = Attention( + self.num_heads, + self.head_dim, + self.scaling, + num_kv_heads=self.num_kv_heads, + cache_config=cache_config, + quant_config=quant_config, + prefix=f"{prefix}.attn", + ) + + def construct( + self, + positions: Tensor, + hidden_states: Tensor, + key_cache: Tensor, + value_cache: Tensor, + slot_mapping: Tensor, + attn_mask: Tensor, + batch_valid_length: Tensor, + q_seq_lens: Tensor, + block_tables: Tensor, + ) -> Tensor: + qkv, _ = self.qkv_proj(hidden_states) + q, k, v = mint.split(qkv, (self.q_size, self.kv_size, self.kv_size), + -1) + q, k = self.rotary_emb(positions, q, k, batch_valid_length) + attn_output = self.attn(q, k, v, key_cache, value_cache, slot_mapping, + attn_mask, batch_valid_length, q_seq_lens, + block_tables) + output, _ = self.o_proj(attn_output) + return output + + +class InternLM3DecoderLayer(nn.Cell): + + def __init__( + self, + config: InternLM3Config, + cache_config=None, + quant_config=None, + prefix: str = "", + ): + super().__init__() + self.hidden_size = config.hidden_size + rope_theta = getattr(config, "rope_theta", 10000) + rope_scaling = getattr(config, "rope_scaling", None) + if rope_scaling is not None and getattr( + config, "original_max_position_embeddings", None): + rope_scaling["original_max_position_embeddings"] = ( + config.original_max_position_embeddings) + max_position_embeddings = getattr(config, "max_position_embeddings", + 8192) + # Support internlm/internlm3-8b with qkv_bias + attention_bias = getattr(config, "qkv_bias", False) + bias_o_proj = getattr(config, "bias", False) + + self.self_attn = InternLM3Attention( + config=config, + hidden_size=self.hidden_size, + num_heads=config.num_attention_heads, + num_kv_heads=getattr(config, "num_key_value_heads", + config.num_attention_heads), + rope_theta=rope_theta, + rope_scaling=rope_scaling, + max_position_embeddings=max_position_embeddings, + quant_config=quant_config, + bias=attention_bias, + bias_o_proj=bias_o_proj, + cache_config=cache_config, + prefix=f"{prefix}.self_attn", + ) + self.mlp = InternLM3MLP( + hidden_size=self.hidden_size, + intermediate_size=config.intermediate_size, + hidden_act=config.hidden_act, + quant_config=quant_config, + bias=getattr(config, "mlp_bias", False), + prefix=f"{prefix}.mlp", + ) + self.input_layernorm = RMSNorm(config.hidden_size, + eps=config.rms_norm_eps) + self.post_attention_layernorm = RMSNorm(config.hidden_size, + eps=config.rms_norm_eps) + + def construct( + self, + positions: Tensor, + hidden_states: Tensor, + key_cache: Tensor, + value_cache: Tensor, + slot_mapping: Tensor, + attn_mask: Tensor, + batch_valid_length: Tensor, + q_seq_lens: Tensor, + block_tables: Tensor, + residual: Optional[Tensor], + ) -> tuple[Tensor, Tensor]: + # Self Attention + if residual is None: + residual = hidden_states + hidden_states = self.input_layernorm(hidden_states) + else: + hidden_states, residual = self.input_layernorm( + hidden_states, residual) + + hidden_states = self.self_attn(positions, hidden_states, key_cache, + value_cache, slot_mapping, attn_mask, + batch_valid_length, q_seq_lens, + block_tables) + + # Fully Connected + hidden_states, residual = self.post_attention_layernorm( + hidden_states, residual) + hidden_states = self.mlp(hidden_states) + return hidden_states, residual + + +class InternLM3Model(nn.Cell): + SUPPORT_LORA = False + SUPPORT_PP = False + + def __init__( + self, + *, + vllm_config, + prefix: str = "", + layer_type: type[InternLM3DecoderLayer] = InternLM3DecoderLayer, + ): + super().__init__() + config = vllm_config.model_config.hf_config + self.config = config + self.padding_idx = config.pad_token_id + self.vocab_size = config.vocab_size + self.org_vocab_size = config.vocab_size + quant_config = vllm_config.quant_config + self.quant_config = quant_config + cache_config = vllm_config.cache_config + lora_config = vllm_config.lora_config # noqa: F841 + + if get_pp_group().is_first_rank or (config.tie_word_embeddings + and get_pp_group().is_last_rank): + self.embed_tokens = VocabParallelEmbedding( + self.vocab_size, + config.hidden_size, + org_num_embeddings=config.vocab_size, + quant_config=quant_config, + ) + else: + self.embed_tokens = PPMissingLayer() + + self.start_layer, self.end_layer, self.layers = make_layers( + config.num_hidden_layers, + lambda prefix: layer_type( + config=config, + cache_config=cache_config, + quant_config=quant_config, + prefix=prefix, + ), + prefix=f"{prefix}.layers", + ) + + if get_pp_group().is_last_rank: + self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + else: + self.norm = PPMissingLayer() + + self.make_empty_intermediate_tensors = \ + make_empty_intermediate_tensors_factory( + ["hidden_states", "residual"], config.hidden_size) + + def get_input_embeddings(self, input_ids: Tensor) -> Tensor: + return self.embed_tokens(input_ids) + + def construct( + self, + input_ids: Optional[Tensor], + positions: Tensor, + key_caches: list[Tensor], + value_caches: list[Tensor], + slot_mapping: Tensor, + attn_mask: Tensor, + batch_valid_length: Tensor, + q_seq_lens: Tensor, + block_tables: Tensor, + intermediate_tensors: Optional[IntermediateTensors] = None, + inputs_embeds: Optional[Tensor] = None, + ) -> Union[Tensor, IntermediateTensors]: + if get_pp_group().is_first_rank: + if inputs_embeds is not None: + hidden_states = inputs_embeds + else: + hidden_states = self.get_input_embeddings(input_ids) + residual = None + else: + assert intermediate_tensors is not None + hidden_states = intermediate_tensors["hidden_states"] + residual = intermediate_tensors["residual"] + + for i in range(self.start_layer, self.end_layer): + layer = self.layers[i] + hidden_states, residual = layer(positions, hidden_states, + key_caches[i - self.start_layer], + value_caches[i - self.start_layer], + slot_mapping, attn_mask, + batch_valid_length, q_seq_lens, + block_tables, residual) + + if not get_pp_group().is_last_rank: + return IntermediateTensors({ + "hidden_states": hidden_states, + "residual": residual + }) + + hidden_states, _ = self.norm(hidden_states, residual) + return hidden_states + + def load_weights(self, weights: Iterable[tuple[str, Tensor]], params_dict): + loaded_params: set[str] = set() + stacked_params_mapping = [ + # (param_name, shard_name, shard_id) + (".qkv_proj", ".q_proj", "q"), + (".qkv_proj", ".k_proj", "k"), + (".qkv_proj", ".v_proj", "v"), + (".gate_up_proj", ".gate_proj", 0), + (".gate_up_proj", ".up_proj", 1), + ] + + for name, loaded_weight in weights: + if "rotary_emb.inv_freq" in name: + continue + if "rotary_emb.cos_cached" in name or \ + "rotary_emb.sin_cached" in name: + # Models trained using ColossalAI may include these tensors in + # the checkpoint. Skip them. + continue + + for param_name, weight_name, shard_id in stacked_params_mapping: + if weight_name not in name: + continue + name = name.replace(weight_name, param_name) + if name in params_dict: + param = params_dict[name] + weight_loader = param.weight_loader + weight_loader(param, loaded_weight, shard_id) + loaded_params.add(name) + break + else: + if name in params_dict: + param = params_dict[name] + weight_loader = getattr(param, "weight_loader", + default_weight_loader) + weight_loader(param, loaded_weight) + loaded_params.add(name) + + return loaded_params + + +class InternLM3ForCausalLM(NativeModel): + + def __init__(self, vllm_config, prefix: str = ""): + super().__init__(vllm_config=vllm_config, prefix=prefix) + + config = vllm_config.model_config.hf_config + quant_config = vllm_config.quant_config + lora_config = vllm_config.lora_config + + self.config = config + self.lora_config = lora_config + + self.quant_config = quant_config + self.model = InternLM3Model(vllm_config=vllm_config) + + if get_pp_group().is_last_rank: + self.unpadded_vocab_size = self.config.vocab_size + # TODO: To support lora + # if self.lora_config: + # self.unpadded_vocab_size += + # self.lora_config.lora_extra_vocab_size + # self.unpadded_vocab_size += config.lora_extra_vocab_size + self.lm_head = ParallelLMHead( + self.unpadded_vocab_size, + self.config.hidden_size, + org_num_embeddings=self.config.vocab_size, + padding_size=( + DEFAULT_VOCAB_PADDING_SIZE + # We need bigger padding if using lora for kernel + # compatibility + if not self.lora_config else + self.lora_config.lora_vocab_padding_size), + quant_config=quant_config, + prefix=maybe_prefix(prefix, "lm_head"), + ) + if self.config.tie_word_embeddings: + self.lm_head = self.lm_head.tie_weights( + self.model.embed_tokens) + + logit_scale = getattr(self.config, "logit_scale", 1.0) + self.logits_processor = LogitsProcessor(self.unpadded_vocab_size, + self.config.vocab_size, + logit_scale) + self.sampler = get_sampler() + else: + self.lm_head = PPMissingLayer() + + self.make_empty_intermediate_tensors = ( + self.model.make_empty_intermediate_tensors) + + self.common_preprocess(vllm_config, prefix) + + def get_input_embeddings(self, input_ids: Tensor) -> Tensor: + return self.model.get_input_embeddings(input_ids) + + def forward(self, + input_ids, + positions, + intermediate_tensors=None, + inputs_embeds=None, + **kwargs): + hidden_states = self.exec_model(input_ids, positions, + intermediate_tensors, inputs_embeds) + return hidden_states + + def load_weights(self, weights: Iterable[tuple[str, Tensor]]) -> set[str]: + params_dict = self.get_params_dict() + load_params = self.model.load_weights(weights, params_dict) + if self.config.tie_word_embeddings: + load_params.add("lm_head.weight") + return load_params + + def sample(self, logits: Tensor, + sampling_metadata: SamplingMetadata) -> Optional[SamplerOutput]: + next_tokens = self.sampler(logits, sampling_metadata) + return next_tokens + + def compute_logits( + self, + hidden_states: Tensor, + sampling_metadata: SamplingMetadata, + ) -> Optional[Tensor]: + logits = self.logits_processor(self.lm_head, hidden_states, + sampling_metadata) + return logits \ No newline at end of file diff --git a/vllm_mindspore/model_executor/models/llama.py b/vllm_mindspore/model_executor/models/llama.py index 29277ccc51f2d1fbd46882d7dac26367ac1be93f..ad85b0226b08258b25b1cd8f2d414586ec59611d 100644 --- a/vllm_mindspore/model_executor/models/llama.py +++ b/vllm_mindspore/model_executor/models/llama.py @@ -48,7 +48,6 @@ else: from mindspore import Tensor, mint, nn from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size -from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler from vllm.model_executor.models.interfaces import SupportsPP from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.sequence import IntermediateTensors @@ -424,7 +423,7 @@ class LlamaModel(nn.Cell): def load_weights(self, weights: Iterable[tuple[str, Tensor]], params_dict): loaded_params: set[str] = set() stacked_params_mapping = [ - # (param_name, shard_name, shard_id) + # shape is (param_name, shard_name, shard_id). (".qkv_proj", ".q_proj", "q"), (".qkv_proj", ".k_proj", "k"), (".qkv_proj", ".v_proj", "v"), @@ -472,11 +471,7 @@ class LlamaForCausalLM(NativeModel, SupportsPP): if get_pp_group().is_last_rank: self.unpadded_vocab_size = self.config.vocab_size - # TODO: To support lora - # if self.lora_config: - # self.unpadded_vocab_size += - # self.lora_config.lora_extra_vocab_size - # self.unpadded_vocab_size += config.lora_extra_vocab_size + self.lm_head = ParallelLMHead( self.unpadded_vocab_size, self.config.hidden_size, @@ -498,7 +493,6 @@ class LlamaForCausalLM(NativeModel, SupportsPP): self.logits_processor = LogitsProcessor(self.unpadded_vocab_size, self.config.vocab_size, logit_scale) - self.sampler = get_sampler() else: self.lm_head = PPMissingLayer() @@ -526,11 +520,6 @@ class LlamaForCausalLM(NativeModel, SupportsPP): load_params.add("lm_head.weight") return load_params - def sample(self, logits: Tensor, - sampling_metadata: SamplingMetadata) -> Optional[SamplerOutput]: - next_tokens = self.sampler(logits, sampling_metadata) - return next_tokens - def compute_logits( self, hidden_states: Tensor, diff --git a/vllm_mindspore/model_executor/models/mf_models/mf_model_base.py b/vllm_mindspore/model_executor/models/mf_models/mf_model_base.py index d13e4351c8252134be298c3172b01fe331e13ec9..e38ec2896d11b8fcaf8ba72f7d7150ecb50523d3 100644 --- a/vllm_mindspore/model_executor/models/mf_models/mf_model_base.py +++ b/vllm_mindspore/model_executor/models/mf_models/mf_model_base.py @@ -32,7 +32,6 @@ from vllm.distributed import get_tensor_model_parallel_world_size from vllm.distributed.parallel_state import get_dp_group from vllm.forward_context import get_forward_context from vllm.logger import init_logger -from vllm.model_executor.layers.sampler import SamplerOutput from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.sequence import IntermediateTensors @@ -74,9 +73,9 @@ class MfModelBase(MsModelBase): self.mf_config.model.model_config.parallel_config.model_parallel = ( get_tensor_model_parallel_world_size()) self.mf_config.model.model_config.parallel_config.pipeline_stage = 1 - self.use_mla_op = \ - bool(vllm_config.additional_config - and vllm_config.additional_config.get('use_mla_op') == 1) + self.use_ringmla = vllm_config.model_config.quantization is not None \ + and vllm_config.parallel_config.tensor_parallel_size < 16 + self.is_chunked = False self._generate_model_config() if not hasattr(self, 'mf_model_config'): raise RuntimeError('mf_model_config not initialized') @@ -214,16 +213,5 @@ class MfModelBase(MsModelBase): logits = logits.view(-1, logits.shape[-1]) return logits - def sample( - self, - logits: Tensor, - sampling_metadata: SamplingMetadata, - ) -> Optional[SamplerOutput]: - if not hasattr(self, 'sampler'): - raise RuntimeError('sampler not initialized') - next_tokens = self.sampler(logits, sampling_metadata) - _pynative_executor.sync() - return next_tokens - def load_weights(self, weights: Iterable[tuple[str, Tensor]]) -> set[str]: raise NotImplementedError("load_weight not implemented.") diff --git a/vllm_mindspore/model_executor/models/mf_models/mindformers.py b/vllm_mindspore/model_executor/models/mf_models/mindformers.py index 4e64014afc12c919f6643278d57f12c9fcc079d9..a3d2d891196fe4516bd1b4d4dc73d095383c519b 100644 --- a/vllm_mindspore/model_executor/models/mf_models/mindformers.py +++ b/vllm_mindspore/model_executor/models/mf_models/mindformers.py @@ -23,14 +23,12 @@ from mindformers import AutoModel, PreTrainedModel from mindformers.core.context import build_mf_context from mindformers.tools.utils import is_pynative from mindspore import Tensor, mutable, ops -from mindspore.common.api import _pynative_executor from mindspore.nn.utils import no_init_parameters from vllm import envs from vllm.config import VllmConfig, get_current_vllm_config from vllm.distributed.parallel_state import get_dp_group, get_pp_group from vllm.forward_context import get_forward_context from vllm.logger import init_logger -from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler from vllm.model_executor.models.interfaces import SupportsPP from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.sequence import IntermediateTensors @@ -61,6 +59,9 @@ class MindFormersForCausalLM(MsModelBase, SupportsPP): self.mf_config = mf_config self.mla_config = self.mf_config.get('model', None).get( 'model_config', None).get('multi_latent_attention', False) + self.use_ringmla = vllm_config.model_config.quantization is not None \ + and vllm_config.parallel_config.tensor_parallel_size < 16 + self.is_chunked = False build_mf_context(self.mf_config) @@ -69,7 +70,6 @@ class MindFormersForCausalLM(MsModelBase, SupportsPP): self._set_dynamic_inputs() - self.sampler = get_sampler() self.set_modules({"model": self.network}) num_layers = self.model_config.get_num_layers(self.parallel_config) @@ -113,12 +113,20 @@ class MindFormersForCausalLM(MsModelBase, SupportsPP): return super().get_kvcache() key_cache = [] + rope_cache = [] forward_context = get_forward_context() - for i in range(self.config.num_hidden_layers): - k_cache = self.kv_caches[i].kv_cache[ - forward_context.virtual_engine][0] - key_cache.append(k_cache) - return mutable(key_cache), None + key_cache = [ + self.kv_caches[i].kv_cache[forward_context.virtual_engine][0] + for i in range(self.config.num_hidden_layers) + ] + if not self.use_ringmla: + return mutable(key_cache), None + # deepseek mla op need key cache and rope cache + rope_cache = [ + self.kv_caches[i].kv_cache[forward_context.virtual_engine][1] + for i in range(self.config.num_hidden_layers) + ] + return mutable(key_cache), mutable(rope_cache) def _get_padding_index(self, q_seq_len): """ @@ -257,15 +265,16 @@ class MindFormersForCausalLM(MsModelBase, SupportsPP): seq_lens_np = np.array(seq_lens, dtype=np.int32) query_lens_np = np.array(query_lens, dtype=np.int32) kv_cache_lens = seq_lens_np - query_lens_np - if attn_metadata.num_decode_tokens == 0 and kv_cache_lens.max( - ) == 0: - is_prefill = True - else: - is_prefill = False + is_prefill = kv_cache_lens.max() == 0 + is_chunked = attn_metadata.num_decode_tokens == 0 and \ + bool(kv_cache_lens.max() > 0) context_lens_tensor = ms.from_numpy(kv_cache_lens) else: # V1 is_prefill = attn_metadata.max_context_lens == 0 + is_chunked = not is_prefill and \ + bool((attn_metadata.context_lens - attn_metadata.num_prompt_tokens).min() < 0) + query_lens_np = attn_metadata.q_seq_lens_np seq_lens_np = attn_metadata.seq_lens_np context_lens_tensor = attn_metadata.context_lens @@ -289,7 +298,7 @@ class MindFormersForCausalLM(MsModelBase, SupportsPP): model_inputs = (self.update_padding_index_to_inputs( model_inputs, q_seq_lens)) - return model_inputs, is_prefill + return model_inputs, is_prefill, is_chunked def forward(self, input_ids: Tensor, @@ -297,21 +306,23 @@ class MindFormersForCausalLM(MsModelBase, SupportsPP): intermediate_tensors: Optional[IntermediateTensors] = None, inputs_embeds: Optional[Tensor] = None, **kwargs) -> Union[Tensor, IntermediateTensors]: - model_inputs, is_prefill = self.prepare_inputs(input_ids, positions) + model_inputs, is_prefill, is_chunked = self.prepare_inputs(input_ids, positions) model_inputs = self.update_model_inputs(model_inputs, **kwargs) if intermediate_tensors is not None: model_inputs["hidden_states"] = \ intermediate_tensors["hidden_states"] - if is_prefill: - self.network.phase = "prefill" + if is_prefill or is_chunked: + self.network.phase = "prefill" if (not self.use_ringmla or not is_chunked) else "chunked" if not self.set_flags or is_pynative(): self.network.add_flags_custom_mcore(is_prefill=True) + self.network.add_flags_chunked(is_chunked=is_chunked) + self.is_chunked |= (self.use_ringmla and is_chunked) hidden_states = self.network(**model_inputs) self.network.phase = "increment" if not self.set_flags or is_pynative(): self.network.add_flags_custom_mcore(is_prefill=False) - self.set_flags = True + self.set_flags = True if not self.use_ringmla else self.is_chunked else: hidden_states = self.network(**model_inputs) @@ -363,15 +374,6 @@ class MindFormersForCausalLM(MsModelBase, SupportsPP): logits = logits.view(-1, logits.shape[-1]) return logits - def sample( - self, - logits: Tensor, - sampling_metadata: SamplingMetadata, - ) -> Optional[SamplerOutput]: - next_tokens = self.sampler(logits, sampling_metadata) - _pynative_executor.sync() - return next_tokens - def load_weights(self, weights: Iterable[tuple[str, Tensor]]): self.network.load_weights(self.mf_config.load_checkpoint) return None diff --git a/vllm_mindspore/model_executor/models/mf_models/qwen2.py b/vllm_mindspore/model_executor/models/mf_models/qwen2.py index c41eb5c8e9d1e9ff612f24d2072790c85fbc912e..b2692d0fd7eda5dff80d80cd44813b2a70157abe 100644 --- a/vllm_mindspore/model_executor/models/mf_models/qwen2.py +++ b/vllm_mindspore/model_executor/models/mf_models/qwen2.py @@ -26,7 +26,6 @@ from research.qwen2_5.infer.qwen2_5 import ( # yapf: enable # noqa: ERA001 from vllm.config import VllmConfig, get_current_vllm_config from vllm.logger import init_logger -from vllm.model_executor.layers.sampler import get_sampler from vllm_mindspore.model_executor.models.mf_models.mf_model_base import ( MfModelBase) @@ -43,7 +42,6 @@ class Qwen2ForCausalLM(MfModelBase): super().__init__(vllm_config=vllm_config, prefix=prefix) self.mf_kvcaches_init = False - self.sampler = get_sampler() self.set_modules({"model": self.network}) self.kv_caches = [ diff --git a/vllm_mindspore/model_executor/models/mindone_models/transformers.py b/vllm_mindspore/model_executor/models/mindone_models/transformers.py index 5cf914b762cbd4232dac4d9eb2771f716e12f1a5..af0824e7f270191ea7ea1a1f73fd086dd2385fa0 100644 --- a/vllm_mindspore/model_executor/models/mindone_models/transformers.py +++ b/vllm_mindspore/model_executor/models/mindone_models/transformers.py @@ -38,7 +38,6 @@ from vllm.config import CacheConfig, ModelConfig, ParallelConfig, VllmConfig from vllm.forward_context import get_forward_context from vllm.logger import init_logger from vllm.model_executor.layers.quantization import QuantizationConfig -from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.sequence import IntermediateTensors @@ -381,7 +380,6 @@ class TransformersForCausalLM(MindONEModelBase): if has_bias: self.lm_head.bias.set_dtype(self.model_config.dtype) - self.sampler = get_sampler() self.set_modules({"model": self.model, "lm_head": self.lm_head}) self.logit_scale = getattr(config, "logit_scale", 1.0) @@ -568,14 +566,6 @@ class TransformersForCausalLM(MindONEModelBase): return logits - def sample(self, logits: mindspore.Tensor, - sampling_metadata: SamplingMetadata) -> Optional[SamplerOutput]: - next_tokens = self.sampler(logits, sampling_metadata) - - _pynative_executor.sync() - - return next_tokens - def load_weights(self, weights: Iterable[tuple[str, mindspore.Tensor]]): # TODO: support parallel weight loading diff --git a/vllm_mindspore/model_executor/models/mistral3.py b/vllm_mindspore/model_executor/models/mistral3.py new file mode 100644 index 0000000000000000000000000000000000000000..c233e073d6f5653121923fa4e4fc6646d17424ea --- /dev/null +++ b/vllm_mindspore/model_executor/models/mistral3.py @@ -0,0 +1,546 @@ +# SPDX-License-Identifier: Apache-2.0 + +# Adapted from +# https://github.com/vllm-project/vllm/blob/main/vllm/model_executor/models/llama.py +# Copyright 2025 Huawei Technologies Co., Ltd. +# Copyright 2024 The MistralAI team. +# Copyright 2023 The vLLM team. +# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved. +# +# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX +# and OPT implementations in this library. It has been modified from its +# original forms to accommodate minor architectural differences compared +# to GPT-NeoX and OPT used by the Meta AI team that trained the model. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# The adaptation of mistral3 in vllm-mindspore mainly includes the +# following points: +# 1. Additional model input parameters have been added, such as +# `key_cache`, `block_tables`, etc., to accommodate the +# vllm-mindspore calling convention. +# 2. During model initialization, methods from the NativeModel base +# class, such as `common_preprocess`, are invoked to adapt to the +# vllm-mindspore workflow. +# 3. In the `forward` function, the `exec_model` method is called to +# perform the model's forward computation, aligning with the +# vllm-mindspore execution flow. +# 4. In the `load_weights` function, due to the lack of `skip_prefix` +# functionality, the handling of `tie_word_embeddings` has been +# adapted. + +"""Inference-only Mistral3 model compatible with HuggingFace weights.""" +from collections.abc import Iterable +from typing import TYPE_CHECKING, Any, Optional, Union + +if TYPE_CHECKING: + from transformers import MistralConfig +else: + MistralConfig = None + +from mindspore import Tensor, mint, nn +from vllm.config import CacheConfig, VllmConfig +from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size +from vllm.logger import init_logger +from vllm.model_executor.layers.quantization import QuantizationConfig +from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler +from vllm.model_executor.sampling_metadata import SamplingMetadata +from vllm.sequence import IntermediateTensors + +from vllm_mindspore.attention import Attention +from vllm_mindspore.model_executor.layers.activation import SiluAndMul +from vllm_mindspore.model_executor.layers.layernorm import RMSNorm +from vllm_mindspore.model_executor.layers.linear import ( + MergedColumnParallelLinear, QKVParallelLinear, RowParallelLinear) +from vllm_mindspore.model_executor.layers.logits_processor import ( + LogitsProcessor) +from vllm_mindspore.model_executor.layers.rotary_embedding import get_rope +from vllm_mindspore.model_executor.layers.vocab_parallel_embedding import ( + DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding) +from vllm_mindspore.model_executor.model_loader.weight_utils import ( + default_weight_loader) +from vllm_mindspore.model_executor.models.model_base import NativeModel +from vllm_mindspore.model_executor.models.utils import ( + PPMissingLayer, extract_layer_index, + make_empty_intermediate_tensors_factory, make_layers, maybe_prefix) + +logger = init_logger(__name__) + + +class Mistral3MLP(nn.Cell): + + def __init__( + self, + hidden_size: int, + intermediate_size: int, + hidden_act: str, + quant_config=None, + bias: bool = False, + prefix: str = "", + ) -> None: + super().__init__() + self.gate_up_proj = MergedColumnParallelLinear( + input_size=hidden_size, + output_sizes=[intermediate_size] * 2, + bias=bias, + quant_config=quant_config, + prefix=f"{prefix}.gate_up_proj", + ) + self.down_proj = RowParallelLinear( + input_size=intermediate_size, + output_size=hidden_size, + bias=bias, + quant_config=quant_config, + prefix=f"{prefix}.down_proj", + ) + if hidden_act != "silu": + raise ValueError(f"Unsupported activation: {hidden_act}. " + "Only silu is supported for now.") + self.act_fn = SiluAndMul() + + def construct(self, x): + x, _ = self.gate_up_proj(x) + x = self.act_fn(x) + x, _ = self.down_proj(x) + return x + + +class Mistral3Attention(nn.Cell): + + def __init__( + self, + config: MistralConfig, + hidden_size: int, + num_heads: int, + num_kv_heads: int, + rope_theta: float = 10000, + rope_scaling: Optional[dict[str, Any]] = None, + max_position_embeddings: int = 4096 * 32, + quant_config=None, + bias: bool = False, + bias_o_proj: bool = False, + cache_config=None, + prefix: str = "", + ) -> None: + super().__init__() + layer_idx = extract_layer_index(prefix) + self.hidden_size = hidden_size + tp_size = get_tensor_model_parallel_world_size() + self.total_num_heads = num_heads + assert self.total_num_heads % tp_size == 0 + self.num_heads = self.total_num_heads // tp_size + self.total_num_kv_heads = num_kv_heads + if self.total_num_kv_heads >= tp_size: + # Number of KV heads is greater than TP size, so we partition + # the KV heads across multiple tensor parallel GPUs. + assert self.total_num_kv_heads % tp_size == 0 + else: + # Number of KV heads is less than TP size, so we replicate + # the KV heads across multiple tensor parallel GPUs. + assert tp_size % self.total_num_kv_heads == 0 + self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size) + # Mistral3 has an optional head_dim introduced by Mistral-Nemo + self.head_dim = getattr(config, "head_dim", + self.hidden_size // self.total_num_heads) + self.rotary_dim = self.head_dim + self.q_size = self.num_heads * self.head_dim + self.kv_size = self.num_kv_heads * self.head_dim + self.scaling = self.head_dim**-0.5 + self.rope_theta = rope_theta + self.max_position_embeddings = max_position_embeddings + + self.qkv_proj = QKVParallelLinear( + hidden_size=hidden_size, + head_size=self.head_dim, + total_num_heads=self.total_num_heads, + total_num_kv_heads=self.total_num_kv_heads, + bias=bias, + quant_config=quant_config, + prefix=f"{prefix}.qkv_proj", + ) + + self.o_proj = RowParallelLinear( + input_size=self.total_num_heads * self.head_dim, + output_size=hidden_size, + bias=bias_o_proj, + quant_config=quant_config, + prefix=f"{prefix}.o_proj", + ) + + is_neox_style = True + self.rotary_emb = get_rope( + self.head_dim, + rotary_dim=self.head_dim, + max_position=max_position_embeddings, + base=rope_theta, + rope_scaling=rope_scaling, + is_neox_style=is_neox_style, + ) + + # Handle sliding window attention + if hasattr(config, "interleaved_sliding_window"): + interleaved_sliding_window = config.interleaved_sliding_window + if isinstance(interleaved_sliding_window, int): + sliding_window = interleaved_sliding_window + elif isinstance(interleaved_sliding_window, list): + sw_idx = layer_idx % len(interleaved_sliding_window) + sliding_window = interleaved_sliding_window[sw_idx] + else: + raise ValueError( + f"{type(interleaved_sliding_window)} is not supported.") + else: + sliding_window = None + + self.attn = Attention( + self.num_heads, + self.head_dim, + self.scaling, + num_kv_heads=self.num_kv_heads, + cache_config=cache_config, + quant_config=quant_config, + per_layer_sliding_window=sliding_window, + prefix=f"{prefix}.attn", + ) + + def construct( + self, + positions: Tensor, + hidden_states: Tensor, + key_cache: Tensor, + value_cache: Tensor, + slot_mapping: Tensor, + attn_mask: Tensor, + batch_valid_length: Tensor, + q_seq_lens: Tensor, + block_tables: Tensor, + ) -> Tensor: + qkv, _ = self.qkv_proj(hidden_states) + q, k, v = mint.split(qkv, (self.q_size, self.kv_size, self.kv_size), + -1) + q, k = self.rotary_emb(positions, q, k, batch_valid_length) + attn_output = self.attn(q, k, v, key_cache, value_cache, slot_mapping, + attn_mask, batch_valid_length, q_seq_lens, + block_tables) + output, _ = self.o_proj(attn_output) + return output + + +class Mistral3DecoderLayer(nn.Cell): + + def __init__( + self, + config: MistralConfig, + cache_config=None, + quant_config=None, + prefix: str = "", + ): + super().__init__() + self.hidden_size = config.hidden_size + rope_theta = getattr(config, "rope_theta", 10000) + rope_scaling = getattr(config, "rope_scaling", None) + if rope_scaling is not None and getattr( + config, "original_max_position_embeddings", None): + rope_scaling["original_max_position_embeddings"] = ( + config.original_max_position_embeddings) + max_position_embeddings = getattr(config, "max_position_embeddings", + 4096 * 32) + # Support mistral/mistral3 with attention_bias + attention_bias = getattr(config, "attention_bias", False) + bias_o_proj = getattr(config, "bias", False) + + self.self_attn = Mistral3Attention( + config=config, + hidden_size=self.hidden_size, + num_heads=config.num_attention_heads, + num_kv_heads=getattr(config, "num_key_value_heads", + config.num_attention_heads), + rope_theta=rope_theta, + rope_scaling=rope_scaling, + max_position_embeddings=max_position_embeddings, + quant_config=quant_config, + bias=attention_bias, + bias_o_proj=bias_o_proj, + cache_config=cache_config, + prefix=f"{prefix}.self_attn", + ) + self.mlp = Mistral3MLP( + hidden_size=self.hidden_size, + intermediate_size=config.intermediate_size, + hidden_act=config.hidden_act, + quant_config=quant_config, + bias=getattr(config, "mlp_bias", False), + prefix=f"{prefix}.mlp", + ) + self.input_layernorm = RMSNorm(config.hidden_size, + eps=config.rms_norm_eps) + self.post_attention_layernorm = RMSNorm(config.hidden_size, + eps=config.rms_norm_eps) + + def construct( + self, + positions: Tensor, + hidden_states: Tensor, + key_cache: Tensor, + value_cache: Tensor, + slot_mapping: Tensor, + attn_mask: Tensor, + batch_valid_length: Tensor, + q_seq_lens: Tensor, + block_tables: Tensor, + residual: Optional[Tensor], + ) -> tuple[Tensor, Tensor]: + # Self Attention + if residual is None: + residual = hidden_states + hidden_states = self.input_layernorm(hidden_states) + else: + hidden_states, residual = self.input_layernorm( + hidden_states, residual) + + hidden_states = self.self_attn(positions, hidden_states, key_cache, + value_cache, slot_mapping, attn_mask, + batch_valid_length, q_seq_lens, + block_tables) + + # Fully Connected + hidden_states, residual = self.post_attention_layernorm( + hidden_states, residual) + hidden_states = self.mlp(hidden_states) + return hidden_states, residual + + +class Mistral3Model(nn.Cell): + SUPPORT_LORA = False + SUPPORT_PP = False + + def __init__( + self, + *, + vllm_config, + prefix: str = "", + layer_type: type[Mistral3DecoderLayer] = Mistral3DecoderLayer, + ): + super().__init__() + config = vllm_config.model_config.hf_config + self.config = config + self.padding_idx = config.pad_token_id + self.vocab_size = config.vocab_size + self.org_vocab_size = config.vocab_size + quant_config = vllm_config.quant_config + self.quant_config = quant_config + cache_config = vllm_config.cache_config + lora_config = vllm_config.lora_config # noqa: F841 + + if get_pp_group().is_first_rank or (config.tie_word_embeddings + and get_pp_group().is_last_rank): + self.embed_tokens = VocabParallelEmbedding( + self.vocab_size, + config.hidden_size, + org_num_embeddings=config.vocab_size, + quant_config=quant_config, + ) + else: + self.embed_tokens = PPMissingLayer() + + self.start_layer, self.end_layer, self.layers = make_layers( + config.num_hidden_layers, + lambda prefix: layer_type( + config=config, + cache_config=cache_config, + quant_config=quant_config, + prefix=prefix, + ), + prefix=f"{prefix}.layers", + ) + + if get_pp_group().is_last_rank: + self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + else: + self.norm = PPMissingLayer() + + self.make_empty_intermediate_tensors = \ + make_empty_intermediate_tensors_factory( + ["hidden_states", "residual"], config.hidden_size) + + def get_input_embeddings(self, input_ids: Tensor) -> Tensor: + return self.embed_tokens(input_ids) + + def construct( + self, + input_ids: Optional[Tensor], + positions: Tensor, + key_caches: list[Tensor], + value_caches: list[Tensor], + slot_mapping: Tensor, + attn_mask: Tensor, + batch_valid_length: Tensor, + q_seq_lens: Tensor, + block_tables: Tensor, + intermediate_tensors: Optional[IntermediateTensors] = None, + inputs_embeds: Optional[Tensor] = None, + ) -> Union[Tensor, IntermediateTensors]: + if get_pp_group().is_first_rank: + if inputs_embeds is not None: + hidden_states = inputs_embeds + else: + hidden_states = self.get_input_embeddings(input_ids) + residual = None + else: + assert intermediate_tensors is not None + hidden_states = intermediate_tensors["hidden_states"] + residual = intermediate_tensors["residual"] + + for i in range(self.start_layer, self.end_layer): + layer = self.layers[i] + hidden_states, residual = layer(positions, hidden_states, + key_caches[i - self.start_layer], + value_caches[i - self.start_layer], + slot_mapping, attn_mask, + batch_valid_length, q_seq_lens, + block_tables, residual) + + if not get_pp_group().is_last_rank: + return IntermediateTensors({ + "hidden_states": hidden_states, + "residual": residual + }) + + hidden_states, _ = self.norm(hidden_states, residual) + return hidden_states + + def load_weights(self, weights: Iterable[tuple[str, Tensor]], params_dict): + loaded_params: set[str] = set() + stacked_params_mapping = [ + # (param_name, shard_name, shard_id) + (".qkv_proj", ".q_proj", "q"), + (".qkv_proj", ".k_proj", "k"), + (".qkv_proj", ".v_proj", "v"), + (".gate_up_proj", ".gate_proj", 0), + (".gate_up_proj", ".up_proj", 1), + ] + + for name, loaded_weight in weights: + if "rotary_emb.inv_freq" in name: + continue + if "rotary_emb.cos_cached" in name or \ + "rotary_emb.sin_cached" in name: + # Models trained using ColossalAI may include these tensors in + # the checkpoint. Skip them. + continue + + for param_name, weight_name, shard_id in stacked_params_mapping: + if weight_name not in name: + continue + name = name.replace(weight_name, param_name) + if name in params_dict: + param = params_dict[name] + weight_loader = param.weight_loader + weight_loader(param, loaded_weight, shard_id) + loaded_params.add(name) + break + else: + if name in params_dict: + param = params_dict[name] + weight_loader = getattr(param, "weight_loader", + default_weight_loader) + weight_loader(param, loaded_weight) + loaded_params.add(name) + + return loaded_params + + +class Mistral3ForCausalLM(NativeModel): + + def __init__(self, vllm_config, prefix: str = ""): + super().__init__(vllm_config=vllm_config, prefix=prefix) + + config = vllm_config.model_config.hf_config + quant_config = vllm_config.quant_config + lora_config = vllm_config.lora_config + + self.config = config + self.lora_config = lora_config + + self.quant_config = quant_config + self.model = Mistral3Model(vllm_config=vllm_config) + + if get_pp_group().is_last_rank: + self.unpadded_vocab_size = self.config.vocab_size + # TODO: To support lora + # if self.lora_config: + # self.unpadded_vocab_size += + # self.lora_config.lora_extra_vocab_size + # self.unpadded_vocab_size += config.lora_extra_vocab_size + self.lm_head = ParallelLMHead( + self.unpadded_vocab_size, + self.config.hidden_size, + org_num_embeddings=self.config.vocab_size, + padding_size=( + DEFAULT_VOCAB_PADDING_SIZE + # We need bigger padding if using lora for kernel + # compatibility + if not self.lora_config else + self.lora_config.lora_vocab_padding_size), + quant_config=quant_config, + prefix=maybe_prefix(prefix, "lm_head"), + ) + if self.config.tie_word_embeddings: + self.lm_head = self.lm_head.tie_weights( + self.model.embed_tokens) + + logit_scale = getattr(self.config, "logit_scale", 1.0) + self.logits_processor = LogitsProcessor(self.unpadded_vocab_size, + self.config.vocab_size, + logit_scale) + self.sampler = get_sampler() + else: + self.lm_head = PPMissingLayer() + + self.make_empty_intermediate_tensors = ( + self.model.make_empty_intermediate_tensors) + + self.common_preprocess(vllm_config, prefix) + + def get_input_embeddings(self, input_ids: Tensor) -> Tensor: + return self.model.get_input_embeddings(input_ids) + + def forward(self, + input_ids, + positions, + intermediate_tensors=None, + inputs_embeds=None, + **kwargs): + hidden_states = self.exec_model(input_ids, positions, + intermediate_tensors, inputs_embeds) + return hidden_states + + def load_weights(self, weights: Iterable[tuple[str, Tensor]]) -> set[str]: + params_dict = self.get_params_dict() + load_params = self.model.load_weights(weights, params_dict) + if self.config.tie_word_embeddings: + load_params.add("lm_head.weight") + return load_params + + def sample(self, logits: Tensor, + sampling_metadata: SamplingMetadata) -> Optional[SamplerOutput]: + next_tokens = self.sampler(logits, sampling_metadata) + return next_tokens + + def compute_logits( + self, + hidden_states: Tensor, + sampling_metadata: SamplingMetadata, + ) -> Optional[Tensor]: + logits = self.logits_processor(self.lm_head, hidden_states, + sampling_metadata) + return logits \ No newline at end of file diff --git a/vllm_mindspore/model_executor/models/model_base.py b/vllm_mindspore/model_executor/models/model_base.py index 4d35d968449d602115bdab34f01f9fcde2fcc091..8786d7b9c4557385764c366d88d1208c41d86329 100644 --- a/vllm_mindspore/model_executor/models/model_base.py +++ b/vllm_mindspore/model_executor/models/model_base.py @@ -27,7 +27,6 @@ from mindspore.common import dtype as mstype from vllm.attention.backends.abstract import AttentionType from vllm.config import VllmConfig, get_current_vllm_config from vllm.forward_context import get_forward_context -from vllm.model_executor.layers.sampler import SamplerOutput from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.sequence import IntermediateTensors @@ -70,10 +69,9 @@ class MLAAttentionWrapper(AttentionWrapper): def __init__(self): super().__init__() vllm_config = get_current_vllm_config() - self.use_mla_op = bool( - vllm_config.additional_config - and vllm_config.additional_config.get('use_mla_op') == 1) - if not self.use_mla_op: + self.use_ringmla = vllm_config.model_config.quantization is not None \ + and vllm_config.parallel_config.tensor_parallel_size < 16 + if not self.use_ringmla: self.kv_cache = [ ( ms.mint.zeros( @@ -88,9 +86,9 @@ class MLAAttentionWrapper(AttentionWrapper): 'qk_rope_head_dim', 0) # k_shape, r_shape used for mla_op k_shape = [*(self.kv_shape[0:-1]), kv_lora_rank - ] if self.use_mla_op else None + ] if self.use_ringmla else None r_shape = [*(self.kv_shape[0:-1]), qk_rope_head_dim - ] if self.use_mla_op else None + ] if self.use_ringmla else None self.kv_cache = [ (ms.mint.zeros(k_shape, dtype=vllm_config.model_config.dtype), ms.mint.zeros(r_shape, dtype=vllm_config.model_config.dtype)) @@ -260,14 +258,6 @@ class MsModelBase: raise NotImplementedError( "Function compute_logits should be Implemented!") - @abstractmethod - def sample( - self, - logits: Tensor, - sampling_metadata: SamplingMetadata, - ) -> Optional[SamplerOutput]: - raise NotImplementedError("Function sample should be Implemented!") - @abstractmethod def load_weights(self, weights: Iterable[tuple[str, Tensor]]) -> set[str]: raise NotImplementedError( @@ -300,7 +290,8 @@ class MsModelBase: # To enforce prefill and decode are both complied in warmup process. # So set max_context_lens to 0 for prefill and 1 for decode. max_context_lens=0 if not self.set_flags else 1, - query_start_loc=None) + query_start_loc=None, + num_prompt_tokens=seq_lengths) def prepare_base_inputs(self, input_ids, positions): attn_metadata = get_forward_context().attn_metadata diff --git a/vllm_mindspore/model_executor/models/qwen2.py b/vllm_mindspore/model_executor/models/qwen2.py index 62d5b020a1882b7a7de98b2b09b35a5f535f1ba7..db8e31c0c4764af35a7b487473e5a9874b786a55 100644 --- a/vllm_mindspore/model_executor/models/qwen2.py +++ b/vllm_mindspore/model_executor/models/qwen2.py @@ -42,7 +42,6 @@ from vllm.config import CacheConfig, VllmConfig from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.models.interfaces import SupportsLoRA -from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler from vllm.sequence import IntermediateTensors from vllm.model_executor.sampling_metadata import SamplingMetadata @@ -464,7 +463,6 @@ class Qwen2ForCausalLM(NativeModel, SupportsLoRA): self.lm_head = PPMissingLayer() self.logits_processor = LogitsProcessor(self.config.vocab_size) - self.sampler = get_sampler() self.make_empty_intermediate_tensors = ( self.model.make_empty_intermediate_tensors) @@ -485,11 +483,6 @@ class Qwen2ForCausalLM(NativeModel, SupportsLoRA): params_dict = self.get_params_dict() self.model.load_weights(weights, params_dict) - def sample(self, logits: Tensor, - sampling_metadata: SamplingMetadata) -> Optional[SamplerOutput]: - next_tokens = self.sampler(logits, sampling_metadata) - return next_tokens - def compute_logits( self, hidden_states: Tensor, diff --git a/vllm_mindspore/model_executor/models/qwen2_5_vl.py b/vllm_mindspore/model_executor/models/qwen2_5_vl.py index 9087339acbd3bb8f328975c2445aec4cd969ea87..02d76dd67f6da8754fa97ee987aa9fdb394dc18b 100644 --- a/vllm_mindspore/model_executor/models/qwen2_5_vl.py +++ b/vllm_mindspore/model_executor/models/qwen2_5_vl.py @@ -55,8 +55,6 @@ from vllm.distributed import utils as dist_utils from vllm.logger import init_logger from vllm.model_executor import SamplingMetadata from vllm.model_executor.layers.quantization import QuantizationConfig -from vllm.model_executor.layers.sampler import SamplerOutput, \ - get_sampler from vllm.model_executor.models.module_mapping import MultiModelKeys from vllm.multimodal import MULTIMODAL_REGISTRY @@ -737,13 +735,13 @@ class Qwen2_5_VisionAttention(nn.Cell): return tensor_list def split_qkv(self, qkv: ms.Tensor) -> tuple[ms.Tensor, ...]: - # [s, 3 * head * head_dim] + # shape is [s, 3 * head * head_dim]. seq_len, _ = qkv.shape - # [s, 3 * head * head_dim] -> 3 * [s, head * head_dim] + # shape is [s, 3 * head * head_dim] -> 3 * [s, head * head_dim] q, k, v = mint.chunk(qkv, 3, dim=-1) - # 3 * [s, head * head_dim] -> 3 * [s, head, head_dim] + # shape is 3 * [s, head * head_dim] -> 3 * [s, head, head_dim] new_shape = (seq_len, self.num_attention_heads_per_partition, self.hidden_size_per_attention_head) q, k, v = (x.view(*new_shape) for x in (q, k, v)) @@ -1214,7 +1212,6 @@ class Qwen2_5_VLForConditionalGeneration(NativeModel, SupportsMultiModal): prefix=maybe_prefix( prefix, "lm_head")) self.logits_processor = LogitsProcessor(config.vocab_size) - self.sampler = get_sampler() else: self.lm_head = PPMissingLayer() @@ -1252,8 +1249,10 @@ class Qwen2_5_VLForConditionalGeneration(NativeModel, SupportsMultiModal): def _maybe_ignore_quant_config(self, quant_config: QuantizationConfig): # GPTQ configs do not have a list of ignored modules, however AutoGPTQ # seems to avoid vision encoder sections for some models. - # if isinstance(quant_config, (GPTQConfig, GPTQMarlinConfig)): - # return None + ''' + if isinstance(quant_config, (GPTQConfig, GPTQMarlinConfig)): + return None + ''' return quant_config def _validate_and_reshape_mm_tensor(self, mm_input: object, @@ -1515,7 +1514,6 @@ class Qwen2_5_VLForConditionalGeneration(NativeModel, SupportsMultiModal): input_ids: ms.Tensor, multimodal_embeddings: Optional[tuple[ms.Tensor, ...]] = None, ) -> ms.Tensor: - # input_ids = input_ids.to(mstype.int64) inputs_embeds = self.model.get_input_embeddings(input_ids) if multimodal_embeddings is not None: inputs_embeds = merge_multimodal_embeddings( @@ -1593,11 +1591,6 @@ class Qwen2_5_VLForConditionalGeneration(NativeModel, SupportsMultiModal): sampling_metadata) return logits - def sample(self, logits: ms.Tensor, - sampling_metadata: SamplingMetadata) -> Optional[SamplerOutput]: - next_tokens = self.sampler(logits, sampling_metadata) - return next_tokens - def load_weights( self, weights: Iterable[tuple[str, ms.Tensor]] ) -> None: # type: ignore[override] diff --git a/vllm_mindspore/model_executor/models/qwen3.py b/vllm_mindspore/model_executor/models/qwen3.py index ecdaaf1c7b65dd58f1fd5ed3df06412b0e4bc773..89a98b58e2cddebfe5b135342ea03de92c90c07d 100644 --- a/vllm_mindspore/model_executor/models/qwen3.py +++ b/vllm_mindspore/model_executor/models/qwen3.py @@ -50,7 +50,6 @@ from vllm.config import CacheConfig, VllmConfig from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size from vllm.logger import init_logger from vllm.model_executor.layers.quantization import QuantizationConfig -from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.sequence import IntermediateTensors @@ -296,7 +295,6 @@ class Qwen3ForCausalLM(NativeModel): quant_config=quant_config, prefix=maybe_prefix( prefix, "lm_head")) - self.sampler = get_sampler() else: self.lm_head = PPMissingLayer() @@ -320,11 +318,6 @@ class Qwen3ForCausalLM(NativeModel): intermediate_tensors, inputs_embeds) return hidden_states - def sample(self, logits: Tensor, - sampling_metadata: SamplingMetadata) -> Optional[SamplerOutput]: - next_tokens = self.sampler(logits, sampling_metadata) - return next_tokens - def compute_logits( self, hidden_states: Tensor, diff --git a/vllm_mindspore/v1/attention/backends/ms_attn.py b/vllm_mindspore/v1/attention/backends/ms_attn.py index f014381b393154cf81ee466439777df0510f8fb1..194681db730c3c0bebd541f4215578a478b583e5 100644 --- a/vllm_mindspore/v1/attention/backends/ms_attn.py +++ b/vllm_mindspore/v1/attention/backends/ms_attn.py @@ -137,6 +137,7 @@ class MsAttentionMetadata: #block_table: torch.Tensor slot_mapping: ms.Tensor + num_prompt_tokens: ms.Tensor # For cascade attention. #use_cascade: bool #common_prefix_len: int @@ -201,6 +202,7 @@ class MsAttentionMetadataBuilder: max_context_lens = self.runner.input_batch.num_computed_tokens_cpu[: num_reqs].max( ) + num_prompt_tokens = ms.from_numpy(self.runner.input_batch.num_prompt_tokens[:num_reqs]) slot_mapping = ms.from_numpy( self.block_table.slot_mapping_np[:num_actual_tokens]) seq_lens_np = self.runner.seq_lens_np[:num_reqs] @@ -220,7 +222,8 @@ class MsAttentionMetadataBuilder: max_seq_len=max_seq_len, context_lens=context_lens, max_context_lens=max_context_lens, - query_start_loc=query_start_loc) + query_start_loc=query_start_loc, + num_prompt_tokens=num_prompt_tokens) return attn_metadata diff --git a/vllm_mindspore/v1/worker/gpu_model_runner.py b/vllm_mindspore/v1/worker/gpu_model_runner.py index b37b39bc1f718473c16cc86a3cbcb28fc73fc70a..d87912ed2678e49dbe3ea87b2fb03fad79847a6b 100644 --- a/vllm_mindspore/v1/worker/gpu_model_runner.py +++ b/vllm_mindspore/v1/worker/gpu_model_runner.py @@ -233,10 +233,9 @@ def _allocate_nz_kv_cache_tensors(self, kv_cache_config): for layer_name in group.layer_names } - use_mla_op = bool( - self.vllm_config.additional_config - and self.vllm_config.additional_config.get('use_mla_op') == 1) - if use_mla_op: + use_ringmla = self.vllm_config.model_config.quantization is not None \ + and self.vllm_config.parallel_config.tensor_parallel_size < 16 + if use_ringmla: logger.error("For 310p, mla kv cache not supported") raise NotImplementedError @@ -290,9 +289,8 @@ def _allocate_kv_cache_tensors(self, kv_cache_config): dtype = kv_cache_spec.dtype coef = 1 if use_mla else 2 # Determine whether deepseek use mla op - use_mla_op = bool( - self.vllm_config.additional_config - and self.vllm_config.additional_config.get('use_mla_op') == 1) + use_ringmla = self.vllm_config.model_config.quantization is not None \ + and self.vllm_config.parallel_config.tensor_parallel_size < 16 kv_lora_rank = getattr(self.vllm_config.model_config.hf_text_config, 'kv_lora_rank', 0) qk_rope_head_dim = getattr(self.vllm_config.model_config.hf_text_config, @@ -317,7 +315,7 @@ def _allocate_kv_cache_tensors(self, kv_cache_config): """ raw_tensors.extend( [mint.zeros(raw_tensor_shape, dtype=target_dtype)] - if not use_mla_op else [ + if not use_ringmla else [ mint.zeros(int(raw_tensor_shape * kv_lora_rank / (kv_lora_rank + qk_rope_head_dim)), dtype=target_dtype), @@ -354,9 +352,8 @@ def _reshape_kv_cache_tensors( corresponding memory buffer for KV cache. """ # Determine whether deepseek use mla op - use_mla_op = bool( - self.vllm_config.additional_config - and self.vllm_config.additional_config.get('use_mla_op') == 1) + use_ringmla = self.vllm_config.model_config.quantization is not None \ + and self.vllm_config.parallel_config.tensor_parallel_size < 16 kv_lora_rank = getattr(self.vllm_config.model_config.hf_text_config, 'kv_lora_rank', 0) qk_rope_head_dim = getattr(self.vllm_config.model_config.hf_text_config, @@ -371,7 +368,7 @@ def _reshape_kv_cache_tensors( dtype_size = get_dtype_size(target_dtype) num_blocks = \ (raw_tensor[0].numel() - if not use_mla_op else + if not use_ringmla else # deepseek mla op need key cache and rope cache (raw_tensor[0].numel() + raw_tensor[1].numel())) * \ coef * dtype_size // kv_cache_spec.page_size_bytes @@ -400,7 +397,7 @@ def _reshape_kv_cache_tensors( kv_cache_layer = [] for idx, kv_cache_raw_tensor in enumerate( kv_cache_raw_tensors[layer_name]): - if use_mla_op: + if use_ringmla: # deepseek mla op need key cache and rope cache cache_shape = [ *(kv_cache_shape[1:-1]),