diff --git a/build.sh b/build.sh index 0f6c713048792a44061918819f7bd4bf8e1f9ba1..fa7c4a2019e1117c297c75afa4d1fba825246754 100755 --- a/build.sh +++ b/build.sh @@ -149,9 +149,9 @@ if [[ $BUILD_OPT == 1 ]]; then # Build mlir export LLVM_INSTALL_PREFIX="$BUILD_DIR/third_party/install/llvm" export TORCHMLIR_INSTALL_PREFIX="$BUILD_DIR/third_party/install/torch_mlir" - if [[ $INC_BUILD != 1 ]]; then - bash "${CURRENT_PATH}/scripts/build_llvm.sh" - fi + # if [[ $INC_BUILD != 1 ]]; then + # bash "${CURRENT_PATH}/scripts/build_llvm.sh" + # fi export MLIR_DIR="${LLVM_INSTALL_PREFIX}/lib/cmake/mlir" export LLVM_DIR="${LLVM_INSTALL_PREFIX}/lib/cmake/llvm" MOPT_CMAKE_ARGS="-DENABLE_OPTIMIZER=on -DMLIR_DIR=${MLIR_DIR} -DLLVM_DIR=${LLVM_DIR}" diff --git a/mopt/CMakeLists.txt b/mopt/CMakeLists.txt index ee965fd35da8597187640f87e7c3ee32966b7e9b..601e7ec769def33f04098809ba2cbd8fa60ae34b 100644 --- a/mopt/CMakeLists.txt +++ b/mopt/CMakeLists.txt @@ -17,10 +17,20 @@ include(TableGen) # If you need ODS/TableGen include(AddLLVM) # LLVM helper macros include(AddMLIR) # MLIR helper macros -file(GLOB_RECURSE PASS_SRC_FILES RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} "*.cc") +include_directories(${CMAKE_CURRENT_SOURCE_DIR}) +include_directories(${CMAKE_CURRENT_SOURCE_DIR}/../third_party/torch-mlir/externals/stablehlo) +include_directories(${CMAKE_BINARY_DIR}/third_party/build/torch_mlir/stablehlo) +set(MRT_INCGEN_INCLUDE_DIR "${CMAKE_BINARY_DIR}/mopt/dialects/include") + +message("CDS DEBUG CMAKE_CURRENT_SOURCE_DIR= " ${CMAKE_CURRENT_SOURCE_DIR}) +message("CDS DEBUG CMAKE_BINARY_DIR= " ${CMAKE_BINARY_DIR}) +message("CDS DEBUG MRT_INCGEN_INCLUDE_DIR= " ${MRT_INCGEN_INCLUDE_DIR}) + +add_subdirectory(dialects) + +file(GLOB_RECURSE PASS_SRC_FILES RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} "src/*.cc") add_library(mopt SHARED ${PASS_SRC_FILES}) -include_directories(${CMAKE_CURRENT_SOURCE_DIR}) # Choose MLIR/LLVM components to link # Available components can be inspected in install/lib/cmake/mlir/*.cmake @@ -41,6 +51,7 @@ target_link_libraries(mopt MLIRRewrite LLVMCore LLVMSupport + MLIRMrtDialect ) # Include directories (MLIR installed headers) @@ -48,6 +59,7 @@ target_include_directories(mopt PRIVATE ${MLIR_INCLUDE_DIRS} ${LLVM_INCLUDE_DIRS} + ${MRT_INCGEN_INCLUDE_DIR} ) target_link_options(mopt PRIVATE -Wl,-rpath,$ORIGIN/../_vendor/llvm/lib) diff --git a/mopt/dialects/CMakeLists.txt b/mopt/dialects/CMakeLists.txt new file mode 100644 index 0000000000000000000000000000000000000000..133025e9d434f37fca90802f14bab9ffccbfb31e --- /dev/null +++ b/mopt/dialects/CMakeLists.txt @@ -0,0 +1,4 @@ +set(MOPT_DIALECTS_DIR ${CMAKE_CURRENT_SOURCE_DIR}) + +add_subdirectory(include/mrt) +add_subdirectory(src/mrt) diff --git a/mopt/dialects/include/mrt/CMakeLists.txt b/mopt/dialects/include/mrt/CMakeLists.txt new file mode 100644 index 0000000000000000000000000000000000000000..95ad595ea0eb0c8df350a2f15f0cc65ebb427f02 --- /dev/null +++ b/mopt/dialects/include/mrt/CMakeLists.txt @@ -0,0 +1,14 @@ +set(LLVM_TARGET_DEFINITIONS ${CMAKE_CURRENT_SOURCE_DIR}/MrtBase.td) +mlir_tablegen(MrtDialect.h.inc -gen-op-decls -I {CMAKE_CURRENT_SOURCE_DIR}) +mlir_tablegen(MrtDialect.cpp.inc -gen-op-defs -I {CMAKE_CURRENT_SOURCE_DIR}) +add_public_tablegen_target(MLIRMrtDialectIncGen) + +set(LLVM_TARGET_DEFINITIONS ${CMAKE_CURRENT_SOURCE_DIR}/MrtOps.td) +mlir_tablegen(MrtOps.h.inc -gen-op-decls -I {CMAKE_CURRENT_SOURCE_DIR}) +mlir_tablegen(MrtOps.cpp.inc -gen-op-defs -I {CMAKE_CURRENT_SOURCE_DIR}) +add_public_tablegen_target(MLIRMrtOpsIncGen) + +mlir_tablegen(MrtOpsAttributes.h.inc -gen-attrdef-decls -attrdefs-dialect=mrt -I {CMAKE_CURRENT_SOURCE_DIR}) +mlir_tablegen(MrtOpsAttributes.cpp.inc -gen-attrdef-defs -attrdefs-dialect=mrt -I {CMAKE_CURRENT_SOURCE_DIR}) +add_public_tablegen_target(MLIRMrtOpsAttrIncGen) +add_mlir_doc(MrtOps MrtOps Dialects/ -gen-op-doc) diff --git a/mopt/dialects/include/mrt/MrtBase.td b/mopt/dialects/include/mrt/MrtBase.td new file mode 100644 index 0000000000000000000000000000000000000000..c4a0055fb273c493ecbaa24f6ba0e65d455e1db7 --- /dev/null +++ b/mopt/dialects/include/mrt/MrtBase.td @@ -0,0 +1,39 @@ +//===-- MrtBase.td - Mrt dialect definitions ---------*------- tablegen -*-===// +// +// Copyright 2025 Huawei Technologies Co., Ltd +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +//===----------------------------------------------------------------------===// +// +// Defines the Mrt dialect +// +//===----------------------------------------------------------------------===// + +#ifndef MRT_DIALECT_MRT_BASE_TD_ +#define MRT_DIALECT_MRT_BASE_TD_ + +include "mlir/IR/OpBase.td" +include "mlir/IR/AttrTypeBase.td" +include "mlir/IR/BuiltinTypeInterfaces.td" + +//===----------------------------------------------------------------------===// +// Mrt Dialect. +//===----------------------------------------------------------------------===// + +def Mrt_Dialect : Dialect { + let name = "mrt"; + let cppNamespace = "::mlir::mrt"; +} + +#endif // MRT_DIALECT_MRT_BASE_TD_ diff --git a/mopt/dialects/include/mrt/MrtDialect.h b/mopt/dialects/include/mrt/MrtDialect.h new file mode 100644 index 0000000000000000000000000000000000000000..022f9a4861ac4a3377667d4b057e77ecd6d7e9cc --- /dev/null +++ b/mopt/dialects/include/mrt/MrtDialect.h @@ -0,0 +1,14 @@ +#ifndef MRT_DIALECT_MRT_DIALECT_H_ +#define MRT_DIALECT_MRT_DIALECT_H_ + +#include "mlir/IR/BuiltinTypes.h" + +#define GET_ATTRDEF_CLASSES +#include "mrt/MrtOpsAttributes.h.inc" + +#include "mrt/MrtDialect.h.inc" + +#define GET_OP_CLASSES +#include "mrt/MrtOps.h.inc" + +#endif diff --git a/mopt/dialects/include/mrt/MrtOps.td b/mopt/dialects/include/mrt/MrtOps.td new file mode 100644 index 0000000000000000000000000000000000000000..d4e8b7ddddb77681108fe9510092a779333b225f --- /dev/null +++ b/mopt/dialects/include/mrt/MrtOps.td @@ -0,0 +1,80 @@ +//===-- MrtBase.td - Mrt dialect definitions ---------*------- tablegen -*-===// +// +// Copyright 2025 Huawei Technologies Co., Ltd +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +//===----------------------------------------------------------------------===// +// +// Defines the Mrt dialect +// +//===----------------------------------------------------------------------===// + +#ifndef MRT_DIALECT_MRT_OPS_TD_ +#define MRT_DIALECT_MRT_OPS_TD_ + +include "MrtBase.td" + +include "mlir/IR/BuiltinOps.td" + +//===----------------------------------------------------------------------===// +// Mrt Dialect operations. +//===----------------------------------------------------------------------===// + +class Mrt_Op traits = []> : + Op { +} + +// 约定:第二输入为 1D i64 tensor 表示目标形状。 +def Mrt_ReshapeOp : Mrt_Op<"reshape", [Pure, SameOperandsAndResultElementType]> { + let summary = "Reshape a tensor to the target shape given as 1D i64 tensor"; + + let arguments = (ins + AnyTensor:$input, + TensorOf<[I64]>:$shape + ); + + let results = (outs + AnyTensor:$result + ); + + let assemblyFormat = [{ + $input `,` $shape attr-dict `:` type($input) `,` type($shape) `->` type($result) + }]; + + // 自定义验证:shape 必须为 1D i64 tensor。 + let extraClassDeclaration = [{ + static mlir::LogicalResult verifyTrait(::mrt::ReshapeOp op); + }]; + + let hasVerifier = 1; + + // InferTypeOpInterface implementation declaration. + let hasCanonicalizer = 1; + let hasFolder = 1; + + // 生成 C++ 接口方法签名 + let extraClassDefinition = [{ + // Infer the return types from operands. + static mlir::LogicalResult inferReturnTypes( + mlir::MLIRContext *context, + std::optional location, + mlir::ValueRange operands, + mlir::DictionaryAttr attributes, + mlir::OpaqueProperties properties, + mlir::RegionRange regions, + llvm::SmallVectorImpl &inferredReturnTypes); + }]; +} + +#endif // MRT_DIALECT_MRT_OPS_TD_ diff --git a/mopt/dialects/src/mrt/CMakeLists.txt b/mopt/dialects/src/mrt/CMakeLists.txt new file mode 100644 index 0000000000000000000000000000000000000000..45e7bcdee1d0f2e5b03f7ecbd7e93d6e481603c9 --- /dev/null +++ b/mopt/dialects/src/mrt/CMakeLists.txt @@ -0,0 +1,21 @@ +add_mlir_library(MLIRMrtDialect + MrtDialect.cc + + ADDITIONAL_HEADER_DIRS + ${MOPT_DIALECTS_DIR}/include/mrt + + DEPENDS + MLIRMrtOpsIncGen + MLIRMrtOpsAttrIncGen + MLIRMrtDialectIncGen + mlir-headers + LINK_LIBS PUBLIC + MLIRIR + MLIRSupport + MLIRSideEffectInterfaces +) + +target_include_directories(MLIRMrtDialect PRIVATE + ${MRT_INCGEN_INCLUDE_DIR} + ${MOPT_DIALECTS_DIR}/include +) diff --git a/mopt/dialects/src/mrt/MrtDialect.cc b/mopt/dialects/src/mrt/MrtDialect.cc new file mode 100644 index 0000000000000000000000000000000000000000..54468768b8fff0487a3fe947b3b2761b79f88ff9 --- /dev/null +++ b/mopt/dialects/src/mrt/MrtDialect.cc @@ -0,0 +1,7 @@ +#include "dialects/include/mrt/MrtDialect.h" +#include "mlir/IR/DialectImplementation.h" + +using namespace mlir; +using namespace mrt; + +void MrtDialect::initialize() { addOperations(); } diff --git a/mopt/src/pass/replace_add_with_mul.cc b/mopt/src/pass/replace_add_with_mul.cc deleted file mode 100644 index 1710969f3fc2521ba542bc6caff5d2cfcaa33b33..0000000000000000000000000000000000000000 --- a/mopt/src/pass/replace_add_with_mul.cc +++ /dev/null @@ -1,81 +0,0 @@ -/** - * Copyright 2025 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "mlir/IR/PatternMatch.h" -#include "mlir/Pass/Pass.h" -#include "mlir/IR/Builders.h" -#include "mlir/Dialect/Func/IR/FuncOps.h" -#include "mlir/Dialect/Tosa/IR/TosaOps.h" - -namespace mrt { -namespace pass { -// Inherit from OperationPass, with explicit namespace: mlir::func::FuncOp -// TODO(dayschan) remove this pass. -struct ReplaceAddWithMulPass - : public mlir::PassWrapper> { - // cppcheck-suppress unknownMacro - MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(ReplaceAddWithMulPass) - - mlir::StringRef getArgument() const final { return "replace-tosa-add-with-mul"; } - mlir::StringRef getDescription() const final { return "Replace all tosa.add with tosa.mul (demo only)."; } - - void runOnOperation() override { - mlir::func::FuncOp func = getOperation(); - mlir::OpBuilder builder(func.getContext()); - llvm::SmallVector toErase; - - func.walk([&](mlir::Operation *op) { - if (auto add = llvm::dyn_cast(op)) { - builder.setInsertionPoint(add); - - mlir::Value lhs = add.getInput1(); - mlir::Value rhs = add.getInput2(); - mlir::Type outTy = add.getType(); - - // Construct an i32 scalar constant (ElementsAttr) with shift = 0, as the third input of tosa.mul - mlir::Location loc = add.getLoc(); - // mlir::Type i32Ty = builder.getI32Type(); - // mlir::RankedTensorType shiftTy = mlir::RankedTensorType::get({}, i32Ty); - - // // DenseElementsAttr is a subclass of ElementsAttr, satisfying ConstOp's value parameter requirement - // mlir::IntegerAttr zeroAttr = builder.getI32IntegerAttr(0); - // mlir::DenseElementsAttr shiftDense = mlir::DenseElementsAttr::get(shiftTy, zeroAttr); - // mlir::ElementsAttr shiftElems = shiftDense; - - // mlir::Value shiftVal = builder.create(loc, shiftTy, shiftElems).getResult(); - - // auto mul = builder.create(loc, outTy, lhs, rhs, shiftVal); - auto mul = builder.create(loc, outTy, lhs, rhs, 0); - - add.getResult().replaceAllUsesWith(mul.getResult()); - toErase.push_back(add); - } - }); - - for (mlir::Operation *op : toErase) { - op->erase(); - } - } -}; - -// Factory function for explicit external creation -std::unique_ptr createReplaceAddWithMulPass() { return std::make_unique(); } - -} // namespace pass -} // namespace mrt - -// Static registration (for invocation/loading via registration name) -static mlir::PassRegistration pass; diff --git a/mopt/src/pass/reshape_op_lowering.cc b/mopt/src/pass/reshape_op_lowering.cc new file mode 100644 index 0000000000000000000000000000000000000000..72aa49bb6fcf1d62f3bc3374a61c4cdf2c5a3589 --- /dev/null +++ b/mopt/src/pass/reshape_op_lowering.cc @@ -0,0 +1,78 @@ +#include "mlir/IR/PatternMatch.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Transforms/DialectConversion.h" + +// 可能还需要包含这些头(取决于你的工程环境与注册路径) +#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/Dialect/Arith/IR/Arith.h" + +#include "stablehlo/dialect/StablehloOps.h" +#include "dialects/include/mrt/MrtDialect.h" + +using namespace mlir; + +namespace { + +class ReshapeOpLowering : public OpConversionPattern { + public: + using OpConversionPattern::OpConversionPattern; + + LogicalResult matchAndRewrite(stablehlo::ReshapeOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + // 静态 Reshape:只有一个输入操作数 + Value input = adaptor.getOperand(); + + // 结果类型:优先用类型转换器的结果,否则用原类型 + Type dstType = getTypeConverter()->convertType(op.getType()); + if (!dstType) dstType = op.getType(); + + // 假设 mrt::ReshapeOp 支持 (Type, Value input) 的创建接口。 + auto newOp = rewriter.create(op.getLoc(), dstType, input); + + rewriter.replaceOp(op, newOp.getResult()); + return success(); + } +}; + +struct ConvertStableHLOToMrtPass : public PassWrapper> { + MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(ConvertStableHLOToMrtPass) + + void getDependentDialects(DialectRegistry ®istry) const override { + registry.insert(); + registry.insert(); + registry.insert(); + } + + void runOnOperation() override { + MLIRContext *ctx = &getContext(); + ModuleOp module = getOperation(); + + // 目标合法性设定 + ConversionTarget target(*ctx); + target.addIllegalDialect(); + target.addLegalDialect(); + target.addLegalDialect(); + target.addLegalDialect(); + + // 允许未知 op 动态合法,避免一次性实现所有转换 + target.markUnknownOpDynamicallyLegal([](Operation *) { return true; }); + + // 类型转换器:如果不改变类型,可使用默认 + TypeConverter typeConverter; + + // 模式集 + RewritePatternSet patterns(ctx); + patterns.add(typeConverter, ctx); + + if (failed(applyPartialConversion(module, target, std::move(patterns)))) { + signalPassFailure(); + } + } +}; + +} // namespace + +std::unique_ptr createConvertStableHLOToMrtPass() { return std::make_unique(); } + +// Static registration (for invocation/loading via registration name) +static mlir::PassRegistration pass;