From 29fb8eda832297242db5a73929aeec4642a0c8bf Mon Sep 17 00:00:00 2001 From: Yule100 Date: Wed, 24 Dec 2025 10:23:01 +0800 Subject: [PATCH] =?UTF-8?q?bugfix=20=E4=BF=AE=E5=A4=8Dds=E4=B8=8D=E5=8A=A0?= =?UTF-8?q?=E8=BD=BD=E6=9D=83=E9=87=8D=E5=9C=BA=E6=99=AF=E7=9A=84=E6=8A=A5?= =?UTF-8?q?=E9=94=99?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../inference/quantization/golden_stick/a8w8.py | 2 ++ .../inference/transformer/multi_latent_attention.py | 12 +++++++++++- 2 files changed, 13 insertions(+), 1 deletion(-) diff --git a/mindformers/parallel_core/inference/quantization/golden_stick/a8w8.py b/mindformers/parallel_core/inference/quantization/golden_stick/a8w8.py index 6cae4e9bd..d39993d51 100644 --- a/mindformers/parallel_core/inference/quantization/golden_stick/a8w8.py +++ b/mindformers/parallel_core/inference/quantization/golden_stick/a8w8.py @@ -32,6 +32,7 @@ class A8W8LinearMethod(LinearMethodBase): self.bias_add = ops.Add() self.is_modelslim = self.quant_config.is_modelslim self.is_ms_custom_ops = False + # pylint: disable=C0415 try: import ms_custom_ops self.is_ms_custom_ops = True @@ -108,6 +109,7 @@ class A8W8LinearMethod(LinearMethodBase): This can be used for example, to transpose weights for computation. """ if self.is_ms_custom_ops: + layer.weight.init_data() layer.weight = self.ms_custom_ops.trans_data(layer.weight, transdata_type=1) if not self.is_modelslim: return diff --git a/mindformers/parallel_core/inference/transformer/multi_latent_attention.py b/mindformers/parallel_core/inference/transformer/multi_latent_attention.py index bb704a603..d9714e43d 100644 --- a/mindformers/parallel_core/inference/transformer/multi_latent_attention.py +++ b/mindformers/parallel_core/inference/transformer/multi_latent_attention.py @@ -28,7 +28,7 @@ import numpy as np from mindspore import mint, Tensor, dtype, Parameter, ops from mindspore.ops import operations as P -from mindspore.common.initializer import Zero +from mindspore.common.initializer import Zero, Normal from mindspore.ops.operations._infer_ops import QuantV2 import mindspore as ms @@ -358,11 +358,21 @@ class MLASelfAttention(MultiLatentAttention): eps=self.config.layernorm_epsilon ) + self.q_absorb = Tensor(shape=(self.num_attention_heads_per_partition, + self.config.qk_head_dim, self.config.kv_lora_rank), + dtype=self.config.compute_dtype, + init=Normal(sigma=1.0)) + self.out_absorb = Tensor(shape=(self.num_attention_heads_per_partition, + self.config.v_head_dim, self.config.kv_lora_rank), + dtype=self.config.compute_dtype, + init=Normal(sigma=1.0)) + def process_weights_after_loading(self) -> None: """ Process the weight after loading. This can be used for example, to transpose weights for computation. """ + self.linear_kv_up_proj.weight.init_data() q_absorb, out_absorb = ops.function.array_func.split_ext(self.linear_kv_up_proj.weight, [self.num_attention_heads_per_partition * self.config.qk_head_dim, self.num_attention_heads_per_partition * self.config.v_head_dim], -2) -- Gitee