diff --git a/inferrt/python/mrt/torch/fx_mlir_backend.py b/inferrt/python/mrt/torch/fx_mlir_backend.py index f35105ae2969078407498b15b0d41e1cad41a3de..6096626aa1e9d91893bbd6b64d76bdec2b20b0b3 100644 --- a/inferrt/python/mrt/torch/fx_mlir_backend.py +++ b/inferrt/python/mrt/torch/fx_mlir_backend.py @@ -148,13 +148,10 @@ def backend(gm: torch.fx.GraphModule, example_inputs: List[torch.Tensor]): m = apply_decompositions(gm, fake_inputs) _print_verbose("FX Graph After Decompositions", dump_func=m.print_readable) - # Convert FX GraphModule to torch_mlir RAW module - torch_mlir_module = _convert_to_torch_mlir(m) - _print_verbose("Torch-MLIR RAW Module", torch_mlir_module) - # Serialize torch_mlir module to IR text - mlir_module = _parse_mlir_module_from_text(str(torch_mlir_module)) - _print_verbose("Re-parsed MLIR Module (torch_mlir RAW, before passes)", mlir_module) + mlir_module = _convert_to_torch_mlir(m) + mlir_module = _parse_mlir_module_from_text(str(mlir_module)) + _print_verbose("Torch-MLIR Raw Module (Re-parsed)", mlir_module) # Run pass pipeline to convert torch_mlir RAW to TORCH backend # pylint: disable=import-outside-toplevel @@ -163,7 +160,7 @@ def backend(gm: torch.fx.GraphModule, example_inputs: List[torch.Tensor]): with mlir_module.context: pm = PassManager.parse( - "builtin.module(torchdynamo-export-to-torch-backend-pipeline)" + "builtin.module(torchdynamo-export-to-torch-backend-pipeline{decompose-complex-ops=false})" ) pm.run(mlir_module.operation) _print_verbose("Torch Backend IR", mlir_module) diff --git a/mopt/include/mopt/Conversion/MrtTypeConverter.h b/mopt/include/mopt/Conversion/MrtTypeConverter.h index 0d93380c61df27e44d2013c9419cc435e8dd1df9..08be13bbce96e1f8a1e34160a76ae1cabbd253c9 100644 --- a/mopt/include/mopt/Conversion/MrtTypeConverter.h +++ b/mopt/include/mopt/Conversion/MrtTypeConverter.h @@ -69,6 +69,37 @@ inline void populateMrtTypeConversions(mlir::TypeConverter &converter) { populateMrtScalarTypeConversions(converter); } +// Populate generic materializations using UnrealizedConversionCastOp. +// This adds target and source materializations that create temporary +// type conversion placeholders. These casts should be eliminated after +// the conversion is complete. +// +// Use this for any TypeConverter that needs to support partial conversions +// where not all operations may be immediately convertible. +inline void populateMrtGenericMaterializations(mlir::TypeConverter &converter) { + // Target materialization: source type -> target type + // Used when converting a value to the target dialect's type system + // This handles both operation results and block arguments + converter.addTargetMaterialization( + [](mlir::OpBuilder &builder, mlir::Type resultType, mlir::ValueRange inputs, mlir::Location loc) -> mlir::Value { + if (inputs.size() != 1) { + return {}; + } + return mlir::UnrealizedConversionCastOp::create(builder, loc, resultType, inputs).getResult(0); + }); + + // Source materialization: target type -> source type + // Used when a value needs to be converted back to the source type + // (e.g., for unconverted uses) + converter.addSourceMaterialization( + [](mlir::OpBuilder &builder, mlir::Type resultType, mlir::ValueRange inputs, mlir::Location loc) -> mlir::Value { + if (inputs.size() != 1) { + return {}; + } + return mlir::UnrealizedConversionCastOp::create(builder, loc, resultType, inputs).getResult(0); + }); +} + } // namespace mrt #endif // MOPT_CONVERSION_MRT_TYPE_CONVERTER_H diff --git a/mopt/include/mopt/Dialect/Mrt/MrtOps.td b/mopt/include/mopt/Dialect/Mrt/MrtOps.td index 074d4f295e63fbd55bac091e3d7c65ba9e4c010f..9ffe0fe4f0e269af114f12503784948dd8712bda 100644 --- a/mopt/include/mopt/Dialect/Mrt/MrtOps.td +++ b/mopt/include/mopt/Dialect/Mrt/MrtOps.td @@ -289,6 +289,7 @@ def Mrt_DivModOp : Mrt_Op<"div_mod", [Pure]> { MrtAnyTensor:$out ); + let assemblyFormat = [{ $self `,` $other `,` $mode attr-dict `:` functional-type(operands, results) @@ -807,6 +808,8 @@ def Mrt_SigmoidOp : Mrt_Op<"sigmoid", [Pure]> { let assemblyFormat = [{ $input attr-dict `:` functional-type(operands, results) }]; + + let genAclnnOp = "aclnnSigmoid"; } def Mrt_SoftmaxOp : Mrt_Op<"softmax", [Pure]> { diff --git a/mopt/lib/Conversion/StablehloToMrt/StablehloToMrt.cc b/mopt/lib/Conversion/StablehloToMrt/StablehloToMrt.cc index 89884afc1460651d40a95829966e8d1b476715ca..23ed67135ec36ec5ed63067fca3867d22c58d73a 100644 --- a/mopt/lib/Conversion/StablehloToMrt/StablehloToMrt.cc +++ b/mopt/lib/Conversion/StablehloToMrt/StablehloToMrt.cc @@ -75,6 +75,8 @@ class StablehloToMrtTypeConverter : public mlir::TypeConverter { explicit StablehloToMrtTypeConverter(mlir::MLIRContext *ctx) { addConversion([](mlir::Type type) { return type; }); mrt::populateMrtTypeConversions(*this); + // Add generic materializations for handling partial conversions + mrt::populateMrtGenericMaterializations(*this); } }; diff --git a/mopt/lib/Conversion/TorchToMrt/TorchAtenToMrt.cc b/mopt/lib/Conversion/TorchToMrt/TorchAtenToMrt.cc index df47ebf7fe0b278416e05cf60485f3444958d781..cde58bb706edb4bb0e41ed01318f583540173a58 100644 --- a/mopt/lib/Conversion/TorchToMrt/TorchAtenToMrt.cc +++ b/mopt/lib/Conversion/TorchToMrt/TorchAtenToMrt.cc @@ -108,6 +108,7 @@ struct ConvertAtenDivTensorMode : public OpConversionPattern(patternList, converter_); + mlir::populateReturnOpTypeConversionPattern(patternList, converter_); + mlir::populateCallOpTypeConversionPattern(patternList, converter_); // Torch constant patterns mlir::populateTorchConstantToMrtPatterns(converter_, patternList); @@ -112,8 +116,14 @@ struct ConvertTorchToMRTPass : public mlir::PassWrapper(); - // torch.constant.none is legal (unused constants will be eliminated by DCE) + // mark some torch ops as legal, these ops will be eliminated by DCE or used by pattern matching target.addLegalOp(); + target.addLegalOp(); + + // UnrealizedConversionCastOp is a temporary operation created during conversion + // It should be legal during partial conversion + target.addLegalOp(); + target.addLegalDialect(); target.addDynamicallyLegalOp( [&](mlir::func::FuncOp op) { return converter_.isSignatureLegal(op.getFunctionType()); }); diff --git a/mopt/lib/Conversion/TorchToMrt/TorchToMrtAten.pdll b/mopt/lib/Conversion/TorchToMrt/TorchToMrtAten.pdll index c84299496cd7949235ddd31dfe9661e783f356f0..5baf4abf40f3488e2ba0669fbfe10f8966e01414 100644 --- a/mopt/lib/Conversion/TorchToMrt/TorchToMrtAten.pdll +++ b/mopt/lib/Conversion/TorchToMrt/TorchToMrtAten.pdll @@ -13,6 +13,12 @@ Pattern ConvertAtenBitwiseNot { replace root with op(convertValues(operands)) -> (convertType(resType)); } +Pattern ConvertAtenCat { + let argsList = op(args: ValueRange); + let root = op(argsList, dim: Value) -> (resType: Type); + replace root with op(convertValues(args), convertValue(dim)) -> (convertType(resType)); +} + Pattern ConvertAtenDivTensor { let root = op(operands: ValueRange) -> (resType: Type); replace root with op(convertValues(operands)) -> (convertType(resType)); diff --git a/tests/st/inferrt/ops/test_aclnn_add.py b/tests/st/inferrt/ops/test_aclnn_add.py index 852e16b3fc8fddbd8ab8c9f21d56c3e2c9963c46..9847ba0581f95140569e9eac96924f3b965a11a2 100644 --- a/tests/st/inferrt/ops/test_aclnn_add.py +++ b/tests/st/inferrt/ops/test_aclnn_add.py @@ -22,6 +22,11 @@ def add_alpha_0_5(x1, x2): return torch.add(x1, x2, alpha=0.5) +def add_no_alpha(x1, x2): + """custom op function without alpha, using + operator""" + return x1 + x2 + + def add_forward(dtype, shape, alpha, op_func_compiled): """ add forward function @@ -30,13 +35,18 @@ def add_forward(dtype, shape, alpha, op_func_compiled): alpha: The alpha value in add. op_func_compiled: The compiled op function. """ - if dtype == np.float16: - prec = 0.001 + if np.issubdtype(dtype, np.integer): + cpu_input0 = np.random.randint(-100, 100, shape).astype(dtype) + cpu_input1 = np.random.randint(-100, 100, shape).astype(dtype) + prec = 0 else: - prec = 0.0001 + if dtype == np.float16: + prec = 0.001 + else: + prec = 0.0001 + cpu_input0 = np.random.uniform(-1, 1, shape).astype(dtype) + cpu_input1 = np.random.uniform(-1, 1, shape).astype(dtype) - cpu_input0 = np.random.uniform(-1, 1, shape).astype(dtype) - cpu_input1 = np.random.uniform(-1, 1, shape).astype(dtype) npu_input0 = torch.from_numpy(cpu_input0).npu() npu_input1 = torch.from_numpy(cpu_input1).npu() @@ -50,6 +60,7 @@ def add_forward(dtype, shape, alpha, op_func_compiled): @pytest.mark.parametrize("pipeline", (True, False)) @pytest.mark.parametrize("shape", [(1024, 1024), (256, 512)]) @pytest.mark.parametrize("op_func,alpha", [ + (add_no_alpha, 1.0), (add_alpha_2, 2), (add_alpha_0_5, 0.5) ]) @@ -63,3 +74,22 @@ def test_add_fp32(pipeline, shape, op_func, alpha, monkeypatch): monkeypatch.setenv("MRT_ENABLE_PIPELINE", "on") op_func_compiled = torch.compile(op_func, backend=backend) add_forward(np.float32, shape, alpha, op_func_compiled) + + +@arg_mark(plat_marks=["platform_ascend"], level_mark="level0", card_mark="onecard", essential_mark="essential") +@pytest.mark.parametrize("pipeline", (True, False)) +@pytest.mark.parametrize("shape", [(1024, 1024), (256, 512)]) +@pytest.mark.parametrize("op_func,alpha", [ + (add_no_alpha, 1), + (add_alpha_2, 2) +]) +def test_add_int32(pipeline, shape, op_func, alpha, monkeypatch): + """ + Feature: Test aclnn add + Description: Test aclnn add with int32 inputs and different alpha types + Expectation: The result is correct + """ + if pipeline: + monkeypatch.setenv("MRT_ENABLE_PIPELINE", "on") + op_func_compiled = torch.compile(op_func, backend=backend) + add_forward(np.int32, shape, alpha, op_func_compiled) diff --git a/tests/st/inferrt/ops/test_aclnn_cat.py b/tests/st/inferrt/ops/test_aclnn_cat.py index 7262c960769627ef7f0e8a410fd26ffbcd78b6da..f2eea725bc60cac8c1fdfd2a49c29e518e917cc7 100644 --- a/tests/st/inferrt/ops/test_aclnn_cat.py +++ b/tests/st/inferrt/ops/test_aclnn_cat.py @@ -19,7 +19,7 @@ import torch from tests.mark_utils import arg_mark from tests.ops_utils import AssertRtolEqual -from mrt.torch import backend +from mrt.torch import fx_mlir_backend as backend def op_func(input, axis): @@ -34,9 +34,9 @@ def get_op_func_compiled(): return torch.compile(custom_op_func, backend=backend) @arg_mark(plat_marks=["platform_ascend"], level_mark="level0", card_mark="onecard", essential_mark="essential") -@pytest.mark.parametrize("shapes", [[2, 5]]) +@pytest.mark.parametrize("shapes", [[2, 5], [3, 4, 5], [1, 10], [5, 3, 2], [2, 3, 4, 5]]) @pytest.mark.parametrize("axis", [0]) -@pytest.mark.parametrize("dtypes", [torch.float16, torch.bfloat16, torch.float32]) +@pytest.mark.parametrize("dtypes", [torch.float16]) @pytest.mark.parametrize("pipeline", (True, False)) def test_cat(shapes, axis, dtypes, pipeline, monkeypatch): """ diff --git a/tests/st/inferrt/ops/test_aclnn_split.py b/tests/st/inferrt/ops/test_aclnn_split.py index 957c67b3f80375d62d495f442caf0e74080f274f..cdeefc96ab2702e85ec4769182f0c79b7caf90d0 100644 --- a/tests/st/inferrt/ops/test_aclnn_split.py +++ b/tests/st/inferrt/ops/test_aclnn_split.py @@ -1,11 +1,15 @@ import pytest import numpy as np import torch +import os import torch.nn.functional as F from tests.mark_utils import arg_mark from tests.ops_utils import AssertRtolEqual -from mrt.torch import backend +if os.getenv("USE_OLD") == "1": + from mrt.torch import backend +else: + from mrt.torch import fx_mlir_backend as backend def op_func(input_self_tensor, split_size, dim=0): @@ -18,7 +22,7 @@ def get_op_func_compiled(): @arg_mark(plat_marks=["platform_ascend"], level_mark="level0", card_mark="onecard", essential_mark="essential") -@pytest.mark.parametrize("pipeline", (True, False)) +@pytest.mark.parametrize("pipeline", (True, )) @pytest.mark.parametrize("shape", [[128, 4096], [32, 1024]]) def test_split_tensor(pipeline, monkeypatch, shape): """ @@ -44,7 +48,7 @@ def test_split_tensor(pipeline, monkeypatch, shape): @arg_mark(plat_marks=["platform_ascend"], level_mark="level0", card_mark="onecard", essential_mark="essential") -@pytest.mark.parametrize("pipeline", (True, False)) +@pytest.mark.parametrize("pipeline", (True, )) @pytest.mark.parametrize("shape", [[128, 4096], [32, 1024]]) def test_split_with_size(pipeline, monkeypatch, shape): """ @@ -68,3 +72,5 @@ def test_split_with_size(pipeline, monkeypatch, shape): AssertRtolEqual(cpu_output0, npu_output_opt0) AssertRtolEqual(cpu_output1, npu_output_opt1) +# test_split_tensor(None, None, [128, 4096]) +# test_split_tensor(None, None, [32, 1024]) diff --git a/tests/ut/mopt/mlir/concat_1.mlir b/tests/ut/mopt/mlir/concat_1.mlir new file mode 100644 index 0000000000000000000000000000000000000000..bf9b97acd342b4cef3f1b784c96f52dfea99e913 --- /dev/null +++ b/tests/ut/mopt/mlir/concat_1.mlir @@ -0,0 +1,8 @@ +module { + func.func @main(%arg0: !torch.vtensor<[2,5],f32>, %arg1: !torch.vtensor<[2,5],f32>) -> !torch.vtensor<[4,5],f32> { + %int0 = torch.constant.int 0 + %0 = torch.prim.ListConstruct %arg1, %arg0 : (!torch.vtensor<[2,5],f32>, !torch.vtensor<[2,5],f32>) -> !torch.list + %1 = torch.aten.cat %0, %int0 : !torch.list, !torch.int -> !torch.vtensor<[4,5],f32> + return %1 : !torch.vtensor<[4,5],f32> + } +}