From 9f11c35cab25e12e1aed5e61e3a208920e4a1815 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E8=B0=A2=E6=98=AD=E9=98=B3?= Date: Thu, 4 Dec 2025 22:11:39 +0800 Subject: [PATCH] Add pdll pattern support for torch to mrt conversions --- .../mopt/Conversion/MrtTypeConverter.h | 6 +- .../StablehloToMrt/StablehloToMrt.cc | 2 +- mopt/lib/Conversion/TorchToMrt/CMakeLists.txt | 13 ++- mopt/lib/Conversion/TorchToMrt/TorchToMrt.cc | 109 ++++++------------ .../lib/Conversion/TorchToMrt/TorchToMrt.pdll | 41 +++++++ .../Conversion/TorchToMrt/TorchToMrtAten.pdll | 12 ++ .../TorchToMrt/TorchToMrtCustom.pdll | 0 7 files changed, 99 insertions(+), 84 deletions(-) create mode 100644 mopt/lib/Conversion/TorchToMrt/TorchToMrt.pdll create mode 100644 mopt/lib/Conversion/TorchToMrt/TorchToMrtAten.pdll create mode 100644 mopt/lib/Conversion/TorchToMrt/TorchToMrtCustom.pdll diff --git a/mopt/include/mopt/Conversion/MrtTypeConverter.h b/mopt/include/mopt/Conversion/MrtTypeConverter.h index 0c8b1b035..12924dd9c 100644 --- a/mopt/include/mopt/Conversion/MrtTypeConverter.h +++ b/mopt/include/mopt/Conversion/MrtTypeConverter.h @@ -25,9 +25,9 @@ namespace mrt { // Populate type conversions for converting standard MLIR types to MRT types. // This adds conversions for RankedTensorType -> mrt::TensorType. -inline void populateMrtTypeConversions(mlir::TypeConverter &converter, mlir::MLIRContext *ctx) { - converter.addConversion([ctx](mlir::RankedTensorType type) -> mlir::Type { - return mrt::TensorType::get(ctx, type.getShape(), type.getElementType(), nullptr); +inline void populateMrtTypeConversions(mlir::TypeConverter &converter) { + converter.addConversion([](mlir::RankedTensorType type) -> mlir::Type { + return mrt::TensorType::get(type.getContext(), type.getShape(), type.getElementType(), nullptr); }); } diff --git a/mopt/lib/Conversion/StablehloToMrt/StablehloToMrt.cc b/mopt/lib/Conversion/StablehloToMrt/StablehloToMrt.cc index 306c2b397..d3e3522f0 100644 --- a/mopt/lib/Conversion/StablehloToMrt/StablehloToMrt.cc +++ b/mopt/lib/Conversion/StablehloToMrt/StablehloToMrt.cc @@ -73,7 +73,7 @@ class StablehloToMrtTypeConverter : public mlir::TypeConverter { public: explicit StablehloToMrtTypeConverter(mlir::MLIRContext *ctx) { addConversion([](mlir::Type type) { return type; }); - mrt::populateMrtTypeConversions(*this, ctx); + mrt::populateMrtTypeConversions(*this); } }; diff --git a/mopt/lib/Conversion/TorchToMrt/CMakeLists.txt b/mopt/lib/Conversion/TorchToMrt/CMakeLists.txt index d291bc9c4..570ce7c1e 100644 --- a/mopt/lib/Conversion/TorchToMrt/CMakeLists.txt +++ b/mopt/lib/Conversion/TorchToMrt/CMakeLists.txt @@ -1,7 +1,8 @@ -# Generate pattern rewrite rules from TableGen -# set(LLVM_TARGET_DEFINITIONS TorchToMrtPatterns.td) -# mlir_tablegen(TorchToMrtPatterns.cpp.inc -gen-rewriters) -# add_public_tablegen_target(MoptTorchToMrtPatternsIncGen) + +add_mlir_pdll_library(TorchToMrtPDLLIncGen + TorchToMrt.pdll + TorchToMrt.pdll.h.inc +) file(GLOB TORCH_TO_MRT_SRC RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} "*.cc") @@ -14,6 +15,7 @@ add_mlir_conversion_library(MoptConversionTorchToMrt DEPENDS MoptConversionPassIncGen + TorchToMrtPDLLIncGen LINK_LIBS PUBLIC MLIRArithDialect @@ -24,6 +26,9 @@ add_mlir_conversion_library(MoptConversionTorchToMrt MLIRTransformUtils MLIRMrtDialect TorchMLIRTorchDialect + MLIRPDLDialect + MLIRPDLInterpDialect + MLIRParser ) # Add generated include directory diff --git a/mopt/lib/Conversion/TorchToMrt/TorchToMrt.cc b/mopt/lib/Conversion/TorchToMrt/TorchToMrt.cc index 463b99210..6c05bf48e 100644 --- a/mopt/lib/Conversion/TorchToMrt/TorchToMrt.cc +++ b/mopt/lib/Conversion/TorchToMrt/TorchToMrt.cc @@ -17,6 +17,7 @@ #include "mopt/Conversion/TorchToMrt/TorchToMrt.h" #include "mlir/IR/PatternMatch.h" +#include "mlir/Parser/Parser.h" #include "mlir/Pass/Pass.h" #include "mlir/Transforms/DialectConversion.h" #include "mlir/Dialect/Func/IR/FuncOps.h" @@ -27,77 +28,33 @@ #include "mopt/Conversion/MrtTypeConverter.h" #include "mopt/Dialect/Mrt/Mrt.h" #include "mopt/Dialect/Mrt/MrtDialect.h" +#include "mlir/Dialect/PDL/IR/PDL.h" +#include "mlir/Dialect/PDLInterp/IR/PDLInterp.h" + +#include "TorchToMrt.pdll.h.inc" namespace { // Populate Torch-specific type conversions to MRT types -void populateTorchToMrtTypeConversions(mlir::TypeConverter &converter, mlir::MLIRContext *ctx) { - converter.addConversion([ctx](mlir::torch::Torch::ValueTensorType type) -> mlir::Type { +void populateTorchToMrtTypeConversions(mlir::TypeConverter &converter) { + converter.addConversion([](mlir::torch::Torch::ValueTensorType type) -> mlir::Type { if (auto builtinType = mlir::dyn_cast(type.toBuiltinTensor())) { - return mrt::TensorType::get(ctx, builtinType.getShape(), builtinType.getElementType(), nullptr); + return mrt::TensorType::get(type.getContext(), builtinType.getShape(), builtinType.getElementType(), nullptr); } return type; }); - converter.addConversion([ctx](mlir::torch::Torch::IntType type) -> mlir::Type { - return mrt::I64Type::get(ctx); - }); + converter.addConversion( + [](mlir::torch::Torch::IntType type) -> mlir::Type { return mrt::I64Type::get(type.getContext()); }); } // TypeConverter for Torch to MRT conversion class TorchToMrtTypeConverter : public mlir::TypeConverter { public: - explicit TorchToMrtTypeConverter(mlir::MLIRContext *ctx) { + TorchToMrtTypeConverter() { addConversion([](mlir::Type type) { return type; }); - mrt::populateMrtTypeConversions(*this, ctx); - populateTorchToMrtTypeConversions(*this, ctx); - } -}; - -struct ConvertSymbolicInt : public mlir::OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - mlir::LogicalResult matchAndRewrite(mlir::torch::Torch::SymbolicIntOp op, OpAdaptor adaptor, - mlir::ConversionPatternRewriter &rewriter) const override { - mlir::Type resultType = getTypeConverter()->convertType(op.getType()); - if (!resultType) return mlir::failure(); - - rewriter.replaceOpWithNewOp(op, resultType, op.getSymbolName(), op.getMinVal(), op.getMaxVal()); - return mlir::success(); - } -}; - -struct ConvertBindSymbolicShape : public mlir::OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - mlir::LogicalResult matchAndRewrite(mlir::torch::Torch::BindSymbolicShapeOp op, OpAdaptor adaptor, - mlir::ConversionPatternRewriter &rewriter) const override { - rewriter.replaceOpWithNewOp(op, adaptor.getOperand(), adaptor.getShapeSymbols(), - op.getShapeExpressions()); - return mlir::success(); - } -}; - -struct ConvertAtenMulTensor : public mlir::OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - mlir::LogicalResult matchAndRewrite(mlir::torch::Torch::AtenMulTensorOp op, OpAdaptor adaptor, - mlir::ConversionPatternRewriter &rewriter) const override { - mlir::Type resultType = getTypeConverter()->convertType(op.getType()); - if (!resultType) return mlir::failure(); - - rewriter.replaceOpWithNewOp(op, resultType, adaptor.getSelf(), adaptor.getOther()); - return mlir::success(); - } -}; - -struct ConvertReturnOp : public mlir::OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - mlir::LogicalResult matchAndRewrite(mlir::func::ReturnOp op, OpAdaptor adaptor, - mlir::ConversionPatternRewriter &rewriter) const override { - rewriter.replaceOpWithNewOp(op, adaptor.getOperands()); - return mlir::success(); + mrt::populateMrtTypeConversions(*this); + populateTorchToMrtTypeConversions(*this); } }; @@ -114,38 +71,38 @@ struct ConvertTorchToMRTPass : public mlir::PassWrapper(); registry.insert(); registry.insert(); + registry.insert(); + registry.insert(); + } + + mlir::LogicalResult initialize(mlir::MLIRContext *ctx) override { + mlir::RewritePatternSet patternList(ctx); + mlir::registerConversionPDLFunctions(patternList); + populateGeneratedPDLLPatterns(patternList, mlir::PDLConversionConfig(&converter_)); + mlir::populateFunctionOpInterfaceTypeConversionPattern(patternList, converter_); + patterns_ = std::move(patternList); + return mlir::success(); } void runOnOperation() override { mlir::ModuleOp module = getOperation(); - mlir::MLIRContext *context = &getContext(); - - TorchToMrtTypeConverter converter(context); - mlir::RewritePatternSet patterns(context); - - patterns.add(converter, context); - patterns.add(converter, context); - patterns.add(converter, context); - mlir::populateFunctionOpInterfaceTypeConversionPattern(patterns, converter); + mlir::MLIRContext *ctx = &getContext(); - mlir::ConversionTarget target(*context); + mlir::ConversionTarget target(*ctx); + target.addIllegalDialect(); target.addLegalDialect(); - target.addLegalOp(); - target.addDynamicallyLegalOp( - [&](mlir::func::FuncOp op) { return converter.isSignatureLegal(op.getFunctionType()); }); - + [&](mlir::func::FuncOp op) { return converter_.isSignatureLegal(op.getFunctionType()); }); target.addDynamicallyLegalOp( - [&](mlir::func::ReturnOp op) { return converter.isLegal(op.getOperandTypes()); }); + [&](mlir::func::ReturnOp op) { return converter_.isLegal(op.getOperandTypes()); }); - target.addIllegalOp(); - target.addIllegalOp(); - target.addIllegalOp(); - - if (mlir::failed(mlir::applyPartialConversion(module, target, std::move(patterns)))) { + if (mlir::failed(mlir::applyPartialConversion(module, target, patterns_))) { signalPassFailure(); } } + + mlir::FrozenRewritePatternSet patterns_; + TorchToMrtTypeConverter converter_; }; } // namespace diff --git a/mopt/lib/Conversion/TorchToMrt/TorchToMrt.pdll b/mopt/lib/Conversion/TorchToMrt/TorchToMrt.pdll new file mode 100644 index 000000000..7441fcad2 --- /dev/null +++ b/mopt/lib/Conversion/TorchToMrt/TorchToMrt.pdll @@ -0,0 +1,41 @@ +#include "mlir/Transforms/DialectConversion.pdll" +#include "torch-mlir/Dialect/Torch/IR/TorchOps.td" +#include "mopt/Dialect/Mrt/MrtOps.td" + +#include "TorchToMrtAten.pdll" +#include "TorchToMrtCustom.pdll" + +Pattern ConvertSymbolicInt { + let root = op { + symbol_name = symbolNameValue: Attr, + min_val = minValValue: Attr, + max_val = maxValValue: Attr + } -> (resType: Type); + + replace root with op { + symbol_name = symbolNameValue, + min_val = minValValue, + max_val = maxValValue + } -> (convertType(resType)); +} + +Pattern ConvertBindSymbolicShape { + let root = op( + operand: Value, + shape_symbols: ValueRange + ) { + shape_expressions = shapeExpressionsValue: Attr + }; + + replace root with op( + convertValue(operand), + convertValues(shape_symbols) + ) { + shape_expressions = shapeExpressionsValue + }; +} + +Pattern ConvertReturnOp { + let root = op(operands: Value); + replace root with op(convertValue(operands)); +} diff --git a/mopt/lib/Conversion/TorchToMrt/TorchToMrtAten.pdll b/mopt/lib/Conversion/TorchToMrt/TorchToMrtAten.pdll new file mode 100644 index 000000000..c56952af6 --- /dev/null +++ b/mopt/lib/Conversion/TorchToMrt/TorchToMrtAten.pdll @@ -0,0 +1,12 @@ + +Pattern ConvertAtenMulTensor { + let root = op( + self: Value, + other: Value + ) -> (resType: Type); + + replace root with op( + convertValue(self), + convertValue(other) + ) -> (convertType(resType)); +} diff --git a/mopt/lib/Conversion/TorchToMrt/TorchToMrtCustom.pdll b/mopt/lib/Conversion/TorchToMrt/TorchToMrtCustom.pdll new file mode 100644 index 000000000..e69de29bb -- Gitee