diff --git a/tests/st/python/cases_parallel/vllm_llama3.py b/tests/st/python/cases_parallel/vllm_llama3.py new file mode 100644 index 0000000000000000000000000000000000000000..656c744d960bbe1c497719de341f9ca7e4907db7 --- /dev/null +++ b/tests/st/python/cases_parallel/vllm_llama3.py @@ -0,0 +1,113 @@ +#!/usr/bin/env python3 +# encoding: utf-8 +# Copyright 2025 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +# isort:skip_file +"""test vllm llama3.""" +import os + +import pytest + +from tests.st.python import set_env + +env_manager = set_env.EnvVarManager() +# def env +env_vars = { + "ASCEND_CUSTOM_PATH": os.path.expandvars("$ASCEND_HOME_PATH/../"), + "MS_ENABLE_LCCL": "off", + "HCCL_OP_EXPANSION_MODE": "AIV", + "MS_ALLOC_CONF": "enable_vmm:True", + "LCCL_DETERMINISTIC": "1", + "HCCL_DETERMINISTIC": "true", + "ATB_MATMUL_SHUFFLE_K_ENABLE": "0", + "ATB_LLM_LCOC_ENABLE": "0", + "VLLM_USE_V1": "1", + "HCCL_IF_BASE_PORT": "60000" +} +# set env +env_manager.setup_ai_environment(env_vars) +import vllm_mindspore +from vllm import LLM, SamplingParams + + +def test_vllm_llama3_8b(): + """ + test case llama3.1 8B + """ + + # Sample prompts. + prompts = [ + "<|start_header_id|>user<|end_header_id|>\n\n将文本分类为中性、负面或正面。 " + "\n文本:我认为这次假期还可以。 \n情感:<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n", + ] + + # Create a sampling params object. + sampling_params = SamplingParams(temperature=0.0, max_tokens=10, top_k=1) + + # Create an LLM. + llm = LLM( + model="/home/workspace/mindspore_dataset/weight/Llama-3.1-8B-Instruct", + gpu_memory_utilization=0.9, + tensor_parallel_size=1, + max_model_len=4096) + # Generate texts from the prompts. The output is a list of RequestOutput objects + # that contain the prompt, generated text, and other information. + outputs = llm.generate(prompts, sampling_params) + except_list = ['中性'] + # Print the outputs. + for i, output in enumerate(outputs): + prompt = output.prompt + generated_text = output.outputs[0].text + print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}") + assert generated_text == except_list[i] + + # unset env + env_manager.unset_all() + + +def test_vllm_llama3_1b(): + """ + test case llama3.2 1B + """ + + # Sample prompts. + prompts = [ + "<|start_header_id|>user<|end_header_id|>\n\n将文本分类为中性、负面或正面。 " + "\n文本:我认为这次假期还可以。 \n情感:<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n", + ] + + # Create a sampling params object. + sampling_params = SamplingParams(temperature=0.0, max_tokens=10, top_k=1) + + # Create an LLM. + llm = LLM( + model="/home/workspace/mindspore_dataset/weight/Llama-3.2-1B-Instruct", + gpu_memory_utilization=0.9, + tensor_parallel_size=1, + max_model_len=4096) + # Generate texts from the prompts. The output is a list of RequestOutput objects + # that contain the prompt, generated text, and other information. + outputs = llm.generate(prompts, sampling_params) + except_list = ['中性'] + # Print the outputs. + for i, output in enumerate(outputs): + prompt = output.prompt + generated_text = output.outputs[0].text + print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}") + assert generated_text == except_list[i] + + # unset env + env_manager.unset_all() diff --git a/tests/st/python/test_cases_parallel.py b/tests/st/python/test_cases_parallel.py index f0ef9f1bcdac854443b553307a445981bbeaaf0a..b5f93c6a72099f87b09700671d291f6aeeb4af15 100644 --- a/tests/st/python/test_cases_parallel.py +++ b/tests/st/python/test_cases_parallel.py @@ -232,7 +232,17 @@ def test_cases_parallel_part5(): "export HCCL_IF_BASE_PORT=61002 && " "pytest -s -v cases_parallel/vllm_mf_qwen3_8b_v1.py::test_mf_qwen3 " "> vllm_mf_qwen3_8b_v1_test_mf_qwen3.log", - "vllm_mf_qwen3_8b_v1_test_mf_qwen3.log") + "vllm_mf_qwen3_8b_v1_test_mf_qwen3.log"), + ("export ASCEND_RT_VISIBLE_DEVICES=4 && export LCAL_COMM_ID=127.0.0.1:10070 && " + "export HCCL_IF_BASE_PORT=61004 && " + "pytest -s -v cases_parallel/vllm_llama3.py::test_vllm_llama3_8b " + "> vllm_llama3_8b_test_vllm_llama3.log", + "vllm_llama3_8b_test_vllm_llama3.log"), + ("export ASCEND_RT_VISIBLE_DEVICES=5 && export LCAL_COMM_ID=127.0.0.1:10071 && " + "export HCCL_IF_BASE_PORT=61006 && " + "pytest -s -v cases_parallel/vllm_llama3.py::test_vllm_llama3_1b " + "> vllm_llama3_1b_test_vllm_llama3.log", + "vllm_llama3_1b_test_vllm_llama3.log"), ] with Pool(len(commands)) as pool: diff --git a/vllm_mindspore/model_executor/layers/rotary_embedding.py b/vllm_mindspore/model_executor/layers/rotary_embedding.py index eb56d6650a5d42c4ae2ed3f55a9112d7f82af718..ff6ea4da22a78e6db7d7f5a31f18b1ea990f9357 100644 --- a/vllm_mindspore/model_executor/layers/rotary_embedding.py +++ b/vllm_mindspore/model_executor/layers/rotary_embedding.py @@ -15,17 +15,17 @@ # limitations under the License. # ============================================================================ +import math from typing import Any, Dict, List, Optional, Tuple, Union -import numpy as np import mindspore -from mindspore import Tensor, mint, ops, nn +import numpy as np +from mindspore import Tensor, mint, nn, ops from mindspore.common import dtype as mstype - from transformers import PretrainedConfig - from vllm.config import get_current_vllm_config + def _apply_rotary_emb( x: Tensor, cos: Tensor, @@ -163,9 +163,10 @@ class InferRotaryEmbedding(nn.Cell): Compute the inverse frequency with numpy. Numpy process is faster during initialization. """ - freqs_base = np.arange(0, self.rotary_dim, 2).astype( - np.float32) # (head_dim // 2, ) - freqs = 1.0 / (base**(freqs_base / self.rotary_dim)) # (head_dim // 2, ) + freqs_base = np.arange(0, self.rotary_dim, + 2).astype(np.float32) # (head_dim // 2, ) + freqs = 1.0 / (base**(freqs_base / self.rotary_dim) + ) # (head_dim // 2, ) return freqs def _compute_cos_sin_cache(self) -> Tuple[Tensor, Tensor]: @@ -173,8 +174,8 @@ class InferRotaryEmbedding(nn.Cell): t = np.arange(0, self.max_position_embeddings, 1).astype(np.float32) freqs = np.outer(t, freqs) # (max_position_embedding, head_dim // 2) emb = np.concatenate((freqs, freqs), axis=-1) - freqs_cos = np.cos(emb) # (seq_len, head_dim) - freqs_sin = np.sin(emb) # (seq_len, head_dim) + freqs_cos = np.cos(emb) # (seq_len, head_dim) + freqs_sin = np.sin(emb) # (seq_len, head_dim) freqs_cos = Tensor(freqs_cos, dtype=self.dtype) freqs_sin = Tensor(freqs_sin, dtype=self.dtype) return freqs_cos, freqs_sin @@ -200,6 +201,52 @@ class InferRotaryEmbedding(nn.Cell): batch_valid_length) +class InferLlama3RotaryEmbedding(InferRotaryEmbedding): + + def __init__( + self, + head_size: int, + rotary_dim: int, + max_position_embeddings: int, + base: int, + is_neox_style: bool, + dtype, + scaling_factor: float, + low_freq_factor: float, + high_freq_factor: float, + orig_max_position: int, + ) -> None: + self.scaling_factor = scaling_factor + self.low_freq_factor = low_freq_factor + self.high_freq_factor = high_freq_factor + self.orig_max_position = orig_max_position + super().__init__(head_size, rotary_dim, max_position_embeddings, base, + is_neox_style, dtype) + + def _compute_inv_freq(self, base: Union[int, float]) -> np.ndarray: + inv_freqs = super()._compute_inv_freq(base) + low_freq_wavelen = self.orig_max_position / self.low_freq_factor + high_freq_wavelen = self.orig_max_position / self.high_freq_factor + + wave_len = 2 * math.pi / inv_freqs + if self.low_freq_factor != self.high_freq_factor: + smooth = (self.orig_max_position / wave_len - self.low_freq_factor + ) / (self.high_freq_factor - self.low_freq_factor) + else: + smooth = 0 + new_freqs = np.where( + wave_len < high_freq_wavelen, + inv_freqs, + np.where( + wave_len > low_freq_wavelen, + inv_freqs / self.scaling_factor, + (1 - smooth) * inv_freqs / self.scaling_factor + + smooth * inv_freqs, + ), + ) + return new_freqs + + class MRotaryEmbedding(RotaryEmbedding): """Rotary Embedding with Multimodal Sections.""" @@ -254,9 +301,9 @@ class MRotaryEmbedding(RotaryEmbedding): cos_l = ops.split(cos, self.mrope_section, axis=-1) sin_l = ops.split(sin, self.mrope_section, axis=-1) cos, sin = (), () - for i in range(len(self.mrope_section)): - cos += (cos_l[i][i],) - sin += (sin_l[i][i],) + for i in range(len(self.mrope_section)): # type: ignore[arg-type] + cos += (cos_l[i][i], ) + sin += (sin_l[i][i], ) cos = ops.cat(cos, axis=-1) sin = ops.cat(sin, axis=-1) @@ -379,7 +426,8 @@ class MRotaryEmbedding(RotaryEmbedding): st_idx = llm_pos_ids_list[-1].max() + 1 if len( llm_pos_ids_list) > 0 else 0 llm_pos_ids_list.append( - ops.arange(text_len).view(1, -1).broadcast_to((3, -1)).int() + st_idx) + ops.arange(text_len).view(1, -1).broadcast_to((3, -1)).int() + + st_idx) t_index = (ops.arange(llm_grid_t).view(-1, 1).broadcast_to( (-1, llm_grid_h * llm_grid_w)) * video_second_per_grid_t * @@ -388,7 +436,7 @@ class MRotaryEmbedding(RotaryEmbedding): (llm_grid_t, -1, llm_grid_w)).flatten().int() w_index = ops.arange(llm_grid_w).view(1, 1, -1).broadcast_to( (llm_grid_t, llm_grid_h, -1)).flatten().int() - + llm_pos_ids_list.append( ops.stack([t_index, h_index, w_index]) + text_len + st_idx) st = ed + llm_grid_t * llm_grid_h * llm_grid_w @@ -398,7 +446,8 @@ class MRotaryEmbedding(RotaryEmbedding): llm_pos_ids_list) > 0 else 0 text_len = len(input_tokens) - st llm_pos_ids_list.append( - ops.arange(text_len).view(1, -1).broadcast_to((3, -1)).int() + st_idx) + ops.arange(text_len).view(1, -1).broadcast_to((3, -1)).int() + + st_idx) llm_positions = ops.cat(llm_pos_ids_list, axis=1).view(3, -1) mrope_position_delta = (llm_positions.max() + 1 - @@ -457,7 +506,7 @@ class InferMRotaryEmbedding(InferRotaryEmbedding): self.rotary_dim = rotary_dim self.max_position_embeddings = max_position_embeddings self.base = base - self.is_neox_style = is_neox_style + self.is_neox_style = is_neox_style # type: ignore[assignment] self.dtype = dtype super().__init__(head_size, rotary_dim, self.cache_max_position_num, base, is_neox_style, dtype) @@ -466,7 +515,7 @@ class InferMRotaryEmbedding(InferRotaryEmbedding): if self.mrope_section: assert sum(self.mrope_section) == rotary_dim // 2 - def construct( + def construct( # type: ignore[override] self, positions: mindspore.Tensor, query: mindspore.Tensor, @@ -486,14 +535,16 @@ class InferMRotaryEmbedding(InferRotaryEmbedding): if is_prefill: num_tokens = positions.shape[-1] cos, sin = self.freqs_cos[positions], self.freqs_sin[positions] - cos, sin = cos[..., :self.rotary_dim//2], sin[..., :self.rotary_dim//2] + cos, sin = cos[..., :self.rotary_dim // + 2], sin[..., :self.rotary_dim // 2] if positions.ndim == 2: cos_l = ops.split(cos, self.mrope_section, axis=-1) sin_l = ops.split(sin, self.mrope_section, axis=-1) cos, sin = (), () - for i in range(len(self.mrope_section)): - cos += (cos_l[i][i],) - sin += (sin_l[i][i],) + for i in range(len( + self.mrope_section)): # type: ignore[arg-type] + cos += (cos_l[i][i], ) + sin += (sin_l[i][i], ) cos = ops.cat(cos, axis=-1) sin = ops.cat(sin, axis=-1) @@ -501,7 +552,8 @@ class InferMRotaryEmbedding(InferRotaryEmbedding): query = query.view(num_tokens, -1, self.head_size) query_rot = query[..., :self.rotary_dim] query_pass = query[..., self.rotary_dim:] - query_rot = _apply_rotary_emb(query_rot, cos, sin, self.is_neox_style) + query_rot = _apply_rotary_emb(query_rot, cos, sin, + self.is_neox_style) query = ops.cat((query_rot, query_pass), axis=-1).view(query_shape) key_shape = key.shape @@ -513,16 +565,18 @@ class InferMRotaryEmbedding(InferRotaryEmbedding): return query, key # decode - if positions.ndim == 2 and positions.shape[0] == len(self.mrope_section): + if positions.ndim == 2 and positions.shape[0] == len( + self.mrope_section): # type: ignore[arg-type] num_tokens = positions.shape[-1] cos, sin = self.freqs_cos[positions], self.freqs_sin[positions] - cos, sin = cos[..., :self.rotary_dim//2], sin[..., :self.rotary_dim//2] + cos, sin = cos[..., :self.rotary_dim // + 2], sin[..., :self.rotary_dim // 2] cos_l = ops.split(cos, self.mrope_section, axis=-1) sin_l = ops.split(sin, self.mrope_section, axis=-1) cos, sin = (), () - for i in range(len(self.mrope_section)): - cos += (cos_l[i][i],) - sin += (sin_l[i][i],) + for i in range(len(self.mrope_section)): # type: ignore[arg-type] + cos += (cos_l[i][i], ) + sin += (sin_l[i][i], ) cos = ops.cat(cos, axis=-1) sin = ops.cat(sin, axis=-1) freqs_cos = ops.cat([cos, cos], axis=-1).squeeze(1) @@ -532,10 +586,11 @@ class InferMRotaryEmbedding(InferRotaryEmbedding): freqs_cos = self.freqs_cos.index_select(0, positions) freqs_sin = self.freqs_sin.index_select(0, positions) - return self.rotary_embedding_op(query, key, freqs_cos, freqs_sin, batch_valid_length) + return self.rotary_embedding_op(query, key, freqs_cos, freqs_sin, + batch_valid_length) -_ROPE_DICT: Dict[Tuple, InferRotaryEmbedding] = {} +_ROPE_DICT: Dict[Tuple, Union[InferRotaryEmbedding, RotaryEmbedding]] = {} def get_rope( @@ -547,7 +602,7 @@ def get_rope( rope_scaling: Optional[Dict[str, Any]] = None, dtype: Optional[Any] = None, partial_rotary_factor: float = 1.0, -) -> InferRotaryEmbedding: +): if dtype is None: dtype = get_current_vllm_config().model_config.dtype @@ -581,7 +636,15 @@ def get_rope( scaling_type = rope_scaling["rope_type"] if scaling_type == "llama3": - raise NotImplementedError + scaling_factor = rope_scaling["factor"] + low_freq_factor = rope_scaling["low_freq_factor"] + high_freq_factor = rope_scaling["high_freq_factor"] + original_max_position = rope_scaling[ + "original_max_position_embeddings"] + rotary_emb = InferLlama3RotaryEmbedding( + head_size, rotary_dim, max_position, base, is_neox_style, + dtype, scaling_factor, low_freq_factor, high_freq_factor, + original_max_position) elif scaling_type == "default": if "mrope_section" in rope_scaling: rotary_emb = InferMRotaryEmbedding( @@ -598,5 +661,5 @@ def get_rope( else: raise NotImplementedError - _ROPE_DICT[key] = rotary_emb + _ROPE_DICT[key] = rotary_emb # type: ignore[assignment] return rotary_emb diff --git a/vllm_mindspore/model_executor/layers/vocab_parallel_embedding.py b/vllm_mindspore/model_executor/layers/vocab_parallel_embedding.py index 6e760aa57cb1d1ded3c687e893bca76475ad7f0e..768a8238f4d73bcecb2d7cd1a453d6de81e07e09 100644 --- a/vllm_mindspore/model_executor/layers/vocab_parallel_embedding.py +++ b/vllm_mindspore/model_executor/layers/vocab_parallel_embedding.py @@ -19,16 +19,15 @@ from dataclasses import dataclass from typing import List, Optional, Sequence, Tuple from mindspore import Parameter, Tensor, mint, nn, ops -from mindspore.common import dtype as mstype from mindspore.common.dtype import typing +from vllm.config import get_current_vllm_config from vllm.distributed import (divide, get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size) -from vllm.model_executor.layers.quantization.base_config import \ - QuantizationConfig -from vllm.config import get_current_vllm_config +from vllm.model_executor.layers.quantization.base_config import ( + QuantizationConfig) -from vllm_mindspore.distributed.communication_op import \ - ReduceFromModelParallelRegion +from vllm_mindspore.distributed.communication_op import ( + ReduceFromModelParallelRegion) from vllm_mindspore.model_executor.layers.quantization.base_config import ( QuantizeMethodBase, method_has_implemented_embedding) from vllm_mindspore.model_executor.utils import set_weight_attrs @@ -408,7 +407,6 @@ class ParallelLMHead(VocabParallelEmbedding): }, ) else: - # self.register_parameter("bias", None) self.bias = None def tie_weights(self, embed_tokens: VocabParallelEmbedding): @@ -417,8 +415,7 @@ class ParallelLMHead(VocabParallelEmbedding): if self.quant_config and self.quant_config.get_name() == "gguf": return embed_tokens else: - # self.weight = embed_tokens.weight - self.weight.set_data(embed_tokens.weight) + self.weight = embed_tokens.weight return self def forward(self, input_): diff --git a/vllm_mindspore/model_executor/models/llama.py b/vllm_mindspore/model_executor/models/llama.py index 354bfb37aa1017f81e16ec433a9f87effc9a6ad6..954579f11f59f15ff46f15ad0a1a37049b2142b7 100644 --- a/vllm_mindspore/model_executor/models/llama.py +++ b/vllm_mindspore/model_executor/models/llama.py @@ -1,5 +1,4 @@ #!/usr/bin/env python3 -# encoding: utf-8 # Copyright 2025 Huawei Technologies Co., Ltd # Copyright 2024 The vLLM team. # @@ -16,49 +15,36 @@ # limitations under the License. # ============================================================================ -from typing import TYPE_CHECKING, Any, Dict, Iterable, List, Optional, Set, Tuple, Type, Union +from typing import (TYPE_CHECKING, Any, Dict, Iterable, List, Optional, Set, + Tuple, Type, Union) if TYPE_CHECKING: from transformers import LlamaConfig else: LlamaConfig = None +from mindspore import Tensor, mint, nn from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size +from vllm.model_executor.models.interfaces import SupportsPP +from vllm.model_executor.sampling_metadata import SamplingMetadata +from vllm.sequence import IntermediateTensors -from vllm_mindspore.model_executor.layers.linear import ( - MergedColumnParallelLinear, - QKVParallelLinear, - RowParallelLinear, -) -from vllm_mindspore.model_executor.layers.logits_processor import LogitsProcessor from vllm_mindspore.attention import Attention from vllm_mindspore.model_executor.layers.activation import SiluAndMul -from vllm_mindspore.model_executor.layers.vocab_parallel_embedding import ( - DEFAULT_VOCAB_PADDING_SIZE, - ParallelLMHead, - VocabParallelEmbedding, -) -from vllm_mindspore.model_executor.models.utils import ( - PPMissingLayer, - extract_layer_index, - make_layers, - maybe_prefix, - make_empty_intermediate_tensors_factory, -) -from vllm_mindspore.model_executor.layers.sampler import get_sampler, SamplerOutput from vllm_mindspore.model_executor.layers.layernorm import RMSNorm +from vllm_mindspore.model_executor.layers.linear import ( + MergedColumnParallelLinear, QKVParallelLinear, RowParallelLinear) +from vllm_mindspore.model_executor.layers.logits_processor import ( + LogitsProcessor) from vllm_mindspore.model_executor.layers.rotary_embedding import get_rope -from vllm.model_executor.sampling_metadata import SamplingMetadata - -from vllm_mindspore.model_executor.models.model_base import MsModelBase - -from vllm.sequence import IntermediateTensors -from vllm.attention import AttentionMetadata -from vllm.model_executor.models.interfaces import SupportsPP -from vllm.model_executor.model_loader.weight_utils import maybe_remap_kv_scale_name - -from mindspore import Tensor, mint, jit, nn -from mindspore import dtype as mstype +from vllm_mindspore.model_executor.layers.sampler import (SamplerOutput, + get_sampler) +from vllm_mindspore.model_executor.layers.vocab_parallel_embedding import ( + DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding) +from vllm_mindspore.model_executor.models.model_base import NativeModel +from vllm_mindspore.model_executor.models.utils import ( + PPMissingLayer, extract_layer_index, + make_empty_intermediate_tensors_factory, make_layers, maybe_prefix) def default_weight_loader(param, loaded_weight) -> None: @@ -66,6 +52,7 @@ def default_weight_loader(param, loaded_weight) -> None: class LlamaMLP(nn.Cell): + def __init__( self, hidden_size: int, @@ -91,13 +78,10 @@ class LlamaMLP(nn.Cell): prefix=f"{prefix}.down_proj", ) if hidden_act != "silu": - raise ValueError( - f"Unsupported activation: {hidden_act}. " - "Only silu is supported for now." - ) + raise ValueError(f"Unsupported activation: {hidden_act}. " + "Only silu is supported for now.") self.act_fn = SiluAndMul() - @jit def construct(self, x): x, _ = self.gate_up_proj(x) x = self.act_fn(x) @@ -106,6 +90,7 @@ class LlamaMLP(nn.Cell): class LlamaAttention(nn.Cell): + def __init__( self, config: LlamaConfig, @@ -139,9 +124,8 @@ class LlamaAttention(nn.Cell): assert tp_size % self.total_num_kv_heads == 0 self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size) # MistralConfig has an optional head_dim introduced by Mistral-Nemo - self.head_dim = getattr( - config, "head_dim", self.hidden_size // self.total_num_heads - ) + self.head_dim = getattr(config, "head_dim", + self.hidden_size // self.total_num_heads) # Phi models introduced a partial_rotary_factor parameter in the config partial_rotary_factor = getattr(config, "partial_rotary_factor", 1) self.rotary_dim = int(partial_rotary_factor * self.head_dim) @@ -177,7 +161,7 @@ class LlamaAttention(nn.Cell): self.head_dim, rotary_dim=self.head_dim, max_position=max_position_embeddings, - base=rope_theta, + base=rope_theta, # type: ignore[arg-type] rope_scaling=rope_scaling, is_neox_style=is_neox_style, ) @@ -190,7 +174,8 @@ class LlamaAttention(nn.Cell): sw_idx = layer_idx % len(interleaved_sliding_window) sliding_window = interleaved_sliding_window[sw_idx] else: - raise ValueError(f"{type(interleaved_sliding_window)} is not supported.") + raise ValueError( + f"{type(interleaved_sliding_window)} is not supported.") else: sliding_window = None @@ -204,32 +189,33 @@ class LlamaAttention(nn.Cell): per_layer_sliding_window=sliding_window, prefix=f"{prefix}.attn", ) - self.attn_mask = mint.triu(mint.ones(size=(128, 128), dtype=mstype.float16), 1) * -10000.0 - @jit def construct( self, positions: Tensor, hidden_states: Tensor, - kv_cache: Tuple[Tensor, Tensor], - # attn_metadata: AttentionMetadata, - num_prefill_tokens: int, - num_decode_tokens: int, + key_cache: Tensor, + value_cache: Tensor, + is_prefill: bool, slot_mapping: Tensor, - batch_valid_length: Tuple[int], - context_lens: Tensor, + attn_mask: Tensor, + batch_valid_length: Tensor, + q_seq_lens: Tensor, block_tables: Tensor, ) -> Tensor: qkv, _ = self.qkv_proj(hidden_states) - q, k, v = mint.split(qkv, (self.q_size, self.kv_size, self.kv_size), -1) - q, k = self.rotary_emb(positions, q, k, context_lens, num_prefill_tokens) - attn_output = self.attn(q, k, v, kv_cache, num_prefill_tokens, num_decode_tokens, - slot_mapping, batch_valid_length, context_lens, block_tables, self.attn_mask) + q, k, v = mint.split(qkv, (self.q_size, self.kv_size, self.kv_size), + -1) + q, k = self.rotary_emb(positions, q, k, batch_valid_length, is_prefill) + attn_output = self.attn(q, k, v, key_cache, value_cache, is_prefill, + slot_mapping, attn_mask, batch_valid_length, + q_seq_lens, block_tables) output, _ = self.o_proj(attn_output) return output class LlamaDecoderLayer(nn.Cell): + def __init__( self, config: LlamaConfig, @@ -242,17 +228,15 @@ class LlamaDecoderLayer(nn.Cell): rope_theta = getattr(config, "rope_theta", 10000) rope_scaling = getattr(config, "rope_scaling", None) if rope_scaling is not None and getattr( - config, "original_max_position_embeddings", None - ): + config, "original_max_position_embeddings", None): rope_scaling["original_max_position_embeddings"] = ( - config.original_max_position_embeddings - ) - max_position_embeddings = getattr(config, "max_position_embeddings", 8192) + config.original_max_position_embeddings) + max_position_embeddings = getattr(config, "max_position_embeddings", + 8192) # Support abacusai/Smaug-72B-v0.1 with attention_bias # Support internlm/internlm-7b with bias attention_bias = getattr(config, "attention_bias", False) or getattr( - config, "bias", False - ) + config, "bias", False) bias_o_proj = attention_bias # support internlm/internlm3-8b with qkv_bias if hasattr(config, 'qkv_bias'): @@ -262,9 +246,8 @@ class LlamaDecoderLayer(nn.Cell): config=config, hidden_size=self.hidden_size, num_heads=config.num_attention_heads, - num_kv_heads=getattr( - config, "num_key_value_heads", config.num_attention_heads - ), + num_kv_heads=getattr(config, "num_key_value_heads", + config.num_attention_heads), rope_theta=rope_theta, rope_scaling=rope_scaling, max_position_embeddings=max_position_embeddings, @@ -282,23 +265,22 @@ class LlamaDecoderLayer(nn.Cell): bias=getattr(config, "mlp_bias", False), prefix=f"{prefix}.mlp", ) - self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) - self.post_attention_layernorm = RMSNorm( - config.hidden_size, eps=config.rms_norm_eps - ) + self.input_layernorm = RMSNorm(config.hidden_size, + eps=config.rms_norm_eps) + self.post_attention_layernorm = RMSNorm(config.hidden_size, + eps=config.rms_norm_eps) - @jit def construct( self, positions: Tensor, hidden_states: Tensor, - kv_cache: Tuple[Tensor, Tensor], - # attn_metadata: AttentionMetadata, - num_prefill_tokens: int, - num_decode_tokens: int, + key_cache: Tensor, + value_cache: Tensor, + is_prefill: bool, slot_mapping: Tensor, - batch_valid_length: Tuple[int], - context_lens: Tensor, + attn_mask: Tensor, + batch_valid_length: Tensor, + q_seq_lens: Tensor, block_tables: Tensor, residual: Optional[Tensor], ) -> Tuple[Tensor, Tensor]: @@ -307,22 +289,17 @@ class LlamaDecoderLayer(nn.Cell): residual = hidden_states hidden_states = self.input_layernorm(hidden_states) else: - hidden_states, residual = self.input_layernorm(hidden_states, residual) - - hidden_states = self.self_attn( - positions, - hidden_states, - kv_cache, - num_prefill_tokens, - num_decode_tokens, - slot_mapping, - batch_valid_length, - context_lens, - block_tables - ) + hidden_states, residual = self.input_layernorm( + hidden_states, residual) + + hidden_states = self.self_attn(positions, hidden_states, key_cache, + value_cache, is_prefill, slot_mapping, + attn_mask, batch_valid_length, + q_seq_lens, block_tables) # Fully Connected - hidden_states, residual = self.post_attention_layernorm(hidden_states, residual) + hidden_states, residual = self.post_attention_layernorm( + hidden_states, residual) hidden_states = self.mlp(hidden_states) return hidden_states, residual @@ -339,18 +316,18 @@ class LlamaModel(nn.Cell): layer_type: Type[LlamaDecoderLayer] = LlamaDecoderLayer, ): super().__init__() - config = vllm_config + config = vllm_config.model_config.hf_config self.config = config self.padding_idx = config.pad_token_id self.vocab_size = config.vocab_size self.org_vocab_size = config.vocab_size - # TODO: Support quant_config cache_config - quant_config = None - cache_config = None + quant_config = vllm_config.quant_config + self.quant_config = quant_config + cache_config = vllm_config.cache_config + lora_config = vllm_config.lora_config # noqa: F841 - if get_pp_group().is_first_rank or ( - config.tie_word_embeddings and get_pp_group().is_last_rank - ): + if get_pp_group().is_first_rank or (config.tie_word_embeddings + and get_pp_group().is_last_rank): self.embed_tokens = VocabParallelEmbedding( self.vocab_size, config.hidden_size, @@ -377,24 +354,22 @@ class LlamaModel(nn.Cell): self.norm = PPMissingLayer() self.make_empty_intermediate_tensors = make_empty_intermediate_tensors_factory( - ["hidden_states", "residual"], config.hidden_size - ) + ["hidden_states", "residual"], config.hidden_size) def get_input_embeddings(self, input_ids: Tensor) -> Tensor: return self.embed_tokens(input_ids) - @jit def construct( self, input_ids: Optional[Tensor], positions: Tensor, - kv_caches: List[Tuple[Tensor, Tensor]], - # attn_metadata: AttentionMetadata, - num_prefill_tokens: int, - num_decode_tokens: int, + key_caches: List[Tensor], + value_caches: List[Tensor], + is_prefill: bool, slot_mapping: Tensor, - batch_valid_length: Tuple[int], - context_lens: Tensor, + attn_mask: Tensor, + batch_valid_length: Tensor, + q_seq_lens: Tensor, block_tables: Tensor, intermediate_tensors: Optional[IntermediateTensors] = None, inputs_embeds: Optional[Tensor] = None, @@ -410,25 +385,20 @@ class LlamaModel(nn.Cell): hidden_states = intermediate_tensors["hidden_states"] residual = intermediate_tensors["residual"] - for i in range(self.start_layer, self.end_layer): # PP 并行对层进行切分 + for i in range(self.start_layer, self.end_layer): layer = self.layers[i] - hidden_states, residual = layer( - positions, - hidden_states, - kv_caches[i - self.start_layer], - num_prefill_tokens, - num_decode_tokens, - slot_mapping, - batch_valid_length, - context_lens, - block_tables, - residual - ) + hidden_states, residual = layer(positions, hidden_states, + key_caches[i - self.start_layer], + value_caches[i - self.start_layer], + is_prefill, slot_mapping, + attn_mask, batch_valid_length, + q_seq_lens, block_tables, residual) if not get_pp_group().is_last_rank: - return IntermediateTensors( - {"hidden_states": hidden_states, "residual": residual} - ) + return IntermediateTensors({ + "hidden_states": hidden_states, + "residual": residual + }) hidden_states, _ = self.norm(hidden_states, residual) return hidden_states @@ -465,21 +435,21 @@ class LlamaModel(nn.Cell): else: if name in params_dict: param = params_dict[name] - weight_loader = getattr( - param, "weight_loader", default_weight_loader - ) + weight_loader = getattr(param, "weight_loader", + default_weight_loader) weight_loader(param, loaded_weight) loaded_params.add(name) return loaded_params -class LlamaForCausalLM(MsModelBase, SupportsPP): +class LlamaForCausalLM(NativeModel, SupportsPP): + def __init__(self, vllm_config, prefix: str = ""): super().__init__(vllm_config=vllm_config, prefix=prefix) quant_config = vllm_config.quant_config - self.model = LlamaModel(vllm_config=self.config) + self.model = LlamaModel(vllm_config=vllm_config) if get_pp_group().is_last_rank: self.unpadded_vocab_size = self.config.vocab_size @@ -495,68 +465,47 @@ class LlamaForCausalLM(MsModelBase, SupportsPP): DEFAULT_VOCAB_PADDING_SIZE # We need bigger padding if using lora for kernel # compatibility - if not self.lora_config - else self.lora_config.lora_vocab_padding_size - ), + if not self.lora_config else + self.lora_config.lora_vocab_padding_size), quant_config=quant_config, prefix=maybe_prefix(prefix, "lm_head"), ) - # if self.config.tie_word_embeddings: - # self.lm_head = self.lm_head.tie_weights( - # self.model.embed_tokens) + if self.config.tie_word_embeddings: + self.lm_head = self.lm_head.tie_weights( + self.model.embed_tokens) logit_scale = getattr(self.config, "logit_scale", 1.0) - self.logits_processor = LogitsProcessor( - self.unpadded_vocab_size, self.config.vocab_size, logit_scale - ) + self.logits_processor = LogitsProcessor(self.unpadded_vocab_size, + self.config.vocab_size, + logit_scale) self.sampler = get_sampler() else: self.lm_head = PPMissingLayer() self.make_empty_intermediate_tensors = ( - self.model.make_empty_intermediate_tensors - ) - - self.set_modules({"model": self.model, "lm_head": self.lm_head}) + self.model.make_empty_intermediate_tensors) - self.set_model_inputs() + self.common_preprocess(vllm_config, prefix) - def tie_lmhead_weights(self): - self.lm_head = self.lm_head.tie_weights(self.model.embed_tokens) - - def forward( - self, - input_ids, - positions, - kv_caches, - attn_metadata, - intermediate_tensors=None, - inputs_embeds=None, - **kwargs - ): - if attn_metadata.num_prefill_tokens > 0: - input_ids = input_ids.expand_dims(0) - if attn_metadata.num_decode_tokens > 0: - input_ids = input_ids.expand_dims(1) - model_output = self.model(input_ids, - positions, - kv_caches, - **dict(attn_metadata), - intermediate_tensors=intermediate_tensors, - inputs_embeds=inputs_embeds) - if attn_metadata.num_prefill_tokens > 0: - model_output = model_output.squeeze(0) - if attn_metadata.num_decode_tokens > 0: - model_output = model_output.squeeze(1) - return model_output + def forward(self, + input_ids, + positions, + intermediate_tensors=None, + inputs_embeds=None, + **kwargs): + hidden_states = self.exec_model(input_ids, positions, + intermediate_tensors, inputs_embeds) + return hidden_states def load_weights(self, weights: Iterable[Tuple[str, Tensor]]) -> Set[str]: params_dict = self.get_params_dict() - self.model.load_weights(weights, params_dict) + load_params = self.model.load_weights(weights, params_dict) + if self.config.tie_word_embeddings: + load_params.add("lm_head.weight") + return load_params - def sample( - self, logits: Tensor, sampling_metadata: SamplingMetadata - ) -> Optional[SamplerOutput]: + def sample(self, logits: Tensor, + sampling_metadata: SamplingMetadata) -> Optional[SamplerOutput]: next_tokens = self.sampler(logits, sampling_metadata) return next_tokens @@ -565,5 +514,6 @@ class LlamaForCausalLM(MsModelBase, SupportsPP): hidden_states: Tensor, sampling_metadata: SamplingMetadata, ) -> Optional[Tensor]: - logits = self.logits_processor(self.lm_head, hidden_states, sampling_metadata) - return logits \ No newline at end of file + logits = self.logits_processor(self.lm_head, hidden_states, + sampling_metadata) + return logits