From 81a54b2184e3484ddfbe3a076e421fae1cacf34e Mon Sep 17 00:00:00 2001 From: mengyuanli Date: Thu, 27 Nov 2025 15:43:24 +0800 Subject: [PATCH] auto_lowered_op own self cache --- .../src/ops/ascend/lowered/auto_lowered_op.cc | 12 +- .../src/ops/ascend/lowered/auto_lowered_op.h | 12 +- inferrt/src/ops/ascend/lowered/kernel_spec.cc | 37 +-- inferrt/src/ops/ascend/lowered/kernel_spec.h | 64 +---- .../ascend/lowered/lowered_kernel_executor.cc | 114 +++++---- .../ascend/lowered/lowered_kernel_executor.h | 18 +- .../ops/ascend/lowered/lowered_op_helper.cc | 34 +-- .../ops/ascend/lowered/lowered_op_helper.h | 9 +- .../src/ops/ascend/lowered/mlir_compiler.cc | 234 +++++------------- .../src/ops/ascend/lowered/mlir_compiler.h | 59 ++--- .../ops/lowered_add/test_lowered_add.py | 46 ---- 11 files changed, 194 insertions(+), 445 deletions(-) diff --git a/inferrt/src/ops/ascend/lowered/auto_lowered_op.cc b/inferrt/src/ops/ascend/lowered/auto_lowered_op.cc index 6a6b217b..ea4ba9ac 100644 --- a/inferrt/src/ops/ascend/lowered/auto_lowered_op.cc +++ b/inferrt/src/ops/ascend/lowered/auto_lowered_op.cc @@ -15,13 +15,19 @@ */ #include "ops/ascend/lowered/auto_lowered_op.h" +#include "ops/ascend/lowered/kernel_spec.h" #include "common/logger.h" namespace mrt::ops { -AutoLoweredOp::AutoLoweredOp(const std::string &specId) : specId_(specId), executor_(nullptr) { - executor_ = std::make_unique(specId); - LOG_OUT << "AutoLoweredOp created for spec: " << specId; +AutoLoweredOp::AutoLoweredOp(const std::string &mlirText) : spec_(nullptr), executor_(nullptr) { + // Create KernelSpec directly owned by this op + spec_ = std::make_unique("auto_lowered_op", mlirText); + + // Create executor with the owned spec + executor_ = std::make_unique(spec_.get()); + + LOG_OUT << "AutoLoweredOp created with MLIR text"; } OpsErrorCode AutoLoweredOp::CalcWorkspace(const std::vector &input, const ir::Value *output, diff --git a/inferrt/src/ops/ascend/lowered/auto_lowered_op.h b/inferrt/src/ops/ascend/lowered/auto_lowered_op.h index f15cb8cc..1d45920e 100644 --- a/inferrt/src/ops/ascend/lowered/auto_lowered_op.h +++ b/inferrt/src/ops/ascend/lowered/auto_lowered_op.h @@ -26,13 +26,17 @@ #include "ops/ascend/lowered/lowered_kernel_executor.h" namespace mrt::ops { + +// Forward declaration +struct KernelSpec; + class AutoLoweredOp : public Operator { public: /** - * @brief Construct with kernel specification ID - * @param specId Kernel spec ID registered in KernelRegistry + * @brief Construct with MLIR text + * @param mlirText MLIR code to compile and execute */ - explicit AutoLoweredOp(const std::string &specId); + explicit AutoLoweredOp(const std::string &mlirText); ~AutoLoweredOp() override = default; @@ -69,7 +73,7 @@ class AutoLoweredOp : public Operator { ir::Value *output, void *stream) override; private: - std::string specId_; // Kernel specification ID + std::unique_ptr spec_; // Owned kernel specification std::unique_ptr executor_; // Executor for this kernel }; } // namespace mrt::ops diff --git a/inferrt/src/ops/ascend/lowered/kernel_spec.cc b/inferrt/src/ops/ascend/lowered/kernel_spec.cc index 287da81b..622b064b 100644 --- a/inferrt/src/ops/ascend/lowered/kernel_spec.cc +++ b/inferrt/src/ops/ascend/lowered/kernel_spec.cc @@ -21,12 +21,14 @@ namespace mrt::ops { // Internal default compiler callback -static bool InternalDefaultMlirCompiler(const std::string &mlirInput, std::string &outputSoPath, - std::string &entryName, std::string &tilingPrefix) { - MlirCompiler::CompileResult result = MlirCompiler::Instance().CompileFromText(mlirInput); +static bool InternalDefaultMlirCompiler(const std::string &mlirInput, const std::string &executorId, + std::string &cacheDir, std::string &outputSoPath, std::string &entryName, + std::string &tilingPrefix) { + MlirCompiler::CompileResult result = MlirCompiler::Instance().CompileFromText(mlirInput, executorId); if (!result.success) { return false; } + cacheDir = result.cacheDir; outputSoPath = result.soPath; entryName = result.entryName; tilingPrefix = result.tilingPrefix; @@ -37,33 +39,4 @@ static bool InternalDefaultMlirCompiler(const std::string &mlirInput, std::strin KernelSpec::KernelSpec(const std::string &id_, const std::string &mlirText_) : id(id_), mlirText(mlirText_), compiler(InternalDefaultMlirCompiler) {} -KernelRegistry &KernelRegistry::Instance() { - static KernelRegistry instance; - return instance; -} - -bool KernelRegistry::Register(const std::string &specId, const std::string &mlirText) { - std::lock_guard lock(mutex_); - - // try_emplace returns pair, second element is true if insertion happened - auto [it, inserted] = specs_.try_emplace(specId, specId, mlirText); - return inserted; -} - -const KernelSpec *KernelRegistry::Lookup(const std::string &specId) const { - std::lock_guard lock(mutex_); - - auto it = specs_.find(specId); - if (it != specs_.end()) { - return &(it->second); - } - - return nullptr; -} - -bool KernelRegistry::Contains(const std::string &specId) const { - std::lock_guard lock(mutex_); - return specs_.find(specId) != specs_.end(); -} - } // namespace mrt::ops diff --git a/inferrt/src/ops/ascend/lowered/kernel_spec.h b/inferrt/src/ops/ascend/lowered/kernel_spec.h index 62464fea..e436d9b3 100644 --- a/inferrt/src/ops/ascend/lowered/kernel_spec.h +++ b/inferrt/src/ops/ascend/lowered/kernel_spec.h @@ -20,7 +20,6 @@ #include #include #include -#include #include #include "common/visible.h" @@ -30,14 +29,17 @@ namespace mrt::ops { /** * @brief Compiler callback function signature * - * Compiles MLIR text to a .so file and returns the .so path. + * Compiles MLIR text to a .so file and returns compilation results. * @param mlirText MLIR text content - * @param outputSoPath Output path for compiled .so file (suggested, can be modified) + * @param executorId Unique identifier for the executor (for per-executor caching) + * @param cacheDir Output cache directory for this compilation (for cleanup) + * @param outputSoPath Output path for compiled .so file * @param entryName Output entry function name (will be filled by compiler) * @param tilingPrefix Output tiling function prefix (will be filled by compiler) * @return true on success, false on failure */ -using MlirCompilerCallback = std::function; /** @@ -61,8 +63,6 @@ struct KernelSpec { std::string entry; // Host API function entry point name std::string tilingPrefix; // Prefix for tiling-related functions (for dynamic shape) - mutable std::mutex compilationMutex_; // Protect compilation state in multithreaded environment - KernelSpec(const std::string &id_, const std::string &mlirText_); /** @@ -118,58 +118,6 @@ struct KernelSpec { } }; -/** - * @brief Registry for managing kernel specifications - * - * This singleton class maintains a registry of all custom kernel specifications. - * It allows registering new kernels and looking up their specifications by ID. - * - * Thread-safe for concurrent registration and lookup. - */ -class KernelRegistry { - public: - /** - * @brief Get the singleton instance - * @return Reference to the global KernelRegistry instance - */ - static KernelRegistry &Instance(); - - /** - * @brief Register a kernel specification - * @param specId Unique identifier for the kernel - * @param mlirText MLIR text for the kernel - * @return true if registration successful, false if specId already exists - */ - bool Register(const std::string &specId, const std::string &mlirText); - - /** - * @brief Look up a kernel specification by ID - * @param specId Identifier to look up - * @return Pointer to KernelSpec if found, nullptr otherwise - */ - const KernelSpec *Lookup(const std::string &specId) const; - - /** - * @brief Check if a kernel spec is registered - * @param specId Identifier to check - * @return true if registered, false otherwise - */ - bool Contains(const std::string &specId) const; - - // Disable copy and move - KernelRegistry(const KernelRegistry &) = delete; - KernelRegistry &operator=(const KernelRegistry &) = delete; - KernelRegistry(KernelRegistry &&) = delete; - KernelRegistry &operator=(KernelRegistry &&) = delete; - - private: - KernelRegistry() = default; - ~KernelRegistry() = default; - - mutable std::mutex mutex_; - std::unordered_map specs_; -}; - } // namespace mrt::ops #endif // __OPS_ASCEND_LOWERED_KERNEL_SPEC_H__ diff --git a/inferrt/src/ops/ascend/lowered/lowered_kernel_executor.cc b/inferrt/src/ops/ascend/lowered/lowered_kernel_executor.cc index a5dbdfaa..1fb12754 100644 --- a/inferrt/src/ops/ascend/lowered/lowered_kernel_executor.cc +++ b/inferrt/src/ops/ascend/lowered/lowered_kernel_executor.cc @@ -50,25 +50,46 @@ LoweredKernelCacheEntry::~LoweredKernelCacheEntry() { // LoweredKernelExecutor Implementation // ============================================================ -LoweredKernelExecutor::LoweredKernelExecutor(const std::string &specId) - : specId_(specId), spec_(nullptr), currentEntry_(nullptr) { - // Look up kernel spec from registry - spec_ = KernelRegistry::Instance().Lookup(specId); +LoweredKernelExecutor::LoweredKernelExecutor(const KernelSpec *spec) + : spec_(spec), cacheDir_(""), currentEntry_(nullptr) { if (spec_ == nullptr) { - LOG_EXCEPTION << "Kernel spec not found: " << specId; + LOG_EXCEPTION << "KernelSpec pointer is null"; } - LOG_OUT << "LoweredKernelExecutor created for spec: " << specId; + // Generate unique executor instance ID + static std::atomic instanceCounter{0}; + uint64_t instanceId = instanceCounter.fetch_add(1, std::memory_order_relaxed); + std::ostringstream oss; + oss << spec_->id << "_inst" << instanceId; + executorId_ = oss.str(); + + LOG_OUT << "LoweredKernelExecutor created with ID: " << executorId_ << " for spec: " << spec_->id; } LoweredKernelExecutor::~LoweredKernelExecutor() { - // Cache entries will be cleaned up automatically via unique_ptr + // Clean up compilation cache directory (unless keepIntermediateFiles is set) + if (!cacheDir_.empty()) { + const auto &options = MlirCompiler::Instance().GetOptions(); + if (options.keepIntermediateFiles) { + LOG_OUT << "Keeping compilation cache directory for debugging: " << cacheDir_; + } else { + // Remove executor-specific parent directory to avoid empty dir residue + // cacheDir_ format: baseDir/executorId/uniqueId -> remove baseDir/executorId + size_t lastSlash = cacheDir_.find_last_of('/'); + std::string targetDir = (lastSlash != std::string::npos) ? cacheDir_.substr(0, lastSlash) : cacheDir_; + + std::string rmCmd = "rm -rf " + targetDir; + int ret = system(rmCmd.c_str()); + (void)ret; + LOG_OUT << "Cleaned up executor cache directory: " << targetDir; + } + } } std::string LoweredKernelExecutor::GenerateCacheKey(const std::vector &inputs, const ir::Value *output) const { std::ostringstream oss; - oss << specId_ << "|inp:" << inputs.size() << "|"; + oss << spec_->id << "|inp:" << inputs.size() << "|"; // Add input shapes with explicit null markers for (size_t i = 0; i < inputs.size(); ++i) { @@ -120,38 +141,40 @@ LoweredCacheEntryPtr LoweredKernelExecutor::LoadKernel() { // Get mutable spec pointer for potential compilation KernelSpec *mutableSpec = const_cast(spec_); - // Lock to prevent race condition during compilation - { - std::lock_guard lock(mutableSpec->compilationMutex_); + // Check if we need to compile MLIR first + if (mutableSpec->NeedsCompilation()) { + LOG_OUT << "Kernel needs compilation from MLIR"; - // Check if we need to compile MLIR first - if (mutableSpec->NeedsCompilation()) { - LOG_OUT << "Kernel needs compilation from MLIR"; + if (mutableSpec->compiler == nullptr) { + LOG_ERROR << "No compiler callback provided for MLIR-based kernel"; + return nullptr; + } - if (mutableSpec->compiler == nullptr) { - LOG_ERROR << "No compiler callback provided for MLIR-based kernel"; - return nullptr; - } + // Call compiler callback + std::string compilationCacheDir, outputSoPath, entryName, tilingPrefix; + bool compileSuccess = + mutableSpec->compiler(mutableSpec->mlirText, executorId_, compilationCacheDir, outputSoPath, entryName, tilingPrefix); - // Call compiler callback - std::string outputSoPath, entryName, tilingPrefix; - bool compileSuccess = mutableSpec->compiler(mutableSpec->mlirText, outputSoPath, entryName, tilingPrefix); + if (!compileSuccess) { + LOG_ERROR << "MLIR compilation failed for spec: " << mutableSpec->id; + return nullptr; + } - if (!compileSuccess) { - LOG_ERROR << "MLIR compilation failed for spec: " << mutableSpec->id; - return nullptr; - } + // Store cache directory in executor for cleanup (only once) + cacheDir_ = compilationCacheDir; - // Update spec with compiled results - mutableSpec->kernelLibPath = outputSoPath; - mutableSpec->entry = entryName; - mutableSpec->tilingPrefix = tilingPrefix; + // Update spec with compiled results + mutableSpec->kernelLibPath = outputSoPath; + mutableSpec->entry = entryName; + mutableSpec->tilingPrefix = tilingPrefix; - LOG_OUT << "MLIR compilation successful:"; - LOG_OUT << " - .so path: " << outputSoPath; - LOG_OUT << " - entry: " << entryName; - LOG_OUT << " - tiling prefix: " << tilingPrefix; - } + LOG_OUT << "MLIR compilation successful:"; + LOG_OUT << " - cache dir: " << compilationCacheDir; + LOG_OUT << " - .so path: " << outputSoPath; + LOG_OUT << " - entry: " << entryName; + LOG_OUT << " - tiling prefix: " << tilingPrefix; + } else { + LOG_OUT << "Kernel already compiled, using cached binary: " << mutableSpec->kernelLibPath; } // Check if spec is ready to load @@ -409,15 +432,12 @@ int LoweredKernelExecutor::GetWorkspaceSize(size_t *workspaceSize, const std::ve std::string cacheKey = GenerateCacheKey(inputs, output); // Check cache - { - std::lock_guard lock(cacheMutex_); - auto it = cache_.find(cacheKey); - if (it != cache_.end()) { - LOG_OUT << "Cache hit for key: " << cacheKey; - currentEntry_ = it->second.get(); - *workspaceSize = currentEntry_->workspaceSize; - return 0; - } + auto it = cache_.find(cacheKey); + if (it != cache_.end()) { + LOG_OUT << "Cache hit for key: " << cacheKey; + currentEntry_ = it->second.get(); + *workspaceSize = currentEntry_->workspaceSize; + return 0; } LOG_OUT << "Cache miss for key: " << cacheKey; @@ -443,11 +463,8 @@ int LoweredKernelExecutor::GetWorkspaceSize(size_t *workspaceSize, const std::ve } // Cache the entry - { - std::lock_guard lock(cacheMutex_); - cache_[cacheKey] = std::move(entry); - currentEntry_ = cache_[cacheKey].get(); - } + cache_[cacheKey] = std::move(entry); + currentEntry_ = cache_[cacheKey].get(); *workspaceSize = currentEntry_->workspaceSize; LOG_OUT << "Workspace size: " << *workspaceSize; @@ -463,7 +480,6 @@ int LoweredKernelExecutor::Launch(void *workspace, size_t workspaceSize, void *s if (entry == nullptr) { // Try to find in cache std::string cacheKey = GenerateCacheKey(inputs, output); - std::lock_guard lock(cacheMutex_); auto it = cache_.find(cacheKey); if (it != cache_.end()) { entry = it->second.get(); diff --git a/inferrt/src/ops/ascend/lowered/lowered_kernel_executor.h b/inferrt/src/ops/ascend/lowered/lowered_kernel_executor.h index b6caf4b4..06f150de 100644 --- a/inferrt/src/ops/ascend/lowered/lowered_kernel_executor.h +++ b/inferrt/src/ops/ascend/lowered/lowered_kernel_executor.h @@ -22,7 +22,6 @@ #include #include #include -#include #include "ops/ascend/lowered/kernel_spec.h" #include "ir/graph.h" @@ -93,12 +92,18 @@ class LoweredKernelExecutor { public: /** * @brief Construct an executor for a specific kernel - * @param specId Kernel specification ID (registered in KernelRegistry) + * @param spec Pointer to KernelSpec (non-owning, must outlive this executor) */ - explicit LoweredKernelExecutor(const std::string &specId); + explicit LoweredKernelExecutor(const KernelSpec *spec); ~LoweredKernelExecutor(); + /** + * @brief Get unique executor instance ID + * Format: _ + */ + const std::string &GetExecutorId() const { return executorId_; } + /** * @brief Calculate workspace size required for kernel execution * @@ -182,10 +187,11 @@ class LoweredKernelExecutor { */ void AddTilingArgs(std::vector &args, const LoweredKernelCacheEntry *entry); - std::string specId_; // Kernel specification ID - const KernelSpec *spec_; // Cached spec pointer + const KernelSpec *spec_; // Non-owning pointer to spec + + std::string executorId_; // Unique executor instance ID + std::string cacheDir_; // Cache directory for compiled kernel (for cleanup) - mutable std::mutex cacheMutex_; // Protects cache access std::unordered_map cache_; // Cache by shape+config key LoweredKernelCacheEntry* currentEntry_; // Current cache entry for Launch reuse (non-owning) diff --git a/inferrt/src/ops/ascend/lowered/lowered_op_helper.cc b/inferrt/src/ops/ascend/lowered/lowered_op_helper.cc index d736fe8e..82ae932e 100644 --- a/inferrt/src/ops/ascend/lowered/lowered_op_helper.cc +++ b/inferrt/src/ops/ascend/lowered/lowered_op_helper.cc @@ -16,51 +16,21 @@ #include "ops/ascend/lowered/lowered_op_helper.h" -#include -#include -#include - #include "common/logger.h" #include "ops/ascend/lowered/auto_lowered_op.h" -#include "ops/ascend/lowered/kernel_spec.h" namespace mrt::ops { -// Generate kernel id from mlir_text hash -static std::string GenerateKernelId(const std::string &mlir_text) { - std::size_t hash_value = std::hash{}(mlir_text); - return "kernel_" + std::to_string(hash_value); -} - std::unique_ptr LoweredOpHelper::CreateFromMlirText(const std::string &mlir_text) { if (mlir_text.empty()) { LOG_ERROR << "MLIR text is empty"; return nullptr; } - std::string kernel_id = GenerateKernelId(mlir_text); - - if (KernelRegistry::Instance().Contains(kernel_id)) { - LOG_OUT << "Kernel already registered: " << kernel_id << ", reusing"; - try { - return std::make_unique(kernel_id); - } catch (const std::exception &e) { - LOG_ERROR << "Failed to create AutoLoweredOp for " << kernel_id << ": " << e.what(); - return nullptr; - } - } - - if (!KernelRegistry::Instance().Register(kernel_id, mlir_text)) { - LOG_ERROR << "Failed to register kernel spec: " << kernel_id; - return nullptr; - } - - LOG_OUT << "Registered lowered kernel: " << kernel_id; - try { - return std::make_unique(kernel_id); + return std::make_unique(mlir_text); } catch (const std::exception &e) { - LOG_ERROR << "Failed to create AutoLoweredOp for " << kernel_id << ": " << e.what(); + LOG_ERROR << "Failed to create AutoLoweredOp: " << e.what(); return nullptr; } } diff --git a/inferrt/src/ops/ascend/lowered/lowered_op_helper.h b/inferrt/src/ops/ascend/lowered/lowered_op_helper.h index 1e2bbae8..afbdb128 100644 --- a/inferrt/src/ops/ascend/lowered/lowered_op_helper.h +++ b/inferrt/src/ops/ascend/lowered/lowered_op_helper.h @@ -37,12 +37,9 @@ class MRT_EXPORT LoweredOpHelper { /** * @brief Create an operator from MLIR text * - * This function handles all internal complexity: - * - Uses mlir_text hash as unique identifier (auto caching) - * - Registers kernel spec to KernelRegistry - * - Creates and returns AutoLoweredOp instance - * - * Same mlir_text will reuse cached compiled kernel automatically. + * This function: + * - Creates an AutoLoweredOp instance with the given MLIR text + * - Each op instance owns its KernelSpec and cache independently * * @param mlir_text MLIR code as string * @return Unique pointer to Operator instance, or nullptr on failure diff --git a/inferrt/src/ops/ascend/lowered/mlir_compiler.cc b/inferrt/src/ops/ascend/lowered/mlir_compiler.cc index 3c307723..4a95c891 100644 --- a/inferrt/src/ops/ascend/lowered/mlir_compiler.cc +++ b/inferrt/src/ops/ascend/lowered/mlir_compiler.cc @@ -20,6 +20,7 @@ #include #include +#include #include #include #include @@ -42,12 +43,10 @@ MlirCompiler::MlirCompiler() { // Initialize options with defaults from environment variables InitializeDefaultOptions(); - // Create cache directory - if (options_.enableCache) { - std::string mkdirCmd = "mkdir -p " + options_.cacheDir; - int ret = system(mkdirCmd.c_str()); - (void)ret; // Ignore return value - } + // Create output directory + std::string mkdirCmd = "mkdir -p " + options_.cacheDir; + int ret = system(mkdirCmd.c_str()); + (void)ret; // Ignore return value } MlirCompiler &MlirCompiler::Instance() { @@ -57,7 +56,7 @@ MlirCompiler &MlirCompiler::Instance() { void MlirCompiler::InitializeDefaultOptions() { // Read cache directory from environment variable - const char *cacheDirEnv = std::getenv("INFERRT_MLIR_CACHE_DIR"); + const char *cacheDirEnv = std::getenv("MRT_LOWERED_CACHE_DIR"); if (cacheDirEnv != nullptr) { options_.cacheDir = cacheDirEnv; } else { @@ -78,34 +77,28 @@ void MlirCompiler::InitializeDefaultOptions() { options_.bishengirCompilePath = "bishengir-compile"; } - // Read verbose flag from environment variable - const char *verboseEnv = std::getenv("INFERRT_MLIR_VERBOSE"); - if (verboseEnv != nullptr && (std::string(verboseEnv) == "1" || std::string(verboseEnv) == "true")) { - options_.verbose = true; + // Read keep intermediate files flag from environment variable + const char *keepFilesEnv = std::getenv("MRT_LOWERED_MLIR_KEEP_FILES"); + if (keepFilesEnv != nullptr && (std::string(keepFilesEnv) == "1" || std::string(keepFilesEnv) == "true")) { + options_.keepIntermediateFiles = true; } else { - options_.verbose = false; + options_.keepIntermediateFiles = false; } - // Always enable cache by default - options_.enableCache = true; - LOG_OUT << "MlirCompiler initialized with options:"; LOG_OUT << " cacheDir: " << options_.cacheDir; LOG_OUT << " bishengirCompilePath: " << options_.bishengirCompilePath; - LOG_OUT << " verbose: " << (options_.verbose ? "true" : "false"); - LOG_OUT << " enableCache: " << (options_.enableCache ? "true" : "false"); + LOG_OUT << " keepIntermediateFiles: " << (options_.keepIntermediateFiles ? "true" : "false"); } void MlirCompiler::SetOptions(const CompileOptions &options) { std::lock_guard lock(mutex_); options_ = options; - // Create cache directory if needed - if (options_.enableCache) { - std::string mkdirCmd = "mkdir -p " + options_.cacheDir; - int ret = system(mkdirCmd.c_str()); - (void)ret; - } + // Create output directory if needed + std::string mkdirCmd = "mkdir -p " + options_.cacheDir; + int ret = system(mkdirCmd.c_str()); + (void)ret; } std::string MlirCompiler::ComputeHash(const std::string &mlirContent) const { @@ -113,98 +106,17 @@ std::string MlirCompiler::ComputeHash(const std::string &mlirContent) const { reinterpret_cast(mlirContent.data()), mlirContent.size())); std::ostringstream oss; - for (auto byte : hash) { - oss << std::hex << std::setw(2) << std::setfill('0') << static_cast(byte); + // Only use first 16 chars for shorter filenames + for (size_t i = 0; i < 8 && i < hash.size(); ++i) { + oss << std::hex << std::setw(2) << std::setfill('0') << static_cast(hash[i]); } return oss.str(); } -bool MlirCompiler::CheckCache(const std::string &hash, CompileResult &result) { - if (!options_.enableCache) { - return false; - } - - // Check in-memory cache first - { - std::lock_guard lock(mutex_); - auto it = cacheMap_.find(hash); - if (it != cacheMap_.end() && it->second.success) { - // Verify .so file still exists - struct stat st; - if (stat(it->second.soPath.c_str(), &st) == 0) { - result = it->second; - LOG_OUT << "MLIR cache hit (in-memory): " << hash; - return true; - } - } - } - - // Check disk cache - std::string soPath = options_.cacheDir + "/" + hash + ".so"; - std::string metaPath = options_.cacheDir + "/" + hash + ".meta"; - - struct stat st; - if (stat(soPath.c_str(), &st) != 0) { - return false; // .so file doesn't exist - } - - // Read metadata file - std::ifstream metaFile(metaPath); - if (!metaFile.is_open()) { - LOG_OUT << "MLIR cache metadata not found: " << metaPath; - return false; - } - - std::string entryName, tilingPrefix; - std::getline(metaFile, entryName); - std::getline(metaFile, tilingPrefix); - metaFile.close(); - - if (entryName.empty()) { - LOG_OUT << "Invalid MLIR cache metadata: " << metaPath; - return false; - } - - result.success = true; - result.soPath = soPath; - result.entryName = entryName; - result.tilingPrefix = tilingPrefix; - - // Update in-memory cache - { - std::lock_guard lock(mutex_); - cacheMap_[hash] = result; - } - - LOG_OUT << "MLIR cache hit (disk): " << hash; - return true; -} - -void MlirCompiler::SaveToCache(const std::string &hash, const CompileResult &result) { - if (!options_.enableCache || !result.success) { - return; - } - - // Save to in-memory cache - { - std::lock_guard lock(mutex_); - cacheMap_[hash] = result; - } - - // Save metadata to disk - std::string metaPath = options_.cacheDir + "/" + hash + ".meta"; - std::ofstream metaFile(metaPath); - if (!metaFile.is_open()) { - LOG_ERROR << "Failed to write MLIR cache metadata: " << metaPath; - return; - } - - metaFile << result.entryName << "\n"; - metaFile << result.tilingPrefix << "\n"; - metaFile.close(); - - LOG_OUT << "MLIR cache saved: " << hash; +std::string MlirCompiler::GenerateUniqueId(const std::string &mlirText) { + // Use only hash for content-based caching + return ComputeHash(mlirText); } bool MlirCompiler::ReadFileContent(const std::string &filePath, std::string &content) { @@ -236,10 +148,6 @@ bool MlirCompiler::WriteFileContent(const std::string &filePath, const std::stri } bool MlirCompiler::ExecuteCommand(const std::string &command, std::string &output, int &exitCode) { - if (options_.verbose) { - LOG_OUT << "Executing command: " << command; - } - // Open pipe to command FILE *pipe = popen((command + " 2>&1").c_str(), "r"); if (pipe == nullptr) { @@ -258,11 +166,6 @@ bool MlirCompiler::ExecuteCommand(const std::string &command, std::string &outpu output = oss.str(); exitCode = pclose(pipe); - if (options_.verbose) { - LOG_OUT << "Command output:\n" << output; - LOG_OUT << "Command exit code: " << exitCode; - } - return true; } @@ -340,7 +243,7 @@ bool MlirCompiler::ExtractFunctionNames(const std::string &mlirText, std::string return false; } -MlirCompiler::CompileResult MlirCompiler::CompileFromText(const std::string &mlirText) { +MlirCompiler::CompileResult MlirCompiler::CompileFromText(const std::string &mlirText, const std::string &executorId) { CompileResult result; if (mlirText.empty()) { @@ -349,15 +252,26 @@ MlirCompiler::CompileResult MlirCompiler::CompileFromText(const std::string &mli return result; } - // Compute hash for caching - std::string hash = ComputeHash(mlirText); + // Generate unique ID for this compilation + std::string uniqueId = GenerateUniqueId(mlirText); - // Check cache - if (CheckCache(hash, result)) { + // Create cache directory: cacheDir/executorId/hash (executor-specific caching) + std::string compilationCacheDir; + if (!executorId.empty()) { + compilationCacheDir = options_.cacheDir + "/" + executorId + "/" + uniqueId; + } else { + compilationCacheDir = options_.cacheDir + "/" + uniqueId; + } + std::string mkdirCmd = "mkdir -p " + compilationCacheDir; + int ret = system(mkdirCmd.c_str()); + if (ret != 0) { + result.errorMessage = "Failed to create cache directory: " + compilationCacheDir; + LOG_ERROR << result.errorMessage; return result; } - LOG_OUT << "Compiling MLIR text (hash: " << hash << ")"; + LOG_OUT << "Compiling MLIR text (id: " << uniqueId << ")"; + LOG_OUT << "Cache directory: " << compilationCacheDir; // Extract function names from MLIR text std::string entryName, tilingPrefix; @@ -367,9 +281,9 @@ MlirCompiler::CompileResult MlirCompiler::CompileFromText(const std::string &mli return result; } - // Create temporary files - std::string inputFile = options_.cacheDir + "/" + hash + "_input.mlir"; - std::string outputSo = options_.cacheDir + "/" + hash + ".so"; + // Create temporary files in the unique cache directory + std::string inputFile = compilationCacheDir + "/" + uniqueId + "_input.mlir"; + std::string outputSo = compilationCacheDir + "/" + uniqueId + ".so"; // Write MLIR text to file if (!WriteFileContent(inputFile, mlirText)) { @@ -378,15 +292,23 @@ MlirCompiler::CompileResult MlirCompiler::CompileFromText(const std::string &mli return result; } - // Run bishengir-compile (input is assumed to be Linalg IR) - if (!RunBishengirCompile(inputFile, outputSo)) { + // Run bishengir-opt to convert linalg.generic to named operations + std::string optOutputFile = compilationCacheDir + "/" + uniqueId + "_opt.mlir"; + if (!RunBishengirOpt(inputFile, optOutputFile)) { + result.errorMessage = "bishengir-opt failed"; + LOG_ERROR << result.errorMessage; + return result; + } + + // Run bishengir-compile on the optimized MLIR + if (!RunBishengirCompile(optOutputFile, outputSo)) { result.errorMessage = "bishengir-compile failed"; LOG_ERROR << result.errorMessage; return result; } // bishengir-compile adds "lib" prefix to the output filename - // If outputSo is "path/hash.so", actual file is "path/libhash.so" + // If outputSo is "path/uniqueId.so", actual file is "path/libuniqueId.so" std::string actualSo; size_t lastSlash = outputSo.find_last_of('/'); if (lastSlash != std::string::npos) { @@ -401,51 +323,23 @@ MlirCompiler::CompileResult MlirCompiler::CompileFromText(const std::string &mli LOG_OUT << "Entry function: " << result.entryName; LOG_OUT << "Tiling prefix: " << result.tilingPrefix; - // Clean up temporary input file (keep .so) - unlink(inputFile.c_str()); + // Clean up intermediate files if not keeping them + if (!options_.keepIntermediateFiles) { + unlink(inputFile.c_str()); + unlink(optOutputFile.c_str()); + } else { + LOG_OUT << "Intermediate files kept for debugging:"; + LOG_OUT << " Input: " << inputFile; + LOG_OUT << " Optimized: " << optOutputFile; + } // Fill result result.success = true; - result.soPath = actualSo; // Use the actual .so path with "lib" prefix - - // Save to cache - SaveToCache(hash, result); + result.cacheDir = compilationCacheDir; // Return the cache directory path + result.soPath = actualSo; // Use the actual .so path with "lib" prefix LOG_OUT << "MLIR compilation successful: " << actualSo; return result; } -void MlirCompiler::ClearCache() { - std::lock_guard lock(mutex_); - - cacheMap_.clear(); - - // Remove all files in cache directory - std::string cmd = "rm -rf " + options_.cacheDir + "/*"; - int ret = system(cmd.c_str()); - (void)ret; - - LOG_OUT << "MLIR cache cleared"; -} - -void MlirCompiler::GetCacheStats(size_t *totalEntries, size_t *cacheSizeBytes) const { - std::lock_guard lock(mutex_); - - if (totalEntries != nullptr) { - *totalEntries = cacheMap_.size(); - } - - if (cacheSizeBytes != nullptr) { - *cacheSizeBytes = 0; - - // Calculate total size of .so files - for (const auto &entry : cacheMap_) { - struct stat st; - if (stat(entry.second.soPath.c_str(), &st) == 0) { - *cacheSizeBytes += st.st_size; - } - } - } -} - } // namespace mrt::ops diff --git a/inferrt/src/ops/ascend/lowered/mlir_compiler.h b/inferrt/src/ops/ascend/lowered/mlir_compiler.h index 97603359..1bdf842c 100644 --- a/inferrt/src/ops/ascend/lowered/mlir_compiler.h +++ b/inferrt/src/ops/ascend/lowered/mlir_compiler.h @@ -34,7 +34,7 @@ namespace mrt::ops { * 2. BiSheng IR → .so file (via backend compiler) * * Features: - * - Compilation caching based on MLIR content hash + * - Stateless compilation (each call generates unique .so) * - Thread-safe compilation * - Configurable output directory * - Automatic entry/tiling function name extraction @@ -46,9 +46,8 @@ class MlirCompiler { */ struct CompileOptions { std::string cacheDir = - ".mrt_lowered_cache"; // Cache directory for compiled kernels (default: current_dir/.mrt_lowered_cache) + ".mrt_lowered_cache"; // Directory for compiled kernels and intermediate files std::string bishengirCompilePath = "bishengir-compile"; // Path to bishengir-compile tool - bool enableCache = true; // Enable compilation cache bool verbose = false; // Enable verbose logging CompileOptions() = default; @@ -59,6 +58,7 @@ class MlirCompiler { */ struct CompileResult { bool success = false; // Compilation success flag + std::string cacheDir; // Cache directory for this compilation (unique per compilation) std::string soPath; // Output .so file path std::string entryName; // Host API entry function name std::string tilingPrefix; // Tiling function prefix @@ -89,31 +89,20 @@ class MlirCompiler { * @brief Compile MLIR text to .so file * * Pipeline: - * 1. Hash the MLIR text for cache lookup - * 2. Check if cached .so exists and is valid - * 3. If not cached: - * a. Write MLIR text to temporary file - * b. Run bishengir-compile to generate .so - * c. Extract entry/tiling function names - * d. Cache the result - * 4. Return compilation result + * 1. Generate unique filename based on MLIR hash + * 2. Write MLIR text to temporary file + * 3. Run bishengir-opt and bishengir-compile to generate .so + * 4. Extract entry/tiling function names + * 5. Return compilation result + * + * Note: Compilation results are cached per executorId. Same mlirText with same + * executorId will reuse the same cache directory across multiple calls. * * @param mlirText MLIR text (Linalg dialect) + * @param executorId Unique identifier for the executor (for per-executor caching) * @return Compilation result with .so path and entry names */ - CompileResult CompileFromText(const std::string &mlirText); - /** - * @brief Clear compilation cache - * Removes all cached .so files and metadata - */ - void ClearCache(); - - /** - * @brief Get cache statistics - * @param totalEntries Output: total number of cached kernels - * @param cacheSizeBytes Output: total cache size in bytes - */ - void GetCacheStats(size_t *totalEntries, size_t *cacheSizeBytes) const; + CompileResult CompileFromText(const std::string &mlirText, const std::string &executorId = ""); // Disable copy and move MlirCompiler(const MlirCompiler &) = delete; @@ -138,24 +127,17 @@ class MlirCompiler { /** * @brief Compute hash of MLIR content * @param mlirContent MLIR text - * @return Hash string (hex) + * @return Hash string (hex, first 16 chars) */ std::string ComputeHash(const std::string &mlirContent) const; /** - * @brief Check if cached .so exists and is valid - * @param hash MLIR content hash - * @param result Output: cached compilation result - * @return true if cache hit and valid - */ - bool CheckCache(const std::string &hash, CompileResult &result); - - /** - * @brief Save compilation result to cache - * @param hash MLIR content hash - * @param result Compilation result to cache + * @brief Generate unique ID for .so filename + * Format: __ + * @param mlirText MLIR text to hash + * @return Unique identifier string */ - void SaveToCache(const std::string &hash, const CompileResult &result); + std::string GenerateUniqueId(const std::string &mlirText); /** * @brief Run bishengir-compile to generate .so @@ -200,8 +182,7 @@ class MlirCompiler { bool ExecuteCommand(const std::string &command, std::string &output, int &exitCode); CompileOptions options_; - mutable std::mutex mutex_; - std::unordered_map cacheMap_; // In-memory cache + mutable std::mutex mutex_; // Protects counter in GenerateUniqueId }; } // namespace mrt::ops diff --git a/tests/st/inferrt/ops/lowered_add/test_lowered_add.py b/tests/st/inferrt/ops/lowered_add/test_lowered_add.py index e0ebe456..1b1e54b5 100644 --- a/tests/st/inferrt/ops/lowered_add/test_lowered_add.py +++ b/tests/st/inferrt/ops/lowered_add/test_lowered_add.py @@ -104,49 +104,3 @@ def test_lowered_bias_add_dynamic_shape(): f"Shape [{M}x{N}] failed with max_diff={torch.max(torch.abs(result - expected)).item()}" print("All dynamic shape tests passed.") - - -@arg_mark(plat_marks=["platform_ascend"], level_mark="level0", card_mark="onecard", essential_mark="essential") -def test_lowered_custom_op_compilation_cache(): - """ - Feature: Test MLIR compilation cache - Description: Second execution should hit cache and be faster - Expectation: No recompilation on second run - """ - import time - - script_dir = os.path.dirname(os.path.abspath(__file__)) - mlir_path = os.path.join(script_dir, "add_dyn.mlir") - os.environ["LOWERED_BIAS_ADD_MLIR_PATH"] = mlir_path - os.environ["TEST_DIR"] = script_dir - - op_source = os.path.join(script_dir, "lowered_add_custom_op.cc") - mrt.ops.load(name="lowered_bias_add", sources=[op_source], backend="Ascend", - extra_ldflags=["-lops_ascend_lowered"]) - - @torch.library.custom_op("mrt::lowered_bias_add", mutates_args=()) - def lowered_bias_add_op(x: torch.Tensor, bias: torch.Tensor) -> torch.Tensor: - raise NotImplementedError() - - @torch.library.register_fake("mrt::lowered_bias_add") - def _(x, bias): - return x - - bias_add_compiled = torch.compile( - lambda x, b: torch.ops.mrt.lowered_bias_add(x, b), - backend=backend - ) - - x = torch.randn(1, 6144, dtype=torch.float16).npu() - bias = torch.randn(1, 6144, dtype=torch.float16).npu() - - start_time = time.time() - result1 = bias_add_compiled(x, bias) - first_time = time.time() - start_time - - start_time = time.time() - result2 = bias_add_compiled(x, bias) - second_time = time.time() - start_time - - assert torch.equal(result1, result2), "Results differ between executions" - print("Compilation cache test passed.") -- Gitee