diff --git a/ops/c_api/batch_matmul/batch_matmul.cc b/ops/c_api/batch_matmul/batch_matmul.cc new file mode 100644 index 0000000000000000000000000000000000000000..1961cceda3585df1819c48d6f2e464d6ca4241d1 --- /dev/null +++ b/ops/c_api/batch_matmul/batch_matmul.cc @@ -0,0 +1,253 @@ +/** + * Copyright 2025 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// ============================================================================= +// GRAPH MODE IMPLEMENTATION +// ============================================================================= + +#include +#include +#include +#include "ops/framework/aclnn/graphmode/aclnn_kernel_mod.h" +#include "ops/framework/utils.h" +#include "ops/c_api/utils/common_utils.h" + +namespace ms_custom_ops { + +constexpr size_t kBatchMatmulMatSize = 2; // Last 2 dimensions for matrix multiplication +constexpr size_t kBatchMatmulDims = 3; // Must be exactly 3D tensor: [batch, M, K] or [batch, K, N] + +enum class BatchMatmulInputIndex : size_t { + kBatchMatmulInputX1Index = 0, + kBatchMatmulInputX2Index, + kBatchMatmulInputTransposeAIndex, + kBatchMatmulInputTransposeBIndex, + kBatchMatmulInputCubeMathTypeIndex, + kBatchMatmulInputsNum, +}; + +enum class BatchMatmulOutputIndex : size_t { kBatchMatmulOutputIndex = 0 }; + +/** + * @brief Batch matrix multiplication operator function implementation for shape/type inference + * This class handles the shape and type inference logic for batch matmul operations in graph mode. + */ +class OPS_API BatchMatMulCustomOpFuncImpl : public OpFuncImpl { + public: + /** + * @brief Infer the output shape of batch matmul operation + * @param primitive The primitive object containing operator information + * @param input_infos List of input information including shapes + * @return ShapeArray containing the inferred output shape + */ + ShapeArray InferShape(const PrimitivePtr &primitive, + const InferInfoPtrList &input_infos) const override { + // Check transpose parameters first, before shape inference + // aclnnBatchMatMul does not support transpose, so we need to check early + bool transpose_a = input_infos[static_cast( + BatchMatmulInputIndex::kBatchMatmulInputTransposeAIndex)]->GetScalarValueWithCheck(); + bool transpose_b = input_infos[static_cast( + BatchMatmulInputIndex::kBatchMatmulInputTransposeBIndex)]->GetScalarValueWithCheck(); + CheckTransposeUnsupported(transpose_a, transpose_b, "BatchMatMul"); + + auto x1_shape = + input_infos[static_cast(BatchMatmulInputIndex::kBatchMatmulInputX1Index)]->GetShape(); + auto x2_shape = + input_infos[static_cast(BatchMatmulInputIndex::kBatchMatmulInputX2Index)]->GetShape(); + + if (IsDynamicRank(x1_shape) || IsDynamicRank(x2_shape)) { + return {ShapeVector({abstract::Shape::kShapeRankAny})}; + } + + // Check that inputs must be exactly 3D tensors + // In Graph mode, MS_LOG(EXCEPTION) will be converted to Python exception by MindSpore framework + if (x1_shape.size() != kBatchMatmulDims || x2_shape.size() != kBatchMatmulDims) { + MS_LOG(EXCEPTION) << "For 'BatchMatMul', inputs must be exactly 3D tensors, " + << "but got x1 with " << x1_shape.size() << "D and x2 with " << x2_shape.size() << "D"; + } + + // Since transpose_a and transpose_b are guaranteed to be False at this point, + // we can pass false to BatchMatMulMakeShape + ShapeVector out_shape = + BatchMatMulMakeShape(x1_shape, x2_shape, false, false, kBatchMatmulMatSize, "BatchMatMul"); + return {out_shape}; + } + + /** + * @brief Infer the output data type of batch matmul operation + * @param primitive The primitive object containing operator information + * @param input_infos List of input information including data types + * @return std::vector containing the inferred output data type + */ + std::vector InferType(const PrimitivePtr &primitive, + const InferInfoPtrList &input_infos) const override { + // Check data types of both inputs + TypeId x1_type = input_infos[static_cast( + BatchMatmulInputIndex::kBatchMatmulInputX1Index)]->GetType(); + TypeId x2_type = input_infos[static_cast( + BatchMatmulInputIndex::kBatchMatmulInputX2Index)]->GetType(); + + // Check if data types are supported + CheckBatchMatMulDataType(x1_type, "BatchMatMul"); + CheckBatchMatMulDataType(x2_type, "BatchMatMul"); + + // Check if both inputs have the same type + if (x1_type != x2_type) { + MS_LOG(EXCEPTION) << "For 'BatchMatMul', inputs x1 and x2 must have the same data type, " + << "but got x1 with " << TypeIdToString(x1_type) + << " and x2 with " << TypeIdToString(x2_type); + } + + // Use the first input's type as output type + return {x1_type}; + } + + bool GeneralInferRegistered() const override { return true; } +}; + +/** + * @brief Ascend kernel implementation for batch matrix multiplication using aclnnBatchMatMul + * This class provides the kernel execution interface for batch matmul operations on Ascend devices. + */ +class BatchMatMulCustomAscend : public AclnnCustomKernelMod { + public: + BatchMatMulCustomAscend() : AclnnCustomKernelMod("aclnnBatchMatMul") {} + ~BatchMatMulCustomAscend() = default; + + /** + * @brief Launch the batch matmul kernel execution + * @param inputs Input tensors including x1, x2, transpose flags and cube math type + * @param workspace Workspace memory for kernel execution + * @param outputs Output tensors + * @param stream_ptr Stream pointer for asynchronous execution + * @return bool indicating success or failure + */ + bool Launch(const std::vector &inputs, const std::vector &workspace, + const std::vector &outputs, void *stream_ptr) override { + MS_EXCEPTION_IF_NULL(stream_ptr); + bool transpose_a = inputs[static_cast(BatchMatmulInputIndex::kBatchMatmulInputTransposeAIndex)] + ->GetValueWithCheck(); + bool transpose_b = inputs[static_cast(BatchMatmulInputIndex::kBatchMatmulInputTransposeBIndex)] + ->GetValueWithCheck(); + + // aclnnBatchMatMul does not support transpose, so we need to check + CheckTransposeUnsupported(transpose_a, transpose_b, "BatchMatMul"); + + int8_t cube_math_type = static_cast( + inputs[static_cast(BatchMatmulInputIndex::kBatchMatmulInputCubeMathTypeIndex)] + ->GetValueWithCheck()); + RunOp(stream_ptr, workspace, inputs[static_cast(BatchMatmulInputIndex::kBatchMatmulInputX1Index)], + inputs[static_cast(BatchMatmulInputIndex::kBatchMatmulInputX2Index)], + outputs[static_cast(BatchMatmulOutputIndex::kBatchMatmulOutputIndex)], cube_math_type); + return true; + } + + /** + * @brief Get workspace information required for kernel execution + * @param inputs Input tensors + * @param outputs Output tensors + */ + void GetWorkSpaceInfo(const std::vector &inputs, + const std::vector &outputs) override { + int8_t cube_math_type = static_cast( + inputs[static_cast(BatchMatmulInputIndex::kBatchMatmulInputCubeMathTypeIndex)] + ->GetValueWithCheck()); + GetWorkspaceForResize(inputs[static_cast(BatchMatmulInputIndex::kBatchMatmulInputX1Index)], + inputs[static_cast(BatchMatmulInputIndex::kBatchMatmulInputX2Index)], + outputs[static_cast(BatchMatmulOutputIndex::kBatchMatmulOutputIndex)], + cube_math_type); + } + + private: + DEFINE_GET_WORKSPACE_FOR_RESIZE(); +}; +} // namespace ms_custom_ops + +REG_GRAPH_MODE_OP(batch_matmul, ms_custom_ops::BatchMatMulCustomOpFuncImpl, + ms_custom_ops::BatchMatMulCustomAscend); + +// ============================================================================= +// PYBOOST MODE IMPLEMENTATION +// ============================================================================= + +#include "ops/framework/ms_kernels_internal/pyboost/internal_pyboost_runner.h" + +namespace ms_custom_ops { +/** + * @brief PyBoost mode implementation of batch matrix multiplication + * @param x1 First input tensor with shape [..., M, K] + * @param x2 Second input tensor with shape [..., K, N] + * @param transpose_a Whether to transpose x1 (must be false for aclnnBatchMatMul) + * @param transpose_b Whether to transpose x2 (must be false for aclnnBatchMatMul) + * @param cube_math_type Cube math type for precision control + * @return ms::Tensor Result tensor with shape [..., M, N] + */ +ms::Tensor batch_matmul_custom(const ms::Tensor &x1, const ms::Tensor &x2, bool transpose_a, bool transpose_b, + int64_t cube_math_type) { + // Check transpose parameters first, before shape inference + // aclnnBatchMatMul does not support transpose, so we need to check early + CheckTransposeUnsupported(transpose_a, transpose_b, "BatchMatMul"); + + // Check data types of both inputs + TypeId x1_type = x1.data_type(); + TypeId x2_type = x2.data_type(); + + // Check if data types are supported + CheckBatchMatMulDataType(x1_type, "BatchMatMul"); + CheckBatchMatMulDataType(x2_type, "BatchMatMul"); + + // Check if both inputs have the same type + if (x1_type != x2_type) { + MS_EXCEPTION(ValueError) << "For 'BatchMatMul', inputs x1 and x2 must have the same data type, " + << "but got x1 with " << TypeIdToString(x1_type) + << " and x2 with " << TypeIdToString(x2_type); + } + + auto x1_shape = x1.shape(); + auto x2_shape = x2.shape(); + + // Check that inputs must be exactly 3D tensors + if (x1_shape.size() != kBatchMatmulDims || x2_shape.size() != kBatchMatmulDims) { + MS_EXCEPTION(ValueError) << "For 'BatchMatMul', inputs must be exactly 3D tensors, " + << "but got x1 with " << x1_shape.size() << "D and x2 with " << x2_shape.size() << "D"; + } + + // Since transpose_a and transpose_b are guaranteed to be False at this point, + // we can pass false to BatchMatMulMakeShape + auto output_shape = BatchMatMulMakeShape(x1_shape, x2_shape, false, false, kBatchMatmulMatSize, "BatchMatMul"); + TypeId out_dtype = x1.data_type(); + auto out = ms::Tensor(out_dtype, output_shape); + + int8_t cube_math_type_int8 = static_cast(cube_math_type); + auto runner = std::make_shared("BatchMatMul"); + runner->SetLaunchFunc(LAUNCH_ACLNN_FUNC(aclnnBatchMatMul, x1, x2, out, cube_math_type_int8)); + runner->Run({x1, x2}, {out}); + return out; +} +} // namespace ms_custom_ops + +auto pyboost_batch_matmul(const ms::Tensor &x1, const ms::Tensor &x2, bool transpose_a, bool transpose_b, + int64_t cube_math_type) { + return ms::pynative::PyboostRunner::Call(ms_custom_ops::batch_matmul_custom, x1, x2, + transpose_a, transpose_b, cube_math_type); +} + +MS_CUSTOM_OPS_EXTENSION_MODULE(m) { + m.def("batch_matmul", &pyboost_batch_matmul, "BatchMatMul", pybind11::arg("x1"), pybind11::arg("x2"), + pybind11::arg("transpose_a") = false, pybind11::arg("transpose_b") = false, + pybind11::arg("cube_math_type") = 0); +} + diff --git a/ops/c_api/batch_matmul/batch_matmul.md b/ops/c_api/batch_matmul/batch_matmul.md new file mode 100644 index 0000000000000000000000000000000000000000..8d27950f68c4ce7986454893a178fdda658ceb97 --- /dev/null +++ b/ops/c_api/batch_matmul/batch_matmul.md @@ -0,0 +1,182 @@ +# batch_matmul 算子 + +## 描述 + +batch_matmul 算子用于执行批量矩阵乘法操作。该算子支持3维的Tensor输入,第一维是batch维度,最后两个维度做矩阵乘法。也支持其中一个输入的batch轴为1时做broadcast。 + +## 输入参数 + +| Name | DType | Shape | Optional | Inplace | Format | Description | +|----------------|--------------------|-----------|----------|---------|--------|------------------------------------------------------| +| x1 | Tensor(float16/bfloat16/float32) | 3维 | No | No | ND | 第一个输入矩阵,shape为[A, M, K]。数据类型支持说明见下方 | +| x2 | Tensor(float16/bfloat16/float32) | 3维 | No | No | ND | 第二个输入矩阵,shape为[B, K, N]。 | +| transpose_a | bool | - | Yes | - | - | 是否对 x1(A 矩阵)进行转置,默认为 False。**注意:aclnnBatchMatMul 不支持 transpose,此参数必须为 False,如果为 True 会在 shape inference 阶段抛出异常** | +| transpose_b | bool | - | Yes | - | - | 是否对 x2(B 矩阵)进行转置,默认为 False。**注意:aclnnBatchMatMul 不支持 transpose,此参数必须为 False,如果为 True 会在 shape inference 阶段抛出异常** | +| cube_math_type | int | - | Yes | - | - | Cube单元的计算逻辑,控制矩阵乘法的精度模式,默认为0。详细说明见下方"cube_math_type 参数说明" | + +## cube_math_type 参数说明 + +`cube_math_type` 参数用于控制 Ascend 芯片上 Cube 单元(矩阵乘法加速单元)的计算精度模式。不同取值对应不同的精度策略: + +- **0 (KEEP_DTYPE)**:保持输入数据类型,不进行精度转换。这是默认值,推荐在大多数场景下使用。算子会按照输入张量的原始数据类型进行计算,不会自动进行精度转换。 + +- **1 (ALLOW_FP32_DOWN_PRECISION)**:允许将 FP32 降精度到 FP16 进行计算,以提高性能。当输入为 FP32 类型时,Cube 单元会将其转换为 FP16 进行计算,从而获得更好的性能,但可能会损失一定的精度。适用于对精度要求不高的场景,或者需要平衡性能和精度的场景。 + +- **2 (USE_FP16 / FORCE_FP16)**:强制使用 FP16 精度进行计算,无论输入数据类型如何。即使输入是其他数据类型,也会强制转换为 FP16 进行计算。适用于需要统一使用 FP16 的场景,或者需要最大化性能的场景。 + +- **3 (USE_HF32 / FORCE_HF32)**:强制使用 HF32(Half Float 32)精度进行计算。HF32 是 Ascend 芯片特有的高精度浮点格式,提供比 FP16 更高的精度(接近 FP32),同时保持比 FP32 更好的性能。适用于对精度要求较高,但又需要一定性能的场景。 + +**使用建议**: + +- 默认情况下使用 `0 (KEEP_DTYPE)`,这样可以保持计算的准确性。 +- 如果对性能要求较高且可以接受一定的精度损失,可以使用 `1 (ALLOW_FP32_DOWN_PRECISION)` 或 `2 (USE_FP16)`。 +- 如果对精度要求较高,可以使用 `3 (USE_HF32)`,它在精度和性能之间提供了良好的平衡。 + +## 输出参数 + +| Name | DType | Shape | Description | +|------|--------------------|------------------------|------------------| +| y | Tensor(float16) | 符合批量矩阵乘法规则的形状 | 批量矩阵乘法的计算结果。**注意:输出类型为float16** | + +## 计算公式 + +out = x1 @ x2 + +其中: + +- x1的shape是[A, M, K],x2的shape是[A, K, N],输出out的shape是[A, M, N] +- 第一维相等,后两维做矩阵乘运算 +- 如果x1的shape是[A, M, K],x2的shape是[1, K, N],输出out的shape是[A, M, N](B矩阵第一维为1,会broadcast到A) +- 如果x1的shape是[1, M, K],x2的shape是[B, K, N],输出out的shape是[B, M, N](A矩阵第一维为1,会broadcast到B) + +## 支持产品 + +- Atlas A3 训练系列产品、Atlas A3 推理系列产品 +- Atlas A2 训练系列产品、Atlas A2 推理系列产品 +- Atlas 训练系列产品 +- Atlas 推理系列产品 + +## 约束说明 + +1. 仅支持3维的Tensor传入,第一维是batch维度,最后两个维度做矩阵乘法 +2. 支持其中一个输入的batch轴为1时做broadcast +3. **数据类型限制**:aclnnBatchMatMul 支持的数据类型取决于芯片型号: + - **Atlas A2 系列**:支持 float16、bfloat16 和 float32 + - **其他芯片(如 Atlas A3、Atlas 910 等)**:仅支持 float16 + + 如果需要在不支持的芯片上使用其他数据类型,请使用 MindSpore 内置的 BatchMatMul 算子。详细说明见下方"数据类型支持说明" +4. mat2的Reduce维度需要与self的Reduce维度大小相等 +5. **注意**:aclnnBatchMatMul 不支持 transpose 参数。如果需要在调用前转置输入,请先使用 transpose 算子对输入进行转置,然后再调用 batch_matmul(此时 transpose_a 和 transpose_b 应设置为 False)。 + +### 数据类型支持说明 + +`aclnnBatchMatMul` API 在不同芯片上的数据类型支持情况: + +- **Atlas A2 训练系列产品/Atlas A2 推理系列产品**: + - ✅ 支持 **float16**:标准的半精度浮点数,提供良好的性能和精度平衡 + - ✅ 支持 **bfloat16**:Brain Float 16,提供比 float16 更大的数值范围,在某些场景下精度更好 + - ✅ 支持 **float32**:单精度浮点数,提供最高精度,但性能相对较低 + +- **其他芯片(Atlas A3、Atlas 910 等)**: + - ✅ 支持 **float16**:标准的半精度浮点数 + - ❌ 不支持 **bfloat16**:如果使用会抛出异常 + - ❌ 不支持 **float32**:如果使用会抛出异常 + +**使用建议**: + +- 在 A2 芯片上,可以根据精度和性能需求选择合适的数据类型 +- 在其他芯片上,只能使用 float16,如果传入其他数据类型会抛出异常 +- 如果需要跨芯片兼容,建议统一使用 float16 + +### 关于 transpose_a 和 transpose_b 参数的说明 + +**为什么这两个参数必须为 False,但还要保留它们?** + +虽然 `aclnnBatchMatMul` API 不支持 transpose 操作,但 `batch_matmul` 算子仍然保留了 `transpose_a` 和 `transpose_b` 参数,原因如下: + +1. **API 兼容性**:MindSpore 内置的 `BatchMatMul` 算子也包含这两个参数。保留它们可以: + - 保持与 MindSpore 内置算子的 API 一致性 + - 方便用户从内置算子迁移到自定义算子,无需修改调用代码 + - 提供统一的接口规范 + +2. **早期错误检测**:在 shape inference 阶段就会检查这两个参数。如果用户传入 `transpose_a=True` 或 `transpose_b=True`,会在图构建阶段就抛出异常,而不是等到运行时才发现错误,这样可以更早地发现问题。 + +3. **未来扩展性**:如果将来 `aclnnBatchMatMul` API 支持 transpose 操作,可以更容易地启用这些参数,而无需修改算子接口。 + +**使用建议**: + +- 如果需要对输入进行转置,请先使用 MindSpore 的 `transpose` 算子对输入进行转置,然后再调用 `batch_matmul`(此时 `transpose_a` 和 `transpose_b` 应设置为 `False`) +- 如果直接传入 `transpose_a=True` 或 `transpose_b=True`,算子会在 shape inference 阶段抛出异常,提示用户先转置输入 + +## 使用示例 + +### 基本使用示例 + +```python +import mindspore as ms +import numpy as np +import ms_custom_ops + +ms.set_context(device_target="Ascend") + +@ms.jit +def batch_matmul_func(x1, x2, transpose_a=False, transpose_b=False, cube_math_type=0): + return ms_custom_ops.batch_matmul(x1, x2, transpose_a=transpose_a, transpose_b=transpose_b, + cube_math_type=cube_math_type) + +# 示例1:基础批量矩阵乘法 +# x1 shape: [A, M, K] = [2, 128, 256] +# x2 shape: [A, K, N] = [2, 256, 128] +# 输出 shape: [2, 128, 128] +batch = 2 +m = 128 +k = 256 +n = 128 +x1 = np.random.randn(batch, m, k).astype(np.float16) +x2 = np.random.randn(batch, k, n).astype(np.float16) + +ms_x1 = ms.Tensor(x1) +ms_x2 = ms.Tensor(x2) +output = batch_matmul_func(ms_x1, ms_x2) +print("Output shape:", output.shape) +``` + +### Broadcast 示例 + +```python +# 示例2:Broadcast batch维度 +# x1 shape: [A, M, K] = [2, 128, 256] +# x2 shape: [1, K, N] = [1, 256, 128] (batch维度为1,会broadcast) +# 输出 shape: [2, 128, 128] +x1 = np.random.randn(2, 128, 256).astype(np.float16) +x2 = np.random.randn(1, 256, 128).astype(np.float16) + +ms_x1 = ms.Tensor(x1) +ms_x2 = ms.Tensor(x2) +output = batch_matmul_func(ms_x1, ms_x2) +print("Output shape:", output.shape) +``` + +### 转置示例 + +```python +# 示例3:使用转置(需要先转置输入,然后设置 transpose_a=False, transpose_b=False) +# 原始 x1 shape: [2, 256, 128],需要转置为 [2, 128, 256] +# 原始 x2 shape: [2, 128, 256],需要转置为 [2, 256, 128] +# 输出 shape: [2, 128, 128] +x1 = np.random.randn(2, 256, 128).astype(np.float16) +x2 = np.random.randn(2, 128, 256).astype(np.float16) + +ms_x1 = ms.Tensor(x1) +ms_x2 = ms.Tensor(x2) + +# 方法1:使用 transpose 算子先转置 +x1_transposed = ms.ops.transpose(ms_x1, (0, 2, 1)) # [2, 128, 256] +x2_transposed = ms.ops.transpose(ms_x2, (0, 2, 1)) # [2, 256, 128] +output = batch_matmul_func(x1_transposed, x2_transposed, transpose_a=False, transpose_b=False) +print("Output shape:", output.shape) + +# 注意:以下代码会抛出异常,因为 aclnnBatchMatMul 不支持 transpose +# output = batch_matmul_func(ms_x1, ms_x2, transpose_a=True, transpose_b=True) # 会报错! +``` + diff --git a/ops/c_api/batch_matmul/batch_matmul_op.yaml b/ops/c_api/batch_matmul/batch_matmul_op.yaml new file mode 100644 index 0000000000000000000000000000000000000000..9e12848cdb59a983e2b1441282e4d34d90742cb2 --- /dev/null +++ b/ops/c_api/batch_matmul/batch_matmul_op.yaml @@ -0,0 +1,22 @@ +#operator batch_matmul +batch_matmul: + args: + x1: + dtype: tensor + x2: + dtype: tensor + transpose_a: + dtype: bool + default: false + transpose_b: + dtype: bool + default: false + cube_math_type: + dtype: int + default: 0 # 0: KEEP_DTYPE, 1: ALLOW_FP32_DOWN_PRECISION, 2: USE_FP16, 3: USE_HF32 + args_signature: + dtype_group: (x1, x2) + returns: + y: + dtype: tensor + diff --git a/ops/c_api/utils/common_utils.cc b/ops/c_api/utils/common_utils.cc index 4d23a84dd88cc57627b1e827be5c88eee194c980..2d325073321e04f1ca5dd8f8d96221a1011eb026 100644 --- a/ops/c_api/utils/common_utils.cc +++ b/ops/c_api/utils/common_utils.cc @@ -57,12 +57,126 @@ static bool CheckSocVersion(const std::string &expected_soc_version) { return !version.empty() && (version == expected_soc_version); } -bool IsSoc910b() { return CheckSocVersion(kAscendVersion910b); } +bool IsSoc910b() { return CheckSocVersion("ascend910b"); } -bool IsSoc910_93() { return CheckSocVersion(kAscendVersion910_93); } +bool IsSoc910_93() { return CheckSocVersion("ascend910_93"); } -bool IsSoc310p() { return CheckSocVersion(kAscendVersion310p); } -bool IsSoc910a() { return CheckSocVersion(kAscendVersion910); } +bool IsSoc310p() { return CheckSocVersion("ascend310p"); } +bool IsSoc910a() { return CheckSocVersion("ascend910"); } bool IsSoc910BC() { return IsSoc910b() || IsSoc910_93(); } + +/** + * @brief Calculate output shape for batch matrix multiplication + * @param x1_shape Shape of first input tensor + * @param x2_shape Shape of second input tensor + * @param transpose_a Whether first tensor is transposed + * @param transpose_b Whether second tensor is transposed + * @param offset Number of dimensions to consider for matrix multiplication (typically 2) + * @param op_name Operator name for error messages + * @return ShapeVector Output shape vector + */ +ShapeVector BatchMatMulMakeShape(const ShapeVector x1_shape, const ShapeVector x2_shape, + bool transpose_a, bool transpose_b, size_t offset, + const std::string &op_name) { + constexpr size_t kMatSize = 2; // Last 2 dimensions for matrix multiplication + if (x1_shape.size() < kMatSize || x2_shape.size() < kMatSize) { + MS_LOG(EXCEPTION) << "For '" << op_name << "', the dimension of 'x1' and 'x2' should be at least 2, " + << "but got " << x1_shape.size() << " and " << x2_shape.size(); + } + + ShapeVector out_shape; + ShapeVector long_shape = x1_shape.size() > x2_shape.size() ? x1_shape : x2_shape; + ShapeVector short_shape = x1_shape.size() > x2_shape.size() ? x2_shape : x1_shape; + size_t size_diff = long_shape.size() - short_shape.size(); + + // Handle batch dimensions (all dimensions except the last 2) + for (size_t i = 0; i < long_shape.size() - offset; i++) { + if (long_shape[i] < 0) { + out_shape.push_back(abstract::Shape::kShapeDimAny); + } else if (i >= size_diff) { + // Broadcast: if one dimension is 1, use the other; otherwise they must be equal + int64_t long_dim = long_shape[i]; + int64_t short_dim = short_shape[i - size_diff]; + if (long_dim == 1) { + out_shape.push_back(short_dim); + } else if (short_dim == 1) { + out_shape.push_back(long_dim); + } else if (long_dim == short_dim) { + out_shape.push_back(long_dim); + } else { + MS_LOG(EXCEPTION) << "For '" << op_name << "', batch dimensions must be equal or one of them must be 1, " + << "but got " << long_dim << " and " << short_dim; + } + } else { + out_shape.push_back(long_shape[i]); + } + } + + // Handle matrix dimensions (last 2 dimensions) + size_t x1_offset = x1_shape.size() - offset; + size_t x2_offset = x2_shape.size() - offset; + + // Output shape: [..., M, N] where M comes from x1 and N comes from x2 + int64_t x1_m = transpose_a ? x1_shape[x1_offset + 1] : x1_shape[x1_offset]; + int64_t x1_k = transpose_a ? x1_shape[x1_offset] : x1_shape[x1_offset + 1]; + int64_t x2_k = transpose_b ? x2_shape[x2_offset + 1] : x2_shape[x2_offset]; + int64_t x2_n = transpose_b ? x2_shape[x2_offset] : x2_shape[x2_offset + 1]; + + // Check K dimensions match + if (x1_k != abstract::Shape::kShapeDimAny && x2_k != abstract::Shape::kShapeDimAny && x1_k != x2_k) { + MS_LOG(EXCEPTION) << "For '" << op_name << "', the K dimension of 'x1' and 'x2' must be equal, " + << "but got " << x1_k << " and " << x2_k; + } + + out_shape.push_back(x1_m); + out_shape.push_back(x2_n); + + return out_shape; +} + +/** + * @brief Check if transpose operations are unsupported and throw exception if they are + * @param transpose_a Whether first tensor is transposed + * @param transpose_b Whether second tensor is transposed + * @param op_name Operator name for error messages + */ +void CheckTransposeUnsupported(bool transpose_a, bool transpose_b, const std::string &op_name) { + if (transpose_a || transpose_b) { + std::string error_msg = "For '" + op_name + "' with aclnnBatchMatMul, transpose_a and transpose_b must be False. " + "Please transpose the input tensors before calling this operator. " + "Got transpose_a=" + (transpose_a ? "True" : "False") + + ", transpose_b=" + (transpose_b ? "True" : "False") + "."; + MS_LOG(EXCEPTION) << error_msg; + } +} + +/** + * @brief Check if data type is supported for batch_matmul operation + * @param dtype Data type to check + * @param op_name Operator name for error messages + * + * Atlas A2 series support float16, bfloat16, and float32 + * Other chips (like Atlas A3, Atlas 910, etc.) only support float16 + */ +void CheckBatchMatMulDataType(TypeId dtype, const std::string &op_name) { + // Check if it's A2 series chip + bool is_a2_series = IsSoc910b() || IsSoc910_93(); + + if (is_a2_series) { + // A2 series support float16, bfloat16, and float32 + if (dtype != TypeId::kNumberTypeFloat16 && dtype != TypeId::kNumberTypeBFloat16 && + dtype != TypeId::kNumberTypeFloat32) { + MS_LOG(EXCEPTION) << "For '" << op_name << "' on Atlas A2 series, only float16, bfloat16, and float32 are " + << "supported, but got " << TypeIdToString(dtype); + } + } else { + // Other chips only support float16 + if (dtype != TypeId::kNumberTypeFloat16) { + MS_LOG(EXCEPTION) << "For '" << op_name << "' on this chip, only float16 is supported, " + << "but got " << TypeIdToString(dtype) << ". " + << "Please use MindSpore built-in BatchMatMul operator for other data types."; + } + } +} } // namespace ms_custom_ops diff --git a/ops/c_api/utils/common_utils.h b/ops/c_api/utils/common_utils.h index c4d6468c93d585db60d542da07badb966db3df2a..b4a55048853abc27c82ef0099377f848f8d1aab8 100644 --- a/ops/c_api/utils/common_utils.h +++ b/ops/c_api/utils/common_utils.h @@ -28,6 +28,22 @@ bool IsSoc910_93(); bool IsSoc310p(); bool IsSoc910a(); bool IsSoc910BC(); + +// Common function for batch matrix multiplication shape inference +// Used by both batch_matmul and quant_batch_matmul operators +ShapeVector BatchMatMulMakeShape(const ShapeVector x1_shape, const ShapeVector x2_shape, + bool transpose_a, bool transpose_b, size_t offset, + const std::string &op_name = "BatchMatMul"); + +// Common function to check if transpose is unsupported +// Used by batch_matmul operator to ensure transpose_a and transpose_b are False +// since aclnnBatchMatMul does not support transpose operations +void CheckTransposeUnsupported(bool transpose_a, bool transpose_b, const std::string &op_name = "BatchMatMul"); + +// Common function to check if data type is supported for batch_matmul +// Atlas A2 series support float16, bfloat16, and float32 +// Other chips (like Atlas A3, Atlas 910, etc.) only support float16 +void CheckBatchMatMulDataType(TypeId dtype, const std::string &op_name = "BatchMatMul"); } // namespace ms_custom_ops #endif // __MS_CUSTOM_OPS_C_API_UTILS_COMMON_UTILS_H_ diff --git a/tests/st/test_custom_batch_matmul.py b/tests/st/test_custom_batch_matmul.py new file mode 100644 index 0000000000000000000000000000000000000000..dceef5f635c8239c0ff71f5d3f31ed1f32625ee1 --- /dev/null +++ b/tests/st/test_custom_batch_matmul.py @@ -0,0 +1,296 @@ +# Copyright 2025 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +""" tests batch_matmul """ + +from functools import wraps + +import numpy as np +import pytest +import mindspore as ms +from mindspore import Tensor, context +from st_utils import custom_compare +import ms_custom_ops + + +def jit(func): + @wraps(func) + def decorator(*args, **kwargs): + if ms.get_context("mode") == ms.PYNATIVE_MODE: + return func(*args, **kwargs) + return ms.jit(func, jit_level="O0", infer_boost="on")(*args, **kwargs) + return decorator + + +class BatchMatmulNet(ms.nn.Cell): + def __init__(self): + super().__init__() + self.batch_matmul = ms_custom_ops.batch_matmul + + @jit + def construct(self, x1, x2, transpose_a=False, transpose_b=False, cube_math_type=0): + out = self.batch_matmul(x1, x2, transpose_a, transpose_b, cube_math_type) + return out + + +@pytest.mark.level0 +@pytest.mark.env_onecard +@pytest.mark.platform_ascend910b +@pytest.mark.platform_ascend310p +@pytest.mark.parametrize('exec_mode', [context.GRAPH_MODE, context.PYNATIVE_MODE]) +@pytest.mark.parametrize('mstype', [ms.float16]) +def test_custom_batch_matmul_basic(exec_mode, mstype): + """ + Feature: Test batch_matmul basic functionality. + Description: Test batch_matmul operation with 3D tensors. + Expectation: Assert that results are consistent with expected. + """ + ms.set_device("Ascend") + if exec_mode == context.GRAPH_MODE: + ms.set_context(mode=exec_mode, jit_syntax_level=ms.STRICT) + else: + ms.set_context(mode=exec_mode) + batch_matmul = BatchMatmulNet() + + batch = 2 + m = 128 + k = 256 + n = 128 + x1 = np.random.randn(batch, m, k).astype(np.float16) + x2 = np.random.randn(batch, k, n).astype(np.float16) + + expected = np.matmul(x1.astype(np.float32), x2.astype(np.float32)) + output = batch_matmul(Tensor(x1, mstype), Tensor(x2, mstype)) + output_np = output.astype(ms.float32).asnumpy() + + res = custom_compare(output_np, expected, mstype) + assert res, "batch_matmul compare fail." + + +@pytest.mark.level0 +@pytest.mark.env_onecard +@pytest.mark.platform_ascend910b +@pytest.mark.platform_ascend310p +@pytest.mark.parametrize('exec_mode', [context.GRAPH_MODE, context.PYNATIVE_MODE]) +@pytest.mark.parametrize('mstype', [ms.float16]) +def test_custom_batch_matmul_broadcast(exec_mode, mstype): + """ + Feature: Test batch_matmul with broadcast. + Description: Test batch_matmul operation when one input has batch dimension 1. + Expectation: Assert that results are consistent with expected. + """ + ms.set_device("Ascend") + if exec_mode == context.GRAPH_MODE: + ms.set_context(mode=exec_mode, jit_syntax_level=ms.STRICT) + else: + ms.set_context(mode=exec_mode) + batch_matmul = BatchMatmulNet() + + batch = 3 + m = 64 + k = 128 + n = 64 + # x1 has batch dimension 1, should broadcast to batch + x1 = np.random.randn(1, m, k).astype(np.float16) + x2 = np.random.randn(batch, k, n).astype(np.float16) + expected = np.matmul(x1.astype(np.float32), x2.astype(np.float32)) + + output = batch_matmul(Tensor(x1, mstype), Tensor(x2, mstype)) + output_np = output.astype(ms.float32).asnumpy() + + res = custom_compare(output_np, expected, mstype) + assert res, "batch_matmul broadcast compare fail." + + +@pytest.mark.level0 +@pytest.mark.env_onecard +@pytest.mark.platform_ascend910b +@pytest.mark.platform_ascend310p +@pytest.mark.parametrize('exec_mode', [context.GRAPH_MODE, context.PYNATIVE_MODE]) +@pytest.mark.parametrize('mstype', [ms.float16]) +def test_custom_batch_matmul_broadcast_reverse(exec_mode, mstype): + """ + Feature: Test batch_matmul with broadcast (reverse direction). + Description: Test batch_matmul operation when x2 has batch dimension 1. + Expectation: Assert that results are consistent with expected. + """ + ms.set_device("Ascend") + if exec_mode == context.GRAPH_MODE: + ms.set_context(mode=exec_mode, jit_syntax_level=ms.STRICT) + else: + ms.set_context(mode=exec_mode) + batch_matmul = BatchMatmulNet() + + batch = 3 + m = 64 + k = 128 + n = 64 + # x2 has batch dimension 1, should broadcast to batch + x1 = np.random.randn(batch, m, k).astype(np.float16) + x2 = np.random.randn(1, k, n).astype(np.float16) + expected = np.matmul(x1.astype(np.float32), x2.astype(np.float32)) + + output = batch_matmul(Tensor(x1, mstype), Tensor(x2, mstype)) + output_np = output.astype(ms.float32).asnumpy() + + res = custom_compare(output_np, expected, mstype) + assert res, "batch_matmul broadcast reverse compare fail." + + +@pytest.mark.level0 +@pytest.mark.env_onecard +@pytest.mark.platform_ascend910b +@pytest.mark.platform_ascend310p +@pytest.mark.parametrize('exec_mode', [context.GRAPH_MODE, context.PYNATIVE_MODE]) +@pytest.mark.parametrize('mstype', [ms.float16]) +@pytest.mark.parametrize('cube_math_type', [0, 1]) +def test_custom_batch_matmul_cube_math_type(exec_mode, mstype, cube_math_type): + """ + Feature: Test batch_matmul with different cube_math_type. + Description: Test batch_matmul operation with different cube math types. + Expectation: Assert that results are consistent with expected. + """ + ms.set_device("Ascend") + if exec_mode == context.GRAPH_MODE: + ms.set_context(mode=exec_mode, jit_syntax_level=ms.STRICT) + else: + ms.set_context(mode=exec_mode) + batch_matmul = BatchMatmulNet() + + batch = 2 + m = 128 + k = 256 + n = 128 + x1 = np.random.randn(batch, m, k).astype(np.float16) + x2 = np.random.randn(batch, k, n).astype(np.float16) + expected = np.matmul(x1.astype(np.float32), x2.astype(np.float32)) + + output = batch_matmul(Tensor(x1, mstype), Tensor(x2, mstype), cube_math_type=cube_math_type) + output_np = output.astype(ms.float32).asnumpy() + + res = custom_compare(output_np, expected, mstype) + assert res, f"batch_matmul cube_math_type={cube_math_type} compare fail." + + +@pytest.mark.level0 +@pytest.mark.env_onecard +@pytest.mark.platform_ascend910b +@pytest.mark.platform_ascend310p +@pytest.mark.parametrize('exec_mode', [context.GRAPH_MODE, context.PYNATIVE_MODE]) +@pytest.mark.parametrize('mstype', [ms.float16]) +def test_custom_batch_matmul_large_batch(exec_mode, mstype): + """ + Feature: Test batch_matmul with large batch size. + Description: Test batch_matmul operation with larger batch dimension. + Expectation: Assert that results are consistent with expected. + """ + ms.set_device("Ascend") + if exec_mode == context.GRAPH_MODE: + ms.set_context(mode=exec_mode, jit_syntax_level=ms.STRICT) + else: + ms.set_context(mode=exec_mode) + batch_matmul = BatchMatmulNet() + + batch = 8 + m = 64 + k = 128 + n = 64 + x1 = np.random.randn(batch, m, k).astype(np.float16) + x2 = np.random.randn(batch, k, n).astype(np.float16) + expected = np.matmul(x1.astype(np.float32), x2.astype(np.float32)) + + output = batch_matmul(Tensor(x1, mstype), Tensor(x2, mstype)) + output_np = output.astype(ms.float32).asnumpy() + + res = custom_compare(output_np, expected, mstype) + assert res, "batch_matmul large batch compare fail." + + +@pytest.mark.level0 +@pytest.mark.env_onecard +@pytest.mark.platform_ascend910b +@pytest.mark.platform_ascend310p +@pytest.mark.parametrize('exec_mode', [context.GRAPH_MODE, context.PYNATIVE_MODE]) +@pytest.mark.parametrize('mstype', [ms.float16]) +def test_custom_batch_matmul_different_shapes(exec_mode, mstype): + """ + Feature: Test batch_matmul with different matrix shapes. + Description: Test batch_matmul operation with various M, K, N dimensions. + Expectation: Assert that results are consistent with expected. + """ + ms.set_device("Ascend") + if exec_mode == context.GRAPH_MODE: + ms.set_context(mode=exec_mode, jit_syntax_level=ms.STRICT) + else: + ms.set_context(mode=exec_mode) + batch_matmul = BatchMatmulNet() + + batch = 2 + m = 32 + k = 64 + n = 96 + x1 = np.random.randn(batch, m, k).astype(np.float16) + x2 = np.random.randn(batch, k, n).astype(np.float16) + expected = np.matmul(x1.astype(np.float32), x2.astype(np.float32)) + + output = batch_matmul(Tensor(x1, mstype), Tensor(x2, mstype)) + output_np = output.astype(ms.float32).asnumpy() + + res = custom_compare(output_np, expected, mstype) + assert res, "batch_matmul different shapes compare fail." + + +@pytest.mark.level2 +@pytest.mark.env_onecard +@pytest.mark.platform_ascend910b +@pytest.mark.platform_ascend310p +@pytest.mark.parametrize('exec_mode', [context.GRAPH_MODE]) +@pytest.mark.parametrize('mstype', [ms.float16]) +def test_custom_batch_matmul_dynamic_shape(exec_mode, mstype): + """ + Feature: Test batch_matmul with dynamic shape. + Description: Test batch_matmul operation with dynamic batch dimension (Graph mode only). + Expectation: Assert that results are consistent with expected. + """ + ms.set_device("Ascend") + if exec_mode == context.GRAPH_MODE: + ms.set_context(mode=exec_mode, jit_syntax_level=ms.STRICT) + else: + ms.set_context(mode=exec_mode) + batch_matmul = BatchMatmulNet() + + m = 64 + k = 128 + n = 64 + + # Set dynamic input shape + # Note: set_inputs only accepts tensor parameters, scalar parameters (transpose_a, transpose_b, cube_math_type) + # are handled via default values in the construct method + x1_dyn = Tensor(shape=[None, m, k], dtype=mstype) + x2_dyn = Tensor(shape=[None, k, n], dtype=mstype) + # Only pass tensor parameters to set_inputs + batch_matmul.set_inputs(x1_dyn, x2_dyn) + + # Test with different batch sizes + for batch in [1, 2, 4]: + x1 = np.random.randn(batch, m, k).astype(np.float16) + x2 = np.random.randn(batch, k, n).astype(np.float16) + expected = np.matmul(x1.astype(np.float32), x2.astype(np.float32)) + + output = batch_matmul(Tensor(x1, mstype), Tensor(x2, mstype)) + output_np = output.astype(ms.float32).asnumpy() + + res = custom_compare(output_np, expected, mstype) + assert res, f"batch_matmul dynamic shape batch={batch} compare fail."