From fef9eb7f3e3fb030645a7134e53f0a3c63126fef Mon Sep 17 00:00:00 2001 From: Daniel Date: Tue, 2 Sep 2025 16:16:51 +0200 Subject: [PATCH 1/7] Ptr To LLVM --- .../mlir/Conversion/PtrToLLVM/PtrToLLVM.h | 27 +++++++ mlir/include/mlir/InitAllExtensions.h | 3 + mlir/lib/Conversion/CMakeLists.txt | 1 + mlir/lib/Conversion/PtrToLLVM/CMakeLists.txt | 17 ++++ mlir/lib/Conversion/PtrToLLVM/PtrToLLVM.cpp | 77 +++++++++++++++++++ 5 files changed, 125 insertions(+) create mode 100644 mlir/include/mlir/Conversion/PtrToLLVM/PtrToLLVM.h create mode 100644 mlir/lib/Conversion/PtrToLLVM/CMakeLists.txt create mode 100644 mlir/lib/Conversion/PtrToLLVM/PtrToLLVM.cpp diff --git a/mlir/include/mlir/Conversion/PtrToLLVM/PtrToLLVM.h b/mlir/include/mlir/Conversion/PtrToLLVM/PtrToLLVM.h new file mode 100644 index 000000000000..0ff92bc85668 --- /dev/null +++ b/mlir/include/mlir/Conversion/PtrToLLVM/PtrToLLVM.h @@ -0,0 +1,27 @@ +//===- PtrToLLVM.h - Ptr to LLVM dialect conversion -------------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_CONVERSION_PTRTOLLVM_PTRTOLLVM_H +#define MLIR_CONVERSION_PTRTOLLVM_PTRTOLLVM_H + +#include + +namespace mlir { +class DialectRegistry; +class LLVMTypeConverter; +class RewritePatternSet; +namespace ptr { +/// Populate the convert to LLVM patterns for the `ptr` dialect. +void populatePtrToLLVMConversionPatterns(LLVMTypeConverter &converter, + RewritePatternSet &patterns); +/// Register the convert to LLVM interface for the `ptr` dialect. +void registerConvertPtrToLLVMInterface(DialectRegistry ®istry); +} // namespace ptr +} // namespace mlir + +#endif // MLIR_CONVERSION_PTRTOLLVM_PTRTOLLVM_H diff --git a/mlir/include/mlir/InitAllExtensions.h b/mlir/include/mlir/InitAllExtensions.h index 20a4ab6f18a2..43835ab9c634 100644 --- a/mlir/include/mlir/InitAllExtensions.h +++ b/mlir/include/mlir/InitAllExtensions.h @@ -23,6 +23,7 @@ #include "mlir/Conversion/MemRefToLLVM/MemRefToLLVM.h" #include "mlir/Conversion/NVVMToLLVM/NVVMToLLVM.h" #include "mlir/Conversion/UBToLLVM/UBToLLVM.h" +#include "mlir/Conversion/PtrToLLVM/PtrToLLVM.h" #include "mlir/Dialect/Affine/TransformOps/AffineTransformOps.h" #include "mlir/Dialect/Bufferization/TransformOps/BufferizationTransformOps.h" #include "mlir/Dialect/Func/Extensions/AllExtensions.h" @@ -44,6 +45,7 @@ #include "mlir/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.h" #include "mlir/Target/LLVMIR/Dialect/ROCDL/ROCDLToLLVMIRTranslation.h" + #include namespace mlir { @@ -56,6 +58,7 @@ namespace mlir { inline void registerAllExtensions(DialectRegistry ®istry) { // Register all conversions to LLVM extensions. arith::registerConvertArithToLLVMInterface(registry); + ptr::registerConvertPtrToLLVMInterface(registry); registerConvertComplexToLLVMInterface(registry); cf::registerConvertControlFlowToLLVMInterface(registry); func::registerAllExtensions(registry); diff --git a/mlir/lib/Conversion/CMakeLists.txt b/mlir/lib/Conversion/CMakeLists.txt index 80c8b84d9ae8..9dcbba9eda5b 100644 --- a/mlir/lib/Conversion/CMakeLists.txt +++ b/mlir/lib/Conversion/CMakeLists.txt @@ -50,6 +50,7 @@ add_subdirectory(ReconcileUnrealizedCasts) add_subdirectory(SCFToControlFlow) add_subdirectory(SCFToEmitC) add_subdirectory(SCFToGPU) +add_subdirectory(PtrToLLVM) add_subdirectory(SCFToOpenMP) add_subdirectory(SCFToSPIRV) add_subdirectory(ShapeToStandard) diff --git a/mlir/lib/Conversion/PtrToLLVM/CMakeLists.txt b/mlir/lib/Conversion/PtrToLLVM/CMakeLists.txt new file mode 100644 index 000000000000..d8c60d7ad0d1 --- /dev/null +++ b/mlir/lib/Conversion/PtrToLLVM/CMakeLists.txt @@ -0,0 +1,17 @@ +add_mlir_conversion_library(MLIRPtrToLLVM + PtrToLLVM.cpp + + ADDITIONAL_HEADER_DIRS + ${MLIR_MAIN_INCLUDE_DIR}/mlir/Conversion/PtrToLLVM + + DEPENDS + MLIRConversionPassIncGen + + LINK_COMPONENTS + Core + + LINK_LIBS PUBLIC + MLIRPtrDialect + MLIRLLVMCommonConversion + MLIRLLVMDialect + ) \ No newline at end of file diff --git a/mlir/lib/Conversion/PtrToLLVM/PtrToLLVM.cpp b/mlir/lib/Conversion/PtrToLLVM/PtrToLLVM.cpp new file mode 100644 index 000000000000..caebed24d0af --- /dev/null +++ b/mlir/lib/Conversion/PtrToLLVM/PtrToLLVM.cpp @@ -0,0 +1,77 @@ +//===- PtrToLLVM.cpp - Ptr to LLVM dialect conversion ---------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "mlir/Conversion/PtrToLLVM/PtrToLLVM.h" + +#include "mlir/Conversion/ConvertToLLVM/ToLLVMInterface.h" +#include "mlir/Conversion/LLVMCommon/ConversionTarget.h" +#include "mlir/Conversion/LLVMCommon/Pattern.h" +#include "mlir/Conversion/LLVMCommon/TypeConverter.h" +#include "mlir/Dialect/LLVMIR/LLVMAttrs.h" +#include "mlir/Dialect/LLVMIR/LLVMDialect.h" +#include "mlir/Dialect/LLVMIR/LLVMTypes.h" +#include "mlir/Dialect/Ptr/IR/PtrOps.h" +#include "mlir/IR/TypeUtilities.h" +#include + +using namespace mlir; + +//===----------------------------------------------------------------------===// +// ConvertToLLVMPatternInterface implementation +//===----------------------------------------------------------------------===// + +namespace { +/// Implement the interface to convert Ptr to LLVM. +struct PtrToLLVMDialectInterface : public ConvertToLLVMPatternInterface { + using ConvertToLLVMPatternInterface::ConvertToLLVMPatternInterface; + void loadDependentDialects(MLIRContext *context) const final { + context->loadDialect(); + } + + /// Hook for derived dialect interface to provide conversion patterns + /// and mark dialect legal for the conversion target. + void populateConvertToLLVMConversionPatterns( + ConversionTarget &target, LLVMTypeConverter &converter, + RewritePatternSet &patterns) const final { + llvm::errs() << "Populating Ptr to LLVM conversion patterns! \n"; + ptr::populatePtrToLLVMConversionPatterns(converter, patterns); + } +}; +} + +//===----------------------------------------------------------------------===// +// API +//===----------------------------------------------------------------------===// + +void mlir::ptr::populatePtrToLLVMConversionPatterns( + LLVMTypeConverter &converter, RewritePatternSet &patterns) { + + llvm::errs() << "Adding type conversion! \n"; + + // Add type conversions. + converter.addConversion([&](ptr::PtrType type) -> Type { + llvm::errs() << "Converting PtrType! \n"; + llvm::errs() << "MemorySpace: " << type.getMemorySpace() << "\n"; + std::optional maybeAttr = + converter.convertTypeAttribute(type, type.getMemorySpace()); + auto memSpace = + maybeAttr ? dyn_cast_or_null(*maybeAttr) : IntegerAttr(); + if (!memSpace) + return {}; + return LLVM::LLVMPointerType::get(type.getContext(), + memSpace.getValue().getSExtValue()); + }); + +} + +void mlir::ptr::registerConvertPtrToLLVMInterface(DialectRegistry ®istry) { + llvm::errs() << "Registering Ptr to LLVM interface! \n"; + registry.addExtension(+[](MLIRContext *ctx, ptr::PtrDialect *dialect) { + dialect->addInterfaces(); + }); +} \ No newline at end of file -- Gitee From 61c6ebff9a2626d1052df820b0280ff0374d7038 Mon Sep 17 00:00:00 2001 From: Daniel Date: Tue, 2 Sep 2025 16:17:35 +0200 Subject: [PATCH 2/7] [Backport][mlir][DataLayout] Add a default memory space entry to the data layout --- mlir/include/mlir/Dialect/DLTI/DLTIAttrs.td | 3 + mlir/include/mlir/Dialect/DLTI/DLTIBase.td | 3 + mlir/include/mlir/Dialect/LLVMIR/LLVMTypes.h | 6 +- .../Dialect/Ptr/IR/MemorySpaceInterfaces.td | 117 ++++++++++++++++++ .../mlir/Interfaces/DataLayoutInterfaces.h | 10 ++ mlir/lib/Dialect/DLTI/DLTI.cpp | 55 ++++---- mlir/lib/Dialect/LLVMIR/IR/LLVMTypes.cpp | 12 +- mlir/lib/Dialect/Ptr/IR/PtrTypes.cpp | 46 +++---- mlir/lib/Interfaces/DataLayoutInterfaces.cpp | 27 ++++ .../lib/Dialect/DLTI/TestDataLayoutQuery.cpp | 5 + mlir/test/lib/Dialect/Test/TestTypes.cpp | 4 +- .../Interfaces/DataLayoutInterfacesTest.cpp | 19 +++ 12 files changed, 258 insertions(+), 49 deletions(-) create mode 100644 mlir/include/mlir/Dialect/Ptr/IR/MemorySpaceInterfaces.td diff --git a/mlir/include/mlir/Dialect/DLTI/DLTIAttrs.td b/mlir/include/mlir/Dialect/DLTI/DLTIAttrs.td index 443e3128b4ac..1b6faa9387e5 100644 --- a/mlir/include/mlir/Dialect/DLTI/DLTIAttrs.td +++ b/mlir/include/mlir/Dialect/DLTI/DLTIAttrs.td @@ -81,6 +81,9 @@ def DLTI_DataLayoutSpecAttr : /// Returns the endiannes identifier. StringAttr getEndiannessIdentifier(MLIRContext *context) const; + + /// Returns the default memory space identifier. + StringAttr getDefaultMemorySpaceIdentifier(MLIRContext *context) const; /// Returns the alloca memory space identifier. StringAttr getAllocaMemorySpaceIdentifier(MLIRContext *context) const; diff --git a/mlir/include/mlir/Dialect/DLTI/DLTIBase.td b/mlir/include/mlir/Dialect/DLTI/DLTIBase.td index e26fbdb14664..65c69d38c537 100644 --- a/mlir/include/mlir/Dialect/DLTI/DLTIBase.td +++ b/mlir/include/mlir/Dialect/DLTI/DLTIBase.td @@ -55,6 +55,9 @@ def DLTI_Dialect : Dialect { constexpr const static ::llvm::StringLiteral kDataLayoutStackAlignmentKey = "dlti.stack_alignment"; + + constexpr const static ::llvm::StringLiteral + kDataLayoutDefaultMemorySpaceKey = "dlti.default_memory_space"; }]; let useDefaultAttributePrinterParser = 1; diff --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMTypes.h b/mlir/include/mlir/Dialect/LLVMIR/LLVMTypes.h index 93733ccd4929..b9f03451b48c 100644 --- a/mlir/include/mlir/Dialect/LLVMIR/LLVMTypes.h +++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMTypes.h @@ -198,8 +198,10 @@ public: uint64_t getPreferredAlignment(const DataLayout &dataLayout, DataLayoutEntryListRef params) const; - bool areCompatible(DataLayoutEntryListRef oldLayout, - DataLayoutEntryListRef newLayout) const; + bool areCompatible( + DataLayoutEntryListRef oldLayout, DataLayoutEntryListRef newLayout, + DataLayoutSpecInterface newSpec, + const DataLayoutIdentifiedEntryMap &map) const; LogicalResult verifyEntries(DataLayoutEntryListRef entries, Location loc) const; diff --git a/mlir/include/mlir/Dialect/Ptr/IR/MemorySpaceInterfaces.td b/mlir/include/mlir/Dialect/Ptr/IR/MemorySpaceInterfaces.td new file mode 100644 index 000000000000..cb7775c862a9 --- /dev/null +++ b/mlir/include/mlir/Dialect/Ptr/IR/MemorySpaceInterfaces.td @@ -0,0 +1,117 @@ +//===-- MemorySpaceInterfaces.td - Memory space interfaces ----------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file defines memory space attribute interfaces. +// +//===----------------------------------------------------------------------===// + +#ifndef PTR_MEMORYSPACEINTERFACES +#define PTR_MEMORYSPACEINTERFACES + +include "mlir/IR/AttrTypeBase.td" +include "mlir/IR/OpBase.td" + +//===----------------------------------------------------------------------===// +// Memory space attribute interface. +//===----------------------------------------------------------------------===// + +def MemorySpaceAttrInterface : AttrInterface<"MemorySpaceAttrInterface"> { + let description = [{ + This interface defines a common API for interacting with the memory model of + a memory space and the operations in the pointer dialect. + + Furthermore, this interface allows concepts such as read-only memory to be + adequately modeled and enforced. + }]; + let cppNamespace = "::mlir::ptr"; + let methods = [ + InterfaceMethod< + /*desc=*/ [{ + This method checks if it's valid to load a value from the memory space + with a specific type, alignment, and atomic ordering. + If `emitError` is non-null then the method is allowed to emit errors. + }], + /*returnType=*/ "::mlir::LogicalResult", + /*methodName=*/ "isValidLoad", + /*args=*/ (ins "::mlir::Type":$type, + "::mlir::ptr::AtomicOrdering":$ordering, + "::mlir::IntegerAttr":$alignment, + "::llvm::function_ref<::mlir::InFlightDiagnostic()>":$emitError) + >, + InterfaceMethod< + /*desc=*/ [{ + This method checks if it's valid to store a value in the memory space + with a specific type, alignment, and atomic ordering. + If `emitError` is non-null then the method is allowed to emit errors. + }], + /*returnType=*/ "::mlir::LogicalResult", + /*methodName=*/ "isValidStore", + /*args=*/ (ins "::mlir::Type":$type, + "::mlir::ptr::AtomicOrdering":$ordering, + "::mlir::IntegerAttr":$alignment, + "::llvm::function_ref<::mlir::InFlightDiagnostic()>":$emitError) + >, + InterfaceMethod< + /*desc=*/ [{ + This method checks if it's valid to perform an atomic operation in the + memory space with a specific type, alignment, and atomic ordering. + If `emitError` is non-null then the method is allowed to emit errors. + }], + /*returnType=*/ "::mlir::LogicalResult", + /*methodName=*/ "isValidAtomicOp", + /*args=*/ (ins "::mlir::ptr::AtomicBinOp":$op, + "::mlir::Type":$type, + "::mlir::ptr::AtomicOrdering":$ordering, + "::mlir::IntegerAttr":$alignment, + "::llvm::function_ref<::mlir::InFlightDiagnostic()>":$emitError) + >, + InterfaceMethod< + /*desc=*/ [{ + This method checks if it's valid to perform an atomic exchange operation + in the memory space with a specific type, alignment, and atomic + orderings. + If `emitError` is non-null then the method is allowed to emit errors. + }], + /*returnType=*/ "::mlir::LogicalResult", + /*methodName=*/ "isValidAtomicXchg", + /*args=*/ (ins "::mlir::Type":$type, + "::mlir::ptr::AtomicOrdering":$successOrdering, + "::mlir::ptr::AtomicOrdering":$failureOrdering, + "::mlir::IntegerAttr":$alignment, + "::llvm::function_ref<::mlir::InFlightDiagnostic()>":$emitError) + >, + InterfaceMethod< + /*desc=*/ [{ + This method checks if it's valid to perform an `addrspacecast` op + in the memory space. + If `emitError` is non-null then the method is allowed to emit errors. + }], + /*returnType=*/ "::mlir::LogicalResult", + /*methodName=*/ "isValidAddrSpaceCast", + /*args=*/ (ins "::mlir::Type":$tgt, + "::mlir::Type":$src, + "::llvm::function_ref<::mlir::InFlightDiagnostic()>":$emitError) + >, + InterfaceMethod< + /*desc=*/ [{ + This method checks if it's valid to perform a `ptrtoint` or `inttoptr` + op in the memory space. + The first type is expected to be integer-like, while the second must be a + ptr-like type. + If `emitError` is non-null then the method is allowed to emit errors. + }], + /*returnType=*/ "::mlir::LogicalResult", + /*methodName=*/ "isValidPtrIntCast", + /*args=*/ (ins "::mlir::Type":$intLikeTy, + "::mlir::Type":$ptrLikeTy, + "::llvm::function_ref<::mlir::InFlightDiagnostic()>":$emitError) + >, + ]; +} + +#endif // PTR_MEMORYSPACEINTERFACES \ No newline at end of file diff --git a/mlir/include/mlir/Interfaces/DataLayoutInterfaces.h b/mlir/include/mlir/Interfaces/DataLayoutInterfaces.h index ab65f92820a6..e64430189009 100644 --- a/mlir/include/mlir/Interfaces/DataLayoutInterfaces.h +++ b/mlir/include/mlir/Interfaces/DataLayoutInterfaces.h @@ -35,6 +35,8 @@ using DeviceIDTargetDeviceSpecPair = std::pair; using DeviceIDTargetDeviceSpecPairListRef = llvm::ArrayRef; +using DataLayoutIdentifiedEntryMap = + ::llvm::DenseMap<::mlir::StringAttr, ::mlir::DataLayoutEntryInterface>; class DataLayoutOpInterface; class DataLayoutSpecInterface; class ModuleOp; @@ -79,6 +81,10 @@ Attribute getDefaultEndianness(DataLayoutEntryInterface entry); /// DataLayoutInterface if specified, otherwise returns the default. Attribute getDefaultAllocaMemorySpace(DataLayoutEntryInterface entry); +/// Default handler for the default memory space request. Dispatches to the +/// DataLayoutInterface if specified, otherwise returns the default. +Attribute getDefaultMemorySpace(DataLayoutEntryInterface entry); + /// Default handler for program memory space request. Dispatches to the /// DataLayoutInterface if specified, otherwise returns the default. Attribute getDefaultProgramMemorySpace(DataLayoutEntryInterface entry); @@ -231,6 +237,9 @@ public: /// Returns the memory space used for AllocaOps. Attribute getAllocaMemorySpace() const; + /// Returns the default memory space used for memory operations. + Attribute getDefaultMemorySpace() const; + /// Returns the memory space used for program memory operations. Attribute getProgramMemorySpace() const; @@ -281,6 +290,7 @@ private: mutable std::optional allocaMemorySpace; mutable std::optional programMemorySpace; mutable std::optional globalMemorySpace; + mutable std::optional defaultMemorySpace; /// Cache for stack alignment. mutable std::optional stackAlignment; diff --git a/mlir/lib/Dialect/DLTI/DLTI.cpp b/mlir/lib/Dialect/DLTI/DLTI.cpp index 420c605d1a19..987de03a4685 100644 --- a/mlir/lib/Dialect/DLTI/DLTI.cpp +++ b/mlir/lib/Dialect/DLTI/DLTI.cpp @@ -170,25 +170,9 @@ combineOneSpec(DataLayoutSpecInterface spec, DenseMap newEntriesForID; spec.bucketEntriesByType(newEntriesForType, newEntriesForID); - // Try overwriting the old entries with the new ones. - for (auto &kvp : newEntriesForType) { - if (!entriesForType.count(kvp.first)) { - entriesForType[kvp.first] = std::move(kvp.second); - continue; - } - - Type typeSample = kvp.second.front().getKey().get(); - assert(&typeSample.getDialect() != - typeSample.getContext()->getLoadedDialect() && - "unexpected data layout entry for built-in type"); - - auto interface = llvm::cast(typeSample); - if (!interface.areCompatible(entriesForType.lookup(kvp.first), kvp.second)) - return failure(); - - overwriteDuplicateEntries(entriesForType[kvp.first], kvp.second); - } - + + // Combine non-Type DL entries first so they are visible to the + // `type.areCompatible` method, allowing to query global properties. for (const auto &kvp : newEntriesForID) { StringAttr id = kvp.second.getKey().get(); Dialect *dialect = id.getReferencedDialect(); @@ -197,7 +181,7 @@ combineOneSpec(DataLayoutSpecInterface spec, continue; } - // Attempt to combine the enties using the dialect interface. If the + // Attempt to combine the entries using the dialect interface. If the // dialect is not loaded for some reason, use the default combinator // that conservatively accepts identical entries only. entriesForID[id] = @@ -208,6 +192,27 @@ combineOneSpec(DataLayoutSpecInterface spec, if (!entriesForID[id]) return failure(); } + // Try overwriting the old entries with the new ones. + for (auto &kvp : newEntriesForType) { + if (!entriesForType.count(kvp.first)) { + entriesForType[kvp.first] = std::move(kvp.second); + continue; + } + + Type typeSample = cast(kvp.second.front().getKey()); + assert(&typeSample.getDialect() != + typeSample.getContext()->getLoadedDialect() && + "unexpected data layout entry for built-in type"); + + auto interface = cast(typeSample); + // TODO: Revisit this method and call once + // https://github.com/llvm/llvm-project/issues/130321 gets resolved. + if (!interface.areCompatible(entriesForType.lookup(kvp.first), kvp.second, + spec, entriesForID)) + return failure(); + + overwriteDuplicateEntries(entriesForType[kvp.first], kvp.second); + } return success(); } @@ -244,6 +249,12 @@ DataLayoutSpecAttr::getEndiannessIdentifier(MLIRContext *context) const { return Builder(context).getStringAttr(DLTIDialect::kDataLayoutEndiannessKey); } +StringAttr DataLayoutSpecAttr::getDefaultMemorySpaceIdentifier( + MLIRContext *context) const { + return Builder(context).getStringAttr( + DLTIDialect::kDataLayoutDefaultMemorySpaceKey); +} + StringAttr DataLayoutSpecAttr::getAllocaMemorySpaceIdentifier(MLIRContext *context) const { return Builder(context).getStringAttr( @@ -417,7 +428,9 @@ public: << DLTIDialect::kDataLayoutEndiannessBig << "' or '" << DLTIDialect::kDataLayoutEndiannessLittle << "'"; } - if (entryName == DLTIDialect::kDataLayoutAllocaMemorySpaceKey || + + if (entryName == DLTIDialect::kDataLayoutDefaultMemorySpaceKey || + entryName == DLTIDialect::kDataLayoutAllocaMemorySpaceKey || entryName == DLTIDialect::kDataLayoutProgramMemorySpaceKey || entryName == DLTIDialect::kDataLayoutGlobalMemorySpaceKey || entryName == DLTIDialect::kDataLayoutStackAlignmentKey) diff --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMTypes.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMTypes.cpp index cf3f38b71013..a3f1bef9b28c 100644 --- a/mlir/lib/Dialect/LLVMIR/IR/LLVMTypes.cpp +++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMTypes.cpp @@ -349,8 +349,10 @@ LLVMPointerType::getIndexBitwidth(const DataLayout &dataLayout, return dataLayout.getTypeIndexBitwidth(get(getContext())); } -bool LLVMPointerType::areCompatible(DataLayoutEntryListRef oldLayout, - DataLayoutEntryListRef newLayout) const { +bool LLVMPointerType::areCompatible( + DataLayoutEntryListRef oldLayout, DataLayoutEntryListRef newLayout, + DataLayoutSpecInterface newSpec, + const DataLayoutIdentifiedEntryMap &map) const { for (DataLayoutEntryInterface newEntry : newLayout) { if (!newEntry.isTypeEntry()) continue; @@ -596,8 +598,10 @@ static uint64_t extractStructSpecValue(Attribute attr, StructDLEntryPos pos) { .getValues()[static_cast(pos)]; } -bool LLVMStructType::areCompatible(DataLayoutEntryListRef oldLayout, - DataLayoutEntryListRef newLayout) const { +bool LLVMStructType::areCompatible( + DataLayoutEntryListRef oldLayout, DataLayoutEntryListRef newLayout, + DataLayoutSpecInterface newSpec, + const DataLayoutIdentifiedEntryMap &map) const{ for (DataLayoutEntryInterface newEntry : newLayout) { if (!newEntry.isTypeEntry()) continue; diff --git a/mlir/lib/Dialect/Ptr/IR/PtrTypes.cpp b/mlir/lib/Dialect/Ptr/IR/PtrTypes.cpp index 2866d4eb10fe..dcc4a1359416 100644 --- a/mlir/lib/Dialect/Ptr/IR/PtrTypes.cpp +++ b/mlir/lib/Dialect/Ptr/IR/PtrTypes.cpp @@ -23,13 +23,12 @@ using namespace mlir::ptr; constexpr const static unsigned kDefaultPointerSizeBits = 64; constexpr const static unsigned kBitsInByte = 8; -constexpr const static unsigned kDefaultPointerAlignment = 8; - -static Attribute getDefaultMemorySpace(PtrType ptr) { return nullptr; } +constexpr const static unsigned kDefaultPointerAlignmentBits = 8; /// Searches the data layout for the pointer spec, returns nullptr if it is not /// found. -static SpecAttr getPointerSpec(DataLayoutEntryListRef params, PtrType type) { +static SpecAttr getPointerSpec(DataLayoutEntryListRef params, PtrType type, + Attribute defaultMemorySpace) { for (DataLayoutEntryInterface entry : params) { if (!entry.isTypeEntry()) continue; @@ -41,20 +40,22 @@ static SpecAttr getPointerSpec(DataLayoutEntryListRef params, PtrType type) { } // If not found, and this is the pointer to the default memory space, assume // 64-bit pointers. - if (type.getMemorySpace() == getDefaultMemorySpace(type)) + if (type.getMemorySpace() == defaultMemorySpace) return SpecAttr::get(type.getContext(), kDefaultPointerSizeBits, - kDefaultPointerAlignment, kDefaultPointerAlignment, - kDefaultPointerSizeBits); + kDefaultPointerAlignmentBits, + kDefaultPointerAlignmentBits, kDefaultPointerSizeBits); return nullptr; } bool PtrType::areCompatible(DataLayoutEntryListRef oldLayout, - DataLayoutEntryListRef newLayout) const { + DataLayoutEntryListRef newLayout, + DataLayoutSpecInterface newSpec, + const DataLayoutIdentifiedEntryMap &map) const { for (DataLayoutEntryInterface newEntry : newLayout) { if (!newEntry.isTypeEntry()) continue; uint32_t size = kDefaultPointerSizeBits; - uint32_t abi = kDefaultPointerAlignment; + uint32_t abi = kDefaultPointerAlignmentBits; auto newType = llvm::cast(newEntry.getKey().get()); const auto *it = llvm::find_if(oldLayout, [&](DataLayoutEntryInterface entry) { @@ -65,10 +66,12 @@ bool PtrType::areCompatible(DataLayoutEntryListRef oldLayout, return false; }); if (it == oldLayout.end()) { + Attribute defaultMemorySpace = mlir::detail::getDefaultMemorySpace( + map.lookup(newSpec.getDefaultMemorySpaceIdentifier(getContext()))); it = llvm::find_if(oldLayout, [&](DataLayoutEntryInterface entry) { if (auto type = llvm::dyn_cast_if_present(entry.getKey())) { auto ptrTy = llvm::cast(type); - return ptrTy.getMemorySpace() == getDefaultMemorySpace(ptrTy); + return ptrTy.getMemorySpace() == defaultMemorySpace; } return false; }); @@ -90,43 +93,44 @@ bool PtrType::areCompatible(DataLayoutEntryListRef oldLayout, uint64_t PtrType::getABIAlignment(const DataLayout &dataLayout, DataLayoutEntryListRef params) const { - if (SpecAttr spec = getPointerSpec(params, *this)) + Attribute defaultMemorySpace = dataLayout.getDefaultMemorySpace(); + if (SpecAttr spec = getPointerSpec(params, *this, defaultMemorySpace)) return spec.getAbi() / kBitsInByte; - return dataLayout.getTypeABIAlignment( - get(getContext(), getDefaultMemorySpace(*this))); + return dataLayout.getTypeABIAlignment(get(getContext(), defaultMemorySpace)); } std::optional PtrType::getIndexBitwidth(const DataLayout &dataLayout, DataLayoutEntryListRef params) const { - if (SpecAttr spec = getPointerSpec(params, *this)) { + Attribute defaultMemorySpace = dataLayout.getDefaultMemorySpace(); + if (SpecAttr spec = getPointerSpec(params, *this, defaultMemorySpace)) { return spec.getIndex() == SpecAttr::kOptionalSpecValue ? spec.getSize() : spec.getIndex(); } - return dataLayout.getTypeIndexBitwidth( - get(getContext(), getDefaultMemorySpace(*this))); + return dataLayout.getTypeIndexBitwidth(get(getContext(), defaultMemorySpace)); } llvm::TypeSize PtrType::getTypeSizeInBits(const DataLayout &dataLayout, DataLayoutEntryListRef params) const { - if (SpecAttr spec = getPointerSpec(params, *this)) + Attribute defaultMemorySpace = dataLayout.getDefaultMemorySpace(); + if (SpecAttr spec = getPointerSpec(params, *this, defaultMemorySpace)) return llvm::TypeSize::getFixed(spec.getSize()); // For other memory spaces, use the size of the pointer to the default memory // space. - return dataLayout.getTypeSizeInBits( - get(getContext(), getDefaultMemorySpace(*this))); + return dataLayout.getTypeSizeInBits(get(getContext(), defaultMemorySpace)); } uint64_t PtrType::getPreferredAlignment(const DataLayout &dataLayout, DataLayoutEntryListRef params) const { - if (SpecAttr spec = getPointerSpec(params, *this)) + Attribute defaultMemorySpace = dataLayout.getDefaultMemorySpace(); + if (SpecAttr spec = getPointerSpec(params, *this, defaultMemorySpace)) return spec.getPreferred() / kBitsInByte; return dataLayout.getTypePreferredAlignment( - get(getContext(), getDefaultMemorySpace(*this))); + get(getContext(), defaultMemorySpace)); } LogicalResult PtrType::verifyEntries(DataLayoutEntryListRef entries, diff --git a/mlir/lib/Interfaces/DataLayoutInterfaces.cpp b/mlir/lib/Interfaces/DataLayoutInterfaces.cpp index 2634245a4b7b..9b3885cc539d 100644 --- a/mlir/lib/Interfaces/DataLayoutInterfaces.cpp +++ b/mlir/lib/Interfaces/DataLayoutInterfaces.cpp @@ -246,6 +246,16 @@ Attribute mlir::detail::getDefaultEndianness(DataLayoutEntryInterface entry) { return entry.getValue(); } +// Returns the default memory space if specified in the given entry. If the +// entry is empty the default memory space represented by an empty attribute is +// returned. +Attribute mlir::detail::getDefaultMemorySpace(DataLayoutEntryInterface entry) { + if (!entry) + return Attribute(); + + return entry.getValue(); +} + // Returns the memory space used for alloca operations if specified in the // given entry. If the entry is empty the default memory space represented by // an empty attribute is returned. @@ -596,6 +606,23 @@ mlir::Attribute mlir::DataLayout::getEndianness() const { return *endianness; } +mlir::Attribute mlir::DataLayout::getDefaultMemorySpace() const { + checkValid(); + if (defaultMemorySpace) + return *defaultMemorySpace; + DataLayoutEntryInterface entry; + if (originalLayout) + entry = originalLayout.getSpecForIdentifier( + originalLayout.getDefaultMemorySpaceIdentifier( + originalLayout.getContext())); + if (auto iface = dyn_cast_or_null(scope)) + defaultMemorySpace = iface.getDefaultMemorySpace(entry); + else + defaultMemorySpace = detail::getDefaultMemorySpace(entry); + return *defaultMemorySpace; +} + + mlir::Attribute mlir::DataLayout::getAllocaMemorySpace() const { checkValid(); if (allocaMemorySpace) diff --git a/mlir/test/lib/Dialect/DLTI/TestDataLayoutQuery.cpp b/mlir/test/lib/Dialect/DLTI/TestDataLayoutQuery.cpp index 56f309f150ca..a4f0fc6b2ff7 100644 --- a/mlir/test/lib/Dialect/DLTI/TestDataLayoutQuery.cpp +++ b/mlir/test/lib/Dialect/DLTI/TestDataLayoutQuery.cpp @@ -42,6 +42,7 @@ struct TestDataLayoutQuery uint64_t preferred = layout.getTypePreferredAlignment(op.getType()); uint64_t index = layout.getTypeIndexBitwidth(op.getType()).value_or(0); Attribute endianness = layout.getEndianness(); + Attribute defaultMemorySpace = layout.getDefaultMemorySpace(); Attribute allocaMemorySpace = layout.getAllocaMemorySpace(); Attribute programMemorySpace = layout.getProgramMemorySpace(); Attribute globalMemorySpace = layout.getGlobalMemorySpace(); @@ -68,6 +69,10 @@ struct TestDataLayoutQuery builder.getNamedAttr("endianness", endianness == Attribute() ? builder.getStringAttr("") : endianness), + builder.getNamedAttr("default_memory_space", + defaultMemorySpace == Attribute() + ? builder.getUI32IntegerAttr(0) + : defaultMemorySpace), builder.getNamedAttr("alloca_memory_space", allocaMemorySpace == Attribute() ? builder.getUI32IntegerAttr(0) diff --git a/mlir/test/lib/Dialect/Test/TestTypes.cpp b/mlir/test/lib/Dialect/Test/TestTypes.cpp index 1593b6d7d753..213c2e69329d 100644 --- a/mlir/test/lib/Dialect/Test/TestTypes.cpp +++ b/mlir/test/lib/Dialect/Test/TestTypes.cpp @@ -284,7 +284,9 @@ TestTypeWithLayoutType::getIndexBitwidth(const DataLayout &dataLayout, } bool TestTypeWithLayoutType::areCompatible( - DataLayoutEntryListRef oldLayout, DataLayoutEntryListRef newLayout) const { + DataLayoutEntryListRef oldLayout, DataLayoutEntryListRef newLayout, + DataLayoutSpecInterface newSpec, + const DataLayoutIdentifiedEntryMap &map) const { unsigned old = extractKind(oldLayout, "alignment"); return old == 1 || extractKind(newLayout, "alignment") <= old; } diff --git a/mlir/unittests/Interfaces/DataLayoutInterfacesTest.cpp b/mlir/unittests/Interfaces/DataLayoutInterfacesTest.cpp index d1227b045d4e..c7350c218a18 100644 --- a/mlir/unittests/Interfaces/DataLayoutInterfacesTest.cpp +++ b/mlir/unittests/Interfaces/DataLayoutInterfacesTest.cpp @@ -23,6 +23,8 @@ using namespace mlir; namespace { constexpr static llvm::StringLiteral kAttrName = "dltest.layout"; constexpr static llvm::StringLiteral kEndiannesKeyName = "dltest.endianness"; +constexpr static llvm::StringLiteral kDefaultKeyName = + "dltest.default_memory_space"; constexpr static llvm::StringLiteral kAllocaKeyName = "dltest.alloca_memory_space"; constexpr static llvm::StringLiteral kProgramKeyName = @@ -83,6 +85,9 @@ struct CustomDataLayoutSpec StringAttr getAllocaMemorySpaceIdentifier(MLIRContext *context) const { return Builder(context).getStringAttr(kAllocaKeyName); } + StringAttr getDefaultMemorySpaceIdentifier(MLIRContext *context) const { + return Builder(context).getStringAttr(kDefaultKeyName); + } StringAttr getProgramMemorySpaceIdentifier(MLIRContext *context) const { return Builder(context).getStringAttr(kProgramKeyName); } @@ -201,6 +206,15 @@ struct SingleQueryType return Attribute(); } + Attribute getDefaultMemorySpace(DataLayoutEntryInterface entry) { + static bool executed = false; + if (executed) + llvm::report_fatal_error("repeated call"); + + executed = true; + return Attribute(); + } + Attribute getProgramMemorySpace(DataLayoutEntryInterface entry) { static bool executed = false; if (executed) @@ -458,6 +472,7 @@ module {} EXPECT_EQ(layout.getTypePreferredAlignment(Float16Type::get(&ctx)), 2u); EXPECT_EQ(layout.getEndianness(), Attribute()); + EXPECT_EQ(layout.getDefaultMemorySpace(), Attribute()); EXPECT_EQ(layout.getAllocaMemorySpace(), Attribute()); EXPECT_EQ(layout.getProgramMemorySpace(), Attribute()); EXPECT_EQ(layout.getGlobalMemorySpace(), Attribute()); @@ -490,6 +505,7 @@ TEST(DataLayout, NullSpec) { EXPECT_EQ(layout.getTypeIndexBitwidth(IndexType::get(&ctx)), 64u); EXPECT_EQ(layout.getEndianness(), Attribute()); + EXPECT_EQ(layout.getDefaultMemorySpace(), Attribute()); EXPECT_EQ(layout.getAllocaMemorySpace(), Attribute()); EXPECT_EQ(layout.getProgramMemorySpace(), Attribute()); EXPECT_EQ(layout.getGlobalMemorySpace(), Attribute()); @@ -530,6 +546,7 @@ TEST(DataLayout, EmptySpec) { EXPECT_EQ(layout.getTypeIndexBitwidth(IndexType::get(&ctx)), 64u); EXPECT_EQ(layout.getEndianness(), Attribute()); + EXPECT_EQ(layout.getDefaultMemorySpace(), Attribute()); EXPECT_EQ(layout.getAllocaMemorySpace(), Attribute()); EXPECT_EQ(layout.getProgramMemorySpace(), Attribute()); EXPECT_EQ(layout.getGlobalMemorySpace(), Attribute()); @@ -552,6 +569,7 @@ TEST(DataLayout, SpecWithEntries) { #dlti.dl_entry, #dlti.dl_entry, #dlti.dl_entry<"dltest.endianness", "little">, + #dlti.dl_entry<"dltest.default_memory_space", 1 : i32>, #dlti.dl_entry<"dltest.alloca_memory_space", 5 : i32>, #dlti.dl_entry<"dltest.program_memory_space", 3 : i32>, #dlti.dl_entry<"dltest.global_memory_space", 2 : i32>, @@ -588,6 +606,7 @@ TEST(DataLayout, SpecWithEntries) { EXPECT_EQ(layout.getTypePreferredAlignment(Float32Type::get(&ctx)), 64u); EXPECT_EQ(layout.getEndianness(), Builder(&ctx).getStringAttr("little")); + EXPECT_EQ(layout.getDefaultMemorySpace(), Builder(&ctx).getI32IntegerAttr(1)); EXPECT_EQ(layout.getAllocaMemorySpace(), Builder(&ctx).getI32IntegerAttr(5)); EXPECT_EQ(layout.getProgramMemorySpace(), Builder(&ctx).getI32IntegerAttr(3)); EXPECT_EQ(layout.getGlobalMemorySpace(), Builder(&ctx).getI32IntegerAttr(2)); -- Gitee From 04a5be38d3579b23254ed4e682081aab62f24526 Mon Sep 17 00:00:00 2001 From: Daniel Date: Tue, 2 Sep 2025 16:18:04 +0200 Subject: [PATCH 3/7] [Backport][mlir][Ptr] Add the MemorySpaceAttrInterface interface and dependencies. --- .../mlir/Dialect/Ptr/IR/CMakeLists.txt | 12 ++++ .../Dialect/Ptr/IR/MemorySpaceInterfaces.h | 32 +++++++++ mlir/include/mlir/Dialect/Ptr/IR/PtrAttrs.h | 2 + .../include/mlir/Dialect/Ptr/IR/PtrDialect.td | 10 +-- mlir/include/mlir/Dialect/Ptr/IR/PtrEnums.td | 69 +++++++++++++++++++ mlir/include/mlir/Dialect/Ptr/IR/PtrOps.td | 1 + mlir/include/mlir/Dialect/Ptr/IR/PtrTypes.h | 1 + .../mlir/Interfaces/DataLayoutInterfaces.td | 22 +++++- mlir/lib/Dialect/Ptr/IR/CMakeLists.txt | 4 +- mlir/lib/Dialect/Ptr/IR/PtrDialect.cpp | 6 ++ mlir/lib/Dialect/Ptr/IR/PtrTypes.cpp | 31 +++++---- 11 files changed, 170 insertions(+), 20 deletions(-) create mode 100644 mlir/include/mlir/Dialect/Ptr/IR/MemorySpaceInterfaces.h create mode 100644 mlir/include/mlir/Dialect/Ptr/IR/PtrEnums.td diff --git a/mlir/include/mlir/Dialect/Ptr/IR/CMakeLists.txt b/mlir/include/mlir/Dialect/Ptr/IR/CMakeLists.txt index df07b8d5a63d..255af4c486cb 100644 --- a/mlir/include/mlir/Dialect/Ptr/IR/CMakeLists.txt +++ b/mlir/include/mlir/Dialect/Ptr/IR/CMakeLists.txt @@ -5,3 +5,15 @@ set(LLVM_TARGET_DEFINITIONS PtrOps.td) mlir_tablegen(PtrOpsAttrs.h.inc -gen-attrdef-decls -attrdefs-dialect=ptr) mlir_tablegen(PtrOpsAttrs.cpp.inc -gen-attrdef-defs -attrdefs-dialect=ptr) add_public_tablegen_target(MLIRPtrOpsAttributesIncGen) + +set(LLVM_TARGET_DEFINITIONS MemorySpaceInterfaces.td) +mlir_tablegen(MemorySpaceInterfaces.h.inc -gen-op-interface-decls) +mlir_tablegen(MemorySpaceInterfaces.cpp.inc -gen-op-interface-defs) +mlir_tablegen(MemorySpaceAttrInterfaces.h.inc -gen-attr-interface-decls) +mlir_tablegen(MemorySpaceAttrInterfaces.cpp.inc -gen-attr-interface-defs) +add_public_tablegen_target(MLIRPtrMemorySpaceInterfacesIncGen) + +set(LLVM_TARGET_DEFINITIONS PtrOps.td) +mlir_tablegen(PtrOpsEnums.h.inc -gen-enum-decls) +mlir_tablegen(PtrOpsEnums.cpp.inc -gen-enum-defs) +add_public_tablegen_target(MLIRPtrOpsEnumsGen) \ No newline at end of file diff --git a/mlir/include/mlir/Dialect/Ptr/IR/MemorySpaceInterfaces.h b/mlir/include/mlir/Dialect/Ptr/IR/MemorySpaceInterfaces.h new file mode 100644 index 000000000000..3714c1caa367 --- /dev/null +++ b/mlir/include/mlir/Dialect/Ptr/IR/MemorySpaceInterfaces.h @@ -0,0 +1,32 @@ +//===-- MemorySpaceInterfaces.h - ptr memory space interfaces ---*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file defines the ptr dialect memory space interfaces. +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_DIALECT_PTR_IR_MEMORYSPACEINTERFACES_H +#define MLIR_DIALECT_PTR_IR_MEMORYSPACEINTERFACES_H + +#include "mlir/IR/Attributes.h" +#include "mlir/IR/BuiltinAttributes.h" +#include "mlir/IR/OpDefinition.h" + +namespace mlir { +class Operation; +namespace ptr { +enum class AtomicBinOp : uint64_t; +enum class AtomicOrdering : uint64_t; +} // namespace ptr +} // namespace mlir + +#include "mlir/Dialect/Ptr/IR/MemorySpaceAttrInterfaces.h.inc" + +#include "mlir/Dialect/Ptr/IR/MemorySpaceInterfaces.h.inc" + +#endif // MLIR_DIALECT_PTR_IR_MEMORYSPACEINTERFACES_H \ No newline at end of file diff --git a/mlir/include/mlir/Dialect/Ptr/IR/PtrAttrs.h b/mlir/include/mlir/Dialect/Ptr/IR/PtrAttrs.h index 72e767764d98..5ffe23e45fe1 100644 --- a/mlir/include/mlir/Dialect/Ptr/IR/PtrAttrs.h +++ b/mlir/include/mlir/Dialect/Ptr/IR/PtrAttrs.h @@ -18,4 +18,6 @@ #define GET_ATTRDEF_CLASSES #include "mlir/Dialect/Ptr/IR/PtrOpsAttrs.h.inc" +#include "mlir/Dialect/Ptr/IR/PtrOpsEnums.h.inc" + #endif // MLIR_DIALECT_PTR_IR_PTRATTRS_H diff --git a/mlir/include/mlir/Dialect/Ptr/IR/PtrDialect.td b/mlir/include/mlir/Dialect/Ptr/IR/PtrDialect.td index 14d72c3001d9..9aa0215e9560 100644 --- a/mlir/include/mlir/Dialect/Ptr/IR/PtrDialect.td +++ b/mlir/include/mlir/Dialect/Ptr/IR/PtrDialect.td @@ -53,14 +53,14 @@ def Ptr_PtrType : Ptr_Type<"Ptr", "ptr", [ memory-space ::= attribute-value ``` }]; - let parameters = (ins OptionalParameter<"Attribute">:$memorySpace); - let assemblyFormat = "(`<` $memorySpace^ `>`)?"; + let parameters = (ins "MemorySpaceAttrInterface":$memorySpace); + let assemblyFormat = "`<` $memorySpace `>`"; let builders = [ - TypeBuilder<(ins CArg<"Attribute", "nullptr">:$memorySpace), [{ - return $_get($_ctxt, memorySpace); + TypeBuilderWithInferredContext<(ins + "MemorySpaceAttrInterface":$memorySpace), [{ + return $_get(memorySpace.getContext(), memorySpace); }]> ]; - let skipDefaultBuilders = 1; } //===----------------------------------------------------------------------===// diff --git a/mlir/include/mlir/Dialect/Ptr/IR/PtrEnums.td b/mlir/include/mlir/Dialect/Ptr/IR/PtrEnums.td new file mode 100644 index 000000000000..59d2d0b34222 --- /dev/null +++ b/mlir/include/mlir/Dialect/Ptr/IR/PtrEnums.td @@ -0,0 +1,69 @@ +//===-- PtrEnums.td - Ptr dialect enumerations -------------*- tablegen -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef PTR_ENUMS +#define PTR_ENUMS + +include "mlir/IR/EnumAttr.td" + +//===----------------------------------------------------------------------===// +// Atomic binary op enum attribute. +//===----------------------------------------------------------------------===// + +def AtomicBinOpXchg : I64EnumAttrCase<"xchg", 0, "xchg">; +def AtomicBinOpAdd : I64EnumAttrCase<"add", 1, "add">; +def AtomicBinOpSub : I64EnumAttrCase<"sub", 2, "sub">; +def AtomicBinOpAnd : I64EnumAttrCase<"_and", 3, "_and">; +def AtomicBinOpNand : I64EnumAttrCase<"nand", 4, "nand">; +def AtomicBinOpOr : I64EnumAttrCase<"_or", 5, "_or">; +def AtomicBinOpXor : I64EnumAttrCase<"_xor", 6, "_xor">; +def AtomicBinOpMax : I64EnumAttrCase<"max", 7, "max">; +def AtomicBinOpMin : I64EnumAttrCase<"min", 8, "min">; +def AtomicBinOpUMax : I64EnumAttrCase<"umax", 9, "umax">; +def AtomicBinOpUMin : I64EnumAttrCase<"umin", 10, "umin">; +def AtomicBinOpFAdd : I64EnumAttrCase<"fadd", 11, "fadd">; +def AtomicBinOpFSub : I64EnumAttrCase<"fsub", 12, "fsub">; +def AtomicBinOpFMax : I64EnumAttrCase<"fmax", 13, "fmax">; +def AtomicBinOpFMin : I64EnumAttrCase<"fmin", 14, "fmin">; +def AtomicBinOpUIncWrap : I64EnumAttrCase<"uinc_wrap", 15, "uinc_wrap">; +def AtomicBinOpUDecWrap : I64EnumAttrCase<"udec_wrap", 16, "udec_wrap">; + +def AtomicBinOp : I64EnumAttr< + "AtomicBinOp", + "ptr.atomicrmw binary operations", + [AtomicBinOpXchg, AtomicBinOpAdd, AtomicBinOpSub, AtomicBinOpAnd, + AtomicBinOpNand, AtomicBinOpOr, AtomicBinOpXor, AtomicBinOpMax, + AtomicBinOpMin, AtomicBinOpUMax, AtomicBinOpUMin, AtomicBinOpFAdd, + AtomicBinOpFSub, AtomicBinOpFMax, AtomicBinOpFMin, AtomicBinOpUIncWrap, + AtomicBinOpUDecWrap]> { + let cppNamespace = "::mlir::ptr"; +} + +//===----------------------------------------------------------------------===// +// Atomic ordering enum attribute. +//===----------------------------------------------------------------------===// + +def AtomicOrderingNotAtomic : I64EnumAttrCase<"not_atomic", 0, "not_atomic">; +def AtomicOrderingUnordered : I64EnumAttrCase<"unordered", 1, "unordered">; +def AtomicOrderingMonotonic : I64EnumAttrCase<"monotonic", 2, "monotonic">; +def AtomicOrderingAcquire : I64EnumAttrCase<"acquire", 3, "acquire">; +def AtomicOrderingRelease : I64EnumAttrCase<"release", 4, "release">; +def AtomicOrderingAcqRel : I64EnumAttrCase<"acq_rel", 5, "acq_rel">; +def AtomicOrderingSeqCst : I64EnumAttrCase<"seq_cst", 6, "seq_cst">; + +def AtomicOrdering : I64EnumAttr< + "AtomicOrdering", + "Atomic ordering for LLVM's memory model", + [AtomicOrderingNotAtomic, AtomicOrderingUnordered, AtomicOrderingMonotonic, + AtomicOrderingAcquire, AtomicOrderingRelease, AtomicOrderingAcqRel, + AtomicOrderingSeqCst + ]> { + let cppNamespace = "::mlir::ptr"; +} + +#endif // PTR_ENUMS \ No newline at end of file diff --git a/mlir/include/mlir/Dialect/Ptr/IR/PtrOps.td b/mlir/include/mlir/Dialect/Ptr/IR/PtrOps.td index c63a0b220e50..313c9f8eb09a 100644 --- a/mlir/include/mlir/Dialect/Ptr/IR/PtrOps.td +++ b/mlir/include/mlir/Dialect/Ptr/IR/PtrOps.td @@ -11,6 +11,7 @@ include "mlir/Dialect/Ptr/IR/PtrDialect.td" include "mlir/Dialect/Ptr/IR/PtrAttrDefs.td" +include "mlir/Dialect/Ptr/IR/MemorySpaceInterfaces.td" include "mlir/IR/OpAsmInterface.td" #endif // PTR_OPS diff --git a/mlir/include/mlir/Dialect/Ptr/IR/PtrTypes.h b/mlir/include/mlir/Dialect/Ptr/IR/PtrTypes.h index 264a97c80722..4fe1b5a1aa42 100644 --- a/mlir/include/mlir/Dialect/Ptr/IR/PtrTypes.h +++ b/mlir/include/mlir/Dialect/Ptr/IR/PtrTypes.h @@ -13,6 +13,7 @@ #ifndef MLIR_DIALECT_PTR_IR_PTRTYPES_H #define MLIR_DIALECT_PTR_IR_PTRTYPES_H +#include "mlir/Dialect/Ptr/IR/MemorySpaceInterfaces.h" #include "mlir/IR/Types.h" #include "mlir/Interfaces/DataLayoutInterfaces.h" diff --git a/mlir/include/mlir/Interfaces/DataLayoutInterfaces.td b/mlir/include/mlir/Interfaces/DataLayoutInterfaces.td index bc5080c9c6a5..a286397c9a41 100644 --- a/mlir/include/mlir/Interfaces/DataLayoutInterfaces.td +++ b/mlir/include/mlir/Interfaces/DataLayoutInterfaces.td @@ -136,6 +136,12 @@ def DataLayoutSpecInterface : AttrInterface<"DataLayoutSpecInterface"> { /*methodName=*/"getStackAlignmentIdentifier", /*args=*/(ins "::mlir::MLIRContext *":$context) >, + InterfaceMethod< + /*description=*/"Returns the default memory space identifier.", + /*retTy=*/"::mlir::StringAttr", + /*methodName=*/"getDefaultMemorySpaceIdentifier", + /*args=*/(ins "::mlir::MLIRContext *":$context) + >, // Implementations may override this if they have an efficient lookup // mechanism. InterfaceMethod< @@ -465,6 +471,18 @@ def DataLayoutOpInterface : OpInterface<"DataLayoutOpInterface"> { return ::mlir::detail::getDefaultStackAlignment(entry); }] >, + StaticInterfaceMethod< + /*description=*/"Returns the memory space used by the ABI computed " + "using the relevant entries. The data layout object " + "can be used for recursive queries.", + /*retTy=*/"::mlir::Attribute", + /*methodName=*/"getDefaultMemorySpace", + /*args=*/(ins "::mlir::DataLayoutEntryInterface":$entry), + /*methodBody=*/"", + /*defaultImplementation=*/[{ + return ::mlir::detail::getDefaultMemorySpace(entry); + }] + >, StaticInterfaceMethod< /*description=*/"Returns the value of the property, if the property is " "defined. Otherwise, it returns std::nullopt.", @@ -567,7 +585,9 @@ def DataLayoutTypeInterface : TypeInterface<"DataLayoutTypeInterface"> { /*retTy=*/"bool", /*methodName=*/"areCompatible", /*args=*/(ins "::mlir::DataLayoutEntryListRef":$oldLayout, - "::mlir::DataLayoutEntryListRef":$newLayout), + "::mlir::DataLayoutEntryListRef":$newLayout, + "::mlir::DataLayoutSpecInterface":$newSpec, + "const ::mlir::DataLayoutIdentifiedEntryMap&":$identified), /*methodBody=*/"", /*defaultImplementation=*/[{ return true; }] >, diff --git a/mlir/lib/Dialect/Ptr/IR/CMakeLists.txt b/mlir/lib/Dialect/Ptr/IR/CMakeLists.txt index 9cf3643c73d3..ec0d5769c776 100644 --- a/mlir/lib/Dialect/Ptr/IR/CMakeLists.txt +++ b/mlir/lib/Dialect/Ptr/IR/CMakeLists.txt @@ -7,7 +7,9 @@ add_mlir_dialect_library( DEPENDS MLIRPtrOpsAttributesIncGen MLIRPtrOpsIncGen - + MLIRPtrOpsEnumsGen + MLIRPtrMemorySpaceInterfacesIncGen + LINK_LIBS PUBLIC MLIRIR diff --git a/mlir/lib/Dialect/Ptr/IR/PtrDialect.cpp b/mlir/lib/Dialect/Ptr/IR/PtrDialect.cpp index 7830ffe893df..ff231dae60c2 100644 --- a/mlir/lib/Dialect/Ptr/IR/PtrDialect.cpp +++ b/mlir/lib/Dialect/Ptr/IR/PtrDialect.cpp @@ -48,6 +48,12 @@ void PtrDialect::initialize() { #define GET_ATTRDEF_CLASSES #include "mlir/Dialect/Ptr/IR/PtrOpsAttrs.cpp.inc" +#include "mlir/Dialect/Ptr/IR/MemorySpaceInterfaces.cpp.inc" + +#include "mlir/Dialect/Ptr/IR/MemorySpaceAttrInterfaces.cpp.inc" + +#include "mlir/Dialect/Ptr/IR/PtrOpsEnums.cpp.inc" + #define GET_TYPEDEF_CLASSES #include "mlir/Dialect/Ptr/IR/PtrOpsTypes.cpp.inc" diff --git a/mlir/lib/Dialect/Ptr/IR/PtrTypes.cpp b/mlir/lib/Dialect/Ptr/IR/PtrTypes.cpp index dcc4a1359416..a0ea5e83f646 100644 --- a/mlir/lib/Dialect/Ptr/IR/PtrTypes.cpp +++ b/mlir/lib/Dialect/Ptr/IR/PtrTypes.cpp @@ -28,7 +28,7 @@ constexpr const static unsigned kDefaultPointerAlignmentBits = 8; /// Searches the data layout for the pointer spec, returns nullptr if it is not /// found. static SpecAttr getPointerSpec(DataLayoutEntryListRef params, PtrType type, - Attribute defaultMemorySpace) { + MemorySpaceAttrInterface defaultMemorySpace) { for (DataLayoutEntryInterface entry : params) { if (!entry.isTypeEntry()) continue; @@ -38,9 +38,11 @@ static SpecAttr getPointerSpec(DataLayoutEntryListRef params, PtrType type, return spec; } } - // If not found, and this is the pointer to the default memory space, assume - // 64-bit pointers. - if (type.getMemorySpace() == defaultMemorySpace) + // If not found, and this is the pointer to the default memory space or if + // `defaultMemorySpace` is null, assume 64-bit pointers. `defaultMemorySpace` + // might be null if the data layout doesn't define the default memory space. + if (type.getMemorySpace() == defaultMemorySpace || + defaultMemorySpace == nullptr) return SpecAttr::get(type.getContext(), kDefaultPointerSizeBits, kDefaultPointerAlignmentBits, kDefaultPointerAlignmentBits, kDefaultPointerSizeBits); @@ -93,44 +95,47 @@ bool PtrType::areCompatible(DataLayoutEntryListRef oldLayout, uint64_t PtrType::getABIAlignment(const DataLayout &dataLayout, DataLayoutEntryListRef params) const { - Attribute defaultMemorySpace = dataLayout.getDefaultMemorySpace(); + auto defaultMemorySpace = llvm::cast_if_present( + dataLayout.getDefaultMemorySpace()); if (SpecAttr spec = getPointerSpec(params, *this, defaultMemorySpace)) return spec.getAbi() / kBitsInByte; - return dataLayout.getTypeABIAlignment(get(getContext(), defaultMemorySpace)); + return dataLayout.getTypeABIAlignment(get(defaultMemorySpace)); } std::optional PtrType::getIndexBitwidth(const DataLayout &dataLayout, DataLayoutEntryListRef params) const { - Attribute defaultMemorySpace = dataLayout.getDefaultMemorySpace(); + auto defaultMemorySpace = llvm::cast_if_present( + dataLayout.getDefaultMemorySpace()); if (SpecAttr spec = getPointerSpec(params, *this, defaultMemorySpace)) { return spec.getIndex() == SpecAttr::kOptionalSpecValue ? spec.getSize() : spec.getIndex(); } - return dataLayout.getTypeIndexBitwidth(get(getContext(), defaultMemorySpace)); + return dataLayout.getTypeIndexBitwidth(get(defaultMemorySpace)); } llvm::TypeSize PtrType::getTypeSizeInBits(const DataLayout &dataLayout, DataLayoutEntryListRef params) const { - Attribute defaultMemorySpace = dataLayout.getDefaultMemorySpace(); + auto defaultMemorySpace = llvm::cast_if_present( + dataLayout.getDefaultMemorySpace()); if (SpecAttr spec = getPointerSpec(params, *this, defaultMemorySpace)) return llvm::TypeSize::getFixed(spec.getSize()); // For other memory spaces, use the size of the pointer to the default memory // space. - return dataLayout.getTypeSizeInBits(get(getContext(), defaultMemorySpace)); + return dataLayout.getTypeSizeInBits(get(defaultMemorySpace)); } uint64_t PtrType::getPreferredAlignment(const DataLayout &dataLayout, DataLayoutEntryListRef params) const { - Attribute defaultMemorySpace = dataLayout.getDefaultMemorySpace(); + auto defaultMemorySpace = llvm::cast_if_present( + dataLayout.getDefaultMemorySpace()); if (SpecAttr spec = getPointerSpec(params, *this, defaultMemorySpace)) return spec.getPreferred() / kBitsInByte; - return dataLayout.getTypePreferredAlignment( - get(getContext(), defaultMemorySpace)); + return dataLayout.getTypePreferredAlignment(get(defaultMemorySpace)); } LogicalResult PtrType::verifyEntries(DataLayoutEntryListRef entries, -- Gitee From 3a999884a163fa1a7f384cd43673830bfa58fbdf Mon Sep 17 00:00:00 2001 From: Daniel Date: Tue, 2 Sep 2025 16:20:11 +0200 Subject: [PATCH 4/7] [Backport][mlir][ptr] Add the ptradd and type_offset ops, and generic_space attr --- .../mlir/Dialect/Ptr/IR/PtrAttrDefs.td | 28 +++++++ mlir/include/mlir/Dialect/Ptr/IR/PtrAttrs.h | 5 ++ mlir/include/mlir/Dialect/Ptr/IR/PtrEnums.td | 11 +++ mlir/include/mlir/Dialect/Ptr/IR/PtrOps.h | 2 + mlir/include/mlir/Dialect/Ptr/IR/PtrOps.td | 75 +++++++++++++++++++ mlir/include/mlir/IR/Properties.td | 24 ++++++ mlir/lib/Dialect/Ptr/IR/CMakeLists.txt | 3 +- mlir/lib/Dialect/Ptr/IR/PtrAttrs.cpp | 46 ++++++++++++ mlir/lib/Dialect/Ptr/IR/PtrDialect.cpp | 27 +++++++ 9 files changed, 220 insertions(+), 1 deletion(-) diff --git a/mlir/include/mlir/Dialect/Ptr/IR/PtrAttrDefs.td b/mlir/include/mlir/Dialect/Ptr/IR/PtrAttrDefs.td index e75038f300f1..24ee1851c9a9 100644 --- a/mlir/include/mlir/Dialect/Ptr/IR/PtrAttrDefs.td +++ b/mlir/include/mlir/Dialect/Ptr/IR/PtrAttrDefs.td @@ -10,7 +10,9 @@ #define PTR_ATTRDEFS include "mlir/Dialect/Ptr/IR/PtrDialect.td" +include "mlir/Dialect/Ptr/IR/MemorySpaceInterfaces.td" include "mlir/IR/AttrTypeBase.td" +include "mlir/IR/BuiltinAttributeInterfaces.td" // All of the attributes will extend this class. class Ptr_Attr + ]> { + let summary = "Generic memory space"; + let description = [{ + The `generic_space` attribute defines a memory space attribute with the + following properties: + - Load and store operations are always valid, regardless of the type. + - Atomic operations are always valid, regardless of the type. + - Cast operations to `generic_space` are always valid. + + Example: + + ```mlir + #ptr.generic_space + ``` + }]; + let assemblyFormat = ""; +} + + //===----------------------------------------------------------------------===// // SpecAttr //===----------------------------------------------------------------------===// diff --git a/mlir/include/mlir/Dialect/Ptr/IR/PtrAttrs.h b/mlir/include/mlir/Dialect/Ptr/IR/PtrAttrs.h index 5ffe23e45fe1..dc0a3ffd4ae3 100644 --- a/mlir/include/mlir/Dialect/Ptr/IR/PtrAttrs.h +++ b/mlir/include/mlir/Dialect/Ptr/IR/PtrAttrs.h @@ -13,7 +13,12 @@ #ifndef MLIR_DIALECT_PTR_IR_PTRATTRS_H #define MLIR_DIALECT_PTR_IR_PTRATTRS_H +#include "mlir/IR/BuiltinAttributeInterfaces.h" #include "mlir/IR/OpImplementation.h" +#include "mlir/Interfaces/DataLayoutInterfaces.h" +#include "llvm/Support/TypeSize.h" + +#include "mlir/Dialect/Ptr/IR/MemorySpaceInterfaces.h" #define GET_ATTRDEF_CLASSES #include "mlir/Dialect/Ptr/IR/PtrOpsAttrs.h.inc" diff --git a/mlir/include/mlir/Dialect/Ptr/IR/PtrEnums.td b/mlir/include/mlir/Dialect/Ptr/IR/PtrEnums.td index 59d2d0b34222..472891dca5cd 100644 --- a/mlir/include/mlir/Dialect/Ptr/IR/PtrEnums.td +++ b/mlir/include/mlir/Dialect/Ptr/IR/PtrEnums.td @@ -66,4 +66,15 @@ def AtomicOrdering : I64EnumAttr< let cppNamespace = "::mlir::ptr"; } +//===----------------------------------------------------------------------===// +// Ptr add flags enum properties. +//===----------------------------------------------------------------------===// + +def Ptr_PtrAddFlags : I32EnumAttr<"PtrAddFlags", "Pointer add flags", [ + I32EnumAttrCase<"none", 0>, I32EnumAttrCase<"nusw", 1>, I32EnumAttrCase<"nuw", 2>, + I32EnumAttrCase<"inbounds", 3> + ]> { + let cppNamespace = "::mlir::ptr"; +} + #endif // PTR_ENUMS \ No newline at end of file diff --git a/mlir/include/mlir/Dialect/Ptr/IR/PtrOps.h b/mlir/include/mlir/Dialect/Ptr/IR/PtrOps.h index 6a0c1429c6be..8686cc7d316d 100644 --- a/mlir/include/mlir/Dialect/Ptr/IR/PtrOps.h +++ b/mlir/include/mlir/Dialect/Ptr/IR/PtrOps.h @@ -18,6 +18,8 @@ #include "mlir/Dialect/Ptr/IR/PtrDialect.h" #include "mlir/Dialect/Ptr/IR/PtrTypes.h" #include "mlir/IR/OpDefinition.h" +#include "mlir/Interfaces/SideEffectInterfaces.h" +#include "mlir/Interfaces/ViewLikeInterface.h" #define GET_OP_CLASSES #include "mlir/Dialect/Ptr/IR/PtrOps.h.inc" diff --git a/mlir/include/mlir/Dialect/Ptr/IR/PtrOps.td b/mlir/include/mlir/Dialect/Ptr/IR/PtrOps.td index 313c9f8eb09a..308c711ee135 100644 --- a/mlir/include/mlir/Dialect/Ptr/IR/PtrOps.td +++ b/mlir/include/mlir/Dialect/Ptr/IR/PtrOps.td @@ -11,7 +11,82 @@ include "mlir/Dialect/Ptr/IR/PtrDialect.td" include "mlir/Dialect/Ptr/IR/PtrAttrDefs.td" +include "mlir/Dialect/Ptr/IR/PtrEnums.td" include "mlir/Dialect/Ptr/IR/MemorySpaceInterfaces.td" +include "mlir/Interfaces/SideEffectInterfaces.td" +include "mlir/Interfaces/ViewLikeInterface.td" include "mlir/IR/OpAsmInterface.td" +//===----------------------------------------------------------------------===// +// PtrAddOp +//===----------------------------------------------------------------------===// + +def Ptr_PtrAddOp : Pointer_Op<"ptr_add", [ + Pure, AllTypesMatch<["base", "result"]>, ViewLikeOpInterface + ]> { + let summary = "Pointer add operation"; + let description = [{ + The `ptr_add` operation adds an integer offset to a pointer to produce a new + pointer. The input and output pointer types are always the same. + + Example: + + ```mlir + %x_off = ptr.ptr_add %x, %off : !ptr.ptr<0>, i32 + %x_off0 = ptr.ptr_add nusw %x, %off : !ptr.ptr<0>, i32 + ``` + }]; + + let arguments = (ins + Ptr_PtrType:$base, + AnySignlessIntegerOrIndex:$offset, + DefaultValuedProperty, "PtrAddFlags::none">:$flags); + let results = (outs Ptr_PtrType:$result); + let assemblyFormat = [{ + ($flags^)? $base `,` $offset attr-dict `:` type($base) `,` type($offset) + }]; + let hasFolder = 1; + let extraClassDeclaration = [{ + /// `ViewLikeOp::getViewSource` method. + Value getViewSource() { return getBase(); } + }]; +} + +//===----------------------------------------------------------------------===// +// TypeOffsetOp +//===----------------------------------------------------------------------===// + +def Ptr_TypeOffsetOp : Pointer_Op<"type_offset", [Pure]> { + let summary = "Type offset operation"; + let description = [{ + The `type_offset` operation produces an int or index-typed SSA value + equal to a target-specific constant representing the offset of a single + element of the given type. + + Example: + + ```mlir + // Return the offset between two f32 stored in memory + %0 = ptr.type_offset f32 : index + // Return the offset between two memref descriptors stored in memory + %1 = ptr.type_offset memref<12 x f64> : i32 + ``` + }]; + + let arguments = (ins TypeAttr:$elementType); + let results = (outs AnySignlessIntegerOrIndex:$result); + let builders = [ + OpBuilder<(ins "Type":$elementType)> + ]; + let assemblyFormat = [{ + $elementType attr-dict `:` type($result) + }]; + let extraClassDeclaration = [{ + /// Returns the type offset according to `layout`. If `layout` is `nullopt` + /// the nearest layout the op will be used for the computation. + llvm::TypeSize getTypeSize(std::optional layout = std::nullopt); + }]; +} + + #endif // PTR_OPS diff --git a/mlir/include/mlir/IR/Properties.td b/mlir/include/mlir/IR/Properties.td index 0babdbbfa05b..e508dab7e82a 100644 --- a/mlir/include/mlir/IR/Properties.td +++ b/mlir/include/mlir/IR/Properties.td @@ -166,4 +166,28 @@ class EnumProperty : }]; } +/// Class for giving a property a default value. +/// This doesn't change anything about the property other than giving it a default +/// which can be used by ODS to elide printing. +class DefaultValuedProperty : Property { + let defaultValue = default; + let storageTypeValueOverride = storageDefault; + let baseProperty = p; + // Keep up to date with `Property` above. + let summary = p.summary; + let description = p.description; + let storageType = p.storageType; + let interfaceType = p.interfaceType; + let convertFromStorage = p.convertFromStorage; + let assignToStorage = p.assignToStorage; + let convertToAttribute = p.convertToAttribute; + let convertFromAttribute = p.convertFromAttribute; + let hashProperty = p.hashProperty; + let parser = p.parser; + let optionalParser = p.optionalParser; + let printer = p.printer; + let readFromMlirBytecode = p.readFromMlirBytecode; + let writeToMlirBytecode = p.writeToMlirBytecode; +} + #endif // PROPERTIES diff --git a/mlir/lib/Dialect/Ptr/IR/CMakeLists.txt b/mlir/lib/Dialect/Ptr/IR/CMakeLists.txt index ec0d5769c776..8c186594b6ad 100644 --- a/mlir/lib/Dialect/Ptr/IR/CMakeLists.txt +++ b/mlir/lib/Dialect/Ptr/IR/CMakeLists.txt @@ -9,10 +9,11 @@ add_mlir_dialect_library( MLIRPtrOpsIncGen MLIRPtrOpsEnumsGen MLIRPtrMemorySpaceInterfacesIncGen - + LINK_LIBS PUBLIC MLIRIR MLIRDataLayoutInterfaces MLIRMemorySlotInterfaces + MLIRViewLikeInterface ) diff --git a/mlir/lib/Dialect/Ptr/IR/PtrAttrs.cpp b/mlir/lib/Dialect/Ptr/IR/PtrAttrs.cpp index f8ce820d0bcb..1770e4febf09 100644 --- a/mlir/lib/Dialect/Ptr/IR/PtrAttrs.cpp +++ b/mlir/lib/Dialect/Ptr/IR/PtrAttrs.cpp @@ -11,6 +11,7 @@ //===----------------------------------------------------------------------===// #include "mlir/Dialect/Ptr/IR/PtrAttrs.h" +#include "mlir/IR/BuiltinTypes.h" #include "llvm/ADT/TypeSwitch.h" using namespace mlir; @@ -18,6 +19,51 @@ using namespace mlir::ptr; constexpr const static unsigned kBitsInByte = 8; +//===----------------------------------------------------------------------===// +// GenericSpaceAttr +//===----------------------------------------------------------------------===// + +LogicalResult GenericSpaceAttr::isValidLoad( + Type type, ptr::AtomicOrdering ordering, IntegerAttr alignment, + function_ref emitError) const { + return success(); +} + +LogicalResult GenericSpaceAttr::isValidStore( + Type type, ptr::AtomicOrdering ordering, IntegerAttr alignment, + function_ref emitError) const { + return success(); +} + +LogicalResult GenericSpaceAttr::isValidAtomicOp( + ptr::AtomicBinOp op, Type type, ptr::AtomicOrdering ordering, + IntegerAttr alignment, function_ref emitError) const { + return success(); +} + +LogicalResult GenericSpaceAttr::isValidAtomicXchg( + Type type, ptr::AtomicOrdering successOrdering, + ptr::AtomicOrdering failureOrdering, IntegerAttr alignment, + function_ref emitError) const { + return success(); +} + +LogicalResult GenericSpaceAttr::isValidAddrSpaceCast( + Type tgt, Type src, function_ref emitError) const { + // TODO: update this method once the `addrspace_cast` op is added to the + // dialect. + assert(false && "unimplemented, see TODO in the source."); + return failure(); +} + +LogicalResult GenericSpaceAttr::isValidPtrIntCast( + Type intLikeTy, Type ptrLikeTy, + function_ref emitError) const { + // TODO: update this method once the int-cast ops are added to the dialect. + assert(false && "unimplemented, see TODO in the source."); + return failure(); +} + //===----------------------------------------------------------------------===// // SpecAttr //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Dialect/Ptr/IR/PtrDialect.cpp b/mlir/lib/Dialect/Ptr/IR/PtrDialect.cpp index ff231dae60c2..c21783011452 100644 --- a/mlir/lib/Dialect/Ptr/IR/PtrDialect.cpp +++ b/mlir/lib/Dialect/Ptr/IR/PtrDialect.cpp @@ -12,7 +12,9 @@ #include "mlir/Dialect/Ptr/IR/PtrOps.h" #include "mlir/IR/DialectImplementation.h" +#include "mlir/IR/Matchers.h" #include "mlir/IR/PatternMatch.h" +#include "mlir/Interfaces/DataLayoutInterfaces.h" #include "mlir/Transforms/InliningUtils.h" #include "llvm/ADT/SmallString.h" #include "llvm/ADT/TypeSwitch.h" @@ -39,6 +41,31 @@ void PtrDialect::initialize() { >(); } +//===----------------------------------------------------------------------===// +// PtrAddOp +//===----------------------------------------------------------------------===// + +/// Fold: ptradd ptr + 0 -> ptr +OpFoldResult PtrAddOp::fold(FoldAdaptor adaptor) { + Attribute attr = adaptor.getOffset(); + if (!attr) + return nullptr; + if (llvm::APInt value; m_ConstantInt(&value).match(attr) && value.isZero()) + return getBase(); + return nullptr; +} + +//===----------------------------------------------------------------------===// +// TypeOffsetOp +//===----------------------------------------------------------------------===// + +llvm::TypeSize TypeOffsetOp::getTypeSize(std::optional layout) { + if (layout) + return layout->getTypeSize(getElementType()); + DataLayout dl = DataLayout::closest(*this); + return dl.getTypeSize(getElementType()); +} + //===----------------------------------------------------------------------===// // Pointer API. //===----------------------------------------------------------------------===// -- Gitee From bc55ba1f14686342742bff8b105fb0bb671d7c7b Mon Sep 17 00:00:00 2001 From: Daniel Date: Tue, 2 Sep 2025 16:20:41 +0200 Subject: [PATCH 5/7] [Backport] [mlir] Add property combinators, initial ODS support --- mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td | 15 +- mlir/include/mlir/Dialect/Ptr/IR/PtrOps.td | 2 +- mlir/include/mlir/IR/ODSSupport.h | 46 ++ mlir/include/mlir/IR/Properties.td | 543 +++++++++++++++++- mlir/include/mlir/TableGen/Operator.h | 2 +- mlir/include/mlir/TableGen/Property.h | 53 +- mlir/lib/IR/ODSSupport.cpp | 73 ++- mlir/lib/TableGen/Property.cpp | 59 +- .../test/lib/Dialect/Test/TestFormatUtils.cpp | 16 +- mlir/test/lib/Dialect/Test/TestFormatUtils.h | 3 +- mlir/test/lib/Dialect/Test/TestOps.td | 83 ++- mlir/test/lib/Dialect/Test/TestOpsSyntax.td | 22 + mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp | 202 ++++++- mlir/tools/mlir-tblgen/OpFormatGen.cpp | 292 +++++++--- 14 files changed, 1235 insertions(+), 176 deletions(-) diff --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td b/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td index d2d1fbaf304b..55ec45dd14a9 100644 --- a/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td +++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td @@ -59,22 +59,9 @@ class LLVM_IntArithmeticOpWithOverflowFlag traits = []> : LLVM_ArithmeticOpBase], traits)> { - dag iofArg = (ins EnumProperty<"IntegerOverflowFlags">:$overflowFlags); + dag iofArg = (ins EnumProperty<"IntegerOverflowFlags", "", "IntegerOverflowFlags::none">:$overflowFlags); let arguments = !con(commonArgs, iofArg); - let builders = [ - OpBuilder<(ins "Type":$type, "Value":$lhs, "Value":$rhs, - "IntegerOverflowFlags":$overflowFlags), [{ - $_state.getOrAddProperties().overflowFlags = overflowFlags; - build($_builder, $_state, type, lhs, rhs); - }]>, - OpBuilder<(ins "Value":$lhs, "Value":$rhs, - "IntegerOverflowFlags":$overflowFlags), [{ - $_state.getOrAddProperties().overflowFlags = overflowFlags; - build($_builder, $_state, lhs, rhs); - }]> - ]; - string mlirBuilder = [{ auto op = $_builder.create<$_qualCppClassName>($_location, $lhs, $rhs); moduleImport.setIntegerOverflowFlags(inst, op); diff --git a/mlir/include/mlir/Dialect/Ptr/IR/PtrOps.td b/mlir/include/mlir/Dialect/Ptr/IR/PtrOps.td index 308c711ee135..df4784dd94f8 100644 --- a/mlir/include/mlir/Dialect/Ptr/IR/PtrOps.td +++ b/mlir/include/mlir/Dialect/Ptr/IR/PtrOps.td @@ -40,7 +40,7 @@ def Ptr_PtrAddOp : Pointer_Op<"ptr_add", [ let arguments = (ins Ptr_PtrType:$base, AnySignlessIntegerOrIndex:$offset, - DefaultValuedProperty, "PtrAddFlags::none">:$flags); + DefaultValuedProperty, "PtrAddFlags::none">:$flags); let results = (outs Ptr_PtrType:$result); let assemblyFormat = [{ ($flags^)? $base `,` $offset attr-dict `:` type($base) `,` type($offset) diff --git a/mlir/include/mlir/IR/ODSSupport.h b/mlir/include/mlir/IR/ODSSupport.h index 70e3f986431e..cafb0b58a759 100644 --- a/mlir/include/mlir/IR/ODSSupport.h +++ b/mlir/include/mlir/IR/ODSSupport.h @@ -33,6 +33,37 @@ convertFromAttribute(int64_t &storage, Attribute attr, /// Convert the provided int64_t to an IntegerAttr attribute. Attribute convertToAttribute(MLIRContext *ctx, int64_t storage); +/// Convert an IntegerAttr attribute to an int32_t, or return an error if the +/// attribute isn't an IntegerAttr. If the optional diagnostic is provided an +/// error message is also emitted. +LogicalResult +convertFromAttribute(int32_t &storage, Attribute attr, + function_ref emitError); + +/// Convert the provided int32_t to an IntegerAttr attribute. +Attribute convertToAttribute(MLIRContext *ctx, int32_t storage); + +/// Extract the string from `attr` into `storage`. If `attr` is not a +/// `StringAttr`, return failure and emit an error into the diagnostic from +/// `emitError`. +LogicalResult +convertFromAttribute(std::string &storage, Attribute attr, + function_ref emitError); + +/// Convert the given string into a StringAttr. Note that this takes a reference +/// to the storage of a string property, which is an std::string. +Attribute convertToAttribute(MLIRContext *ctx, const std::string &storage); + +/// Extract the boolean from `attr` into `storage`. If `attr` is not a +/// `BoolAttr`, return failure and emit an error into the diagnostic from +/// `emitError`. +LogicalResult +convertFromAttribute(bool &storage, Attribute attr, + function_ref emitError); + +/// Convert the given string into a BooleanAttr. +Attribute convertToAttribute(MLIRContext *ctx, bool storage); + /// Convert a DenseI64ArrayAttr to the provided storage. It is expected that the /// storage has the same size as the array. An error is returned if the /// attribute isn't a DenseI64ArrayAttr or it does not have the same size. If @@ -49,6 +80,21 @@ LogicalResult convertFromAttribute(MutableArrayRef storage, Attribute attr, function_ref emitError); +/// Convert a DenseI64ArrayAttr to the provided storage, which will be +/// cleared before writing. An error is returned and emitted to the optional +/// `emitError` function if the attribute isn't a DenseI64ArrayAttr. +LogicalResult +convertFromAttribute(SmallVectorImpl &storage, Attribute attr, + function_ref emitError); + +/// Convert a DenseI32ArrayAttr to the provided storage, which will be +/// cleared before writing. It is expected that the storage has the same size as +/// the array. An error is returned and emitted to the optional `emitError` +/// function if the attribute isn't a DenseI32ArrayAttr. +LogicalResult +convertFromAttribute(SmallVectorImpl &storage, Attribute attr, + function_ref emitError); + /// Convert the provided ArrayRef to a DenseI64ArrayAttr attribute. Attribute convertToAttribute(MLIRContext *ctx, ArrayRef storage); diff --git a/mlir/include/mlir/IR/Properties.td b/mlir/include/mlir/IR/Properties.td index e508dab7e82a..f55a5ab96f77 100644 --- a/mlir/include/mlir/IR/Properties.td +++ b/mlir/include/mlir/IR/Properties.td @@ -29,7 +29,6 @@ class Property { // // Format: // - `$_storage` will contain the property in the storage type. - // - `$_ctxt` will contain an `MLIRContext *`. code convertFromStorage = "$_storage"; // The call expression to build a property storage from the interface type. @@ -40,24 +39,26 @@ class Property { code assignToStorage = "$_storage = $_value"; // The call expression to convert from the storage type to an attribute. + // The resulting attribute must be non-null in non-error cases. // // Format: // - `$_storage` is the storage type value. // - `$_ctxt` is a `MLIRContext *`. // - // The expression must result in an Attribute. + // The expression must return an `Attribute` and will be used as a function body. code convertToAttribute = [{ - convertToAttribute($_ctxt, $_storage) + return convertToAttribute($_ctxt, $_storage); }]; // The call expression to convert from an Attribute to the storage type. // // Format: - // - `$_storage` is the storage type value. + // - `$_storage` is a reference to a value of the storage type. // - `$_attr` is the attribute. // - `$_diag` is a callback to get a Diagnostic to emit error. // - // The expression must return a LogicalResult + // The expression must return a LogicalResult and will be used as a function body + // or in other similar contexts. code convertFromAttribute = [{ return convertFromAttribute($_storage, $_attr, $_diag); }]; @@ -68,18 +69,68 @@ class Property { // - `$_storage` is the variable to hash. // // The expression should define a llvm::hash_code. - code hashProperty = [{ - llvm::hash_value($_storage); + // If unspecified, defaults to `llvm::hash_value($_storage)`. + // The default is not specified in tablegen because many combinators, like + // ArrayProperty, can fall back to more efficient implementations of + // `hashProperty` when their underlying elements have trivial hashing. + code hashProperty = ""; + + // The body of the parser for a value of this property. + // Format: + // - `$_parser` is the OpAsmParser. + // - `$_storage` is the location into which the value is to be placed if it is + // present. + // - `$_ctxt` is a `MLIRContext *` + // + // This defines the body of a function (typically a lambda) that returns a + // ParseResult. There is an implicit `return success()` at the end of the parser + // code. + // + // When this code executes, `$_storage` will be initialized to the property's + // default value (if any, accounting for the storage type override). + code parser = [{ + auto value = ::mlir::FieldParser<}] # storageType # [{>::parse($_parser); + if (::mlir::failed(value)) + return ::mlir::failure(); + $_storage = std::move(*value); }]; + // The body of the parser for a value of this property as the anchor of an optional + // group. This should parse the property if possible and do nothing if a value of + // the relevant type is not next in the parse stream. + // You are not required to define this parser if it cannot be meaningfully + // implemented. + // This has the same context and substitutions as `parser` except that it is + // required to return an OptionalParseResult. + // + // If the optional parser doesn't parse anything, it should not set + // $_storage, since the parser doesn't know if the default value has been + // overwritten. + code optionalParser = ""; + + // The printer for a value of this property. + // Format: + // - `$_storage` is the storage data. + // - `$_printer` is the OpAsmPrinter instance. + // - `$_ctxt` is a `MLIRContext *` + // + // This may be called in an expression context, so variable declarations must + // be placed within a new scope. + // + // The printer for a property should always print a non-empty value - default value + // printing elision happens outside the context of this printing expression. + code printer = "$_printer << $_storage"; + // The call expression to emit the storage type to bytecode. // // Format: // - `$_storage` is the storage type value. // - `$_writer` is a `DialectBytecodeWriter`. // - `$_ctxt` is a `MLIRContext *`. + // + // This will become the body af a function returning void. code writeToMlirBytecode = [{ - writeToMlirBytecode($_writer, $_storage) + writeToMlirBytecode($_writer, $_storage); }]; // The call expression to read the storage type from bytecode. @@ -88,13 +139,31 @@ class Property { // - `$_storage` is the storage type value. // - `$_reader` is a `DialectBytecodeReader`. // - `$_ctxt` is a `MLIRContext *`. + // + // This will become the body of a function returning LogicalResult. + // There is an implicit `return success()` at the end of this function. + // + // When this code executes, `$_storage` will be initialized to the property's + // default value (if any, accounting for the storage type override). code readFromMlirBytecode = [{ if (::mlir::failed(readFromMlirBytecode($_reader, $_storage))) return ::mlir::failure(); }]; - // Default value for the property. - string defaultValue = ?; + // Base definition for the property. (Will be) used for `OptionalProperty` and + // such cases, analogously to `baseAttr`. + Property baseProperty = ?; + + // Default value for the property within its storage. This should be an expression + // of type `interfaceType` and should be comparable with other types of that + // interface typ with `==`. The empty string means there is no default value. + string defaultValue = ""; + + // If set, the default value the storage of the property should be initilized to. + // This is only needed when the storage and interface types of the property + // are distinct (ex. SmallVector for storage vs. ArrayRef for interfacing), as it + // will fall back to `defaultValue` when unspecified. + string storageTypeValueOverride = ""; } /// Implementation of the Property class's `readFromMlirBytecode` field using @@ -133,12 +202,16 @@ defvar writeMlirBytecodeWithConvertToAttribute = [{ // Primitive property kinds // Any kind of integer stored as properties. -class IntProperty : +class IntProperty : Property { - code writeToMlirBytecode = [{ + let summary = !if(!empty(desc), storageTypeParam, desc); + let optionalParser = [{ + return $_parser.parseOptionalInteger($_storage); + }]; + let writeToMlirBytecode = [{ $_writer.writeVarInt($_storage); }]; - code readFromMlirBytecode = [{ + let readFromMlirBytecode = [{ uint64_t val; if (failed($_reader.readVarInt(val))) return ::mlir::failure(); @@ -146,24 +219,259 @@ class IntProperty : }]; } -class ArrayProperty : - Property { - let interfaceType = "::llvm::ArrayRef<" # storageTypeParam # ">"; - let convertFromStorage = "$_storage"; - let assignToStorage = "::llvm::copy($_value, $_storage)"; -} +def I32Property : IntProperty<"int32_t">; +def I64Property : IntProperty<"int64_t">; -class EnumProperty : +class EnumProperty : Property { - code writeToMlirBytecode = [{ + // TODO: take advantage of EnumAttrInfo and the like to make this share nice + // parsing code with EnumAttr. + let writeToMlirBytecode = [{ $_writer.writeVarInt(static_cast($_storage)); }]; - code readFromMlirBytecode = [{ + let readFromMlirBytecode = [{ uint64_t val; if (failed($_reader.readVarInt(val))) return ::mlir::failure(); $_storage = static_cast<}] # storageTypeParam # [{>(val); }]; + let defaultValue = default; +} + +def StringProperty : Property<"std::string", "string"> { + let interfaceType = "::llvm::StringRef"; + let convertFromStorage = "::llvm::StringRef{$_storage}"; + let assignToStorage = "$_storage = $_value.str()"; + let optionalParser = [{ + if (::mlir::failed($_parser.parseOptionalString(&$_storage))) + return std::nullopt; + }]; + let printer = "$_printer.printString($_storage)"; + let readFromMlirBytecode = [{ + StringRef val; + if (::mlir::failed($_reader.readString(val))) + return ::mlir::failure(); + $_storage = val.str(); + }]; + let writeToMlirBytecode = [{ + $_writer.writeOwnedString($_storage); + }]; +} + +def BoolProperty : IntProperty<"bool", "boolean"> { + let printer = [{ $_printer << ($_storage ? "true" : "false") }]; + let readFromMlirBytecode = [{ + return $_reader.readBool($_storage); + }]; + let writeToMlirBytecode = [{ + $_writer.writeOwnedBool($_storage); + }]; +} + +def UnitProperty : Property<"bool", "unit property"> { + let summary = "unit property"; + let description = [{ + A property whose presence or abscence is used as a flag. + + This is stored as a boolean that defaults to false, and is named UnitProperty + by analogy with UnitAttr, which has the more comprehensive rationale and + explains the less typical syntax. + + Note that this attribute does have a syntax for the false case to allow for its + use in contexts where default values shouldn't be elided. + }]; + let defaultValue = "false"; + + let convertToAttribute = [{ + if ($_storage) + return ::mlir::UnitAttr::get($_ctxt); + else + return ::mlir::BoolAttr::get($_ctxt, false); + }]; + let convertFromAttribute = [{ + if (::llvm::isa<::mlir::UnitAttr>($_attr)) { + $_storage = true; + return ::mlir::success(); + } + if (auto boolAttr = ::llvm::dyn_cast<::mlir::BoolAttr>($_attr)) { + $_storage = boolAttr.getValue(); + return ::mlir::success(); + } + return ::mlir::failure(); + }]; + + let parser = [{ + ::llvm::StringRef keyword; + if (::mlir::failed($_parser.parseOptionalKeyword(&keyword, + {"unit", "unit_absent"}))) + return $_parser.emitError($_parser.getCurrentLocation(), + "expected 'unit' or 'unit_absent'"); + $_storage = (keyword == "unit"); + }]; + + let optionalParser = [{ + ::llvm::StringRef keyword; + if (::mlir::failed($_parser.parseOptionalKeyword(&keyword, + {"unit", "unit_absent"}))) + return std::nullopt; + $_storage = (keyword == "unit"); + }]; + + let printer = [{ + $_printer << ($_storage ? "unit" : "unit_absent") + }]; + + let writeToMlirBytecode = [{ + $_writer.writeOwnedBool($_storage); + }]; + let readFromMlirBytecode = [{ + if (::mlir::failed($_reader.readBool($_storage))) + return ::mlir::failure(); + }]; +} + +//===----------------------------------------------------------------------===// +// Primitive property combinators + +/// Create a variable named `name` of `prop`'s storage type that is initialized +/// to the correct default value, if there is one. +class _makePropStorage { + code ret = prop.storageType # " " # name + # !cond(!not(!empty(prop.storageTypeValueOverride)) : " = " # prop.storageTypeValueOverride, + !not(!empty(prop.defaultValue)) : " = " # prop.defaultValue, + true : "") # ";"; +} + +/// The generic class for arrays of some other property, which is stored as a +/// `SmallVector` of that property. This uses an `ArrayAttr` as its attribute form +/// though subclasses can override this, as is the case with IntArrayAttr below. +/// Those wishing to use a non-default number of SmallVector elements should +/// subclass `ArrayProperty`. +class ArrayProperty, string desc = ""> : + Property<"::llvm::SmallVector<" # elem.storageType # ">", desc> { + let summary = "array of " # elem.summary; + let interfaceType = "::llvm::ArrayRef<" # elem.storageType # ">"; + let convertFromStorage = "::llvm::ArrayRef<" # elem.storageType # ">{$_storage}"; + let assignToStorage = "$_storage.assign($_value.begin(), $_value.end())"; + + let convertFromAttribute = [{ + auto arrayAttr = ::llvm::dyn_cast_if_present<::mlir::ArrayAttr>($_attr); + if (!arrayAttr) + return $_diag() << "expected array attribute"; + for (::mlir::Attribute elemAttr : arrayAttr) { + }] # _makePropStorage.ret # [{ + auto elemRes = [&](Attribute propAttr, }] # elem.storageType # [{& propStorage) -> ::mlir::LogicalResult { + }] # !subst("$_attr", "propAttr", + !subst("$_storage", "propStorage", elem.convertFromAttribute)) # [{ + }(elemAttr, elemVal); + if (::mlir::failed(elemRes)) + return ::mlir::failure(); + $_storage.push_back(std::move(elemVal)); + } + return ::mlir::success(); + }]; + + let convertToAttribute = [{ + SmallVector elems; + for (const auto& elemVal : $_storage) { + auto elemAttr = [&](const }] # elem.storageType #[{& propStorage) -> ::mlir::Attribute { + }] # !subst("$_storage", "propStorage", elem.convertToAttribute) # [{ + }(elemVal); + elems.push_back(elemAttr); + } + return ::mlir::ArrayAttr::get($_ctxt, elems); + }]; + + defvar theParserBegin = [{ + auto& storage = $_storage; + auto parseElemFn = [&]() -> ::mlir::ParseResult { + }] # _makePropStorage.ret # [{ + auto elemParse = [&](}] # elem.storageType # [{& propStorage) -> ::mlir::ParseResult { + }] # !subst("$_storage", "propStorage", elem.parser) # [{ + return ::mlir::success(); + }(elemVal); + if (::mlir::failed(elemParse)) + return ::mlir::failure(); + storage.push_back(std::move(elemVal)); + return ::mlir::success(); + }; + }]; + let parser = theParserBegin # [{ + return $_parser.parseCommaSeparatedList( + ::mlir::OpAsmParser::Delimiter::Square, parseElemFn); + }]; + // Hack around the lack of a peek method + let optionalParser = theParserBegin # [{ + auto oldLoc = $_parser.getCurrentLocation(); + auto parseResult = $_parser.parseCommaSeparatedList( + ::mlir::OpAsmParser::Delimiter::OptionalSquare, parseElemFn); + if (::mlir::failed(parseResult)) + return ::mlir::failure(); + auto newLoc = $_parser.getCurrentLocation(); + if (oldLoc == newLoc) + return std::nullopt; + return ::mlir::success(); + }]; + + let printer = [{ [&](){ + $_printer << "["; + auto elemPrinter = [&](const }] # elem.storageType # [{& elemVal) { + }] # !subst("$_storage", "elemVal", elem.printer) #[{; + }; + ::llvm::interleaveComma($_storage, $_printer, elemPrinter); + $_printer << "]"; + }()}]; + + let readFromMlirBytecode = [{ + uint64_t length; + if (::mlir::failed($_reader.readVarInt(length))) + return ::mlir::failure(); + $_storage.reserve(length); + for (uint64_t i = 0; i < length; ++i) { + }]# _makePropStorage.ret # [{ + auto elemRead = [&](}] # elem.storageType # [{& propStorage) -> ::mlir::LogicalResult { + }] # !subst("$_storage", "propStorage", elem.readFromMlirBytecode) # [{; + return ::mlir::success(); + }(elemVal); + if (::mlir::failed(elemRead)) + return ::mlir::failure(); + $_storage.push_back(std::move(elemVal)); + } + }]; + + let writeToMlirBytecode = [{ + $_writer.writeVarInt($_storage.size()); + for (const auto& elemVal : $_storage) { + [&]() { + }] # !subst("$_storage", "elemVal", elem.writeToMlirBytecode) #[{; + }(); + } + }]; + + // There's no hash_value for SmallVector, so we construct the ArrayRef ourselves. + // In the non-trivial case, we define a mapped range to get internal hash + // codes. + let hashProperty = !if(!empty(elem.hashProperty), + [{::llvm::hash_value(::llvm::ArrayRef<}] # elem.storageType # [{>{$_storage})}], + [{[&]() -> ::llvm::hash_code { + auto getElemHash = [](const auto& propStorage) -> ::llvm::hash_code { + return }] # !subst("$_storage", "propStorage", elem.hashProperty) # [{; + }; + auto mapped = ::llvm::map_range($_storage, getElemHash); + return ::llvm::hash_combine_range(mapped.begin(), mapped.end()); + }() + }]); +} + +class IntArrayProperty : + ArrayProperty> { + // Bring back the trivial conversions we don't get in the general case. + let convertFromAttribute = [{ + return convertFromAttribute($_storage, $_attr, $_diag); + }]; + let convertToAttribute = [{ + return convertToAttribute($_ctxt, $_storage); + }]; } /// Class for giving a property a default value. @@ -190,4 +498,193 @@ class DefaultValuedProperty +/// interfaced with as an std::optional.. +/// The syntax is `none` (or empty string if elided) for an absent value or +/// `some<[underlying property]>` when a value is set. +/// +/// As a special exception, if the underlying property has an optional parser and +/// no default value (ex. an integer property), the printer will skip the `some` +/// bracketing and delegate to the optional parser. In that case, the syntax is the +/// syntax of the underlying property, or the keyword `none` in the rare cases that +/// it is needed. This behavior can be disabled by setting `canDelegateParsing` to 0. +class OptionalProperty + : Property<"std::optional<" # p.storageType # ">", "optional " # p.summary> { + + // In the cases where the underlying attribute is plain old data that's passed by + // value, the conversion code is trivial. + defvar hasTrivialStorage = !and(!eq(p.convertFromStorage, "$_storage"), + !eq(p.assignToStorage, "$_storage = $_value"), + !eq(p.storageType, p.interfaceType)); + + defvar delegatesParsing = !and(!empty(p.defaultValue), + !not(!empty(p.optionalParser)), canDelegateParsing); + + let interfaceType = "std::optional<" # p.interfaceType # ">"; + let defaultValue = "std::nullopt"; + + let convertFromStorage = !if(hasTrivialStorage, + p.convertFromStorage, + [{($_storage.has_value() ? std::optional<}] # p.interfaceType # ">{" + # !subst("$_storage", "(*($_storage))", p.convertFromStorage) + # [{} : std::nullopt)}]); + let assignToStorage = !if(hasTrivialStorage, + p.assignToStorage, + [{[&]() { + if (!$_value.has_value()) { + $_storage = std::nullopt; + return; + } + }] # _makePropStorage.ret # [{ + [&](}] # p.storageType # [{& propStorage) { + }] # !subst("$_storage", "propStorage", + !subst("$_value", "(*($_value))", p.assignToStorage)) # [{; + }(presentVal); + $_storage = std::move(presentVal); + }()}]); + + let convertFromAttribute = [{ + auto arrayAttr = ::llvm::dyn_cast<::mlir::ArrayAttr>($_attr); + if (!arrayAttr) + return $_diag() << "expected optional properties to materialize as arrays"; + if (arrayAttr.size() > 1) + return $_diag() << "expected optional properties to become 0- or 1-element arrays"; + if (arrayAttr.empty()) { + $_storage = std::nullopt; + return ::mlir::success(); + } + ::mlir::Attribute presentAttr = arrayAttr[0]; + }] # _makePropStorage.ret # [{ + auto presentRes = [&](Attribute propAttr, }] # p.storageType # [{& propStorage) -> ::mlir::LogicalResult { + }] # !subst("$_storage", "propStorage", + !subst("$_attr", "propAttr", p.convertFromAttribute)) # [{ + }(presentAttr, presentVal); + if (::mlir::failed(presentRes)) + return ::mlir::failure(); + $_storage = std::move(presentVal); + return ::mlir::success(); + }]; + + let convertToAttribute = [{ + if (!$_storage.has_value()) { + return ::mlir::ArrayAttr::get($_ctxt, {}); + } + auto attr = [&]() -> ::mlir::Attribute { + }] # !subst("$_storage", "(*($_storage))", p.convertToAttribute) # [{ + }(); + return ::mlir::ArrayAttr::get($_ctxt, {attr}); + }]; + + defvar delegatedParserBegin = [{ + if (::mlir::succeeded($_parser.parseOptionalKeyword("none"))) { + $_storage = std::nullopt; + return ::mlir::success(); + } + }] #_makePropStorage.ret # [{ + auto delegParseResult = [&](}] # p.storageType # [{& propStorage) -> ::mlir::OptionalParseResult { + }] # !subst("$_storage", "propStorage", p.optionalParser) # [{ + return ::mlir::success(); + }(presentVal); + if (!delegParseResult.has_value()) { + }]; + + defvar delegatedParserEnd = [{ + } + if (delegParseResult.has_value() && ::mlir::failed(*delegParseResult)) + return ::mlir::failure(); + $_storage = std::move(presentVal); + return ::mlir::success(); + }]; + // If we're being explicitly called for our parser, we're expecting to have been + // printede into a context where the default value isn't elided. Therefore, + // not-present from the underlying parser is a failure. + defvar delegatedParser = delegatedParserBegin # [{ + return ::mlir::failure(); + }] # delegatedParserEnd; + defvar delegatedOptionalParser = delegatedParserBegin # [{ + return std::nullopt; + }] # delegatedParserEnd; + + defvar generalParserBegin = [{ + ::llvm::StringRef keyword; + if (::mlir::failed($_parser.parseOptionalKeyword(&keyword, {"none", "some"}))) { + }]; + defvar generalParserEnd = [{ + } + if (keyword == "none") { + $_storage = std::nullopt; + return ::mlir::success(); + } + if (::mlir::failed($_parser.parseLess())) + return ::mlir::failure(); + }] # _makePropStorage.ret # [{ + auto presentParse = [&](}] # p.storageType # [{& propStorage) -> ::mlir::ParseResult { + }] # !subst("$_storage", "propStorage", p.parser) # [{ + return ::mlir::success(); + }(presentVal); + if (presentParse || $_parser.parseGreater()) + return ::mlir::failure(); + $_storage = std::move(presentVal); + }]; + defvar generalParser = generalParserBegin # [{ + return $_parser.emitError($_parser.getCurrentLocation(), "expected 'none' or 'some'"); + }] # generalParserEnd; + defvar generalOptionalParser = generalParserBegin # [{ + return std::nullopt; + }] # generalParserEnd; + + let parser = !if(delegatesParsing, delegatedParser, generalParser); + let optionalParser = !if(delegatesParsing, + delegatedOptionalParser, generalOptionalParser); + + defvar delegatedPrinter = [{ + [&]() { + if (!$_storage.has_value()) { + $_printer << "none"; + return; + } + }] # !subst("$_storage", "(*($_storage))", p.printer) # [{; + }()}]; + defvar generalPrinter = [{ + [&]() { + if (!$_storage.has_value()) { + $_printer << "none"; + return; + } + $_printer << "some<"; + }] # !subst("$_storage", "(*($_storage))", p.printer) # [{; + $_printer << ">"; + }()}]; + let printer = !if(delegatesParsing, delegatedPrinter, generalPrinter); + + let readFromMlirBytecode = [{ + bool isPresent = false; + if (::mlir::failed($_reader.readBool(isPresent))) + return ::mlir::failure(); + if (!isPresent) { + $_storage = std::nullopt; + return ::mlir::success(); + } + }] # _makePropStorage.ret # [{ + auto presentResult = [&](}] # p.storageType # [{& propStorage) -> ::mlir::LogicalResult { + }] # !subst("$_storage", "propStorage", p.readFromMlirBytecode) # [{; + return ::mlir::success(); + }(presentVal); + if (::mlir::failed(presentResult)) + return ::mlir::failure(); + $_storage = std::move(presentVal); + }]; + let writeToMlirBytecode = [{ + $_writer.writeOwnedBool($_storage.has_value()); + if (!$_storage.has_value()) + return; + }] # !subst("$_storage", "(*($_storage))", p.writeToMlirBytecode); + + let hashProperty = !if(!empty(p.hashProperty), p.hashProperty, + [{ ::llvm::hash_value($_storage.has_value() ? std::optional<::llvm::hash_code>{}] # + !subst("$_storage", "(*($_storage))", p.hashProperty) #[{} : std::nullopt) }]); + assert !or(!not(delegatesParsing), !eq(defaultValue, "std::nullopt")), + "For delegated parsing to be used, the default value must be nullopt. " # + "To use a non-trivial default, set the canDelegateParsing argument to 0"; +} +#endif // PROPERTIES \ No newline at end of file diff --git a/mlir/include/mlir/TableGen/Operator.h b/mlir/include/mlir/TableGen/Operator.h index cc5853c044e9..768291a3a726 100644 --- a/mlir/include/mlir/TableGen/Operator.h +++ b/mlir/include/mlir/TableGen/Operator.h @@ -384,7 +384,7 @@ private: SmallVector attributes; /// The properties of the op. - SmallVector properties; + SmallVector properties; /// The arguments of the op (operands and native attributes). SmallVector arguments; diff --git a/mlir/include/mlir/TableGen/Property.h b/mlir/include/mlir/TableGen/Property.h index d0d6f4940c7c..702e6756e6a9 100644 --- a/mlir/include/mlir/TableGen/Property.h +++ b/mlir/include/mlir/TableGen/Property.h @@ -35,12 +35,20 @@ class Property { public: explicit Property(const llvm::Record *record); explicit Property(const llvm::DefInit *init); - Property(StringRef storageType, StringRef interfaceType, - StringRef convertFromStorageCall, StringRef assignToStorageCall, - StringRef convertToAttributeCall, StringRef convertFromAttributeCall, + Property(StringRef summary, StringRef description, StringRef storageType, + StringRef interfaceType, StringRef convertFromStorageCall, + StringRef assignToStorageCall, StringRef convertToAttributeCall, + StringRef convertFromAttributeCall, StringRef parserCall, + StringRef optionalParserCall, StringRef printerCall, StringRef readFromMlirBytecodeCall, StringRef writeToMlirBytecodeCall, StringRef hashPropertyCall, - StringRef defaultValue); + StringRef defaultValue, StringRef storageTypeValueOverride); + + // Returns the summary (for error messages) of this property's type. + StringRef getSummary() const { return summary; } + + // Returns the description of this property. + StringRef getDescription() const { return description; } // Returns the storage type. StringRef getStorageType() const { return storageType; } @@ -66,6 +74,19 @@ public: return convertFromAttributeCall; } + // Returns the method call which parses this property from textual MLIR. + StringRef getParserCall() const { return parserCall; } + + // Returns true if this property has defined an optional parser. + bool hasOptionalParser() const { return !optionalParserCall.empty(); } + + // Returns the method call which optionally parses this property from textual + // MLIR. + StringRef getOptionalParserCall() const { return optionalParserCall; } + + // Returns the method call which prints this property to textual MLIR. + StringRef getPrinterCall() const { return printerCall; } + // Returns the method call which reads this property from // bytecode and assign it to the storage. StringRef getReadFromMlirBytecodeCall() const { @@ -87,6 +108,24 @@ public: // Returns the default value for this Property. StringRef getDefaultValue() const { return defaultValue; } + // Returns whether this Property has a default storage-type value that is + // distinct from its default interface-type value. + bool hasStorageTypeValueOverride() const { + return !storageTypeValueOverride.empty(); + } + + StringRef getStorageTypeValueOverride() const { + return storageTypeValueOverride; + } + + // Returns this property's TableGen def-name. + StringRef getPropertyDefName() const; + + // Returns the base-level property that this Property constraint is based on + // or the Property itself otherwise. (Note: there are currently no + // property constraints, this function is added for future-proofing) + Property getBaseProperty() const; + // Returns the TableGen definition this Property was constructed from. const llvm::Record &getDef() const { return *def; } @@ -95,16 +134,22 @@ private: const llvm::Record *def; // Elements describing a Property, in general fetched from the record. + StringRef summary; + StringRef description; StringRef storageType; StringRef interfaceType; StringRef convertFromStorageCall; StringRef assignToStorageCall; StringRef convertToAttributeCall; StringRef convertFromAttributeCall; + StringRef parserCall; + StringRef optionalParserCall; + StringRef printerCall; StringRef readFromMlirBytecodeCall; StringRef writeToMlirBytecodeCall; StringRef hashPropertyCall; StringRef defaultValue; + StringRef storageTypeValueOverride; }; // A struct wrapping an op property and its name together diff --git a/mlir/lib/IR/ODSSupport.cpp b/mlir/lib/IR/ODSSupport.cpp index 6e968d62e61c..a55cdd834b39 100644 --- a/mlir/lib/IR/ODSSupport.cpp +++ b/mlir/lib/IR/ODSSupport.cpp @@ -33,6 +33,50 @@ Attribute mlir::convertToAttribute(MLIRContext *ctx, int64_t storage) { return IntegerAttr::get(IntegerType::get(ctx, 64), storage); } +LogicalResult +mlir::convertFromAttribute(int32_t &storage, Attribute attr, + function_ref emitError) { + auto valueAttr = dyn_cast(attr); + if (!valueAttr) { + emitError() << "expected IntegerAttr for key `value`"; + return failure(); + } + storage = valueAttr.getValue().getSExtValue(); + return success(); +} +Attribute mlir::convertToAttribute(MLIRContext *ctx, int32_t storage) { + return IntegerAttr::get(IntegerType::get(ctx, 32), storage); +} + +LogicalResult +mlir::convertFromAttribute(std::string &storage, Attribute attr, + function_ref emitError) { + auto valueAttr = dyn_cast(attr); + if (!valueAttr) + return emitError() + << "expected string property to come from string attribute"; + storage = valueAttr.getValue().str(); + return success(); +} +Attribute mlir::convertToAttribute(MLIRContext *ctx, + const std::string &storage) { + return StringAttr::get(ctx, storage); +} + +LogicalResult +mlir::convertFromAttribute(bool &storage, Attribute attr, + function_ref emitError) { + auto valueAttr = dyn_cast(attr); + if (!valueAttr) + return emitError() + << "expected string property to come from string attribute"; + storage = valueAttr.getValue(); + return success(); +} +Attribute mlir::convertToAttribute(MLIRContext *ctx, bool storage) { + return BoolAttr::get(ctx, storage); +} + template LogicalResult convertDenseArrayFromAttr(MutableArrayRef storage, Attribute attr, @@ -64,7 +108,34 @@ mlir::convertFromAttribute(MutableArrayRef storage, Attribute attr, "DenseI32ArrayAttr"); } +template +LogicalResult +convertDenseArrayFromAttr(SmallVectorImpl &storage, Attribute attr, + function_ref emitError, + StringRef denseArrayTyStr) { + auto valueAttr = dyn_cast(attr); + if (!valueAttr) { + emitError() << "expected " << denseArrayTyStr << " for key `value`"; + return failure(); + } + storage.resize_for_overwrite(valueAttr.size()); + llvm::copy(valueAttr.asArrayRef(), storage.begin()); + return success(); +} +LogicalResult +mlir::convertFromAttribute(SmallVectorImpl &storage, Attribute attr, + function_ref emitError) { + return convertDenseArrayFromAttr(storage, attr, emitError, + "DenseI64ArrayAttr"); +} +LogicalResult +mlir::convertFromAttribute(SmallVectorImpl &storage, Attribute attr, + function_ref emitError) { + return convertDenseArrayFromAttr(storage, attr, emitError, + "DenseI32ArrayAttr"); +} + Attribute mlir::convertToAttribute(MLIRContext *ctx, ArrayRef storage) { return DenseI64ArrayAttr::get(ctx, storage); -} +} \ No newline at end of file diff --git a/mlir/lib/TableGen/Property.cpp b/mlir/lib/TableGen/Property.cpp index e61d2fd2480f..9f4b9ce1a294 100644 --- a/mlir/lib/TableGen/Property.cpp +++ b/mlir/lib/TableGen/Property.cpp @@ -33,16 +33,23 @@ static StringRef getValueAsString(const Init *init) { } Property::Property(const Record *def) - : Property(getValueAsString(def->getValueInit("storageType")), - getValueAsString(def->getValueInit("interfaceType")), - getValueAsString(def->getValueInit("convertFromStorage")), - getValueAsString(def->getValueInit("assignToStorage")), - getValueAsString(def->getValueInit("convertToAttribute")), - getValueAsString(def->getValueInit("convertFromAttribute")), - getValueAsString(def->getValueInit("readFromMlirBytecode")), - getValueAsString(def->getValueInit("writeToMlirBytecode")), - getValueAsString(def->getValueInit("hashProperty")), - getValueAsString(def->getValueInit("defaultValue"))) { + : Property( + getValueAsString(def->getValueInit("summary")), + getValueAsString(def->getValueInit("description")), + getValueAsString(def->getValueInit("storageType")), + getValueAsString(def->getValueInit("interfaceType")), + getValueAsString(def->getValueInit("convertFromStorage")), + getValueAsString(def->getValueInit("assignToStorage")), + getValueAsString(def->getValueInit("convertToAttribute")), + getValueAsString(def->getValueInit("convertFromAttribute")), + getValueAsString(def->getValueInit("parser")), + getValueAsString(def->getValueInit("optionalParser")), + getValueAsString(def->getValueInit("printer")), + getValueAsString(def->getValueInit("readFromMlirBytecode")), + getValueAsString(def->getValueInit("writeToMlirBytecode")), + getValueAsString(def->getValueInit("hashProperty")), + getValueAsString(def->getValueInit("defaultValue")), + getValueAsString(def->getValueInit("storageTypeValueOverride"))) { this->def = def; assert((def->isSubClassOf("Property") || def->isSubClassOf("Attr")) && "must be subclass of TableGen 'Property' class"); @@ -50,22 +57,44 @@ Property::Property(const Record *def) Property::Property(const DefInit *init) : Property(init->getDef()) {} -Property::Property(StringRef storageType, StringRef interfaceType, +Property::Property(StringRef summary, StringRef description, + StringRef storageType, StringRef interfaceType, StringRef convertFromStorageCall, StringRef assignToStorageCall, StringRef convertToAttributeCall, - StringRef convertFromAttributeCall, + StringRef convertFromAttributeCall, StringRef parserCall, + StringRef optionalParserCall, StringRef printerCall, StringRef readFromMlirBytecodeCall, StringRef writeToMlirBytecodeCall, - StringRef hashPropertyCall, StringRef defaultValue) - : storageType(storageType), interfaceType(interfaceType), + StringRef hashPropertyCall, StringRef defaultValue, + StringRef storageTypeValueOverride) + : summary(summary), description(description), storageType(storageType), + interfaceType(interfaceType), convertFromStorageCall(convertFromStorageCall), assignToStorageCall(assignToStorageCall), convertToAttributeCall(convertToAttributeCall), convertFromAttributeCall(convertFromAttributeCall), + parserCall(parserCall), optionalParserCall(optionalParserCall), + printerCall(printerCall), readFromMlirBytecodeCall(readFromMlirBytecodeCall), writeToMlirBytecodeCall(writeToMlirBytecodeCall), - hashPropertyCall(hashPropertyCall), defaultValue(defaultValue) { + hashPropertyCall(hashPropertyCall), defaultValue(defaultValue), + storageTypeValueOverride(storageTypeValueOverride) { if (storageType.empty()) storageType = "Property"; } + +StringRef Property::getPropertyDefName() const { + if (def->isAnonymous()) { + return getBaseProperty().def->getName(); + } + return def->getName(); +} + +Property Property::getBaseProperty() const { + if (const auto *defInit = + llvm::dyn_cast(def->getValueInit("baseProperty"))) { + return Property(defInit).getBaseProperty(); + } + return *this; +} \ No newline at end of file diff --git a/mlir/test/lib/Dialect/Test/TestFormatUtils.cpp b/mlir/test/lib/Dialect/Test/TestFormatUtils.cpp index 6e75dd393228..9ed1b3a47be3 100644 --- a/mlir/test/lib/Dialect/Test/TestFormatUtils.cpp +++ b/mlir/test/lib/Dialect/Test/TestFormatUtils.cpp @@ -297,11 +297,17 @@ void test::printSwitchCases(OpAsmPrinter &p, Operation *op, // CustomUsingPropertyInCustom //===----------------------------------------------------------------------===// -bool test::parseUsingPropertyInCustom(OpAsmParser &parser, int64_t value[3]) { - return parser.parseLSquare() || parser.parseInteger(value[0]) || - parser.parseComma() || parser.parseInteger(value[1]) || - parser.parseComma() || parser.parseInteger(value[2]) || - parser.parseRSquare(); +bool test::parseUsingPropertyInCustom(OpAsmParser &parser, + SmallVector &value) { + auto elemParser = [&]() { + int64_t v = 0; + if (failed(parser.parseInteger(v))) + return failure(); + value.push_back(v); + return success(); + }; + return failed(parser.parseCommaSeparatedList(OpAsmParser::Delimiter::Square, + elemParser)); } void test::printUsingPropertyInCustom(OpAsmPrinter &printer, Operation *op, diff --git a/mlir/test/lib/Dialect/Test/TestFormatUtils.h b/mlir/test/lib/Dialect/Test/TestFormatUtils.h index 7e9cd834278e..6d4df7d82ffa 100644 --- a/mlir/test/lib/Dialect/Test/TestFormatUtils.h +++ b/mlir/test/lib/Dialect/Test/TestFormatUtils.h @@ -160,7 +160,8 @@ void printSwitchCases(mlir::OpAsmPrinter &p, mlir::Operation *op, // CustomUsingPropertyInCustom //===----------------------------------------------------------------------===// -bool parseUsingPropertyInCustom(mlir::OpAsmParser &parser, int64_t value[3]); +bool parseUsingPropertyInCustom(mlir::OpAsmParser &parser, + llvm::SmallVector &value); void printUsingPropertyInCustom(mlir::OpAsmPrinter &printer, mlir::Operation *op, diff --git a/mlir/test/lib/Dialect/Test/TestOps.td b/mlir/test/lib/Dialect/Test/TestOps.td index 9450764fcb1d..70579cf5f3e1 100644 --- a/mlir/test/lib/Dialect/Test/TestOps.td +++ b/mlir/test/lib/Dialect/Test/TestOps.td @@ -2947,11 +2947,18 @@ def TestVersionedOpC : TEST_Op<"versionedC"> { // Op with a properties struct defined inline. def TestOpWithProperties : TEST_Op<"with_properties"> { - let assemblyFormat = "prop-dict attr-dict"; + let assemblyFormat = [{ + `a` `=` $a `,` + `b` `=` $b `,` + `c` `=` $c `,` + `flag` `=` $flag `,` + `array` `=` $array attr-dict}]; let arguments = (ins - IntProperty<"int64_t">:$a, + I64Property:$a, StrAttr:$b, // Attributes can directly be used here. - ArrayProperty<"int64_t", 4>:$array // example of an array + StringProperty:$c, + BoolProperty:$flag, + IntArrayProperty<"int64_t">:$array // example of an array ); } @@ -2974,7 +2981,7 @@ def TestOpWithPropertiesAndInferredType // Demonstrate how to wrap an existing C++ class named MyPropStruct. def MyStructProperty : Property<"MyPropStruct"> { - let convertToAttribute = "$_storage.asAttribute($_ctxt)"; + let convertToAttribute = "return $_storage.asAttribute($_ctxt);"; let convertFromAttribute = "return MyPropStruct::setFromAttr($_storage, $_attr, $_diag);"; let hashProperty = "$_storage.hash();"; } @@ -2988,14 +2995,14 @@ def TestOpWithWrappedProperties : TEST_Op<"with_wrapped_properties"> { def TestOpUsingPropertyInCustom : TEST_Op<"using_property_in_custom"> { let assemblyFormat = "custom($prop) attr-dict"; - let arguments = (ins ArrayProperty<"int64_t", 3>:$prop); + let arguments = (ins IntArrayProperty<"int64_t">:$prop); } def TestOpUsingPropertyInCustomAndOther : TEST_Op<"using_property_in_custom_and_other"> { let assemblyFormat = "custom($prop) prop-dict attr-dict"; let arguments = (ins - ArrayProperty<"int64_t", 3>:$prop, + IntArrayProperty<"int64_t">:$prop, IntProperty<"int64_t">:$other ); } @@ -3021,7 +3028,7 @@ def TestOpUsingIntPropertyWithWorseBytecode def PropertiesWithCustomPrint : Property<"PropertiesWithCustomPrint"> { let convertToAttribute = [{ - getPropertiesAsAttribute($_ctxt, $_storage) + return getPropertiesAsAttribute($_ctxt, $_storage); }]; let convertFromAttribute = [{ return setPropertiesFromAttribute($_storage, $_attr, $_diag); @@ -3085,7 +3092,7 @@ def TestOpWithNiceProperties : TEST_Op<"with_nice_properties"> { def VersionedProperties : Property<"VersionedProperties"> { let convertToAttribute = [{ - getPropertiesAsAttribute($_ctxt, $_storage) + return getPropertiesAsAttribute($_ctxt, $_storage); }]; let convertFromAttribute = [{ return setPropertiesFromAttribute($_storage, $_attr, $_diag); @@ -3131,13 +3138,65 @@ def TestOpWithVersionedProperties : TEST_Op<"with_versioned_properties"> { } def TestOpWithDefaultValuedProperties : TEST_Op<"with_default_valued_properties"> { - let assemblyFormat = "prop-dict attr-dict"; - let arguments = (ins DefaultValuedAttr:$a); + let assemblyFormat = [{ + ($a^) : (`na`)? + ($b^)? + ($c^)? + ($unit^)? + attr-dict + }]; + let arguments = (ins DefaultValuedAttr:$a, + DefaultValuedProperty:$b, + DefaultValuedProperty, "-1">:$c, + UnitProperty:$unit); } def TestOpWithOptionalProperties : TEST_Op<"with_optional_properties"> { - let assemblyFormat = "prop-dict attr-dict"; - let arguments = (ins OptionalAttr:$a, OptionalAttr:$b); +let assemblyFormat = [{ + (`anAttr` `=` $anAttr^)? + (`simple` `=` $simple^)? + (`nonTrivialStorage` `=` $nonTrivialStorage^)? + (`hasDefault` `=` $hasDefault^)? + (`nested` `=` $nested^)? + (`longSyntax` `=` $longSyntax^)? + (`hasUnit` $hasUnit^)? + (`maybeUnit` `=` $maybeUnit^)? + attr-dict + }]; + let arguments = (ins + OptionalAttr:$anAttr, + OptionalProperty:$simple, + OptionalProperty:$nonTrivialStorage, + // Confirm that properties with default values now default to nullopt and have + // the long syntax. + OptionalProperty>:$hasDefault, + OptionalProperty>:$nested, + OptionalProperty:$longSyntax, + UnitProperty:$hasUnit, + OptionalProperty:$maybeUnit); +} + +def TestOpWithArrayProperties : TEST_Op<"with_array_properties"> { + let assemblyFormat = [{ + `ints` `=` $ints + `strings` `=` $strings + `nested` `=` $nested + `opt` `=` $opt + `explicitOptions` `=` $explicitOptions + `explicitUnits` `=` $explicitUnits + ($hasDefault^ `thats_has_default`)? + attr-dict + }]; + let arguments = (ins + ArrayProperty:$ints, + ArrayProperty:$strings, + ArrayProperty>:$nested, + OptionalProperty>:$opt, + ArrayProperty>:$explicitOptions, + ArrayProperty:$explicitUnits, + DefaultValuedProperty, + "::llvm::ArrayRef{}", "::llvm::SmallVector{}">:$hasDefault + ); } //===----------------------------------------------------------------------===// diff --git a/mlir/test/lib/Dialect/Test/TestOpsSyntax.td b/mlir/test/lib/Dialect/Test/TestOpsSyntax.td index 3129085058fd..795b9da95563 100644 --- a/mlir/test/lib/Dialect/Test/TestOpsSyntax.td +++ b/mlir/test/lib/Dialect/Test/TestOpsSyntax.td @@ -86,6 +86,17 @@ def OIListTrivial : TEST_Op<"oilist_with_keywords_only"> { }]; } +// Ops related to OIList primitive +def OIListTrivialProperties : TEST_Op<"oilist_with_keywords_only_properties"> { + let arguments = (ins UnitProperty:$keyword, UnitProperty:$otherKeyword, + UnitProperty:$diffNameUnitPropertyKeyword); + let assemblyFormat = [{ + oilist( `keyword` $keyword + | `otherKeyword` $otherKeyword + | `thirdKeyword` $diffNameUnitPropertyKeyword) attr-dict + }]; +} + def OIListSimple : TEST_Op<"oilist_with_simple_args", [AttrSizedOperandSegments]> { let arguments = (ins Optional:$arg0, Optional:$arg1, @@ -392,6 +403,17 @@ def FormatOptionalUnitAttrNoElide let assemblyFormat = "($is_optional^)? attr-dict"; } +def FormatOptionalUnitProperty : TEST_Op<"format_optional_unit_property"> { + let arguments = (ins UnitProperty:$is_optional); + let assemblyFormat = "(`is_optional` $is_optional^)? attr-dict"; +} + +def FormatOptionalUnitPropertyNoElide + : TEST_Op<"format_optional_unit_property_no_elide"> { + let arguments = (ins UnitProperty:$is_optional); + let assemblyFormat = "($is_optional^)? attr-dict"; +} + def FormatOptionalEnumAttr : TEST_Op<"format_optional_enum_attr"> { let arguments = (ins OptionalAttr:$attr); let assemblyFormat = "($attr^)? attr-dict"; diff --git a/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp b/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp index 0fc750c7bbc8..9a1729892058 100644 --- a/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp +++ b/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp @@ -155,6 +155,36 @@ static const char *const valueRangeReturnCode = R"( std::next({0}, valueRange.first + valueRange.second)}; )"; +/// Parse operand/result segment_size property. +/// {0}: Number of elements in the segment array +static const char *const parseTextualSegmentSizeFormat = R"( + size_t i = 0; + auto parseElem = [&]() -> ::mlir::ParseResult { + if (i >= {0}) + return $_parser.emitError($_parser.getCurrentLocation(), + "expected `]` after {0} segment sizes"); + if (failed($_parser.parseInteger($_storage[i]))) + return ::mlir::failure(); + i += 1; + return ::mlir::success(); + }; + if (failed($_parser.parseCommaSeparatedList( + ::mlir::AsmParser::Delimeter::Square, parseElem))) + return failure(); + if (i < {0}) + return $_parser.emitError($_parser.getCurrentLocation(), + "expected {0} segment sizes, found only ") << i; + return success(); +)"; + +static const char *const printTextualSegmentSize = R"( + [&]() { + $_printer << '['; + ::llvm::interleaveComma($_storage, $_printer); + $_printer << ']'; + }() +)"; + /// Read operand/result segment_size from bytecode. static const char *const readBytecodeSegmentSizeNative = R"( if ($_reader.getBytecodeVersion() >= /*kNativePropertiesODSSegmentSize=*/6) @@ -422,8 +452,10 @@ private: // Property std::optional operandSegmentsSize; std::string operandSegmentsSizeStorage; + std::string operandSegmentsSizeParser; std::optional resultSegmentsSize; std::string resultSegmentsSizeStorage; + std::string resultSegmentsSizeParser; // Indices to store the position in the emission order of the operand/result // segment sizes attribute if emitted as part of the properties for legacy @@ -448,31 +480,40 @@ void OpOrAdaptorHelper::computeAttrMetadata() { {namedAttr.name, AttributeMetadata{namedAttr.name, !isOptional, attr}}); } - auto makeProperty = [&](StringRef storageType) { + auto makeProperty = [&](StringRef storageType, StringRef parserCall) { return Property( + /*summary=*/"", + /*description=*/"", /*storageType=*/storageType, /*interfaceType=*/"::llvm::ArrayRef", /*convertFromStorageCall=*/"$_storage", /*assignToStorageCall=*/ "::llvm::copy($_value, $_storage.begin())", /*convertToAttributeCall=*/ - "::mlir::DenseI32ArrayAttr::get($_ctxt, $_storage)", + "return ::mlir::DenseI32ArrayAttr::get($_ctxt, $_storage);", /*convertFromAttributeCall=*/ "return convertFromAttribute($_storage, $_attr, $_diag);", + /*parserCall=*/parserCall, + /*optionalParserCall=*/"", + /*printerCall=*/printTextualSegmentSize, /*readFromMlirBytecodeCall=*/readBytecodeSegmentSizeNative, /*writeToMlirBytecodeCall=*/writeBytecodeSegmentSizeNative, /*hashPropertyCall=*/ "::llvm::hash_combine_range(std::begin($_storage), " "std::end($_storage));", - /*StringRef defaultValue=*/""); + /*StringRef defaultValue=*/"", + /*storageTypeValueOverride=*/""); }; // Include key attributes from several traits as implicitly registered. if (op.getTrait("::mlir::OpTrait::AttrSizedOperandSegments")) { if (op.getDialect().usePropertiesForAttributes()) { operandSegmentsSizeStorage = llvm::formatv("std::array", op.getNumOperands()); - operandSegmentsSize = {"operandSegmentSizes", - makeProperty(operandSegmentsSizeStorage)}; + operandSegmentsSizeParser = + llvm::formatv(parseTextualSegmentSizeFormat, op.getNumOperands()); + operandSegmentsSize = { + "operandSegmentSizes", + makeProperty(operandSegmentsSizeStorage, operandSegmentsSizeParser)}; } else { attrMetadata.insert( {operandSegmentAttrName, AttributeMetadata{operandSegmentAttrName, @@ -484,8 +525,11 @@ void OpOrAdaptorHelper::computeAttrMetadata() { if (op.getDialect().usePropertiesForAttributes()) { resultSegmentsSizeStorage = llvm::formatv("std::array", op.getNumResults()); - resultSegmentsSize = {"resultSegmentSizes", - makeProperty(resultSegmentsSizeStorage)}; + resultSegmentsSizeParser = + llvm::formatv(parseTextualSegmentSizeFormat, op.getNumResults()); + resultSegmentsSize = { + "resultSegmentSizes", + makeProperty(resultSegmentsSizeStorage, resultSegmentsSizeParser)}; } else { attrMetadata.insert( {resultSegmentAttrName, @@ -572,6 +616,12 @@ private: void genPropertiesSupportForBytecode(ArrayRef attrOrProperties); + // Generates getters for the properties. + void genPropGetters(); + + // Generates seters for the properties. + void genPropSetters(); + // Generates getters for the attributes. void genAttrGetters(); @@ -1041,6 +1091,8 @@ OpEmitter::OpEmitter(const Operator &op, genNamedRegionGetters(); genNamedSuccessorGetters(); genPropertiesSupport(); + genPropGetters(); + genPropSetters(); genAttrGetters(); genAttrSetters(); genOptionalAttrRemovers(); @@ -1198,6 +1250,16 @@ void OpEmitter::genAttrNameGetters() { } } +// Emit the getter for a named property. +// It is templated to be shared between the Op and the adaptor class. +template +static void emitPropGetter(OpClassOrAdaptor &opClass, const Operator &op, + StringRef name, const Property &prop) { + auto *method = opClass.addInlineMethod(prop.getInterfaceType(), name); + ERROR_IF_PRUNED(method, name, op); + method->body() << formatv(" return getProperties().{0}();", name); +} + // Emit the getter for an attribute with the return type specified. // It is templated to be shared between the Op and the adaptor class. template @@ -1313,7 +1375,7 @@ void OpEmitter::genPropertiesSupport() { )decl"; const char *propFromAttrFmt = R"decl( auto setFromAttr = [] (auto &propStorage, ::mlir::Attribute propAttr, - ::llvm::function_ref<::mlir::InFlightDiagnostic()> emitError) {{ + ::llvm::function_ref<::mlir::InFlightDiagnostic()> emitError) -> ::mlir::LogicalResult {{ {0} }; {2}; @@ -1358,7 +1420,10 @@ void OpEmitter::genPropertiesSupport() { .addSubst("_storage", propertyStorage) .addSubst("_diag", propertyDiag)), name, getAttr); - if (prop.hasDefaultValue()) { + if (prop.hasStorageTypeValueOverride()) { + setPropMethod << formatv(attrGetDefaultFmt, name, + prop.getStorageTypeValueOverride()); + } else if (prop.hasDefaultValue()) { setPropMethod << formatv(attrGetDefaultFmt, name, prop.getDefaultValue()); } else { @@ -1409,8 +1474,10 @@ void OpEmitter::genPropertiesSupport() { const char *propToAttrFmt = R"decl( { const auto &propStorage = prop.{0}; - attrs.push_back(odsBuilder.getNamedAttr("{0}", - {1})); + auto attr = [&]() -> ::mlir::Attribute {{ + {1} + }(); + attrs.push_back(odsBuilder.getNamedAttr("{0}", attr)); } )decl"; for (const auto &attrOrProp : attrOrProperties) { @@ -1458,9 +1525,12 @@ void OpEmitter::genPropertiesSupport() { StringRef name = namedProperty->name; auto &prop = namedProperty->prop; FmtContext fctx; - hashMethod << formatv(propHashFmt, name, - tgfmt(prop.getHashPropertyCall(), - &fctx.addSubst("_storage", propertyStorage))); + if (!prop.getHashPropertyCall().empty()) { + hashMethod << formatv( + propHashFmt, name, + tgfmt(prop.getHashPropertyCall(), + &fctx.addSubst("_storage", propertyStorage))); + } } } hashMethod << " return llvm::hash_combine("; @@ -1468,8 +1538,13 @@ void OpEmitter::genPropertiesSupport() { attrOrProperties, hashMethod, [&](const ConstArgument &attrOrProp) { if (const auto *namedProperty = llvm::dyn_cast_if_present(attrOrProp)) { - hashMethod << "\n hash_" << namedProperty->name << "(prop." - << namedProperty->name << ")"; + if (!namedProperty->prop.getHashPropertyCall().empty()) { + hashMethod << "\n hash_" << namedProperty->name << "(prop." + << namedProperty->name << ")"; + } else { + hashMethod << "\n ::llvm::hash_value(prop." + << namedProperty->name << ")"; + } return; } const auto *namedAttr = @@ -1524,8 +1599,9 @@ void OpEmitter::genPropertiesSupport() { "\"{0}\") return ", resultSegmentAttrName); } - getInherentAttrMethod << tgfmt(prop.getConvertToAttributeCall(), &fctx) - << ";\n"; + getInherentAttrMethod << "[&]() -> ::mlir::Attribute { " + << tgfmt(prop.getConvertToAttributeCall(), &fctx) + << " }();\n"; if (name == operandSegmentAttrName) { setInherentAttrMethod @@ -1549,13 +1625,15 @@ void OpEmitter::genPropertiesSupport() { )decl", name); if (name == operandSegmentAttrName) { - populateInherentAttrsMethod - << formatv(" attrs.append(\"{0}\", {1});\n", operandSegmentAttrName, - tgfmt(prop.getConvertToAttributeCall(), &fctx)); + populateInherentAttrsMethod << formatv( + " attrs.append(\"{0}\", [&]() -> ::mlir::Attribute { {1} }());\n", + operandSegmentAttrName, + tgfmt(prop.getConvertToAttributeCall(), &fctx)); } else { - populateInherentAttrsMethod - << formatv(" attrs.append(\"{0}\", {1});\n", resultSegmentAttrName, - tgfmt(prop.getConvertToAttributeCall(), &fctx)); + populateInherentAttrsMethod << formatv( + " attrs.append(\"{0}\", [&]() -> ::mlir::Attribute { {1} }());\n", + resultSegmentAttrName, + tgfmt(prop.getConvertToAttributeCall(), &fctx)); } } getInherentAttrMethod << " return std::nullopt;\n"; @@ -1701,6 +1779,26 @@ void OpEmitter::genPropertiesSupportForBytecode( readPropertiesMethod << " return ::mlir::success();"; } +void OpEmitter::genPropGetters() { + for (const NamedProperty &prop : op.getProperties()) { + std::string name = op.getGetterName(prop.name); + emitPropGetter(opClass, op, name, prop.prop); + } +} + +void OpEmitter::genPropSetters() { + for (const NamedProperty &prop : op.getProperties()) { + std::string name = op.getSetterName(prop.name); + std::string argName = "new" + convertToCamelFromSnakeCase( + prop.name, /*capitalizeFirst=*/true); + auto *method = opClass.addInlineMethod( + "void", name, MethodParameter(prop.prop.getInterfaceType(), argName)); + if (!method) + return; + method->body() << formatv(" getProperties().{0}({1});", name, argName); + } +} + void OpEmitter::genAttrGetters() { FmtContext fctx; fctx.withBuilder("::mlir::Builder((*this)->getContext())"); @@ -2957,6 +3055,12 @@ void OpEmitter::buildParamList(SmallVectorImpl ¶mList, } // Add parameters for all arguments (operands and attributes). + // Track "attr-like" (property and attribute) optional values separate from + // attributes themselves so that the disambiguation code can look at the first + // attribute specifically when determining where to trim the optional-value + // list to avoid ambiguity while preserving the ability of all-property ops to + // use default parameters. + int defaultValuedAttrLikeStartIndex = op.getNumArgs(); int defaultValuedAttrStartIndex = op.getNumArgs(); // Successors and variadic regions go at the end of the parameter list, so no // default arguments are possible. @@ -2967,6 +3071,15 @@ void OpEmitter::buildParamList(SmallVectorImpl ¶mList, for (int i = op.getNumArgs() - 1; i >= 0; --i) { auto *namedAttr = llvm::dyn_cast_if_present(op.getArg(i)); + auto *namedProperty = + llvm::dyn_cast_if_present(op.getArg(i)); + if (namedProperty) { + Property prop = namedProperty->prop; + if (!prop.hasDefaultValue()) + break; + defaultValuedAttrLikeStartIndex = i; + continue; + } if (!namedAttr) break; @@ -2986,6 +3099,7 @@ void OpEmitter::buildParamList(SmallVectorImpl ¶mList, if (retType == "::llvm::APInt" || retType == "::llvm::APFloat") break; + defaultValuedAttrLikeStartIndex = i; defaultValuedAttrStartIndex = i; } } @@ -3001,8 +3115,10 @@ void OpEmitter::buildParamList(SmallVectorImpl ¶mList, if ((attrParamKind == AttrParamKind::WrappedAttr && canUseUnwrappedRawValue(attr)) || (attrParamKind == AttrParamKind::UnwrappedValue && - !canUseUnwrappedRawValue(attr))) + !canUseUnwrappedRawValue(attr))) { ++defaultValuedAttrStartIndex; + defaultValuedAttrLikeStartIndex = defaultValuedAttrStartIndex; + } } /// Collect any inferred attributes. @@ -3029,8 +3145,16 @@ void OpEmitter::buildParamList(SmallVectorImpl ¶mList, operand->isOptional()); continue; } - if (llvm::isa_and_present(arg)) { - // TODO + if (auto *propArg = llvm::dyn_cast_if_present(arg)) { + const Property &prop = propArg->prop; + StringRef type = prop.getInterfaceType(); + std::string defaultValue; + if (prop.hasDefaultValue() && i >= defaultValuedAttrLikeStartIndex) { + defaultValue = prop.getDefaultValue(); + } + bool isOptional = prop.hasDefaultValue(); + paramList.emplace_back(type, propArg->name, StringRef(defaultValue), + isOptional); continue; } const NamedAttribute &namedAttr = *arg.get(); @@ -3157,6 +3281,15 @@ void OpEmitter::genCodeForAddingArgAndRegionForBuilder( } } + // Push all properties to the result. + for (const auto &namedProp : op.getProperties()) { + // Use the setter from the Properties struct since the conversion from the + // interface type (used in the builder argument) to the storage type (used + // in the state) is not necessarily trivial. + std::string setterName = op.getSetterName(namedProp.name); + body << formatv(" {0}.getOrAddProperties().{1}({2});\n", + builderOpState, setterName, namedProp.name); + } // Push all attributes to the result. for (const auto &namedAttr : op.getAttributes()) { auto &attr = namedAttr.attr; @@ -3996,17 +4129,19 @@ OpOperandAdaptorEmitter::OpOperandAdaptorEmitter( // Generate the data member using the storage type. os << " using " << name << "Ty = " << prop.getStorageType() << ";\n" << " " << name << "Ty " << name; - if (prop.hasDefaultValue()) + if (prop.hasStorageTypeValueOverride()) + os << " = " << prop.getStorageTypeValueOverride(); + else if (prop.hasDefaultValue()) os << " = " << prop.getDefaultValue(); comparatorOs << " rhs." << name << " == this->" << name << " &&\n"; // Emit accessors using the interface type. const char *accessorFmt = R"decl(; - {0} get{1}() { + {0} get{1}() const { auto &propStorage = this->{2}; return {3}; } - void set{1}(const {0} &propValue) { + void set{1}({0} propValue) { auto &propStorage = this->{2}; {4}; } @@ -4274,6 +4409,11 @@ OpOperandAdaptorEmitter::OpOperandAdaptorEmitter( ERROR_IF_PRUNED(m, "Adaptor::getAttributes", op); m->body() << " return odsAttrs;"; } + for (auto &namedProp : op.getProperties()) { + std::string name = op.getGetterName(namedProp.name); + emitPropGetter(genericAdaptorBase, op, name, namedProp.prop); + } + for (auto &namedAttr : op.getAttributes()) { const auto &name = namedAttr.name; const auto &attr = namedAttr.attr; @@ -4564,4 +4704,4 @@ static mlir::GenRegistration genOpDefs("gen-op-defs", "Generate op definitions", [](const RecordKeeper &records, raw_ostream &os) { return emitOpDefs(records, os); - }); + }); \ No newline at end of file diff --git a/mlir/tools/mlir-tblgen/OpFormatGen.cpp b/mlir/tools/mlir-tblgen/OpFormatGen.cpp index a97d8760842a..2129c4325c0c 100644 --- a/mlir/tools/mlir-tblgen/OpFormatGen.cpp +++ b/mlir/tools/mlir-tblgen/OpFormatGen.cpp @@ -45,7 +45,7 @@ public: OpVariableElement(const VarT *var) : var(var) {} /// Get the variable. - const VarT *getVar() { return var; } + const VarT *getVar() const { return var; } protected: /// The op variable, e.g. a type or attribute constraint. @@ -64,11 +64,6 @@ struct AttributeVariable return attrType ? attrType->getBuilderCall() : std::nullopt; } - /// Return if this attribute refers to a UnitAttr. - bool isUnitAttr() const { - return var->attr.getBaseAttr().getAttrDefName() == "UnitAttr"; - } - /// Indicate if this attribute is printed "qualified" (that is it is /// prefixed with the `#dialect.mnemonic`). bool shouldBeQualified() { return shouldBeQualifiedFlag; } @@ -98,6 +93,42 @@ using SuccessorVariable = /// This class represents a variable that refers to a property argument. using PropertyVariable = OpVariableElement; + +/// LLVM RTTI helper for attribute-like variables, that is, attributes or +/// properties. This allows for common handling of attributes and properties in +/// parts of the code that are oblivious to whether something is stored as an +/// attribute or a property. +struct AttributeLikeVariable : public VariableElement { + enum { AttributeLike = 1 << 0 }; + + static bool classof(const VariableElement *ve) { + return ve->getKind() == VariableElement::Attribute || + ve->getKind() == VariableElement::Property; + } + + static bool classof(const FormatElement *fe) { + return isa(fe) && classof(cast(fe)); + } + + /// Returns true if the variable is a UnitAttr or a UnitProperty. + bool isUnit() const { + if (const auto *attr = dyn_cast(this)) + return attr->getVar()->attr.getBaseAttr().getAttrDefName() == "UnitAttr"; + if (const auto *prop = dyn_cast(this)) { + return prop->getVar()->prop.getBaseProperty().getPropertyDefName() == + "UnitProperty"; + } + llvm_unreachable("Type that wasn't listed in classof()"); + } + + StringRef getName() const { + if (const auto *attr = dyn_cast(this)) + return attr->getVar()->name; + if (const auto *prop = dyn_cast(this)) + return prop->getVar()->name; + llvm_unreachable("Type that wasn't listed in classof()"); + } +}; } // namespace //===----------------------------------------------------------------------===// @@ -214,11 +245,11 @@ public: /// If the parsing element is a single UnitAttr element, then it returns the /// attribute variable. Otherwise, returns nullptr. - AttributeVariable * - getUnitAttrParsingElement(ArrayRef pelement) { + AttributeLikeVariable * + getUnitVariableParsingElement(ArrayRef pelement) { if (pelement.size() == 1) { - auto *attrElem = dyn_cast(pelement[0]); - if (attrElem && attrElem->isUnitAttr()) + auto *attrElem = dyn_cast(pelement[0]); + if (attrElem && attrElem->isUnit()) return attrElem; } return nullptr; @@ -488,6 +519,36 @@ const char *const enumAttrParserCode = R"( } )"; +/// The code snippet used to generate a parser call for a property. +/// {0}: The name of the property +/// {1}: The C++ class name of the operation +/// {2}: The property's parser code with appropriate substitutions performed +/// {3}: The description of the expected property for the error message. +const char *const propertyParserCode = R"( + auto {0}PropLoc = parser.getCurrentLocation(); + auto {0}PropParseResult = [&](auto& propStorage) -> ::mlir::ParseResult {{ + {2} + return ::mlir::success(); + }(result.getOrAddProperties<{1}::Properties>().{0}); + if (failed({0}PropParseResult)) {{ + return parser.emitError({0}PropLoc, "invalid value for property {0}, expected {3}"); + } +)"; + +/// The code snippet used to generate a parser call for a property. +/// {0}: The name of the property +/// {1}: The C++ class name of the operation +/// {2}: The property's parser code with appropriate substitutions performed +const char *const optionalPropertyParserCode = R"( + auto {0}PropParseResult = [&](auto& propStorage) -> ::mlir::OptionalParseResult {{ + {2} + return ::mlir::success(); + }(result.getOrAddProperties<{1}::Properties>().{0}); + if ({0}PropParseResult.has_value() && failed(*{0}PropParseResult)) {{ + return ::mlir::failure(); + } +)"; + /// The code snippet used to generate a parser call for an operand. /// /// {0}: The name of the operand. @@ -796,9 +857,9 @@ static void genElementParserStorage(FormatElement *element, const Operator &op, // If the anchor is a unit attribute, it won't be parsed directly so elide // it. - auto *anchor = dyn_cast(optional->getAnchor()); + auto *anchor = dyn_cast(optional->getAnchor()); FormatElement *elidedAnchorElement = nullptr; - if (anchor && anchor != elements.front() && anchor->isUnitAttr()) + if (anchor && anchor != elements.front() && anchor->isUnit()) elidedAnchorElement = anchor; for (FormatElement *childElement : elements) if (childElement != elidedAnchorElement) @@ -808,7 +869,7 @@ static void genElementParserStorage(FormatElement *element, const Operator &op, } else if (auto *oilist = dyn_cast(element)) { for (ArrayRef pelement : oilist->getParsingElements()) { - if (!oilist->getUnitAttrParsingElement(pelement)) + if (!oilist->getUnitVariableParsingElement(pelement)) for (FormatElement *element : pelement) genElementParserStorage(element, op, body); } @@ -1049,7 +1110,6 @@ static void genCustomDirectiveParser(CustomDirective *dir, MethodBody &body, body << llvm::formatv(" result.addAttribute(\"{0}\", {0}Attr);\n", var->name); } - } else if (auto *operand = dyn_cast(param)) { const NamedTypeConstraint *var = operand->getVar(); if (var->isOptional()) { @@ -1137,6 +1197,29 @@ static void genEnumAttrParser(const NamedAttribute *var, MethodBody &body, validCaseKeywordsStr, errorMessage, attrAssignment); } +// Generate the parser for a property. +static void genPropertyParser(PropertyVariable *propVar, MethodBody &body, + StringRef opCppClassName, + bool requireParse = true) { + StringRef name = propVar->getVar()->name; + const Property &prop = propVar->getVar()->prop; + bool parseOptionally = + prop.hasDefaultValue() && !requireParse && prop.hasOptionalParser(); + FmtContext fmtContext; + fmtContext.addSubst("_parser", "parser"); + fmtContext.addSubst("_ctxt", "parser.getContext()"); + fmtContext.addSubst("_storage", "propStorage"); + + if (parseOptionally) { + body << formatv(optionalPropertyParserCode, name, opCppClassName, + tgfmt(prop.getOptionalParserCall(), &fmtContext)); + } else { + body << formatv(propertyParserCode, name, opCppClassName, + tgfmt(prop.getParserCall(), &fmtContext), + prop.getSummary()); + } +} + // Generate the parser for an attribute. static void genAttrParser(AttributeVariable *attr, MethodBody &body, FmtContext &attrTypeCtx, bool parseAsOptional, @@ -1213,14 +1296,16 @@ if (!dict) { } )decl"; - // TODO: properties might be optional as well. + // {0}: fromAttribute call + // {1}: property name + // {2}: isRequired const char *propFromAttrFmt = R"decl( auto setFromAttr = [] (auto &propStorage, ::mlir::Attribute propAttr, - ::llvm::function_ref<::mlir::InFlightDiagnostic()> emitError) {{ + ::llvm::function_ref<::mlir::InFlightDiagnostic()> emitError) -> ::mlir::LogicalResult {{ {0}; }; auto attr = dict.get("{1}"); -if (!attr) {{ +if (!attr && {2}) {{ emitError() << "expected key entry for {1} in DictionaryAttr to set " "Properties."; return ::mlir::failure(); @@ -1238,13 +1323,14 @@ if (::mlir::failed(setFromAttr(prop.{1}, attr, emitError))) StringRef name = namedProperty.name; const Property &prop = namedProperty.prop; + bool isRequired = !prop.hasDefaultValue(); FmtContext fctx; body << formatv(propFromAttrFmt, tgfmt(prop.getConvertFromAttributeCall(), &fctx.addSubst("_attr", "propAttr") .addSubst("_storage", "propStorage") .addSubst("_diag", "emitError")), - name); + name, isRequired); } // Generate the setter for any attribute not parsed elsewhere. @@ -1331,20 +1417,24 @@ void OperationFormat::genElementParser(FormatElement *element, MethodBody &body, // If the anchor is a unit attribute, we don't need to print it. When // parsing, we will add this attribute if this group is present. FormatElement *elidedAnchorElement = nullptr; - auto *anchorAttr = dyn_cast(optional->getAnchor()); - if (anchorAttr && anchorAttr != firstElement && - anchorAttr->isUnitAttr()) { - elidedAnchorElement = anchorAttr; + auto *anchorVar = dyn_cast(optional->getAnchor()); + if (anchorVar && anchorVar != firstElement && anchorVar->isUnit()) { + elidedAnchorElement = anchorVar; if (!thenGroup == optional->isInverted()) { - // Add the anchor unit attribute to the operation state. - if (useProperties) { + // Add the anchor unit attribute or property to the operation state + // or set the property to true. + if (isa(anchorVar)) { + body << formatv( + " result.getOrAddProperties<{1}::Properties>().{0} = true;", + anchorVar->getName(), opCppClassName); + } else if (useProperties) { body << formatv( " result.getOrAddProperties<{1}::Properties>().{0} = " "parser.getBuilder().getUnitAttr();", - anchorAttr->getVar()->name, opCppClassName); + anchorVar->getName(), opCppClassName); } else { - body << " result.addAttribute(\"" << anchorAttr->getVar()->name + body << " result.addAttribute(\"" << anchorVar->getName() << "\", parser.getBuilder().getUnitAttr());\n"; } } @@ -1368,6 +1458,12 @@ void OperationFormat::genElementParser(FormatElement *element, MethodBody &body, genAttrParser(attrVar, body, attrTypeCtx, /*parseAsOptional=*/true, useProperties, opCppClassName); body << " if (" << attrVar->getVar()->name << "Attr) {\n"; + } else if (auto *propVar = dyn_cast(firstElement)) { + genPropertyParser(propVar, body, opCppClassName, /*requireParse=*/false); + body << llvm::formatv("if ({0}PropParseResult.has_value() && " + "succeeded(*{0}PropParseResult)) ", + propVar->getVar()->name) + << " {\n"; } else if (auto *literal = dyn_cast(firstElement)) { body << " if (::mlir::succeeded(parser.parseOptional"; genLiteralParser(literal->getSpelling(), body); @@ -1430,15 +1526,19 @@ void OperationFormat::genElementParser(FormatElement *element, MethodBody &body, body << ")) {\n"; StringRef lelementName = lelement->getSpelling(); body << formatv(oilistParserCode, lelementName); - if (AttributeVariable *unitAttrElem = - oilist->getUnitAttrParsingElement(pelement)) { - if (useProperties) { + if (AttributeLikeVariable *unitVarElem = + oilist->getUnitVariableParsingElement(pelement)) { + if (isa(unitVarElem)) { + body << formatv( + " result.getOrAddProperties<{1}::Properties>().{0} = true;", + unitVarElem->getName(), opCppClassName); + } else if (useProperties) { body << formatv( " result.getOrAddProperties<{1}::Properties>().{0} = " "parser.getBuilder().getUnitAttr();", - unitAttrElem->getVar()->name, opCppClassName); + unitVarElem->getName(), opCppClassName); } else { - body << " result.addAttribute(\"" << unitAttrElem->getVar()->name + body << " result.addAttribute(\"" << unitVarElem->getName() << "\", UnitAttr::get(parser.getContext()));\n"; } } else { @@ -1468,6 +1568,8 @@ void OperationFormat::genElementParser(FormatElement *element, MethodBody &body, (genCtx == GenContext::Normal && attr->getVar()->attr.isOptional()); genAttrParser(attr, body, attrTypeCtx, parseAsOptional, useProperties, opCppClassName); + } else if (auto *prop = dyn_cast(element)) { + genPropertyParser(prop, body, opCppClassName); } else if (auto *operand = dyn_cast(element)) { ArgumentLengthKind lengthKind = getArgumentLengthKind(operand->getVar()); @@ -1876,6 +1978,38 @@ const char *enumAttrBeginPrinterCode = R"( auto caseValueStr = {1}(caseValue); )"; +/// Generate a check that an optional or default-valued attribute or property +/// has a non-default value. For these purposes, the default value of an +/// optional attribute is its presence, even if the attribute itself has a +/// default value. +static void genNonDefaultValueCheck(MethodBody &body, const Operator &op, + AttributeVariable &attrElement) { + Attribute attr = attrElement.getVar()->attr; + std::string getter = op.getGetterName(attrElement.getVar()->name); + bool optionalAndDefault = attr.isOptional() && attr.hasDefaultValue(); + if (optionalAndDefault) + body << "("; + if (attr.isOptional()) + body << getter << "Attr()"; + if (optionalAndDefault) + body << " && "; + if (attr.hasDefaultValue()) { + FmtContext fctx; + fctx.withBuilder("::mlir::OpBuilder((*this)->getContext())"); + body << getter << "Attr() != " + << tgfmt(attr.getConstBuilderTemplate(), &fctx, + attr.getDefaultValue()); + } + if (optionalAndDefault) + body << ")"; +} + +static void genNonDefaultValueCheck(MethodBody &body, const Operator &op, + PropertyVariable &propElement) { + body << op.getGetterName(propElement.getVar()->name) + << "() != " << propElement.getVar()->prop.getDefaultValue(); +} + /// Generate the printer for the 'prop-dict' directive. static void genPropDictPrinter(OperationFormat &fmt, Operator &op, MethodBody &body) { @@ -1904,6 +2038,15 @@ static void genPropDictPrinter(OperationFormat &fmt, Operator &op, body << " }\n"; } } + // Similarly, elide default-valued properties. + for (const NamedProperty &prop : op.getProperties()) { + if (prop.prop.hasDefaultValue()) { + body << " if (" << op.getGetterName(prop.name) + << "() == " << prop.prop.getDefaultValue() << ") {"; + body << " elidedProps.push_back(\"" << prop.name << "\");\n"; + body << " }\n"; + } + } body << " _odsPrinter << \" \";\n" << " printProperties(this->getContext(), _odsPrinter, " @@ -2031,7 +2174,6 @@ static void genCustomDirectiveParameterPrinter(FormatElement *element, } else if (auto *property = dyn_cast(element)) { FmtContext ctx; - ctx.addSubst("_ctxt", "getContext()"); const NamedProperty *namedProperty = property->getVar(); ctx.addSubst("_storage", "getProperties()." + namedProperty->name); body << tgfmt(namedProperty->prop.getConvertFromStorageCall(), &ctx); @@ -2154,16 +2296,6 @@ static void genEnumAttrPrinter(const NamedAttribute *var, const Operator &op, " }\n"; } -/// Generate a check that a DefaultValuedAttr has a value that is non-default. -static void genNonDefaultValueCheck(MethodBody &body, const Operator &op, - AttributeVariable &attrElement) { - FmtContext fctx; - Attribute attr = attrElement.getVar()->attr; - fctx.withBuilder("::mlir::OpBuilder((*this)->getContext())"); - body << " && " << op.getGetterName(attrElement.getVar()->name) << "Attr() != " - << tgfmt(attr.getConstBuilderTemplate(), &fctx, attr.getDefaultValue()); -} - /// Generate the check for the anchor of an optional group. static void genOptionalGroupPrinterAnchor(FormatElement *anchor, const Operator &op, @@ -2190,17 +2322,12 @@ static void genOptionalGroupPrinterAnchor(FormatElement *anchor, genOptionalGroupPrinterAnchor(element->getInputs(), op, body); }) .Case([&](AttributeVariable *element) { - Attribute attr = element->getVar()->attr; - body << op.getGetterName(element->getVar()->name) << "Attr()"; - if (attr.isOptional()) - return; // done - if (attr.hasDefaultValue()) { - // Consider a default-valued attribute as present if it's not the - // default value. - genNonDefaultValueCheck(body, op, *element); - return; - } - llvm_unreachable("attribute must be optional or default-valued"); + // Consider a default-valued attribute as present if it's not the + // default value and an optional one present if it is set. + genNonDefaultValueCheck(body, op, *element); + }) + .Case([&](PropertyVariable *element) { + genNonDefaultValueCheck(body, op, *element); }) .Case([&](CustomDirective *ele) { body << '('; @@ -2276,10 +2403,10 @@ void OperationFormat::genElementPrinter(FormatElement *element, ArrayRef thenElements = optional->getThenElements(); ArrayRef elseElements = optional->getElseElements(); FormatElement *elidedAnchorElement = nullptr; - auto *anchorAttr = dyn_cast(anchor); + auto *anchorAttr = dyn_cast(anchor); if (anchorAttr && anchorAttr != thenElements.front() && (elseElements.empty() || anchorAttr != elseElements.front()) && - anchorAttr->isUnitAttr()) { + anchorAttr->isUnit()) { elidedAnchorElement = anchorAttr; } auto genElementPrinters = [&](ArrayRef elements) { @@ -2319,13 +2446,13 @@ void OperationFormat::genElementPrinter(FormatElement *element, for (VariableElement *var : vars) { TypeSwitch(var) .Case([&](AttributeVariable *attrEle) { - body << " || (" << op.getGetterName(attrEle->getVar()->name) - << "Attr()"; - Attribute attr = attrEle->getVar()->attr; - if (attr.hasDefaultValue()) { - // Don't print default-valued attributes. - genNonDefaultValueCheck(body, op, *attrEle); - } + body << " || ("; + genNonDefaultValueCheck(body, op, *attrEle); + body << ")"; + }) + .Case([&](PropertyVariable *propEle) { + body << " || ("; + genNonDefaultValueCheck(body, op, *propEle); body << ")"; }) .Case([&](OperandVariable *ele) { @@ -2352,7 +2479,7 @@ void OperationFormat::genElementPrinter(FormatElement *element, body << ") {\n"; genLiteralPrinter(lelement->getSpelling(), body, shouldEmitSpace, lastWasPunctuation); - if (oilist->getUnitAttrParsingElement(pelement) == nullptr) { + if (oilist->getUnitVariableParsingElement(pelement) == nullptr) { for (FormatElement *element : pelement) genElementPrinter(element, body, op, shouldEmitSpace, lastWasPunctuation); @@ -2369,7 +2496,7 @@ void OperationFormat::genElementPrinter(FormatElement *element, return; } - // Emit the attribute dictionary. + // Emit the property dictionary. if (isa(element)) { genPropDictPrinter(*this, op, body); lastWasPunctuation = false; @@ -2408,6 +2535,13 @@ void OperationFormat::genElementPrinter(FormatElement *element, else body << "_odsPrinter.printStrippedAttrOrType(" << op.getGetterName(var->name) << "Attr());\n"; + } else if (auto *property = dyn_cast(element)) { + const NamedProperty *var = property->getVar(); + FmtContext fmtContext; + fmtContext.addSubst("_printer", "_odsPrinter"); + fmtContext.addSubst("_ctxt", "getContext()"); + fmtContext.addSubst("_storage", "getProperties()." + var->name); + body << tgfmt(var->prop.getPrinterCall(), &fmtContext) << ";\n"; } else if (auto *operand = dyn_cast(element)) { if (operand->getVar()->isVariadicOfVariadic()) { body << " ::llvm::interleaveComma(" @@ -2737,6 +2871,10 @@ static bool isOptionallyParsed(FormatElement *el) { Attribute attr = attrVar->getVar()->attr; return attr.isOptional() || attr.hasDefaultValue(); } + if (auto *propVar = dyn_cast(el)) { + const Property &prop = propVar->getVar()->prop; + return prop.hasDefaultValue() && prop.hasOptionalParser(); + } if (auto *operandVar = dyn_cast(el)) { const NamedTypeConstraint *operand = operandVar->getVar(); return operand->isOptional() || operand->isVariadic() || @@ -3141,10 +3279,9 @@ OpFormatParser::parseVariableImpl(SMLoc loc, StringRef name, Context ctx) { } if (const NamedProperty *property = findArg(op.getProperties(), name)) { - if (ctx != CustomDirectiveContext && ctx != RefDirectiveContext) + if (ctx == TypeDirectiveContext) return emitError( - loc, "properties currently only supported in `custom` directive"); - + loc, "properties cannot be used as children to a `type` directive"); if (ctx == RefDirectiveContext) { if (!seenProperties.count(property)) return emitError(loc, "property '" + name + @@ -3428,6 +3565,15 @@ LogicalResult OpFormatParser::verifyOIListParsingElement(FormatElement *element, "an oilist parsing group"); return success(); }) + // Only optional properties can be within an oilist parsing group. + .Case([&](PropertyVariable *propEle) { + if (!propEle->getVar()->prop.hasDefaultValue()) + return emitError( + loc, + "only default-valued or optional properties can be used in " + "an olist parsing group"); + return success(); + }) // Only optional-like(i.e. variadic) operands can be within an // oilist parsing group. .Case([&](OperandVariable *ele) { @@ -3557,6 +3703,16 @@ LogicalResult OpFormatParser::verifyOptionalGroupElement(SMLoc loc, "can be used to anchor an optional group"); return success(); }) + // All properties can be within the optional group, but only optional + // properties can be the anchor. + .Case([&](PropertyVariable *propEle) { + Property prop = propEle->getVar()->prop; + if (isAnchor && !(prop.hasDefaultValue() && prop.hasOptionalParser())) + return emitError(loc, "only properties with default values " + "that can be optionally parsed " + "can be used to anchor an optional group"); + return success(); + }) // Only optional-like(i.e. variadic) operands can be within an optional // group. .Case([&](OperandVariable *ele) { @@ -3649,4 +3805,4 @@ void mlir::tblgen::generateOpFormat(const Operator &constOp, OpClass &opClass) { // Generate the printer and parser based on the parsed format. format.genParser(op, opClass); format.genPrinter(op, opClass); -} +} \ No newline at end of file -- Gitee From b326a0893c3084821ad694b58dac40bd620487ac Mon Sep 17 00:00:00 2001 From: Daniel Date: Tue, 2 Sep 2025 16:21:17 +0200 Subject: [PATCH 6/7] [Backport] Add PtrLikeTypeInterface, removed unnecessary ops --- .../include/mlir/Dialect/Ptr/IR/PtrDialect.td | 52 +++++++++++++ mlir/include/mlir/Dialect/Ptr/IR/PtrOps.td | 75 ------------------- mlir/include/mlir/IR/BuiltinTypeInterfaces.td | 53 +++++++++++++ mlir/include/mlir/IR/BuiltinTypes.h | 19 ++++- mlir/include/mlir/IR/BuiltinTypes.td | 2 + mlir/lib/Conversion/PtrToLLVM/PtrToLLVM.cpp | 16 ++-- mlir/lib/Dialect/Ptr/IR/PtrDialect.cpp | 25 ------- mlir/lib/Dialect/Ptr/IR/PtrTypes.cpp | 12 +++ mlir/lib/IR/BuiltinTypes.cpp | 14 ++++ 9 files changed, 159 insertions(+), 109 deletions(-) diff --git a/mlir/include/mlir/Dialect/Ptr/IR/PtrDialect.td b/mlir/include/mlir/Dialect/Ptr/IR/PtrDialect.td index 9aa0215e9560..cd87d7474468 100644 --- a/mlir/include/mlir/Dialect/Ptr/IR/PtrDialect.td +++ b/mlir/include/mlir/Dialect/Ptr/IR/PtrDialect.td @@ -37,6 +37,7 @@ class Ptr_Type traits = []> def Ptr_PtrType : Ptr_Type<"Ptr", "ptr", [ MemRefElementTypeInterface, + PtrLikeTypeInterface, DeclareTypeInterfaceMethods ]> { @@ -61,8 +62,59 @@ def Ptr_PtrType : Ptr_Type<"Ptr", "ptr", [ return $_get(memorySpace.getContext(), memorySpace); }]> ]; + let extraClassDeclaration = [{ + // `PtrLikeTypeInterface` interface methods. + /// Returns `Type()` as this pointer type is opaque. + Type getElementType() const { + return Type(); + } + /// Clones the pointer with specified memory space or returns failure + /// if an `elementType` was specified or if the memory space doesn't + /// implement `MemorySpaceAttrInterface`. + FailureOr clonePtrWith(Attribute memorySpace, + std::optional elementType) const { + if (elementType) + return failure(); + if (auto ms = memorySpace.dyn_cast()) + return llvm::cast(get(ms)); + return failure(); + } + /// `!ptr.ptr` types are seen as ptr-like objects with no metadata. + bool hasPtrMetadata() const { + return false; + } + }]; +} + +def Ptr_PtrMetadata : Ptr_Type<"PtrMetadata", "ptr_metadata"> { + let summary = "Pointer metadata type"; + let description = [{ + The `ptr_metadata` type represents an opaque-view of the metadata associated + with a `ptr-like` object type. + + Note: It's a verification error to construct a `ptr_metadata` type using a + `ptr-like` type with no metadata. + + Example: + + ```mlir + // The metadata associated with a `memref` type. + !ptr.ptr_metadata> + ``` + }]; + let parameters = (ins "PtrLikeTypeInterface":$type); + let assemblyFormat = "`<` $type `>`"; + let builders = [ + TypeBuilderWithInferredContext<(ins + "PtrLikeTypeInterface":$ptrLike), [{ + return $_get(ptrLike.getContext(), ptrLike); + }]> + ]; + let genVerifyDecl = 1; } + + //===----------------------------------------------------------------------===// // Base address operation definition. //===----------------------------------------------------------------------===// diff --git a/mlir/include/mlir/Dialect/Ptr/IR/PtrOps.td b/mlir/include/mlir/Dialect/Ptr/IR/PtrOps.td index df4784dd94f8..313c9f8eb09a 100644 --- a/mlir/include/mlir/Dialect/Ptr/IR/PtrOps.td +++ b/mlir/include/mlir/Dialect/Ptr/IR/PtrOps.td @@ -11,82 +11,7 @@ include "mlir/Dialect/Ptr/IR/PtrDialect.td" include "mlir/Dialect/Ptr/IR/PtrAttrDefs.td" -include "mlir/Dialect/Ptr/IR/PtrEnums.td" include "mlir/Dialect/Ptr/IR/MemorySpaceInterfaces.td" -include "mlir/Interfaces/SideEffectInterfaces.td" -include "mlir/Interfaces/ViewLikeInterface.td" include "mlir/IR/OpAsmInterface.td" -//===----------------------------------------------------------------------===// -// PtrAddOp -//===----------------------------------------------------------------------===// - -def Ptr_PtrAddOp : Pointer_Op<"ptr_add", [ - Pure, AllTypesMatch<["base", "result"]>, ViewLikeOpInterface - ]> { - let summary = "Pointer add operation"; - let description = [{ - The `ptr_add` operation adds an integer offset to a pointer to produce a new - pointer. The input and output pointer types are always the same. - - Example: - - ```mlir - %x_off = ptr.ptr_add %x, %off : !ptr.ptr<0>, i32 - %x_off0 = ptr.ptr_add nusw %x, %off : !ptr.ptr<0>, i32 - ``` - }]; - - let arguments = (ins - Ptr_PtrType:$base, - AnySignlessIntegerOrIndex:$offset, - DefaultValuedProperty, "PtrAddFlags::none">:$flags); - let results = (outs Ptr_PtrType:$result); - let assemblyFormat = [{ - ($flags^)? $base `,` $offset attr-dict `:` type($base) `,` type($offset) - }]; - let hasFolder = 1; - let extraClassDeclaration = [{ - /// `ViewLikeOp::getViewSource` method. - Value getViewSource() { return getBase(); } - }]; -} - -//===----------------------------------------------------------------------===// -// TypeOffsetOp -//===----------------------------------------------------------------------===// - -def Ptr_TypeOffsetOp : Pointer_Op<"type_offset", [Pure]> { - let summary = "Type offset operation"; - let description = [{ - The `type_offset` operation produces an int or index-typed SSA value - equal to a target-specific constant representing the offset of a single - element of the given type. - - Example: - - ```mlir - // Return the offset between two f32 stored in memory - %0 = ptr.type_offset f32 : index - // Return the offset between two memref descriptors stored in memory - %1 = ptr.type_offset memref<12 x f64> : i32 - ``` - }]; - - let arguments = (ins TypeAttr:$elementType); - let results = (outs AnySignlessIntegerOrIndex:$result); - let builders = [ - OpBuilder<(ins "Type":$elementType)> - ]; - let assemblyFormat = [{ - $elementType attr-dict `:` type($result) - }]; - let extraClassDeclaration = [{ - /// Returns the type offset according to `layout`. If `layout` is `nullopt` - /// the nearest layout the op will be used for the computation. - llvm::TypeSize getTypeSize(std::optional layout = std::nullopt); - }]; -} - - #endif // PTR_OPS diff --git a/mlir/include/mlir/IR/BuiltinTypeInterfaces.td b/mlir/include/mlir/IR/BuiltinTypeInterfaces.td index db38e2e1bce2..2fca540e5293 100644 --- a/mlir/include/mlir/IR/BuiltinTypeInterfaces.td +++ b/mlir/include/mlir/IR/BuiltinTypeInterfaces.td @@ -41,6 +41,59 @@ def MemRefElementTypeInterface : TypeInterface<"MemRefElementTypeInterface"> { }]; } +//===----------------------------------------------------------------------===// +// PtrLikeTypeInterface +//===----------------------------------------------------------------------===// + +def PtrLikeTypeInterface : TypeInterface<"PtrLikeTypeInterface"> { + let cppNamespace = "::mlir"; + let description = [{ + A ptr-like type represents an object storing a memory address. This object + is constituted by: + - A memory address called the base pointer. This pointer is treated as a + bag of bits without any assumed structure. The bit-width of the base + pointer must be a compile-time constant. However, the bit-width may remain + opaque or unavailable during transformations that do not depend on the + base pointer. Finally, it is considered indivisible in the sense that as + a `PtrLikeTypeInterface` value, it has no metadata. + - Optional metadata about the pointer. For example, the size of the memory + region associated with the pointer. + + Furthermore, all ptr-like types have two properties: + - The memory space associated with the address held by the pointer. + - An optional element type. If the element type is not specified, the + pointer is considered opaque. + }]; + let methods = [ + InterfaceMethod<[{ + Returns the memory space of this ptr-like type. + }], + "::mlir::Attribute", "getMemorySpace">, + InterfaceMethod<[{ + Returns the element type of this ptr-like type. Note: this method can + return `::mlir::Type()`, in which case the pointer is considered opaque. + }], + "::mlir::Type", "getElementType">, + InterfaceMethod<[{ + Returns whether this ptr-like type has non-empty metadata. + }], + "bool", "hasPtrMetadata">, + InterfaceMethod<[{ + Returns a clone of this type with the given memory space and element type, + or `failure` if the type cannot be cloned with the specified arguments. + If the pointer is opaque and `elementType` is not `std::nullopt` the + method will return `failure`. + + If no `elementType` is provided and ptr is not opaque, the `elementType` + of this type is used. + }], + "::llvm::FailureOr<::mlir::PtrLikeTypeInterface>", "clonePtrWith", (ins + "::mlir::Attribute":$memorySpace, + "::std::optional<::mlir::Type>":$elementType + )> + ]; +} + //===----------------------------------------------------------------------===// // ShapedType //===----------------------------------------------------------------------===// diff --git a/mlir/include/mlir/IR/BuiltinTypes.h b/mlir/include/mlir/IR/BuiltinTypes.h index 4250be90ba7f..217bfd9eb4f4 100644 --- a/mlir/include/mlir/IR/BuiltinTypes.h +++ b/mlir/include/mlir/IR/BuiltinTypes.h @@ -141,7 +141,9 @@ public: /// Note: This class attaches the ShapedType trait to act as a mixin to /// provide many useful utility functions. This inheritance has no effect /// on derived memref types. -class BaseMemRefType : public Type, public ShapedType::Trait { +class BaseMemRefType : public Type, + public PtrLikeTypeInterface::Trait, + public ShapedType::Trait { public: using Type::Type; @@ -158,6 +160,13 @@ public: /// provided shape is `std::nullopt`, the current shape of the type is used. BaseMemRefType cloneWith(std::optional> shape, Type elementType) const; + + /// Clone this type with the given memory space and element type. If the + /// provided element type is `std::nullopt`, the current element type of the + /// type is used. + FailureOr + clonePtrWith(Attribute memorySpace, std::optional elementType) const; + // Make sure that base class overloads are visible. using ShapedType::Trait::clone; @@ -183,8 +192,16 @@ public: /// New `Attribute getMemorySpace()` method should be used instead. unsigned getMemorySpaceAsInt() const; + /// Returns that this ptr-like object has non-empty ptr metadata. + bool hasPtrMetadata() const { return true; } + /// Allow implicit conversion to ShapedType. operator ShapedType() const { return llvm::cast(*this); } + + /// Allow implicit conversion to PtrLikeTypeInterface. + operator PtrLikeTypeInterface() const { + return llvm::cast(*this); + } }; } // namespace mlir diff --git a/mlir/include/mlir/IR/BuiltinTypes.td b/mlir/include/mlir/IR/BuiltinTypes.td index 0b3532dcc7d4..387e037f9eee 100644 --- a/mlir/include/mlir/IR/BuiltinTypes.td +++ b/mlir/include/mlir/IR/BuiltinTypes.td @@ -423,6 +423,7 @@ def Builtin_Integer : Builtin_Type<"Integer", "integer"> { //===----------------------------------------------------------------------===// def Builtin_MemRef : Builtin_Type<"MemRef", "memref", [ + PtrLikeTypeInterface, ShapedTypeInterface ], "BaseMemRefType"> { let summary = "Shaped reference to a region of memory"; @@ -951,6 +952,7 @@ def Builtin_Tuple : Builtin_Type<"Tuple", "tuple"> { //===----------------------------------------------------------------------===// def Builtin_UnrankedMemRef : Builtin_Type<"UnrankedMemRef", "unranked_memref", [ + PtrLikeTypeInterface, ShapedTypeInterface ], "BaseMemRefType"> { let summary = "Shaped reference, with unknown rank, to a region of memory"; diff --git a/mlir/lib/Conversion/PtrToLLVM/PtrToLLVM.cpp b/mlir/lib/Conversion/PtrToLLVM/PtrToLLVM.cpp index caebed24d0af..756a441b8ec6 100644 --- a/mlir/lib/Conversion/PtrToLLVM/PtrToLLVM.cpp +++ b/mlir/lib/Conversion/PtrToLLVM/PtrToLLVM.cpp @@ -38,11 +38,10 @@ struct PtrToLLVMDialectInterface : public ConvertToLLVMPatternInterface { void populateConvertToLLVMConversionPatterns( ConversionTarget &target, LLVMTypeConverter &converter, RewritePatternSet &patterns) const final { - llvm::errs() << "Populating Ptr to LLVM conversion patterns! \n"; ptr::populatePtrToLLVMConversionPatterns(converter, patterns); } }; -} +} // namespace //===----------------------------------------------------------------------===// // API @@ -50,13 +49,16 @@ struct PtrToLLVMDialectInterface : public ConvertToLLVMPatternInterface { void mlir::ptr::populatePtrToLLVMConversionPatterns( LLVMTypeConverter &converter, RewritePatternSet &patterns) { - - llvm::errs() << "Adding type conversion! \n"; + converter.addTypeAttributeConversion( + [&](PtrLikeTypeInterface type, ptr::GenericSpaceAttr memorySpace) + -> TypeConverter::AttributeConversionResult { + if (type.getMemorySpace() != memorySpace) + return TypeConverter::AttributeConversionResult::na(); + return IntegerAttr::get(IntegerType::get(type.getContext(), 32), 0); + }); // Add type conversions. converter.addConversion([&](ptr::PtrType type) -> Type { - llvm::errs() << "Converting PtrType! \n"; - llvm::errs() << "MemorySpace: " << type.getMemorySpace() << "\n"; std::optional maybeAttr = converter.convertTypeAttribute(type, type.getMemorySpace()); auto memSpace = @@ -66,11 +68,9 @@ void mlir::ptr::populatePtrToLLVMConversionPatterns( return LLVM::LLVMPointerType::get(type.getContext(), memSpace.getValue().getSExtValue()); }); - } void mlir::ptr::registerConvertPtrToLLVMInterface(DialectRegistry ®istry) { - llvm::errs() << "Registering Ptr to LLVM interface! \n"; registry.addExtension(+[](MLIRContext *ctx, ptr::PtrDialect *dialect) { dialect->addInterfaces(); }); diff --git a/mlir/lib/Dialect/Ptr/IR/PtrDialect.cpp b/mlir/lib/Dialect/Ptr/IR/PtrDialect.cpp index c21783011452..061b3feb4d66 100644 --- a/mlir/lib/Dialect/Ptr/IR/PtrDialect.cpp +++ b/mlir/lib/Dialect/Ptr/IR/PtrDialect.cpp @@ -41,31 +41,6 @@ void PtrDialect::initialize() { >(); } -//===----------------------------------------------------------------------===// -// PtrAddOp -//===----------------------------------------------------------------------===// - -/// Fold: ptradd ptr + 0 -> ptr -OpFoldResult PtrAddOp::fold(FoldAdaptor adaptor) { - Attribute attr = adaptor.getOffset(); - if (!attr) - return nullptr; - if (llvm::APInt value; m_ConstantInt(&value).match(attr) && value.isZero()) - return getBase(); - return nullptr; -} - -//===----------------------------------------------------------------------===// -// TypeOffsetOp -//===----------------------------------------------------------------------===// - -llvm::TypeSize TypeOffsetOp::getTypeSize(std::optional layout) { - if (layout) - return layout->getTypeSize(getElementType()); - DataLayout dl = DataLayout::closest(*this); - return dl.getTypeSize(getElementType()); -} - //===----------------------------------------------------------------------===// // Pointer API. //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Dialect/Ptr/IR/PtrTypes.cpp b/mlir/lib/Dialect/Ptr/IR/PtrTypes.cpp index a0ea5e83f646..101330073d2d 100644 --- a/mlir/lib/Dialect/Ptr/IR/PtrTypes.cpp +++ b/mlir/lib/Dialect/Ptr/IR/PtrTypes.cpp @@ -151,3 +151,15 @@ LogicalResult PtrType::verifyEntries(DataLayoutEntryListRef entries, } return success(); } + +//===----------------------------------------------------------------------===// +// Pointer metadata +//===----------------------------------------------------------------------===// + +LogicalResult +PtrMetadataType::verify(function_ref emitError, + PtrLikeTypeInterface type) { + if (!type.hasPtrMetadata()) + return emitError() << "the ptr-like type has no metadata"; + return success(); +} diff --git a/mlir/lib/IR/BuiltinTypes.cpp b/mlir/lib/IR/BuiltinTypes.cpp index faa944937e00..39627cba244d 100644 --- a/mlir/lib/IR/BuiltinTypes.cpp +++ b/mlir/lib/IR/BuiltinTypes.cpp @@ -386,6 +386,20 @@ BaseMemRefType BaseMemRefType::cloneWith(std::optional> shape, return builder; } +FailureOr +BaseMemRefType::clonePtrWith(Attribute memorySpace, + std::optional elementType) const { + Type eTy = elementType ? *elementType : getElementType(); + if (llvm::dyn_cast(*this)) + return ::llvm::cast( + UnrankedMemRefType::get(eTy, memorySpace)); + + MemRefType::Builder builder(llvm::cast(*this)); + builder.setElementType(eTy); + builder.setMemorySpace(memorySpace); + return ::llvm::cast(static_cast(builder)); +} + MemRefType BaseMemRefType::clone(::llvm::ArrayRef shape, Type elementType) const { return ::llvm::cast(cloneWith(shape, elementType)); -- Gitee From 92f7f55cc7e01d7ec939a2fb5bbbb5ded925281e Mon Sep 17 00:00:00 2001 From: Daniel Date: Tue, 16 Sep 2025 10:30:22 +0200 Subject: [PATCH 7/7] fixed test failures --- mlir/test/Dialect/Ptr/layout.mlir | 101 ++++++++++++------ mlir/test/Dialect/Ptr/types.mlir | 21 ++-- mlir/test/IR/properties.mlir | 52 +++++++-- mlir/test/Transforms/test-legalizer.mlir | 4 +- mlir/test/lib/Dialect/Test/CMakeLists.txt | 1 + mlir/test/lib/Dialect/Test/TestAttrDefs.td | 11 ++ mlir/test/lib/Dialect/Test/TestAttributes.cpp | 43 ++++++++ mlir/test/lib/Dialect/Test/TestAttributes.h | 1 + mlir/test/mlir-tblgen/op-format.td | 4 +- 9 files changed, 183 insertions(+), 55 deletions(-) diff --git a/mlir/test/Dialect/Ptr/layout.mlir b/mlir/test/Dialect/Ptr/layout.mlir index 73189a388942..f904e729fcbe 100644 --- a/mlir/test/Dialect/Ptr/layout.mlir +++ b/mlir/test/Dialect/Ptr/layout.mlir @@ -1,56 +1,91 @@ // RUN: mlir-opt --test-data-layout-query --split-input-file --verify-diagnostics %s | FileCheck %s module attributes { dlti.dl_spec = #dlti.dl_spec< - #dlti.dl_entry>, - #dlti.dl_entry,#ptr.spec>, - #dlti.dl_entry, #ptr.spec>, - #dlti.dl_entry<"dlti.alloca_memory_space", 5 : ui64>, - #dlti.dl_entry<"dlti.global_memory_space", 2 : ui64>, - #dlti.dl_entry<"dlti.program_memory_space", 3 : ui64>, + #dlti.dl_entry, #ptr.spec>, + #dlti.dl_entry>,#ptr.spec>, + #dlti.dl_entry>, #ptr.spec>, + #dlti.dl_entry<"dlti.default_memory_space", #test.const_memory_space<7>>, + #dlti.dl_entry<"dlti.alloca_memory_space", #test.const_memory_space<5>>, + #dlti.dl_entry<"dlti.global_memory_space", #test.const_memory_space<2>>, + #dlti.dl_entry<"dlti.program_memory_space", #test.const_memory_space<3>>, #dlti.dl_entry<"dlti.stack_alignment", 128 : i64> >} { - // CHECK: @spec + // CHECK-LABEL: @spec func.func @spec() { // CHECK: alignment = 4 - // CHECK: alloca_memory_space = 5 + // CHECK: alloca_memory_space = #test.const_memory_space<5> // CHECK: bitsize = 32 - // CHECK: global_memory_space = 2 + // CHECK: default_memory_space = #test.const_memory_space<7> + // CHECK: global_memory_space = #test.const_memory_space<2> // CHECK: index = 32 // CHECK: preferred = 8 - // CHECK: program_memory_space = 3 + // CHECK: program_memory_space = #test.const_memory_space<3> // CHECK: size = 4 // CHECK: stack_alignment = 128 - "test.data_layout_query"() : () -> !ptr.ptr - // CHECK: alignment = 4 - // CHECK: alloca_memory_space = 5 - // CHECK: bitsize = 32 - // CHECK: global_memory_space = 2 - // CHECK: index = 32 - // CHECK: preferred = 8 - // CHECK: program_memory_space = 3 - // CHECK: size = 4 + "test.data_layout_query"() : () -> !ptr.ptr<#test.const_memory_space> + // CHECK: alignment = 1 + // CHECK: alloca_memory_space = #test.const_memory_space<5> + // CHECK: bitsize = 64 + // CHECK: default_memory_space = #test.const_memory_space<7> + // CHECK: global_memory_space = #test.const_memory_space<2> + // CHECK: index = 64 + // CHECK: preferred = 1 + // CHECK: program_memory_space = #test.const_memory_space<3> + // CHECK: size = 8 // CHECK: stack_alignment = 128 - "test.data_layout_query"() : () -> !ptr.ptr<3> + "test.data_layout_query"() : () -> !ptr.ptr<#test.const_memory_space<3>> // CHECK: alignment = 8 - // CHECK: alloca_memory_space = 5 + // CHECK: alloca_memory_space = #test.const_memory_space<5> // CHECK: bitsize = 64 - // CHECK: global_memory_space = 2 + // CHECK: default_memory_space = #test.const_memory_space<7> + // CHECK: global_memory_space = #test.const_memory_space<2> // CHECK: index = 64 // CHECK: preferred = 8 - // CHECK: program_memory_space = 3 + // CHECK: program_memory_space = #test.const_memory_space<3> // CHECK: size = 8 // CHECK: stack_alignment = 128 - "test.data_layout_query"() : () -> !ptr.ptr<5> + "test.data_layout_query"() : () -> !ptr.ptr<#test.const_memory_space<5>> // CHECK: alignment = 8 - // CHECK: alloca_memory_space = 5 + // CHECK: alloca_memory_space = #test.const_memory_space<5> // CHECK: bitsize = 32 - // CHECK: global_memory_space = 2 + // CHECK: default_memory_space = #test.const_memory_space<7> + // CHECK: global_memory_space = #test.const_memory_space<2> // CHECK: index = 24 // CHECK: preferred = 8 - // CHECK: program_memory_space = 3 + // CHECK: program_memory_space = #test.const_memory_space<3> // CHECK: size = 4 // CHECK: stack_alignment = 128 - "test.data_layout_query"() : () -> !ptr.ptr<4> + "test.data_layout_query"() : () -> !ptr.ptr<#test.const_memory_space<4>> + return + } +} + +// ----- + +module attributes { dlti.dl_spec = #dlti.dl_spec< + #dlti.dl_entry, #ptr.spec>, + #dlti.dl_entry<"dlti.default_memory_space", #test.const_memory_space> +>} { + // CHECK-LABEL: @default_memory_space + func.func @default_memory_space() { + // CHECK: alignment = 4 + // CHECK: bitsize = 32 + // CHECK: index = 32 + // CHECK: preferred = 4 + // CHECK: size = 4 + "test.data_layout_query"() : () -> !ptr.ptr<#test.const_memory_space> + // CHECK: alignment = 4 + // CHECK: bitsize = 32 + // CHECK: index = 32 + // CHECK: preferred = 4 + // CHECK: size = 4 + "test.data_layout_query"() : () -> !ptr.ptr<#test.const_memory_space<1>> + // CHECK: alignment = 4 + // CHECK: bitsize = 32 + // CHECK: index = 32 + // CHECK: preferred = 4 + // CHECK: size = 4 + "test.data_layout_query"() : () -> !ptr.ptr<#test.const_memory_space<2>> return } } @@ -59,7 +94,7 @@ module attributes { dlti.dl_spec = #dlti.dl_spec< // expected-error@+2 {{preferred alignment is expected to be at least as large as ABI alignment}} module attributes { dlti.dl_spec = #dlti.dl_spec< - #dlti.dl_entry> + #dlti.dl_entry, #ptr.spec> >} { func.func @pointer() { return @@ -70,7 +105,7 @@ module attributes { dlti.dl_spec = #dlti.dl_spec< // expected-error@+2 {{size entry must be divisible by 8}} module attributes { dlti.dl_spec = #dlti.dl_spec< - #dlti.dl_entry> + #dlti.dl_entry, #ptr.spec> >} { func.func @pointer() { return @@ -82,7 +117,7 @@ module attributes { dlti.dl_spec = #dlti.dl_spec< // expected-error@+2 {{abi entry must be divisible by 8}} module attributes { dlti.dl_spec = #dlti.dl_spec< - #dlti.dl_entry> + #dlti.dl_entry, #ptr.spec> >} { func.func @pointer() { return @@ -94,7 +129,7 @@ module attributes { dlti.dl_spec = #dlti.dl_spec< // expected-error@+2 {{preferred entry must be divisible by 8}} module attributes { dlti.dl_spec = #dlti.dl_spec< - #dlti.dl_entry> + #dlti.dl_entry, #ptr.spec> >} { func.func @pointer() { return @@ -106,7 +141,7 @@ module attributes { dlti.dl_spec = #dlti.dl_spec< // expected-error@+2 {{index entry must be divisible by 8}} module attributes { dlti.dl_spec = #dlti.dl_spec< - #dlti.dl_entry> + #dlti.dl_entry, #ptr.spec> >} { func.func @pointer() { return diff --git a/mlir/test/Dialect/Ptr/types.mlir b/mlir/test/Dialect/Ptr/types.mlir index 279213bd6fc3..6f4e89eb3e19 100644 --- a/mlir/test/Dialect/Ptr/types.mlir +++ b/mlir/test/Dialect/Ptr/types.mlir @@ -1,17 +1,24 @@ // RUN: mlir-opt %s -split-input-file | mlir-opt | FileCheck %s // CHECK-LABEL: func @ptr_test -// CHECK: (%[[ARG0:.*]]: !ptr.ptr, %[[ARG1:.*]]: !ptr.ptr<1 : i32>) -// CHECK: -> (!ptr.ptr<1 : i32>, !ptr.ptr) -func.func @ptr_test(%arg0: !ptr.ptr, %arg1: !ptr.ptr<1 : i32>) -> (!ptr.ptr<1 : i32>, !ptr.ptr) { - // CHECK: return %[[ARG1]], %[[ARG0]] : !ptr.ptr<1 : i32>, !ptr.ptr - return %arg1, %arg0 : !ptr.ptr<1 : i32>, !ptr.ptr +// CHECK: (%[[ARG0:.*]]: !ptr.ptr<#test.const_memory_space>, %[[ARG1:.*]]: !ptr.ptr<#test.const_memory_space<1>>) +// CHECK: -> (!ptr.ptr<#test.const_memory_space<1>>, !ptr.ptr<#test.const_memory_space>) +func.func @ptr_test(%arg0: !ptr.ptr<#test.const_memory_space>, %arg1: !ptr.ptr<#test.const_memory_space<1>>) -> (!ptr.ptr<#test.const_memory_space<1>>, !ptr.ptr<#test.const_memory_space>) { + // CHECK: return %[[ARG1]], %[[ARG0]] : !ptr.ptr<#test.const_memory_space<1>>, !ptr.ptr<#test.const_memory_space> + return %arg1, %arg0 : !ptr.ptr<#test.const_memory_space<1>>, !ptr.ptr<#test.const_memory_space> } // ----- // CHECK-LABEL: func @ptr_test -// CHECK: %[[ARG:.*]]: memref -func.func @ptr_test(%arg0: memref) { +// CHECK: %[[ARG:.*]]: memref> +func.func @ptr_test(%arg0: memref>) { + return +} + +// CHECK-LABEL: func @ptr_test_1 +// CHECK: (%[[ARG0:.*]]: !ptr.ptr<#test.const_memory_space>, %[[ARG1:.*]]: !ptr.ptr<#test.const_memory_space<3>>) +func.func @ptr_test_1(%arg0: !ptr.ptr<#test.const_memory_space>, + %arg1: !ptr.ptr<#test.const_memory_space<3>>) { return } diff --git a/mlir/test/IR/properties.mlir b/mlir/test/IR/properties.mlir index 01ea856b0316..418b81dcbb03 100644 --- a/mlir/test/IR/properties.mlir +++ b/mlir/test/IR/properties.mlir @@ -2,10 +2,10 @@ // # RUN: mlir-opt %s -mlir-print-op-generic -split-input-file | mlir-opt -mlir-print-op-generic | FileCheck %s --check-prefix=GENERIC // CHECK: test.with_properties -// CHECK-SAME: <{a = 32 : i64, array = array, b = "foo"}>{{$}} +// CHECK-SAME: a = 32, b = "foo", c = "bar", flag = true, array = [1, 2, 3, 4]{{$}} // GENERIC: "test.with_properties"() -// GENERIC-SAME: <{a = 32 : i64, array = array, b = "foo"}> : () -> () -test.with_properties <{a = 32 : i64, array = array, b = "foo"}> +// GENERIC-SAME: <{a = 32 : i64, array = array, b = "foo", c = "bar", flag = true}> : () -> () +test.with_properties a = 32, b = "foo", c = "bar", flag = true, array = [1, 2, 3, 4] // CHECK: test.with_nice_properties // CHECK-SAME: "foo bar" is -3{{$}} @@ -34,18 +34,48 @@ test.using_property_in_custom [1, 4, 20] // GENERIC-SAME: }> test.using_property_ref_in_custom 1 + 4 = 5 -// CHECK: test.with_default_valued_properties {{$}} +// CHECK: test.with_default_valued_properties na{{$}} // GENERIC: "test.with_default_valued_properties"() -// GENERIC-SAME: <{a = 0 : i32}> -test.with_default_valued_properties <{a = 0 : i32}> +// GENERIC-SAME: <{a = 0 : i32, b = "", c = -1 : i32, unit = false}> : () -> () +test.with_default_valued_properties 0 "" -1 unit_absent + +// CHECK: test.with_default_valued_properties 1 "foo" 0 unit{{$}} +// GENERIC: "test.with_default_valued_properties"() +// GENERIC-SAME: <{a = 1 : i32, b = "foo", c = 0 : i32, unit}> : () -> () +test.with_default_valued_properties 1 "foo" 0 unit // CHECK: test.with_optional_properties -// CHECK-SAME: <{b = 0 : i32}> +// CHECK-SAME: simple = 0 +// GENERIC: "test.with_optional_properties"() +// GENERIC-SAME: <{hasDefault = [], hasUnit = false, longSyntax = [], maybeUnit = [], nested = [], nonTrivialStorage = [], simple = [0]}> : () -> () +test.with_optional_properties simple = 0 + +// CHECK: test.with_optional_properties{{$}} // GENERIC: "test.with_optional_properties"() -// GENERIC-SAME: <{b = 0 : i32}> -test.with_optional_properties <{b = 0 : i32}> +// GENERIC-SAME: simple = [] +test.with_optional_properties -// CHECK: test.with_optional_properties {{$}} +// CHECK: test.with_optional_properties +// CHECK-SAME: anAttr = 0 simple = 1 nonTrivialStorage = "foo" hasDefault = some<0> nested = some<1> longSyntax = some<"bar"> hasUnit maybeUnit = some // GENERIC: "test.with_optional_properties"() -// GENERIC-SAME: : () -> () +// GENERIC-SAME: <{anAttr = 0 : i32, hasDefault = [0], hasUnit, longSyntax = ["bar"], maybeUnit = [unit], nested = {{\[}}[1]], nonTrivialStorage = ["foo"], simple = [1]}> : () -> () test.with_optional_properties + anAttr = 0 + simple = 1 + nonTrivialStorage = "foo" + hasDefault = some<0> + nested = some<1> + longSyntax = some<"bar"> + hasUnit + maybeUnit = some + +// CHECK: test.with_optional_properties +// CHECK-SAME: nested = some +// GENERIC: "test.with_optional_properties"() +// GENERIC-SAME: nested = {{\[}}[]] +test.with_optional_properties nested = some + +// CHECK: test.with_array_properties +// CHECK-SAME: ints = [1, 2] strings = ["a", "b"] nested = {{\[}}[1, 2], [3, 4]] opt = [-1, -2] explicitOptions = [none, 0] explicitUnits = [unit, unit_absent] +// GENERIC: "test.with_array_properties"() +test.with_array_properties ints = [1, 2] strings = ["a", "b"] nested = [[1, 2], [3, 4]] opt = [-1, -2] explicitOptions = [none, 0] explicitUnits = [unit, unit_absent] [] thats_has_default diff --git a/mlir/test/Transforms/test-legalizer.mlir b/mlir/test/Transforms/test-legalizer.mlir index 65c947198e06..a52c0e636f0c 100644 --- a/mlir/test/Transforms/test-legalizer.mlir +++ b/mlir/test/Transforms/test-legalizer.mlir @@ -408,10 +408,10 @@ func.func @test_move_op_before_rollback() { // CHECK-LABEL: func @test_properties_rollback() func.func @test_properties_rollback() { - // CHECK: test.with_properties <{a = 32 : i64, + // CHECK: test.with_properties a = 32, // expected-remark @below{{op 'test.with_properties' is not legalizable}} test.with_properties - <{a = 32 : i64, array = array, b = "foo"}> + a = 32, b = "foo", c = "bar", flag = true, array = [1, 2, 3, 4] {modify_inplace} "test.return"() : () -> () } diff --git a/mlir/test/lib/Dialect/Test/CMakeLists.txt b/mlir/test/lib/Dialect/Test/CMakeLists.txt index 967101242e26..65c0121b798b 100644 --- a/mlir/test/lib/Dialect/Test/CMakeLists.txt +++ b/mlir/test/lib/Dialect/Test/CMakeLists.txt @@ -85,6 +85,7 @@ add_mlir_library(MLIRTestDialect MLIRLinalgDialect MLIRLinalgTransforms MLIRLLVMDialect + MLIRPtrDialect MLIRPass MLIRPolynomialDialect MLIRReduce diff --git a/mlir/test/lib/Dialect/Test/TestAttrDefs.td b/mlir/test/lib/Dialect/Test/TestAttrDefs.td index a0a1cd30ed8a..007fbab07556 100644 --- a/mlir/test/lib/Dialect/Test/TestAttrDefs.td +++ b/mlir/test/lib/Dialect/Test/TestAttrDefs.td @@ -17,6 +17,7 @@ include "TestDialect.td" include "TestEnumDefs.td" include "mlir/Dialect/Polynomial/IR/PolynomialAttributes.td" +include "mlir/Dialect/Ptr/IR/MemorySpaceInterfaces.td" include "mlir/Dialect/Utils/StructuredOpsUtils.td" include "mlir/IR/AttrTypeBase.td" include "mlir/IR/BuiltinAttributeInterfaces.td" @@ -368,5 +369,15 @@ def NestedPolynomialAttr2 : Test_Attr<"NestedPolynomialAttr2"> { }]; } +// Test a ptr constant memory space. + + +def TestConstMemorySpaceAttr : Test_Attr<"TestConstMemorySpace", [ + DeclareAttrInterfaceMethods + ]> { + let mnemonic = "const_memory_space"; + let parameters = (ins DefaultValuedParameter<"unsigned", "0">:$addressSpace); + let assemblyFormat = "(`<` $addressSpace^ `>`)?"; +} #endif // TEST_ATTRDEFS diff --git a/mlir/test/lib/Dialect/Test/TestAttributes.cpp b/mlir/test/lib/Dialect/Test/TestAttributes.cpp index b66dfbfcf089..644e450175b4 100644 --- a/mlir/test/lib/Dialect/Test/TestAttributes.cpp +++ b/mlir/test/lib/Dialect/Test/TestAttributes.cpp @@ -357,6 +357,49 @@ getDynamicCustomAssemblyFormatAttr(TestDialect *testDialect) { std::move(parser), std::move(printer)); } +//===----------------------------------------------------------------------===// +// TestConstMemorySpaceAttr +//===----------------------------------------------------------------------===// + +LogicalResult TestConstMemorySpaceAttr::isValidLoad( + Type type, mlir::ptr::AtomicOrdering ordering, IntegerAttr alignment, + function_ref emitError) const { + return success(); +} + +LogicalResult TestConstMemorySpaceAttr::isValidStore( + Type type, mlir::ptr::AtomicOrdering ordering, IntegerAttr alignment, + function_ref emitError) const { + return emitError ? (emitError() << "memory space is read-only") : failure(); +} + +LogicalResult TestConstMemorySpaceAttr::isValidAtomicOp( + mlir::ptr::AtomicBinOp binOp, Type type, mlir::ptr::AtomicOrdering ordering, + IntegerAttr alignment, function_ref emitError) const { + return emitError ? (emitError() << "memory space is read-only") : failure(); +} + +LogicalResult TestConstMemorySpaceAttr::isValidAtomicXchg( + Type type, mlir::ptr::AtomicOrdering successOrdering, + mlir::ptr::AtomicOrdering failureOrdering, IntegerAttr alignment, + function_ref emitError) const { + return emitError ? (emitError() << "memory space is read-only") : failure(); +} + +LogicalResult TestConstMemorySpaceAttr::isValidAddrSpaceCast( + Type tgt, Type src, function_ref emitError) const { + return emitError + ? (emitError() << "memory space doesn't allow addrspace casts") + : failure(); +} + +LogicalResult TestConstMemorySpaceAttr::isValidPtrIntCast( + Type intLikeTy, Type ptrLikeTy, + function_ref emitError) const { + return emitError ? (emitError() << "memory space doesn't allow int-ptr casts") + : failure(); +} + //===----------------------------------------------------------------------===// // TestDialect //===----------------------------------------------------------------------===// diff --git a/mlir/test/lib/Dialect/Test/TestAttributes.h b/mlir/test/lib/Dialect/Test/TestAttributes.h index 7099bcf31729..bcbc360758ee 100644 --- a/mlir/test/lib/Dialect/Test/TestAttributes.h +++ b/mlir/test/lib/Dialect/Test/TestAttributes.h @@ -18,6 +18,7 @@ #include "TestTraits.h" #include "mlir/Dialect/Polynomial/IR/PolynomialAttributes.h" +#include "mlir/Dialect/Ptr/IR/MemorySpaceInterfaces.h" #include "mlir/Dialect/Utils/StructuredOpsUtils.h" #include "mlir/IR/Attributes.h" #include "mlir/IR/Diagnostics.h" diff --git a/mlir/test/mlir-tblgen/op-format.td b/mlir/test/mlir-tblgen/op-format.td index 4a19ffb3dfcc..8af4341952f0 100644 --- a/mlir/test/mlir-tblgen/op-format.td +++ b/mlir/test/mlir-tblgen/op-format.td @@ -73,7 +73,7 @@ def OptionalGroupA : TestFormat_Op<[{ // CHECK-NEXT: result.addAttribute("a", parser.getBuilder().getUnitAttr()) // CHECK: parser.parseKeyword("bar") // CHECK-LABEL: OptionalGroupB::print -// CHECK: if (!getAAttr()) +// CHECK: if (!(getAAttr() && getAAttr() != ((false) ? ::mlir::OpBuilder((*this)->getContext()).getUnitAttr() : nullptr))) // CHECK-NEXT: odsPrinter << ' ' << "foo" // CHECK-NEXT: else // CHECK-NEXT: odsPrinter << ' ' << "bar" @@ -84,7 +84,7 @@ def OptionalGroupB : TestFormat_Op<[{ // Optional group anchored on a default-valued attribute: // CHECK-LABEL: OptionalGroupC::parse -// CHECK: if (getAAttr() && getAAttr() != ::mlir::OpBuilder((*this)->getContext()).getStringAttr("default")) { +// CHECK: if (getAAttr() != ::mlir::OpBuilder((*this)->getContext()).getStringAttr("default")) { // CHECK-NEXT: odsPrinter << ' '; // CHECK-NEXT: odsPrinter.printAttributeWithoutType(getAAttr()); // CHECK-NEXT: } -- Gitee