diff --git a/inferrt/python/mrt/torch/fx_backend.py b/inferrt/python/mrt/torch/fx_backend.py index 7d2e5042b780d2db4706118e3cfefe648b18d4ef..dddb3d909cf42d6e4a64bccf30dda0935b43c7c1 100644 --- a/inferrt/python/mrt/torch/fx_backend.py +++ b/inferrt/python/mrt/torch/fx_backend.py @@ -206,6 +206,10 @@ def _flatten_args(op: Op, node: Node) -> List[Argument]: kwargs = node.kwargs if not kwargs: return flat_args + # if kwargs has only one element, add the value to flat_args and return + if len(kwargs) == 1: + flat_args.append(list(kwargs.values())[0]) + return flat_args if not isinstance(node.target, OpOverloadPacket): raise RuntimeError( f"Unsupported node target for keyword only args: {node.target}" diff --git a/inferrt/src/ops/ascend/aclnn/utils/aclnn_converter.cc b/inferrt/src/ops/ascend/aclnn/utils/aclnn_converter.cc index 2faa7f5ba31ea6f264b49d99bbcc40c4528e656e..39165e9109e791f636f4b8969ee8dbaa2c177f58 100644 --- a/inferrt/src/ops/ascend/aclnn/utils/aclnn_converter.cc +++ b/inferrt/src/ops/ascend/aclnn/utils/aclnn_converter.cc @@ -47,5 +47,31 @@ aclDataType Convert(ir::DataType::Type dtype) { return ret; } +template +aclScalar *CreateAclScalar(T val, aclDataType dtype) { + static const auto aclCreateScalar = GET_ACLNN_COMMON_META_FUNC(aclCreateScalar); + CHECK_IF_NULL(aclCreateScalar); + return aclCreateScalar(&val, dtype); +} + +aclScalar *Convert(const ir::Value *value) { + if (value == nullptr) { + return nullptr; + } + if (value->IsInt()) { + return CreateAclScalar(value->ToInt(), ACL_INT64); + } + if (value->IsFloat()) { + return CreateAclScalar(value->ToFloat(), ACL_FLOAT); + } + if (value->IsDouble()) { + return CreateAclScalar(value->ToDouble(), ACL_DOUBLE); + } + if (value->IsBool()) { + return CreateAclScalar(value->ToBool(), ACL_BOOL); + } + LOG_EXCEPTION << "Invalid value: " << value; + return nullptr; +} } // namespace ops } // namespace mrt diff --git a/inferrt/src/ops/ascend/aclnn/utils/aclnn_converter.h b/inferrt/src/ops/ascend/aclnn/utils/aclnn_converter.h index 4a2c16f45b4dfe0005ab5be7714e2a4430abb856..ae18e272b1599735c50f67d82a5db6c15029f4f4 100644 --- a/inferrt/src/ops/ascend/aclnn/utils/aclnn_converter.h +++ b/inferrt/src/ops/ascend/aclnn/utils/aclnn_converter.h @@ -37,6 +37,9 @@ namespace ops { // Convert dtype DA_API aclDataType Convert(ir::DataType::Type dtype); +// Convert value to aclScalar +DA_API aclScalar *Convert(const ir::Value *value); + // Convert tensor inline aclTensor *Convert(const ir::TensorPtr &tensor) { static const auto aclCreateTensor = GET_ACLNN_COMMON_META_FUNC(aclCreateTensor); diff --git a/mopt/include/mopt/Dialect/Mrt/Mrt.td b/mopt/include/mopt/Dialect/Mrt/Mrt.td index 836a6d2dc41e25531f1309fd383231877bd6a0de..dcae9f4a8984275a081443b00f93d5e5ca18d17d 100644 --- a/mopt/include/mopt/Dialect/Mrt/Mrt.td +++ b/mopt/include/mopt/Dialect/Mrt/Mrt.td @@ -24,4 +24,4 @@ include "mopt/Dialect/Mrt/MrtTypes.td" include "mopt/Dialect/Mrt/MrtConstantOps.td" include "mopt/Dialect/Mrt/MrtOps.td" -#endif // MRT_DIALECT_MRT_TD \ No newline at end of file +#endif // MRT_DIALECT_MRT_TD diff --git a/mopt/include/mopt/Dialect/Mrt/MrtConstantOps.td b/mopt/include/mopt/Dialect/Mrt/MrtConstantOps.td index 87fb18af9c5bd455694c589f342da3bab6780efa..0c04e880cb5858ec5a727cb22537e7bc23d417cb 100644 --- a/mopt/include/mopt/Dialect/Mrt/MrtConstantOps.td +++ b/mopt/include/mopt/Dialect/Mrt/MrtConstantOps.td @@ -175,4 +175,25 @@ def Mrt_CreateDtypeOp : Mrt_Op<"constant.dtype", [Pure]> { let assemblyFormat = "$value attr-dict `:` type($result)"; } +// Create a scalar value from an attribute and element type +// The created value can be modified at runtime through other operations. +// This operation creates a Mrt_ScalarType with the specified element type. +def Mrt_CreateScalarOp : Mrt_Op<"constant.scalar", [Pure]> { + let summary = "Create a scalar value from an attribute and element type"; + let description = [{ + Creates a value of Mrt_ScalarType from a value attribute and an element type. + The value attribute can be F32Attr, F64Attr, I64Attr, etc., and the elementType + specifies the data type of the scalar (e.g., f32, f64, i64). + The created value can be modified at runtime through other operations. + }]; + + let arguments = (ins + AnyAttr:$value, + TypeAttr:$elementType + ); + let results = (outs Mrt_ScalarType:$result); + + let assemblyFormat = "$value `,` $elementType attr-dict `:` type($result)"; +} + #endif // MRT_DIALECT_MRT_CONSTANT_OPS_TD diff --git a/mopt/include/mopt/Dialect/Mrt/MrtOps.td b/mopt/include/mopt/Dialect/Mrt/MrtOps.td index eab175492d9de79dd239857171ca5fe80e6fe650..5e65d121f80ab7caf26f58bafbeb8fd2df84bf86 100644 --- a/mopt/include/mopt/Dialect/Mrt/MrtOps.td +++ b/mopt/include/mopt/Dialect/Mrt/MrtOps.td @@ -28,6 +28,28 @@ include "mopt/Dialect/Mrt/MrtTypes.td" // the file organized and easy to navigate. //===----------------------------------------------------------------------===// +def Mrt_AddOp : Mrt_Op<"add", [Pure]> { + let summary = "aclnnAdd"; + let description = [{ + Element-wise addition of two tensors with an optional alpha scaling factor. + Computes: result = x + y * alpha + }]; + + let arguments = (ins + MrtAnyTensor:$x, + MrtAnyTensor:$y, + Mrt_ScalarType:$alpha + ); + + let results = (outs MrtAnyTensor:$result); + + let assemblyFormat = [{ + $x `,` $y `,` $alpha attr-dict `:` functional-type(operands, results) + }]; + + let generated = 1; +} + def Mrt_AddRmsNormOp : Mrt_Op<"add_rms_norm", [Pure]> { let summary = "aclnnAddRmsNorm"; let description = [{ @@ -57,7 +79,6 @@ def Mrt_AddRmsNormOp : Mrt_Op<"add_rms_norm", [Pure]> { let generated = 1; } - def Mrt_ApplyRotaryPosEmbOp : Mrt_Op<"apply_rotary_pos_emb"> { let summary = "aclnnApplyRotaryPosEmb"; let description = [{ diff --git a/mopt/include/mopt/Dialect/Mrt/MrtTypes.td b/mopt/include/mopt/Dialect/Mrt/MrtTypes.td index fbaf37e58592cd391c116331b5ff686dc85355f4..350cfefef5e31c989bb82aedca60f6f7e25d8f49 100644 --- a/mopt/include/mopt/Dialect/Mrt/MrtTypes.td +++ b/mopt/include/mopt/Dialect/Mrt/MrtTypes.td @@ -307,4 +307,27 @@ def Mrt_DtypeType : Mrt_Type<"Dtype", "dtype"> { }]; } +// Scalar type (stores mlir::Type for element type) +// This type represents a single-element scalar value with a specific data type. +// Use this type as an operand to pass scalar values of different types to operations. +// For example, it can be used for the alpha parameter in Add operation. +def Mrt_ScalarType : Mrt_Type<"Scalar", "scalar"> { + let summary = "Scalar type for MRT dialect"; + let description = [{ + Represents a single-element scalar value with a specific data type. + This type is used as an operand to pass scalar values of different types + (e.g., f32, f64, i64) to operations. The element type specifies the data type + of the scalar value. This type is distinct from tensor types and can be used + to pass scalar values like alpha parameter in Add operation. + }]; + + let parameters = (ins + "mlir::Type":$elementType + ); + + let assemblyFormat = [{ + `<` $elementType `>` + }]; +} + #endif // MRT_DIALECT_MRT_TYPES_TD diff --git a/mopt/include/mopt/Dialect/Mrt/Transforms/CMakeLists.txt b/mopt/include/mopt/Dialect/Mrt/Transforms/CMakeLists.txt index 2e6d0c56953a9125a1ed5fbd88e4370520a3db29..2bb19f3c32c2f70601af36d336cb6d08e4f0e3a0 100644 --- a/mopt/include/mopt/Dialect/Mrt/Transforms/CMakeLists.txt +++ b/mopt/include/mopt/Dialect/Mrt/Transforms/CMakeLists.txt @@ -2,4 +2,3 @@ set(LLVM_TARGET_DEFINITIONS Passes.td) mlir_tablegen(Passes.h.inc -gen-pass-decls -name MrtTransforms) add_public_tablegen_target(MLIRMrtTransformsPassIncGen) - diff --git a/scripts/codegen/gen_aclnn_ops.py b/scripts/codegen/gen_aclnn_ops.py index 2315f716dd08610c0771d7332f1c36394c95e4a4..920e5db1b0610f926577e87167747457ea05abf4 100644 --- a/scripts/codegen/gen_aclnn_ops.py +++ b/scripts/codegen/gen_aclnn_ops.py @@ -30,6 +30,7 @@ def _convert_type(input_type: str) -> str: 'Mrt_F32Type': '->ToFloat()', 'Mrt_F64Type': '->ToDouble()', 'Mrt_StringType': '->ToString()', + 'Mrt_ScalarType': '', # do nothing here, the ValuePtr will be converted to aclScalar } tuple_types_set = { 'Mrt_I64ArrayType', diff --git a/tests/st/inferrt/ops/test_aclnn_add.py b/tests/st/inferrt/ops/test_aclnn_add.py new file mode 100644 index 0000000000000000000000000000000000000000..7cdefd6103b6746194ef7b83345fddb5f152b9e4 --- /dev/null +++ b/tests/st/inferrt/ops/test_aclnn_add.py @@ -0,0 +1,65 @@ +import pytest +import numpy as np +import torch + +from tests.mark_utils import arg_mark +from tests.ops_utils import AssertRtolEqual +from mrt.torch import backend + + +def op_func(x1, x2, alpha): + """op function for add""" + return x1 + alpha * x2 + + +def add_alpha_2(x1, x2): + """custom op function with alpha=2""" + return torch.add(x1, x2, alpha=2) + + +def add_alpha_0_5(x1, x2): + """custom op function with alpha=0.5""" + return torch.add(x1, x2, alpha=0.5) + + +def add_forward(dtype, shape, alpha, op_func_compiled): + """ + add forward function + Args: + dtype: The data type of the input. + alpha: The alpha value in add. + op_func_compiled: The compiled op function. + """ + if dtype == np.float16: + prec = 0.001 + else: + prec = 0.0001 + + cpu_input0 = np.random.uniform(-1, 1, shape).astype(dtype) + cpu_input1 = np.random.uniform(-1, 1, shape).astype(dtype) + npu_input0 = torch.from_numpy(cpu_input0).npu() + npu_input1 = torch.from_numpy(cpu_input1).npu() + + cpu_output = op_func(cpu_input0, cpu_input1, alpha) + npu_output = op_func_compiled(npu_input0, npu_input1).detach().cpu().numpy() + + AssertRtolEqual(cpu_output, npu_output, prec) + + +@arg_mark(plat_marks=["platform_ascend"], level_mark="level0", card_mark="onecard", essential_mark="essential") +@pytest.mark.parametrize("pipeline", (True, False)) +@pytest.mark.parametrize("shape", [(1024, 1024), (256, 512)]) +@pytest.mark.parametrize("op_func,alpha", [ + (add_alpha_2, 2), + (add_alpha_0_5, 0.5) +]) +def test_add_fp32(pipeline, shape, op_func, alpha, monkeypatch): + """ + Feature: Test aclnn add + Description: Test aclnn add with fp32 inputs and different alpha types + Expectation: The result is correct + """ + if pipeline: + monkeypatch.setenv("MRT_ENABLE_PIPELINE", "on") + op_func_compiled = torch.compile(op_func, backend=backend) + add_forward(np.float32, shape, alpha, op_func_compiled)