diff --git a/.gitignore b/.gitignore index c2e309d16223c13413003dc624922047b7722e86..9a19a6c40a15840a3ffb43bfcf806b003b23bffa 100644 --- a/.gitignore +++ b/.gitignore @@ -5,6 +5,7 @@ build/ output/ inferrt/src/pybind/pybind11-src/ +output/ # Cmake files CMakeFiles/ diff --git a/CMakeLists.txt b/CMakeLists.txt index aa6bc81e1180a918c0e10f57e29752671016ee8b..be042c26fc63dc09cb76a492e4fd491a8ea4d58d 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -27,8 +27,15 @@ function(check_debug_log_out) endfunction() if(ENABLE_TORCH_FRONT) - execute_process(COMMAND python -c "import torch; print(torch.compiled_with_cxx11_abi())" OUTPUT_VARIABLE PYTORCH_CXX11_ABI_VERSION OUTPUT_STRIP_TRAILING_WHITESPACE) - + execute_process( + COMMAND + python + -c + "import torch; print(torch.compiled_with_cxx11_abi())" + OUTPUT_VARIABLE + PYTORCH_CXX11_ABI_VERSION + OUTPUT_STRIP_TRAILING_WHITESPACE + ) if("${PYTORCH_CXX11_ABI_VERSION}" STREQUAL "True") message("-- Enable _GLIBCXX_USE_CXX11_ABI") add_definitions(-D_GLIBCXX_USE_CXX11_ABI=1) @@ -38,5 +45,9 @@ if(ENABLE_TORCH_FRONT) endif() endif() +if(ENABLE_OPTIMIZER) + add_subdirectory(${PROJECT_SOURCE_DIR}/mopt) +endif() + add_subdirectory(${PROJECT_SOURCE_DIR}/inferrt/src) -add_subdirectory(${PROJECT_SOURCE_DIR}/tests) \ No newline at end of file +add_subdirectory(${PROJECT_SOURCE_DIR}/tests) diff --git a/build.sh b/build.sh index fd2ef2a27b08029afaaa5b6002c09189823822e1..53c7116d1bbd3c7a9bb6d1183b031902af98985c 100755 --- a/build.sh +++ b/build.sh @@ -19,6 +19,8 @@ usage() echo " -t Build and run tests, default off" echo " -f Enable frontend, default compile all frontend" echo " -b Enable backend, default compile cpu backend" + echo " -O Enable optimizer, default off" + echo " -S Enable enable download dependency from gitee, default on" } process_options() @@ -30,7 +32,13 @@ process_options() export ENABLE_MINDSPORE_FRONT=1 export ENABLE_TORCH_FRONT=1 - while getopts 'Dd:hitf:b:' OPT; do + # Default build optimizer + export BUILD_OPT=0 + + # Default enable gitee downloading. + export ENABLE_GITEE=on + + while getopts 'Dd:hitf:b:S:O' OPT; do case $OPT in D) # Debug version or not. @@ -39,14 +47,15 @@ process_options() d) # Enable log out for modules. # -d lexer,parser,compiler,vm,ir,rt,dapy - OPTARGS=(${OPTARG//,/ }) - for ARG in ${OPTARGS[@]} + IFS=',' read -r -a OPTARGS <<< "$OPTARG" + for ARG in "${OPTARGS[@]}" do - export DEBUG_LOG_OUT="$DEBUG_LOG_OUT -DDEBUG_LOG_OUT_$ARG=on" + export DEBUG_LOG_OUT="$DEBUG_LOG_OUT -DDEBUG_LOG_OUT_${ARG}=on" done ;; i) export INC_BUILD=1;; t) export BUILD_TESTS=1;; + O) export BUILD_OPT=1;; h) usage exit 0 @@ -74,6 +83,9 @@ process_options() exit 1 fi ;; + S) + export ENABLE_GITEE="$OPTARG" + ;; ?) usage exit 1 @@ -82,7 +94,7 @@ process_options() done } -process_options $@ +process_options "$@" INFERRT_CMAKE_ARGS="${INFERRT_CMAKE_ARGS} $DEBUG $DEBUG_LOG_OUT" if [[ $BUILD_TESTS == 1 ]]; then @@ -105,7 +117,6 @@ fi # Prepare source and build directories ################################################## CURRENT_PATH=$(pwd) -SCRIPT_PATH=$(dirname "$0") INFERRT_PATH=$CURRENT_PATH BUILD_DIR=$CURRENT_PATH/build @@ -116,7 +127,7 @@ make_sure_build_dir() if [ -d "$1" ]; then echo "$1 already exists." else - mkdir -p $1 + mkdir -p "$1" fi if [ ! -d "$1" ]; then @@ -124,7 +135,7 @@ make_sure_build_dir() return fi } -make_sure_build_dir $BUILD_DIR +make_sure_build_dir "$BUILD_DIR" # Try using ccache if type -P ccache &>/dev/null; then @@ -135,13 +146,28 @@ else export CCACHE_CMAKE_ARGS="" fi + +# Build mopt +if [[ $BUILD_OPT == 1 ]]; then + # Build mlir + export LLVM_INSTALL_PREFIX="$BUILD_DIR/third_party/install/llvm" + export TORCHMLIR_INSTALL_PREFIX="$BUILD_DIR/third_party/install/torch_mlir" + if [[ $INC_BUILD != 1 ]]; then + bash "${CURRENT_PATH}/scripts/build_llvm.sh" + fi + export MLIR_DIR="${LLVM_INSTALL_PREFIX}/lib/cmake/mlir" + export LLVM_DIR="${LLVM_INSTALL_PREFIX}/lib/cmake/llvm" + MOPT_CMAKE_ARGS="-DENABLE_OPTIMIZER=on -DMLIR_DIR=${MLIR_DIR} -DLLVM_DIR=${LLVM_DIR}" +else + MOPT_CMAKE_ARGS="" +fi + ################################################## # Make da & dapy execution and shared library ################################################## -cd $BUILD_DIR +cd "$BUILD_DIR" if [[ $INC_BUILD != 1 ]]; then - rm $BUILD_DIR/* -rf - cmake $INFERRT_PATH $CCACHE_CMAKE_ARGS $INFERRT_CMAKE_ARGS + cmake "$INFERRT_PATH" $CCACHE_CMAKE_ARGS $INFERRT_CMAKE_ARGS $MOPT_CMAKE_ARGS fi make @@ -150,15 +176,15 @@ make # Run essential test ################################################## # Run inferrt test -export DART_KERNEL_LIB_PATH=$BUILD_DIR/inferrt/src/ops/dummy/libkernel_dummy.so +export DART_KERNEL_LIB_PATH="$BUILD_DIR/inferrt/src/ops/dummy/libkernel_dummy.so" export DART_KERNEL_LIB_NAME=Dummy export DUMMY_RUN="on" echo "==============================" echo "Run da execution test cases:" echo "# 1/2: ./da sample/fibonacci_20.da" -$BUILD_DIR/inferrt/src/da $INFERRT_PATH/inferrt/src/lang/sample/fibonacci_20.da +$BUILD_DIR/inferrt/src/da "$INFERRT_PATH/inferrt/src/lang/sample/fibonacci_20.da" echo "# 2/2: ./da sample/da_llm_sample.da" -$BUILD_DIR/inferrt/src/da $INFERRT_PATH/inferrt/src/lang/sample/da_llm_sample.da +$BUILD_DIR/inferrt/src/da "$INFERRT_PATH/inferrt/src/lang/sample/da_llm_sample.da" echo "==============================" # Run hardware test @@ -176,7 +202,7 @@ if [[ $BUILD_TESTS == 1 ]]; then echo "==============================" fi -cd $CURRENT_PATH +cd "$CURRENT_PATH" # 1. Clean up previous build artifacts rm -rf output temp_build dist diff --git a/cmake/llvm.cmake b/cmake/llvm.cmake new file mode 100644 index 0000000000000000000000000000000000000000..e3596c2b57aad0569f146e8233f0822b0d0eb785 --- /dev/null +++ b/cmake/llvm.cmake @@ -0,0 +1,59 @@ +# Build and install LLVM (with MLIR) via a nested CMake invocation +# No ExternalProject_Add; use custom commands with stamp files. + +# Sources and build/install paths +set(LLVM_SRC_DIR "${MOPT_TOP_SOURCE_DIR}/third_party/llvm-project/llvm") +set(LLVM_BUILD_DIR "${MOPT_THIRD_PARTY_BUILD_DIR}/llvm") +set(LLVM_STAMP_DIR "${LLVM_BUILD_DIR}") +set(LLVM_CONFIG_STAMP "${LLVM_STAMP_DIR}/configured.stamp") +set(LLVM_BUILT_STAMP "${LLVM_STAMP_DIR}/built.stamp") +set(LLVM_INSTALL_PREFIX "${MOPT_INSTALL_PREFIX}") + +# Configure arguments +set(LLVM_CMAKE_ARGS + -G Ninja + -S "${LLVM_SRC_DIR}" + -B "${LLVM_BUILD_DIR}" + -DCMAKE_BUILD_TYPE=Release + -DCMAKE_INSTALL_PREFIX="${LLVM_INSTALL_PREFIX}" + -DLLVM_ENABLE_PROJECTS=mlir + -DLLVM_ENABLE_RUNTIMES= + -DLLVM_ENABLE_ASSERTIONS=ON + -DLLVM_TARGETS_TO_BUILD=X86;AArch64 + -DLLVM_BUILD_LLVM_DYLIB=ON + -DMLIR_BUILD_MLIR_C_DYLIB=ON + -DBUILD_SHARED_LIBS=OFF + -DLLVM_ENABLE_TERMINFO=OFF + -DLLVM_INCLUDE_EXAMPLES=OFF + -DLLVM_INCLUDE_TESTS=OFF + -DLLVM_INCLUDE_DOCS=OFF +) + +# Configure step (stamp) +add_custom_command( + OUTPUT "${LLVM_CONFIG_STAMP}" + COMMAND ${CMAKE_COMMAND} -E make_directory "${LLVM_BUILD_DIR}" + COMMAND ${CMAKE_COMMAND} ${LLVM_CMAKE_ARGS} + COMMAND ${CMAKE_COMMAND} -E touch "${LLVM_CONFIG_STAMP}" + WORKING_DIRECTORY "${LLVM_BUILD_DIR}" + COMMENT "Configuring LLVM/MLIR" + VERBATIM +) + +# Build + install step (stamp) +add_custom_command( + OUTPUT "${LLVM_BUILT_STAMP}" + DEPENDS "${LLVM_CONFIG_STAMP}" + COMMAND ${CMAKE_COMMAND} --build "${LLVM_BUILD_DIR}" --target install --config Release + COMMAND ${CMAKE_COMMAND} -E touch "${LLVM_BUILT_STAMP}" + WORKING_DIRECTORY "${LLVM_BUILD_DIR}" + COMMENT "Building and installing LLVM/MLIR" + VERBATIM +) + +# Public target +add_custom_target(llvm_ext ALL DEPENDS "${LLVM_BUILT_STAMP}") + +# Export package dirs for downstream +set(LLVM_CMAKE_DIR "${LLVM_INSTALL_PREFIX}/lib/cmake/llvm" PARENT_SCOPE) +set(MLIR_CMAKE_DIR "${LLVM_INSTALL_PREFIX}/lib/cmake/mlir" PARENT_SCOPE) diff --git a/cmake/torchmlir.cmake b/cmake/torchmlir.cmake new file mode 100644 index 0000000000000000000000000000000000000000..87aa093be0f2934c27678b3bb8641262c4607ce8 --- /dev/null +++ b/cmake/torchmlir.cmake @@ -0,0 +1,53 @@ +# Build and install torch-mlir using the previously installed LLVM/MLIR + +if(NOT TARGET llvm_ext) + message(FATAL_ERROR "torch_mlir_ext requires llvm_ext to be built first") +endif() + +set(TORCH_MLIR_SRC_DIR "${MOPT_TOP_SOURCE_DIR}/third_party/torch_mlir") +set(TORCH_MLIR_BUILD_DIR "${MOPT_THIRD_PARTY_BUILD_DIR}/torch-mlir") +set(TORCH_MLIR_STAMP_DIR "${TORCH_MLIR_BUILD_DIR}") +set(TORCH_MLIR_CONFIG_STAMP "${TORCH_MLIR_STAMP_DIR}/configured.stamp") +set(TORCH_MLIR_BUILT_STAMP "${TORCH_MLIR_STAMP_DIR}/built.stamp") +set(TORCH_MLIR_INSTALL_PREFIX "${MOPT_INSTALL_PREFIX}") + +# Optional: force a specific Python interpreter +# set(Python3_EXECUTABLE "/usr/bin/python3") + +set(TORCH_MLIR_CMAKE_ARGS + -G Ninja + -S "${TORCH_MLIR_SRC_DIR}" + -B "${TORCH_MLIR_BUILD_DIR}" + -DCMAKE_BUILD_TYPE=Release + -DCMAKE_INSTALL_PREFIX="${TORCH_MLIR_INSTALL_PREFIX}" + -DLLVM_DIR="${LLVM_CMAKE_DIR}" + -DMLIR_DIR="${MLIR_CMAKE_DIR}" + -DBUILD_TESTING=OFF + # -DPython3_EXECUTABLE=${Python3_EXECUTABLE} +) + +add_custom_command( + OUTPUT "${TORCH_MLIR_CONFIG_STAMP}" + DEPENDS llvm_ext + COMMAND ${CMAKE_COMMAND} -E make_directory "${TORCH_MLIR_BUILD_DIR}" + COMMAND ${CMAKE_COMMAND} ${TORCH_MLIR_CMAKE_ARGS} + COMMAND ${CMAKE_COMMAND} -E touch "${TORCH_MLIR_CONFIG_STAMP}" + WORKING_DIRECTORY "${TORCH_MLIR_BUILD_DIR}" + COMMENT "Configuring torch-mlir" + VERBATIM +) + +add_custom_command( + OUTPUT "${TORCH_MLIR_BUILT_STAMP}" + DEPENDS "${TORCH_MLIR_CONFIG_STAMP}" llvm_ext + COMMAND ${CMAKE_COMMAND} --build "${TORCH_MLIR_BUILD_DIR}" --target install --config Release + COMMAND ${CMAKE_COMMAND} -E touch "${TORCH_MLIR_BUILT_STAMP}" + WORKING_DIRECTORY "${TORCH_MLIR_BUILD_DIR}" + COMMENT "Building and installing torch-mlir" + VERBATIM +) + +add_custom_target(torch_mlir_ext ALL DEPENDS "${TORCH_MLIR_BUILT_STAMP}") + +# Export package dir +set(TORCH_MLIR_CMAKE_DIR "${TORCH_MLIR_INSTALL_PREFIX}/lib/cmake/torch-mlir" PARENT_SCOPE) diff --git a/mopt/CMakeLists.txt b/mopt/CMakeLists.txt new file mode 100644 index 0000000000000000000000000000000000000000..ce748821da05fd86c2d6192a4dc8e109850b0328 --- /dev/null +++ b/mopt/CMakeLists.txt @@ -0,0 +1,85 @@ +check_debug_log_out() + +# Set compiler standard (optional) +# set(CMAKE_CXX_STANDARD 17) +# set(CMAKE_CXX_STANDARD_REQUIRED ON) + +# Point to the installed LLVM/MLIR +# Option A: pass -DMLIR_DIR and -DLLVM_DIR from the command line +# Option B: hardcode paths here (less portable) +# set(LLVM_DIR "/path/to/llvm/install/lib/cmake/llvm") +# set(MLIR_DIR "/path/to/llvm/install/lib/cmake/mlir") + +# Find packages (requires installed *Config.cmake files) +find_package(LLVM REQUIRED CONFIG) +find_package(MLIR REQUIRED CONFIG) + +message(STATUS "Found LLVM ${LLVM_PACKAGE_VERSION} at ${LLVM_DIR}") +message(STATUS "Found MLIR at ${MLIR_DIR}") + +# Common LLVM compile definitions +include_directories(SYSTEM ${LLVM_INCLUDE_DIRS}) +add_definitions(${LLVM_DEFINITIONS}) + +# MLIR CMake helpers +list(APPEND CMAKE_MODULE_PATH "${MLIR_CMAKE_DIR}") +include(TableGen) # If you need ODS/TableGen +include(AddLLVM) # LLVM helper macros +include(AddMLIR) # MLIR helper macros + +# If linking static libs, prefer consistent static/dynamic across all +# set(LLVM_USE_STATIC_LIBS ON) + +# Your executable +# add_executable(my_mlir_app +# src/main.cpp +# ) + +file(GLOB_RECURSE PASS_SRC_FILES RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} "*.cc") +add_library(mopt SHARED ${PASS_SRC_FILES}) + +include_directories(${CMAKE_CURRENT_SOURCE_DIR}) + +# Choose MLIR/LLVM components to link +# Adjust according to APIs you actually use +# Available components can be inspected in install/lib/cmake/mlir/*.cmake +target_link_libraries(mopt + PRIVATE + MLIRIR + MLIRParser + MLIROptLib + MLIRSupport + MLIRTransforms + MLIRPass + MLIRDialectUtils + # Add dialects/conversions/transforms as needed: + # MLIRAffine + MLIRArithDialect + MLIRFuncDialect + MLIRSCFDialect + MLIRMemRefDialect + MLIRTosaDialect + # For rewrites/printing/parsing utilities: + MLIRRewrite + # LLVM support libraries (often pulled transitively, but explicit is safer) + LLVMCore + LLVMSupport +) + +# Include directories (MLIR installed headers) +target_include_directories(mopt + PRIVATE + ${MLIR_INCLUDE_DIRS} + ${LLVM_INCLUDE_DIRS} +) + +target_link_options(mopt PRIVATE -Wl,-rpath,$ORIGIN/../_vendor/llvm/lib) + +# Some platforms may need this to avoid ABI symbol issues +# add_definitions(-D_GLIBCXX_USE_CXX11_ABI=1) # Adjust per toolchain/ABI + +# Stricter compile/link flags +# target_compile_options(my_mlir_app PRIVATE -Wall -Wextra -Wpedantic) + +# 导出符号,便于 mlir-opt/PassManager 加载 +mlir_check_all_link_libraries(mopt) diff --git a/mopt/include/.gitkeep b/mopt/include/.gitkeep new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/mopt/python/mrt/torch/decompositions.py b/mopt/python/mrt/torch/decompositions.py new file mode 100644 index 0000000000000000000000000000000000000000..09c9da9d5edc161721cd53a22dc73b0f234a70e8 --- /dev/null +++ b/mopt/python/mrt/torch/decompositions.py @@ -0,0 +1,154 @@ +# Copyright 2024 Advanced Micro Devices, Inc +# +# Licensed under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +# +# This file is derived from: +# https://github.com/iree-org/iree-turbine/blob/main/iree/turbine/dynamo/decompositions.py + +"""Decomposition utilities and context management for PyTorch ops. + +This module manages decomposition tables for different compilation scopes, +providing: +- A thread-local stack of decomposition tables per scope. +- Helpers to extend or prune decompositions within a context. +- A default set of decompositions collected from torch._decomp. +""" + +from typing import Callable, Dict, List, Optional, Sequence, Union + +import contextlib +import threading + +import torch +from torch._decomp import get_decompositions, remove_decompositions + +# pylint: disable=protected-access + +DecompositionTable = Dict[torch._ops.OperatorBase, Callable] +DecompositionOpsList = Sequence[ + Union[torch._ops.OperatorBase, torch._ops.OpOverloadPacket] +] + +# Manages "scopes" for decompositions used. Each unique scope is an attribute on +# the _decomp_local. If the attribute is missing, then the default +# decompositions are used. The scope "aot" is used for all AOT cases. +_decomp_local = threading.local() + + +def _get_decomp_stack(scope: str) -> List[DecompositionTable]: + """Returns the thread-local stack for the given scope, creating it if absent.""" + try: + return getattr(_decomp_local, scope) + except AttributeError: + stack: List[DecompositionTable] = [] + setattr(_decomp_local, scope, stack) + return stack + + +def _current(scope: str) -> DecompositionTable: + """Gets the current decomposition table (which may be the default).""" + stack = _get_decomp_stack(scope) + if stack: + return dict(stack[-1]) + return dict(DEFAULT_DECOMPOSITION_TABLE) + + +@contextlib.contextmanager +def _extend_context_manager( + scope: str, + *, + from_current: bool = True, + add_ops: Optional[DecompositionOpsList] = None, + remove_ops: Optional[DecompositionOpsList] = None, +): + """Context manager to push a derived decomposition table on the scope stack. + + Args: + scope: The name of the decomposition scope (thread-local). + from_current: If True, start from the current table; otherwise start empty. + add_ops: Optional sequence of ops to add decompositions for. + remove_ops: Optional sequence of ops to remove decompositions for. + + Yields: + The table that is active within the context. + """ + table: DecompositionTable + if from_current: + table = dict(_current(scope)) + else: + table = {} + if add_ops: + table.update(get_decompositions(add_ops)) + if remove_ops: + remove_decompositions(table, remove_ops) # type: ignore + stack = _get_decomp_stack(scope) + stack.append(table) + try: + yield table + finally: + popped = stack.pop() + assert ( + popped is table + ), "contextmanager unbalanced: popped different that pushed" + + +def _get_default_decomposition_ops() -> DecompositionOpsList: + """Collects the default set of operator decompositions used by this module.""" + aten = torch.ops.aten + # default decompositions pulled from SHARK / torch._decomp + return [ + aten.embedding_dense_backward, + aten.native_layer_norm_backward, + aten.slice_backward, + aten.select_backward, + aten.norm.ScalarOpt_dim, + aten.native_group_norm, + aten.upsample_bilinear2d.vec, + aten.split.Tensor, + aten.split_with_sizes, + aten.native_layer_norm, + aten.masked_fill.Tensor, + aten.masked_fill.Scalar, + aten.t, + aten.addmm, + # decompositions that aid us in handling nn.BatchNorm2d + aten._native_batch_norm_legit_functional, + aten._native_batch_norm_legit_no_training, + aten._native_batch_norm_legit, + aten._native_batch_norm_legit.no_stats, + aten.squeeze.dims, + # decompositions for miscellaneous ops that are not handled in torch-mlir but have available decompositions + aten.soft_margin_loss, + aten.im2col, + aten._euclidean_dist, + aten.index_copy, + aten.index_copy_, + aten.grid_sampler_2d, + aten.log_sigmoid_forward, + aten.unsafe_split.Tensor, + aten.binary_cross_entropy, + aten.dot, + aten._adaptive_avg_pool2d, + aten._prelu_kernel, + aten.full, + aten._log_softmax, + aten.nll_loss_forward, + aten.nll_loss_backward, + aten._to_copy, + aten._log_softmax_backward_data, + aten.lift_fresh_copy.default, + aten._unsafe_index.Tensor, + aten.unbind.int, + aten.linspace.Tensor_Tensor, + ] + + +# Some older APIs still use an op list instead of a table. +DEFAULT_DECOMPOSITIONS: DecompositionOpsList = _get_default_decomposition_ops() + +# The table of default decompositions. +DEFAULT_DECOMPOSITION_TABLE: DecompositionTable = get_decompositions( + DEFAULT_DECOMPOSITIONS +) diff --git a/mopt/python/mrt/torch/fx_mlir_backend.py b/mopt/python/mrt/torch/fx_mlir_backend.py new file mode 100644 index 0000000000000000000000000000000000000000..7d825079060c760ca18492af6ceb1c8a5ea3acf9 --- /dev/null +++ b/mopt/python/mrt/torch/fx_mlir_backend.py @@ -0,0 +1,374 @@ +# Copyright 2025 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""FX-to-MLIR backend utilities. + +This module provides: +- Helpers to parse MLIR modules/functions and build a GraphExecutor. +- A small runtime op mapping from MLIR op names to executor ops. +- An FX backend entry that applies decompositions and imports to StableHLO. +""" + +from typing import Any, Dict, List, Mapping, Optional, Tuple as PyTuple + +import torch +from torch._decomp import get_decompositions +from torch.fx.experimental.proxy_tensor import make_fx +from torch.func import functionalize +from torch._subclasses.fake_tensor import FakeTensorMode + +from mrt.ir import GraphExecutor, Node, Op +from mrt.torch.utils import from_torch, to_torch, update_tensor_data +from mrt.torch.decompositions import DEFAULT_DECOMPOSITIONS + + +def _elemtype_to_torch_dtype(elem_ty) -> torch.dtype: + """Map MLIR element type to torch.dtype. + + Args: + elem_ty: MLIR element type object or string-like. + + Returns: + torch.dtype corresponding to elem_ty. + + Raises: + NotImplementedError: If the element type is unsupported. + """ + # Typical bindings: elem_ty is a subclass of mlir.ir.Type; using str(elem_ty) + s = str(elem_ty) + if s in ("f32", "tensor"): + return torch.float32 + if s == "f16": + return torch.float16 + if s == "bf16": + return torch.bfloat16 + if s in ("i64", "si64"): + return torch.int64 + if s in ("i32", "si32"): + return torch.int32 + if s in ("i16", "si16"): + return torch.int16 + if s in ("i8", "si8"): + return torch.int8 + if s in ("ui8", "uint8"): + return torch.uint8 + raise NotImplementedError(f"Unsupported element type: {elem_ty!r}") + + +def _ranked_tensor_type_to_shape_dtype(tensor_ty): + """Extract shape and dtype from an MLIR RankedTensorType-like object. + + Supports common bindings exposing attributes: + - tensor_ty.shape: tuple[int | Attribute] + - tensor_ty.element_type + + Falls back to parsing textual form like 'tensor'. + + Returns: + A pair (shape: List[int], dtype: torch.dtype) + """ + # Adapt to mlir.ir.RankedTensorType + # Common interface: tensor_ty.shape -> tuple[int|mlir.ir.Attribute], tensor_ty.element_type + try: + shape = list(tensor_ty.shape) # type: ignore[attr-defined] + elem_ty = tensor_ty.element_type # type: ignore[attr-defined] + except AttributeError: + # Fallback: parse via string (covers textual MLIR) + # Forms like: 'tensor' or 'tensor<2x3xf32>' + ts = str(tensor_ty) # e.g., 'tensor' + if not (ts.startswith("tensor<") and ts.endswith(">")): + # Provide clearer error than a ValueError deep in parsing + raise ValueError(f"Unrecognized tensor type textual form: {ts}") + core = ts[len("tensor<"): -1] + dims, et = core.rsplit("x", 1) + elem_ty = et + shape = [] + for d in dims.split("x"): + if d == "?": + shape.append(-1) + else: + shape.append(int(d)) + # Convert possible MLIR symbolic dims to -1 + shape = [int(d) if isinstance(d, int) else -1 for d in shape] + dtype = _elemtype_to_torch_dtype(elem_ty) + return shape, dtype + + +def _ranked_tensor_type_to_dummy(tensor_ty, is_fake) -> torch.Tensor: + """Create a FakeTensor with the same shape/dtype as the MLIR tensor type.""" + shape, dtype = _ranked_tensor_type_to_shape_dtype(tensor_ty) + if is_fake: + fake_mode = FakeTensorMode() + return fake_mode.from_tensor(torch.empty(shape, dtype=dtype)) + return torch.empty(shape, dtype=dtype) + + +def _map_mlir_op_name_to_runtime_op(name: str) -> Op: + """Map MLIR op name (dialect.op) to runtime Op enum/class used by GraphExecutor. + + Raises: + NotImplementedError: if the op name is not supported. + """ + m = { + "tosa.add": Op.add, + "tosa.sub": Op.sub, + "tosa.mul": Op.mul, + "tosa.matmul": Op.matmul, + "tosa.reshape": Op.reshape, + "tosa.transpose": Op.transpose, + "tosa.concat": Op.concat, + "tosa.negate": Op.neg, + "tosa.square": Op.square, + "tosa.rsqrt": Op.rsqrt, + "tosa.relu": Op.relu, + "tosa.sigmoid": Op.sigmoid, + "tosa.softmax": Op.softmax, + "stablehlo.dot_general": Op.matmul, + "stablehlo.reshape": Op.reshape, + } + if name not in m: + raise NotImplementedError(f"Unsupported op: {name}") + return m[name] + + +# ===== Module/function traversal (based on module.operation) ===== + + +def _top_module_op(module) -> Any: + """Return the top-level Operation from an mlir.ir.Module-like object.""" + # module is mlir.ir.Module + # Get its top-level operation for traversal + return module.operation + + +def _get_func_io(func_op) -> PyTuple[List[Any], List[Any]]: + """Return function inputs (block arguments) and return operands.""" + # func_op is an Operation (func.func) + # Its body is typically at func_op.regions[0].blocks[0] + if not func_op.regions: + raise ValueError("func.func has no regions") + entry_block = func_op.regions[0].blocks[0] + inputs = list(entry_block.arguments) + + # Find func.return (possibly in the same block) + ret_op = None + for op in entry_block.operations: + if getattr(op, "name", "") == "func.return": + ret_op = op + break + if ret_op is None: + # Traverse all blocks + for region in func_op.regions: + for blk in region.blocks: + for op in blk.operations: + if getattr(op, "name", "") == "func.return": + ret_op = op + break + if ret_op: + break + if ret_op: + break + if ret_op is None: + raise ValueError("func.return not found in function body") + + outputs = list(ret_op.operands) + return inputs, outputs + + +def _iter_ops_in_func(func_op): + """Yield all operations in a func.func in block order.""" + for region in func_op.regions: + for blk in region.blocks: + for op in blk.operations: + yield op + + +def _func_candidates(mlir_module) -> List[Any]: + """Collect candidate func.func operations from a module.""" + top = _top_module_op(mlir_module) + candidates: List[Any] = [] + for region in top.regions: + for block in region.blocks: + for op in block.operations: + if getattr(op, "OPERATION_NAME", "") == "func.func": + candidates.append(op) + return candidates + + +def _func_name_from_attr(func_op) -> Optional[str]: + """Extract function symbol name from attributes if present.""" + if hasattr(func_op, "attributes") and "sym_name" in func_op.attributes: + val = str(func_op.attributes["sym_name"]) + # Typically printed as '"main"' + return val.strip('"') + return None + + +def _pick_default_func(candidates: List[Any]) -> Optional[Any]: + """Pick a default function, preferring 'main' or 'forward'.""" + if not candidates: + return None + if len(candidates) == 1: + return candidates[0] + for f in candidates: + name = _func_name_from_attr(f) + if name in ("main", "forward"): + return f + return candidates[0] + + +def _get_func_op(mlir_module, func_name: Optional[str] = None): + """Find a func.func operation by optional name, else pick a default.""" + candidates = _func_candidates(mlir_module) + if func_name is None: + return _pick_default_func(candidates) + + for f in candidates: + name = _func_name_from_attr(f) + if name == func_name: + return f + # Some bindings keep the quotes in str(); tolerate exact quoted form + if name is None and hasattr(f, "attributes") and "sym_name" in f.attributes: + if str(f.attributes["sym_name"]) == f'"{func_name}"': + return f + return None + + +# ===== Main build function ===== + + +def build_executor_from_mlir_module(mlir_module, func_name: Optional[str] = None): + """Build a callable executor from an MLIR module. + + Args: + mlir_module: An mlir.ir.Module-like object. + func_name: Optional symbol name to select a specific function. + + Returns: + A Python callable that accepts torch.Tensor inputs and returns a torch object. + """ + # Find func.func at top-level; if func_name is specified, match the symbol name + func_op = _get_func_op(mlir_module, func_name) + if func_op is None: + raise ValueError(f"func.func @{func_name} not found") + + executor = GraphExecutor("mlir_graph_exec") + env: Dict[Any, Node] = {} + + func_inputs, func_outputs = _get_func_io(func_op) + params: List[Node] = [] + for arg in func_inputs: + arg_ty = getattr(arg, "type") + dummy = _ranked_tensor_type_to_dummy(arg_ty, False) + val_node = executor.add_value_node(from_torch(dummy)) + env[arg] = val_node + params.append(val_node) + with executor: + # 1) Parameters + for val_node in params: + executor.add_parameter(val_node) + + # 2) Traverse ops + for op in _iter_ops_in_func(func_op): + name = getattr(op, "name", "") + + if name == "func.return": + continue + + # Regular ops (single result) + runtime_op = _map_mlir_op_name_to_runtime_op(name) + + # Operand nodes + inputs: List[Node] = [] + for operand in op.operands: + if operand not in env: + raise RuntimeError(f"Operand not ready for {name}") + inputs.append(env[operand]) + + results = list(op.results) + if len(results) != 1: + raise NotImplementedError( + f"{name} has {len(results)} results; only single-result ops supported" + ) + + out_ty = results[0].type + out_dummy = _ranked_tensor_type_to_dummy(out_ty, True) + + node = executor.add_op_node(runtime_op, inputs, from_torch(out_dummy)) + env[results[0]] = node + + # 3) Return + if not func_outputs: + executor.set_return() + else: + # Multiple returns: aggregate and then return + executor.make_tuple([env[v] for v in func_outputs]) + executor.set_return() + + executor.dump_graph() + executor.build() + + # Record parameter nodes in input order + placeholder_nodes = [env[a] for a in func_inputs] + + def compiled_callable(*new_inputs: torch.Tensor): + """Run the compiled executor with new torch.Tensor inputs.""" + if len(new_inputs) != len(placeholder_nodes): + raise ValueError( + f"Expected {len(placeholder_nodes)} inputs, but got {len(new_inputs)}" + ) + for i, p_node in enumerate(placeholder_nodes): + if p_node.output.is_tensor(): + update_tensor_data(p_node.output.to_tensor(), new_inputs[i]) + else: + p_node.output = from_torch(new_inputs[i]) + result = executor.run() + return to_torch(result) + + return compiled_callable + + +def apply_decompositions( + gm: torch.fx.GraphModule, + example_inputs, + decompose_ops: Optional[List[torch._ops.OpOverload]] = None, # pylint: disable=protected-access +): + """Apply operator decompositions to a GraphModule if requested. + + Note: + The torch._ops types are part of PyTorch's operator registry surface. + We suppress Pylint's protected-access warning for the type annotation. + """ + if decompose_ops is None: + return gm + + decompositions: Mapping = get_decompositions(decompose_ops) + gm = make_fx( + functionalize(gm), + decomposition_table=decompositions, + )(*example_inputs) + + return gm + + +def backend(gm: torch.fx.GraphModule, example_inputs: List[torch.Tensor]): + """FX backend entry point: decompose, import to StableHLO, and build executor.""" + from torch_mlir import fx # grouped third-party import kept local to reduce module import overhead + from torch_mlir.compiler_utils import OutputType + + gm = apply_decompositions(gm, example_inputs, DEFAULT_DECOMPOSITIONS) + + mlir_module = fx.stateless_fx_import(gm, output_type=OutputType.STABLEHLO) + + return build_executor_from_mlir_module(mlir_module) diff --git a/mopt/src/pass/replace_add_with_mul.cc b/mopt/src/pass/replace_add_with_mul.cc new file mode 100644 index 0000000000000000000000000000000000000000..1710969f3fc2521ba542bc6caff5d2cfcaa33b33 --- /dev/null +++ b/mopt/src/pass/replace_add_with_mul.cc @@ -0,0 +1,81 @@ +/** + * Copyright 2025 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "mlir/IR/PatternMatch.h" +#include "mlir/Pass/Pass.h" +#include "mlir/IR/Builders.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/Dialect/Tosa/IR/TosaOps.h" + +namespace mrt { +namespace pass { +// Inherit from OperationPass, with explicit namespace: mlir::func::FuncOp +// TODO(dayschan) remove this pass. +struct ReplaceAddWithMulPass + : public mlir::PassWrapper> { + // cppcheck-suppress unknownMacro + MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(ReplaceAddWithMulPass) + + mlir::StringRef getArgument() const final { return "replace-tosa-add-with-mul"; } + mlir::StringRef getDescription() const final { return "Replace all tosa.add with tosa.mul (demo only)."; } + + void runOnOperation() override { + mlir::func::FuncOp func = getOperation(); + mlir::OpBuilder builder(func.getContext()); + llvm::SmallVector toErase; + + func.walk([&](mlir::Operation *op) { + if (auto add = llvm::dyn_cast(op)) { + builder.setInsertionPoint(add); + + mlir::Value lhs = add.getInput1(); + mlir::Value rhs = add.getInput2(); + mlir::Type outTy = add.getType(); + + // Construct an i32 scalar constant (ElementsAttr) with shift = 0, as the third input of tosa.mul + mlir::Location loc = add.getLoc(); + // mlir::Type i32Ty = builder.getI32Type(); + // mlir::RankedTensorType shiftTy = mlir::RankedTensorType::get({}, i32Ty); + + // // DenseElementsAttr is a subclass of ElementsAttr, satisfying ConstOp's value parameter requirement + // mlir::IntegerAttr zeroAttr = builder.getI32IntegerAttr(0); + // mlir::DenseElementsAttr shiftDense = mlir::DenseElementsAttr::get(shiftTy, zeroAttr); + // mlir::ElementsAttr shiftElems = shiftDense; + + // mlir::Value shiftVal = builder.create(loc, shiftTy, shiftElems).getResult(); + + // auto mul = builder.create(loc, outTy, lhs, rhs, shiftVal); + auto mul = builder.create(loc, outTy, lhs, rhs, 0); + + add.getResult().replaceAllUsesWith(mul.getResult()); + toErase.push_back(add); + } + }); + + for (mlir::Operation *op : toErase) { + op->erase(); + } + } +}; + +// Factory function for explicit external creation +std::unique_ptr createReplaceAddWithMulPass() { return std::make_unique(); } + +} // namespace pass +} // namespace mrt + +// Static registration (for invocation/loading via registration name) +static mlir::PassRegistration pass; diff --git a/scripts/build_llvm.sh b/scripts/build_llvm.sh new file mode 100644 index 0000000000000000000000000000000000000000..28be8c6d32fcc7ef121b5749ac9d1207a48b6fce --- /dev/null +++ b/scripts/build_llvm.sh @@ -0,0 +1,123 @@ +#!/bin/bash + +PROJECT_DIR="${PROJECT_DIR:-${PWD}}" + +if [ "X${ENABLE_GITEE}" == "Xon" ]; then + LLVM_URL="https://gitee.com/mirrors/LLVM/repository/archive/d16b21b17d13ecd88a068bb803df43e53d3b04ba.zip" + TORCHMLIR_URL="https://gitee.com/mirrors_llvm/torch-mlir/repository/archive/7e7af670802d99cacdaf26e6e37249d544e4896e.zip" + STABLEHLO_URL="https://gitee.com/magicor/stablehlo/repository/archive/c28d55e91b4a5daaff18a33ce7e9bbd0f171256a.zip" +else + LLVM_URL="https://github.com/llvm/llvm-project/archive/d16b21b17d13ecd88a068bb803df43e53d3b04ba.zip" + TORCHMLIR_URL="https://github.com/llvm/torch-mlir/archive/7e7af670802d99cacdaf26e6e37249d544e4896e.zip" + STABLEHLO_URL="https://github.com/openxla/stablehlo/archive/c28d55e91b4a5daaff18a33ce7e9bbd0f171256a.zip" +fi + +TEMP_ARCHIVES_DIR="${PROJECT_DIR}/build/third_party/archives" +mkdir -p "${TEMP_ARCHIVES_DIR}" + +download() { + local url="$1" + local out="$2" + if [ -f "${out}" ] && [ -s "${out}" ]; then + # TODO(dayschan): check hash value of archive file + echo "Skip download, found existing archive: ${out}" + return + fi + wget --no-check-certificate -O "${out}" "${url}" +} + +extract_zip_to_dir() { + local zip_file="$1" + local dest_dir="$2" + rm -rf "${dest_dir}" + mkdir -p "${dest_dir}" + local work_dir + work_dir="$(mktemp -d "${TEMP_ARCHIVES_DIR}/unzip.XXXXXX")" + unzip -q "${zip_file}" -d "${work_dir}" + local top_dir + top_dir="$(find "${work_dir}" -mindepth 1 -maxdepth 1 -type d | head -n1)" + if [ -z "${top_dir}" ]; then + echo "Error: failed to locate top dir in ${zip_file}" >&2 + exit 1 + fi + shopt -s dotglob + mv "${top_dir}/"* "${dest_dir}/" + shopt -u dotglob + rm -rf "${work_dir}" +} + +LLVM_DIR="${PROJECT_DIR}/third_party/llvm-project" +TORCHMLIR_DIR="${PROJECT_DIR}/third_party/torch-mlir" +STABLEHLO_DIR="${TORCHMLIR_DIR}/externals/stablehlo" + +LLVM_ZIP="${TEMP_ARCHIVES_DIR}/llvm-project-d16b21b17d13ecd88a068bb803df43e53d3b04ba.zip" +download "${LLVM_URL}" "${LLVM_ZIP}" +extract_zip_to_dir "${LLVM_ZIP}" "${LLVM_DIR}" + +TORCHMLIR_ZIP="${TEMP_ARCHIVES_DIR}/torch-mlir-7e7af670802d99cacdaf26e6e37249d544e4896e.zip" +download "${TORCHMLIR_URL}" "${TORCHMLIR_ZIP}" +extract_zip_to_dir "${TORCHMLIR_ZIP}" "${TORCHMLIR_DIR}" + +STABLEHLO_ZIP="${TEMP_ARCHIVES_DIR}/stablehlo-c28d55e91b4a5daaff18a33ce7e9bbd0f171256a.zip" +download "${STABLEHLO_URL}" "${STABLEHLO_ZIP}" +extract_zip_to_dir "${STABLEHLO_ZIP}" "${STABLEHLO_DIR}" + +#------------ build llvm +LLVM_BUILD_DIR="${PROJECT_DIR}/build/third_party/build/llvm" +LLVM_INSTALL_PREFIX="${LLVM_INSTALL_PREFIX:-${LLVM_BUILD_DIR}/install}" + +cd "${PROJECT_DIR}" +echo "Configuring llvm..." +cmake -GNinja -B "${LLVM_BUILD_DIR}" \ + ${CCACHE_CMAKE_ARGS} \ + -DCMAKE_BUILD_TYPE=Release \ + -DBUILD_SHARED_LIBS=ON \ + -DMLIR_BUILD_MLIR_C_DYLIB=ON \ + -DPython3_FIND_VIRTUALENV=ONLY \ + -DLLVM_ENABLE_PROJECTS="mlir" \ + -DLLVM_ENABLE_RTTI=ON \ + -DMLIR_ENABLE_BINDINGS_PYTHON=ON \ + -DLLVM_TARGETS_TO_BUILD=host \ + -DLLVM_BUILD_UTILS=ON \ + -DLLVM_INSTALL_UTILS=ON \ + -DCMAKE_INSTALL_PREFIX="${LLVM_INSTALL_PREFIX}" \ + -DCMAKE_BUILD_RPATH="${LLVM_BUILD_DIR}/lib" \ + -DCMAKE_INSTALL_RPATH="\$ORIGIN/../lib" \ + third_party/llvm-project/llvm +echo "Configured llvm." + +echo "Building llvm..." +cmake --build "${LLVM_BUILD_DIR}" +echo "Built llvm: ${LLVM_BUILD_DIR}" + +echo "Installing llvm..." +cmake --install "${LLVM_BUILD_DIR}" +echo "Installed llvm: ${LLVM_INSTALL_PREFIX}" + +#------------ build torch_mlir +TORCHMLIR_BUILD_DIR="${PROJECT_DIR}/build/third_party/build/torch_mlir" +TORCHMLIR_INSTALL_PREFIX="${TORCHMLIR_INSTALL_PREFIX:-${TORCHMLIR_BUILD_DIR}/install}" + +echo "Configuring torch_mlir..." +cmake -GNinja -B "${TORCHMLIR_BUILD_DIR}" \ + ${CCACHE_CMAKE_ARGS} \ + -DMLIR_DIR="${LLVM_INSTALL_PREFIX}/lib/cmake/mlir" \ + -DLLVM_DIR="${LLVM_INSTALL_PREFIX}/lib/cmake/llvm" \ + -DCMAKE_BUILD_TYPE=Release \ + -DPython3_FIND_VIRTUALENV=ONLY \ + -DMLIR_ENABLE_BINDINGS_PYTHON=ON \ + -DLLVM_TARGETS_TO_BUILD=host \ + -DTORCH_MLIR_ENABLE_TESTS=OFF \ + -DCMAKE_INSTALL_PREFIX="${TORCHMLIR_INSTALL_PREFIX}" \ + -DCMAKE_BUILD_RPATH="${TORCHMLIR_INSTALL_PREFIX}/lib" \ + -DTORCH_MLIR_INSTALL_USE_SYMLINKS=OFF \ + third_party/torch-mlir +echo "Configured torch_mlir." + +echo "Building torch_mlir..." +cmake --build "${TORCHMLIR_BUILD_DIR}" +echo "Built torch_mlir: ${TORCHMLIR_BUILD_DIR}" + +echo "Installing torch_mlir..." +cmake --install "${TORCHMLIR_BUILD_DIR}" +echo "Installed torch_mlir: ${TORCHMLIR_BUILD_DIR}" diff --git a/setup.py b/setup.py index 484a3fdd886710d2a471a6fc9349af9479f03459..1172c55e8ed437b3b5e10c6bd84987d4b3e4c96e 100644 --- a/setup.py +++ b/setup.py @@ -1,3 +1,26 @@ +# Copyright 2025 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Build script for packaging the mrt project. + +This script: +- Collects shared libraries and Python sources into a temporary package tree. +- Vendors third-party artifacts (LLVM, torch-mlir) into the package. +- Embeds the current git commit id in a `.commit-id` file. +- Produces a wheel and moves it to the output directory. +""" + import os import shutil import fnmatch @@ -5,30 +28,35 @@ import subprocess from setuptools import setup from setuptools.dist import Distribution + class BinaryDistribution(Distribution): """Custom distribution class to indicate binary extensions exist""" + def has_ext_modules(self): return True + def get_git_commit_id(): """Get the current git commit ID""" try: commit_id = subprocess.check_output(['git', 'rev-parse', 'HEAD']).decode('utf-8').strip() return commit_id - except Exception as e: + except subprocess.CalledProcessError as e: print(f"Warning: Could not get git commit ID: {e}") return "unknown" + def create_commit_id_file(dst_dir, commit_id): """Create .commit-id file in the target directory""" commit_file = os.path.join(dst_dir, '.commit-id') with open(commit_file, 'w') as f: f.write(commit_id) + def copy_so_files(src_dir, dst_root_dir, dst_lib_dir, special_so_patterns): """ Copy .so files with special handling for files matching specific patterns - + Args: src_dir: Source directory containing .so files dst_root_dir: Destination root directory for special .so files @@ -37,7 +65,7 @@ def copy_so_files(src_dir, dst_root_dir, dst_lib_dir, special_so_patterns): """ os.makedirs(dst_root_dir, exist_ok=True) os.makedirs(dst_lib_dir, exist_ok=True) - + for root, _, files in os.walk(src_dir): for file in files: if file.endswith('.so'): @@ -50,6 +78,7 @@ def copy_so_files(src_dir, dst_root_dir, dst_lib_dir, special_so_patterns): # Copy other .so files to lib directory shutil.copy2(src_path, os.path.join(dst_lib_dir, file)) + def copy_py_files(src_python_dir, dst_root_dir): """ Copy Python files while preserving subdirectory structure @@ -63,41 +92,75 @@ def copy_py_files(src_python_dir, dst_root_dir): os.makedirs(os.path.dirname(dst_path), exist_ok=True) shutil.copy2(src_path, dst_path) + +def copy_thirdparty_files(build_dir, package_dir): + """Vendor third-party artifacts (LLVM and torch-mlir) into pkg_dir.""" + llvm_install_path = os.path.join(build_dir, "third_party", "install", "llvm") + torch_mlir_install_path = os.path.join(build_dir, "third_party", "install", "torch_mlir") + if not os.path.exists(llvm_install_path): + return + vendor_path = os.path.join(package_dir, "_vendor") + + vendor_llvm_path = os.path.join(vendor_path, "llvm") + os.makedirs(vendor_llvm_path, exist_ok=True) + shutil.copytree(os.path.join(llvm_install_path, "lib"), os.path.join(vendor_llvm_path, "lib"), dirs_exist_ok=True) + shutil.copytree(os.path.join(llvm_install_path, "python_packages", "mlir_core"), + os.path.join(vendor_llvm_path, "python"), dirs_exist_ok=True) + + vendor_torch_mlir_path = os.path.join(vendor_path, "torch_mlir") + os.makedirs(vendor_torch_mlir_path, exist_ok=True) + shutil.copytree( + os.path.join( + torch_mlir_install_path, + "python_packages", + "torch_mlir", + "torch_mlir"), + os.path.join( + vendor_torch_mlir_path, + "python", + "torch_mlir")) + + # Get the directory where this script is located script_dir = os.path.dirname(os.path.abspath(__file__)) # Configure paths -src_dir = script_dir + '/build/inferrt/src' -python_src_dir = script_dir + '/inferrt/python/mrt' +project_build_dir = os.path.join(script_dir, "build") +inferrt_src_dir = script_dir + '/build/inferrt/src' +mrt_python_src_dir = os.path.join(script_dir, "inferrt", "python", "mrt") +mopt_python_src_dir = os.path.join(script_dir, "mopt", "python", "mrt") temp_dir = os.path.join(script_dir, 'temp_build') package_name = 'mrt' # Define patterns for special .so files (supports wildcards) -special_so_patterns = [ +special_so_files_patterns = [ '_mrt_api*.so', # Matches all .so files starting with _mrt_api '_mrt_ir*.so', # Matches all .so files starting with _mrt_ir '_mrt_torch*.so' # Matches all .so files starting with _mrt_torch ] # Clean and create temporary directory structure -package_dir = os.path.join(temp_dir, package_name) +project_package_dir = os.path.join(temp_dir, package_name) shutil.rmtree(temp_dir, ignore_errors=True) -os.makedirs(package_dir, exist_ok=True) +os.makedirs(project_package_dir, exist_ok=True) # Get current git commit ID -commit_id = get_git_commit_id() +git_commit_id = get_git_commit_id() # Copy files to temporary directory copy_so_files( - src_dir=src_dir, - dst_root_dir=package_dir, # Destination for special .so files - dst_lib_dir=os.path.join(package_dir, 'lib'), # Destination for other .so files - special_so_patterns=special_so_patterns + src_dir=inferrt_src_dir, + dst_root_dir=project_package_dir, # Destination for special .so files + dst_lib_dir=os.path.join(project_package_dir, 'lib'), # Destination for other .so files + special_so_patterns=special_so_files_patterns ) -copy_py_files(python_src_dir, package_dir) +copy_py_files(mrt_python_src_dir, project_package_dir) +copy_py_files(mopt_python_src_dir, project_package_dir) + +copy_thirdparty_files(project_build_dir, project_package_dir) # Create .commit-id file in package directory -create_commit_id_file(package_dir, commit_id) +create_commit_id_file(project_package_dir, git_commit_id) # Generate wheel package setup( @@ -111,10 +174,11 @@ setup( package_dir={'': 'temp_build'}, package_data={ package_name: [ - '*.so', # Include .so files in root directory - 'lib/*.so', # Include .so files in lib directory + '**/*.so', # Include all .so files recursively '**/*.py', # Include all Python files recursively - '.commit-id', # Include the commit ID file + '**/*.pyi', + '**/llvm/**/*.so.*', # There are some library named *.so.19.0git in llvm now. + '.commit-id', # Include the commit ID file ], }, include_package_data=True, diff --git a/tests/mopt/test_fx_mlir_backend.py b/tests/mopt/test_fx_mlir_backend.py new file mode 100644 index 0000000000000000000000000000000000000000..205b1b7ce6c76d4062e6a06a02d98d8d874f79c9 --- /dev/null +++ b/tests/mopt/test_fx_mlir_backend.py @@ -0,0 +1,39 @@ +# Copyright 2025 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""test_fx_mlir_backend""" + +import torch +import numpy as np +from mrt.torch.fx_mlir_backend import backend + + +def foo(x, y): + return torch.matmul(x, y) + + +opt_foo = torch.compile(foo, backend=backend) + + +def run(shape1, shape2): + x_np = np.random.randn(*shape1).astype(np.float32) + y_np = np.random.randn(*shape2).astype(np.float32) + expect = np.matmul(x_np, y_np) + out = opt_foo(torch.tensor(x_np), torch.tensor(y_np)) + assert np.allclose(out, expect, 1e-3, 1e-3), f"\nout={out}\nexpect={expect}" + + +run((2, 2), (2, 4)) + +print("The result is correct. 'mrt' backend has been installed successfully.")