From c7f3d39b4b6fbd744765f774147b6028c13b9dde Mon Sep 17 00:00:00 2001 From: zm Date: Thu, 16 Oct 2025 21:13:37 +0800 Subject: [PATCH] add ajust weight of Llama on 310p --- vllm_mindspore/model_executor/models/llama.py | 27 ++++++++++++++++++- 1 file changed, 26 insertions(+), 1 deletion(-) diff --git a/vllm_mindspore/model_executor/models/llama.py b/vllm_mindspore/model_executor/models/llama.py index ad85b022..ca14670c 100644 --- a/vllm_mindspore/model_executor/models/llama.py +++ b/vllm_mindspore/model_executor/models/llama.py @@ -46,6 +46,7 @@ if TYPE_CHECKING: else: LlamaConfig = None +import mindspore as ms from mindspore import Tensor, mint, nn from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size from vllm.model_executor.models.interfaces import SupportsPP @@ -68,7 +69,7 @@ from vllm_mindspore.model_executor.models.model_base import NativeModel from vllm_mindspore.model_executor.models.utils import ( PPMissingLayer, extract_layer_index, make_empty_intermediate_tensors_factory, make_layers, maybe_prefix) - +from vllm_mindspore.utils import is_310p, FORMAT_TYPE class LlamaMLP(nn.Cell): @@ -458,6 +459,30 @@ class LlamaModel(nn.Cell): weight_loader(param, loaded_weight) loaded_params.add(name) + def adjust_weight(params_dict): + if not is_310p(): + return + + target_keywords = [ + "qkv_proj.weight", + "o_proj.weight", + "gate_up_proj.weight", + "down_proj.weight", + # "lm_head.weight", + ] + + for name, param in params_dict.items(): + if any(name.endwith(keyword) for keyword in target_keywords): + cast_weight = ops.auto_generate.format_cast( + param, FORMAT_TYPE['nz']) + ms.runtime.synchronize() + param.set_data(cast_weight) + + if is_310p(): + ms.runtime.synchronize() + adjust_weight(params_dict) + ms.runtime.synchronize() + return loaded_params -- Gitee