diff --git a/torch_npu/csrc/aten/ops/op_api/op_api_common.cpp b/torch_npu/csrc/aten/ops/op_api/op_api_common.cpp index 4e19510e5f7a61aaa20fab12634e812175392cf2..6430f7a5838b91055ba71549793ebe275f93997d 100644 --- a/torch_npu/csrc/aten/ops/op_api/op_api_common.cpp +++ b/torch_npu/csrc/aten/ops/op_api/op_api_common.cpp @@ -225,26 +225,52 @@ uint64_t MurmurHash(const void *key, const int len, const uint32_t seed = 0xdead uint64_t k2 = 0; // because the size of a block is 16, different offsets are calculated for tail blocks // for different sizes - switch(len & 15) - { - case 15: k2 ^= ((uint64_t)tail[14]) << 48; - case 14: k2 ^= ((uint64_t)tail[13]) << 40; - case 13: k2 ^= ((uint64_t)tail[12]) << 32; - case 12: k2 ^= ((uint64_t)tail[11]) << 24; - case 11: k2 ^= ((uint64_t)tail[10]) << 16; - case 10: k2 ^= ((uint64_t)tail[ 9]) << 8; - case 9: k2 ^= ((uint64_t)tail[ 8]) << 0; - k2 *= c2; k2 = ROTL64(k2, 33); k2 *= c1; h2 ^= k2; - - case 8: k1 ^= ((uint64_t)tail[ 7]) << 56; - case 7: k1 ^= ((uint64_t)tail[ 6]) << 48; - case 6: k1 ^= ((uint64_t)tail[ 5]) << 40; - case 5: k1 ^= ((uint64_t)tail[ 4]) << 32; - case 4: k1 ^= ((uint64_t)tail[ 3]) << 24; - case 3: k1 ^= ((uint64_t)tail[ 2]) << 16; - case 2: k1 ^= ((uint64_t)tail[ 1]) << 8; - case 1: k1 ^= ((uint64_t)tail[ 0]) << 0; - k1 *= c1; k1 = ROTL64(k1, 31); k1 *= c2; h1 ^= k1; + switch(len & 15) { + case 15: + k2 ^= ((uint64_t)tail[14]) << 48; + [[fallthrough]]; + case 14: + k2 ^= ((uint64_t)tail[13]) << 40; + [[fallthrough]]; + case 13: + k2 ^= ((uint64_t)tail[12]) << 32; + [[fallthrough]]; + case 12: + k2 ^= ((uint64_t)tail[11]) << 24; + [[fallthrough]]; + case 11: + k2 ^= ((uint64_t)tail[10]) << 16; + [[fallthrough]]; + case 10: + k2 ^= ((uint64_t)tail[9]) << 8; + [[fallthrough]]; + case 9: k2 ^= ((uint64_t)tail[8]) << 0; + k2 *= c2; k2 = ROTL64(k2, 33); k2 *= c1; h2 ^= k2; + [[fallthrough]]; + case 8: + k1 ^= ((uint64_t)tail[7]) << 56; + [[fallthrough]]; + case 7: + k1 ^= ((uint64_t)tail[6]) << 48; + [[fallthrough]]; + case 6: + k1 ^= ((uint64_t)tail[5]) << 40; + [[fallthrough]]; + case 5: + k1 ^= ((uint64_t)tail[4]) << 32; + [[fallthrough]]; + case 4: + k1 ^= ((uint64_t)tail[3]) << 24; + [[fallthrough]]; + case 3: + k1 ^= ((uint64_t)tail[2]) << 16; + [[fallthrough]]; + case 2: + k1 ^= ((uint64_t)tail[1]) << 8; + [[fallthrough]]; + case 1: k1 ^= ((uint64_t)tail[0]) << 0; + k1 *= c1; k1 = ROTL64(k1, 31); k1 *= c2; h1 ^= k1; + [[fallthrough]]; }; h1 ^= len; diff --git a/torch_npu/csrc/aten/ops/op_api/op_api_common.h b/torch_npu/csrc/aten/ops/op_api/op_api_common.h index 063a1021a387292780f8dc8466625242948db6f7..3cb9193420561bf8af35c62ec1da1f7341cd8536 100644 --- a/torch_npu/csrc/aten/ops/op_api/op_api_common.h +++ b/torch_npu/csrc/aten/ops/op_api/op_api_common.h @@ -406,6 +406,51 @@ typedef void(*InitPTACacheThreadLocal) (); typedef void(*SetPTAHashKey) (uint64_t); typedef bool(*CanUsePTACache) (const char *); +template +bool HitCache(aclrtStream acl_stream, const char *aclnn_api, void *phrase2, Args && ...args) { + static const auto ptaGetExecCacheAddr = GetOpApiFuncAddr("PTAGetExecCache"); + static const auto initPTACacheThreadLocalAddr = GetOpApiFuncAddr("InitPTACacheThreadLocal"); + static const auto setPTAHashKeyAddr = GetOpApiFuncAddr("SetPTAHashKey"); + static const auto canUsePTACacheAddr = GetOpApiFuncAddr("CanUsePTACache"); + PTAGetExecCache ptaGetExecCacheFunc = reinterpret_cast(ptaGetExecCacheAddr); + InitPTACacheThreadLocal initPTACacheThreadLocalFunc = reinterpret_cast(initPTACacheThreadLocalAddr); + SetPTAHashKey setPTAHashKeyFunc = reinterpret_cast(setPTAHashKeyAddr); + CanUsePTACache canUsePTACacheFunc = reinterpret_cast(canUsePTACacheAddr); + bool has_func = ptaGetExecCacheFunc && initPTACacheThreadLocalFunc && setPTAHashKeyFunc; + bool can_use = canUsePTACacheFunc && canUsePTACacheFunc(aclnn_api); + if (!has_func || !can_use) { + return false; + } + uint64_t workspace_size = 0; + uint64_t *workspace_size_addr = &workspace_size; + initPTACacheThreadLocalFunc(); + g_hashOffset = 0; + AddParamToBuf(std::string(aclnn_api), args...); + uint64_t hash = CalcHashId(); + setPTAHashKeyFunc(hash); + aclOpExecutor *executor = ptaGetExecCacheFunc(hash, workspace_size_addr); + if (executor == nullptr) { + return false; + } + void *workspace_addr = nullptr; + if (workspace_size != 0) { + auto workspace_tensor = at_npu::native::CalcuOpUtil::UnsafeEmptyWorkspace(workspace_size); + workspace_addr = workspace_tensor.storage().data(); + } + auto acl_call = [workspace_addr, workspace_size, acl_stream, executor, phrase2] () -> int { + typedef int(*OpApiFunc)(void*, uint64_t, aclOpExecutor*, const aclrtStream); + OpApiFunc opApiFunc = reinterpret_cast(phrase2); + auto api_ret = opApiFunc(workspace_addr, workspace_size, executor, acl_stream); + TORCH_CHECK(api_ret == 0, "call failed, detail:", aclGetRecentErrMsg()); + return api_ret; + }; + at_npu::native::OpCommand cmd; + cmd.Name(aclnn_api); + cmd.SetCustomHandler(acl_call); + cmd.Run(); + return true; +} + /** * 异步调用npu执行, 无返回值. */ @@ -416,10 +461,6 @@ typedef bool(*CanUsePTACache) (const char *); static const auto initMemAddr = GetOpApiFuncAddr("InitHugeMemThreadLocal"); \ static const auto unInitMemAddr = GetOpApiFuncAddr("UnInitHugeMemThreadLocal"); \ static const auto releaseMemAddr = GetOpApiFuncAddr("ReleaseHugeMem"); \ - static const auto ptaGetExecCacheAddr = GetOpApiFuncAddr("PTAGetExecCache"); \ - static const auto initPTACacheThreadLocalAddr = GetOpApiFuncAddr("InitPTACacheThreadLocal"); \ - static const auto setPTAHashKeyAddr = GetOpApiFuncAddr("SetPTAHashKey"); \ - static const auto canUsePTACacheAddr = GetOpApiFuncAddr("CanUsePTACache"); \ TORCH_CHECK(getWorkspaceSizeFuncAddr != nullptr && opApiFuncAddr != nullptr, \ #aclnn_api, " or ", #aclnn_api "GetWorkspaceSize", " not in ", GetOpApiLibName(), ", or ", \ GetOpApiLibName(), "not found."); \ @@ -430,38 +471,8 @@ typedef bool(*CanUsePTACache) (const char *); aclOpExecutor **executor_addr = &executor; \ InitHugeMemThreadLocal initMemFunc = reinterpret_cast(initMemAddr); \ UnInitHugeMemThreadLocal unInitMemFunc = reinterpret_cast(unInitMemAddr); \ - PTAGetExecCache ptaGetExecCacheFunc = reinterpret_cast(ptaGetExecCacheAddr); \ - InitPTACacheThreadLocal initPTACacheThreadLocalFunc = reinterpret_cast(initPTACacheThreadLocalAddr); \ - SetPTAHashKey setPTAHashKeyFunc = reinterpret_cast(setPTAHashKeyAddr); \ - CanUsePTACache canUsePTACacheFunc = reinterpret_cast(canUsePTACacheAddr); \ - bool has_func = ptaGetExecCacheFunc && initPTACacheThreadLocalFunc && setPTAHashKeyFunc; \ - bool can_use = canUsePTACacheFunc && canUsePTACacheFunc(#aclnn_api); \ - if (has_func && can_use) { \ - initPTACacheThreadLocalFunc(); \ - g_hashOffset = 0; \ - AddParamToBuf(std::string(#aclnn_api), __VA_ARGS__); \ - uint64_t hashId = CalcHashId(); \ - setPTAHashKeyFunc(hashId); \ - executor = ptaGetExecCacheFunc(hashId, workspace_size_addr); \ - if (executor != nullptr) { \ - void *workspace_addr = nullptr; \ - if (workspace_size != 0) { \ - auto workspace_tensor = CalcuOpUtil::UnsafeEmptyWorkspace(workspace_size); \ - workspace_addr = workspace_tensor.storage().data(); \ - } \ - auto acl_call = [workspace_addr, workspace_size, acl_stream, executor] () -> int { \ - typedef int(*OpApiFunc)(void*, uint64_t, aclOpExecutor*, const aclrtStream); \ - OpApiFunc opApiFunc = reinterpret_cast(opApiFuncAddr); \ - auto api_ret = opApiFunc(workspace_addr, workspace_size, executor, acl_stream); \ - TORCH_CHECK(api_ret == 0, "call " #aclnn_api " failed, detail:", aclGetRecentErrMsg()); \ - return api_ret; \ - }; \ - at_npu::native::OpCommand cmd; \ - cmd.Name(#aclnn_api); \ - cmd.SetCustomHandler(acl_call); \ - cmd.Run(); \ - break; \ - } \ + if (HitCache(acl_stream, #aclnn_api, opApiFuncAddr, __VA_ARGS__)) { \ + break; \ } \ if (initMemFunc) { \ initMemFunc(nullptr, false); \