diff --git a/inferrt/python/mrt/torch/fx_backend.py b/inferrt/python/mrt/torch/fx_backend.py index b5a81cc169dc35000146e5676c79f2d98bbd51c8..9ada3efbd02afb5b7efb1cc06df478f28ea107d9 100644 --- a/inferrt/python/mrt/torch/fx_backend.py +++ b/inferrt/python/mrt/torch/fx_backend.py @@ -173,7 +173,7 @@ _OP_MAP = { torch.transpose: Op.permute, torch.split: Op.split_with_size, torch.flatten: Op.flatten, - torch.cat: Op.concat, + torch.cat: Op.cat, torch.neg: Op.neg, torch.square: Op.square, torch.rsqrt: Op.rsqrt, @@ -229,6 +229,7 @@ _OP_MAP = { "sigmoid": Op.sigmoid, "reshape": Op.reshape, "transpose": Op.permute, + "cat": Op.cat, "neg": Op.neg, "square": Op.square, "rsqrt": Op.rsqrt, diff --git a/inferrt/src/ops/op_def/ops.list b/inferrt/src/ops/op_def/ops.list index 3f7ef959d58f870403218f9367d5de90f8aabf12..d8ec20af853f326f89a7537b6a40091949abe188 100644 --- a/inferrt/src/ops/op_def/ops.list +++ b/inferrt/src/ops/op_def/ops.list @@ -43,6 +43,7 @@ OP(shape) OP(strided_slice) OP(tile) OP(permute) +OP(cat) OP(make_tuple) OP(tuple_getitem) OP(update_state) diff --git a/mopt/include/mopt/Conversion/MrtTypeConverter.h b/mopt/include/mopt/Conversion/MrtTypeConverter.h index 0c8b1b035754d232f03a16accc0e405c813f9655..12924dd9c16e99b49133358c424b559d00d10f16 100644 --- a/mopt/include/mopt/Conversion/MrtTypeConverter.h +++ b/mopt/include/mopt/Conversion/MrtTypeConverter.h @@ -25,9 +25,9 @@ namespace mrt { // Populate type conversions for converting standard MLIR types to MRT types. // This adds conversions for RankedTensorType -> mrt::TensorType. -inline void populateMrtTypeConversions(mlir::TypeConverter &converter, mlir::MLIRContext *ctx) { - converter.addConversion([ctx](mlir::RankedTensorType type) -> mlir::Type { - return mrt::TensorType::get(ctx, type.getShape(), type.getElementType(), nullptr); +inline void populateMrtTypeConversions(mlir::TypeConverter &converter) { + converter.addConversion([](mlir::RankedTensorType type) -> mlir::Type { + return mrt::TensorType::get(type.getContext(), type.getShape(), type.getElementType(), nullptr); }); } diff --git a/mopt/include/mopt/Dialect/Mrt/MrtOps.td b/mopt/include/mopt/Dialect/Mrt/MrtOps.td index 13b63b9015992ae5db788b3602734c4a2d17068b..a129dcceb89a849e2c2436aa648af07eb96abc36 100644 --- a/mopt/include/mopt/Dialect/Mrt/MrtOps.td +++ b/mopt/include/mopt/Dialect/Mrt/MrtOps.td @@ -202,7 +202,7 @@ def Mrt_CastOp : Mrt_Op<"cast", [Pure]> { let genAclnnOp = "aclnnCast"; } -def Mrt_ConcatOp : Mrt_Op<"concat", [Pure]> { +def Mrt_CatOp : Mrt_Op<"cat", [Pure]> { let summary = "Concatenation operation"; let description = [{ Concatenates tensors along the specified axis. @@ -218,6 +218,7 @@ def Mrt_ConcatOp : Mrt_Op<"concat", [Pure]> { let assemblyFormat = [{ $inputs `,` $axis attr-dict `:` functional-type(operands, results) }]; + let generated = 1; } def Mrt_ConvOp : Mrt_Op<"conv", [Pure]> { diff --git a/mopt/lib/Conversion/StablehloToMrt/StablehloToMrt.cc b/mopt/lib/Conversion/StablehloToMrt/StablehloToMrt.cc index 306c2b397225485d5d5078a43cf1c67c037c6c8e..d3e3522f023e5940808988baf6ab7fe44152694b 100644 --- a/mopt/lib/Conversion/StablehloToMrt/StablehloToMrt.cc +++ b/mopt/lib/Conversion/StablehloToMrt/StablehloToMrt.cc @@ -73,7 +73,7 @@ class StablehloToMrtTypeConverter : public mlir::TypeConverter { public: explicit StablehloToMrtTypeConverter(mlir::MLIRContext *ctx) { addConversion([](mlir::Type type) { return type; }); - mrt::populateMrtTypeConversions(*this, ctx); + mrt::populateMrtTypeConversions(*this); } }; diff --git a/mopt/lib/Conversion/TorchToMrt/CMakeLists.txt b/mopt/lib/Conversion/TorchToMrt/CMakeLists.txt index d291bc9c494d3140b9c0f89ea091f96cb215a84c..570ce7c1ed6fbbc0bd6fcb29a2ea32b7101d0861 100644 --- a/mopt/lib/Conversion/TorchToMrt/CMakeLists.txt +++ b/mopt/lib/Conversion/TorchToMrt/CMakeLists.txt @@ -1,7 +1,8 @@ -# Generate pattern rewrite rules from TableGen -# set(LLVM_TARGET_DEFINITIONS TorchToMrtPatterns.td) -# mlir_tablegen(TorchToMrtPatterns.cpp.inc -gen-rewriters) -# add_public_tablegen_target(MoptTorchToMrtPatternsIncGen) + +add_mlir_pdll_library(TorchToMrtPDLLIncGen + TorchToMrt.pdll + TorchToMrt.pdll.h.inc +) file(GLOB TORCH_TO_MRT_SRC RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} "*.cc") @@ -14,6 +15,7 @@ add_mlir_conversion_library(MoptConversionTorchToMrt DEPENDS MoptConversionPassIncGen + TorchToMrtPDLLIncGen LINK_LIBS PUBLIC MLIRArithDialect @@ -24,6 +26,9 @@ add_mlir_conversion_library(MoptConversionTorchToMrt MLIRTransformUtils MLIRMrtDialect TorchMLIRTorchDialect + MLIRPDLDialect + MLIRPDLInterpDialect + MLIRParser ) # Add generated include directory diff --git a/mopt/lib/Conversion/TorchToMrt/TorchToMrt.cc b/mopt/lib/Conversion/TorchToMrt/TorchToMrt.cc index 463b992101dd7d71241d0be00c1284a7ca4c8f91..6c05bf48effdc21fdbf361fa1920649219aa967b 100644 --- a/mopt/lib/Conversion/TorchToMrt/TorchToMrt.cc +++ b/mopt/lib/Conversion/TorchToMrt/TorchToMrt.cc @@ -17,6 +17,7 @@ #include "mopt/Conversion/TorchToMrt/TorchToMrt.h" #include "mlir/IR/PatternMatch.h" +#include "mlir/Parser/Parser.h" #include "mlir/Pass/Pass.h" #include "mlir/Transforms/DialectConversion.h" #include "mlir/Dialect/Func/IR/FuncOps.h" @@ -27,77 +28,33 @@ #include "mopt/Conversion/MrtTypeConverter.h" #include "mopt/Dialect/Mrt/Mrt.h" #include "mopt/Dialect/Mrt/MrtDialect.h" +#include "mlir/Dialect/PDL/IR/PDL.h" +#include "mlir/Dialect/PDLInterp/IR/PDLInterp.h" + +#include "TorchToMrt.pdll.h.inc" namespace { // Populate Torch-specific type conversions to MRT types -void populateTorchToMrtTypeConversions(mlir::TypeConverter &converter, mlir::MLIRContext *ctx) { - converter.addConversion([ctx](mlir::torch::Torch::ValueTensorType type) -> mlir::Type { +void populateTorchToMrtTypeConversions(mlir::TypeConverter &converter) { + converter.addConversion([](mlir::torch::Torch::ValueTensorType type) -> mlir::Type { if (auto builtinType = mlir::dyn_cast(type.toBuiltinTensor())) { - return mrt::TensorType::get(ctx, builtinType.getShape(), builtinType.getElementType(), nullptr); + return mrt::TensorType::get(type.getContext(), builtinType.getShape(), builtinType.getElementType(), nullptr); } return type; }); - converter.addConversion([ctx](mlir::torch::Torch::IntType type) -> mlir::Type { - return mrt::I64Type::get(ctx); - }); + converter.addConversion( + [](mlir::torch::Torch::IntType type) -> mlir::Type { return mrt::I64Type::get(type.getContext()); }); } // TypeConverter for Torch to MRT conversion class TorchToMrtTypeConverter : public mlir::TypeConverter { public: - explicit TorchToMrtTypeConverter(mlir::MLIRContext *ctx) { + TorchToMrtTypeConverter() { addConversion([](mlir::Type type) { return type; }); - mrt::populateMrtTypeConversions(*this, ctx); - populateTorchToMrtTypeConversions(*this, ctx); - } -}; - -struct ConvertSymbolicInt : public mlir::OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - mlir::LogicalResult matchAndRewrite(mlir::torch::Torch::SymbolicIntOp op, OpAdaptor adaptor, - mlir::ConversionPatternRewriter &rewriter) const override { - mlir::Type resultType = getTypeConverter()->convertType(op.getType()); - if (!resultType) return mlir::failure(); - - rewriter.replaceOpWithNewOp(op, resultType, op.getSymbolName(), op.getMinVal(), op.getMaxVal()); - return mlir::success(); - } -}; - -struct ConvertBindSymbolicShape : public mlir::OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - mlir::LogicalResult matchAndRewrite(mlir::torch::Torch::BindSymbolicShapeOp op, OpAdaptor adaptor, - mlir::ConversionPatternRewriter &rewriter) const override { - rewriter.replaceOpWithNewOp(op, adaptor.getOperand(), adaptor.getShapeSymbols(), - op.getShapeExpressions()); - return mlir::success(); - } -}; - -struct ConvertAtenMulTensor : public mlir::OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - mlir::LogicalResult matchAndRewrite(mlir::torch::Torch::AtenMulTensorOp op, OpAdaptor adaptor, - mlir::ConversionPatternRewriter &rewriter) const override { - mlir::Type resultType = getTypeConverter()->convertType(op.getType()); - if (!resultType) return mlir::failure(); - - rewriter.replaceOpWithNewOp(op, resultType, adaptor.getSelf(), adaptor.getOther()); - return mlir::success(); - } -}; - -struct ConvertReturnOp : public mlir::OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - mlir::LogicalResult matchAndRewrite(mlir::func::ReturnOp op, OpAdaptor adaptor, - mlir::ConversionPatternRewriter &rewriter) const override { - rewriter.replaceOpWithNewOp(op, adaptor.getOperands()); - return mlir::success(); + mrt::populateMrtTypeConversions(*this); + populateTorchToMrtTypeConversions(*this); } }; @@ -114,38 +71,38 @@ struct ConvertTorchToMRTPass : public mlir::PassWrapper(); registry.insert(); registry.insert(); + registry.insert(); + registry.insert(); + } + + mlir::LogicalResult initialize(mlir::MLIRContext *ctx) override { + mlir::RewritePatternSet patternList(ctx); + mlir::registerConversionPDLFunctions(patternList); + populateGeneratedPDLLPatterns(patternList, mlir::PDLConversionConfig(&converter_)); + mlir::populateFunctionOpInterfaceTypeConversionPattern(patternList, converter_); + patterns_ = std::move(patternList); + return mlir::success(); } void runOnOperation() override { mlir::ModuleOp module = getOperation(); - mlir::MLIRContext *context = &getContext(); - - TorchToMrtTypeConverter converter(context); - mlir::RewritePatternSet patterns(context); - - patterns.add(converter, context); - patterns.add(converter, context); - patterns.add(converter, context); - mlir::populateFunctionOpInterfaceTypeConversionPattern(patterns, converter); + mlir::MLIRContext *ctx = &getContext(); - mlir::ConversionTarget target(*context); + mlir::ConversionTarget target(*ctx); + target.addIllegalDialect(); target.addLegalDialect(); - target.addLegalOp(); - target.addDynamicallyLegalOp( - [&](mlir::func::FuncOp op) { return converter.isSignatureLegal(op.getFunctionType()); }); - + [&](mlir::func::FuncOp op) { return converter_.isSignatureLegal(op.getFunctionType()); }); target.addDynamicallyLegalOp( - [&](mlir::func::ReturnOp op) { return converter.isLegal(op.getOperandTypes()); }); + [&](mlir::func::ReturnOp op) { return converter_.isLegal(op.getOperandTypes()); }); - target.addIllegalOp(); - target.addIllegalOp(); - target.addIllegalOp(); - - if (mlir::failed(mlir::applyPartialConversion(module, target, std::move(patterns)))) { + if (mlir::failed(mlir::applyPartialConversion(module, target, patterns_))) { signalPassFailure(); } } + + mlir::FrozenRewritePatternSet patterns_; + TorchToMrtTypeConverter converter_; }; } // namespace diff --git a/mopt/lib/Conversion/TorchToMrt/TorchToMrt.pdll b/mopt/lib/Conversion/TorchToMrt/TorchToMrt.pdll new file mode 100644 index 0000000000000000000000000000000000000000..7441fcad216459aed1439f2b653f64ac7873870a --- /dev/null +++ b/mopt/lib/Conversion/TorchToMrt/TorchToMrt.pdll @@ -0,0 +1,41 @@ +#include "mlir/Transforms/DialectConversion.pdll" +#include "torch-mlir/Dialect/Torch/IR/TorchOps.td" +#include "mopt/Dialect/Mrt/MrtOps.td" + +#include "TorchToMrtAten.pdll" +#include "TorchToMrtCustom.pdll" + +Pattern ConvertSymbolicInt { + let root = op { + symbol_name = symbolNameValue: Attr, + min_val = minValValue: Attr, + max_val = maxValValue: Attr + } -> (resType: Type); + + replace root with op { + symbol_name = symbolNameValue, + min_val = minValValue, + max_val = maxValValue + } -> (convertType(resType)); +} + +Pattern ConvertBindSymbolicShape { + let root = op( + operand: Value, + shape_symbols: ValueRange + ) { + shape_expressions = shapeExpressionsValue: Attr + }; + + replace root with op( + convertValue(operand), + convertValues(shape_symbols) + ) { + shape_expressions = shapeExpressionsValue + }; +} + +Pattern ConvertReturnOp { + let root = op(operands: Value); + replace root with op(convertValue(operands)); +} diff --git a/mopt/lib/Conversion/TorchToMrt/TorchToMrtAten.pdll b/mopt/lib/Conversion/TorchToMrt/TorchToMrtAten.pdll new file mode 100644 index 0000000000000000000000000000000000000000..c56952af6f2b5adde5ddb76f84ae88b56fb0d6fe --- /dev/null +++ b/mopt/lib/Conversion/TorchToMrt/TorchToMrtAten.pdll @@ -0,0 +1,12 @@ + +Pattern ConvertAtenMulTensor { + let root = op( + self: Value, + other: Value + ) -> (resType: Type); + + replace root with op( + convertValue(self), + convertValue(other) + ) -> (convertType(resType)); +} diff --git a/mopt/lib/Conversion/TorchToMrt/TorchToMrtCustom.pdll b/mopt/lib/Conversion/TorchToMrt/TorchToMrtCustom.pdll new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/tests/st/inferrt/ops/test_aclnn_cat.py b/tests/st/inferrt/ops/test_aclnn_cat.py new file mode 100644 index 0000000000000000000000000000000000000000..7262c960769627ef7f0e8a410fd26ffbcd78b6da --- /dev/null +++ b/tests/st/inferrt/ops/test_aclnn_cat.py @@ -0,0 +1,59 @@ +# Copyright 2025 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""torch.cat case""" + +import pytest +import numpy as np +import torch + +from tests.mark_utils import arg_mark +from tests.ops_utils import AssertRtolEqual +from mrt.torch import backend + + +def op_func(input, axis): + """golden""" + return torch.cat(input, axis) + + +def get_op_func_compiled(): + """cat op""" + def custom_op_func(x, axis): + return torch.cat(x, axis) + return torch.compile(custom_op_func, backend=backend) + +@arg_mark(plat_marks=["platform_ascend"], level_mark="level0", card_mark="onecard", essential_mark="essential") +@pytest.mark.parametrize("shapes", [[2, 5]]) +@pytest.mark.parametrize("axis", [0]) +@pytest.mark.parametrize("dtypes", [torch.float16, torch.bfloat16, torch.float32]) +@pytest.mark.parametrize("pipeline", (True, False)) +def test_cat(shapes, axis, dtypes, pipeline, monkeypatch): + """ + Feature: Test aclnn cat + Description: Test aclnn cat + Expectation: The result is correct + """ + if pipeline: + monkeypatch.setenv("MRT_ENABLE_PIPELINE", "on") + cpu_input0 = torch.rand(shapes, dtype=dtypes) + cpu_input1 = torch.rand(shapes, dtype=dtypes) + npu_input0 = cpu_input0.npu() + npu_input1 = cpu_input1.npu() + + cpu_output = op_func((cpu_input0, cpu_input1), axis) + op_func_compiled = get_op_func_compiled() + list_in_npu = [npu_input0, npu_input1] + npu_output = op_func_compiled(list_in_npu, axis) + npu_output_cpu = npu_output.cpu() + AssertRtolEqual(cpu_output, npu_output_cpu)