From 119ba35edb7fb646bfc81c6769e7dd1776890970 Mon Sep 17 00:00:00 2001 From: lihui Date: Thu, 27 Nov 2025 16:22:31 +0800 Subject: [PATCH] add ub pass Signed-off-by: lihui --- .../include/akg/Dialect/CMakeLists.txt | 1 + .../include/akg/Dialect/Tensor/CMakeLists.txt | 3 + .../include/akg/Dialect/Tensor/Passes.h | 31 +++++ .../include/akg/Dialect/Tensor/Passes.td | 33 +++++ .../akg/Dialect/Tensor/Transforms/Tensor2UB.h | 36 ++++++ akg-mlir/compiler/lib/Dialect/CMakeLists.txt | 3 +- .../lib/Dialect/Tensor/CMakeLists.txt | 22 ++++ .../Dialect/Tensor/Transforms/Tensor2UB.cpp | 119 ++++++++++++++++++ .../Pipelines/AscendPipelines/CMakeLists.txt | 1 + .../compiler/lib/Pipelines/CMakeLists.txt | 1 + .../lib/Pipelines/GPUPipelines/CMakeLists.txt | 1 + .../compiler/tools/akg-opt/CMakeLists.txt | 2 + akg-mlir/compiler/tools/akg-opt/akg-opt.cpp | 2 + .../tests/ut/Dialect/Tensor/Tensor_2_UB.mlir | 35 ++++++ 14 files changed, 289 insertions(+), 1 deletion(-) create mode 100644 akg-mlir/compiler/include/akg/Dialect/Tensor/CMakeLists.txt create mode 100644 akg-mlir/compiler/include/akg/Dialect/Tensor/Passes.h create mode 100644 akg-mlir/compiler/include/akg/Dialect/Tensor/Passes.td create mode 100644 akg-mlir/compiler/include/akg/Dialect/Tensor/Transforms/Tensor2UB.h create mode 100644 akg-mlir/compiler/lib/Dialect/Tensor/CMakeLists.txt create mode 100644 akg-mlir/compiler/lib/Dialect/Tensor/Transforms/Tensor2UB.cpp create mode 100644 akg-mlir/tests/ut/Dialect/Tensor/Tensor_2_UB.mlir diff --git a/akg-mlir/compiler/include/akg/Dialect/CMakeLists.txt b/akg-mlir/compiler/include/akg/Dialect/CMakeLists.txt index 563022a983..874bc63d79 100644 --- a/akg-mlir/compiler/include/akg/Dialect/CMakeLists.txt +++ b/akg-mlir/compiler/include/akg/Dialect/CMakeLists.txt @@ -6,3 +6,4 @@ add_subdirectory(Linalg) add_subdirectory(Fusion) add_subdirectory(MindSpore) add_subdirectory(CPU) +add_subdirectory(Tensor) \ No newline at end of file diff --git a/akg-mlir/compiler/include/akg/Dialect/Tensor/CMakeLists.txt b/akg-mlir/compiler/include/akg/Dialect/Tensor/CMakeLists.txt new file mode 100644 index 0000000000..eec58b95ee --- /dev/null +++ b/akg-mlir/compiler/include/akg/Dialect/Tensor/CMakeLists.txt @@ -0,0 +1,3 @@ +set(LLVM_TARGET_DEFINITIONS Passes.td) +mlir_tablegen(Passes.h.inc -gen-pass-decls -name AKGTensor) +add_public_tablegen_target(AKGTensorPassIncGen) \ No newline at end of file diff --git a/akg-mlir/compiler/include/akg/Dialect/Tensor/Passes.h b/akg-mlir/compiler/include/akg/Dialect/Tensor/Passes.h new file mode 100644 index 0000000000..a73b161e04 --- /dev/null +++ b/akg-mlir/compiler/include/akg/Dialect/Tensor/Passes.h @@ -0,0 +1,31 @@ +/** + * Copyright 2023-2025 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef COMPILER_INCLUDE_AKG_DIALECT_TENSOR_PASSES_H_ +#define COMPILER_INCLUDE_AKG_DIALECT_TENSOR_PASSES_H_ + +#include "akg/Dialect/Tensor/Transforms/Tensor2UB.h" + +namespace mlir { + +/// Generate the code for registering transforms passes. +#ifndef GEN_PASS_REGISTRATION +#define GEN_PASS_REGISTRATION +#include "akg/Dialect/Tensor/Passes.h.inc" +#endif +} // namespace mlir + +#endif // COMPILER_INCLUDE_AKG_DIALECT_TENSOR_PASSES_H_ \ No newline at end of file diff --git a/akg-mlir/compiler/include/akg/Dialect/Tensor/Passes.td b/akg-mlir/compiler/include/akg/Dialect/Tensor/Passes.td new file mode 100644 index 0000000000..d50fc5c94e --- /dev/null +++ b/akg-mlir/compiler/include/akg/Dialect/Tensor/Passes.td @@ -0,0 +1,33 @@ +/** + * Copyright 2023-2024 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef AKG_MLIR_DIALECT_TENSOR_PASSES +#define AKG_MLIR_DIALECT_TENSOR_PASSES + +include "mlir/Pass/PassBase.td" + +def Tensor2UBPass : Pass<"tensor-2-ub", "mlir::func::FuncOp"> { + let summary = "Move the tensor to UB (Unified Buffer) on Ascend"; + let description = [{ + This pass is a placeholder or diagnostic pass that prints a message + indicating tensor-to-UB transformation is being considered. + It does not modify the IR in minimal form. + }]; + + let constructor = "mlir::tensor::createTensor2UBPass()"; +} + +#endif // AKG_MLIR_DIALECT_TENSOR_PASSES \ No newline at end of file diff --git a/akg-mlir/compiler/include/akg/Dialect/Tensor/Transforms/Tensor2UB.h b/akg-mlir/compiler/include/akg/Dialect/Tensor/Transforms/Tensor2UB.h new file mode 100644 index 0000000000..6fad4f6266 --- /dev/null +++ b/akg-mlir/compiler/include/akg/Dialect/Tensor/Transforms/Tensor2UB.h @@ -0,0 +1,36 @@ +/** + * Copyright 2025 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef COMPILER_INCLUDE_AKG_DIALECT_TENSOR_TRANSFORMS_TENSOR2UB_H_ +#define COMPILER_INCLUDE_AKG_DIALECT_TENSOR_TRANSFORMS_TENSOR2UB_H_ + +#include +#include "mlir/Pass/Pass.h" + +namespace mlir { + +namespace func { +class FuncOp; +} // namespace func + +namespace tensor { + +std::unique_ptr> +createTensor2UBPass(); + +} // namespace tensor +} // namespace mlir + +#endif // COMPILER_INCLUDE_AKG_DIALECT_TENSOR_TRANSFORMS_TENSOR2UB_H_ \ No newline at end of file diff --git a/akg-mlir/compiler/lib/Dialect/CMakeLists.txt b/akg-mlir/compiler/lib/Dialect/CMakeLists.txt index 8493fac50c..874bc63d79 100644 --- a/akg-mlir/compiler/lib/Dialect/CMakeLists.txt +++ b/akg-mlir/compiler/lib/Dialect/CMakeLists.txt @@ -5,4 +5,5 @@ add_subdirectory(GPU) add_subdirectory(Linalg) add_subdirectory(Fusion) add_subdirectory(MindSpore) -add_subdirectory(CPU) \ No newline at end of file +add_subdirectory(CPU) +add_subdirectory(Tensor) \ No newline at end of file diff --git a/akg-mlir/compiler/lib/Dialect/Tensor/CMakeLists.txt b/akg-mlir/compiler/lib/Dialect/Tensor/CMakeLists.txt new file mode 100644 index 0000000000..c64b25e6df --- /dev/null +++ b/akg-mlir/compiler/lib/Dialect/Tensor/CMakeLists.txt @@ -0,0 +1,22 @@ +file(GLOB_RECURSE SRC_LIST *.cpp) + +add_mlir_dialect_library(AKGTensorPasses + ${SRC_LIST} + + ADDITIONAL_HEADER_DIRS + ${AKG_MLIR_SOURCE_DIR}/akg/Dialect/Tensor + ${AKG_MLIR_SOURCE_DIR}/akg/Dialect/Tensor/Transforms + + + DISABLE_INSTALL + DEPENDS + AKGTensorPassIncGen + MLIRTensorDialect + + LINK_LIBS PUBLIC + MLIRIR + MLIRTensorDialect + MLIRFuncDialect + MLIRSupport + MLIRPass + ) \ No newline at end of file diff --git a/akg-mlir/compiler/lib/Dialect/Tensor/Transforms/Tensor2UB.cpp b/akg-mlir/compiler/lib/Dialect/Tensor/Transforms/Tensor2UB.cpp new file mode 100644 index 0000000000..3385753918 --- /dev/null +++ b/akg-mlir/compiler/lib/Dialect/Tensor/Transforms/Tensor2UB.cpp @@ -0,0 +1,119 @@ +/** + * Copyright 2025 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +//===----------------------------------------------------------------------===// +// Tensor2UB.cpp +// +// This pass instruments the IR by inserting a `linalg.abs` operation after +// every `tensor.extract_slice` op. Specifically: +// • For each `tensor.extract_slice`, create a new `tensor.empty` as output, +// then apply `linalg.abs` to the slice result. +// • Redirect all subsequent uses of the original slice to the output of `abs`. +// +// This transformation is intended as a placeholder or instrumentation step +// (e.g., for testing buffer allocation, UB modeling, or pass scheduling), +// and does not perform actual Tensor-to-Uniform-Buffer conversion. +// Note: The use of `linalg.abs` assumes element types support absolute value +// (e.g., signed integers or floating-point); unsigned or unsupported types +// may cause verifier errors if not handled externally. +//===----------------------------------------------------------------------===// + +#include "akg/Dialect/Affine/Transforms/AffineForVectorize.h" +#include "akg/Dialect/Tensor/Transforms/Tensor2UB.h" +#include "akg/Dialect/Affine/Passes.h" +#include "akg/Dialect/Tensor/Passes.h" +#include "mlir/Dialect/Affine/IR/AffineMemoryOpInterfaces.h" +#include "mlir/Dialect/Affine/IR/AffineOps.h" +#include "mlir/Dialect/Tensor/IR/Tensor.h" +#include "mlir/Dialect/Linalg/IR/Linalg.h" +#include "mlir/Dialect/Affine/Utils.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/Dialect/Vector/IR/VectorOps.h" +#include "mlir/Pass/Pass.h" +#include "llvm/Support/CommandLine.h" +#include "llvm/Support/Debug.h" + +#define DEBUG_TYPE "tensor-2-ub" + +namespace mlir { +#define GEN_PASS_DECL_TENSOR2UBPASS +#define GEN_PASS_DEF_TENSOR2UBPASS +#include "akg/Dialect/Tensor/Passes.h.inc" +} + +using namespace mlir; +using namespace mlir::tensor; + +namespace { + +class Tensor2UBPass + : public mlir::impl::Tensor2UBPassBase { + public: + Tensor2UBPass() = default; + Tensor2UBPass(const Tensor2UBPass &) = default; + StringRef getArgument() const override { return "tensor-2-ub"; } + StringRef getDescription() const override { + return "Vectorize each innermost affine.for loop (1-D). " + "For dynamic trip count or ub==1, use --default-vf."; + } + + void runOnOperation() override { + func::FuncOp funcOp = getOperation(); + MLIRContext *ctx = funcOp.getContext(); + OpBuilder builder(ctx); + + llvm::SmallVector slices; + + funcOp.walk([&](tensor::ExtractSliceOp op) { + slices.push_back(op); + }); + + for (auto sliceOp : slices) { + Value sliceValue = sliceOp.getResult(); + auto loc = sliceOp.getLoc(); + auto sliceTy = sliceValue.getType().cast(); + builder.setInsertionPointAfter(sliceOp); + + // 1. create output tensor, structured op need outs + auto empty = builder.create( + loc, sliceTy.getShape(), sliceTy.getElementType()); + + // 2. create linalg.abs, inputs + outputs + auto absOp = builder.create( + loc, + ValueRange{sliceValue}, + ValueRange{empty} + ); + + Value absValue = absOp.getResult(0); + llvm::SmallVector operandsToReplace; + for (OpOperand &use : sliceValue.getUses()) { + Operation *user = use.getOwner(); + if (user == absOp) continue; + operandsToReplace.push_back(&use); + } + for (auto *opOperand : operandsToReplace) { + opOperand->set(absValue); + } + } + } +}; +} // namespace + +std::unique_ptr> +mlir::tensor::createTensor2UBPass() { + return std::make_unique(); +} \ No newline at end of file diff --git a/akg-mlir/compiler/lib/Pipelines/AscendPipelines/CMakeLists.txt b/akg-mlir/compiler/lib/Pipelines/AscendPipelines/CMakeLists.txt index b8060947ce..ee2bf4f249 100644 --- a/akg-mlir/compiler/lib/Pipelines/AscendPipelines/CMakeLists.txt +++ b/akg-mlir/compiler/lib/Pipelines/AscendPipelines/CMakeLists.txt @@ -1,6 +1,7 @@ set(AKG_MLIR_LIBS AKGTransformsPasses AKGLinalgPasses + AKGTensorPasses AKGLLVMIRPasses AKGMLIRAnalysis ) diff --git a/akg-mlir/compiler/lib/Pipelines/CMakeLists.txt b/akg-mlir/compiler/lib/Pipelines/CMakeLists.txt index d2d8a391bd..281e8f8474 100644 --- a/akg-mlir/compiler/lib/Pipelines/CMakeLists.txt +++ b/akg-mlir/compiler/lib/Pipelines/CMakeLists.txt @@ -5,6 +5,7 @@ add_subdirectory(AscendPipelines) set(AKG_MLIR_LIBS MindSporePasses AKGAffinePasses + AKGTensorPasses AKGTransformsPasses AKGGPUPasses AKGSCFPasses diff --git a/akg-mlir/compiler/lib/Pipelines/GPUPipelines/CMakeLists.txt b/akg-mlir/compiler/lib/Pipelines/GPUPipelines/CMakeLists.txt index 70c1fc4283..e0a23e5af5 100644 --- a/akg-mlir/compiler/lib/Pipelines/GPUPipelines/CMakeLists.txt +++ b/akg-mlir/compiler/lib/Pipelines/GPUPipelines/CMakeLists.txt @@ -1,5 +1,6 @@ set(AKG_MLIR_LIBS AKGAffinePasses + AKGTensorPasses AKGTransformsPasses AKGGPUPasses AKGLinalgPasses diff --git a/akg-mlir/compiler/tools/akg-opt/CMakeLists.txt b/akg-mlir/compiler/tools/akg-opt/CMakeLists.txt index 8f285f4d4d..9f6d02ce8b 100644 --- a/akg-mlir/compiler/tools/akg-opt/CMakeLists.txt +++ b/akg-mlir/compiler/tools/akg-opt/CMakeLists.txt @@ -1,5 +1,6 @@ set(AKG_MLIR_LIBS AKGAffinePasses + AKGTensorPasses AKGTransformsPasses AKGSCFPasses AKGGPUPasses @@ -25,6 +26,7 @@ set(LIBS ${conversion_libs} ${AKG_MLIR_LIBS} AKGLinalgPasses + AKGTensorPasses MLIRAnalysis MLIRDialect MLIROptLib diff --git a/akg-mlir/compiler/tools/akg-opt/akg-opt.cpp b/akg-mlir/compiler/tools/akg-opt/akg-opt.cpp index ffa6dac649..14824a514e 100644 --- a/akg-mlir/compiler/tools/akg-opt/akg-opt.cpp +++ b/akg-mlir/compiler/tools/akg-opt/akg-opt.cpp @@ -16,6 +16,7 @@ #include "akg/Conversion/Passes.h" #include "akg/Dialect/Affine/Passes.h" +#include "akg/Dialect/Tensor/Passes.h" #include "akg/Dialect/Fusion/IR/Fusion.h" #include "akg/Dialect/GPU/Passes.h" #include "akg/Dialect/LLVMIR/Passes.h" @@ -62,6 +63,7 @@ int main(int argc, char **argv) { registerMindSporePasses(); registerAKGAffinePasses(); + registerAKGTensorPasses(); registerMindSporePasses(); registerAKGLinalgPasses(); registerAKGTransformsPasses(); diff --git a/akg-mlir/tests/ut/Dialect/Tensor/Tensor_2_UB.mlir b/akg-mlir/tests/ut/Dialect/Tensor/Tensor_2_UB.mlir new file mode 100644 index 0000000000..794cff6f7c --- /dev/null +++ b/akg-mlir/tests/ut/Dialect/Tensor/Tensor_2_UB.mlir @@ -0,0 +1,35 @@ +// RUN: akg-opt %s --tensor-2-ub -allow-unregistered-dialect | FileCheck %s + +// CHECK-LABEL: func.func @tensor_2_ub(%arg0: tensor<86016xbf16>, %arg1: tensor<86016xbf16>) -> tensor<86016xbf16> attributes {OperatorType = "Elementwise", compute_capability = "", mindspore_kernel, process = "aicore"} { +// CHECK-NEXT: %0 = tensor.empty() : tensor<86016xbf16> +// CHECK-NEXT: %1 = affine.for %arg2 = 0 to 86016 step 32 iter_args(%arg3 = %0) -> (tensor<86016xbf16>) { +// CHECK-NEXT: %2 = affine.for %arg4 = #map(%arg2) to #map1(%arg2) step 512 iter_args(%arg5 = %arg3) -> (tensor<86016xbf16>) { +// CHECK-NEXT: %extracted_slice = tensor.extract_slice %arg0[0] [512] [1] : tensor<86016xbf16> to tensor<512xbf16> +// CHECK-NEXT: %3 = tensor.empty() : tensor<512xbf16> +// CHECK-NEXT: %4 = linalg.abs ins(%extracted_slice : tensor<512xbf16>) outs(%3 : tensor<512xbf16>) -> tensor<512xbf16> +// CHECK-NEXT: %extracted_slice_0 = tensor.extract_slice %arg1[0] [512] [1] : tensor<86016xbf16> to tensor<512xbf16> +// CHECK-NEXT: %5 = tensor.empty() : tensor<512xbf16> +// CHECK-NEXT: %6 = linalg.abs ins(%extracted_slice_0 : tensor<512xbf16>) outs(%5 : tensor<512xbf16>) -> tensor<512xbf16> +// CHECK-NEXT: %7 = arith.addf %4, %6 : tensor<512xbf16> +// CHECK-NEXT: %inserted_slice = tensor.insert_slice %7 into %arg5[0] [512] [1] : tensor<512xbf16> into tensor<86016xbf16> +// CHECK-NEXT: affine.yield %inserted_slice : tensor<86016xbf16> +// CHECK-NEXT: } +// CHECK-NEXT: affine.yield %2 : tensor<86016xbf16> +// CHECK-NEXT: } +// CHECK-NEXT: return %1 : tensor<86016xbf16> +// CHECK-NEXT: } + +func.func @tensor_2_ub(%arg0: tensor<86016xbf16>, %arg1: tensor<86016xbf16>) -> tensor<86016xbf16> attributes {OperatorType = "Elementwise", compute_capability = "", mindspore_kernel, process = "aicore"} { + %0 = tensor.empty() : tensor<86016xbf16> + %1 = affine.for %arg2 = 0 to 86016 step 32 iter_args(%arg3 = %0) -> (tensor<86016xbf16>) { + %2 = affine.for %arg4 = affine_map<(d0) -> (d0)>(%arg2) to affine_map<(d0) -> (d0 + 32)>(%arg2) step 512 iter_args(%arg5 = %arg3) -> (tensor<86016xbf16>) { + %extracted_slice = tensor.extract_slice %arg0[0] [512] [1] : tensor<86016xbf16> to tensor<512xbf16> + %extracted_slice_0 = tensor.extract_slice %arg1[0] [512] [1] : tensor<86016xbf16> to tensor<512xbf16> + %3 = arith.addf %extracted_slice, %extracted_slice_0 : tensor<512xbf16> + %inserted_slice = tensor.insert_slice %3 into %arg5[0] [512] [1] : tensor<512xbf16> into tensor<86016xbf16> + affine.yield %inserted_slice : tensor<86016xbf16> + } + affine.yield %2 : tensor<86016xbf16> + } + return %1 : tensor<86016xbf16> +} \ No newline at end of file -- Gitee