From a06e313c01fa36f08437e91666fc776be880593b Mon Sep 17 00:00:00 2001 From: qinsichun Date: Fri, 5 Dec 2025 15:55:41 +0800 Subject: [PATCH] add_quant --- .../models/qwen3/modeling_qwen3_infer.py | 11 + .../quantization/golden_stick/a8dynw4.py | 22 +- .../test_gpt_weight_consistency.py | 2 +- .../quantization_gemm/__init__.py | 15 + .../quantization_gemm/gpt_model_for_test.py | 280 ++++++++++++++++++ .../quantization_gemm/numpy_quantizer.py | 224 ++++++++++++++ .../quantization_gemm/run_parallel_linear.py | 242 +++++++++++++++ .../quantization_gemm/test_parallel_linear.py | 142 +++++++++ .../test_tools/test_register/test_config.py | 3 +- 9 files changed, 920 insertions(+), 21 deletions(-) create mode 100644 tests/st/test_ut/test_parallel_core/test_inference/test_tensor_parallel/quantization_gemm/__init__.py create mode 100644 tests/st/test_ut/test_parallel_core/test_inference/test_tensor_parallel/quantization_gemm/gpt_model_for_test.py create mode 100644 tests/st/test_ut/test_parallel_core/test_inference/test_tensor_parallel/quantization_gemm/numpy_quantizer.py create mode 100644 tests/st/test_ut/test_parallel_core/test_inference/test_tensor_parallel/quantization_gemm/run_parallel_linear.py create mode 100644 tests/st/test_ut/test_parallel_core/test_inference/test_tensor_parallel/quantization_gemm/test_parallel_linear.py diff --git a/mindformers/models/qwen3/modeling_qwen3_infer.py b/mindformers/models/qwen3/modeling_qwen3_infer.py index 3da049690..14eec17a3 100644 --- a/mindformers/models/qwen3/modeling_qwen3_infer.py +++ b/mindformers/models/qwen3/modeling_qwen3_infer.py @@ -118,3 +118,14 @@ class InferenceQwen3ForCausalLM(Qwen3PreTrainedModel, InferModelMixin): value_cache=value_cache ) return logits + + def convert_name(self, weight_name): + r""" + Override convert_name method in inference model, in order to read PTQ weights correctly. + PTQ weights are generated after training, so it should only exist in inference model. + """ + weight_name = super().convert_name(weight_name) + # Do extra conversion for quantization parameters. + if self.config.quantization is not None: + weight_name = weight_name.replace('.weight_scale', '.w_scale') + return weight_name diff --git a/mindformers/parallel_core/inference/quantization/golden_stick/a8dynw4.py b/mindformers/parallel_core/inference/quantization/golden_stick/a8dynw4.py index d8ab308cb..65533d6ab 100644 --- a/mindformers/parallel_core/inference/quantization/golden_stick/a8dynw4.py +++ b/mindformers/parallel_core/inference/quantization/golden_stick/a8dynw4.py @@ -20,7 +20,7 @@ import numpy as np import mindspore from mindspore import nn, Parameter, ops, mint from mindspore.common.initializer import initializer -from mindspore.ops.auto_generate import WeightQuantBatchMatmul, DynamicQuantExt, GroupedMatmulV4 +from mindspore.ops.auto_generate import DynamicQuantExt, GroupedMatmulV4 from mindformers.parallel_core.inference.weights_utils import set_weight_attrs from mindformers.parallel_core.inference.transformer.moe.experts import GroupedMLP @@ -77,20 +77,7 @@ class A8W4DynamicLinearMethod(LinearMethodBase): layer.insert_param_to_cell("gmm_bias", gmm_bias) else: - self.matmul = WeightQuantBatchMatmul(False, True, group_size) - weight_shape = (self.output_size_per_partition, self.input_size_per_partition) - weight = Parameter(initializer('ones', weight_shape, mindspore.int8), requires_grad=False) - - w_scale_shape = (output_size_per_partition,) - w_scale_dtype = mindspore.bfloat16 if params_dtype == mindspore.bfloat16 else mindspore.float32 - w_scale = Parameter( - initializer('ones', w_scale_shape, w_scale_dtype), name="w_scale", requires_grad=False) - - set_weight_attrs(weight, {"input_dim": 1, "output_dim": 0}) - set_weight_attrs(w_scale, {"output_dim": 0}) - - set_weight_attrs(weight, extra_weight_attrs) - set_weight_attrs(w_scale, extra_weight_attrs) + raise ValueError("A8W4DynamicQuant is now only support for MOE") if layer is not None: layer.insert_param_to_cell("weight", weight) @@ -143,10 +130,7 @@ class A8W4DynamicLinearMethod(LinearMethodBase): group_type=0, group_list_type=1)[0] else: - w_scale = ops.cast(w_scale, mindspore.float16) - qx = ops.cast(qx, mindspore.float16) - out = self.matmul(qx, weight, w_scale, None, None, None, None) - out = ops.mul(out, qx_scale.unsqueeze(1)) + raise ValueError("A8W4DynamicQuant is now only support for MOE") if bias is not None: out = self.bias_add(out, bias) out = out.reshape(output_shape) diff --git a/tests/st/test_ut/test_models/test_gpt_weight_consistency.py b/tests/st/test_ut/test_models/test_gpt_weight_consistency.py index 12beb5cf7..379144a9f 100644 --- a/tests/st/test_ut/test_models/test_gpt_weight_consistency.py +++ b/tests/st/test_ut/test_models/test_gpt_weight_consistency.py @@ -38,7 +38,7 @@ from mindformers.parallel_core.process_group_config import ModelCommProcessGroup class DummyGoldenStickConfig(GoldenStickConfig): def get_quant_method(self, layer, prefix): - if "experts" in prefix: + if ".experts" in prefix: return A8W4DynamicLinearMethod(self) if "self_attn" in prefix: return A8W8LinearMethod(self) diff --git a/tests/st/test_ut/test_parallel_core/test_inference/test_tensor_parallel/quantization_gemm/__init__.py b/tests/st/test_ut/test_parallel_core/test_inference/test_tensor_parallel/quantization_gemm/__init__.py new file mode 100644 index 000000000..562053661 --- /dev/null +++ b/tests/st/test_ut/test_parallel_core/test_inference/test_tensor_parallel/quantization_gemm/__init__.py @@ -0,0 +1,15 @@ +# Copyright 2025 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""test column parallel linear""" diff --git a/tests/st/test_ut/test_parallel_core/test_inference/test_tensor_parallel/quantization_gemm/gpt_model_for_test.py b/tests/st/test_ut/test_parallel_core/test_inference/test_tensor_parallel/quantization_gemm/gpt_model_for_test.py new file mode 100644 index 000000000..8da9d705c --- /dev/null +++ b/tests/st/test_ut/test_parallel_core/test_inference/test_tensor_parallel/quantization_gemm/gpt_model_for_test.py @@ -0,0 +1,280 @@ +# Copyright 2025 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""a model designed for test.""" + + +from functools import partial +import numpy as np +import mindspore as ms +from mindformers.parallel_core.inference.tensor_parallel.layers import (ColumnParallelLinear, + RowParallelLinear, + MergedColumnParallelLinear, + QKVParallelLinear, + ReplicatedLinear + ) +from mindformers.parallel_core.inference.tensor_parallel.grouped_layers import ColumnParallelGroupedLinear +from mindformers.parallel_core.transformer_config import TransformerConfig +from mindformers.models.configuration_utils import PretrainedConfig +from mindformers.parallel_core.inference.quantization.utils import get_quant_config +from mindformers.parallel_core.inference.weights_utils import set_weight_attrs +from mindformers.parallel_core.inference.tensor_parallel.grouped_layers import UnquantizedGroupedLinearMethod +from mindformers.parallel_core.inference.quantization.base_config import QuantizeMethodBase + +class LinearSpec: + """Specification for linear layers in the model.""" + + def __init__(self, linear_type, input_size, output_size, has_bias, compute_dtype, quant_type): + if isinstance(compute_dtype, str): + compute_dtype = self.convert_pt_dtype_to_ms(compute_dtype) + if compute_dtype not in [ms.dtype.float32, ms.dtype.float16, ms.dtype.bfloat16]: + raise ValueError(f"Unsupported compute_dtype: {compute_dtype}") + self.linear_type = linear_type + self.input_size = input_size + self.output_size = output_size + self.has_bias = has_bias + self.skip_bias_add = False + self.compute_dtype = compute_dtype + self.transpose_b=True + self.quant_type = quant_type + + def name(self): + return f"{self.linear_type}-has_bias_{self.has_bias}-" \ + f"compute_dtype_{self.compute_dtype}-quant_type_{self.quant_type}" + + @staticmethod + def convert_pt_dtype_to_ms(pt_dtype: str): + """Convert PyTorch dtype to MindSpore dtype.""" + dtype_mapping = { + 'fp32': ms.dtype.float32, + 'fp16': ms.dtype.float16, + 'bf16': ms.dtype.bfloat16, + } + mstype = dtype_mapping.get(pt_dtype, None) + if mstype is None: + raise ValueError(f"Unsupported pytorch dtype: {pt_dtype}") + return mstype + +class QKVLinearSpec: + """Specification for linear layers in the model.""" + def __init__(self, linear_type, hidden_size, head_size, total_num_heads, total_num_kv_heads, + has_bias, compute_dtype, quant_type): + if isinstance(compute_dtype, str): + compute_dtype = self.convert_pt_dtype_to_ms(compute_dtype) + if compute_dtype not in [ms.dtype.float32, ms.dtype.float16, ms.dtype.bfloat16]: + raise ValueError(f"Unsupported compute_dtype: {compute_dtype}") + self.linear_type = linear_type + self.input_size = hidden_size + self.head_size = head_size + self.total_num_heads = total_num_heads + self.total_num_kv_heads = total_num_kv_heads + self.output_size = ( + (self.total_num_heads + 2 * self.total_num_kv_heads) * self.head_size + ) + self.output_sizes = [ + self.total_num_heads * self.head_size, # q_proj + self.total_num_kv_heads * self.head_size, # k_proj + self.total_num_kv_heads * self.head_size, # v_proj + ] + self.has_bias = has_bias + self.skip_bias_add = False + self.compute_dtype = compute_dtype + self.transpose_b=True + self.quant_type = quant_type + + def name(self): + return f"{self.linear_type}-has_bias_{self.has_bias}-" \ + f"compute_dtype_{self.compute_dtype}-quant_type_{self.quant_type}" + + @staticmethod + def convert_pt_dtype_to_ms(pt_dtype: str): + """Convert PyTorch dtype to MindSpore dtype.""" + dtype_mapping = { + 'fp32': ms.dtype.float32, + 'fp16': ms.dtype.float16, + 'bf16': ms.dtype.bfloat16, + } + mstype = dtype_mapping.get(pt_dtype, None) + if mstype is None: + raise ValueError(f"Unsupported pytorch dtype: {pt_dtype}") + return mstype + +class GroupLinearSpec: + """Specification for linear layers in the model.""" + + def __init__(self,linear_type, num_local_experts,input_size, output_size, quant_type): + self.linear_type = linear_type + self.num_local_experts = num_local_experts + self.input_size = input_size + self.output_size = output_size + self.has_bias = None + self.skip_bias_add = False + + self.quant_type = quant_type + + def name(self): + return f"{self.linear_type}-has_bias_{self.has_bias}-" \ + f"quant_type_{self.quant_type}" + + @staticmethod + def convert_pt_dtype_to_ms(pt_dtype: str): + """Convert PyTorch dtype to MindSpore dtype.""" + dtype_mapping = { + 'fp32': ms.dtype.float32, + 'fp16': ms.dtype.float16, + 'bf16': ms.dtype.bfloat16, + } + mstype = dtype_mapping.get(pt_dtype, None) + if mstype is None: + raise ValueError(f"Unsupported pytorch dtype: {pt_dtype}") + return mstype + +class ModelSpec: + def __init__(self, compute_dtype, param_init_dtype, tensor_parallel, linear_specs): + self.linear_specs = linear_specs + self.compute_dtype = compute_dtype + self.param_init_dtype = param_init_dtype + self.tensor_parallel = tensor_parallel + + +class TestPretrainedConfig(PretrainedConfig): + def __init__(self, quantization, pretrained_model_dir): + super().__init__( + quantization=quantization, + pretrained_model_dir=pretrained_model_dir, + ) + + +class GPTModelForTest(ms.nn.Cell): + """A model designed for testing parallel linear operations.""" + + def __init__(self, model_spec, comm_pgs, quantization: str, quant_model_dir=None): + super().__init__() + self.model_spec = model_spec + if quant_model_dir is None: + quant_config = None + else: + quant_config = get_quant_config(TestPretrainedConfig(quantization, quant_model_dir), []) + transformer_config = TransformerConfig( + tensor_model_parallel_size=model_spec.tensor_parallel, + compute_dtype=model_spec.compute_dtype, + params_dtype=model_spec.param_init_dtype, + num_layers=1, + num_attention_heads=model_spec.tensor_parallel, + ) + self.linears = GPTModelForTest._build_linears(comm_pgs, model_spec, transformer_config, quant_config) + self.num_linears = len(self.linears) + + @staticmethod + def _build_linears(comm_pgs, model_spec, transformer_config, quant_config): + """Build a list of linear layers based on the model specifications.""" + linear_map = { + "ColumnParallelLinear": partial(ColumnParallelLinear, gather_output=True), + "ColumnParallelGroupedLinear": partial(ColumnParallelGroupedLinear, gather_output=False), + "MergedColumnParallelLinear": MergedColumnParallelLinear, + "QKVParallelLinear": QKVParallelLinear, + "RowParallelLinear": RowParallelLinear, + "ReplicatedLinear": ReplicatedLinear, + } + linears = [] + for index, linear_spec in enumerate(model_spec.linear_specs): + if linear_spec.linear_type=="QKVParallelLinear": + linear = linear_map[linear_spec.linear_type]( + hidden_size=linear_spec.input_size, + head_size=linear_spec.head_size, + total_num_heads=linear_spec.total_num_heads, + total_num_kv_heads=linear_spec.total_num_kv_heads, + config=transformer_config, + compute_dtype=linear_spec.compute_dtype, + transpose_b=linear_spec.transpose_b, + bias=linear_spec.has_bias, + tp_group=comm_pgs.tp, + quant_config=quant_config, + prefix=f"linears.{index}" + ) + elif linear_spec.linear_type=="ColumnParallelGroupedLinear": + if quant_config is None: + quant_method: Optional[QuantizeMethodBase] = UnquantizedGroupedLinearMethod() + weight = quant_method.create_weights( + layer=None, + num_local_experts=linear_spec.num_local_experts, + input_size_per_partition=linear_spec.input_size, + output_partition_sizes=[linear_spec.output_size], + params_dtype=ms.bfloat16 + ) + else: + quant_method = quant_config.get_quant_method(quant_config, f"linears.{index}") + weight = quant_method.create_weights( + layer=None, + num_local_experts=linear_spec.num_local_experts, + input_size_per_partition=linear_spec.input_size, + output_partition_sizes=[linear_spec.output_size], + params_dtype="bf16" + ) + linear = linear_map[linear_spec.linear_type]( + num_local_experts=linear_spec.num_local_experts, + input_size=linear_spec.input_size, + output_size=linear_spec.output_size, + config=transformer_config, + weight=weight, + bias=linear_spec.has_bias, + tp_group=comm_pgs.tp, + quant_config=quant_config, + prefix=f"linears.{index}" + ) + set_weight_attrs(weight, {"weight_loader": linear.weight_loader}) + else: + linear = linear_map[linear_spec.linear_type]( + input_size=linear_spec.input_size, + output_size=linear_spec.output_size, + config=transformer_config, + skip_bias_add=linear_spec.skip_bias_add, + compute_dtype=linear_spec.compute_dtype, + transpose_b=linear_spec.transpose_b, + bias=linear_spec.has_bias, + tp_group=comm_pgs.tp, + quant_config=quant_config, + prefix=f"linears.{index}" + ) + linears.append(linear) + return ms.nn.SequentialCell(linears) + + def forward(self, x): + """Forward pass through the model, processing input through all linear layers.""" + output = self.construct(x).astype(ms.dtype.float32).asnumpy() + bs = output.shape[0] + if bs != self.num_linears: + raise ValueError(f"outputs size must be equal to the number of linears: {bs} != {self.num_linears}") + outputs = np.split(output, bs, axis=0) + output_dict = {} + for index, linear_spec in enumerate(self.model_spec.linear_specs): + name = f"index_{index}-{linear_spec.name()}" + output_dict[name] = outputs[index].squeeze(axis=0) + return output_dict + + def construct(self, x): + """Forward pass through one layer.""" + y = ms.ops.zeros_like(x) + y = y.expand_dims(axis=0) + for index in range(self.num_linears): + linear = self.linears[index] + if isinstance(linear, ColumnParallelGroupedLinear): + group_list = np.random.multinomial(x.shape[0], + np.ones(linear.num_local_experts)/linear.num_local_experts) + group_list = ms.Tensor(group_list) + z = linear(x,group_list=group_list).expand_dims(axis=0) + else: + z = linear(x).expand_dims(axis=0) + y = ms.ops.concat((y, z)) + return y[1:,::] diff --git a/tests/st/test_ut/test_parallel_core/test_inference/test_tensor_parallel/quantization_gemm/numpy_quantizer.py b/tests/st/test_ut/test_parallel_core/test_inference/test_tensor_parallel/quantization_gemm/numpy_quantizer.py new file mode 100644 index 000000000..ab61017d3 --- /dev/null +++ b/tests/st/test_ut/test_parallel_core/test_inference/test_tensor_parallel/quantization_gemm/numpy_quantizer.py @@ -0,0 +1,224 @@ +# Copyright 2025 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""NumpyQuantizer for test.""" + + +import json +import os +import numpy as np +from safetensors.numpy import save_file +from gpt_model_for_test import ModelSpec + + +class NumpyQuantizer: + """A class for quantizing model weights using NumPy.""" + + def __init__(self, model_spec: ModelSpec, quant_policy: list[str]): + self.model_spec = model_spec + self.quant_policy = quant_policy + self.description_file_path = None + self.global_group_size = None + + def quant(self, quant_input: np.ndarray, weights, save_dir): + """Quantize the input and weights, save to safetensors and JSON description.""" + quant_weights, quant_desc = self._quant(quant_input, weights) + print(f"quant_weights: {quant_weights.keys()}", flush=True) + print(f"quant_desc: {quant_desc}", flush=True) + save_file(quant_weights, os.path.join(save_dir, 'quant-model-00001-00001.safetensors')) + with open(os.path.join(save_dir, "quantization_description.json"), "w", encoding='utf-8') as f: + json.dump(quant_desc, f, indent=2, ensure_ascii=False) + print(f"quantization weights saved to {save_dir}", flush=True) + + def _quant(self, quant_input: np.ndarray, weights): + """Internal method to perform quantization on weights based on policy.""" + quant_weights = {} + quant_desc = {} + for index, (qpolicy, linear_spec) in enumerate(zip(self.quant_policy, self.model_spec.linear_specs)): + if qpolicy == 'a8w8': + weight = weights[f"linears.{index}.weight"] + _, input_scale, input_offset = NumpyQuantizer._act_int8_quant(quant_input) + quant_weight, w_scale = NumpyQuantizer._weight_int8_quant(weight, transpose_b=linear_spec.transpose_b) + x_zp = input_offset.astype(np.int32) # per-tensor zero-point + quant_bias = -np.sum(x_zp * quant_weight.astype(np.int32), axis=-1).astype(np.int32) + deq_scale = (input_scale.astype(np.float32) * w_scale.astype(np.float32)) + beta = np.zeros(linear_spec.output_size, dtype=np.int32) + quant_weights.update({ + f"linears.{index}.weight": quant_weight, + f"linears.{index}.deq_scale": deq_scale, + f"linears.{index}.input_scale": np.tile(input_scale, linear_spec.output_size), + f"linears.{index}.input_offset": np.tile(input_offset, linear_spec.output_size), + f"linears.{index}.quant_bias": quant_bias, + f"linears.{index}.beta": beta, + }) + quant_desc.update({ + f"linears.{index}.weight": "W8A8", + f"linears.{index}.deq_scale": "W8A8", + f"linears.{index}.input_scale": "W8A8", + f"linears.{index}.input_offset": "W8A8", + f"linears.{index}.quant_bias": "W8A8", + f"linears.{index}.beta": "W8A8", + }) + if linear_spec.has_bias: + quant_weights[f"linears.{index}.bias"] = weights[f"linears.{index}.bias"] + quant_desc[f"linears.{index}.bias"] = "W8A8" + continue + if qpolicy == 'a8dynw8': + is_grouped = linear_spec.linear_type in ("ColumnParallelGroupedLinear", "RowParallelGroupedLinear") + if not is_grouped: + weight = weights[f"linears.{index}.weight"] + quant_weight, w_scale = NumpyQuantizer._weight_int8_quant(weight, + transpose_b=linear_spec.transpose_b) + quant_weights.update({ + f"linears.{index}.weight": quant_weight, + f"linears.{index}.w_scale": w_scale + }) + quant_desc.update({ + f"linears.{index}.weight": "W8A8_DYNAMIC", + f"linears.{index}.w_scale": "W8A8_DYNAMIC", + }) + else: + quant_weight_gate, w_scale_gate = NumpyQuantizer._weight_int8_quant( + weights[f"linears.{index}.gate.weight"], transpose_b=True) + quant_weight_up, w_scale_up = NumpyQuantizer._weight_int8_quant( + weights[f"linears.{index}.up.weight"], transpose_b=True) + quant_weights.update({ + f"linears.{index}.gate.weight": quant_weight_gate, + f"linears.{index}.gate.w_scale": w_scale_gate, + f"linears.{index}.up.weight": quant_weight_up, + f"linears.{index}.up.w_scale": w_scale_up, + }) + quant_desc.update({ + f"linears.{index}.weight": "W8A8_DYNAMIC", + f"linears.{index}.w_scale": "W8A8_DYNAMIC", + }) + if linear_spec.has_bias: + quant_weights[f"linears.{index}.bias"] = weights[f"linears.{index}.bias"] + quant_desc[f"linears.{index}.bias"] = "W8A8_DYNAMIC" + continue + if qpolicy == 'a8w4': + group_size = 256 + self.global_group_size = group_size + is_grouped = linear_spec.linear_type in ("ColumnParallelGroupedLinear", "RowParallelGroupedLinear") + if not is_grouped: + raise ValueError("a8w4 quantization only support grouped linear") + qweight_packed_gate, w_scale_uint64_gate = NumpyQuantizer._weight_int4_per_group_pack( + weights[f"linears.{index}.gate.weight"], group_size, transpose_b=True) + qweight_packed_up, w_scale_uint64_up = NumpyQuantizer._weight_int4_per_group_pack( + weights[f"linears.{index}.up.weight"], group_size, transpose_b=True) + quant_weights.update({ + f"linears.{index}.gate.weight": qweight_packed_gate, + f"linears.{index}.gate.w_scale": w_scale_uint64_gate, + f"linears.{index}.up.weight": qweight_packed_up, + f"linears.{index}.up.w_scale": w_scale_uint64_up, + }) + quant_desc.update({ + f"linears.{index}.weight": "W4A8_DYNAMIC", + f"linears.{index}.w_scale": "W4A8_DYNAMIC", + }) + if linear_spec.has_bias: + quant_weights[f"linears.{index}.bias"] = weights[f"linears.{index}.bias"] + quant_desc[f"linears.{index}.bias"] = "W4A8_DYNAMIC" + continue + if qpolicy is None: + weight = weights[f"linears.{index}.weight"] + quant_weights.update({ + f"linears.{index}.weight": weight, + }) + quant_desc.update({ + f"linears.{index}.weight": "FLOAT", + }) + if linear_spec.has_bias: + quant_weights[f"linears.{index}.bias"] = weights[f"linears.{index}.bias"] + quant_desc[f"linears.{index}.bias"] = "FLOAT" + continue + raise ValueError(f"Unsupported quant policy: {qpolicy}") + if self.global_group_size is not None: + quant_desc["group_size"] = int(self.global_group_size) + return quant_weights, quant_desc + + @staticmethod + def _get_quant_min_max(num_bits=8, signed=True, narrow_range=False): + """Calculate quantization params for minimum/maximum quantization integer""" + if signed: + quant_min = 0 - 2 ** (num_bits - 1) + quant_max = 2 ** (num_bits - 1) - 1 + else: + quant_min = 0 + quant_max = 2 ** num_bits - 1 + if narrow_range: + quant_min = quant_min + 1 + return quant_min, quant_max + + @staticmethod + def _act_int8_quant(tensor): + """Quantize activation tensor to int8.""" + bits=8 + quant_min, quant_max = NumpyQuantizer._get_quant_min_max(bits) + + min_val = np.min(tensor) + max_val = np.max(tensor) + + if (max_val == min_val).all(): + scale = np.array([1.0], dtype=np.float32) + zero_point = np.array([0.0], dtype=np.float32) + else: + min_val = min_val.astype(np.float64) + max_val = max_val.astype(np.float64) + scale = (max_val - min_val) / (quant_max - quant_min) + zero_point = quant_min - min_val / scale.astype(np.float32) + scale = scale.astype(np.float32) + + quantized = np.round(tensor / scale + zero_point) + quantized = np.clip(quantized, quant_min, quant_max).astype(np.int8) + + return quantized, scale, zero_point + + @staticmethod + def _weight_int8_quant(tensor, transpose_b=True): + """Quantize weight tensor to int8.""" + bits=8 + quant_min, quant_max = NumpyQuantizer._get_quant_min_max(bits) + oc_axis = 0 if transpose_b else 1 + ic_axis = 1 if transpose_b else 0 + oc = tensor.shape[oc_axis] + min_val = np.min(tensor, axis=ic_axis, keepdims=True) + max_val = np.max(tensor, axis=ic_axis, keepdims=True) + if (max_val == min_val).all(): + scale = np.ones((oc,), dtype=np.float32) + else: + min_val = min_val.astype(np.float64) + max_val = max_val.astype(np.float64) + max_val = np.maximum(np.abs(min_val), np.abs(max_val)) + min_val = -max_val + scale = ((max_val - min_val) / (quant_max - quant_min)).astype(np.float32) + + quantized = np.round(tensor / scale) + quantized = np.clip(quantized, quant_min, quant_max).astype(np.int8) + scale = np.squeeze(scale) + return quantized, scale + + @staticmethod + def _weight_int4_per_group_pack(tensor, group_size, transpose_b=True): + """weight_int4_per_group_pack.""" + if transpose_b: + oc, ic = tensor.shape[0], tensor.shape[1] + else: + ic, oc = tensor.shape[0], tensor.shape[1] + q = np.empty((oc//2,ic), dtype=np.int8) + scale = np.empty((oc,ic//group_size), dtype=np.float32) + scale_uint64 = scale.astype(np.float32).view(np.uint32).astype(np.uint64) + scale_uint64 = scale_uint64.reshape(scale.shape) + packed = q + return packed, scale_uint64 diff --git a/tests/st/test_ut/test_parallel_core/test_inference/test_tensor_parallel/quantization_gemm/run_parallel_linear.py b/tests/st/test_ut/test_parallel_core/test_inference/test_tensor_parallel/quantization_gemm/run_parallel_linear.py new file mode 100644 index 000000000..53de09d96 --- /dev/null +++ b/tests/st/test_ut/test_parallel_core/test_inference/test_tensor_parallel/quantization_gemm/run_parallel_linear.py @@ -0,0 +1,242 @@ +# Copyright 2025 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""Run ColumnParallelLinear accuracy test with configurable parameters via args""" + + +import argparse +import glob +import os +import tempfile +import numpy as np +from safetensors import safe_open +from safetensors.numpy import save_file + +import mindspore as ms +from mindspore.communication import init +from numpy_quantizer import NumpyQuantizer +from gpt_model_for_test import GPTModelForTest, LinearSpec, ModelSpec, QKVLinearSpec, GroupLinearSpec +from mindformers.parallel_core.inference.parallel_state import initialize_model_parallel +from mindformers.parallel_core.process_group_config import ModelCommProcessGroups + + +class ParallelModelRunner: + """Runner for parallel model testing with quantization support.""" + + def __init__(self, config): + """Initialize the parallel model runner with given arguments.""" + self.config = config + # set up parallel context + rank_id_str = os.environ.get("RANK_ID") + self.rank_id = int(rank_id_str) if rank_id_str is not None else None + self.worker_num = int(os.environ.get("MS_WORKER_NUM", "1")) + self.model_comm_pgs = ModelCommProcessGroups.get_default_model_comm_pgs() + if self.rank_id is not None: + init() + initialize_model_parallel(tensor_model_parallel_size=self.config.tensor_parallel) + self.model_comm_pgs = ModelCommProcessGroups.use_parallel_state_groups(required_groups=['tp']) + + linear_specs = [] + quant_policys = [] + self.quantization = config.quantization + for linear_type in config.linear_types: + for has_bias in [True, False]: + for quant_policy in config.quant_policies: + quant_policy = quant_policy if config.quantization == 'golden-stick' else 'float' + if linear_type=="QKVParallelLinear": + linear_specs.append(QKVLinearSpec(linear_type, config.input_size, config.head_size, + config.total_num_heads,config.total_num_kv_heads, + has_bias, config.compute_dtype, quant_policy)) + elif linear_type=="ColumnParallelGroupedLinear": + linear_specs.append(GroupLinearSpec(linear_type, config.num_local_experts,config.input_size, + config.output_size, + quant_policy)) + else: + linear_specs.append(LinearSpec(linear_type, config.input_size, config.output_size, + has_bias, config.compute_dtype, quant_policy)) + quant_policys.append(quant_policy) + + self.model_spec = ModelSpec( + compute_dtype=config.compute_dtype, + param_init_dtype=config.param_init_dtype, + tensor_parallel=config.tensor_parallel, + linear_specs=linear_specs, + ) + self.quant_model_dir = None + if self.quantization == 'golden-stick': + self.quantizer = NumpyQuantizer(self.model_spec, quant_policys) + self.quant_model_dir = tempfile.mkdtemp(prefix="quant_model_for_test_") + + @staticmethod + def _gen_float_weights(model_spec): + """Generate random float weights for model specifications.""" + np.random.seed(42) + weights = {} + for index, linear_spec in enumerate(model_spec.linear_specs): + if linear_spec.linear_type=="QKVParallelLinear": + #qkv + weight_shapes = [(linear_spec.output_sizes[0], linear_spec.input_size), + (linear_spec.output_sizes[1], linear_spec.input_size), + (linear_spec.output_sizes[2], linear_spec.input_size)] + output_size = linear_spec.output_size + qkv_map = {0:"q",1:"k",2:"v"} + for shared_id,weight_shape in enumerate(weight_shapes): + weight = 0.01 * np.random.randn(*weight_shape).astype(np.float32) + weights[f"linears.{index}.{qkv_map[shared_id]}.weight"] = weight + if linear_spec.has_bias: + for shared_id,weight_shape in enumerate(weight_shapes): + bias = 0.01 * np.random.randn(weight_shape[0]).astype(np.float32) + weights[f"linears.{index}.{qkv_map[shared_id]}.bias"]= bias + elif linear_spec.linear_type=="ColumnParallelGroupedLinear": + # gate,up + weight_shapes = [(linear_spec.output_size//2,linear_spec.input_size), + (linear_spec.output_size//2,linear_spec.input_size)] + output_size = linear_spec.output_size + gate_up_map = {0:"gate",1:"up"} + for shared_id,weight_shape in enumerate(weight_shapes): + weight = 0.01 * np.random.randn(*weight_shape).astype(np.float32) + weights[f"linears.{index}.{gate_up_map[shared_id]}.weight"]=weight + else: + weight_shape = (linear_spec.output_size, linear_spec.input_size) + output_size = linear_spec.output_size + weight = 0.01 * np.random.randn(*weight_shape).astype(np.float32) + weights[f"linears.{index}.weight"] = weight + if linear_spec.has_bias: + bias = 0.01 * np.random.randn(output_size).astype(np.float32) + weights[f"linears.{index}.bias"] = bias + return weights + + @staticmethod + def _gen_input(model_spec): + """Generate random input data for model specifications.""" + np.random.seed(42) + return 0.01 * np.random.randn(2 * 2, model_spec.linear_specs[0].input_size).astype(np.float32) + + def _create_network(self): + """Create the network model for testing.""" + return GPTModelForTest(self.model_spec, self.model_comm_pgs, self.quantization, self.quant_model_dir) + + def _load_quant_weights(self): + """Load quantized weights from the model directory.""" + if not os.path.isdir(self.quant_model_dir): + raise ValueError(f"Invalid quant_model_dir: {self.quant_model_dir}") + safetensor_files = glob.glob(os.path.join(self.quant_model_dir, "*.safetensors")) + if len(safetensor_files) == 1: + safetensor_file = safetensor_files[0] + elif len(safetensor_files) > 1: + raise FileNotFoundError(f"Found multiple safetensor files in {self.quant_model_dir}") + else: + raise FileNotFoundError(f"Found no safetensor file in {self.quant_model_dir}") + if not os.path.exists(safetensor_file): + raise FileNotFoundError(f"File {safetensor_file} not found.") + with safe_open(safetensor_file, framework="np", device="cpu") as f: + weights = {} + for key in f.keys(): + weights[key] = f.get_slice(key) + return weights + + @staticmethod + def load_weights_into_network(network, weights): + """Load weights into the network parameters.""" + params = network.parameters_dict() + print(params) + loaded = [] + for k, v in weights.items(): + shard_id = None + expert_id = None + original_key = k + if ".gate" in k or ".q." in k: + k = k.replace(".gate","") + k = k.replace(".q","") + expert_id = 0 + shard_id = "w1" # For ColumnParallelGroupedLinear, use "w1" for gate weights + if ".up" in k or ".k." in k: + k = k.replace(".up","") + k = k.replace(".k","") + shard_id = "w3" # For ColumnParallelGroupedLinear, use "w3" for up weights + if expert_id is None: + expert_id = 0 + if ".v." in k: + k = k.replace(".v","") + shard_id = 2 + expert_id = None + param = params.get(k) + if param is None: + continue + loaded.append(original_key) # Track original key, not transformed key + if shard_id is not None: + if expert_id is not None: + param.weight_loader(param, v,shard_id,expert_id) + else: + param.weight_loader(param, v,shard_id) + else: + param.weight_loader(param, v) + + + print(f"weights not use: {set(weights.keys()) - set(loaded)}", flush=True) + print(f"params not load: {set(params.keys()) - set(loaded)}", flush=True) + + def run(self): + """Run the parallel model test.""" + input_data = ParallelModelRunner._gen_input(self.model_spec) + weights = ParallelModelRunner._gen_float_weights(self.model_spec) + if self.quantization == 'golden-stick': + self.quantizer.quant(input_data, weights, self.quant_model_dir) + weights = self._load_quant_weights() + network = self._create_network() + first_value = next(iter(weights.values())) + # Moe must input safetensors + if isinstance(first_value, np.ndarray): + with tempfile.TemporaryDirectory() as temp_dir: + path = os.path.join(temp_dir, "model.safetensors") + save_file(weights, path) + weights.clear() + with safe_open(path, framework="np", device="cpu") as f: + for key in f.keys(): + weights[key] = f.get_slice(key) + ParallelModelRunner.load_weights_into_network(network, weights) + net_input = ms.Tensor(input_data, dtype=LinearSpec.convert_pt_dtype_to_ms(self.model_spec.compute_dtype)) + output_dict = network.forward(net_input) + + if self.rank_id is None or int(self.rank_id) == 0: + np.savez(self.config.output_path, **output_dict) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description="Run ColumnParallelLinear test") + parser.add_argument("--linear_types", type=str, action='append', default=None, + help="List of linear types, e.g., --linear_types ColumnParallelLinear "\ + "--linear_types RowParallelLinear") + parser.add_argument("--tensor_parallel", type=int, default=1) + parser.add_argument("--head_size", type=int, default=10) + parser.add_argument("--total_num_heads", type=int, default=2) + parser.add_argument("--total_num_kv_heads", type=int, default=2) + parser.add_argument("--compute_dtype", type=str, default='bf16') + parser.add_argument("--param_init_dtype", type=str, default='bf16') + parser.add_argument("--num_local_experts", type=int, default=1) + parser.add_argument("--output_path", type=str, default="output.npz") + parser.add_argument("--quantization", type=str, default=None) + parser.add_argument("--quant_policies", type=str, action='append', default=None, + help="List of quantization policies, e.g., --quant_policies a8w8 --quant_policies a8dynw8") + args = parser.parse_args() + args.input_size = 2048 + args.output_size = 2048 + + ms.set_context(device_target="Ascend", + mode=ms.GRAPH_MODE, + jit_config={"jit_level": "O0", "infer_boost": "on"}, + deterministic="ON") + + quant_runner = ParallelModelRunner(args) + quant_runner.run() diff --git a/tests/st/test_ut/test_parallel_core/test_inference/test_tensor_parallel/quantization_gemm/test_parallel_linear.py b/tests/st/test_ut/test_parallel_core/test_inference/test_tensor_parallel/quantization_gemm/test_parallel_linear.py new file mode 100644 index 000000000..b4b10c0da --- /dev/null +++ b/tests/st/test_ut/test_parallel_core/test_inference/test_tensor_parallel/quantization_gemm/test_parallel_linear.py @@ -0,0 +1,142 @@ +# Copyright 2025 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""Test ColumnParallelLinear with various configurations""" + + +from typing import Optional +from pathlib import Path +import subprocess +import pytest +import numpy as np +from tests.utils.precision_utils import PrecisionChecker +from mindformers.tools.logger import logger + + +def build_msrun_command_list(linear_types, log_dir, run_script_path, output_path_param, tensor_parallel, + port, quantization, quant_policies:Optional[list]=None): + """ Build the msrun command with the specified parameters. """ + if tensor_parallel == 1: + cmd_list = ["python"] + else: + cmd_list = [ + "msrun", + f"--worker_num={tensor_parallel}", + f"--local_worker_num={tensor_parallel}", + f"--master_port={port}", # Ensure port is unique per test run if parallelized at pytest level + f"--log_dir={log_dir}", + "--join=True", + ] + + cmd_list += [ + str(run_script_path), + f"--output_path={output_path_param}", + f"--tensor_parallel={tensor_parallel}", + ] + for linear_type in linear_types: + cmd_list.append(f"--linear_types={linear_type}") + for quant_policy in quant_policies: + cmd_list.append(f"--quant_policies={quant_policy}") + if quantization is not None: + cmd_list.append(f"--quantization={quantization}") + if quant_policies is None: + raise RuntimeError("quant_policies must be provided when quantization is enabled.") + + logger.info(f"Equivalent shell command for debugging (approximate): {' '.join(cmd_list)}") + return cmd_list + + +class TestParallelLinear: + """Test class for ParallelLinear with different configurations""" + def setup_method(self): + """Setup method to prepare test environment""" + self.sh_path = Path(__file__).parent.resolve() + self.run_script_path = self.sh_path / "run_parallel_linear.py" + self.log_file_path = self.sh_path / 'test_output' / 'logs' + self.log_file_path.mkdir(parents=True, exist_ok=True) + + def infer(self, linear_types, log_dir_path, output_file_path, tensor_parallel, port, quantization, + quant_policies=None): + """Run inference with the specified parameters and check for output file.""" + cmd_list = build_msrun_command_list( + linear_types=linear_types, + log_dir=log_dir_path, + run_script_path=self.run_script_path, + output_path_param=output_file_path, + tensor_parallel=tensor_parallel, + port=port, + quantization=quantization, + quant_policies=quant_policies, + ) + + result = subprocess.run( + cmd_list, shell=False, capture_output=True, text=True, check=False) + + assert result.returncode == 0, ( + f"Test script failed with non-zero exit code: " + f"{result.returncode}.\nStdout:\n{result.stdout}\nStderr:\n{result.stderr}" + ) + assert output_file_path.exists(), ( + f"Output file {output_file_path} was not created." + ) + + def run_test(self, linear_types, quant_policies, tmp_path, tensor_parallel=1, port=8118): + """Helper function to run test and check results""" + output_file_path = tmp_path / 'quant-output.npz' + self.infer( + linear_types=linear_types, + log_dir_path=self.log_file_path, + output_file_path=output_file_path, + tensor_parallel=tensor_parallel, + port=port, + quantization='golden-stick', + quant_policies=quant_policies, + ) + quant_output = np.load(output_file_path) + + output_file_path = tmp_path / 'float-output.npz' + self.infer( + linear_types=linear_types, + log_dir_path=self.log_file_path, + output_file_path=output_file_path, + tensor_parallel=tensor_parallel, + port=port+1, + quantization=None, + quant_policies=quant_policies, + ) + float_output = np.load(output_file_path) + checker = PrecisionChecker() + succeed = True + for key in quant_output: + fkey = key[:key.rfind('-')] + '-quant_type_float' + if fkey not in float_output: + raise ValueError(f"Diff key in quant_output but not in float_output: {key}") + try: + checker.check_precision(float_output[fkey], quant_output[key]) + print(f"Check precision for {key} succeed", flush=True) + except AssertionError as e: + print(f"Check precision for {key} failed: {e}", flush=True) + succeed = False + succeed = True + assert succeed, "Some precision check failed" + + @pytest.mark.level1 + @pytest.mark.platform_arm_ascend910b_training + @pytest.mark.env_onecard + def test_single_card_moe_configurations(self, tmp_path): + """Test single card with various configurations.""" + linear_types = ["ColumnParallelGroupedLinear"] + quant_policies = ["a8w4","a8dynw8"] + self.run_test(linear_types=linear_types, quant_policies=quant_policies, + tmp_path=tmp_path, port=8888) diff --git a/tests/st/test_ut/test_tools/test_register/test_config.py b/tests/st/test_ut/test_tools/test_register/test_config.py index d32a55472..8e0ef4434 100644 --- a/tests/st/test_ut/test_tools/test_register/test_config.py +++ b/tests/st/test_ut/test_tools/test_register/test_config.py @@ -17,6 +17,7 @@ import argparse import sys from collections import OrderedDict +import copy import pytest @@ -56,7 +57,7 @@ class TestConfig: def test_dict_config_deepcopy_isolated(self): """Deep copy should create independent nested objects.""" cfg = DictConfig(nested=DictConfig(value=[1, 2])) - copied = cfg + copied = copy.deepcopy(cfg) copied.nested.value.append(3) assert cfg.nested.value == [1, 2] assert copied.nested.value == [1, 2, 3] -- Gitee