diff --git a/docs/map_from_buildin_to_custom.md b/docs/map_from_buildin_to_custom.md index 1f47895cf40adede3b4eafc735a41be574e2276e..40cbe07ed6586afbb8185d5890188d8e41c7edc8 100644 --- a/docs/map_from_buildin_to_custom.md +++ b/docs/map_from_buildin_to_custom.md @@ -1,14 +1,16 @@ -|build-in接口|ms_custom_ops接口|变化说明| -|------------------------------------|-------------------------------------|-----------------------------------| -| ops.auto_generate.format_cast | [ms_custom_ops.trans_data](../ops/c_api/trans_data/trans_data.md) | 两者均进行ND和NZ的相互转换。format_cast依赖ms内置白名单;trans_data不使用白名单但有使用约束,详见trans_data文档。同一权重只能选用其中一种进行转换,建议网络中统一采用同一种算子,两者不兼容。 | -| ops.auto_generate.mla | [ms_custom_ops.mla](../ops/c_api/mla/mla_doc.md) | 新增了input_format参数,用于指定输入参数的format | -| ops.auto_generate.reshape_and_cache| [ms_custom_ops.reshape_and_cache](../ops/c_api/reshape_and_cache/reshape_and_cache.md) | 新增cache_mode参数,用于指定Atlas 训练系列cache的format是ND还是NZ; 新增head_num,cache_mode为NZ的时候必须提供,辅助计算。| -| ops.moe_init_routing_v2 | [ms_custom_ops.moe_init_routing_v2](../ops/c_api/moe_init_routing_v2/moe_init_routing_v2.md) | 接口一致,仅支持 Atlas 推理系列 | -|ops.auto_generate.moe_gating_group_topk|[ms_custom_ops.moe_gating_group_topk](../ops/c_api/moe_gating_group_topk/moe_gating_group_topk.md) |接口一致| -| ops.auto_generate.group_topk | [ms_custom_ops.group_topk](../ops/c_api/group_topk/group_topk_doc.md) | 副作用接口,将不再支持左值输出 | -| ops.auto_generate.mla_preprocess | [ms_custom_ops.mla_preprocess](../ops/c_api/mla_preprocess/mla_preprocess_doc.md) | 接口一致 | -| ops.auto_generate.fused_add_topk_div | [ms_custom_ops.fused_add_topk_div](../ops/c_api/fused_add_topk_div/fused_add_topk_div_doc.md) | 接口一致 | -| ops.auto_generate.paged_cache_load | [ms_custom_ops.paged_cache_load](../ops/c_api/paged_cache_load/paged_cache_load_doc.md) | 新增支持key、value支持不同dtype;取消inplace更新的输出key、value,直接改为输出 | -| ops.auto_generate.quant_batch_matmul | [ms_custom_ops.quant_batch_matmul](../ops/c_api/quant_batch_matmul/quant_batch_matmul.md) | 新增了x2_format参数,用于指定x2的format; 入参名称`pertokenScaleOptional`修改为`pertoken_scale`; 入参名称`dtype`修改为`output_dtype` | -| ops.auto_generate.apply_rotary_pos_emb | [ms_custom_ops.apply_rotary_pos_emb_atb](../ops/c_api/apply_rotary_pos_emb_atb/apply_rotary_pos_emb_atb.md) | 新增atb的apply_rotary_pos_emb_atb算子,代替ops.auto_generate.apply_rotary_pos_emb,注意rotary_coeff和cos_format有变化,详见[API](https://gitee.com/mindspore/ms_custom_ops/blob/master/ops/c_api/apply_rotary_pos_emb_atb/apply_rotary_pos_emb_atb.md) | -| ops.moe_token_unpermute | [ms_custom_ops.moe_token_unpermute](../ops/c_api/moe_token_unpermute/moe_token_unpermute.md) | 接口参数一致, 但需注意:ops接口只支持A2训练芯片,ms_custom_ops场景下只支持Atlas推理系列产品, 并且ms_custom_ops场景下当前仅支持:`padded_mode = false, restore_shape = None`, topK 支持 1、2,、4、8, hidden_size 支持 2048、5120、7168。 | +|build-in接口|ms_custom_ops接口|变化说明| +|------------------------------------|-------------------------------------|-----------------------------------| +| ops.auto_generate.format_cast | [ms_custom_ops.trans_data](../ops/c_api/trans_data/trans_data.md) | 两者均进行ND和NZ的相互转换。format_cast依赖ms内置白名单;trans_data不使用白名单但有使用约束,详见trans_data文档。同一权重只能选用其中一种进行转换,建议网络中统一采用同一种算子,两者不兼容。 | +| ops.auto_generate.mla | [ms_custom_ops.mla](../ops/c_api/mla/mla_doc.md) | 新增了input_format参数,用于指定输入参数的format | +| ops.auto_generate.reshape_and_cache| [ms_custom_ops.reshape_and_cache](../ops/c_api/reshape_and_cache/reshape_and_cache.md) | 新增cache_mode参数,用于指定Atlas 训练系列cache的format是ND还是NZ; 新增head_num,cache_mode为NZ的时候必须提供,辅助计算。| +| ops.moe_init_routing_v2 | [ms_custom_ops.moe_init_routing_v2](../ops/c_api/moe_init_routing_v2/moe_init_routing_v2.md) | 接口一致,仅支持 Atlas 推理系列 | +|ops.auto_generate.moe_gating_group_topk|[ms_custom_ops.moe_gating_group_topk](../ops/c_api/moe_gating_group_topk/moe_gating_group_topk.md) |接口一致| +| ops.auto_generate.group_topk | [ms_custom_ops.group_topk](../ops/c_api/group_topk/group_topk_doc.md) | 副作用接口,将不再支持左值输出 | +| ops.auto_generate.mla_preprocess | [ms_custom_ops.mla_preprocess](../ops/c_api/mla_preprocess/mla_preprocess_doc.md) | 接口一致 | +| ops.auto_generate.fused_add_topk_div | [ms_custom_ops.fused_add_topk_div](../ops/c_api/fused_add_topk_div/fused_add_topk_div_doc.md) | 接口一致 | +| ops.auto_generate.paged_cache_load | [ms_custom_ops.paged_cache_load](../ops/c_api/paged_cache_load/paged_cache_load_doc.md) | 新增支持key、value支持不同dtype;取消inplace更新的输出key、value,直接改为输出 | +| ops.auto_generate.quant_batch_matmul | [ms_custom_ops.quant_batch_matmul](../ops/c_api/quant_batch_matmul/quant_batch_matmul.md) | 新增了x2_format参数,用于指定x2的format; 入参名称`pertokenScaleOptional`修改为`pertoken_scale`; 入参名称`dtype`修改为`output_dtype` | +| ops.auto_generate.apply_rotary_pos_emb | [ms_custom_ops.apply_rotary_pos_emb_atb](../ops/c_api/apply_rotary_pos_emb_atb/apply_rotary_pos_emb_atb.md) | 新增atb的apply_rotary_pos_emb_atb算子,代替ops.auto_generate.apply_rotary_pos_emb,注意rotary_coeff和cos_format有变化,详见[API](https://gitee.com/mindspore/ms_custom_ops/blob/master/ops/c_api/apply_rotary_pos_emb_atb/apply_rotary_pos_emb_atb.md) | +| ops.moe_token_unpermute | [ms_custom_ops.moe_token_unpermute](../ops/c_api/moe_token_unpermute/moe_token_unpermute.md) | 接口参数一致, 但需注意:ops接口只支持A2训练芯片,ms_custom_ops场景下只支持Atlas推理系列产品, 并且ms_custom_ops场景下当前仅支持:`padded_mode = false, restore_shape = None`, topK 支持 1、2,、4、8, hidden_size 支持 2048、5120、7168。 | +| ops.auto_generate.GroupedMatmul | [ms_custom_ops.grouped_matmul](../ops/c_api/grouped_matmul/grouped_matmul.md) | 接口变更:输入从tuple[tensor]改为tensor,group_list前移到weight之后并改为必传参数,移除了offset、antiquant_offset、split_item、group_type等参数,返回从tuple[tensor]改为tensor,基于Internal框架实现,仅支持 Atlas 推理系列 | +| ops.auto_generate.GroupedMatmulV4 | [ms_custom_ops.grouped_matmul](../ops/c_api/grouped_matmul/grouped_matmul.md) | 接口变更:输入从tuple[tensor]改为tensor,group_list前移到weight之后并改为必传参数,per_token_scale改为per_token_scale,移除了offset、antiquant_offset、activation相关参数、split_item、group_type、group_list_type、act_type、output_dtype等参数,返回从tuple[tensor]改为tensor,增加transpose_a、transpose_b参数,基于Internal框架实现,仅支持 Atlas 推理系列 | \ No newline at end of file diff --git a/docs/op_list.md b/docs/op_list.md index 9d8241d931a40b80833cc44c1c325ed1a6b30428..a7a8ed889a0ce6dd41f062b9b8da95c785f4b360 100644 --- a/docs/op_list.md +++ b/docs/op_list.md @@ -8,6 +8,8 @@ 1. [flash_attention_encoder](../ops/c_api/flash_attention_encoder/flash_attention_encoder.md) 1. [fused_add_topk_div](../ops/c_api/fused_add_topk_div/fused_add_topk_div.md) 1. [group_topk](../ops/c_api/group_topk/group_topk_doc.md) +1. [grouped_matmul](../ops/c_api/grouped_matmul/grouped_matmul_doc.md) +1. [grouped_matmul_w4](../ops/c_api/grouped_matmul_w4/grouped_matmul_w4_doc.md) 1. [kv_rmsnorm_rope_cache](../ops/c_api/kv_rmsnorm_rope_cache/kv_rmsnorm_rope_cache.md) 1. [lightning_indexer](../ops/c_api/lightning_indexer/lightning_indexer_doc.md) 1. [matmul](../ops/c_api/matmul/matmul.md) diff --git a/ops/c_api/grouped_matmul/grouped_matmul.cc b/ops/c_api/grouped_matmul/grouped_matmul.cc new file mode 100644 index 0000000000000000000000000000000000000000..4da57c65ece1a71343831590a54baced9f51b5a0 --- /dev/null +++ b/ops/c_api/grouped_matmul/grouped_matmul.cc @@ -0,0 +1,332 @@ +/** + * Copyright 2025 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// ============================================================================= +// GRAPH MODE IMPLEMENTATION +// ============================================================================= + +#include +#include +#include +#include +#include +#include + +#include "ops/framework/ms_kernels_internal/graphmode/internal_kernel_mod.h" +#include "ops/framework/utils.h" + +namespace ms_custom_ops { +enum class GroupedMatmulInputIndex : size_t { + kGmmXIndex = 0, + kGmmWeightIndex, + kGmmGroupListIndex, + kGmmBiasIndex, + kGmmScaleIndex, + kGmmPerTokenScaleIndex, + kGmmAntiquantScaleIndex, + kGmmTransposeAIndex, + kGmmTransposeBIndex, + kGmmXFormatIndex, + kGmmInputsNum, +}; +enum class GroupedMatmulOutputIndex : size_t { + kGmmOutputIndex = 0, + kGmmOutputsNum, +}; + +// ============================================================================= +// UTILITY FUNCTIONS +// ============================================================================= +void CheckInputRank(const ShapeVector &x_shape, const ShapeVector &weight_shape, const ShapeVector &group_list_shape) { + auto x_shape_size = x_shape.size(); + auto weight_shape_size = weight_shape.size(); + auto group_list_shape_size = group_list_shape.size(); + if (MS_UNLIKELY(x_shape_size != kDim2)) { + MS_LOG(EXCEPTION) << "For GroupedMatmul, input 'x' must be 2D, but got:" << x_shape_size; + } + if (MS_UNLIKELY(weight_shape_size != kDim3)) { + MS_LOG(EXCEPTION) << "For GroupedMatmul, input 'weight' must be 3D, but got:" << weight_shape_size; + } + if (MS_UNLIKELY(group_list_shape_size != kDim1)) { + MS_LOG(EXCEPTION) << "For GroupedMatmul, input 'group_list' must be 1D, but got:" << group_list_shape_size; + } +} + +ShapeVector InferGroupedMatmulOutputShape(const ShapeVector &x_shape, const ShapeVector &weight_shape, + const ShapeVector &group_list_shape, bool transpose_a, bool transpose_b, + TypeId x_dtype) { + CheckInputRank(x_shape, weight_shape, group_list_shape); + // get input dimensions + auto x_m_dim = transpose_a ? x_shape[kIndex1] : x_shape[kIndex0]; + auto x_k_dim = transpose_a ? x_shape[kIndex0] : x_shape[kIndex1]; + auto weight_e_dim = weight_shape[kIndex0]; + auto weight_n_dim = transpose_b ? weight_shape[kIndex1] : weight_shape[kIndex2]; + auto weight_k_dim = transpose_b ? weight_shape[kIndex2] : weight_shape[kIndex1]; + + // infer output shape when weight_n_dim is dynamic + if (weight_n_dim == abstract::Shape::kShapeDimAny) { + return ShapeVector{x_m_dim, weight_n_dim}; + } + + if (MS_UNLIKELY(x_k_dim != abstract::Shape::kShapeDimAny && weight_k_dim != abstract::Shape::kShapeDimAny && + x_k_dim != weight_k_dim)) { + MS_LOG(EXCEPTION) << "For GroupedMatmul, input 'x' and 'weight' must have the same dimension of 'k', but got:" + << x_k_dim << " and " << weight_k_dim; + } + // Check shape alignment requirements + // Check K dimension alignment + if (x_k_dim != abstract::Shape::kShapeDimAny) { + const auto kFloat16KAlign16 = 16; + const auto kInt8KAlign32 = 32; + int64_t k_align = (x_dtype == kNumberTypeFloat16) ? kFloat16KAlign16 : kInt8KAlign32; + if (x_k_dim % k_align != 0) { + MS_LOG(EXCEPTION) << "For GroupedMatmul, input 'x' K dimension must be aligned to " << k_align << " for " + << (x_dtype == kNumberTypeFloat16 ? "float16" : "int8") << " input, but got: " << x_k_dim; + } + } + + // Check N dimension alignment + if (MS_UNLIKELY(weight_n_dim != abstract::Shape::kShapeDimAny && weight_n_dim % 16 != 0)) { + MS_LOG(EXCEPTION) << "For GroupedMatmul, input 'weight' N dimension must be aligned to 16, but got: " + << weight_n_dim; + } + + // Check E dimension alignment + if (MS_UNLIKELY(weight_e_dim != abstract::Shape::kShapeDimAny && weight_e_dim % 8 != 0)) { + MS_LOG(EXCEPTION) << "For GroupedMatmul, input 'weight' E dimension (group count) must be aligned to 8, but got: " + << weight_e_dim; + } + + return ShapeVector{x_m_dim, weight_n_dim}; +} + +class OPS_API CustomGroupedMatmulOpFuncImpl : public OpFuncImpl { + public: + ShapeArray InferShape(const PrimitivePtr &primitive, const InferInfoPtrList &input_infos) const override { + // dynamic rank + if (MS_UNLIKELY(input_infos[static_cast(GroupedMatmulInputIndex::kGmmXIndex)]->IsDynamicRank() || + input_infos[static_cast(GroupedMatmulInputIndex::kGmmWeightIndex)]->IsDynamicRank())) { + auto output_shape = ShapeVector{abstract::Shape::kShapeRankAny}; + return {output_shape}; + } + // check input size + if (MS_UNLIKELY(input_infos.size() != static_cast(GroupedMatmulInputIndex::kGmmInputsNum))) { + MS_LOG(EXCEPTION) << "For GroupedMatmul, input size must be " + << static_cast(GroupedMatmulInputIndex::kGmmInputsNum) << ", but got " + << input_infos.size(); + } + + // get input shapes + auto x_shape = input_infos[static_cast(GroupedMatmulInputIndex::kGmmXIndex)]->GetShape(); + auto weight_shape = input_infos[static_cast(GroupedMatmulInputIndex::kGmmWeightIndex)]->GetShape(); + auto group_list_shape = input_infos[static_cast(GroupedMatmulInputIndex::kGmmGroupListIndex)]->GetShape(); + auto transpose_a = + input_infos[static_cast(GroupedMatmulInputIndex::kGmmTransposeAIndex)]->GetScalarValueWithCheck(); + auto transpose_b = + input_infos[static_cast(GroupedMatmulInputIndex::kGmmTransposeBIndex)]->GetScalarValueWithCheck(); + auto x_dtype = input_infos[static_cast(GroupedMatmulInputIndex::kGmmXIndex)]->GetType(); + auto output_shape = + InferGroupedMatmulOutputShape(x_shape, weight_shape, group_list_shape, transpose_a, transpose_b, x_dtype); + return {output_shape}; + } + + std::vector InferType(const PrimitivePtr &primitive, const InferInfoPtrList &input_infos) const override { + MS_EXCEPTION_IF_NULL(primitive); + auto ms_context = MsContext::GetInstance(); + MS_EXCEPTION_IF_NULL(ms_context); + const auto &device_target = ms_context->ascend_soc_version(); + if (device_target != kAscendVersion310p) { + MS_LOG(EXCEPTION) << "For GroupedMatmul, only support Ascend device, but got:" << device_target; + } + auto x_dtype = input_infos[static_cast(GroupedMatmulInputIndex::kGmmXIndex)]->GetType(); + auto weight_dtype = input_infos[static_cast(GroupedMatmulInputIndex::kGmmWeightIndex)]->GetType(); + auto group_list_dtype = input_infos[static_cast(GroupedMatmulInputIndex::kGmmGroupListIndex)]->GetType(); + // check input types + if (MS_UNLIKELY(x_dtype != kNumberTypeFloat16 && x_dtype != kNumberTypeInt8)) { + MS_LOG(EXCEPTION) << "For GroupedMatmul, input 'x' must be float16 or int8, but got:" << x_dtype; + } + if (MS_UNLIKELY(weight_dtype != x_dtype)) { + MS_LOG(EXCEPTION) << "For GroupedMatmul, input 'weight' must be the same type as 'x', but got:" << weight_dtype; + } + if (MS_UNLIKELY(group_list_dtype != kNumberTypeInt32)) { + MS_LOG(EXCEPTION) << "For GroupedMatmul, input 'group_list' must be int32, but got:" << group_list_dtype; + } + auto check_optional_input_valid_dtypes = [](const InferInfoPtr &input_info, std::vector expected_types, + const std::string &input_name) { + if (input_info->IsNone()) { + return; + } + if (std::find(expected_types.begin(), expected_types.end(), input_info->GetType()) == expected_types.end()) { + MS_LOG(EXCEPTION) << "For GroupedMatmul, input '" << input_name << "' must be one of " << expected_types + << ", but got:" << input_info->GetType(); + } + }; + check_optional_input_valid_dtypes(input_infos[static_cast(GroupedMatmulInputIndex::kGmmBiasIndex)], + {kNumberTypeFloat16, kNumberTypeInt32}, "bias"); + check_optional_input_valid_dtypes(input_infos[static_cast(GroupedMatmulInputIndex::kGmmScaleIndex)], + {kNumberTypeInt64, kNumberTypeUInt64, kNumberTypeFloat32}, "scale"); + check_optional_input_valid_dtypes(input_infos[static_cast(GroupedMatmulInputIndex::kGmmPerTokenScaleIndex)], + {kNumberTypeFloat32}, "per_token_scale"); + check_optional_input_valid_dtypes( + input_infos[static_cast(GroupedMatmulInputIndex::kGmmAntiquantScaleIndex)], + {kNumberTypeFloat16, kNumberTypeFloat32}, "antiquant_scale"); + + // infer output types + auto output_dtype = kNumberTypeFloat16; + return {output_dtype}; + } + + bool GeneralInferRegistered() const override { return true; } +}; + +class CustomGroupedMatmul : public InternalKernelMod { + public: + CustomGroupedMatmul() : InternalKernelMod() {} + ~CustomGroupedMatmul() = default; + + void InitKernelInputsOutputsIndex() override { + kernel_inputs_index_ = {static_cast(GroupedMatmulInputIndex::kGmmXIndex), + static_cast(GroupedMatmulInputIndex::kGmmWeightIndex), + static_cast(GroupedMatmulInputIndex::kGmmBiasIndex), + static_cast(GroupedMatmulInputIndex::kGmmScaleIndex), + static_cast(GroupedMatmulInputIndex::kGmmGroupListIndex), + static_cast(GroupedMatmulInputIndex::kGmmPerTokenScaleIndex), + static_cast(GroupedMatmulInputIndex::kGmmAntiquantScaleIndex)}; + kernel_outputs_index_ = {static_cast(GroupedMatmulOutputIndex::kGmmOutputIndex)}; + } + + protected: + internal_v2::InternalOpPtr CreateKernel(const internal_v2::InputsImmutableInfoList &inputs, + const internal_v2::OutputsImmutableInfoList &outputs, + const std::vector &ms_inputs, + const std::vector &ms_outputs) override { + param_.transpose_a = + ms_inputs.at(static_cast(GroupedMatmulInputIndex::kGmmTransposeAIndex))->GetValueWithCheck(); + param_.transpose_b = + ms_inputs.at(static_cast(GroupedMatmulInputIndex::kGmmTransposeBIndex))->GetValueWithCheck(); + param_.with_bias = + !(ms_inputs.at(static_cast(GroupedMatmulInputIndex::kGmmBiasIndex))->GetType()->isa()); + param_.enable_shuffle = false; // the real definition is in internal + auto inputs_clone = inputs; + auto x_format = static_cast( + ms_inputs.at(static_cast(GroupedMatmulInputIndex::kGmmXFormatIndex))->GetValueWithCheck()); + if (x_format == DataFormat::FRACTAL_NZ) { + inputs_clone[static_cast(GroupedMatmulInputIndex::kGmmXIndex)].SetFormat( + internal_v2::TensorFormat::kFormatFRACTAL_NZ); + } + inputs_clone[static_cast(GroupedMatmulInputIndex::kGmmWeightIndex)].SetFormat( + internal_v2::TensorFormat::kFormatFRACTAL_NZ); + return internal_v2::CreateGroupedMatmulOp(inputs_clone, outputs, param_, internal_v2::kInternalGroupedMatmulOpName); + } + + uint64_t GenerateTilingKey(const std::vector &inputs) override { + return InternalTilingCache::GenerateKey(kernel_name_, inputs, param_); + } + + private: + internal_v2::MatmulParam param_; +}; +} // namespace ms_custom_ops + +REG_GRAPH_MODE_OP(grouped_matmul, ms_custom_ops::CustomGroupedMatmulOpFuncImpl, ms_custom_ops::CustomGroupedMatmul); + +// ============================================================================= +// PYBOOST MODE IMPLEMENTATION +// ============================================================================= + +#include "ops/framework/ms_kernels_internal/pyboost/internal_pyboost_runner.h" + +namespace ms_custom_ops { +class GroupedMatmulRunner : public InternalPyboostRunner { + public: + using InternalPyboostRunner::InternalPyboostRunner; + + void SetParam(const bool &transpose_a, const bool &transpose_b, const bool &with_bias) { + param_.transpose_a = transpose_a; + param_.transpose_b = transpose_b; + param_.with_bias = with_bias; + param_.enable_shuffle = false; // the real definition is in internal + } + void SetXFormat(const DataFormat &x_format) { this->x_format_ = x_format; } + + protected: + internal_v2::InternalOpPtr CreateKernel(const internal_v2::InputsImmutableInfoList &inputs, + const internal_v2::OutputsImmutableInfoList &outputs) override { + auto inputs_clone = inputs; + if (x_format_ == DataFormat::FRACTAL_NZ) { + inputs_clone[static_cast(GroupedMatmulInputIndex::kGmmXIndex)].SetFormat( + internal_v2::TensorFormat::kFormatFRACTAL_NZ); + } + inputs_clone[static_cast(GroupedMatmulInputIndex::kGmmWeightIndex)].SetFormat( + internal_v2::TensorFormat::kFormatFRACTAL_NZ); + return internal_v2::CreateGroupedMatmulOp(inputs_clone, outputs, param_, internal_v2::kInternalGroupedMatmulOpName); + } + + private: + internal_v2::MatmulParam param_; + DataFormat x_format_{DataFormat::ND}; +}; + +std::vector npu_grouped_matmul(const ms::Tensor &x, const ms::Tensor &weight, const ms::Tensor &group_list, + const std::optional &bias, + const std::optional &scale, + const std::optional &per_token_scale, + const std::optional &antiquant_scale, const bool &transpose_a, + const bool &transpose_b, const int &x_format) { + auto op_name = "GroupedMatmul"; + auto runner = std::make_shared(op_name); + MS_EXCEPTION_IF_NULL(runner); + + // Set param for internal kernel + runner->SetParam(transpose_a, transpose_b, bias.has_value()); + runner->SetXFormat(static_cast(x_format)); + // Setup the runner with all parameters (including hash calculation) + runner->Setup(op_name, x, weight, group_list, bias, scale, per_token_scale, antiquant_scale, transpose_a, + transpose_b); + + // if you need infer shape and type, you can use this + std::vector inputs = {x, + weight, + GetTensorOrEmpty(bias), + GetTensorOrEmpty(scale), + group_list, + GetTensorOrEmpty(per_token_scale), + GetTensorOrEmpty(antiquant_scale)}; + auto out_shape = InferGroupedMatmulOutputShape(x.shape(), weight.shape(), group_list.shape(), transpose_a, + transpose_b, x.data_type()); + std::vector outputs = {ms::Tensor(TypeId::kNumberTypeFloat16, out_shape)}; + runner->GetOrCreateKernel(inputs, outputs); + runner->Run(inputs, outputs); + return outputs; +} +} // namespace ms_custom_ops + +auto pyboost_grouped_matmul(const ms::Tensor &x, const ms::Tensor &weight, const ms::Tensor &group_list, + const std::optional &bias, const std::optional &scale, + const std::optional &per_token_scale, + const std::optional &antiquant_scale, const bool &transpose_a, + const bool &transpose_b, const int &x_format) { + return ms::pynative::PyboostRunner::Call( + ms_custom_ops::npu_grouped_matmul, x, weight, group_list, bias, scale, per_token_scale, antiquant_scale, + transpose_a, transpose_b, x_format); +} + +MS_CUSTOM_OPS_EXTENSION_MODULE(m) { + m.def("grouped_matmul", &pyboost_grouped_matmul, "GroupedMatmul", pybind11::arg("x"), pybind11::arg("weight"), + pybind11::arg("group_list"), pybind11::arg("bias") = std::nullopt, pybind11::arg("scale") = std::nullopt, + pybind11::arg("per_token_scale") = std::nullopt, pybind11::arg("antiquant_scale") = std::nullopt, + pybind11::arg("transpose_a") = false, pybind11::arg("transpose_b") = false, pybind11::arg("x_format") = 0); +} diff --git a/ops/c_api/grouped_matmul/grouped_matmul_doc.md b/ops/c_api/grouped_matmul/grouped_matmul_doc.md new file mode 100644 index 0000000000000000000000000000000000000000..2dd2bafb8c9d3fd0adcd8abcd8e0df252785f197 --- /dev/null +++ b/ops/c_api/grouped_matmul/grouped_matmul_doc.md @@ -0,0 +1,108 @@ +# grouped_matmul 算子 + +## 描述 + +`grouped_matmul`(分组矩阵乘法)算子针对输入张量 `x` 与权重张量 `weight` 按照分组信息 `group_list` 逐组执行矩阵乘法操作,可选地支持bias(偏置)、scale(缩放因子)、per_token_scale(token级缩放)和antiquant_scale(反量化缩放)等参数。每组输入和权重做独立矩阵乘法,并拼接形成整体输出。该算子适用于高效实现分组全连接、Mixture-of-Experts (MoE) 等需要基于动态分组的场景,并针对Ascend芯片做性能优化。 + +### 计算公式 + +- **非量化场景(FP16)**: + + ```python + y = x * weight[i] + bias[i] + ``` + +- **量化场景(per-channel scale)**: + + ```python + y = (x * weight[i] + bias[i]) * scale[i] + ``` + +- **量化场景(per-token scale)**: + + ```python + y = (x * weight[i] + bias[i]) * scale[i] * per_token_scale + ``` + +其中 `i` 表示第`i`组,`x` 为该组输入片段,`weight[i]`、`bias[i]`、`scale[i]` 为每组权重、偏置、缩放因子,`per_token_scale` 通常 shape 为 [M] 或 [M, 1]。 + +--- + +## 输入参数 + +| Name | DType | Shape | Optional | Inplace | Format | Description | +|-----------------|--------------------------------------------|--------------------------|-------------------------|--------|-------------|--------------------------------------------------------------------------------------------------------------| +| x | Tensor(float16/int8) | [M, K] 或 [K, M] | 否 | 否 | ND/FRACTAL_NZ | 输入特征矩阵,支持 float16、int8。 | +| weight | Tensor(float16/int8) | [E, K, N] 或 [E, N, K]* | 否 | 否 | FRACTAL_NZ | 分组权重张量,支持 float16、int8,E为分组数,必须为NZ格式。 | +| group_list | Tensor(int32) | [E] | 否 | 否 | ND | 分组边界列表,升序,指定每组在x中的范围,必填。 | +| bias | Tensor(float16/int32) | [E, N] | 是,默认None | 否 | ND | 每组偏置项,可选,支持 float16 或 int32。 | +| scale | Tensor(float32/int64/uint64) | [E, N] | 是,默认None | 否 | ND | 每组dequant缩放因子,量化推理时必需,支持 float32、int64 或 uint64。 | +| per_token_scale | Tensor(float32) | [M] 或 [M, 1] | 是,默认None | 否 | ND | 每token缩放系数(动态量化/补偿时可选),仅支持 float32。 | +| antiquant_scale | Tensor(float16/float32) | [E, K, 1] | 是,默认None | 否 | ND | 权重量化反量化缩放因子,int8量化推理配合使用。 | +| transpose_a | bool | 标量 | 是,默认False | 否 | - | x是否转置,默认为False。 | +| transpose_b | bool | 标量 | 是,默认False | 否 | - | weight是否转置,默认为False。 | +| x_format | int | 标量 | 是,默认0 | 否 | - | 输入x的数据格式:0表示ND,1表示FRACTAL_NZ。大shape场景下支持将x转换为FRACTAL_NZ格式时高性能推理。 | + +> - x 的 shape[0] 必须等于 group_list[-1](即最终分组后总token数量一致)。 +> - K轴:在fp16输入场景下必须为16的整数倍,在int8输入场景下必须为32的整数倍。 +> - N轴:必须为16的整数倍。 +> - E轴(分组数):必须为8的整数倍. +> - weight 应转换为 FRACTAL_NZ 格式,推荐用 `ms_custom_ops.trans_data(weight, transdata_type=1)`。 +> - bias、scale、antiquant_scale、per_token_scale等参数如不需要可不提供,不同量化方案可据实际场景选用。 +> - 权重量化场景下,需设置transpose_b=True,此时权重weight的shape支持[E, N, K]。 +> - 支持PyNative/Graph模式下使用。 +> - 对于 int4 量化场景,请使用 `grouped_matmul_w4` 算子。 + +--- + +## 输出参数 + +| Name | DType | Shape | Format | Description | +|--------|----------------------------|-----------|--------|----------------------------| +| out | float16 | [M, N] | ND | 输出拼接后的分组MatMul结果 | + +--- + +## 支持平台 + +- 仅支持昇腾Atlas推理系列芯片。 + +--- + +## 常见问题 + +1. group_list设置不当会导致shape mismatch错误,请确保 `sum(group样本数) == M`。 +2. 权重格式须为NZ,否则算子将报错。 +3. 若使用int8量化推理,请正确提供scale, antiquant_scale相关参数与数据类型。 + +--- + +## 适用场景 + +- MoE(Mixture-of-Experts)动态专家分组全连接 +- 稀疏或分组全连接、主干-分支等模型高效推理 +- 多路分组稀疏矩阵乘法(多分支/专家并行) + +--- + +## 使用示例 + +```python +import mindspore as ms +from mindspore import Tensor +import numpy as np +import ms_custom_ops + +# 假设分4组: 总M=8, K=16, N=32, E=4 +M, K, N, E = 8, 16, 32, 4 +x = Tensor(np.random.randn(M, K).astype(np.float16)) +weight = Tensor(np.random.randn(E, K, N).astype(np.float16)) +group_list = Tensor(np.array([2, 4, 6, 8]).astype(np.int32)) # 每组样本数量累加 +bias = Tensor(np.random.randn(E, N).astype(np.float16)) +scale = Tensor(np.ones((E, N)).astype(np.float32)) + +out = ms_custom_ops.grouped_matmul( + x, weight, group_list, bias, scale +) +print(out.shape) # (8, 32) +``` diff --git a/ops/c_api/grouped_matmul/grouped_matmul_op.yaml b/ops/c_api/grouped_matmul/grouped_matmul_op.yaml new file mode 100644 index 0000000000000000000000000000000000000000..89a92abaf4ab3541c5fce06c2dfd64442e4bcd74 --- /dev/null +++ b/ops/c_api/grouped_matmul/grouped_matmul_op.yaml @@ -0,0 +1,33 @@ +#operator grouped_matmul +grouped_matmul: + args: + x: + dtype: tensor + weight: + dtype: tensor + group_list: + dtype: tensor + bias: + dtype: tensor + default: None + scale: + dtype: tensor + default: None + per_token_scale: + dtype: tensor + default: None + antiquant_scale: + dtype: tensor + default: None + transpose_a: + dtype: bool + default: False + transpose_b: + dtype: bool + default: False + x_format: + dtype: int + default: 0 # 0: ND, 1: FRACTAL_NZ + returns: + out: + dtype: tensor diff --git a/ops/c_api/grouped_matmul_w4/grouped_matmul_w4.cc b/ops/c_api/grouped_matmul_w4/grouped_matmul_w4.cc new file mode 100644 index 0000000000000000000000000000000000000000..d9ee1d9ec113a4a109193554b8bd392d9e07f5cf --- /dev/null +++ b/ops/c_api/grouped_matmul_w4/grouped_matmul_w4.cc @@ -0,0 +1,256 @@ +/** + * Copyright 2025 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// ============================================================================= +// GRAPH MODE IMPLEMENTATION +// ============================================================================= + +#include +#include +#include +#include +#include +#include + +#include "ops/framework/ms_kernels_internal/graphmode/internal_kernel_mod.h" +#include "ops/framework/utils.h" + +namespace ms_custom_ops { +enum class GroupedMatmulW4InputIndex : size_t { + kGmmW4XIndex = 0, + kGmmW4WeightIndex, + kGmmW4GroupListIndex, + kGmmW4BiasIndex, + kGmmW4XScaleIndex, + kGmmW4WeightScaleIndex, + kGmmW4InputsNum, + kGmmW4UnusedScaleIndex, +}; +enum class GroupedMatmulW4OutputIndex : size_t { + kGmmW4OutputIndex = 0, + kGmmW4OutputsNum, +}; + +// ============================================================================= +// UTILITY FUNCTIONS +// ============================================================================= +ShapeVector InferGroupedMatmulW4OutputShape(const ShapeVector &x_shape, const ShapeVector &weight_shape, + const ShapeVector &group_list_shape) { + auto x_shape_size = x_shape.size(); + auto weight_shape_size = weight_shape.size(); + auto group_list_shape_size = group_list_shape.size(); + // check input rank + if (MS_UNLIKELY(x_shape_size != kDim2)) { + MS_LOG(EXCEPTION) << "For GroupedMatmulW4, input 'x' must be 2D, but got:" << x_shape_size; + } + if (MS_UNLIKELY(weight_shape_size != kDim3)) { + MS_LOG(EXCEPTION) << "For GroupedMatmulW4, input 'weight' must be 3D, but got:" << weight_shape_size; + } + if (MS_UNLIKELY(group_list_shape_size != kDim1)) { + MS_LOG(EXCEPTION) << "For GroupedMatmulW4, input 'group_list' must be 1D, but got:" << group_list_shape_size; + } + // get input dimensions + auto x_m_dim = x_shape[kIndex0]; + auto x_k_dim = x_shape[kIndex1]; + auto weight_n_dim = weight_shape[kIndex1]; + auto weight_k_dim = weight_shape[kIndex2]; + + // infer output shape when weight_n_dim is dynamic + if (weight_n_dim == abstract::Shape::kShapeDimAny) { + return ShapeVector{x_m_dim, weight_n_dim}; + } + + // For qint4x2 weight type, adjust the logical dimensions based on storage layout + auto weight_real_k_dim = weight_k_dim; + // Stored as [e, n, k/2], so restore k dimension + weight_real_k_dim <<= 1; + + if (MS_UNLIKELY(x_k_dim != abstract::Shape::kShapeDimAny && weight_k_dim != abstract::Shape::kShapeDimAny && + x_k_dim != weight_real_k_dim)) { + MS_LOG(EXCEPTION) << "For GroupedMatmulW4, input 'x' and 'weight' must have the same dimension of 'k', but got:" + << x_k_dim << " and " << weight_real_k_dim; + } + + return ShapeVector{x_m_dim, weight_n_dim}; +} + +class OPS_API CustomGroupedMatmulW4OpFuncImpl : public OpFuncImpl { + public: + ShapeArray InferShape(const PrimitivePtr &primitive, const InferInfoPtrList &input_infos) const override { + // dynamic rank + if (MS_UNLIKELY(input_infos[static_cast(GroupedMatmulW4InputIndex::kGmmW4XIndex)]->IsDynamicRank() || + input_infos[static_cast(GroupedMatmulW4InputIndex::kGmmW4WeightIndex)]->IsDynamicRank())) { + auto output_shape = ShapeVector{abstract::Shape::kShapeRankAny}; + return {output_shape}; + } + // check input size + if (MS_UNLIKELY(input_infos.size() != static_cast(GroupedMatmulW4InputIndex::kGmmW4InputsNum))) { + MS_LOG(EXCEPTION) << "For GroupedMatmulW4, input size must be " << GroupedMatmulW4InputIndex::kGmmW4InputsNum + << ", but got " << input_infos.size(); + } + + // get input shapes + auto x_shape = input_infos[static_cast(GroupedMatmulW4InputIndex::kGmmW4XIndex)]->GetShape(); + auto weight_shape = input_infos[static_cast(GroupedMatmulW4InputIndex::kGmmW4WeightIndex)]->GetShape(); + auto group_list_shape = + input_infos[static_cast(GroupedMatmulW4InputIndex::kGmmW4GroupListIndex)]->GetShape(); + auto output_shape = InferGroupedMatmulW4OutputShape(x_shape, weight_shape, group_list_shape); + return {output_shape}; + } + + std::vector InferType(const PrimitivePtr &primitive, const InferInfoPtrList &input_infos) const override { + MS_EXCEPTION_IF_NULL(primitive); + auto ms_context = MsContext::GetInstance(); + MS_EXCEPTION_IF_NULL(ms_context); + const auto &device_target = ms_context->ascend_soc_version(); + if (device_target != kAscendVersion310p) { + MS_LOG(EXCEPTION) << "For GroupedMatmulW4, only support Ascend device, but got:" << device_target; + } + auto x_dtype = input_infos[static_cast(GroupedMatmulW4InputIndex::kGmmW4XIndex)]->GetType(); + auto weight_dtype = input_infos[static_cast(GroupedMatmulW4InputIndex::kGmmW4WeightIndex)]->GetType(); + auto group_list_dtype = + input_infos[static_cast(GroupedMatmulW4InputIndex::kGmmW4GroupListIndex)]->GetType(); + auto bias_dtype = input_infos[static_cast(GroupedMatmulW4InputIndex::kGmmW4BiasIndex)]->GetType(); + auto x_scale_dtype = input_infos[static_cast(GroupedMatmulW4InputIndex::kGmmW4XScaleIndex)]->GetType(); + auto weight_scale_dtype = + input_infos[static_cast(GroupedMatmulW4InputIndex::kGmmW4WeightScaleIndex)]->GetType(); + // check input types + if (MS_UNLIKELY(x_dtype != kNumberTypeInt8)) { + MS_LOG(EXCEPTION) << "For GroupedMatmulW4, input 'x' must be int8, but got:" << x_dtype; + } + if (MS_UNLIKELY(weight_dtype != kNumberTypeInt4)) { + MS_LOG(EXCEPTION) << "For GroupedMatmulW4, input 'weight' must be qint4x2, but got:" << weight_dtype; + } + if (MS_UNLIKELY(group_list_dtype != kNumberTypeInt32)) { + MS_LOG(EXCEPTION) << "For GroupedMatmulW4, input 'group_list' must be int32, but got:" << group_list_dtype; + } + if (MS_UNLIKELY(bias_dtype != kNumberTypeFloat32)) { + MS_LOG(EXCEPTION) << "For GroupedMatmulW4, input 'bias' must be float32, but got:" << bias_dtype; + } + if (MS_UNLIKELY(x_scale_dtype != kNumberTypeFloat32)) { + MS_LOG(EXCEPTION) << "For GroupedMatmulW4, input 'x_scale' must be float32, but got:" << x_scale_dtype; + } + if (MS_UNLIKELY(weight_scale_dtype != kNumberTypeFloat32)) { + MS_LOG(EXCEPTION) << "For GroupedMatmulW4, input 'weight_scale' must be float32, but got:" << weight_scale_dtype; + } + // infer output types + auto output_dtype = kNumberTypeFloat16; + return {output_dtype}; + } + + bool GeneralInferRegistered() const override { return true; } +}; + +class CustomGroupedMatmulW4 : public InternalKernelMod { + public: + CustomGroupedMatmulW4() : InternalKernelMod() {} + ~CustomGroupedMatmulW4() = default; + + void InitKernelInputsOutputsIndex() override { + kernel_inputs_index_ = {static_cast(GroupedMatmulW4InputIndex::kGmmW4XIndex), + static_cast(GroupedMatmulW4InputIndex::kGmmW4WeightIndex), + static_cast(GroupedMatmulW4InputIndex::kGmmW4UnusedScaleIndex), + static_cast(GroupedMatmulW4InputIndex::kGmmW4BiasIndex), + static_cast(GroupedMatmulW4InputIndex::kGmmW4GroupListIndex), + static_cast(GroupedMatmulW4InputIndex::kGmmW4XScaleIndex), + static_cast(GroupedMatmulW4InputIndex::kGmmW4WeightScaleIndex)}; + kernel_outputs_index_ = {static_cast(GroupedMatmulW4OutputIndex::kGmmW4OutputIndex)}; + } + + protected: + internal_v2::InternalOpPtr CreateKernel(const internal_v2::InputsImmutableInfoList &inputs, + const internal_v2::OutputsImmutableInfoList &outputs, + const std::vector &ms_inputs, + const std::vector &ms_outputs) override { + internal_v2::MatmulParam param; + param.transpose_a = false; + param.transpose_b = true; // For w4, weight is always transposed + param.with_bias = true; + param.enable_shuffle = false; // the real definition is in internal + auto inputs_clone = inputs; + inputs_clone[static_cast(GroupedMatmulW4InputIndex::kGmmW4WeightIndex)].SetFormat( + internal_v2::TensorFormat::kFormatFRACTAL_NZ); + return internal_v2::CreateGroupedMatmulOp(inputs_clone, outputs, param, internal_v2::kInternalGroupedMatmulOpName); + } +}; +} // namespace ms_custom_ops + +REG_GRAPH_MODE_OP(grouped_matmul_w4, ms_custom_ops::CustomGroupedMatmulW4OpFuncImpl, + ms_custom_ops::CustomGroupedMatmulW4); + +// ============================================================================= +// PYBOOST MODE IMPLEMENTATION +// ============================================================================= + +#include "ops/framework/ms_kernels_internal/pyboost/internal_pyboost_runner.h" + +namespace ms_custom_ops { +class GroupedMatmulW4Runner : public InternalPyboostRunner { + public: + using InternalPyboostRunner::InternalPyboostRunner; + + void SetParam() { + param_.transpose_a = false; + param_.transpose_b = true; // For w4, weight is always transposed + param_.with_bias = true; + param_.enable_shuffle = false; // the real definition is in internal + } + + protected: + internal_v2::InternalOpPtr CreateKernel(const internal_v2::InputsImmutableInfoList &inputs, + const internal_v2::OutputsImmutableInfoList &outputs) override { + auto inputs_clone = inputs; + inputs_clone[static_cast(GroupedMatmulW4InputIndex::kGmmW4WeightIndex)].SetFormat( + internal_v2::TensorFormat::kFormatFRACTAL_NZ); + return internal_v2::CreateGroupedMatmulOp(inputs_clone, outputs, param_, internal_v2::kInternalGroupedMatmulOpName); + } + + private: + internal_v2::MatmulParam param_; +}; + +std::vector npu_grouped_matmul_w4(const ms::Tensor &x, const ms::Tensor &weight, + const ms::Tensor &group_list, const ms::Tensor &bias, + const ms::Tensor &x_scale, const ms::Tensor &weight_scale) { + auto op_name = "GroupedMatmulW4"; + auto runner = std::make_shared(op_name); + MS_EXCEPTION_IF_NULL(runner); + + // Set param for internal kernel + runner->SetParam(); + // Setup the runner with all parameters (including hash calculation) + runner->Setup(op_name, x, weight, group_list, bias, x_scale, weight_scale); + + // if you need infer shape and type, you can use this + std::vector inputs = {x, weight, ms::Tensor(), bias, group_list, x_scale, weight_scale}; + auto out_shape = InferGroupedMatmulW4OutputShape(x.shape(), weight.shape(), group_list.shape()); + std::vector outputs = {ms::Tensor(TypeId::kNumberTypeFloat16, out_shape)}; + runner->GetOrCreateKernel(inputs, outputs); + runner->Run(inputs, outputs); + return outputs; +} +} // namespace ms_custom_ops + +auto pyboost_grouped_matmul_w4(const ms::Tensor &x, const ms::Tensor &weight, const ms::Tensor &group_list, + const ms::Tensor &bias, const ms::Tensor &x_scale, const ms::Tensor &weight_scale) { + return ms::pynative::PyboostRunner::Call(ms_custom_ops::npu_grouped_matmul_w4, x, weight, + group_list, bias, x_scale, weight_scale); +} + +MS_CUSTOM_OPS_EXTENSION_MODULE(m) { + m.def("grouped_matmul_w4", &pyboost_grouped_matmul_w4, "GroupedMatmulW4", pybind11::arg("x"), pybind11::arg("weight"), + pybind11::arg("group_list"), pybind11::arg("bias"), pybind11::arg("x_scale"), pybind11::arg("weight_scale")); +} diff --git a/ops/c_api/grouped_matmul_w4/grouped_matmul_w4_doc.md b/ops/c_api/grouped_matmul_w4/grouped_matmul_w4_doc.md new file mode 100644 index 0000000000000000000000000000000000000000..46ac113869d9b5fdf896ab3d1899ce3cf5bfe5c6 --- /dev/null +++ b/ops/c_api/grouped_matmul_w4/grouped_matmul_w4_doc.md @@ -0,0 +1,101 @@ +# grouped_matmul_w4 算子 + +## 描述 + +`grouped_matmul_w4`(分组矩阵乘法,int4量化版本)算子针对输入张量 `x` 与权重张量 `weight`(int4量化)按照分组信息 `group_list` 逐组执行矩阵乘法操作,支持bias(偏置)、x_scale(输入缩放因子)和weight_scale(权重量化缩放因子)等参数。每组输入和权重做独立矩阵乘法,并拼接形成整体输出。该算子适用于高效实现分组全连接、Mixture-of-Experts (MoE) 等需要基于动态分组的场景,并针对Ascend芯片做性能优化。 + +### 计算公式 + +- **int4量化场景**: + + ```python + y = (x * weight[i] + bias[i]) * x_scale * weight_scale[i] + ``` + +其中 `i` 表示第`i`组,`x` 为该组输入片段,`weight[i]`、`bias[i]`、`weight_scale[i]` 为每组权重、偏置、权重量化缩放因子,`x_scale` 为输入缩放因子。 + +--- + +## 输入参数 + +| Name | DType | Shape | Optional | Format | Description | +|--------------|--------------------------|--------------------------|-------------------------|-------------|--------------------------------------------------------------------------------------------------------------| +| x | Tensor(int8) | [M, K] | No | ND | 输入特征矩阵,支持 int8。 | +| weight | Tensor(qint4x2) | [E, N, K/2] | No | FRACTAL_NZ | 分组权重张量,支持 qint4x2(int4量化),E为分组数,必须为NZ格式。 | +| group_list | Tensor(int32) | [E] | No | ND | 分组边界列表,升序,指定每组在x中的范围,必填。 | +| bias | Tensor(float32) | [E, N] | No | ND | 每组偏置项,仅支持 float32。 | +| x_scale | Tensor(float32) | [M] 或 [M, 1] | No | ND | 输入缩放因子,仅支持 float32。 | +| weight_scale | Tensor(float32) | [E, K//g, N] 或类似 | No | ND | 权重量化缩放因子,int4量化推理配合使用,仅支持 float32。 | + +> - x 的 shape[0] 必须等于 group_list[-1](即最终分组后总token数量一致)。 +> - M轴:支持动态shape。 +> - K、N轴:仅支持以下组合:[256, 7168]、[512, 7168]、[7168, 512]、[7168, 1024]。 +> - E轴(分组数):仅支持 256。 +> - g值(权重量化分组大小):仅支持 256。 +> - weight 为 int4 量化格式(qint4x2),存储为 [E, N, K/2] 形状,其中 K/2 是因为两个4bit值打包到一个8bit中。 +> - weight 必须转换为FRACTAL_NZ格式,推荐用 `ms_custom_ops.trans_data(weight, transdata_type=1)`。 +> - bias、x_scale、weight_scale 为必需参数,必须提供。 +> - 支持PyNative/Graph模式下使用。 +> - 对于 float16 和 int8 量化场景,请使用 `grouped_matmul` 算子。 + +--- + +## 输出参数 + +| 名称 | DType | Shape | Format | 说明 | +|--------|----------|--------|--------|----------------------------| +| out | float16 | [M, N] | ND | 输出拼接后的分组MatMul结果 | + +--- + +## 支持平台 + +- 仅支持昇腾Atlas推理系列芯片。 + +--- + +## 常见问题 + +1. group_list设置不当会导致shape mismatch错误,请确保 `sum(group样本数) == M`。 +2. 权重格式须为NZ,否则算子将报错。 +3. 若使用int4量化推理,请正确提供bias、x_scale、weight_scale相关参数与数据类型。 + +--- + +## 适用场景 + +- MoE(Mixture-of-Experts)动态专家分组全连接(int4量化) +- 稀疏或分组全连接、主干-分支等模型高效推理(int4量化) +- 多路分组稀疏矩阵乘法(多分支/专家并行,int4量化) + +--- + +## 使用示例 + +```python +import mindspore as ms +from mindspore import Tensor +import numpy as np +import ms_custom_ops + +# 假设分4组: 总M=8, K=16, N=32, E=4 +M, K, N, E = 8, 16, 32, 4 +x = Tensor(np.random.randint(-128, 127, (M, K)).astype(np.int8)) +# weight 为 int4 量化格式,需要转换为 qint4x2 +weight = Tensor(np.random.randint(0, 255, (E, N, K // 2)).astype(np.uint8)) +weight = ms_custom_ops.type_cast(weight, ms.qint4x2) +group_list = Tensor(np.array([2, 4, 6, 8]).astype(np.int32)) +bias = Tensor(np.random.randn(E, N).astype(np.float16)) +x_scale = Tensor(np.random.randn(M).astype(np.float32)) +weight_scale = Tensor(np.random.randn(E, K // 256, N).astype(np.float16)) + +# weight 必须转换为 NZ 格式 +w_i8 = ms_custom_ops.type_cast(weight, ms.int8) +w_i8_nz = ms_custom_ops.trans_data(w_i8, transdata_type=1) +weight_nz = ms_custom_ops.type_cast(w_i8_nz, ms.qint4x2) + +out = ms_custom_ops.grouped_matmul_w4( + x, weight_nz, group_list, bias, x_scale, weight_scale +) +print(out.shape) # (8, 32) +``` diff --git a/ops/c_api/grouped_matmul_w4/grouped_matmul_w4_op.yaml b/ops/c_api/grouped_matmul_w4/grouped_matmul_w4_op.yaml new file mode 100644 index 0000000000000000000000000000000000000000..b97a56aaebb068ce5194ec4bef2f17f98767ea4d --- /dev/null +++ b/ops/c_api/grouped_matmul_w4/grouped_matmul_w4_op.yaml @@ -0,0 +1,18 @@ +#operator grouped_matmul_w4 +grouped_matmul_w4: + args: + x: + dtype: tensor + weight: + dtype: tensor + group_list: + dtype: tensor + bias: + dtype: tensor + x_scale: + dtype: tensor + weight_scale: + dtype: tensor + returns: + out: + dtype: tensor diff --git a/ops/framework/ms_kernels_internal/graphmode/internal_kernel_mod.cc b/ops/framework/ms_kernels_internal/graphmode/internal_kernel_mod.cc index b7dbddecdae9e917a9bdd48b11fd03d2a19bce2b..dc6fc61129ed380c9ffe2ecdc2f36ab81a0bd86d 100644 --- a/ops/framework/ms_kernels_internal/graphmode/internal_kernel_mod.cc +++ b/ops/framework/ms_kernels_internal/graphmode/internal_kernel_mod.cc @@ -44,14 +44,7 @@ bool InternalKernelMod::Init(const std::vector &inputs, const st } for (size_t i = 0; i < inputs.size(); i++) { - bool is_include = false; - for (auto idx : kernel_inputs_index_) { - if (i == idx) { - is_include = true; - break; - } - } - if (!is_include) { + if (!(std::any_of(kernel_inputs_index_.begin(), kernel_inputs_index_.end(), [=](auto idx) { return idx == i; }))) { recreate_cared_indices_.emplace_back(i); } } @@ -197,12 +190,20 @@ void InternalKernelMod::GetInternalKernel(const std::vector &inp internal_v2::InputsImmutableInfoList inputs_ii; internal_v2::OutputsImmutableInfoList outputs_ii; for (auto i : kernel_inputs_index_) { + if (i >= inputs.size()) { + inputs_ii.emplace_back(internal_v2::DataType::kTypeNone, internal_v2::TensorFormat::kFormatND); + continue; + } auto dtype = TransInternalDataType(inputs[i]->dtype_id()); auto format = TransInternalFormat(inputs[i]->format()); inputs_ii.emplace_back(dtype, format); } for (auto i : kernel_outputs_index_) { + if (i >= outputs.size()) { + outputs_ii.emplace_back(internal_v2::DataType::kTypeNone, internal_v2::TensorFormat::kFormatND); + continue; + } auto dtype = TransInternalDataType(outputs[i]->dtype_id()); auto format = TransInternalFormat(outputs[i]->format()); outputs_ii.emplace_back(dtype, format); @@ -245,6 +246,10 @@ int InternalKernelMod::Resize(const std::vector &inputs, const s size_t idx = 0; for (auto i : kernel_inputs_index_) { + if (i >= inputs.size()) { + internal_inputs_shape_[idx++] = std::move(internal_v2::ShapeInfo{}); + continue; + } auto shape = TransInternalShape(inputs[i]->GetShapeVector()); if (inputs[i]->dtype_id() == kMetaTypeNone) { shape = {}; @@ -254,6 +259,10 @@ int InternalKernelMod::Resize(const std::vector &inputs, const s idx = 0; for (auto i : kernel_outputs_index_) { + if (i >= outputs.size()) { + internal_outputs_shape_[idx++] = std::move(internal_v2::ShapeInfo{}); + continue; + } auto shape = TransInternalShape(outputs[i]->GetShapeVector()); if (outputs[i]->dtype_id() == kMetaTypeNone) { shape = {}; @@ -279,10 +288,18 @@ void InternalKernelMod::UpdateAddr(const std::vector &inputs, const std::vector &workspace) { size_t idx = 0; for (auto i : kernel_inputs_index_) { + if (i >= inputs.size()) { + internal_inputs_addr_[idx++] = nullptr; + continue; + } internal_inputs_addr_[idx++] = inputs[i]->device_ptr(); } idx = 0; for (auto i : kernel_outputs_index_) { + if (i >= outputs.size()) { + internal_outputs_addr_[idx++] = nullptr; + continue; + } internal_outputs_addr_[idx++] = outputs[i]->device_ptr(); } diff --git a/ops/framework/ms_kernels_internal/internal_helper.cc b/ops/framework/ms_kernels_internal/internal_helper.cc index a161baf3496032d25c43ffa2954238fcd1d20afb..716917e47b05a3c9c13a02184422b237f4a9990a 100644 --- a/ops/framework/ms_kernels_internal/internal_helper.cc +++ b/ops/framework/ms_kernels_internal/internal_helper.cc @@ -34,6 +34,7 @@ internal_v2::DataType TransInternalDataType(TypeId ms_type) { {kNumberTypeInt8, internal_v2::DataType::kTypeInt8}, {kNumberTypeUInt8, internal_v2::DataType::kTypeUint8}, {kNumberTypeInt64, internal_v2::DataType::kTypeInt64}, + {kNumberTypeInt4, internal_v2::DataType::kTypeInt4}, {kNumberTypeUInt64, internal_v2::DataType::kTypeUint64}, {kNumberTypeComplex64, internal_v2::DataType::kTypeComplex64}, {kNumberTypeComplex128, internal_v2::DataType::kTypeComplex128}, diff --git a/tests/st/test_grouped_matmul.py b/tests/st/test_grouped_matmul.py new file mode 100644 index 0000000000000000000000000000000000000000..21eb3e22e5d53a3385d7325b5021070bebb93a31 --- /dev/null +++ b/tests/st/test_grouped_matmul.py @@ -0,0 +1,345 @@ +# Copyright 2025 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import os +import random +import numpy as np +import pytest +from functools import wraps + +from st_utils import custom_compare + +import mindspore as ms +from mindspore import context, Tensor, ops, Profiler +from mindspore.nn import Cell +import ms_custom_ops + +np.set_printoptions(precision=2, suppress=True, linewidth=200) + + +def jit(func): + @wraps(func) + def decorator(*args, **kwargs): + if context.get_context("mode") == context.PYNATIVE_MODE: + return func(*args, **kwargs) + return ms.jit(func, jit_level="O0", infer_boost="on")(*args, **kwargs) + return decorator + + +def process_deq_scale(deq_scale) -> np.ndarray: + new_deq_scale = np.frombuffer(deq_scale.tobytes(), dtype=np.uint32) + return new_deq_scale.astype(np.int64) + + +def split_x(x, group_list): + x_split = [] + for i in range(len(group_list)): + if i == 0: + x_split.append(x[0: group_list[i],]) + else: + x_split.append(x[group_list[i - 1]: group_list[i],]) + return x_split + + +def split_w(w): + tmp_split = np.split(w, w.shape[0], axis=0) + w_split = [] + for t in tmp_split: + w_split.append(np.squeeze(t, 0)) + return w_split + + +def np_qbmm_compute(a, b, tmp_scale, bias=None, tmp_pertoken_scale=None): + c = np.dot(a.astype(np.float32), b.astype(np.float32)).astype(np.int32) + if bias is not None: + c = c + bias + c = c.astype(np.float32) * tmp_scale + if tmp_pertoken_scale is not None: + per_token_scale = tmp_pertoken_scale[:, np.newaxis] + c = c * per_token_scale + c = c.astype(np.float16) + return c + + +class Net(Cell): + def __init__(self): + super().__init__() + + @jit + def construct(self, x, weight, group_list, bias=None, scale=None, per_token_scale=None, antiquant_scale=None, transpose_a=False, transpose_b=False): + out = ms_custom_ops.grouped_matmul( + x, weight, group_list, bias, scale, per_token_scale, antiquant_scale, transpose_a, transpose_b) + return out + + +def custom_grouped_matmul(m, k, n, e, trans_a=False, trans_b=False, profiling=False, with_bias=True, with_pertoken_scale=False, scale_dtype=ms.int64): + os.environ['INTERNAL_PRINT_TILING'] = "on" + + # numpy calculate + np_x_all = np.random.uniform(-20, 20, size=[m, k]).astype(np.int8) + np_w_all = np.random.uniform(-20, 20, size=[e, k, n]).astype(np.int8) + np_b_all = np.random.randint(-10, 10, (e, n)).astype(np.int32) + np_s_all = np.random.rand(e, n).astype(np.float32) / 1000 + scale = process_deq_scale(np_s_all) + np_ps_all = np.random.rand(m).astype(np.float32) + group_list_np = np.array(generate_random_numbers(m, e)).astype(np.int32) + + def compute_numpy_result(np_x_all, np_w_all, np_b_all, np_s_all, np_ps_all, group_list_np, with_bias, with_pertoken_scale): + # use group_list split x. [(G0, k), (G1, k)....(GN, k)] + np_x = split_x(np_x_all, group_list_np) + np_w = split_w(np_w_all) # [(k, n), (k, n)....(k, n)] + np_b = split_w(np_b_all) # [(n), (n)....(n)] + np_s = split_w(np_s_all) # [(n), (n)....(n)] + # use group_list split per_token_scale. [(G0,), (G1,)....(GN,)] + np_ps = split_x(np_ps_all, group_list_np) + + if not with_bias and not with_pertoken_scale: + res_np = [np_qbmm_compute(x0, w0, s0) + for x0, w0, s0 in zip(np_x, np_w, np_s)] + elif with_bias and not with_pertoken_scale: + res_np = [np_qbmm_compute(x0, w0, s0, b0) + for x0, w0, s0, b0 in zip(np_x, np_w, np_s, np_b)] + if not with_bias and with_pertoken_scale: + res_np = [np_qbmm_compute(x0, w0, s0, None, ps0) + for x0, w0, s0, ps0 in zip(np_x, np_w, np_s, np_ps)] + elif with_bias and with_pertoken_scale: + res_np = [np_qbmm_compute(x0, w0, s0, b0, ps0) for x0, w0, s0, b0, ps0 in zip( + np_x, np_w, np_s, np_b, np_ps)] + + expect_np = np.concatenate(res_np, axis=0) + return expect_np + + expect_np = compute_numpy_result( + np_x_all, np_w_all, np_b_all, np_s_all, np_ps_all, group_list_np, with_bias, with_pertoken_scale) + + # ms calculate + if trans_b: + np_w_all = np.transpose(np_w_all, (0, 2, 1)) + x = Tensor(np_x_all) # [m, k] + w = Tensor(np_w_all) # [e, k, n] + + # weight must be NZ format so do transdata before + w_nz = ms_custom_ops.trans_data(w, transdata_type=1) + weight = ms.Parameter(w_nz, requires_grad=False) + + if scale_dtype == ms.float32: + s = Tensor(np_s_all, ms.float32) # [e, n] + elif scale_dtype == ms.int64: + s = Tensor(scale, ms.int64) # [e, n] + else: + raise ValueError( + f"scale_dtype must be float32 or int64, but got {scale_dtype}") + scale = ms.Parameter(s, requires_grad=False) if s is not None else None + + ps = ms.Parameter(Tensor(np_ps_all, ms.float32), + requires_grad=False) if with_pertoken_scale else None + b = ms.Parameter(Tensor(np_b_all, ms.int32), + requires_grad=False) if with_bias else None + + group_list = Tensor(group_list_np, dtype=ms.int32) + gmm_net = Net() + + for _ in range(50 if profiling else 1): + output = gmm_net(x, weight, group_list, b, scale, ps, + transpose_a=trans_a, transpose_b=trans_b) + if profiling: + return + res = custom_compare(output.astype( + ms.float32).asnumpy(), expect_np, ms.float16) + assert res, "matmul compare fail." + + +def generate_random_numbers(m, e): + # 生成e-1个互不相同的随机数,范围是1到n,但不包括m + random_numbers = random.choices([i for i in range(1, m+1)], k=e-1) + # 将m添加到列表的末尾 + random_numbers.append(m) + # 将列表从小到大排序 + random_numbers.sort() + return random_numbers + + +@pytest.mark.level0 +@pytest.mark.platform_ascend310p +@pytest.mark.parametrize('m', [32, 1000]) +@pytest.mark.parametrize('with_pertoken_scale', [True, False]) +@pytest.mark.parametrize('exec_mode', [context.GRAPH_MODE, context.PYNATIVE_MODE]) +@pytest.mark.env_onecard +def test_gmm_v4_with_scale(m, with_pertoken_scale, exec_mode): + """ + Feature: test matmul operator in graph mode or pynative mode + Description: test matmul with per_token_scale and without per_token_scale + Expectation: the result is correct + """ + ms.set_device("Ascend") + ms.set_context(mode=exec_mode) + custom_grouped_matmul(m, 64, 128, 8, trans_b=True, + with_pertoken_scale=with_pertoken_scale, scale_dtype=ms.float32) + + +@pytest.mark.level1 +@pytest.mark.platform_ascend310p +@pytest.mark.env_onecard +@pytest.mark.parametrize('inputs_shape', [[40, 512, 7168], [1024, 2048, 7168], [1024, 7168, 4096]]) +def test_gmm_v4_with_scale_ds(inputs_shape): + """ + Feature: test grouped matmul with scale + Description: test grouped matmul with scale + Expectation: the result is correct + """ + ms.set_device("Ascend") + ms.set_context(mode=context.GRAPH_MODE) + e = 256 + custom_grouped_matmul(*inputs_shape, e, trans_b=True, + with_pertoken_scale=True, scale_dtype=ms.float32) + + +def grouped_quant_matmul(m, k, n, e, trans_a=False, trans_b=False, profiling=False, with_bias=False): + os.environ['INTERNAL_PRINT_TILING'] = "on" + + # numpy calculate + np_x_all = np.random.uniform(-20, 20, size=[m, k]).astype(np.int8) + np_w_all = np.random.uniform(-20, 20, size=[e, k, n]).astype(np.int8) + np_b_all = np.random.randint(-10, 10, (e, n)).astype(np.int32) + np_s_all = np.random.rand(e, n).astype(np.float32) / 1000 + + scale = process_deq_scale(np_s_all) + group_list_np = np.array(generate_random_numbers(m, e)).astype(np.int32) + + # use group_list split x. [(G0, n), (G1, n)....(GN, n)] + np_x = split_x(np_x_all, group_list_np) + np_w = split_w(np_w_all) # [(k, n), (k, n)....(k, n)] + np_b = split_w(np_b_all) # [(n), (n)....(n)] + np_s = split_w(np_s_all) # [(n), (n)....(n)] + if not with_bias: + res_np = [np_qbmm_compute(x0, w0, s0) + for x0, w0, s0 in zip(np_x, np_w, np_s)] + else: + res_np = [np_qbmm_compute(x0, w0, s0, b0) + for x0, w0, s0, b0 in zip(np_x, np_w, np_s, np_b)] + expect_np = np.concatenate(res_np, axis=0) + + # ms calculate + if trans_b: + np_w_all = np.transpose(np_w_all, (0, 2, 1)) + x = ms.Tensor(np_x_all) # [m, k] + w = ms.Tensor(np_w_all) # [e, k, n] + s = ms.Tensor(scale, ms.int64) # [e, n] + + # [e, n] + b = ms.Parameter(Tensor(np_b_all, ms.int32), + requires_grad=False) if with_bias else None + + group_list = Tensor(group_list_np, dtype=ms.int32) + w_nz = ms_custom_ops.trans_data(w, transdata_type=1) + weight = ms.Parameter(w_nz, requires_grad=False) + + gmm_net = Net() + + for _ in range(50 if profiling else 1): + output = gmm_net(x, weight, group_list, bias=b, scale=s, per_token_scale=None, antiquant_scale=None, + transpose_a=trans_a, transpose_b=trans_b) + if profiling: + return + res = custom_compare(output.astype( + ms.float32).asnumpy(), expect_np, ms.float16) + assert res, "matmul compare fail." + + +def grouped_matmul(m, k, n, e, trans_a=False, trans_b=False, profiling=False): + os.environ['INTERNAL_PRINT_TILING'] = "on" + + # numpy calculate + np_x_all = np.random.uniform(0.1, 2, size=[m, k]).astype(np.float16) + np_w_all = np.random.uniform(0.1, 1, size=[e, k, n]).astype(np.float16) + group_list_np = np.array(generate_random_numbers(m, e)).astype(np.int32) + + # use group_list split x. [(G0, n), (G1, n)....(GN, n)] + np_x = split_x(np_x_all, group_list_np) + np_w = split_w(np_w_all) # [(k, n), (k, n)....(k, n)] + res_np = [np.matmul(x0, w0) for x0, w0 in zip(np_x, np_w)] + expect_np = np.concatenate(res_np, axis=0) + + # ms calculate + if trans_b: + np_w_all = np.transpose(np_w_all, (0, 2, 1)) + x = ms.Tensor(np_x_all) # [m, k] + w = ms.Tensor(np_w_all) # [e, k, n] + + group_list = ms.Tensor(group_list_np, dtype=ms.int32) + w_nz = ms_custom_ops.trans_data(w, transdata_type=1) + weight = ms.Parameter(w_nz, requires_grad=False) + gmm_net = Net() + for _ in range(50 if profiling else 1): + output = gmm_net(x, weight, group_list, bias=None, scale=None, per_token_scale=None, antiquant_scale=None, + transpose_a=trans_a, transpose_b=trans_b) + if profiling: + return + res = custom_compare(output.astype( + ms.float32).asnumpy(), expect_np, ms.float16) + assert res, "matmul compare fail." + + +@pytest.mark.level0 +@pytest.mark.platform_ascend310p +@pytest.mark.parametrize('input_shape', [[32, 32, 64, 8], [1000, 256, 512, 16]]) +@pytest.mark.parametrize('exec_mode', [context.GRAPH_MODE, context.PYNATIVE_MODE]) +@pytest.mark.env_onecard +def test_gmm_increment(input_shape, exec_mode): + """ + Feature: test matmul operator in graph mode + Description: test matmul. + Expectation: the result is correct + """ + ms.set_device("Ascend") + ms.set_context(mode=exec_mode) + grouped_matmul(*input_shape, trans_b=True) + grouped_quant_matmul(*input_shape, trans_b=True) + + +@pytest.mark.level0 +@pytest.mark.platform_ascend310p +@pytest.mark.parametrize('input_shape', [[32, 32, 64, 8], [1000, 256, 512, 16]]) +@pytest.mark.env_onecard +def test_gmm_quant_with_bias(input_shape): + """ + Feature: test matmul operator in graph mode + Description: test matmul. + Expectation: the result is correct + """ + ms.set_device("Ascend") + ms.set_context(mode=context.GRAPH_MODE) + grouped_quant_matmul(*input_shape, trans_b=True, with_bias=True) + + +@pytest.mark.level1 +@pytest.mark.platform_ascend310p +@pytest.mark.parametrize('m', [1024]) +@pytest.mark.parametrize('e', [8]) +@pytest.mark.parametrize("prof_flag_str", [0]) +@pytest.mark.env_onecard +def test_moe_real_shape(m, e, prof_flag_str): + """ + Feature: test matmul operator in graph mode + Description: test matmul. + Expectation: the result is correct + """ + ms.set_device("Ascend") + ms.set_context(mode=context.GRAPH_MODE) + prof_flag = bool(int(prof_flag_str)) + profiler = Profiler(start_profile=False, output_path="profiler") + profiler.start() + grouped_matmul(m, 5120, 2688, e, profiling=prof_flag, trans_b=True) + grouped_matmul(m, 5120, 1344, e, profiling=prof_flag, trans_b=True) + profiler.stop() + profiler.analyse() diff --git a/tests/st/test_grouped_matmul_w4.py b/tests/st/test_grouped_matmul_w4.py new file mode 100644 index 0000000000000000000000000000000000000000..955e89bf924000deb3764280e09f43226a926e92 --- /dev/null +++ b/tests/st/test_grouped_matmul_w4.py @@ -0,0 +1,173 @@ +# Copyright 2025 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import os +import random +import numpy as np +import pytest +from functools import wraps + +from st_utils import custom_compare + +import mindspore as ms +from mindspore import context, Tensor +from mindspore.nn import Cell +import ms_custom_ops + +np.set_printoptions(precision=2, suppress=True, linewidth=200) + + +def jit(func): + @wraps(func) + def decorator(*args, **kwargs): + if context.get_context("mode") == context.PYNATIVE_MODE: + return func(*args, **kwargs) + return ms.jit(func, jit_level="O0", infer_boost="on")(*args, **kwargs) + return decorator + + +def split_x(x, group_list): + x_split = [] + for i in range(len(group_list)): + if i == 0: + x_split.append(x[0: group_list[i],]) + else: + x_split.append(x[group_list[i - 1]: group_list[i],]) + return x_split + + +def split_w(w): + tmp_split = np.split(w, w.shape[0], axis=0) + w_split = [] + for t in tmp_split: + w_split.append(np.squeeze(t, 0)) + return w_split + + +def generate_random_numbers(m, e): + # 生成e-1个互不相同的随机数,范围是1到n,但不包括m + random_numbers = random.choices([i for i in range(1, m+1)], k=e-1) + # 将m添加到列表的末尾 + random_numbers.append(m) + # 将列表从小到大排序 + random_numbers.sort() + return random_numbers + + +class Net(Cell): + def __init__(self): + super().__init__() + + @jit + def construct(self, x, weight, group_list, bias, x_scale, weight_scale): + out = ms_custom_ops.grouped_matmul_w4( + x, weight, group_list, bias, x_scale, weight_scale) + return out + + +def grouped_matmul_int4_swft(m, k, n, e): + os.environ['INTERNAL_PRINT_TILING'] = "on" + group_list_np = np.array(generate_random_numbers(m, e)).astype(np.int32) + g = 256 + + def dyn_quant(x_fp16): + x_abs = np.abs(x_fp16) + x_max = np.max(x_abs, axis=-1, keepdims=True) + anti_scale = x_max.astype(np.float32) / 127.0 + x_int8 = np.round(x_fp16.astype(np.float32) / + anti_scale).astype(np.int8) + return x_int8, anti_scale + + def quant(y_fp16): + y_fp = y_fp16.reshape((e, k // g, g, n)) + y_max = np.max(np.abs(y_fp), keepdims=True, axis=-2) + scale = (y_max.astype(np.float32) / 7.0) + y_int8 = np.round(y_fp.astype(np.float32) / + scale).astype(np.int8).reshape((e, k, n)) + return y_int8, scale + + np_x_fp = np.random.uniform(-0.3, 0.3, [m, k]).astype(np.float16) + np_x_all, np_x_scale = dyn_quant(np_x_fp) + np_w_fp = np.random.uniform(-0.3, 0.3, [e, k, n]).astype(np.float16) + np_w_all, np_y_scale = quant(np_w_fp) + bias = np.ones([e, 1, k]).astype(np.float16) * 8 + np_w_fp16 = np_w_all.reshape(e, k//g, g, n).astype(np.float32) * np_y_scale + np_w_fp16 = np_w_fp16.reshape(e, k, n) + np_bias = np.matmul(bias, np_w_fp16).astype(np.float32) + + np_x = split_x(np_x_all, group_list_np) + np_w = split_w(np_w_fp16) + np_p = split_x(np_x_scale, group_list_np) + res_np = [(np.matmul(x0.astype(np.float16), w0) * p0) + for x0, w0, p0 in zip(np_x, np_w, np_p)] + expect_np = np.concatenate(res_np, axis=0) + + def i8toi4(y_int8): + input_x = ((y_int8 + 16) % 16).astype(np.uint8).reshape(-1) + input_y = (input_x[1::2] << 4) | input_x[::2] + return input_y + np_w_all_int4 = i8toi4(np_w_all.transpose(0, 2, 1)).reshape(e, n, k // 2) + + x = Tensor(np_x_all) + w = Tensor(np_w_all_int4, dtype=ms.qint4x2) + weight_scale = Tensor(np_y_scale.reshape(e, k // g, n)) + bias_tensor = Tensor(np_bias.reshape(e, n)) + x_scale = Tensor(np_x_scale.reshape(m,)) + + group_list = Tensor(group_list_np, dtype=ms.int32) + + # weight must be NZ format so do transdata before + w_i8 = ms_custom_ops.type_cast(w, ms.int8) + w_i8_nz = ms_custom_ops.trans_data(w_i8, transdata_type=1) + w_i4_nz = ms_custom_ops.type_cast(w_i8_nz, ms.qint4x2) + weight = ms.Parameter(w_i4_nz, requires_grad=False) + + gmm_net = Net() + output = gmm_net(x, weight, group_list, bias_tensor, x_scale, weight_scale) + res = custom_compare(output.astype( + ms.float32).asnumpy(), expect_np, ms.float16) + assert res, "matmul compare fail." + + +@pytest.mark.level0 +@pytest.mark.platform_ascend310p +@pytest.mark.env_onecard +@pytest.mark.parametrize('batch_size', [5, 17]) +@pytest.mark.parametrize('inputs_shape', [[256, 7168, 256]]) +@pytest.mark.parametrize('exec_mode', [context.GRAPH_MODE, context.PYNATIVE_MODE]) +def test_gmm_w4_swft_0(batch_size, inputs_shape, exec_mode): + """ + Feature: test grouped matmul w4 with scale + Description: test grouped matmul w4 with scale + Expectation: the result is correct + """ + ms.set_device("Ascend") + ms.set_context(mode=exec_mode) + grouped_matmul_int4_swft(batch_size, *inputs_shape) + + +@pytest.mark.level1 +@pytest.mark.platform_ascend310p +@pytest.mark.env_onecard +@pytest.mark.parametrize('batch_size', [8, 32]) +@pytest.mark.parametrize('inputs_shape', [[7168, 512, 256], [512, 7168, 256], [7168, 1024, 256]]) +def test_gmm_w4_swft_1(batch_size, inputs_shape): + """ + Feature: test matmul operator in graph mode + Description: test matmul. + Expectation: the result is correct + """ + ms.set_device("Ascend") + ms.set_context(mode=context.GRAPH_MODE) + grouped_matmul_int4_swft(batch_size, *inputs_shape) +