diff --git a/docs/map_from_buildin_to_custom.md b/docs/map_from_buildin_to_custom.md index 40cbe07ed6586afbb8185d5890188d8e41c7edc8..23a80ff4800e1e507255149ec7376febe59bbca6 100644 --- a/docs/map_from_buildin_to_custom.md +++ b/docs/map_from_buildin_to_custom.md @@ -13,4 +13,5 @@ | ops.auto_generate.apply_rotary_pos_emb | [ms_custom_ops.apply_rotary_pos_emb_atb](../ops/c_api/apply_rotary_pos_emb_atb/apply_rotary_pos_emb_atb.md) | 新增atb的apply_rotary_pos_emb_atb算子,代替ops.auto_generate.apply_rotary_pos_emb,注意rotary_coeff和cos_format有变化,详见[API](https://gitee.com/mindspore/ms_custom_ops/blob/master/ops/c_api/apply_rotary_pos_emb_atb/apply_rotary_pos_emb_atb.md) | | ops.moe_token_unpermute | [ms_custom_ops.moe_token_unpermute](../ops/c_api/moe_token_unpermute/moe_token_unpermute.md) | 接口参数一致, 但需注意:ops接口只支持A2训练芯片,ms_custom_ops场景下只支持Atlas推理系列产品, 并且ms_custom_ops场景下当前仅支持:`padded_mode = false, restore_shape = None`, topK 支持 1、2,、4、8, hidden_size 支持 2048、5120、7168。 | | ops.auto_generate.GroupedMatmul | [ms_custom_ops.grouped_matmul](../ops/c_api/grouped_matmul/grouped_matmul.md) | 接口变更:输入从tuple[tensor]改为tensor,group_list前移到weight之后并改为必传参数,移除了offset、antiquant_offset、split_item、group_type等参数,返回从tuple[tensor]改为tensor,基于Internal框架实现,仅支持 Atlas 推理系列 | -| ops.auto_generate.GroupedMatmulV4 | [ms_custom_ops.grouped_matmul](../ops/c_api/grouped_matmul/grouped_matmul.md) | 接口变更:输入从tuple[tensor]改为tensor,group_list前移到weight之后并改为必传参数,per_token_scale改为per_token_scale,移除了offset、antiquant_offset、activation相关参数、split_item、group_type、group_list_type、act_type、output_dtype等参数,返回从tuple[tensor]改为tensor,增加transpose_a、transpose_b参数,基于Internal框架实现,仅支持 Atlas 推理系列 | \ No newline at end of file +| ops.auto_generate.GroupedMatmulV4 | [ms_custom_ops.grouped_matmul](../ops/c_api/grouped_matmul/grouped_matmul.md) | 接口变更:输入从tuple[tensor]改为tensor,group_list前移到weight之后并改为必传参数,per_token_scale改为per_token_scale,移除了offset、antiquant_offset、activation相关参数、split_item、group_type、group_list_type、act_type、output_dtype等参数,返回从tuple[tensor]改为tensor,增加transpose_a、transpose_b参数,基于Internal框架实现,仅支持 Atlas 推理系列 | +| ops.auto_generate.grouped_matmul_v4 | [ms_custom_ops.grouped_matmul_v4_cops](../ops/c_api/grouped_matmul_v4_cops/grouped_matmul_v4_cops.md) | 新增weight_format参数,用于指定权重的format | \ No newline at end of file diff --git a/docs/op_list.md b/docs/op_list.md index a7a8ed889a0ce6dd41f062b9b8da95c785f4b360..70f243087a199e00ded4b8b29a9ad9e077ac95b9 100644 --- a/docs/op_list.md +++ b/docs/op_list.md @@ -7,6 +7,7 @@ 1. [fa_update](../ops/c_api/fa_update/fa_update_doc.md) 1. [flash_attention_encoder](../ops/c_api/flash_attention_encoder/flash_attention_encoder.md) 1. [fused_add_topk_div](../ops/c_api/fused_add_topk_div/fused_add_topk_div.md) +1. [grouped_matmul_v4_cops](../ops/c_api/grouped_matmul_v4_cops/grouped_matmul_v4_cops.md) 1. [group_topk](../ops/c_api/group_topk/group_topk_doc.md) 1. [grouped_matmul](../ops/c_api/grouped_matmul/grouped_matmul_doc.md) 1. [grouped_matmul_w4](../ops/c_api/grouped_matmul_w4/grouped_matmul_w4_doc.md) diff --git a/docs/pass_list.md b/docs/pass_list.md new file mode 100644 index 0000000000000000000000000000000000000000..a81cf011a7df789e8c15e34d67ad4956937fa3df --- /dev/null +++ b/docs/pass_list.md @@ -0,0 +1,6 @@ +# MsCustomOps Pass 列表 + +| Pass名称 | 功能描述 | 输入模式 | 输出模式 | 默认使能 | +|---------|---------|---------|---------|---------| +| AddRmsNormFusionPass | 将Add和RmsNorm操作融合为单个AddRmsNorm操作,减少计算开销 | RmsNorm(Add(x1, x2), gamma, eps) | AddRmsNorm(x1, x2, gamma, eps) | 否 | +| ConvertTupleInputToDynamicInput | 将算子tuple/list类型输入展开,转换为动态输入 | op_func(var, list[tensor]) | op_func(var, tensor1, tensor2, ...) | 是 | diff --git a/ops/c_api/grouped_matmul_v4_cops/grouped_matmul_v4_cops.cc b/ops/c_api/grouped_matmul_v4_cops/grouped_matmul_v4_cops.cc new file mode 100644 index 0000000000000000000000000000000000000000..1b1f28ff15a408e08e4c51b7893916bc7927666e --- /dev/null +++ b/ops/c_api/grouped_matmul_v4_cops/grouped_matmul_v4_cops.cc @@ -0,0 +1,405 @@ +/** + * Copyright 2025 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// ============================================================================= +// GRAPH MODE IMPLEMENTATION +// ============================================================================= + +#include +#include +#include +#include "ops/c_api/grouped_matmul_v4_cops/grouped_matmul_v4_cops_func_impl.h" +#include "ops/framework/aclnn/graphmode/aclnn_kernel_mod.h" +#include "ops/framework/utils.h" + +constexpr auto kGroupedMatmulV4CopsName = "grouped_matmul_v4_cops"; +constexpr size_t kListTensorInputNums = 12; + +namespace ms_custom_ops { +std::vector> DealWithListTensors(const std::vector &group_info, + const std::vector &start_idxs, + const std::vector &inputs) { + std::vector> list_inputs{}; + for (size_t i = 0; i < kListTensorInputNums; i++) { + std::vector input_i{}; + if (group_info[i] > 0) { + input_i.assign(inputs.begin() + start_idxs[i], inputs.begin() + start_idxs[i + 1]); + } + (void)list_inputs.emplace_back(std::move(input_i)); + } + return list_inputs; +} + +std::vector ComputeStartIdxsFromGroupInfo(const std::vector &group_info) { + std::vector start_idxs{0}; + int64_t cur_end_idx = 0; + for (size_t i = 0; i < group_info.size(); ++i) { + cur_end_idx += (group_info[i] == 0 ? 1 : group_info[i]); + start_idxs.push_back(cur_end_idx); + } + return start_idxs; +} + +static inline void UnifyWeightShape(const std::vector &ori_weights, + std::vector> *new_weights_shared_ptr, + std::vector *new_weights_raw_ptr) { + for (const auto &w : ori_weights) { + if (w->dtype_id() == kNumberTypeInt4) { + const auto &storage_info = w->tensor_storage_info(); + if (storage_info != nullptr && !storage_info->is_contiguous) { + MS_LOG(EXCEPTION) << "Currently, " << kGroupedMatmulV4CopsName + << " does not support noncontiguous input tensor for int4 quant, " + << "but got noncontiguous input tensor: " << w->ToString() + << ", storage info: " << storage_info->ToString(); + } + auto new_w = w->CloneKernelTensor(); + auto w_shape = w->GetShapeVector(); + w_shape.back() *= kNumber2; + new_w->SetShapeVector(w_shape); + new_weights_shared_ptr->emplace_back(new_w); + new_weights_raw_ptr->emplace_back(new_w.get()); + } else { + new_weights_raw_ptr->emplace_back(w); + } + } +} + +class GroupedMatmulV4CopsAscend : public AclnnCustomKernelMod { + public: + GroupedMatmulV4CopsAscend() : AclnnCustomKernelMod("aclnnGroupedMatmulV4") {} + ~GroupedMatmulV4CopsAscend() = default; + + bool Launch(const std::vector &inputs, const std::vector &workspace, + const std::vector &outputs, void *stream_ptr) override { + auto list_inputs = DealWithListTensors(group_info_, start_idxs_, inputs); + const auto &group_list = inputs[start_idxs_[group_list_idx_]]; + + std::vector> new_weights; + std::vector new_weights_raw; + UnifyWeightShape(list_inputs[kIndex1], &new_weights, &new_weights_raw); + + RunOp(stream_ptr, workspace, list_inputs[kIndex0], new_weights_raw, list_inputs[kIndex2], list_inputs[kIndex3], + list_inputs[kIndex4], list_inputs[kIndex5], list_inputs[kIndex6], list_inputs[kIndex7], group_list, + list_inputs[kIndex9], list_inputs[kIndex10], list_inputs[kIndex11], split_item_, group_type_, + group_list_type_, act_type_, outputs, activation_feature_out_, dyn_quant_scale_out_); + return true; + } + + void GetWorkSpaceInfo(const std::vector &inputs, + const std::vector &outputs) override { + group_info_ = GetValue>(primitive_->GetAttr(kGroupInfo)); + start_idxs_ = ComputeStartIdxsFromGroupInfo(group_info_); + + auto list_inputs = DealWithListTensors(group_info_, start_idxs_, inputs); + const auto &group_list = inputs[start_idxs_[group_list_idx_]]; + const auto split_item_idx = start_idxs_.back(); + split_item_ = inputs.at(split_item_idx)->GetValueWithCheck(); + group_type_ = inputs.at(split_item_idx + kIndex1)->GetValueWithCheck(); + group_list_type_ = inputs.at(split_item_idx + kIndex2)->GetValueWithCheck(); + act_type_ = inputs.at(split_item_idx + kIndex3)->GetValueWithCheck(); + weight_format_ = inputs.at(split_item_idx + kIndex4)->GetValueWithCheck(); + + std::vector> new_weights; + std::vector new_weights_raw; + UnifyWeightShape(list_inputs[kIndex1], &new_weights, &new_weights_raw); + + if (weight_format_ == kFractalNzFormat) { + for (auto &w_tensor : new_weights_raw) { + w_tensor->set_format(mindspore::Format::FRACTAL_NZ); + if (w_tensor->tensor_storage_info() != nullptr) { + MS_LOG(EXCEPTION) << "For " << kGroupedMatmulV4CopsName + << ", FRACTAL_NZ is not support when storage_info is not nullptr"; + } + auto nd_shape = w_tensor->GetShapeVector(); + auto storage_info = GetNZFormatStorageInfo(nd_shape, w_tensor->dtype_id()); + w_tensor->set_tensor_storage_info(storage_info); + } + } + + GetWorkspaceForResize(list_inputs[kIndex0], new_weights_raw, list_inputs[kIndex2], list_inputs[kIndex3], + list_inputs[kIndex4], list_inputs[kIndex5], list_inputs[kIndex6], list_inputs[kIndex7], + group_list, list_inputs[kIndex9], list_inputs[kIndex10], list_inputs[kIndex11], split_item_, + group_type_, group_list_type_, act_type_, outputs, activation_feature_out_, + dyn_quant_scale_out_); + } + + private: + DEFINE_GET_WORKSPACE_FOR_RESIZE(); + const size_t group_list_idx_{8}; + int64_t act_type_{0}; + int64_t group_list_type_{0}; + int64_t group_type_{0}; + int64_t split_item_{0}; + std::string weight_format_{"ND"}; + std::vector group_info_{}; + std::vector start_idxs_{}; + const std::vector activation_feature_out_{}; + const std::vector dyn_quant_scale_out_{}; +}; +} // namespace ms_custom_ops + +REG_GRAPH_MODE_OP(grouped_matmul_v4_cops, ms_custom_ops::GroupedMatmulV4CopsFuncImpl, + ms_custom_ops::GroupedMatmulV4CopsAscend); + +// ============================================================================= +// PYBOOST MODE IMPLEMENTATION +// ============================================================================= + +namespace ms_custom_ops { +TypeIdList InferOutputTypes(const std::vector &x, const std::vector &weight, + const std::optional> &scale, + const std::optional &output_dtype) { + TypeId x_type = x[0].data_type(); + TypeId w_type = weight[0].data_type(); + TypeIdList output_types; + if (x_type == kNumberTypeInt8 && w_type == kNumberTypeInt4) { + TypeId output_type = kNumberTypeBFloat16; + if (output_dtype.has_value()) { + output_type = static_cast(output_dtype.value()); + static std::set valid_dtype_set = {kNumberTypeFloat16, kNumberTypeBFloat16}; + if (valid_dtype_set.find(output_type) == valid_dtype_set.end()) { + MS_EXCEPTION(ValueError) << "For " << kGroupedMatmulV4CopsName + << " with A8W4, the output type must be in [Float16, BFloat16], but got " + << TypeIdToString(output_type); + } + } + std::transform(x.begin(), x.end(), std::back_inserter(output_types), + [output_type](const ms::Tensor &info) { return output_type; }); + } else if (!scale.has_value()) { + auto out_type = x_type == kNumberTypeInt8 ? kNumberTypeInt32 : x_type; + std::transform(x.begin(), x.end(), std::back_inserter(output_types), + [=](const ms::Tensor &info) { return out_type; }); + } else { + const auto &scale_tensors = scale.value(); + TypeId scale_type = scale_tensors[0].data_type(); + if (scale_type == kNumberTypeUInt64) { + std::transform(x.begin(), x.end(), std::back_inserter(output_types), + [](const ms::Tensor &info) { return kNumberTypeInt8; }); + } else if (scale_type == kNumberTypeBFloat16) { + std::transform(x.begin(), x.end(), std::back_inserter(output_types), + [](const ms::Tensor &info) { return kNumberTypeBFloat16; }); + } else if (scale_type == kNumberTypeFloat32) { + std::transform(x.begin(), x.end(), std::back_inserter(output_types), + [](const ms::Tensor &info) { return kNumberTypeFloat16; }); + } else { + MS_EXCEPTION(ValueError) << "For " << kGroupedMatmulV4CopsName + << ", the scale only support Uint16, BFloat16 and Float32, but got " << scale_type; + } + } + return output_types; +} + +ShapeArray InferOutputShapes(const std::vector &x, const std::vector &weight, + const std::optional &group_list, const int64_t &group_type) { + ShapeArray output_shapes; + if (group_type == -1) { + if (MS_UNLIKELY(x.size() != weight.size())) { + MS_EXCEPTION(ValueError) << "For " << kGroupedMatmulV4CopsName + << ", when group_type is -1 and split_item is 0, x's size " + << "should be equal to weight, but got " << x.size() << " and " << weight.size(); + } + for (size_t i = 0; i < x.size(); i++) { + const auto &x_shape = x[i].shape(); + const auto &w_shape = weight[i].shape(); + auto res_shape = x_shape; + res_shape.back() = w_shape.back(); + (void)output_shapes.emplace_back(std::move(res_shape)); + } + } else { + if (MS_UNLIKELY(x.size() != kDim1 || weight.size() != kDim1)) { + MS_EXCEPTION(ValueError) << "For " << kGroupedMatmulV4CopsName + << ", when split_item is 3. the size of x and weight should " + << "both be 1, but got x's size " << x.size() << ", and weight's size " << weight.size(); + } + const auto &x_shape = x[0].shape(); + const auto &w_shape = weight[0].shape(); + auto m = x_shape[x_shape.size() - kIndex2]; + auto n = w_shape.back(); + bool is_int4 = weight[0].data_type() == kNumberTypeInt4; + if (is_int4) { + n = n << 1; + } + int64_t group_list_size = group_list.has_value() ? group_list.value().shape()[0] : 1; + std::vector res_shape; + if (group_type == 0) { + res_shape = std::vector{m, n}; + } else if (group_type == 1) { + res_shape = std::vector{group_list_size, m, n}; + } + (void)output_shapes.emplace_back(std::move(res_shape)); + } + return output_shapes; +} + +std::vector CreateEmptyOutputTensors(TypeIdList &output_types, const ShapeArray &output_shapes) { + if (output_types.size() != output_shapes.size()) { + MS_EXCEPTION(ValueError) << "For " << kGroupedMatmulV4CopsName + << ", the number of output types must be equal to the number of output shapes, " + << "but got output types size: " << output_types.size() + << ", output shapes size: " << output_shapes.size(); + } + std::vector empty_tensors; + for (size_t i = 0; i < output_types.size(); i++) { + auto out_shape = output_shapes[i]; + TypeId out_type = static_cast(output_types[i]); + auto out_tensor = ms::Tensor(out_type, out_shape); + (void)empty_tensors.emplace_back(out_tensor); + } + return empty_tensors; +} + +std::vector GetOptionalTensorList(const std::optional> &tensor_list) { + return tensor_list.has_value() ? tensor_list.value() : std::vector(); +} + +void GetFlattenInputs(const std::vector &tensor_list, std::vector &inputs) { + for (auto &tensor : tensor_list) { + (void)inputs.emplace_back(tensor); + } +} + +void UnifyWeightShape(const std::vector &ori_weights, std::vector *new_weights) { + for (const auto &ori_weight : ori_weights) { + if (ori_weight.data_type() == kNumberTypeInt4) { + MS_EXCEPTION_IF_NULL(ori_weight.tensor()); + const auto &storage_info = ori_weight.tensor()->storage_info(); + if (storage_info != nullptr && !storage_info->is_contiguous) { + MS_LOG(EXCEPTION) << "Currently, " << kGroupedMatmulV4CopsName + << " does not support noncontiguous input tensor for int4 quant, " + << "but got noncontiguous input tensor: " << ori_weight.tensor()->ToString() + << ", storage info: " << storage_info->ToString(); + } + auto new_weight = ms::Tensor(ori_weight.data_type(), ori_weight.shape()); + new_weight.AssignTensor(ori_weight); + auto ori_weight_shape = ori_weight.shape(); + ori_weight_shape.back() *= kNumber2; + new_weight.tensor()->set_shape(ori_weight_shape); + (void)new_weights->emplace_back(std::move(new_weight)); + } else { + (void)new_weights->emplace_back(ori_weight); + } + } +} + +std::vector grouped_matmul_v4_cops_custom( + const std::vector &x, const std::vector &weight, + const std::optional> &bias, const std::optional> &scale, + const std::optional> &offset, const std::optional> &antiquant_scale, + const std::optional> &antiquant_offset, + const std::optional> &pre_token_scale, const std::optional &group_list, + const std::optional> &activation_input, + const std::optional> &activation_quant_scale, + const std::optional> &activation_quant_offset, const int64_t &split_item, + const int64_t &group_type, const int64_t &group_list_type, const int64_t &act_type, const std::string &weight_format, + const std::optional &output_dtype) { + auto output_types = InferOutputTypes(x, weight, scale, output_dtype); + auto output_shapes = InferOutputShapes(x, weight, group_list, group_type); + auto outputs = CreateEmptyOutputTensors(output_types, output_shapes); + + // Because copying the tensor is required under UnifyWeightShape, the address needs to be allocated in advance. + auto device_type = mindspore::DeviceManagerConf::GetInstance()->device_type(); + auto device_context = mindspore::runtime::OpRunner::GetDeviceContext(device_type); + auto stream_id = static_cast(mindspore::CurrentStream::id()); + for (size_t i = 0; i < weight.size(); i++) { + mindspore::runtime::DeviceAddressUtils::CreateInputTensorAddress(device_context, stream_id, kNumber1, + weight[i].tensor()); + } + std::vector new_weight; + UnifyWeightShape(weight, &new_weight); + if (weight_format == kFractalNzFormat) { + for (auto &w_tensor : new_weight) { + w_tensor.set_format(weight_format); + auto storage_info = GetNZFormatStorageInfo(w_tensor.shape(), w_tensor.data_type()); + MS_EXCEPTION_IF_NULL(w_tensor.tensor()); + MS_EXCEPTION_IF_NULL(w_tensor.tensor()->device_address()); + w_tensor.tensor()->device_address()->set_tensor_storage_info(storage_info); + } + } + + auto bias_tensor = GetOptionalTensorList(bias); + auto scale_tensor = GetOptionalTensorList(scale); + auto offset_tensor = GetOptionalTensorList(offset); + auto antiquant_scale_tensor = GetOptionalTensorList(antiquant_scale); + auto antiquant_offset_tensor = GetOptionalTensorList(antiquant_offset); + auto pre_token_scale_tensor = GetOptionalTensorList(pre_token_scale); + auto activation_input_tensor = GetOptionalTensorList(activation_input); + auto activation_quant_scale_tensor = GetOptionalTensorList(activation_quant_scale); + auto activation_quant_offset_tensor = GetOptionalTensorList(activation_quant_offset); + + std::vector inputs; + GetFlattenInputs(x, inputs); + GetFlattenInputs(weight, inputs); + GetFlattenInputs(bias_tensor, inputs); + GetFlattenInputs(scale_tensor, inputs); + GetFlattenInputs(offset_tensor, inputs); + GetFlattenInputs(antiquant_scale_tensor, inputs); + GetFlattenInputs(antiquant_offset_tensor, inputs); + GetFlattenInputs(pre_token_scale_tensor, inputs); + (void)inputs.emplace_back(GetTensorOrEmpty(group_list)); + GetFlattenInputs(activation_input_tensor, inputs); + GetFlattenInputs(activation_quant_scale_tensor, inputs); + GetFlattenInputs(activation_quant_offset_tensor, inputs); + + std::vector activation_feature_out; + std::vector dyn_quant_scale_out; + auto runner = std::make_shared("GroupedMatmulV4"); + runner->SetLaunchFunc(LAUNCH_ACLNN_FUNC(aclnnGroupedMatmulV4, x, new_weight, bias_tensor, scale_tensor, offset_tensor, + antiquant_scale_tensor, antiquant_offset_tensor, pre_token_scale_tensor, + group_list, activation_input_tensor, activation_quant_scale_tensor, + activation_quant_offset_tensor, split_item, group_type, group_list_type, + act_type, outputs, activation_feature_out, dyn_quant_scale_out)); + runner->Run(inputs, outputs); + return outputs; +} +} // namespace ms_custom_ops + +auto pyboost_grouped_matmul_v4_cops( + const std::vector &x, const std::vector &weight, + const std::optional> &bias, const std::optional> &scale, + const std::optional> &offset, const std::optional> &antiquant_scale, + const std::optional> &antiquant_offset, + const std::optional> &pre_token_scale, const std::optional &group_list, + const std::optional> &activation_input, + const std::optional> &activation_quant_scale, + const std::optional> &activation_quant_offset, const int64_t &split_item, + const int64_t &group_type, const int64_t &group_list_type, const int64_t &act_type, const std::string &weight_format, + const std::optional &output_dtype) { + { + // grouped_matmul_v4 does not support asynchronous execution due to an uncertain number of outputs. To ensure proper + // sequencing, it is necessary to wait for the frontend queue to clear first. + ms_custom_ops::GilReleaseWithCheck no_gil; + mindspore::runtime::Pipeline::Get().frontend_stage()->Wait(); + } + ms::inner::ConvertStubNodeToTensor(x, weight, bias, scale, offset, antiquant_scale, antiquant_offset, pre_token_scale, + group_list, activation_input, activation_quant_scale, activation_quant_offset); + std::vector outputs = ms_custom_ops::grouped_matmul_v4_cops_custom( + x, weight, bias, scale, offset, antiquant_scale, antiquant_offset, pre_token_scale, group_list, activation_input, + activation_quant_scale, activation_quant_offset, split_item, group_type, group_list_type, act_type, weight_format, + output_dtype); + return outputs; +} + +MS_CUSTOM_OPS_EXTENSION_MODULE(m) { + m.def("grouped_matmul_v4_cops", &pyboost_grouped_matmul_v4_cops, "GroupedMatmulV4", pybind11::arg("x"), + pybind11::arg("weight"), pybind11::arg("bias") = std::nullopt, pybind11::arg("scale") = std::nullopt, + pybind11::arg("offset") = std::nullopt, pybind11::arg("antiquant_scale") = std::nullopt, + pybind11::arg("antiquant_offset") = std::nullopt, pybind11::arg("pre_token_scale") = std::nullopt, + pybind11::arg("group_list") = std::nullopt, pybind11::arg("activation_input") = std::nullopt, + pybind11::arg("activation_quant_scale") = std::nullopt, pybind11::arg("activation_quant_offset") = std::nullopt, + pybind11::arg("split_item") = 0, pybind11::arg("group_type") = -1, pybind11::arg("group_list_type") = 0, + pybind11::arg("act_type") = 0, pybind11::arg("weight_format") = std::string("ND"), + pybind11::arg("output_dtype") = std::nullopt); +} diff --git a/ops/c_api/grouped_matmul_v4_cops/grouped_matmul_v4_cops_doc.md b/ops/c_api/grouped_matmul_v4_cops/grouped_matmul_v4_cops_doc.md new file mode 100644 index 0000000000000000000000000000000000000000..fad28e00f408c7e9ff18adfc5774c0e333b9ae2b --- /dev/null +++ b/ops/c_api/grouped_matmul_v4_cops/grouped_matmul_v4_cops_doc.md @@ -0,0 +1,126 @@ +# grouped_matmul_v4算子 + +## 描述 + +grouped_matmul_v4算子用于执行分组矩阵乘法操作,支持多种量化格式(如int8、int4),并提供丰富的配置选项,包括偏置、缩放、偏移、激活函数等。该算子适用于高效处理大规模矩阵乘法计算,尤其在深度学习模型中。 + +## 输入参数 + +| Name | DType | Shape | Optional | Inplace | Format | Description | +|------|-------|-------|----------|---------|--------|-------------| +| x | List[Tensor] | 多维张量列表 | No | No | ND | 输入特征张量列表 | +| weight | List[Tensor] | 多维张量列表 | No | No | ND/FRACTAL_NZ | 权重张量列表,支持int4量化 | +| bias | List[Tensor] | 适当广播的形状列表 | Yes | No | ND | 偏置张量列表 | +| scale | List[Tensor] | 适当广播的形状列表 | Yes | No | ND | 缩放因子列表,用于量化/反量化过程 | +| offset | List[Tensor] | 适当广播的形状列表 | Yes | No | ND | 偏移量列表 | +| antiquant_scale | List[Tensor] | 适当广播的形状列表 | Yes | No | ND | 反量化缩放因子列表 | +| antiquant_offset | List[Tensor] | 适当广播的形状列表 | Yes | No | ND | 反量化偏移量列表 | +| pre_token_scale | List[Tensor] | 适当广播的形状列表 | Yes | No | ND | 逐token缩放因子列表 | +| group_list | Tensor | 一维张量 | Yes | - | ND | 分组信息列表 | +| activation_input | List[Tensor] | 多维张量列表 | Yes | No | ND | 激活函数输入张量列表 | +| activation_quant_scale | List[Tensor] | 适当广播的形状列表 | Yes | No | ND | 激活函数量化缩放因子列表 | +| activation_quant_offset | List[Tensor] | 适当广播的形状列表 | Yes | No | ND | 激活函数量化偏移量列表 | +| split_item | int64 | - | Yes | - | - | 分割项配置 | +| group_type | int64 | - | Yes | - | - | 分组类型配置,0表示常规分组,-1表示多输出 | +| group_list_type | int64 | - | Yes | - | - | 分组列表类型配置 | +| act_type | int64 | - | Yes | - | - | 激活函数类型配置 | +| weight_format | str | - | Yes | - | - | 权重格式,可选值为"ND"或"FRACTAL_NZ",默认为"ND" | +| output_dtype | int64 | - | Yes | - | - | 输出数据类型,支持float16、bfloat16 | + +## 输出参数 + +| Name | DType | Shape | Description | +|------|-------|-------|-------------| +| output | Tensor/List[Tensor] | 符合分组矩阵乘法规则的形状 | 分组矩阵乘法的计算结果,当group_type为-1时返回多个输出 | + +更多详细信息请参考:[aclnnGroupedMatmulV4](https://www.hiascend.com/document/detail/zh/canncommercial/82RC1/API/aolapi/context/aclnnGroupedMatmulV4.md) + +## 支持平台 + +- Atlas A2 训练系列产品/Atlas 800I A2 推理产品/A200I A2 Box 异构组件 +- Atlas 推理系列产品 +- Atlas A3 训练系列产品/Atlas A3 推理系列产品 + +## 使用示例 + +### 基本使用示例 + +```python +import mindspore as ms +import numpy as np +import ms_custom_ops + +ms.set_device("Ascend") + +@ms.jit +def grouped_matmul_v4_func(x, weight, bias=None, scale=None, + offset=None, antiquant_scale=None, + antiquant_offset=None, pre_token_scale=None, + group_list=None, activation_input=None, + activation_quant_scale=None, + activation_quant_offset=None, split_item=0, + group_type=0, group_list_type=0, act_type=0, + weight_format="ND", output_dtype=ms.float16): + return ms_custom_ops.grouped_matmul_v4_cops( + x, weight, bias, scale, offset, + antiquant_scale, antiquant_offset, pre_token_scale, + group_list, activation_input, activation_quant_scale, + activation_quant_offset, split_item, group_type, + group_list_type, act_type, weight_format, output_dtype) + +# 准备输入数据 +expert_num = 4 +seq_len = 128 +hidden_size = 768 + +# 创建输入张量列表 +x = [ms.Tensor(np.random.randint(-128, 127, size=(seq_len, hidden_size)), dtype=ms.int8)] +# 创建权重张量列表 +weight = [ms.Tensor(np.random.randint(-8, 7, size=(expert_num, hidden_size, hidden_size)), dtype=ms.int8)] +# 创建分组列表 +group_list = ms.Tensor([0, 1, 2, 3], dtype=ms.int64) + +# 执行分组矩阵乘法 +output = grouped_matmul_v4_func( + x, weight, group_list=group_list, + split_item=3, group_type=0, group_list_type=0, act_type=0, + weight_format="ND", output_dtype=ms.float16 +) +print("Output shape:", output[0].shape) +``` + +### 多输出模式示例 + +```python +import mindspore as ms +import numpy as np +import ms_custom_ops + +ms.set_device("Ascend") + +# 准备多个输入和权重 +seq_len = 64 +input_size = 512 +output_size = 256 + +# 创建多个输入和权重 +x = [ + ms.Tensor(np.random.randint(-128, 127, size=(seq_len, input_size)), dtype=ms.int8), + ms.Tensor(np.random.randint(-128, 127, size=(seq_len, input_size)), dtype=ms.int8) +] +weight = [ + ms.Tensor(np.random.randint(-8, 7, size=(input_size, output_size)), dtype=ms.int8), + ms.Tensor(np.random.randint(-8, 7, size=(input_size, output_size)), dtype=ms.int8) +] + +# 多输出模式 (group_type = -1) +outputs = ms_custom_ops.grouped_matmul_v4_cops( + x, weight, + split_item=0, group_type=-1, group_list_type=0, act_type=0, + weight_format="ND" +) + +print("Number of outputs:", len(outputs)) +for i, out in enumerate(outputs): + print(f"Output {i+1} shape:", out.shape) +``` \ No newline at end of file diff --git a/ops/c_api/grouped_matmul_v4_cops/grouped_matmul_v4_cops_func_impl.cc b/ops/c_api/grouped_matmul_v4_cops/grouped_matmul_v4_cops_func_impl.cc new file mode 100644 index 0000000000000000000000000000000000000000..3d9bb20d705bb5d4208723ec331ddb3d12bd6173 --- /dev/null +++ b/ops/c_api/grouped_matmul_v4_cops/grouped_matmul_v4_cops_func_impl.cc @@ -0,0 +1,239 @@ +/** + * Copyright 2025 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "ops/c_api/grouped_matmul_v4_cops/grouped_matmul_v4_cops_func_impl.h" +#include +#include +#include +#include +#include +#include "ops/framework/utils.h" + +namespace ms_custom_ops { +constexpr size_t kListTensorInputNums = 12; +constexpr int64_t kGroupTypeMultiOutput = -1; +constexpr int64_t kGroupTypeSingleOutput = 0; + +void GroupedMatmulV4CopsFuncImpl::FetchGroupInfo(const PrimitivePtr &primitive, + const InferInfoPtrList &input_infos) const { + // for tensortuple(input arg) in backend split. (AscendConvertTupleInputToDynamicInput pass) + std::vector dyn_input_sizes; + for (size_t i = 0; i < kListTensorInputNums; i++) { + const auto &tensors = input_infos[i]; + if (i == LongToSize(kGroupListIndex)) { + dyn_input_sizes.push_back(1); + continue; + } + if (tensors->IsNone()) { + dyn_input_sizes.push_back(0); + continue; + } + if (MS_UNLIKELY(tensors->IsDynamicSequence())) { + MS_EXCEPTION(RuntimeError) + << "For '" << primitive->name() + << "', all inputs which is list[tensor] should not be dynamic sequence, which is not supported."; + } + const auto &elements = tensors->GetSequenceElements(); + dyn_input_sizes.push_back(SizeToLong(elements.size())); + } + primitive->set_attr(kGroupInfo, MakeValue(dyn_input_sizes)); // len of tuple input +} + +int64_t GroupedMatmulV4CopsFuncImpl::FetchGroupListIndex(const PrimitivePtr &primitive, + const InferInfoPtrList &input_infos) const { + if (MS_LIKELY(input_infos[kXIndex]->IsSequence())) { + return kGroupListIndex; + } + // Runtime phase: the element in input_args is KernelTensor. (tuple is expanded) + const auto &group_info = GetValue>(primitive->GetAttr(kGroupInfo)); + std::vector start_idxes{0}; + int64_t cur_end_idx = 0; + for (size_t i = 0; i < group_info.size(); ++i) { + cur_end_idx += (group_info[i] == 0 ? 1 : group_info[i]); + start_idxes.push_back(cur_end_idx); + } + return start_idxes.at(kGroupListIndex); +} + +ShapeArray GroupedMatmulV4CopsFuncImpl::InferShapeForSingleOutput(const PrimitivePtr &primitive, + const ShapeArray &x_shapes, + const ShapeArray &w_shapes, int64_t group_list_size, + int64_t group_type, bool is_int4) const { + if (MS_UNLIKELY(x_shapes.size() != kDim1 || w_shapes.size() != kDim1)) { + MS_EXCEPTION(ValueError) << "For '" << primitive->name() + << "', when split_item is 3. the size of x and weight should both be 1, but got x's size " + << x_shapes.size() << ", and weight's size " << w_shapes.size(); + } + + const auto &x_shape = x_shapes[0]; + const auto &w_shape = w_shapes[0]; + auto is_x_dyn_rank = IsDynamicRank(x_shape); + auto is_w_dyn_rank = IsDynamicRank(w_shape); + auto m = is_x_dyn_rank ? abstract::Shape::kShapeDimAny : x_shape[x_shape.size() - 2]; + auto n = abstract::Shape::kShapeDimAny; + if (!is_w_dyn_rank) { + n = w_shape.back(); + if (is_int4) { + n = n << 1; + } + } + + std::vector res_shape; + if (group_type == kGroupTypeSingleOutput) { + // x.shape [m, k], w.shape [e, k, n], y.shape [m, n] + res_shape = std::vector{m, n}; + } else { + // x.shape [m, k], w.shape [k, n], y.shape [b, m, n] + res_shape = std::vector{group_list_size, m, n}; + } + return {std::move(res_shape)}; +} + +ShapeArray GroupedMatmulV4CopsFuncImpl::InferShapeForMultiOutput(const PrimitivePtr &primitive, + const ShapeArray &x_shapes, + const ShapeArray &w_shapes) const { + if (MS_UNLIKELY(x_shapes.size() != w_shapes.size())) { + MS_EXCEPTION(ValueError) + << "For '" << primitive->name() + << "', when group_type is -1 and split_item is 0, x's size should be equal to weight, but got ." + << x_shapes.size() << " and " << w_shapes.size(); + } + + ShapeArray output_shapes; + for (size_t i = 0; i < x_shapes.size(); i++) { + const auto &x_shape = x_shapes[i]; + const auto &w_shape = w_shapes[i]; + if (MS_UNLIKELY(IsDynamicRank(x_shape))) { + (void)output_shapes.emplace_back(ShapeVector{abstract::TensorShape::kShapeRankAny}); + } else { + auto res_shape = x_shape; + res_shape.back() = IsDynamicRank(w_shape) ? abstract::Shape::kShapeDimAny : w_shape.back(); + (void)output_shapes.emplace_back(std::move(res_shape)); + } + } + return output_shapes; +} + +int64_t GroupedMatmulV4CopsFuncImpl::FetchGroupListSize(const PrimitivePtr &primitive, + const InferInfoPtrList &input_infos) const { + const auto group_list_idx = FetchGroupListIndex(primitive, input_infos); + const auto &group_list_shape = input_infos.at(group_list_idx)->GetShape(); + MS_CHECK_VALUE(group_list_shape.size() == kDim1, + CheckAndConvertUtils::FormatCheckIntegerMsg("group_list's rank", group_list_shape.size(), kEqual, + kDim1, primitive)); + return input_infos[group_list_idx]->IsDynamic() ? abstract::Shape::kShapeDimAny : group_list_shape[kIndex0]; +} + +std::pair GroupedMatmulV4CopsFuncImpl::FetchInputAndWeightShapes( + const PrimitivePtr &primitive, const InferInfoPtrList &input_infos) const { + ShapeArray x_shapes; + ShapeArray w_shapes; + if (MS_LIKELY(input_infos[kXIndex]->IsSequence())) { + FetchGroupInfo(primitive, input_infos); + auto FetchTupleTensorShapeFunc = [](const InferInfoPtr &tensors) { + const auto &elements = tensors->GetSequenceElements(); + ShapeArray shapes; + std::transform(elements.begin(), elements.end(), std::back_inserter(shapes), + [](const InferInfoPtr &info) { return info->GetShape(); }); + return shapes; + }; + // get tuple_x_shape in compile phase + x_shapes = FetchTupleTensorShapeFunc(input_infos[kXIndex]); + // get tuple_w_shape in compile phase + w_shapes = FetchTupleTensorShapeFunc(input_infos[kWeightIndex]); + } else { + // Runtime phase: the element in input_args is KernelTensor. (tuple is expanded) + auto tuple_len = GetValue>(primitive->GetAttr(kGroupInfo)); + size_t x_idx_end = LongToSize(tuple_len[kXIndex]); + size_t w_idx_end = LongToSize(tuple_len[kXIndex] + tuple_len[kWeightIndex]); + std::transform(input_infos.begin(), input_infos.begin() + x_idx_end, std::back_inserter(x_shapes), + [](const InferInfoPtr &info) { return info->GetShape(); }); + std::transform(input_infos.begin() + x_idx_end, input_infos.begin() + w_idx_end, std::back_inserter(w_shapes), + [](const InferInfoPtr &info) { return info->GetShape(); }); + } + return std::make_pair(std::move(x_shapes), std::move(w_shapes)); +} + +ShapeArray GroupedMatmulV4CopsFuncImpl::InferShape(const PrimitivePtr &primitive, + const InferInfoPtrList &input_infos) const { + auto [x_shapes, w_shapes] = FetchInputAndWeightShapes(primitive, input_infos); + const auto input_num = SizeToLong(input_infos.size()); + auto group_type_index = input_num + kGroupTypeOffset; + auto group_type_opt = input_infos[group_type_index]->GetScalarValue(); + MS_CHECK_VALUE(group_type_opt.has_value(), "For 'grouped_matmul_v4_cops', 'group_type' must be provided."); + auto group_type = group_type_opt.value(); + if (group_type == kGroupTypeMultiOutput) { + return InferShapeForMultiOutput(primitive, x_shapes, w_shapes); + } + + auto group_list_size = FetchGroupListSize(primitive, input_infos); + bool is_int4 = false; + if (MS_LIKELY(input_infos[kWeightIndex]->IsSequence())) { + const auto &w_tensors = input_infos[kWeightIndex]->GetSequenceElements(); + MS_CHECK_VALUE(w_tensors.size() > 0, "For 'grouped_matmul_v4_cops', 'weight' must be provided."); + is_int4 = w_tensors[0]->GetType() == kNumberTypeInt4; + } else { + is_int4 = input_infos[kWeightIndex]->GetType() == kNumberTypeInt4; + } + + return InferShapeForSingleOutput(primitive, x_shapes, w_shapes, group_list_size, group_type, is_int4); +} + +TypeIdList GroupedMatmulV4CopsFuncImpl::InferType(const PrimitivePtr &primitive, + const InferInfoPtrList &input_infos) const { + const auto &x_tensors = input_infos[kXIndex]->GetSequenceElements(); + const auto &w_tensors = input_infos[kWeightIndex]->GetSequenceElements(); + const auto &scale_infos = input_infos[kScaleIndex]; + TypeId x_type = x_tensors[0]->GetType(); + TypeId w_type = w_tensors[0]->GetType(); + TypeIdList output_types; + if (x_type == kNumberTypeInt8 && w_type == kNumberTypeInt4) { + TypeId output_type = kNumberTypeBFloat16; + if (!input_infos[kOutDtypeIndex]->IsNone()) { + auto dtype_ptr = input_infos[kOutDtypeIndex]->GetScalarValueWithCheck(); + output_type = static_cast(dtype_ptr); + static std::set valid_dtype_set = {kNumberTypeFloat16, kNumberTypeBFloat16}; + MS_CHECK_VALUE( + valid_dtype_set.find(output_type) != valid_dtype_set.end(), + "For 'grouped_matmul_v4_cops' with A8W4, the output type must be in [Float16, BFloat16], but got " + + TypeIdToString(output_type)); + } + std::transform(x_tensors.begin(), x_tensors.end(), std::back_inserter(output_types), + [output_type](const InferInfoPtr &info) { return output_type; }); + } else if (scale_infos->IsNone()) { + auto out_type = x_type == kNumberTypeInt8 ? kNumberTypeInt32 : x_type; + std::transform(x_tensors.begin(), x_tensors.end(), std::back_inserter(output_types), + [=](const InferInfoPtr &info) { return out_type; }); + } else { + const auto &scale_tensors = scale_infos->GetSequenceElements(); + TypeId scale_type = scale_tensors[0]->GetType(); + if (scale_type == kNumberTypeUInt64) { + std::transform(x_tensors.begin(), x_tensors.end(), std::back_inserter(output_types), + [](const InferInfoPtr &info) { return kNumberTypeInt8; }); + } else if (scale_type == kNumberTypeBFloat16) { + std::transform(x_tensors.begin(), x_tensors.end(), std::back_inserter(output_types), + [](const InferInfoPtr &info) { return kNumberTypeBFloat16; }); + } else if (scale_type == kNumberTypeFloat32) { + std::transform(x_tensors.begin(), x_tensors.end(), std::back_inserter(output_types), + [](const InferInfoPtr &info) { return kNumberTypeFloat16; }); + } else { + MS_EXCEPTION(ValueError) << "For '" << primitive->name() + << "', the scale only support Uint16, BFloat16 and Float32, but got " << scale_type; + } + } + return output_types; +} +} // namespace ms_custom_ops diff --git a/ops/c_api/grouped_matmul_v4_cops/grouped_matmul_v4_cops_func_impl.h b/ops/c_api/grouped_matmul_v4_cops/grouped_matmul_v4_cops_func_impl.h new file mode 100644 index 0000000000000000000000000000000000000000..e004c7531ca558523949b296ea88d19bb2e28d28 --- /dev/null +++ b/ops/c_api/grouped_matmul_v4_cops/grouped_matmul_v4_cops_func_impl.h @@ -0,0 +1,63 @@ +/** + * Copyright 2025 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MS_CUSTOM_OPS_OPS_C_API_GROUPED_MATMUL_V4_COPS_FUNC_IMPL_H_ +#define MS_CUSTOM_OPS_OPS_C_API_GROUPED_MATMUL_V4_COPS_FUNC_IMPL_H_ + +#include +#include "mindspore/include/custom_op_api.h" +#include "ops/framework/aclnn/graphmode/aclnn_kernel_mod.h" +#include "ops/framework/utils.h" + +namespace ms_custom_ops { +constexpr auto kGroupInfo = "group_info"; +constexpr size_t kXIndex = 0; +constexpr size_t kWeightIndex = 1; +constexpr size_t kScaleIndex = 3; +constexpr size_t kPerTokenScaleIndex = 7; +constexpr size_t kGroupListIndex = 8; +constexpr size_t kSplitItemIndex = 12; +constexpr size_t kGroupTypeIndex = 13; +constexpr size_t kGroupListTypeIndex = 14; +constexpr size_t kActTypeIndex = 15; +constexpr size_t kWeightFormatIndex = 16; +constexpr size_t kOutDtypeIndex = 17; +constexpr int64_t kGroupTypeOffset = -5; + +class GroupedMatmulV4CopsFuncImpl : public OpFuncImpl { + public: + GroupedMatmulV4CopsFuncImpl() {} + ~GroupedMatmulV4CopsFuncImpl() = default; + + ShapeArray InferShape(const PrimitivePtr &primitive, const InferInfoPtrList &input_infos) const override; + TypeIdList InferType(const PrimitivePtr &primitive, const InferInfoPtrList &input_infos) const override; + bool GeneralInferRegistered() const override { return true; }; + + protected: + void FetchGroupInfo(const PrimitivePtr &primitive, const InferInfoPtrList &input_infos) const; + int64_t FetchGroupListIndex(const PrimitivePtr &primitive, const InferInfoPtrList &input_infos) const; + int64_t FetchGroupListSize(const PrimitivePtr &primitive, const InferInfoPtrList &input_infos) const; + ShapeArray InferShapeForSingleOutput(const PrimitivePtr &primitive, const ShapeArray &x_shapes, + const ShapeArray &w_shapes, int64_t group_list_size, int64_t group_type, + bool is_int4) const; + ShapeArray InferShapeForMultiOutput(const PrimitivePtr &primitive, const ShapeArray &x_shapes, + const ShapeArray &w_shapes) const; + std::pair FetchInputAndWeightShapes(const PrimitivePtr &primitive, + const InferInfoPtrList &input_infos) const; +}; +} // namespace ms_custom_ops + +#endif // MS_CUSTOM_OPS_OPS_C_API_GROUPED_MATMUL_V4_COPS_FUNC_IMPL_H_ diff --git a/ops/c_api/grouped_matmul_v4_cops/grouped_matmul_v4_cops_op.yaml b/ops/c_api/grouped_matmul_v4_cops/grouped_matmul_v4_cops_op.yaml new file mode 100644 index 0000000000000000000000000000000000000000..81cd61d78db4658cf701d57496ec10b745088362 --- /dev/null +++ b/ops/c_api/grouped_matmul_v4_cops/grouped_matmul_v4_cops_op.yaml @@ -0,0 +1,70 @@ +#operator grouped_matmul_v4_cops +grouped_matmul_v4_cops: + args: + x: + dtype: tuple[tensor] + type_cast: list[tensor] + weight: + dtype: tuple[tensor] + type_cast: list[tensor] + bias: + dtype: tuple[tensor] + type_cast: list[tensor] + default: None + scale: + dtype: tuple[tensor] + type_cast: list[tensor] + default: None + offset: + dtype: tuple[tensor] + type_cast: list[tensor] + default: None + antiquant_scale: + dtype: tuple[tensor] + type_cast: list[tensor] + default: None + antiquant_offset: + dtype: tuple[tensor] + type_cast: list[tensor] + default: None + pre_token_scale: + dtype: tuple[tensor] + type_cast: list[tensor] + default: None + group_list: + dtype: tensor + default: None + activation_input: + dtype: tuple[tensor] + type_cast: list[tensor] + default: None + activation_quant_scale: + dtype: tuple[tensor] + type_cast: list[tensor] + default: None + activation_quant_offset: + dtype: tuple[tensor] + type_cast: list[tensor] + default: None + split_item: + dtype: int + default: 0 + group_type: + dtype: int + default: -1 + group_list_type: + dtype: int + default: 0 + act_type: + dtype: int + default: 0 + weight_format: + dtype: str + default: "'ND'" + output_dtype: + dtype: TypeId + arg_handler: dtype_to_type_id + default: None + returns: + out: + dtype: tuple[tensor] diff --git a/ops/framework/utils.h b/ops/framework/utils.h index 6c321cf1b6f2d171896891f5d137817e06d3e6c7..6e8e5d2f8d3d53e5d91c0517e50f8ef10361d23d 100644 --- a/ops/framework/utils.h +++ b/ops/framework/utils.h @@ -17,6 +17,7 @@ #ifndef __MS_CUSTOM_OPS_CCSRC_UTILS_UTILS_H__ #define __MS_CUSTOM_OPS_CCSRC_UTILS_UTILS_H__ +#include #include #include #include @@ -176,5 +177,20 @@ inline mindspore::TensorStorageInfoPtr GetNZFormatStorageInfo(const mindspore::S auto storage_info = std::make_shared(nd_shape, strides, nz_shape, strides, true); return storage_info; } + +class GilReleaseWithCheck { + public: + GilReleaseWithCheck() { + if (Py_IsInitialized() != 0 && PyGILState_Check() != 0) { + release_ = std::make_unique(); + } + } + + ~GilReleaseWithCheck() { + release_ = nullptr; + } + private: + std::unique_ptr release_; +}; } // namespace ms_custom_ops #endif // __MS_CUSTOM_OPS_CCSRC_UTILS_UTILS_H__ diff --git a/pass/convert_tuple_input_to_dynamic_input/convert_tuple_input_to_dynamic_input.cc b/pass/convert_tuple_input_to_dynamic_input/convert_tuple_input_to_dynamic_input.cc new file mode 100644 index 0000000000000000000000000000000000000000..f63dc4062336f8c98ed45b1370e902bfdcd209a9 --- /dev/null +++ b/pass/convert_tuple_input_to_dynamic_input/convert_tuple_input_to_dynamic_input.cc @@ -0,0 +1,54 @@ +/** + * Copyright 2023 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "pass/convert_tuple_input_to_dynamic_input/convert_tuple_input_to_dynamic_input.h" +#include +#include +#include "mindspore/ops/op_def/array_ops.h" +#include "mindspore/ops/op_def/structure_ops.h" +#include "mindspore/ops/op_def/nn_ops.h" +#include "mindspore/ops/op_def/framework_ops.h" +#include "mindspore/ccsrc/include/backend/optimizer/helper.h" +#include "mindspore/ccsrc/include/common/utils/anfalgo.h" +#include "mindspore/ccsrc/include/backend/anf_runtime_algorithm.h" + +namespace mindspore { +namespace opt { +constexpr auto kCustomGroupedMatmulV4 = "Custom_grouped_matmul_v4_cops"; +const BaseRef ConvertTupleInputToDynamicInput::DefinePattern() const { + VarPtr V = std::make_shared(); + VarPtr Xs = std::make_shared(); + return VectorRef({V, Xs}); +} + +const AnfNodePtr ConvertTupleInputToDynamicInput::Process(const FuncGraphPtr &func_graph, const AnfNodePtr &node, + const EquivPtr &) const { + if (node == nullptr || !node->isa() || !AnfUtils::IsRealKernel(node)) { + return nullptr; + } + static const std::unordered_set need_unfold_calculate_node = {kCustomGroupedMatmulV4}; + auto cnode = node->cast(); + MS_EXCEPTION_IF_NULL(cnode); + PrimitivePtr prim = common::AnfAlgo::GetCNodePrimitive(cnode); + MS_EXCEPTION_IF_NULL(prim); + auto node_name = prim->name(); + if (need_unfold_calculate_node.find(node_name) != need_unfold_calculate_node.end()) { + return ConvertMakeTupleInputToPlantInputs(func_graph, cnode); + } + return nullptr; +} +REGISTER_PASS(ConvertTupleInputToDynamicInput) +} // namespace opt +} // namespace mindspore diff --git a/pass/convert_tuple_input_to_dynamic_input/convert_tuple_input_to_dynamic_input.h b/pass/convert_tuple_input_to_dynamic_input/convert_tuple_input_to_dynamic_input.h new file mode 100644 index 0000000000000000000000000000000000000000..e751c1ee08eaf8e2750e516c0be7be147ecb064d --- /dev/null +++ b/pass/convert_tuple_input_to_dynamic_input/convert_tuple_input_to_dynamic_input.h @@ -0,0 +1,39 @@ +/** + * Copyright 2025 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MS_CUSTOM_OPS_PASS_CONVERT_TUPLE_INPUT_TO_DYNAMIC_INPUT_H_ +#define MS_CUSTOM_OPS_PASS_CONVERT_TUPLE_INPUT_TO_DYNAMIC_INPUT_H_ + +#include +#include "pass/pass_registry.h" +#include "include/backend/optimizer/optimizer.h" +#include "include/backend/visible.h" + +namespace mindspore { +namespace opt { +class BACKEND_COMMON_EXPORT ConvertTupleInputToDynamicInput : public PatternProcessPass { + public: + explicit ConvertTupleInputToDynamicInput(bool multigraph = true) + : PatternProcessPass("convert_tuple_input_to_dynamic_input", multigraph) {} + + ~ConvertTupleInputToDynamicInput() override = default; + + const BaseRef DefinePattern() const override; + + const AnfNodePtr Process(const FuncGraphPtr &func_graph, const AnfNodePtr &node, const EquivPtr &) const override; +}; +} // namespace opt +} // namespace mindspore +#endif // MS_CUSTOM_OPS_PASS_CONVERT_TUPLE_INPUT_TO_DYNAMIC_INPUT_H_ diff --git a/python/ms_custom_ops/__init__.py b/python/ms_custom_ops/__init__.py index 4eed0fd201a11d66eaf13da8415cdb249e5d9b01..6ce46ed877cc5d15a690957b6149fa095e025d1d 100644 --- a/python/ms_custom_ops/__init__.py +++ b/python/ms_custom_ops/__init__.py @@ -84,6 +84,7 @@ def register_custom_pass(pass_name, backend="ascend"): _init_env() +register_custom_pass("ConvertTupleInputToDynamicInput") # pylint: disable=wrong-import-position from .ms_custom_ops import * diff --git a/tests/__init__.py b/tests/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/tests/mark_utils.py b/tests/mark_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..834a8b8caea9b589acc964cc1aac763ec2fba832 --- /dev/null +++ b/tests/mark_utils.py @@ -0,0 +1,65 @@ +# Copyright 2025 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +""" define marks """ +import numpy as np +import pytest + +def get_numpy_global_seed(): + return 1967515154 + +def set_numpy_global_seed(): + np.random.seed(get_numpy_global_seed()) + +def arg_mark(plat_marks, level_mark, card_mark, essential_mark): + """ + Decorator generator for adding platform, level, card count, and essentiality marks to test functions. + + Args: + plat_marks (list): List of platform marks, optional values are in optional_plat_marks + level_mark (str): Test level mark, optional values are in optional_level_marks + card_mark (str): Card count mark, optional values are in optional_card_marks + essential_mark (str): Essentiality mark, optional values are in optional_essential_marks + + Returns: + function: Decorator function for adding pytest marks + + Raises: + ValueError: When the provided mark values are not within the optional values range + """ + optional_plat_marks = ['platform_ascend', 'platform_ascend910b', 'platform_ascend310p', 'platform_gpu', + 'cpu_linux', 'cpu_windows', 'cpu_macos'] + optional_level_marks = ['level0', 'level1', 'level2', 'level3', 'level4'] + optional_card_marks = ['onecard', 'allcards', 'dryrun', 'dryrun_only'] + optional_essential_marks = ['essential', 'unessential'] + if not plat_marks or not set(plat_marks).issubset(set(optional_plat_marks)): + raise ValueError("wrong plat_marks values") + if level_mark not in optional_level_marks: + raise ValueError("wrong level_mark value") + if card_mark not in optional_card_marks: + raise ValueError("wrong card_mark value") + if essential_mark not in optional_essential_marks: + raise ValueError("wrong essential_mark value") + + def decorator(func): + for plat_mark in plat_marks: + func = getattr(pytest.mark, plat_mark)(func) + func = getattr(pytest.mark, level_mark)(func) + func = getattr(pytest.mark, card_mark)(func) + func = getattr(pytest.mark, essential_mark)(func) + set_numpy_global_seed() + return func + + return decorator diff --git a/tests/st/test_grouped_matmul_v4.py b/tests/st/test_grouped_matmul_v4.py new file mode 100644 index 0000000000000000000000000000000000000000..eba48e7ae57324f29ef04c4fd419b7aedfe626f6 --- /dev/null +++ b/tests/st/test_grouped_matmul_v4.py @@ -0,0 +1,520 @@ +# Copyright 2025 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Unit test module for grouped_matmul_v4 operator + +This module contains tests for various functions and scenarios of the grouped_matmul_v4 operator, +including correctness verification under different input dimensions, quantization formats, +group types, and other conditions. +""" +import numpy as np +import pytest +import mindspore as ms +from mindspore import dtype as mstype +from mindspore import context, Tensor +from mindspore.nn import Cell +from tests.mark_utils import arg_mark +import ms_custom_ops + +def split_x(x, group_list): + """ + Split input tensor according to group list + + Args: + x: Input tensor + group_list: Group index list + + Returns: + list: List of split tensors + """ + x_split = [] + for i, end_idx in enumerate(group_list): + if i == 0: + x_split.append(x[0: end_idx,]) + else: + x_split.append(x[group_list[i - 1]: end_idx,]) + return x_split + + +def split_w(w): + """ + Split weight tensor along the first dimension + + Args: + w: Weight tensor to split + + Returns: + list: List of split weight tensors + """ + tmp_split = np.split(w, w.shape[0], axis=0) + w_split = [] + for t in tmp_split: + w_split.append(np.squeeze(t, 0)) + return w_split + + +class GroupedMatmulV4Net(Cell): + """ + Network wrapper class for grouped_matmul_v4 operator + + Used to encapsulate the grouped_matmul_v4_cops operator in tests, providing a unified interface for computation. + """ + def __init__(self): + super().__init__() + self.gmm_v4 = ms_custom_ops.grouped_matmul_v4_cops + + def construct(self, x, weight, bias=None, scale=None, offset=None, antiquant_scale=None, + antiquant_offset=None, pertoken_scale=None, group_list=None, split_item=3, + group_type=-1, group_list_type=0, weight_format="ND", output_dtype=None): + out = self.gmm_v4(x, weight, bias, scale, offset, antiquant_scale, antiquant_offset, + pertoken_scale, group_list, split_item=split_item, group_type=group_type, + group_list_type=group_list_type, weight_format=weight_format, output_dtype=output_dtype) + return out + + +@arg_mark(plat_marks=['platform_ascend910b'], level_mark='level0', card_mark='onecard', essential_mark='unessential') +@pytest.mark.parametrize('mode', ['KBK', 'pynative']) +def test_grouped_matmul_v4_x2d_w2d_splititem0_grouptypeneg1_none(mode): + """ + Feature: Test grouped_matmul + Description: semi_auto_parallel + Expectation: shape is as expected. + """ + context.set_context(device_target="Ascend") + if mode == 'KBK': + ms.set_context(mode=ms.GRAPH_MODE) + ms.set_context(jit_level='O0') + elif mode == 'pynative': + ms.set_context(mode=ms.PYNATIVE_MODE) + gmm_v4_net = GroupedMatmulV4Net() + + split_item = 0 + group_type = -1 + + m0 = 16 + k0 = 256 + n0 = 128 + + m1 = 127 + k1 = 88 + n1 = 64 + + # numpy calculate + np_x0 = np.random.uniform(1, 2, size=[2, 3, 4, 5, m0, k0]).astype(np.float32) + np_w0 = np.random.uniform(1, 2, size=[k0, n0]).astype(np.float32) + np_b0 = np.random.uniform(1, 5, size=[n0]).astype(np.float32) + + np_x1 = np.random.uniform(1, 2, size=[2, 3, 4, 5, m1, k1]).astype(np.float32) + np_w1 = np.random.uniform(1, 2, size=[k1, n1]).astype(np.float32) + np_b1 = np.random.uniform(1, 5, size=[n1]).astype(np.float32) + + except0 = np.matmul(np_x0, np_w0) + np_b0 + except1 = np.matmul(np_x1, np_w1) + np_b1 + + # ms calculate + x = [ms.Tensor(np_x0, dtype=mstype.bfloat16), ms.Tensor(np_x1, dtype=mstype.bfloat16)] + w = [ms.Tensor(np_w0, dtype=mstype.bfloat16), ms.Tensor(np_w1, dtype=mstype.bfloat16)] + b = [ms.Tensor(np_b0), ms.Tensor(np_b1)] + + res = gmm_v4_net(x, w, b, split_item=split_item, group_type=group_type) + + # compare + np.testing.assert_allclose(except0, res[0].float().asnumpy(), rtol=4e-3) + np.testing.assert_allclose(except1, res[1].float().asnumpy(), rtol=4e-3) + + +@arg_mark(plat_marks=['platform_ascend910b'], level_mark='level0', card_mark='onecard', essential_mark='unessential') +@pytest.mark.parametrize('mode', ['KBK', 'pynative']) +def test_grouped_matmul_v4_x2d_w3d_splititem3_grouptype0_a16w8(mode): + """ + Feature: Test grouped_matmul + Description: semi_auto_parallel + Expectation: shape is as expected. + """ + context.set_context(device_target="Ascend") + if mode == 'KBK': + ms.set_context(mode=ms.GRAPH_MODE) + ms.set_context(jit_level='O0') + elif mode == 'pynative': + ms.set_context(mode=ms.PYNATIVE_MODE) + gmm_v4_net = GroupedMatmulV4Net() + + split_item = 3 + group_type = 0 + group_list_type = 0 + + m0 = 32 + k0 = 256 + n0 = 128 + e0 = 8 + group_list_np = [1, 3, 10, 14, 18, 22, 24, 30] # last value can be less than total token numbers. + + # numpy calculate + np_x_all = np.random.uniform(-128, 127, size=[m0, k0]).astype(np.float16) + np_w_all = np.random.uniform(-128, 127, size=[e0, k0, n0]).astype(np.int8) + antiquant_scale0 = np.array(np.full([e0, n0], 0.01)).astype(np.float16) + antiquant_offset0 = np.array(np.full([e0, n0], 1)).astype(np.float16) + + np_x = split_x(np_x_all, group_list_np) + np_w = split_w(np_w_all) + np_s = split_w(antiquant_scale0) + np_o = split_w(antiquant_offset0) + res_np = [np.matmul(x0, (w0 + o0) * s0) for x0, w0, s0, o0 in zip(np_x, np_w, np_s, np_o)] + except_np = np.concatenate(res_np, axis=0) + + # ms calculate + x = [ms.Tensor(np_x_all)] + w = [ms.Tensor(np_w_all)] + antiquant_scale = [ms.Tensor(antiquant_scale0)] + antiquant_offset = [ms.Tensor(antiquant_offset0)] + + b = None + scale = None + offset = None + pertoken_scale = None + group_list = ms.Tensor(group_list_np, dtype=mstype.int64) + + res = gmm_v4_net(x, w, b, scale, offset, antiquant_scale, antiquant_offset, pertoken_scale, group_list, + split_item, group_type, group_list_type) + + # compare + np.testing.assert_allclose(except_np, res[0][:30].asnumpy(), rtol=1e-3) + + +@arg_mark(plat_marks=['platform_ascend910b'], level_mark='level0', card_mark='onecard', essential_mark='unessential') +@pytest.mark.parametrize('mode', ['KBK', 'pynative']) +def test_grouped_matmul_v4_x2d_w3d_splititem3_grouptype0_a16w4(mode): + """ + Feature: Test grouped_matmul + Description: semi_auto_parallel + Expectation: shape is as expected. + """ + context.set_context(device_target="Ascend") + if mode == 'KBK': + ms.set_context(mode=ms.GRAPH_MODE) + ms.set_context(jit_level='O0') + elif mode == 'pynative': + ms.set_context(mode=ms.PYNATIVE_MODE) + gmm_v4_net = GroupedMatmulV4Net() + + split_item = 3 + group_type = 0 + group_list_type = 0 + + m0 = 32 + k0 = 256 + n0 = 128 + e0 = 8 + group_list_np = [1, 3, 10, 14, 18, 22, 24, 30] # last value can be less than total token numbers + + # numpy calculate + np_x_all = np.random.uniform(-128, 127, size=[m0, k0]).astype(np.float16) + np_w_all = np.random.uniform(0, 2, size=[e0, k0, n0]).astype(np.int8) + antiquant_scale0 = np.array(np.full([e0, n0], 0.01)).astype(np.float16) + antiquant_offset0 = np.array(np.full([e0, n0], 1)).astype(np.float16) + + for i in range(e0): + for j in range(k0): + for k in range(n0): + np_w_all[i, j, k] = np_w_all[i, j, k] & 0xf + + np_w_all_int4 = np.ones((e0 * k0 * n0 // 2,), dtype=np.int8) + np_w_all_one_rank = np_w_all.reshape(-1,) + for i in range(e0 * k0 * n0 // 2): + np_w_all_int4[i] = np_w_all_one_rank[i * 2] | ((np_w_all_one_rank[(i * 2) + 1] & 15) << 4) + + np_w_all_int4_3_rank = np_w_all_int4.reshape((e0, k0, n0 // 2)) + + np_x = split_x(np_x_all, group_list_np) + np_w = split_w(np_w_all) + np_s = split_w(antiquant_scale0) + np_o = split_w(antiquant_offset0) + res_np = [np.matmul(x0, (w0 + o0) * s0) for x0, w0, s0, o0 in zip(np_x, np_w, np_s, np_o)] + expect_np = np.concatenate(res_np, axis=0) + + # ms calculate + x = [ms.Tensor(np_x_all)] + w = [ms.Tensor(np_w_all_int4_3_rank, dtype=ms.qint4x2)] + antiquant_scale = [ms.Tensor(antiquant_scale0)] + antiquant_offset = [ms.Tensor(antiquant_offset0)] + + b = None + scale = None + offset = None + pertoken_scale = None + group_list = ms.Tensor(group_list_np, dtype=mstype.int64) + + res = gmm_v4_net(x, w, b, scale, offset, antiquant_scale, antiquant_offset, pertoken_scale, group_list, + split_item, group_type, group_list_type) + + # compare + np.testing.assert_allclose(expect_np, res[0][:30].asnumpy(), rtol=1e-3, atol=1e-3) + + +@arg_mark(plat_marks=['platform_ascend910b'], level_mark='level0', card_mark='onecard', essential_mark='unessential') +@pytest.mark.parametrize('mode', ['KBK', 'pynative']) +def test_grouped_matmul_v4_x2d_w3d_splititem3_grouptype0_none_pertoken(mode): + """ + Feature: Test grouped_matmul + Description: semi_auto_parallel + Expectation: shape is as expected. + """ + context.set_context(device_target="Ascend") + if mode == 'KBK': + ms.set_context(mode=ms.GRAPH_MODE) + ms.set_context(jit_level='O0') + elif mode == 'pynative': + ms.set_context(mode=ms.PYNATIVE_MODE) + gmm_v4_net = GroupedMatmulV4Net() + + split_item = 3 + group_type = 0 + group_list_type = 1 + + m0 = 32 + k0 = 256 + n0 = 128 + e0 = 8 + group_list_np = [1, 2, 7, 4, 4, 4, 2, 8] + + # numpy calculate + np_x_all = np.random.uniform(-128, 127, size=[m0, k0]).astype(np.int8) + np_w_all = np.random.uniform(-128, 127, size=[e0, k0, n0]).astype(np.int8) + np_s_all = np.array(np.full([e0, n0], 10)).astype(np.float32) + np_pts = np.array([10] * m0).astype(np.float32) + + np_x = split_x(np_x_all, np.cumsum(group_list_np)) + np_w = split_w(np_w_all) + np_s = split_w(np_s_all) + res_np = [np.matmul(x0, w0 * s0) for x0, w0, s0 in zip(np_x, np_w, np_s)] + except_np = np.concatenate(res_np, axis=0) * np_pts.reshape(m0, 1) + + # ms calculate + x = [ms.Tensor(np_x_all)] + w = [ms.Tensor(np_w_all)] + scale = [ms.Tensor(np_s_all, dtype=mstype.bfloat16)] + pertoken_scale = [ms.Tensor(np_pts)] + + b = None + offset = None + antiquant_scale = None + antiquant_offset = None + group_list = ms.Tensor(group_list_np, dtype=mstype.int64) + + res = gmm_v4_net(x, w, b, scale, offset, antiquant_scale, antiquant_offset, pertoken_scale, group_list, + split_item, group_type, group_list_type) + + # compare + np.testing.assert_allclose(except_np, res[0].float().asnumpy(), rtol=4e-3) + + +@arg_mark(plat_marks=['platform_ascend910b'], level_mark='level0', card_mark='onecard', essential_mark='unessential') +@pytest.mark.parametrize('mode', ['KBK', 'pynative']) +def test_grouped_matmul_v4_x2d_w3d_splititem3_grouptype0_none_perchannel(mode): + """ + Feature: Test grouped_matmul + Description: semi_auto_parallel + Expectation: shape is as expected. + """ + context.set_context(device_target="Ascend") + if mode == 'KBK': + ms.set_context(mode=ms.GRAPH_MODE) + ms.set_context(jit_level='O0') + elif mode == 'pynative': + ms.set_context(mode=ms.PYNATIVE_MODE) + gmm_v4_net = GroupedMatmulV4Net() + + split_item = 3 + group_type = 0 + group_list_type = 1 + + m0 = 32 + k0 = 256 + n0 = 128 + e0 = 8 + group_list_np = [1, 2, 7, 4, 4, 4, 2, 8] + + # numpy calculate + np_x_all = np.random.uniform(-128, 127, size=[m0, k0]).astype(np.int8) + np_w_all = np.random.uniform(-128, 127, size=[e0, k0, n0]).astype(np.int8) + np_s_all = np.array(np.full([e0, n0], 10)).astype(np.float32) + np_b_all = np.array(np.full([e0, n0], 1)).astype(np.float32) + + np_x = split_x(np_x_all, np.cumsum(group_list_np)) + np_w = split_w(np_w_all) + np_s = split_w(np_s_all) + np_b = split_w(np_b_all) + res_np = [np.matmul(x0, w0 * s0) + b0 * s0 for x0, w0, s0, b0 in zip(np_x, np_w, np_s, np_b)] + except_np = np.concatenate(res_np, axis=0) + + # ms calculate + x = [ms.Tensor(np_x_all)] + w = [ms.Tensor(np_w_all)] + scale = [ms.Tensor(np_s_all, dtype=mstype.bfloat16)] + bias = [ms.Tensor(np_b, dtype=mstype.int32)] + + offset = None + antiquant_scale = None + antiquant_offset = None + group_list = ms.Tensor(group_list_np, dtype=mstype.int64) + + res = gmm_v4_net(x, w, bias, scale, offset, antiquant_scale, antiquant_offset, None, group_list, + split_item, group_type, group_list_type) + + # compare + np.testing.assert_allclose(except_np, res[0].float().asnumpy(), rtol=4e-3) + + +@arg_mark(plat_marks=['platform_ascend910b'], level_mark='level0', card_mark='onecard', essential_mark='unessential') +@pytest.mark.parametrize('mode', ['KBK', 'pynative']) +def test_ops_grouped_matmul_v4_multi_dyn(mode): + """ + Feature: Pyboost function. + Description: Test GroupedMatmulV4 forward with dynamic rank/shape. + Expectation: Success. + """ + context.set_context(device_target="Ascend") + if mode == 'KBK': + ms.set_context(mode=ms.GRAPH_MODE) + ms.set_context(jit_level='O0') + elif mode == 'pynative': + ms.set_context(mode=ms.PYNATIVE_MODE) + gmm_v4_net = GroupedMatmulV4Net() + + split_item = 0 + group_type = -1 + group_list_type = 0 + weight_format = "ND" + + x = ms.mutable([Tensor(shape=(None, None), dtype=mstype.float16), Tensor(shape=(None, None), dtype=mstype.float16)]) + weight = ms.mutable([Tensor(shape=(None, None), dtype=mstype.float16), + Tensor(shape=(None, None), dtype=mstype.float16)]) + gmm_v4_net.set_inputs(x, weight, None, None, None, None, None, None, None, split_item, + group_type, group_list_type, weight_format, None) + + np_x0 = np.random.uniform(0.1, 2, size=[16, 256]).astype(np.float32) + np_w0 = np.random.uniform(0.1, 1, size=[256, 128]).astype(np.float32) + expect0 = np.matmul(np_x0, np_w0) + + np_x1 = np.random.uniform(0.1, 2, size=[127, 88]).astype(np.float32) + np_w1 = np.random.uniform(0.1, 1, size=[88, 64]).astype(np.float32) + expect1 = np.matmul(np_x1, np_w1) + + x1 = ms.mutable([ms.Tensor(np_x0, dtype=mstype.float16), ms.Tensor(np_x1, dtype=mstype.float16)]) + weight1 = ms.mutable([ms.Tensor(np_w0, dtype=mstype.float16), ms.Tensor(np_w1, dtype=mstype.float16)]) + res1 = gmm_v4_net(x1, weight1, split_item=split_item, group_type=group_type) + np.testing.assert_allclose(expect0, res1[0].asnumpy(), rtol=1e-1) + np.testing.assert_allclose(expect1, res1[1].asnumpy(), rtol=1e-1) + + x2 = ms.mutable([ms.Tensor(np_x0, dtype=mstype.float16), ms.Tensor(np_x1, dtype=mstype.float16)]) + weight2 = ms.mutable([ms.Tensor(np_w0, dtype=mstype.float16), ms.Tensor(np_w1, dtype=mstype.float16)]) + res2 = gmm_v4_net(x2, weight2, split_item=split_item, group_type=group_type) + np.testing.assert_allclose(expect0, res2[0].asnumpy(), rtol=1e-1) + + +@arg_mark(plat_marks=['platform_ascend910b'], level_mark='level0', card_mark='onecard', essential_mark='unessential') +@pytest.mark.parametrize('mode', ['KBK', 'pynative']) +@pytest.mark.parametrize('output_dtype', [ms.float16, ms.bfloat16]) +def test_ops_grouped_mamtul_v4_a8w4(mode, output_dtype): + """ + Feature: pyboost function. + Description: test GroupedMatmulV4 forward with a8w4. + Expectation: success. + """ + np.random.seed(1) + context.set_context(device_target="Ascend") + if mode == 'KBK': + ms.set_context(mode=ms.GRAPH_MODE) + ms.set_context(jit_level='O0') + elif mode == 'pynative': + ms.set_context(mode=ms.PYNATIVE_MODE) + gmm_v4_net = GroupedMatmulV4Net() + + e = 8 + m = 32 + k = 256 + n = 128 + split_item = 3 + group_type = 0 + group_list_type = 1 + + x_np = np.random.randint(-5, 5, size=(m, k)).astype(np.int8) + w_np = np.random.randint(-5, 5, size=(e, k, n)).astype(np.int8) + w_int4_np = w_np.reshape(-1) & 0x000F + w_int4_np = w_int4_np[0::2] | (w_int4_np[1::2] << 4) + w_int4_np = w_int4_np.reshape(e, k, n // 2) + scale_np = np.random.normal(0, 0.01, size=(e, 1, n)).astype(np.float32) + scale_np_uint64 = np.frombuffer(scale_np.tobytes(), dtype=np.uint32).astype(np.uint64).reshape(e, 1, n) + bias_np = 8 * (w_np.astype(np.float32) * scale_np).sum(axis=1) + pertoken_scale_np = np.random.normal(0, 0.01, (m, 1)).astype(np.float32) + group_list_np = np.array([1, 2, 7, 4, 4, 4, 2, 8], dtype=np.int64) + + index = np.cumsum(group_list_np) + x_np_split = np.split(x_np, index, axis=0) + pertoken_scale_np_split = np.split(pertoken_scale_np, index, axis=0) + out_list = [] + scale_fp32 = scale_np_uint64.astype(np.uint32) + scale_fp32.dtype = np.float32 + for i in range(e): + mm = np.matmul(x_np_split[i].astype(np.int32), w_np[i].astype(np.int32)).astype(np.float32) + mm = mm * scale_fp32[i] * pertoken_scale_np_split[i] + out_list.append(mm) + expect = np.concatenate(out_list, axis=0) + + x = [Tensor(x_np, ms.int8)] + weight = [Tensor(w_int4_np, dtype=ms.qint4x2)] + bias = [Tensor(bias_np, ms.float32)] + scale = [Tensor(scale_np_uint64, ms.uint64)] + pertoken_scale = [Tensor(pertoken_scale_np, ms.float32)] + group_list = Tensor(group_list_np, ms.int64) + out = gmm_v4_net(x, weight, bias=bias, scale=scale, pertoken_scale=pertoken_scale, + group_list=group_list, split_item=split_item, group_type=group_type, + group_list_type=group_list_type, output_dtype=output_dtype)[0] + cnt = expect.shape[0] + np.testing.assert_allclose(expect.astype(np.float32), out[:cnt].astype(ms.float32).asnumpy(), rtol=5e-3, atol=5e-3) + + +@arg_mark(plat_marks=['platform_ascend910b'], level_mark='level0', card_mark='onecard', essential_mark='unessential') +@pytest.mark.parametrize('mode', ['KBK', 'pynative']) +@pytest.mark.parametrize('weight_format', ["ND", "FRACTAL_NZ"]) +def test_ops_grouped_mamtul_v4_fractal_nz(mode, weight_format): + """ + Feature: pyboost function. + Description: test GroupedMatmulV4 forward with fractal_nz. + Expectation: success. + """ + np.random.seed(1) + context.set_context(device_target="Ascend") + if mode == 'KBK': + ms.set_context(mode=ms.GRAPH_MODE) + ms.set_context(jit_level='O0') + elif mode == 'pynative': + ms.set_context(mode=ms.PYNATIVE_MODE) + gmm_v4_net = GroupedMatmulV4Net() + + split_item = 0 + group_type = -1 + + np_x0 = np.random.uniform(0.1, 2, size=[16, 256]).astype(np.float16) + np_w0 = np.random.uniform(0.1, 1, size=[256, 128]).astype(np.float16) + expect0 = np.matmul(np_x0, np_w0) + + ms_x0 = [ms.Tensor(np_x0, dtype=mstype.float16)] + if weight_format == "FRACTAL_NZ": + ms_w0 = [ms_custom_ops.trans_data(ms.Tensor(np_w0, dtype=mstype.float16), transdata_type=1)] + else: + ms_w0 = [ms.Tensor(np_w0, dtype=mstype.float16)] + + res = gmm_v4_net(ms_x0, ms_w0, split_item=split_item, group_type=group_type, weight_format=weight_format) + np.testing.assert_allclose(expect0, res[0].asnumpy(), rtol=1e-1)