diff --git a/mindformers/parallel_core/inference/quantization/golden_stick/a8w8.py b/mindformers/parallel_core/inference/quantization/golden_stick/a8w8.py index 6cae4e9bd2e264a9565056a7bf6cbebcb726d060..d39993d51e175491b1fc546bb9438f40d39e09c5 100644 --- a/mindformers/parallel_core/inference/quantization/golden_stick/a8w8.py +++ b/mindformers/parallel_core/inference/quantization/golden_stick/a8w8.py @@ -32,6 +32,7 @@ class A8W8LinearMethod(LinearMethodBase): self.bias_add = ops.Add() self.is_modelslim = self.quant_config.is_modelslim self.is_ms_custom_ops = False + # pylint: disable=C0415 try: import ms_custom_ops self.is_ms_custom_ops = True @@ -108,6 +109,7 @@ class A8W8LinearMethod(LinearMethodBase): This can be used for example, to transpose weights for computation. """ if self.is_ms_custom_ops: + layer.weight.init_data() layer.weight = self.ms_custom_ops.trans_data(layer.weight, transdata_type=1) if not self.is_modelslim: return diff --git a/mindformers/parallel_core/inference/transformer/multi_latent_attention.py b/mindformers/parallel_core/inference/transformer/multi_latent_attention.py index bb704a6037ed577df653d6eb3d930b1a95ead806..d9714e43d1f3a76072c0688bbc5e2f67d8b3f37e 100644 --- a/mindformers/parallel_core/inference/transformer/multi_latent_attention.py +++ b/mindformers/parallel_core/inference/transformer/multi_latent_attention.py @@ -28,7 +28,7 @@ import numpy as np from mindspore import mint, Tensor, dtype, Parameter, ops from mindspore.ops import operations as P -from mindspore.common.initializer import Zero +from mindspore.common.initializer import Zero, Normal from mindspore.ops.operations._infer_ops import QuantV2 import mindspore as ms @@ -358,11 +358,21 @@ class MLASelfAttention(MultiLatentAttention): eps=self.config.layernorm_epsilon ) + self.q_absorb = Tensor(shape=(self.num_attention_heads_per_partition, + self.config.qk_head_dim, self.config.kv_lora_rank), + dtype=self.config.compute_dtype, + init=Normal(sigma=1.0)) + self.out_absorb = Tensor(shape=(self.num_attention_heads_per_partition, + self.config.v_head_dim, self.config.kv_lora_rank), + dtype=self.config.compute_dtype, + init=Normal(sigma=1.0)) + def process_weights_after_loading(self) -> None: """ Process the weight after loading. This can be used for example, to transpose weights for computation. """ + self.linear_kv_up_proj.weight.init_data() q_absorb, out_absorb = ops.function.array_func.split_ext(self.linear_kv_up_proj.weight, [self.num_attention_heads_per_partition * self.config.qk_head_dim, self.num_attention_heads_per_partition * self.config.v_head_dim], -2)