From 277e413ffca6e714cb12acacd997aef153e2623e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E7=BD=97=E5=B9=B8=E8=BF=90?= Date: Mon, 24 Jun 2024 19:56:58 +0800 Subject: [PATCH 01/37] =?UTF-8?q?=E3=80=90fix=E3=80=91=E3=80=90=E5=8A=A8?= =?UTF-8?q?=E6=80=81=E6=89=A9=E5=AE=B9=E3=80=91=E5=88=A0=E9=99=A4=E6=89=A9?= =?UTF-8?q?=E5=AE=B9=E7=9A=84=E5=86=97=E4=BD=99=E4=BB=A3=E7=A0=81=EF=BC=8C?= =?UTF-8?q?=E5=88=A0=E9=99=A4=E6=89=A9=E5=AE=B9=E9=A2=9D=E5=A4=96=E7=94=B3?= =?UTF-8?q?=E8=AF=B7=E7=A9=BA=E9=97=B4=E7=9A=84=E4=BB=A3=E7=A0=81?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/core/emb_table/emb_table.cpp | 163 ------------------------- src/core/emb_table/emb_table.h | 93 -------------- src/core/key_process/key_process.cpp | 39 ------ src/core/key_process/key_process.h | 4 - src/tests/emb_table/emb_table_test.cpp | 135 -------------------- 5 files changed, 434 deletions(-) delete mode 100644 src/core/emb_table/emb_table.cpp delete mode 100644 src/core/emb_table/emb_table.h delete mode 100644 src/tests/emb_table/emb_table_test.cpp diff --git a/src/core/emb_table/emb_table.cpp b/src/core/emb_table/emb_table.cpp deleted file mode 100644 index 914cf535..00000000 --- a/src/core/emb_table/emb_table.cpp +++ /dev/null @@ -1,163 +0,0 @@ -/* Copyright 2024. Huawei Technologies Co.,Ltd. All rights reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and - limitations under the License. -==============================================================================*/ - -#include -#include -#include -#include -#include "acl/acl_base.h" -#include "utils/common.h" -#include "initializer/initializer.h" -#include "emb_table/emb_table.h" - - -using namespace std; -using namespace MxRec; -using namespace tensorflow; - -void EmbTable::Init(const EmbInfo& eInfo, const RankInfo& rInfo, int initSeed) -{ -#ifndef GTEST - this->rankInfo = rInfo; - this->seed = initSeed; - this->embInfo = eInfo; - LOG_INFO("EmbTable init, deviceID {}, embSize {} running", rInfo.deviceId, embInfo.extEmbeddingSize); - // 计算embedding table需要分配的内存块数 - auto ret = aclrtSetDevice(static_cast(rInfo.deviceId)); - if (ret != ACL_ERROR_NONE) { - LOG_ERROR("Set device failed, device_id:{}, ret={}", rInfo.deviceId, ret); - throw AclError(); - } - embSize = embInfo.extEmbeddingSize; - blockSize = BLOCK_EMB_COUNT * embSize; - for (int i = 0; i < INIT_BLOCK_COUNT; ++i) { - // 申请新的内存块 - void *newBlock = nullptr; - aclError ec = aclrtMalloc(&newBlock, blockSize * sizeof(float), ACL_MEM_MALLOC_HUGE_FIRST); - if (ec != ACL_SUCCESS) { - LOG_ERROR("aclrtMalloc failed, ret={}", ec); - throw AclError(); - } - // 申请内存初始化 - RandomInit(newBlock); - // 将新的内存块加入内存链表 - memoryList.push_back(newBlock); - SplitMemoryBlock(newBlock); - } - totalCapacity = static_cast(memoryList.size()) * BLOCK_EMB_COUNT; - LOG_INFO("aclrtMalloc success, emb name:{}, total capacity:{}", embInfo.name, totalCapacity); -#endif -} - -EmbTable::~EmbTable() -{ -#ifndef GTEST - for (void *block : memoryList) { - // 释放内存块 - aclError ret = aclrtFree(block); - if (ret != ACL_SUCCESS) { - LOG_ERROR("aclrtFree failed, ret={}", ret); - } - block = nullptr; - } -#endif -} - -// 从embeddingList获取一个可用的emb地址 -int64_t EmbTable::GetEmbAddress() -{ - int64_t ret = -1; -#ifndef GTEST - if (embeddingList.empty()) { - PrintStatus(); - LOG_DEBUG("GetEmbAddress, embedding_list size: empty! Add block!"); - void *addBlock = nullptr; - aclError ret = aclrtMalloc(&addBlock, blockSize * sizeof(float), ACL_MEM_MALLOC_HUGE_FIRST); - if (ret != ACL_SUCCESS) { - LOG_ERROR("aclrtMalloc failed, ret={}", ret); - throw AclError(); - } - RandomInit(addBlock); - // 将新的内存块加入内存list - memoryList.push_back(addBlock); - SplitMemoryBlock(addBlock); - totalCapacity += BLOCK_EMB_COUNT; - } - float *embAddr = embeddingList.front(); - embeddingList.pop_front(); - usedCapacity++; - ret = reinterpret_cast(embAddr); -#endif - return ret; -} - -void EmbTable::RandomInit(void* newBlock) -{ -#ifndef GTEST - LOG_INFO("Device GenerateEmbData Start, seed:{}, initializer num: {}", seed, embInfo.initializeInfos.size()); - vector devEmb(blockSize); - for (const auto& initializeInfo: as_const(embInfo.initializeInfos)) { - LOG_INFO("Device GenerateEmbData ing. name {}", initializeInfo.name.c_str()); - for (int i = 0; i < BLOCK_EMB_COUNT; i++) { - initializeInfo.initializer->GenerateData(&devEmb[i * embSize], embSize); - } - } - LOG_INFO("Device GenerateEmbData End, seed:{}", seed); - ExecuteAclMemcpy(newBlock, devEmb); -#endif -} - -void EmbTable::ExecuteAclMemcpy(void* newBlock, vector devEmb) const -{ -#ifndef GTEST - aclError ret = aclrtMemcpy( - newBlock, blockSize * sizeof(float), devEmb.data(), blockSize * sizeof(float), ACL_MEMCPY_HOST_TO_DEVICE); - if (ret != ACL_SUCCESS) { - LOG_ERROR("aclrtMemcpy failed, ret={}", ret); - throw AclError(); - } -#endif -} - - -void EmbTable::SplitMemoryBlock(void *newBlock) -{ -#ifndef GTEST - if (embSize == 0) { - throw std::runtime_error("SplitMemoryBlock by embSize=0!"); - } - for (int i = 0; i < BLOCK_EMB_COUNT; i++) { - float *embPtr = static_cast(newBlock) + i * embSize; - embeddingList.push_back(embPtr); - } -#endif -} - -void EmbTable::PrintStatus() const -{ - // 输出embedding table的总容量和未使用的使用容量 - LOG_INFO("Total capacity:{}, Unused capacity:{}", - totalCapacity * embSize, totalCapacity * embSize - usedCapacity * embSize); -} - -int64_t EmbTable::GetTableSize() const -{ - return static_cast(usedCapacity); -} - -int64_t EmbTable::GetTableCapacity() const -{ - return static_cast(totalCapacity); -} diff --git a/src/core/emb_table/emb_table.h b/src/core/emb_table/emb_table.h deleted file mode 100644 index 2d30818c..00000000 --- a/src/core/emb_table/emb_table.h +++ /dev/null @@ -1,93 +0,0 @@ -/* Copyright 2024. Huawei Technologies Co.,Ltd. All rights reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and - limitations under the License. -==============================================================================*/ - -#ifndef MX_REC_EMB_TABLE_H -#define MX_REC_EMB_TABLE_H - -#include -#include -#include -#include - -#include "utils/common.h" - -namespace MxRec { - - using namespace std; - - class EmbTable { - public: - EmbTable() = default; - - void Init(const EmbInfo& eInfo, const RankInfo& rInfo, int initSeed = 0); - - ~EmbTable(); - - // 从embeddingList获取获取一个可用的emb地址 - int64_t GetEmbAddress(); - - // 打印emb表使用情况 - void PrintStatus() const; - - int64_t GetTableSize() const; - - int64_t GetTableCapacity() const; - - EmbTable(const EmbTable&) = delete; - - EmbTable(EmbTable&&) = delete; - - EmbTable& operator=(const EmbTable&) = delete; - - EmbTable& operator=(EmbTable&&) = delete; - - void ExecuteAclMemcpy(void* newBlock, vector devEmb) const; - - GTEST_PRIVATE: - constexpr static int BLOCK_EMB_COUNT = 100000; - constexpr static int INIT_BLOCK_COUNT = 5; - constexpr static int TEST_EMB_SIZE = 12; - EmbInfo embInfo; - RankInfo rankInfo; - size_t blockSize = 1; - int embSize = 1; - size_t totalCapacity = 1; - size_t usedCapacity = 0; - int seed = 0; - // embedding地址的列表 - list embeddingList; - // 内存块列表 - vector memoryList; - - void RandomInit(void* newBlock); - - // embSize由embInfo得出 - void SplitMemoryBlock(void* newBlock); - - // 内部类,抛出内存不足异常 - class OutOfMemoryError : public runtime_error { - public: - OutOfMemoryError() : runtime_error("Out of memory!") {} - }; - - // 内部类,抛出acl异常 - class AclError : public runtime_error { - public: - AclError() : runtime_error("Acl failed!") {} - }; - }; -} - -#endif // MX_REC_EMB_TABLE_MANAGER_H \ No newline at end of file diff --git a/src/core/key_process/key_process.cpp b/src/core/key_process/key_process.cpp index b5dc962e..74dfafa5 100644 --- a/src/core/key_process/key_process.cpp +++ b/src/core/key_process/key_process.cpp @@ -57,11 +57,6 @@ bool KeyProcess::Initialize(const RankInfo& rInfo, const vector& eInfos embInfos[info.name] = info; scInfo[info.name] = info.sendCount; InitHotEmbTotCount(info, rInfo); - if (rankInfo.useDynamicExpansion) { - // 动态扩容 - embeddingTableMap[info.name].Init(info, rInfo, seed); - LOG_INFO(KEY_PROCESS "EmbeddingTableMap:{} init success", info.name); - } } LOG_INFO(KEY_PROCESS "hot emb count info:{}", MapToString(hotEmbTotCount)); @@ -1114,40 +1109,6 @@ void KeyProcess::Key2Offset(const EmbNameT& embName, KeysT& splitKey, int channe embName, maxOffsetTmp, embInfos[embName].devVocabSize, key2OffsetTC.ElapsedMS()); } -void KeyProcess::Key2OffsetDynamicExpansion(const EmbNameT& embName, KeysT& splitKey, int channel) -{ - TimeCost key2OffsetTC; - EASY_FUNCTION(profiler::colors::Blue600) - std::lock_guard lk(mut); // lock for PROCESS_THREAD - auto& key2Offset = keyOffsetMap[embName]; - auto& maxOffsetTmp = maxOffset[embName]; - auto& curEmbTable = embeddingTableMap[embName]; // empty when not use dynamic expansion - for (long& key : splitKey) { - if (key == -1) { - key = 0; - continue; - } - const auto& iter = key2Offset.find(key); - if (iter != key2Offset.end()) { - key = iter->second; - } else { - // 新值 - if (channel == TRAIN_CHANNEL_ID) { -#ifndef GTEST - int64_t addr = curEmbTable.GetEmbAddress(); - key2Offset[key] = addr; - key = addr; -#endif - maxOffsetTmp++; - continue; - } - key = 0; - } - } - LOG_DEBUG("current expansion emb:{}, usage:{}/{}, key2OffsetTC({} ms)", - embName, maxOffsetTmp, embInfos[embName].devVocabSize, key2OffsetTC.ElapsedMS()); -} - /* * 构建恢复向量,以便从去重后的emb向量/key恢复回batch对应的emb向量 * 输入接收到emb块的偏移blockOffset,batch内每个key在块内的偏移restoreVec diff --git a/src/core/key_process/key_process.h b/src/core/key_process/key_process.h index 589fc2a5..82a3205b 100644 --- a/src/core/key_process/key_process.h +++ b/src/core/key_process/key_process.h @@ -28,7 +28,6 @@ See the License for the specific language governing permissions and #include "ock_ctr_common/include/factory.h" #include "utils/common.h" -#include "emb_table/emb_table.h" #include "feature_admit_and_evict.h" #include "hybrid_mgmt/hybrid_mgmt_block.h" #include "utils/singleton.h" @@ -196,7 +195,6 @@ namespace MxRec { map> evictPosMap {}; map> hotKey {}; map hotEmbTotCount; - map embeddingTableMap {}; ock::ctr::FactoryPtr factory {}; int hotEmbUpdateStep = HOT_EMB_UPDATE_STEP_DEFAULT; bool isWithFAAE; @@ -251,8 +249,6 @@ namespace MxRec { void Key2Offset(const EmbNameT& embName, KeysT& splitKey, int channel); - void Key2OffsetDynamicExpansion(const EmbNameT& embName, KeysT& splitKey, int channel); - unique_ptr GetBatchData(int channel, int commId) const; void BuildRestoreVec(const unique_ptr& batch, const vector& blockOffset, diff --git a/src/tests/emb_table/emb_table_test.cpp b/src/tests/emb_table/emb_table_test.cpp deleted file mode 100644 index b26b4487..00000000 --- a/src/tests/emb_table/emb_table_test.cpp +++ /dev/null @@ -1,135 +0,0 @@ -/* Copyright 2024. Huawei Technologies Co.,Ltd. All rights reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and - limitations under the License. -==============================================================================*/ - -#include -#include -#include -#include -#include -#include -#include -#include "utils/common.h" -#include "emb_table/emb_table.h" - -using namespace std; -using namespace MxRec; -using namespace testing; -using namespace tensorflow; - -class EmbTableTest : public testing::Test { -protected: - void SetUp() - { - // 设置测试用的EmbInfo - embInfo.extEmbeddingSize = embTable.TEST_EMB_SIZE; - LOG_INFO("EmbTable BLOCK_EMB_COUNT {} INIT_BLOCK_COUNT {}", - embTable.BLOCK_EMB_COUNT, embTable.INIT_BLOCK_COUNT); - rankInfo.rankId = 0; - rankInfo.rankSize = 1; - rankInfo.localRankSize = 1; - rankInfo.useStatic = true; - rankInfo.localRankId = 0; - rankInfo.isDDR = true; - rankInfo.ctrlSteps = { 1, -1 }; - rankInfo.deviceId = 0; - // 初始化EmbeddingTable -#ifndef GTEST - LOG_INFO("rank {} running", rankInfo.deviceId); - aclInit(nullptr); -#endif - } - - EmbTable embTable; - EmbInfo embInfo; - RankInfo rankInfo; - aclrtContext context; - - void TearDown() { - } -}; - -// 测试初始化是否正常 -TEST_F(EmbTableTest, Init) -{ -#ifndef GTEST - // 测试初始化是否出现异常 - EXPECT_NO_THROW(embTable.Init(embInfo, rankInfo, 0)); - LOG_INFO("embTable Init succeed!"); - ASSERT_EQ(embTable.rankInfo.g_rankId, rankInfo.g_rankId); - ASSERT_EQ(embTable.rankInfo.rankSize, rankInfo.rankSize); - ASSERT_EQ(embTable.rankInfo.localRankSize, rankInfo.localRankSize); - ASSERT_EQ(embTable.rankInfo.useStatic, rankInfo.useStatic); - ASSERT_EQ(embTable.rankInfo.localRankId, rankInfo.localRankId); - // 测试容量是否正常 - LOG_INFO("totalCapacity {}, INIT_BLOCK_COUNT {}", embTable.totalCapacity, embTable.INIT_BLOCK_COUNT); - EXPECT_EQ(embTable.totalCapacity, embTable.INIT_BLOCK_COUNT * embTable.BLOCK_EMB_COUNT); -#endif -} - -// 测试embedding list为空时的情况 -TEST_F(EmbTableTest, GetEmbAddressEmptyList) -{ -#ifndef GTEST - embTable.Init(embInfo, rankInfo, 0); - while (!embTable.embeddingList.empty()) { - float *embAddr = reinterpret_cast(embTable.GetEmbAddress()); - EXPECT_NE(embAddr, nullptr); - } - ASSERT_EQ(embTable.embeddingList.size(), 0); - - float *curAddr = nullptr; - int usedCapacityBefore = embTable.usedCapacity; - ASSERT_NO_THROW({ - curAddr= reinterpret_cast(embTable.GetEmbAddress()); - }); - EXPECT_NE(curAddr, nullptr); - EXPECT_EQ(embTable.usedCapacity, usedCapacityBefore + 1); -#endif -} - -// 测试正常情况 -TEST_F(EmbTableTest, GetEmbAddressNormal) -{ -#ifndef GTEST - embTable.Init(embInfo, rankInfo, 0); - ASSERT_EQ(embTable.totalCapacity, embTable.INIT_BLOCK_COUNT); - float *curAddr = nullptr; - int totalCapacityBefore = embTable.totalCapacity; - int usedCapacityBefore = embTable.usedCapacity; - ASSERT_NO_THROW({ - curAddr = reinterpret_cast(embTable.GetEmbAddress()); - }); - EXPECT_NE(curAddr, nullptr); - EXPECT_EQ(embTable.totalCapacity, totalCapacityBefore); - EXPECT_EQ(embTable.usedCapacity, usedCapacityBefore + 1); -#endif -} - -// 测试将一个emb地址放入embeddingList中,是否成功 -TEST_F(EmbTableTest, PutEmbAddress) -{ -#ifndef GTEST - embTable.Init(embInfo, rankInfo, 0); - int64_t curAddr; - int usedCapacityBefore = embTable.usedCapacity; - ASSERT_NO_THROW({ - curAddr = embTable.GetEmbAddress(); - }); - EXPECT_EQ(embTable.usedCapacity, usedCapacityBefore + 1); - embTable.PutEmbAddress(curAddr); - EXPECT_EQ(embTable.usedCapacity, usedCapacityBefore); - EXPECT_EQ(curAddr, reinterpret_cast(embTable.embeddingList.back())); -#endif -} -- Gitee From 3b9fbb550f6ca5b78f3e6adfbe4220ea98c7afb4 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E7=BD=97=E5=B9=B8=E8=BF=90?= Date: Mon, 24 Jun 2024 20:24:09 +0800 Subject: [PATCH 02/37] =?UTF-8?q?=E3=80=90fix=E3=80=91capacity=E6=8E=A5?= =?UTF-8?q?=E5=8F=A3=E9=80=82=E9=85=8D=E6=96=B0ddr=E3=80=81ssd?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- mx_rec/core/emb/sparse_embedding.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/mx_rec/core/emb/sparse_embedding.py b/mx_rec/core/emb/sparse_embedding.py index 071f4506..39af9d60 100644 --- a/mx_rec/core/emb/sparse_embedding.py +++ b/mx_rec/core/emb/sparse_embedding.py @@ -77,9 +77,9 @@ class ExternalStorageSparseEmbedding(SparseEmbedding): def capacity(self) -> int: # DDR if not self._ssd_vocabulary_size: - return self._device_vocabulary_size + self._host_vocabulary_size + return self._host_vocabulary_size # SSD - return self._device_vocabulary_size + self._host_vocabulary_size + self._ssd_vocabulary_size + return self._host_vocabulary_size + self._ssd_vocabulary_size def _set_specific_value_for_non_valid_key(id_offsets: Optional[tf.Tensor], -- Gitee From 533bc2be9043d0e29fbf962a5a3f781cea3250f4 Mon Sep 17 00:00:00 2001 From: LiJiang Date: Fri, 28 Jun 2024 14:31:16 +0800 Subject: [PATCH 03/37] =?UTF-8?q?=E5=88=A0=E9=99=A4=E5=86=97=E4=BD=99?= =?UTF-8?q?=E4=BB=A3=E7=A0=81=EF=BC=9B=E4=BF=AE=E6=94=B9=E9=94=99=E8=AF=AF?= =?UTF-8?q?log?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- cust_op/fused_lazy_adam/aclnn_lazy_adam_test/src/op_runner.cpp | 2 +- cust_op/fused_lazy_adam/op_host/lazy_adam.cpp | 2 -- 2 files changed, 1 insertion(+), 3 deletions(-) diff --git a/cust_op/fused_lazy_adam/aclnn_lazy_adam_test/src/op_runner.cpp b/cust_op/fused_lazy_adam/aclnn_lazy_adam_test/src/op_runner.cpp index e9711379..3b9b51fe 100644 --- a/cust_op/fused_lazy_adam/aclnn_lazy_adam_test/src/op_runner.cpp +++ b/cust_op/fused_lazy_adam/aclnn_lazy_adam_test/src/op_runner.cpp @@ -322,7 +322,7 @@ namespace AclnnLazyAdam { ERROR_LOG("Execute Operator failed. error code is %d", static_cast(ret)); return false; } - INFO_LOG("Execute aclnnAddCustom success"); + INFO_LOG("Execute aclnnLazyAdam success"); ret = aclrtSynchronizeStreamWithTimeout(stream, STREAM_TIMEOUT); if (ret != SUCCESS) { diff --git a/cust_op/fused_lazy_adam/op_host/lazy_adam.cpp b/cust_op/fused_lazy_adam/op_host/lazy_adam.cpp index fb7f86b3..2c288729 100644 --- a/cust_op/fused_lazy_adam/op_host/lazy_adam.cpp +++ b/cust_op/fused_lazy_adam/op_host/lazy_adam.cpp @@ -54,8 +54,6 @@ static ge::graphStatus LazyAdamTilingFunc(gert::TilingContext* context) ge::DataType indicesDtype = context->GetInputDesc(1)->GetDataType(); int indicesDtypeSize = ge::GetSizeByDataType(indicesDtype); - tiling.SaveToBuffer(context->GetRawTilingData()->GetData(), context->GetRawTilingData()->GetCapacity()); - context->GetRawTilingData()->SetDataSize(tiling.GetDataSize()); auto attrs = context->GetAttrs(); float beta1 = *attrs->GetAttrPointer(0); -- Gitee From 550a302c91620bbb771ac980ac564aa5c4f467a8 Mon Sep 17 00:00:00 2001 From: steepcurve Date: Mon, 1 Jul 2024 15:14:44 +0800 Subject: [PATCH 04/37] fix: `StringFormat` use cases --- src/core/checkpoint/checkpoint.cpp | 142 ++-- src/core/utils/common.h | 1132 ++++++++++++++-------------- src/ops_tf/hybrid_dataset_ops.cpp | 4 +- 3 files changed, 653 insertions(+), 625 deletions(-) diff --git a/src/core/checkpoint/checkpoint.cpp b/src/core/checkpoint/checkpoint.cpp index abd3a10e..bc7501bb 100644 --- a/src/core/checkpoint/checkpoint.cpp +++ b/src/core/checkpoint/checkpoint.cpp @@ -13,21 +13,22 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include -#include -#include -#include +#include "checkpoint.h" + #include #include +#include +#include +#include + +#include #include "ckpt_data_handler/feat_admit_n_evict_ckpt/feat_admit_n_evict_ckpt.h" -#include "ckpt_data_handler/key_freq_map_ckpt/key_freq_map_ckpt.h" #include "ckpt_data_handler/key_count_map_ckpt/key_count_map_ckpt.h" -#include "utils/time_cost.h" -#include "utils/common.h" +#include "ckpt_data_handler/key_freq_map_ckpt/key_freq_map_ckpt.h" #include "file_system/file_system_handler.h" - -#include "checkpoint.h" +#include "utils/common.h" +#include "utils/time_cost.h" using namespace std; using namespace MxRec; @@ -89,11 +90,17 @@ void Checkpoint::SetDataHandler(CkptData& ckptData) void Checkpoint::SetDataHandler(const vector& featureTypes) { - map> setCkptMap{ - {CkptFeatureType::FEAT_ADMIT_N_EVICT, [this] { dataHandlers.push_back(make_unique()); }}, - {CkptFeatureType::DDR_KEY_FREQ_MAP, [this] { dataHandlers.push_back(make_unique()); }}, - {CkptFeatureType::KEY_COUNT_MAP, [this] { dataHandlers.push_back(make_unique()); }} - }; + map> setCkptMap{{CkptFeatureType::FEAT_ADMIT_N_EVICT, + [this] { + dataHandlers.push_back(make_unique()); + }}, + {CkptFeatureType::DDR_KEY_FREQ_MAP, + [this] { + dataHandlers.push_back(make_unique()); + }}, + {CkptFeatureType::KEY_COUNT_MAP, [this] { + dataHandlers.push_back(make_unique()); + }}}; for (const auto& featureType : featureTypes) { setCkptMap.at(featureType)(); @@ -104,8 +111,8 @@ void Checkpoint::SaveProcess(CkptData& ckptData) { for (const auto& dataHandler : dataHandlers) { dataHandler->SetProcessData(ckptData); - vector embNames { dataHandler->GetEmbNames() }; - vector saveDataTypes { dataHandler->GetDataTypes() }; + vector embNames{dataHandler->GetEmbNames()}; + vector saveDataTypes{dataHandler->GetDataTypes()}; MakeUpperLayerSaveDir(); MakeDataLayerSaveDir(embNames, saveDataTypes, dataHandler); SaveDataset(embNames, saveDataTypes, dataHandler); @@ -118,17 +125,16 @@ void Checkpoint::MakeUpperLayerSaveDir() MakeSaveDir(innerDirPath); } -void Checkpoint::MakeDataLayerSaveDir(const vector& embNames, - const vector& saveDataTypes, +void Checkpoint::MakeDataLayerSaveDir(const vector& embNames, const vector& saveDataTypes, const unique_ptr& dataHandler) { for (const auto& embName : embNames) { - auto dataDir { innerDirPath + dirSeparator + embName }; + auto dataDir{innerDirPath + dirSeparator + embName}; MakeSaveDir(dataDir); for (const auto& saveDataType : saveDataTypes) { - auto dataDirName { dataHandler->GetDataDirName(saveDataType) }; - auto datasetPath { dataDir + dirSeparator + dataDirName }; + auto dataDirName{dataHandler->GetDataDirName(saveDataType)}; + auto datasetPath{dataDir + dirSeparator + dataDirName}; MakeSaveDir(datasetPath); } } @@ -146,7 +152,7 @@ void Checkpoint::MakeSaveDir(const string& dirName) const Checkpoint::EmbSizeInfo Checkpoint::GetEmbeddingSize(const string& embName) { EmbSizeInfo embSizeInfo; - for (const auto &embInfo: mgmtEmbInfo) { + for (const auto& embInfo : mgmtEmbInfo) { if (embInfo.name == embName) { embSizeInfo.embSize = embInfo.embeddingSize; embSizeInfo.extEmbSize = embInfo.extEmbeddingSize; @@ -158,29 +164,28 @@ Checkpoint::EmbSizeInfo Checkpoint::GetEmbeddingSize(const string& embName) bool Checkpoint::CheckEmbNames(const string& embName) { - for (const auto &embInfo: mgmtEmbInfo) { - if (embInfo.name == embName && embInfo.isSave) { + for (const auto& embInfo : mgmtEmbInfo) { + if (embInfo.name == embName && embInfo.isSave) { return true; } } return false; } -void Checkpoint::SaveDataset(const vector& embNames, - const vector& saveDataTypes, +void Checkpoint::SaveDataset(const vector& embNames, const vector& saveDataTypes, const unique_ptr& dataHandler) { - for (const auto& embName: embNames) { + for (const auto& embName : embNames) { if (!CheckEmbNames(embName)) { continue; } auto dataDir{innerDirPath + dirSeparator + embName}; - for (const auto& saveDataType: saveDataTypes) { - auto datasetPath { dataDir + dirSeparator + dataHandler->GetDataDirName(saveDataType) }; - auto datasetDir { datasetPath + dirSeparator + datasetName + to_string(rankId) + dataFileType }; + for (const auto& saveDataType : saveDataTypes) { + auto datasetPath{dataDir + dirSeparator + dataHandler->GetDataDirName(saveDataType)}; + auto datasetDir{datasetPath + dirSeparator + datasetName + to_string(rankId) + dataFileType}; LOG_DEBUG("====Start getting data from handler to: {}", datasetDir); - auto transData { dataHandler->GetDataset(saveDataType, embName) }; + auto transData{dataHandler->GetDataset(saveDataType, embName)}; LOG_DEBUG("====Start saving data to: {}", datasetDir); WriteStream(transData, datasetDir, transData.datasetSize, saveDataType); @@ -197,36 +202,36 @@ void Checkpoint::WriteStream(CkptTransData& transData, const string& dataDir, si ssize_t writeBytesNum; if (int32TransSet.find(dataType) != int32TransSet.end()) { - writeBytesNum = fileSystemPtr->Write(dataDir, - reinterpret_cast(transData.int32Arr.data()), dataSize); + writeBytesNum = + fileSystemPtr->Write(dataDir, reinterpret_cast(transData.int32Arr.data()), dataSize); } else if (int64TransSet.find(dataType) != int64TransSet.end()) { - writeBytesNum = fileSystemPtr->Write(dataDir, - reinterpret_cast(transData.int64Arr.data()), dataSize); + writeBytesNum = + fileSystemPtr->Write(dataDir, reinterpret_cast(transData.int64Arr.data()), dataSize); } else if (dataType == CkptDataType::ATTRIBUTE) { - writeBytesNum = fileSystemPtr->Write(dataDir, - reinterpret_cast(transData.attribute.data()), dataSize); + writeBytesNum = + fileSystemPtr->Write(dataDir, reinterpret_cast(transData.attribute.data()), dataSize); } else { throw runtime_error("unknown CkptDataType"); } if (writeBytesNum == -1) { - throw runtime_error(StringFormat("Error: Save data failed. data type: %d. " - "An error occurred while writing file: %s.", dataType, dataDir.c_str())); + throw runtime_error(StringFormat("Error: Save data failed. data type: %s. " + "An error occurred while writing file: %s.", + CkptDataTypeName(dataType).c_str(), dataDir.c_str())); } if (writeBytesNum != dataSize) { - throw runtime_error(StringFormat("Error: Save data failed. data type: %d ." + throw runtime_error(StringFormat("Error: Save data failed. data type: %s. " "Expected to write %d bytes, but actually write %d bytes to file %s.", - dataType, dataSize, writeBytesNum, dataDir.c_str())); + CkptDataTypeName(dataType).c_str(), dataSize, writeBytesNum, dataDir.c_str())); } } - void Checkpoint::LoadProcess(CkptData& ckptData) { for (const auto& dataHandler : dataHandlers) { - vector embNames {}; - vector dirNames { dataHandler->GetDirNames() }; - vector saveDataTypes { dataHandler->GetDataTypes() }; + vector embNames{}; + vector dirNames{dataHandler->GetDirNames()}; + vector saveDataTypes{dataHandler->GetDataTypes()}; innerDirPath = processPath; if (find(dirNames.begin(), dirNames.end(), ssdSymbol) != dirNames.end()) { embNames = GetTableLayerLoadDir(); @@ -238,7 +243,6 @@ void Checkpoint::LoadProcess(CkptData& ckptData) } } - vector Checkpoint::GetEmbedTableNames() { vector loadTableNames; @@ -262,22 +266,20 @@ vector Checkpoint::GetTableLayerLoadDir() return loadTableDir; } -void Checkpoint::LoadDataset(const vector& embNames, - const vector& saveDataTypes, - const unique_ptr& dataHandler, - CkptData& ckptData) +void Checkpoint::LoadDataset(const vector& embNames, const vector& saveDataTypes, + const unique_ptr& dataHandler, CkptData& ckptData) { for (const auto& embName : embNames) { - auto dataDir { innerDirPath + dirSeparator + embName }; + auto dataDir{innerDirPath + dirSeparator + embName}; for (const auto& saveDataType : saveDataTypes) { - auto datasetPath { dataDir + dirSeparator + dataHandler->GetDataDirName(saveDataType) }; + auto datasetPath{dataDir + dirSeparator + dataHandler->GetDataDirName(saveDataType)}; - auto datasetDir { datasetPath + dirSeparator + "slice" + dataFileType }; - auto attributeDir { datasetPath + dirSeparator + "slice" + attribFileType }; + auto datasetDir{datasetPath + dirSeparator + "slice" + dataFileType}; + auto attributeDir{datasetPath + dirSeparator + "slice" + attribFileType}; CkptTransData transData; LOG_DEBUG("====Start reading data from: {}", attributeDir); - auto dataElmtBytes { dataHandler->GetDataElmtBytes(CkptDataType::ATTRIBUTE) }; + auto dataElmtBytes{dataHandler->GetDataElmtBytes(CkptDataType::ATTRIBUTE)}; ReadStream(transData, attributeDir, CkptDataType::ATTRIBUTE, dataElmtBytes); dataElmtBytes = dataHandler->GetDataElmtBytes(saveDataType); @@ -290,7 +292,7 @@ void Checkpoint::LoadDataset(const vector& embNames, } LOG_DEBUG("====Start loading data from: {} to data handler.", attributeDir); - if ((saveDataType == CkptDataType::EMB_INFO)) { + if ((saveDataType == CkptDataType::EMB_INFO)) { dataHandler->SetDatasetForLoadEmb(saveDataType, embName, transData, ckptData); } else { dataHandler->SetDataset(saveDataType, embName, transData); @@ -299,14 +301,12 @@ void Checkpoint::LoadDataset(const vector& embNames, } } -void Checkpoint::ReadStream(CkptTransData& transData, - const string& dataDir, - CkptDataType dataType, +void Checkpoint::ReadStream(CkptTransData& transData, const string& dataDir, CkptDataType dataType, uint32_t dataElmtBytes) { if (dataElmtBytes == 0) { LOG_WARN("dataElmtBytes is 0, don't handle [/ %] operation"); - return ; + return; } if (fileSystemPtr == nullptr) { @@ -315,7 +315,7 @@ void Checkpoint::ReadStream(CkptTransData& transData, } size_t datasetSize = fileSystemPtr->GetFileSize(dataDir); - auto resizeSize { datasetSize / dataElmtBytes }; + auto resizeSize{datasetSize / dataElmtBytes}; SetTransDataSize(transData, resizeSize, dataType); if (datasetSize % dataElmtBytes > 0) { @@ -328,31 +328,29 @@ void Checkpoint::ReadStream(CkptTransData& transData, } else if (int64TransSet.find(dataType) != int64TransSet.end()) { readBytesNum = fileSystemPtr->Read(dataDir, reinterpret_cast(transData.int64Arr.data()), datasetSize); } else if (dataType == CkptDataType::ATTRIBUTE) { - readBytesNum = fileSystemPtr->Read(dataDir, reinterpret_cast(transData.attribute.data()), datasetSize); + readBytesNum = fileSystemPtr->Read(dataDir, reinterpret_cast(transData.attribute.data()), datasetSize); } else { throw runtime_error("unknown CkptDataType"); } if (readBytesNum == -1) { - throw runtime_error(StringFormat("Error: Load data failed. data type: %d ." - "An error occurred while reading file: %s.", dataType, dataDir.c_str())); + throw runtime_error(StringFormat("Error: Load data failed. data type: %s. " + "An error occurred while reading file: %s.", + CkptDataTypeName(dataType).c_str(), dataDir.c_str())); } if (readBytesNum != datasetSize) { - throw runtime_error(StringFormat("Error: Load data failed. data type: %d ." + throw runtime_error(StringFormat("Error: Load data failed. data type: %s. " "Expected to read %d bytes, but actually read %d bytes to file %s.", - dataType, datasetSize, readBytesNum, dataDir.c_str())); + CkptDataTypeName(dataType).c_str(), datasetSize, readBytesNum, dataDir.c_str())); } } -void Checkpoint::ReadStreamForEmbData(CkptTransData& transData, - const string& dataDir, - uint32_t dataElmtBytes, - CkptData& ckptData, - string embName) const +void Checkpoint::ReadStreamForEmbData(CkptTransData& transData, const string& dataDir, uint32_t dataElmtBytes, + CkptData& ckptData, string embName) const { if (dataElmtBytes == 0) { LOG_ERROR("dataElmtBytes is 0, don't handle [/ %] operation"); - return ; + return; } if (fileSystemPtr == nullptr) { diff --git a/src/core/utils/common.h b/src/core/utils/common.h index 9a39e7ac..f8ff4565 100644 --- a/src/core/utils/common.h +++ b/src/core/utils/common.h @@ -17,608 +17,638 @@ See the License for the specific language governing permissions and #define COMMON_H #include -#include -#include -#include + +#include #include +#include +#include #include +#include #include -#include -#include -#include "tensorflow/core/framework/tensor.h" -#include "absl/container/flat_hash_map.h" -#include "securec.h" -#include "utils/logger.h" -#include "utils/config.h" +#include -#include "initializer/initializer.h" +#include "absl/container/flat_hash_map.h" #include "initializer/constant_initializer/constant_initializer.h" -#include "initializer/truncated_normal_initializer/truncated_normal_initializer.h" +#include "initializer/initializer.h" #include "initializer/random_normal_initializer/random_normal_initializer.h" -#include "ock_ctr_common/include/factory.h" +#include "initializer/truncated_normal_initializer/truncated_normal_initializer.h" #include "ock_ctr_common/include/embedding_cache.h" +#include "ock_ctr_common/include/factory.h" +#include "securec.h" +#include "tensorflow/core/framework/tensor.h" +#include "utils/config.h" +#include "utils/logger.h" #if defined(BUILD_WITH_EASY_PROFILER) - #include - #include +#include +#include #else - #define EASY_FUNCTION(...) - #define EASY_VALUE(...) - #define EASY_BLOCK(...) - #define EASY_END_BLOCK - #define EASY_PROFILER_ENABLE - #define EASY_PROFILER_DISABLE +#define EASY_FUNCTION(...) +#define EASY_VALUE(...) +#define EASY_BLOCK(...) +#define EASY_END_BLOCK +#define EASY_PROFILER_ENABLE +#define EASY_PROFILER_DISABLE #endif namespace MxRec { #define INFO_PTR shared_ptr #define MGMT_CPY_THREADS 4 #define PROFILING - using namespace tensorflow; - extern ock::ctr::FactoryPtr factory; - constexpr int TRAIN_CHANNEL_ID = 0; - constexpr int EVAL_CHANNEL_ID = 1; - - constexpr int MAX_CHANNEL_NUM = 2; - constexpr int MAX_KEY_PROCESS_THREAD = 10; - constexpr int MAX_QUEUE_NUM = MAX_CHANNEL_NUM * MAX_KEY_PROCESS_THREAD; - constexpr int DEFAULT_KEY_PROCESS_THREAD = 6; - constexpr int KEY_PROCESS_THREAD = 6; - constexpr char SUM_SAME_ID[] = "sum_same_id_gradients_and_apply"; - constexpr size_t MAX_VOCABULARY_SIZE = 1e10; - constexpr int SSD_SIZE_INDEX = 2; - constexpr int MAX_FILE_NUM = 1000; - constexpr int EMBEDDING_THREAD_NUM = 2; - // for GLOG - struct GlogConfig { - static bool gStatOn; - static int gGlogLevel; - static string gRankId; - }; - - constexpr int GLOG_MAX_BUF_SIZE = 1024; - constexpr int GLOG_TIME_WIDTH_2 = 2; - constexpr int GLOG_TIME_WIDTH_6 = 6; - constexpr char GLOG_STAT_FLAG[] = "statOn"; - - // unique related config - constexpr int UNIQUE_BUCKET = 6; - constexpr int MIN_UNIQUE_THREAD_NUM = 1; - - // validate file - constexpr long long FILE_MAX_SIZE = 1LL << 40; - constexpr int FILE_MIN_SIZE = 0; - constexpr size_t BUFFER_SIZE{1024 * 1024 * 64}; - constexpr size_t MAP_BYTE_SIZE{static_cast(10) * 1024 * 1024 * 1024}; +using namespace tensorflow; +extern ock::ctr::FactoryPtr factory; +constexpr int TRAIN_CHANNEL_ID = 0; +constexpr int EVAL_CHANNEL_ID = 1; + +constexpr int MAX_CHANNEL_NUM = 2; +constexpr int MAX_KEY_PROCESS_THREAD = 10; +constexpr int MAX_QUEUE_NUM = MAX_CHANNEL_NUM * MAX_KEY_PROCESS_THREAD; +constexpr int DEFAULT_KEY_PROCESS_THREAD = 6; +constexpr int KEY_PROCESS_THREAD = 6; +constexpr char SUM_SAME_ID[] = "sum_same_id_gradients_and_apply"; +constexpr size_t MAX_VOCABULARY_SIZE = 1e10; +constexpr int SSD_SIZE_INDEX = 2; +constexpr int MAX_FILE_NUM = 1000; +constexpr int EMBEDDING_THREAD_NUM = 2; +// for GLOG +struct GlogConfig { + static bool gStatOn; + static int gGlogLevel; + static string gRankId; +}; + +constexpr int GLOG_MAX_BUF_SIZE = 1024; +constexpr int GLOG_TIME_WIDTH_2 = 2; +constexpr int GLOG_TIME_WIDTH_6 = 6; +constexpr char GLOG_STAT_FLAG[] = "statOn"; + +// unique related config +constexpr int UNIQUE_BUCKET = 6; +constexpr int MIN_UNIQUE_THREAD_NUM = 1; + +// validate file +constexpr long long FILE_MAX_SIZE = 1LL << 40; +constexpr int FILE_MIN_SIZE = 0; +constexpr size_t BUFFER_SIZE{1024 * 1024 * 64}; +constexpr size_t MAP_BYTE_SIZE{static_cast(10) * 1024 * 1024 * 1024}; #ifdef GTEST - constexpr int KEY_PROCESS_TIMEOUT = 3; +constexpr int KEY_PROCESS_TIMEOUT = 3; #else - constexpr int KEY_PROCESS_TIMEOUT = 120; +constexpr int KEY_PROCESS_TIMEOUT = 120; #endif - constexpr int GET_BATCH_TIMEOUT = 300; - constexpr int EOS_TIMEOUT = 30; - - constexpr size_t DEFAULT_RANDOM_SEED = 10086; - constexpr int64_t INVALID_KEY_VALUE = -1; - constexpr int32_t INVALID_INDEX_VALUE = -1; - constexpr int ALLTOALLVC_ALIGN = 128; - constexpr int PROFILING_START_BATCH_ID = 100; - constexpr int PROFILING_END_BATCH_ID = 200; - constexpr int MGMT_THREAD_BIND = 48; - constexpr int UNIQUE_MAX_BUCKET_WIDTH = 6; - constexpr int HOT_EMB_UPDATE_STEP_DEFAULT = 1000; - constexpr float HOT_EMB_CACHE_PCT = static_cast(1. / 3); // hot emb cache percent - - const string COMBINE_HISTORY_NAME = "combine_table_history"; - const string SAVE_SPARSE_PATH_PREFIX = "sparse"; - - using emb_key_t = int64_t; - using emb_cache_key_t = uint64_t; - using freq_num_t = int64_t; - using EmbNameT= std::string; - using KeysT = std::vector; - using LookupKeyT = std::tuple; // batch_id quarry_lable keys_vector - using UinqueKeyT = std::tuple>; - using RestoreVecSecT = std::tuple>; - using TensorInfoT = std::tuple>>::iterator>; - - namespace HybridOption { - const unsigned int USE_STATIC = 0x001; - const unsigned int USE_DYNAMIC_EXPANSION = 0x001 << 1; - const unsigned int USE_SUM_SAME_ID_GRADIENTS = 0x001 << 2; - }; - - string GetChipName(int devID); - int GetThreadNumEnv(); - - namespace UBSize { - const int ASCEND910_PREMIUM_A = 262144; - const int ASCEND910_PRO_B = 262144; - const int ASCEND910_B2 = 196608; - const int ASCEND910_B1 = 196608; - const int ASCEND910_B3 = 196608; - const int ASCEND910_B4 = 196608; - const int ASCEND910_C1 = 196608; - const int ASCEND910_C2 = 196608; - const int ASCEND910_C3 = 196608; - const int ASCEND920_A = 196608; - const int ASCEND910_PRO_A = 262144; - const int ASCEND910_B = 262144; - const int ASCEND910_A = 262144; - const int ASCEND910_B2C = 196608; - }; - - inline int GetUBSize(int devID) - { - const std::map chipUbSizeList = {{"910A", UBSize::ASCEND910_A}, - {"910B", UBSize::ASCEND910_B}, - {"920A", UBSize::ASCEND920_A}, - {"910B1", UBSize::ASCEND910_B1}, - {"910B2", UBSize::ASCEND910_B2}, - {"910B3", UBSize::ASCEND910_B3}, - {"910B4", UBSize::ASCEND910_B4}, - {"910B2C", UBSize::ASCEND910_B2C}, - {"910C1", UBSize::ASCEND910_C1}, - {"910C2", UBSize::ASCEND910_C1}, - {"910C3", UBSize::ASCEND910_C3} - }; - auto it = chipUbSizeList.find(GetChipName(devID)); - if (it != chipUbSizeList.end()) { - return it->second; - } - - throw std::runtime_error("unknown chip ub size" + GetChipName(devID)); +constexpr int GET_BATCH_TIMEOUT = 300; +constexpr int EOS_TIMEOUT = 30; + +constexpr size_t DEFAULT_RANDOM_SEED = 10086; +constexpr int64_t INVALID_KEY_VALUE = -1; +constexpr int32_t INVALID_INDEX_VALUE = -1; +constexpr int ALLTOALLVC_ALIGN = 128; +constexpr int PROFILING_START_BATCH_ID = 100; +constexpr int PROFILING_END_BATCH_ID = 200; +constexpr int MGMT_THREAD_BIND = 48; +constexpr int UNIQUE_MAX_BUCKET_WIDTH = 6; +constexpr int HOT_EMB_UPDATE_STEP_DEFAULT = 1000; +constexpr float HOT_EMB_CACHE_PCT = static_cast(1. / 3); // hot emb cache percent + +const string COMBINE_HISTORY_NAME = "combine_table_history"; +const string SAVE_SPARSE_PATH_PREFIX = "sparse"; + +using emb_key_t = int64_t; +using emb_cache_key_t = uint64_t; +using freq_num_t = int64_t; +using EmbNameT = std::string; +using KeysT = std::vector; +using LookupKeyT = std::tuple; // batch_id quarry_lable keys_vector +using UinqueKeyT = std::tuple>; +using RestoreVecSecT = std::tuple>; +using TensorInfoT = std::tuple>>::iterator>; + +namespace HybridOption { +const unsigned int USE_STATIC = 0x001; +const unsigned int USE_DYNAMIC_EXPANSION = 0x001 << 1; +const unsigned int USE_SUM_SAME_ID_GRADIENTS = 0x001 << 2; +}; // namespace HybridOption + +string GetChipName(int devID); +int GetThreadNumEnv(); + +namespace UBSize { +const int ASCEND910_PREMIUM_A = 262144; +const int ASCEND910_PRO_B = 262144; +const int ASCEND910_B2 = 196608; +const int ASCEND910_B1 = 196608; +const int ASCEND910_B3 = 196608; +const int ASCEND910_B4 = 196608; +const int ASCEND910_C1 = 196608; +const int ASCEND910_C2 = 196608; +const int ASCEND910_C3 = 196608; +const int ASCEND920_A = 196608; +const int ASCEND910_PRO_A = 262144; +const int ASCEND910_B = 262144; +const int ASCEND910_A = 262144; +const int ASCEND910_B2C = 196608; +}; // namespace UBSize + +inline int GetUBSize(int devID) +{ + const std::map chipUbSizeList = { + {"910A", UBSize::ASCEND910_A}, {"910B", UBSize::ASCEND910_B}, {"920A", UBSize::ASCEND920_A}, + {"910B1", UBSize::ASCEND910_B1}, {"910B2", UBSize::ASCEND910_B2}, {"910B3", UBSize::ASCEND910_B3}, + {"910B4", UBSize::ASCEND910_B4}, {"910B2C", UBSize::ASCEND910_B2C}, {"910C1", UBSize::ASCEND910_C1}, + {"910C2", UBSize::ASCEND910_C1}, {"910C3", UBSize::ASCEND910_C3}}; + auto it = chipUbSizeList.find(GetChipName(devID)); + if (it != chipUbSizeList.end()) { + return it->second; } - template - struct Batch { - size_t Size() const - { - return sample.size(); - } + throw std::runtime_error("unknown chip ub size" + GetChipName(devID)); +} - std::string UnParse() const - { - std::string s; - constexpr size_t maxDispLen = 20; - int maxLen = static_cast(std::min(sample.size(), maxDispLen)); - for (int i = 0; i < maxLen; i++) { - s += std::to_string(sample[i]) + " "; - } - return s; - } - - std::vector sample; - std::string name; - size_t batchSize; - int batchId; - int channel = 0; - time_t timestamp { -1 }; - }; - - struct BatchTask { - vector splits; - vector embNames; - size_t batchSize; - int batchQueueId; - int batchId; - int channelId; - time_t timestamp { -1 }; - const void *tensor; - }; - - using EmbBatchT = Batch; - using BatchTaskT = BatchTask; - - struct DDRParam { - vector tmpDataOut; - vector offsetsOut; - DDRParam(vector tmpData, vector offset) - { - tmpDataOut = tmpData; - offsetsOut = offset; - } - }; - - struct RankInfo { - RankInfo() = default; - - RankInfo(int rankId, int deviceId, int localRankSize, int option, const std::vector& ctrlSteps); - RankInfo(int localRankSize, int option, const std::vector& maxStep); - - int rankId {}; - int deviceId {}; - int rankSize {}; - int localRankId {}; - int localRankSize {}; - bool useStatic { false }; - uint32_t option {}; - bool isDDR { false }; - bool isSSDEnabled { false }; - bool useDynamicExpansion {false}; - bool useSumSameIdGradients {true}; - std::vector ctrlSteps; // 包含4个步数: train_steps, eval_steps, save_steps, max_train_steps - }; - - struct EmbBaseInfo { - int batchId; - int channelId; - string name; - }; - - enum TensorIndex : uint32_t { - TENSOR_INDEX_0, - TENSOR_INDEX_1, - TENSOR_INDEX_2, - TENSOR_INDEX_3, - TENSOR_INDEX_4, - TENSOR_INDEX_5, - TENSOR_INDEX_6, - TENSOR_INDEX_7, - TENSOR_INDEX_8 - }; - - enum TupleIndex : uint32_t { - TUPLE_INDEX_0 = 0, - TUPLE_INDEX_1, - TUPLE_INDEX_2, - TUPLE_INDEX_3, - TUPLE_INDEX_4, - TUPLE_INDEX_5, - TUPLE_INDEX_6, - TUPLE_INDEX_7 - }; - - struct RandomInfo { - RandomInfo() = default; - - RandomInfo(int start, int len, float constantVal, float randomMin, float randomMax); - - int start; - int len; - float constantVal; - float randomMin; - float randomMax; - }; - - struct EmbeddingSizeInfo { - size_t embeddingSize = 0; - size_t extendEmbSize = 0; - EmbeddingSizeInfo() = default; - EmbeddingSizeInfo(size_t embSize, size_t extendSize) - : embeddingSize(embSize), extendEmbSize(extendSize) {} - }; - - struct OptimizerInfo { - OptimizerInfo() = default; - OptimizerInfo(std::string name, vector params) - { - optimName = name; - optimParams = std::move(params); - } - - std::string optimName; - vector optimParams; - }; - - struct ThresholdValue { - ThresholdValue() = default; - ThresholdValue(EmbNameT name, int countThre, int timeThre, int faaeCoef, bool isSum) - { - tableName = name; - countThreshold = countThre; - timeThreshold = timeThre; - faaeCoefficient = faaeCoef; - isEnableSum = isSum; - } - - EmbNameT tableName { "" }; // embName - int countThreshold { -1 }; // 只配置count,即“只有准入、而没有淘汰”功能,对应SingleHostEmbTableStatus::SETS_ONLY_ADMIT状态 - int timeThreshold { -1 }; // 只配置time,配置错误;即准入是淘汰的前提,对应SingleHostEmbTableStatus::SETS_BOTH状态 - int faaeCoefficient { 1 }; // 配置后,该表在准入时,count计数会乘以该系数 - bool isEnableSum {true}; // 配置false,该表在准入时,count计数不会累加 - }; - - struct FeatureItemInfo { - FeatureItemInfo() = default; - FeatureItemInfo(uint32_t cnt, time_t lastT) - : count(cnt), lastTime(lastT) - {} - - uint32_t count { 0 }; - time_t lastTime { 0 }; - }; - - using HistoryRecords = absl::flat_hash_map>; - struct AdmitAndEvictData { - HistoryRecords historyRecords; // embName ---> {id, FeatureItemInfo} 映射 - absl::flat_hash_map timestamps; // 用于特征准入&淘汰的时间戳 - }; - - void SetLog(int rank); - - template - string StringFormat(const string& format, Args ... args) +template +struct Batch { + size_t Size() const { - auto size = static_cast(GLOG_MAX_BUF_SIZE); - auto buf = std::make_unique(size); - memset_s(buf.get(), size, 0, size); - int nChar = snprintf_s(buf.get(), size, size - 1, format.c_str(), args ...); - if (nChar == -1) { - throw invalid_argument("StringFormat failed"); - } - return string(buf.get(), buf.get() + nChar); + return sample.size(); } - // use environment variable GLOG_v to decide if showing debug log. - // default 0, debug message will not display. - // 1 for debug, 2 for trace - constexpr int GLOG_DEBUG = 1; - constexpr int GLOG_TRACE = 2; - - template - std::string VectorToString(const std::vector& vec) + std::string UnParse() const { - constexpr size_t maxDispLen = 20; // max display number - int maxLen = static_cast(std::min(vec.size(), maxDispLen)); - - std::stringstream ss; - ss << "["; - for (size_t i = 0; i < maxLen; ++i) { - ss << vec[i]; - if (i != vec.size() - 1) { - ss << ", "; - } + std::string s; + constexpr size_t maxDispLen = 20; + int maxLen = static_cast(std::min(sample.size(), maxDispLen)); + for (int i = 0; i < maxLen; i++) { + s += std::to_string(sample[i]) + " "; } - ss << "]"; - return ss.str(); + return s; } - std::string FloatPtrToLimitStr(float* ptr, const size_t& prtSize); - - template - std::string MapToString(const std::map& map) + std::vector sample; + std::string name; + size_t batchSize; + int batchId; + int channel = 0; + time_t timestamp{-1}; +}; + +struct BatchTask { + vector splits; + vector embNames; + size_t batchSize; + int batchQueueId; + int batchId; + int channelId; + time_t timestamp{-1}; + const void* tensor; +}; + +using EmbBatchT = Batch; +using BatchTaskT = BatchTask; + +struct DDRParam { + vector tmpDataOut; + vector offsetsOut; + DDRParam(vector tmpData, vector offset) { - std::stringstream ss; - ss << "{"; - for (auto it = map.begin(); it != map.end(); ++it) { - ss << it->first << ": " << it->second; - if (std::next(it) != map.end()) { - ss << ", "; - } - } - ss << "}"; - return ss.str(); + tmpDataOut = tmpData; + offsetsOut = offset; } - - template - std::string MapToString(const absl::flat_hash_map& map) +}; + +struct RankInfo { + RankInfo() = default; + + RankInfo(int rankId, int deviceId, int localRankSize, int option, const std::vector& ctrlSteps); + RankInfo(int localRankSize, int option, const std::vector& maxStep); + + int rankId{}; + int deviceId{}; + int rankSize{}; + int localRankId{}; + int localRankSize{}; + bool useStatic{false}; + uint32_t option{}; + bool isDDR{false}; + bool isSSDEnabled{false}; + bool useDynamicExpansion{false}; + bool useSumSameIdGradients{true}; + std::vector ctrlSteps; // 包含4个步数: train_steps, eval_steps, save_steps, max_train_steps +}; + +struct EmbBaseInfo { + int batchId; + int channelId; + string name; +}; + +enum TensorIndex : uint32_t { + TENSOR_INDEX_0, + TENSOR_INDEX_1, + TENSOR_INDEX_2, + TENSOR_INDEX_3, + TENSOR_INDEX_4, + TENSOR_INDEX_5, + TENSOR_INDEX_6, + TENSOR_INDEX_7, + TENSOR_INDEX_8 +}; + +enum TupleIndex : uint32_t { + TUPLE_INDEX_0 = 0, + TUPLE_INDEX_1, + TUPLE_INDEX_2, + TUPLE_INDEX_3, + TUPLE_INDEX_4, + TUPLE_INDEX_5, + TUPLE_INDEX_6, + TUPLE_INDEX_7 +}; + +struct RandomInfo { + RandomInfo() = default; + + RandomInfo(int start, int len, float constantVal, float randomMin, float randomMax); + + int start; + int len; + float constantVal; + float randomMin; + float randomMax; +}; + +struct EmbeddingSizeInfo { + size_t embeddingSize = 0; + size_t extendEmbSize = 0; + EmbeddingSizeInfo() = default; + EmbeddingSizeInfo(size_t embSize, size_t extendSize) : embeddingSize(embSize), extendEmbSize(extendSize) {} +}; + +struct OptimizerInfo { + OptimizerInfo() = default; + OptimizerInfo(std::string name, vector params) { - std::stringstream ss; - ss << "{"; - for (auto it = map.begin(); it != map.end(); ++it) { - ss << it->first << ": " << it->second; - if (std::next(it) != map.end()) { - ss << ", "; - } - } - ss << "}"; - return ss.str(); + optimName = name; + optimParams = std::move(params); } - void ValidateReadFile(const string& dataDir, size_t datasetSize); + std::string optimName; + vector optimParams; +}; - template - inline Tensor Vec2TensorI32(const std::vector& data) +struct ThresholdValue { + ThresholdValue() = default; + ThresholdValue(EmbNameT name, int countThre, int timeThre, int faaeCoef, bool isSum) { - Tensor tmpTensor(tensorflow::DT_INT32, { static_cast(data.size()) }); - auto tmpData = tmpTensor.flat(); - for (int j = 0; j < static_cast(data.size()); j++) { - tmpData(j) = static_cast(data[j]); - } - return tmpTensor; + tableName = name; + countThreshold = countThre; + timeThreshold = timeThre; + faaeCoefficient = faaeCoef; + isEnableSum = isSum; } - template - inline Tensor Vec2TensorI64(const std::vector& data) - { - Tensor tmpTensor(tensorflow::DT_INT64, { static_cast(data.size()) }); - auto tmpData = tmpTensor.flat(); - for (int j = 0; j < static_cast(data.size()); j++) { - tmpData(j) = static_cast(data[j]); + EmbNameT tableName{""}; // embName + int countThreshold{ + -1}; // 只配置count,即“只有准入、而没有淘汰”功能,对应SingleHostEmbTableStatus::SETS_ONLY_ADMIT状态 + int timeThreshold{-1}; // 只配置time,配置错误;即准入是淘汰的前提,对应SingleHostEmbTableStatus::SETS_BOTH状态 + int faaeCoefficient{1}; // 配置后,该表在准入时,count计数会乘以该系数 + bool isEnableSum{true}; // 配置false,该表在准入时,count计数不会累加 +}; + +struct FeatureItemInfo { + FeatureItemInfo() = default; + FeatureItemInfo(uint32_t cnt, time_t lastT) : count(cnt), lastTime(lastT) {} + + uint32_t count{0}; + time_t lastTime{0}; +}; + +using HistoryRecords = absl::flat_hash_map>; +struct AdmitAndEvictData { + HistoryRecords historyRecords; // embName ---> {id, FeatureItemInfo} 映射 + absl::flat_hash_map timestamps; // 用于特征准入&淘汰的时间戳 +}; + +void SetLog(int rank); + +template +string StringFormat(const string& format, Args... args) +{ + auto size = static_cast(GLOG_MAX_BUF_SIZE); + auto buf = std::make_unique(size); + memset_s(buf.get(), size, 0, size); + int nChar = snprintf_s(buf.get(), size, size - 1, format.c_str(), args...); + if (nChar == -1) { + throw invalid_argument("StringFormat failed"); + } + return string(buf.get(), buf.get() + nChar); +} + +// use environment variable GLOG_v to decide if showing debug log. +// default 0, debug message will not display. +// 1 for debug, 2 for trace +constexpr int GLOG_DEBUG = 1; +constexpr int GLOG_TRACE = 2; + +template +std::string VectorToString(const std::vector& vec) +{ + constexpr size_t maxDispLen = 20; // max display number + int maxLen = static_cast(std::min(vec.size(), maxDispLen)); + + std::stringstream ss; + ss << "["; + for (size_t i = 0; i < maxLen; ++i) { + ss << vec[i]; + if (i != vec.size() - 1) { + ss << ", "; } - return tmpTensor; } - - struct EmbInfoParams { - std::string name; - int sendCount; - int embeddingSize; - int extEmbeddingSize; - bool isSave; - bool isGrad; - EmbInfoParams() = default; - - EmbInfoParams(const std::string& name, - int sendCount, - int embeddingSize, - int extEmbeddingSize, - bool isSave, - bool isGrad) - : name(name), - sendCount(sendCount), - embeddingSize(embeddingSize), - extEmbeddingSize(extEmbeddingSize), - isSave(isSave), - isGrad(isGrad) - { + ss << "]"; + return ss.str(); +} + +std::string FloatPtrToLimitStr(float* ptr, const size_t& prtSize); + +template +std::string MapToString(const std::map& map) +{ + std::stringstream ss; + ss << "{"; + for (auto it = map.begin(); it != map.end(); ++it) { + ss << it->first << ": " << it->second; + if (std::next(it) != map.end()) { + ss << ", "; } - }; - - struct EmbInfo { - EmbInfo() = default; - - EmbInfo(const EmbInfoParams& embInfoParams, - std::vector vocabsize, - std::vector initializeInfos, - std::vector ssdDataPath) - : name(embInfoParams.name), - sendCount(embInfoParams.sendCount), - embeddingSize(embInfoParams.embeddingSize), - extEmbeddingSize(embInfoParams.extEmbeddingSize), - isSave(embInfoParams.isSave), - isGrad(embInfoParams.isGrad), - devVocabSize(vocabsize[0]), - hostVocabSize(vocabsize[1]), - ssdVocabSize(vocabsize[SSD_SIZE_INDEX]), - initializeInfos(std::move(initializeInfos)), - ssdDataPath(std::move(ssdDataPath)) - { + } + ss << "}"; + return ss.str(); +} + +template +std::string MapToString(const absl::flat_hash_map& map) +{ + std::stringstream ss; + ss << "{"; + for (auto it = map.begin(); it != map.end(); ++it) { + ss << it->first << ": " << it->second; + if (std::next(it) != map.end()) { + ss << ", "; } + } + ss << "}"; + return ss.str(); +} + +void ValidateReadFile(const string& dataDir, size_t datasetSize); + +template +inline Tensor Vec2TensorI32(const std::vector& data) +{ + Tensor tmpTensor(tensorflow::DT_INT32, {static_cast(data.size())}); + auto tmpData = tmpTensor.flat(); + for (int j = 0; j < static_cast(data.size()); j++) { + tmpData(j) = static_cast(data[j]); + } + return tmpTensor; +} + +template +inline Tensor Vec2TensorI64(const std::vector& data) +{ + Tensor tmpTensor(tensorflow::DT_INT64, {static_cast(data.size())}); + auto tmpData = tmpTensor.flat(); + for (int j = 0; j < static_cast(data.size()); j++) { + tmpData(j) = static_cast(data[j]); + } + return tmpTensor; +} + +struct EmbInfoParams { + std::string name; + int sendCount; + int embeddingSize; + int extEmbeddingSize; + bool isSave; + bool isGrad; + EmbInfoParams() = default; + + EmbInfoParams(const std::string& name, int sendCount, int embeddingSize, int extEmbeddingSize, bool isSave, + bool isGrad) + : name(name), + sendCount(sendCount), + embeddingSize(embeddingSize), + extEmbeddingSize(extEmbeddingSize), + isSave(isSave), + isGrad(isGrad) + { + } +}; + +struct EmbInfo { + EmbInfo() = default; + + EmbInfo(const EmbInfoParams& embInfoParams, std::vector vocabsize, + std::vector initializeInfos, std::vector ssdDataPath) + : name(embInfoParams.name), + sendCount(embInfoParams.sendCount), + embeddingSize(embInfoParams.embeddingSize), + extEmbeddingSize(embInfoParams.extEmbeddingSize), + isSave(embInfoParams.isSave), + isGrad(embInfoParams.isGrad), + devVocabSize(vocabsize[0]), + hostVocabSize(vocabsize[1]), + ssdVocabSize(vocabsize[SSD_SIZE_INDEX]), + initializeInfos(std::move(initializeInfos)), + ssdDataPath(std::move(ssdDataPath)) + { + } - std::string name; - int sendCount; - int embeddingSize; - int extEmbeddingSize; - bool isSave; - bool isGrad; - size_t devVocabSize; - size_t hostVocabSize; - size_t ssdVocabSize; - std::vector initializeInfos; - std::vector ssdDataPath; - }; - - struct HostEmbTable { - EmbInfo hostEmbInfo; - std::vector> embData; - }; - - struct All2AllInfo { - KeysT keyRecv; - vector scAll; - vector countRecv; - All2AllInfo() = default; - All2AllInfo(KeysT keyRecv, vector scAll, vector countRecv) - : keyRecv(keyRecv), scAll(scAll), countRecv(countRecv) {} - }; - - struct UniqueInfo { - vector restore; - vector hotPos; - All2AllInfo all2AllInfo; - UniqueInfo() = default; - UniqueInfo(vector restore, vector hotPos, All2AllInfo all2AllInfo) - : restore(restore), hotPos(hotPos), all2AllInfo(all2AllInfo) {} - }; - - struct KeySendInfo { - KeysT keySend; - vector keyCount; - }; - - using EmbMemT = absl::flat_hash_map; - using OffsetMemT = std::map; - using KeyOffsetMemT = std::map>; - using KeyCountMemT = std::map>; - using Table2ThreshMemT = absl::flat_hash_map; - using trans_serialize_t = uint8_t; - using OffsetMapT = std::map>; - using OffsetT = std::vector; - using AllKeyOffsetMapT = std::map>; - using KeyFreqMemT = unordered_map>; - using EmbLocalTableT = EmbCache::EmbCacheManager; - - enum class CkptFeatureType { - HOST_EMB = 0, - EMB_HASHMAP = 1, - MAX_OFFSET = 2, - KEY_OFFSET_MAP = 3, - FEAT_ADMIT_N_EVICT = 4, - DDR_KEY_FREQ_MAP = 5, - EXCLUDE_DDR_KEY_FREQ_MAP = 6, - KEY_COUNT_MAP = 7, - EMB_LOCAL_TABLE = 8 - }; - - struct CkptData { - EmbMemT* hostEmbs = nullptr; - OffsetMemT maxOffset; - KeyOffsetMemT keyOffsetMap; - OffsetMapT offsetMap; - OffsetMapT* offsetMapPtr = &offsetMap; - KeyCountMemT keyCountMap; - Table2ThreshMemT table2Thresh; - AdmitAndEvictData histRec; - KeyFreqMemT ddrKeyFreqMaps; - KeyFreqMemT excludeDDRKeyFreqMaps; - }; - - struct CkptTransData { - std::vector int64Arr; - std::vector addressArr; - std::vector int32Arr; - std::vector transDataset; // may all use this to transfer data - std::vector attribute; // may need to use other form for attributes - size_t datasetSize; - size_t attributeSize; - }; - - enum class CkptDataType { - EMB_INFO = 0, - EMB_DATA = 1, - EMB_HASHMAP = 2, - DEV_OFFSET = 3, - EMB_CURR_STAT = 4, - NDDR_OFFSET = 5, - NDDR_FEATMAP = 6, - TABLE_2_THRESH = 7, - HIST_REC = 8, - ATTRIBUTE = 9, - DDR_FREQ_MAP = 10, - EXCLUDE_FREQ_MAP = 11, - EVICT_POS = 12, - KEY_COUNT_MAP = 13 - }; - - enum CTRLogLevel { // can't use enum class due to compatibility for AccCTR - DEBUG = 0, - INFO, - WARN, - ERROR, - }; - - static void CTRLog(int level, const char *msg) + std::string name; + int sendCount; + int embeddingSize; + int extEmbeddingSize; + bool isSave; + bool isGrad; + size_t devVocabSize; + size_t hostVocabSize; + size_t ssdVocabSize; + std::vector initializeInfos; + std::vector ssdDataPath; +}; + +struct HostEmbTable { + EmbInfo hostEmbInfo; + std::vector> embData; +}; + +struct All2AllInfo { + KeysT keyRecv; + vector scAll; + vector countRecv; + All2AllInfo() = default; + All2AllInfo(KeysT keyRecv, vector scAll, vector countRecv) + : keyRecv(keyRecv), + scAll(scAll), + countRecv(countRecv) { - switch (level) { - case CTRLogLevel::DEBUG: - LOG_DEBUG(msg); - break; - case CTRLogLevel::INFO: - LOG_INFO(msg); - break; - case CTRLogLevel::WARN: - LOG_WARN(msg); - break; - case CTRLogLevel::ERROR: - LOG_ERROR(msg); - break; - default: - break; - } } +}; + +struct UniqueInfo { + vector restore; + vector hotPos; + All2AllInfo all2AllInfo; + UniqueInfo() = default; + UniqueInfo(vector restore, vector hotPos, All2AllInfo all2AllInfo) + : restore(restore), + hotPos(hotPos), + all2AllInfo(all2AllInfo) + { + } +}; + +struct KeySendInfo { + KeysT keySend; + vector keyCount; +}; + +using EmbMemT = absl::flat_hash_map; +using OffsetMemT = std::map; +using KeyOffsetMemT = std::map>; +using KeyCountMemT = std::map>; +using Table2ThreshMemT = absl::flat_hash_map; +using trans_serialize_t = uint8_t; +using OffsetMapT = std::map>; +using OffsetT = std::vector; +using AllKeyOffsetMapT = std::map>; +using KeyFreqMemT = unordered_map>; +using EmbLocalTableT = EmbCache::EmbCacheManager; + +enum class CkptFeatureType { + HOST_EMB = 0, + EMB_HASHMAP = 1, + MAX_OFFSET = 2, + KEY_OFFSET_MAP = 3, + FEAT_ADMIT_N_EVICT = 4, + DDR_KEY_FREQ_MAP = 5, + EXCLUDE_DDR_KEY_FREQ_MAP = 6, + KEY_COUNT_MAP = 7, + EMB_LOCAL_TABLE = 8 +}; + +struct CkptData { + EmbMemT* hostEmbs = nullptr; + OffsetMemT maxOffset; + KeyOffsetMemT keyOffsetMap; + OffsetMapT offsetMap; + OffsetMapT* offsetMapPtr = &offsetMap; + KeyCountMemT keyCountMap; + Table2ThreshMemT table2Thresh; + AdmitAndEvictData histRec; + KeyFreqMemT ddrKeyFreqMaps; + KeyFreqMemT excludeDDRKeyFreqMaps; +}; + +struct CkptTransData { + std::vector int64Arr; + std::vector addressArr; + std::vector int32Arr; + std::vector transDataset; // may all use this to transfer data + std::vector attribute; // may need to use other form for attributes + size_t datasetSize; + size_t attributeSize; +}; + +enum class CkptDataType { + EMB_INFO = 0, + EMB_DATA = 1, + EMB_HASHMAP = 2, + DEV_OFFSET = 3, + EMB_CURR_STAT = 4, + NDDR_OFFSET = 5, + NDDR_FEATMAP = 6, + TABLE_2_THRESH = 7, + HIST_REC = 8, + ATTRIBUTE = 9, + DDR_FREQ_MAP = 10, + EXCLUDE_FREQ_MAP = 11, + EVICT_POS = 12, + KEY_COUNT_MAP = 13 +}; + +static std::string CkptDataTypeName(CkptDataType type) +{ + switch (type) { + case CkptDataType::EMB_INFO: + return "EMB_INFO"; + case CkptDataType::EMB_DATA: + return "EMB_DATA"; + case CkptDataType::EMB_HASHMAP: + return "EMB_HASHMAP"; + case CkptDataType::DEV_OFFSET: + return "DEV_OFFSET"; + case CkptDataType::EMB_CURR_STAT: + return "EMB_CURR_STAT"; + case CkptDataType::NDDR_OFFSET: + return "NDDR_OFFSET"; + case CkptDataType::NDDR_FEATMAP: + return "NDDR_FEATMAP"; + case CkptDataType::TABLE_2_THRESH: + return "TABLE_2_THRESH"; + case CkptDataType::HIST_REC: + return "HIST_REC"; + case CkptDataType::ATTRIBUTE: + return "ATTRIBUTE"; + case CkptDataType::DDR_FREQ_MAP: + return "DDR_FREQ_MAP"; + case CkptDataType::EXCLUDE_FREQ_MAP: + return "EXCLUDE_FREQ_MAP"; + case CkptDataType::EVICT_POS: + return "EVICT_POS"; + case CkptDataType::KEY_COUNT_MAP: + return "KEY_COUNT_MAP"; + default: + return "UNKNOWN"; + } +} + +enum CTRLogLevel { // can't use enum class due to compatibility for AccCTR + DEBUG = 0, + INFO, + WARN, + ERROR, +}; + +static void CTRLog(int level, const char* msg) +{ + switch (level) { + case CTRLogLevel::DEBUG: + LOG_DEBUG(msg); + break; + case CTRLogLevel::INFO: + LOG_INFO(msg); + break; + case CTRLogLevel::WARN: + LOG_WARN(msg); + break; + case CTRLogLevel::ERROR: + LOG_ERROR(msg); + break; + default: + break; + } +} - ostream& operator<<(ostream& ss, MxRec::CkptDataType type); - bool CheckFilePermission(const string& filePath); +ostream& operator<<(ostream& ss, MxRec::CkptDataType type); +bool CheckFilePermission(const string& filePath); - int GetStepFromPath(const string& loadPath); -} // end namespace MxRec +int GetStepFromPath(const string& loadPath); +} // end namespace MxRec #define KEY_PROCESS "\033[45m[KeyProcess]\033[0m " #define STAT_INFO "[StatInfo] " #ifdef GTEST - #define GTEST_PRIVATE public +#define GTEST_PRIVATE public #else - #define GTEST_PRIVATE private +#define GTEST_PRIVATE private #endif #endif diff --git a/src/ops_tf/hybrid_dataset_ops.cpp b/src/ops_tf/hybrid_dataset_ops.cpp index 2eee8531..98fca961 100644 --- a/src/ops_tf/hybrid_dataset_ops.cpp +++ b/src/ops_tf/hybrid_dataset_ops.cpp @@ -403,7 +403,7 @@ namespace MxRec { out(0) = batchId; if (channelId == 1) { if (maxStep != -1 && batchId >= maxStep) { - LOG_DEBUG(StringFormat("skip excess batch after {}/{}", batchId, maxStep)); + LOG_DEBUG(StringFormat("skip excess batch after %d/%d", batchId, maxStep)); return; } } @@ -658,4 +658,4 @@ namespace tensorflow { .SetIsStateful() .SetShapeFn(::tensorflow::shape_inference::UnknownShape); REGISTER_KERNEL_BUILDER(Name("LazyAdam").Device(DEVICE_CPU), MxRec::CustOps); -} \ No newline at end of file +} -- Gitee From f6340067b4f2475582615b8cf77cf38baeabffe3 Mon Sep 17 00:00:00 2001 From: steepcurve Date: Mon, 1 Jul 2024 08:28:14 +0000 Subject: [PATCH 05/37] fix: cleancode. Signed-off-by: steepcurve --- src/core/checkpoint/checkpoint.cpp | 26 ++++++++++++++------------ 1 file changed, 14 insertions(+), 12 deletions(-) diff --git a/src/core/checkpoint/checkpoint.cpp b/src/core/checkpoint/checkpoint.cpp index bc7501bb..469e209e 100644 --- a/src/core/checkpoint/checkpoint.cpp +++ b/src/core/checkpoint/checkpoint.cpp @@ -90,17 +90,18 @@ void Checkpoint::SetDataHandler(CkptData& ckptData) void Checkpoint::SetDataHandler(const vector& featureTypes) { - map> setCkptMap{{CkptFeatureType::FEAT_ADMIT_N_EVICT, - [this] { - dataHandlers.push_back(make_unique()); - }}, - {CkptFeatureType::DDR_KEY_FREQ_MAP, - [this] { - dataHandlers.push_back(make_unique()); - }}, - {CkptFeatureType::KEY_COUNT_MAP, [this] { - dataHandlers.push_back(make_unique()); - }}}; + auto featAdmitNEvictHandler = [this] { + dataHandlers.push_back(make_unique()); + }; + auto ddrKeyFreqMapHandler = [this] { + dataHandlers.push_back(make_unique()); + }; + auto keyCountMapHandler = [this] { + dataHandlers.push_back(make_unique()); + }; + map> setCkptMap{{CkptFeatureType::FEAT_ADMIT_N_EVICT, featAdmitNEvictHandler}, + {CkptFeatureType::DDR_KEY_FREQ_MAP, ddrKeyFreqMapHandler}, + {CkptFeatureType::KEY_COUNT_MAP, keyCountMapHandler}}; for (const auto& featureType : featureTypes) { setCkptMap.at(featureType)(); @@ -341,7 +342,8 @@ void Checkpoint::ReadStream(CkptTransData& transData, const string& dataDir, Ckp if (readBytesNum != datasetSize) { throw runtime_error(StringFormat("Error: Load data failed. data type: %s. " "Expected to read %d bytes, but actually read %d bytes to file %s.", - CkptDataTypeName(dataType).c_str(), datasetSize, readBytesNum, dataDir.c_str())); + CkptDataTypeName(dataType).c_str(), datasetSize, readBytesNum, + dataDir.c_str())); } } -- Gitee From c145cc40abe4a77ab850169ebd09576d27261c21 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E7=BD=97=E5=B9=B8=E8=BF=90?= Date: Wed, 3 Jul 2024 14:31:25 +0800 Subject: [PATCH 06/37] =?UTF-8?q?=E3=80=90FIX=E3=80=91=E4=BF=AE=E5=A4=8DSS?= =?UTF-8?q?D=E6=A8=A1=E5=BC=8F=E7=B2=BE=E5=BA=A6=E9=97=AE=E9=A2=98?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../cache_manager/cache_manager.cpp | 8 +-- src/core/hybrid_mgmt/hybrid_mgmt.cpp | 59 +++++++------------ src/core/hybrid_mgmt/hybrid_mgmt.h | 7 +-- 3 files changed, 28 insertions(+), 46 deletions(-) diff --git a/src/AccCTR/src/embedding_cache/cache_manager/cache_manager.cpp b/src/AccCTR/src/embedding_cache/cache_manager/cache_manager.cpp index 8a6187a1..3017cf8e 100644 --- a/src/AccCTR/src/embedding_cache/cache_manager/cache_manager.cpp +++ b/src/AccCTR/src/embedding_cache/cache_manager/cache_manager.cpp @@ -72,16 +72,16 @@ int EmbCacheManagerImpl::CreateCacheForTable(const EmbCacheInfo& embCacheInfo, return H_THREAD_NUM_ERROR; } - uint32_t reserve = embCacheInfo.vocabSize / VOCAB_CACHE_RATIO; - if (!offsetMappers[embCacheInfo.tableName].Initialize(reserve, embCacheInfo.maxCacheSize)) { + uint32_t reserveDevice = embCacheInfo.maxCacheSize / VOCAB_CACHE_RATIO; + if (!offsetMappers[embCacheInfo.tableName].Initialize(reserveDevice, embCacheInfo.maxCacheSize)) { offsetMappers[embCacheInfo.tableName].UnInitialize(); offsetMappers.erase(embCacheInfo.tableName); return H_MEMORY_ALLOC_ERROR; } EmbPoolParam embPoolParam{prefillBufferSize, refillThreadNum}; - - if (!embTables[embCacheInfo.tableName].Initialize(embCacheInfo, reserve, initializerInfos, embPoolParam)) { + uint32_t reserveHost = embCacheInfo.vocabSize / VOCAB_CACHE_RATIO; + if (!embTables[embCacheInfo.tableName].Initialize(embCacheInfo, reserveHost, initializerInfos, embPoolParam)) { offsetMappers.erase(embCacheInfo.tableName); embTables.erase(embCacheInfo.tableName); return H_MEMORY_ALLOC_ERROR; diff --git a/src/core/hybrid_mgmt/hybrid_mgmt.cpp b/src/core/hybrid_mgmt/hybrid_mgmt.cpp index fda54d9d..9e195419 100644 --- a/src/core/hybrid_mgmt/hybrid_mgmt.cpp +++ b/src/core/hybrid_mgmt/hybrid_mgmt.cpp @@ -918,28 +918,27 @@ void HybridMgmt::SetOptimizerInfo(const string& embName, OptimizerInfo optimInfo EmbeddingMgmt::Instance()->SetOptimizerInfo(embName, optimInfo); } -void HybridMgmt::LookUpAddrs(const string &embName, int extEmbeddingSize) +// L3Storage +void HybridMgmt::LookUpAndRemoveAddrs(const EmbTaskInfo &info) { - int id = 0; - uint64_t memSize = extEmbeddingSize * sizeof(float); + uint64_t memSize = info.extEmbeddingSize * sizeof(float); const std::string hbmSwapKeyQueName = "HBMSwapKeyQue"; const std::string ddrSwapKeyQueName = "DDRSwapKeyQue"; - auto lookUpFunc = [this, memSize, embName, id]( + auto lookUpFunc = [this, memSize, info]( std::map>> &fromQue, std::map>> &toQue, const string &swapStr, const string &fromQueName ) { - std::vector keys = fromQue[embName + swapStr].WaitAndPop(); + std::vector keys = fromQue[info.name + swapStr].WaitAndPop(); if (!isRunning) { return; } std::vector addrs; TimeCost lookupAddrsTC; - int rc = embCache->EmbeddingLookupAddrs(embName, keys, addrs); + int rc = embCache->EmbeddingLookupAddrs(info.name, keys, addrs); if (rc != H_OK) { - lookupAddrSuccess = false; LOG_ERROR("lookUpAddrs, table:{}, fromQue: {}, swapStr:{}, keys.size:{}, addrs.size:{}, pushId:{}", - embName, fromQueName, swapStr, keys.size(), addrs.size(), id); + info.name, fromQueName, swapStr, keys.size(), addrs.size(), info.batchId); throw runtime_error("EmbeddingLookupAddrs failed! error code:" + std::to_string(rc)); } if (&fromQue == &DDRSwapKeyQue && swapStr == SWAP_OUT_STR) { @@ -947,31 +946,28 @@ void HybridMgmt::LookUpAddrs(const string &embName, int extEmbeddingSize) auto *newAddr = (float*)malloc(memSize); rc = memcpy_s(newAddr, memSize, addr, memSize); if (rc != 0) { - lookupAddrSuccess = false; throw runtime_error("memcpy_s failed! error code:" + std::to_string(rc)); } addr = newAddr; } - rc = embCache->EmbeddingRemove(embName, keys); + rc = embCache->EmbeddingRemove(info.name, keys); if (rc != H_OK) { - lookupAddrSuccess = false; throw runtime_error("EmbeddingRemove failed! error code:" + std::to_string(rc)); } } LOG_DEBUG("table:{}, fromQue:{}, swapStr:{}, keys.size:{}, addrs.size:{}, pushId:{}, lookupAddrsTC(ms):{}", - embName, fromQueName, swapStr, keys.size(), addrs.size(), id, lookupAddrsTC.ElapsedMS()); - toQue[embName + swapStr].Pushv(addrs); + info.name, fromQueName, swapStr, keys.size(), addrs.size(), info.batchId, lookupAddrsTC.ElapsedMS()); + toQue[info.name + swapStr].Pushv(addrs); }; - while (isRunning && lookupAddrSuccess) { - lookUpFunc(DDRSwapKeyQue, DDRSwapAddrsQue, SWAP_OUT_STR, ddrSwapKeyQueName); - lookUpFunc(DDRSwapKeyQue, DDRSwapAddrsQue, SWAP_IN_STR, ddrSwapKeyQueName); - lookUpFunc(HBMSwapKeyQue, tableToQueueLookup, SWAP_IN_STR, hbmSwapKeyQueName); - lookUpFunc(HBMSwapKeyQue, tableToQueueLookup, SWAP_OUT_STR, hbmSwapKeyQueName); - id++; - lookUpSwapInAddrsPushId[embName]++; - } + + lookUpFunc(DDRSwapKeyQue, DDRSwapAddrsQue, SWAP_OUT_STR, ddrSwapKeyQueName); + lookUpFunc(DDRSwapKeyQue, DDRSwapAddrsQue, SWAP_IN_STR, ddrSwapKeyQueName); + lookUpFunc(HBMSwapKeyQue, tableToQueueLookup, SWAP_IN_STR, hbmSwapKeyQueName); + lookUpFunc(HBMSwapKeyQue, tableToQueueLookup, SWAP_OUT_STR, hbmSwapKeyQueName); + lookUpSwapInAddrsPushId[info.name]++; } +// DDR void HybridMgmt::LookUpSwapAddrs(const string &embName, const string &swapStr) { int id = 0; @@ -1146,6 +1142,9 @@ void HybridMgmt::EmbeddingReceiveAndUpdateL3Storage(int batchId, int index, cons .extEmbeddingSize=embInfo.extEmbeddingSize, .name=embInfo.name }; + // host swap out need to be executed before lookup + LookUpAndRemoveAddrs(info); + float* ptr = nullptr; vector swapOutAddrs; int64_t dims0 = 0; @@ -1226,8 +1225,6 @@ void HybridMgmt::ProcessEmbInfoL3Storage(const EmbBaseInfo& info, bool& remainBa HandleEndBatchCase(info, swapInPos); - CheckLookupAddrSuccessL3Storage(); - if (info.channelId == TRAIN_CHANNEL_ID) { alreadyTrainOnce = true; } @@ -1295,8 +1292,6 @@ void HybridMgmt::InitDataPipelineForL3Storage(const string &embName, int extEmbe DDRSwapAddrsQue[embName + SWAP_IN_STR]; // 初始化lookup线程 - lookUpThreads.emplace_back( - std::async(std::launch::async, [=] { LookUpAddrs(embName, extEmbeddingSize); })); LOG_DEBUG("data pipeline for L3Storage init"); } @@ -1321,8 +1316,9 @@ void HybridMgmt::InitEmbeddingCache(const vector& embInfos) embInfo.name, embInfo.hostVocabSize, embInfo.extEmbeddingSize, embInfo.devVocabSize); EmbCache::EmbCacheInfo embCacheInfo(embInfo.name, embInfo.hostVocabSize, embInfo.embeddingSize, embInfo.extEmbeddingSize, embInfo.devVocabSize); + size_t prefill = std::max(embInfo.hostVocabSize/10, 2 * embInfo.devVocabSize); int ret = embCache->CreateCacheForTable( - embCacheInfo, embInfo.initializeInfos, INVALID_KEY_VALUE, embInfo.hostVocabSize, EMBEDDING_THREAD_NUM); + embCacheInfo, embInfo.initializeInfos, INVALID_KEY_VALUE, prefill, EMBEDDING_THREAD_NUM); if (ret != H_OK) { throw runtime_error(embInfo.name + "create cache for table failed, error code: " + std::to_string(ret)); } @@ -1355,9 +1351,6 @@ void HybridMgmt::JoinEmbeddingCacheThread() for (auto& t : EmbeddingReceiveAndUpdateThreadPool) { t.join(); } - for (auto& t : lookUpThreads) { - t.wait(); - } for (auto& t : lookUpSwapInAddrsThreads) { t.wait(); } @@ -2175,14 +2168,6 @@ void HybridMgmt::CheckLookupAddrSuccessDDR() } } -void HybridMgmt::CheckLookupAddrSuccessL3Storage() -{ - if (!lookupAddrSuccess) { - for (auto& t : lookUpThreads) { - t.get(); - } - } -} void HybridMgmt::GetSwapPairsAndKey2Offset(const EmbBaseInfo &info, vector &uniqueKeys, pair, vector> &swapInKoPair, diff --git a/src/core/hybrid_mgmt/hybrid_mgmt.h b/src/core/hybrid_mgmt/hybrid_mgmt.h index 83299da3..0654be91 100644 --- a/src/core/hybrid_mgmt/hybrid_mgmt.h +++ b/src/core/hybrid_mgmt/hybrid_mgmt.h @@ -157,7 +157,6 @@ namespace MxRec { std::vector EmbeddingReceiveAndUpdateThreadPool; std::vector> lookUpSwapOutAddrsThreads; std::vector> lookUpSwapInAddrsThreads; - std::vector> lookUpThreads; std::map>> HBMSwapKeyQue; std::map>> SwapOut2L3StorageKeyQue; @@ -190,9 +189,9 @@ namespace MxRec { void EvictL3StorageKeys(const string& embName, const vector& keys) const; - void LookUpAddrs(const string &embName, int extEmbeddingSize); + void LookUpAndRemoveAddrs(const EmbTaskInfo &info); // L3Storage, synchronous - void LookUpSwapAddrs(const std::string &embName, const std::string &swapStr); + void LookUpSwapAddrs(const std::string &embName, const std::string &swapStr); // DDR, asynchronous void EmbeddingTask(); @@ -312,8 +311,6 @@ namespace MxRec { void CheckLookupAddrSuccessDDR(); - void CheckLookupAddrSuccessL3Storage(); - void GetSwapPairsAndKey2Offset(const EmbBaseInfo& info, vector &uniqueKeys, std::pair, vector>& swapInKoPair, std::pair, vector>& swapOutKoPair); -- Gitee From 12b6f9f608f43a9d3e0f981f531dc2c72021478f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E7=BD=97=E5=B9=B8=E8=BF=90?= Date: Wed, 3 Jul 2024 14:48:57 +0800 Subject: [PATCH 07/37] =?UTF-8?q?=E3=80=90FIX=E3=80=91=E4=BF=AE=E5=A4=8DSS?= =?UTF-8?q?D=E6=A8=A1=E5=BC=8F=E7=B2=BE=E5=BA=A6=E9=97=AE=E9=A2=98?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../cache_manager/cache_manager.cpp | 7 +- src/core/emb_table/embedding_ddr.cpp | 2 +- src/core/hybrid_mgmt/hybrid_mgmt.cpp | 65 +++++++++---------- src/core/hybrid_mgmt/hybrid_mgmt.h | 7 +- src/core/l3_storage/cache_manager.h | 4 +- 5 files changed, 41 insertions(+), 44 deletions(-) diff --git a/src/AccCTR/src/embedding_cache/cache_manager/cache_manager.cpp b/src/AccCTR/src/embedding_cache/cache_manager/cache_manager.cpp index 3017cf8e..c6cc1bbd 100644 --- a/src/AccCTR/src/embedding_cache/cache_manager/cache_manager.cpp +++ b/src/AccCTR/src/embedding_cache/cache_manager/cache_manager.cpp @@ -72,16 +72,15 @@ int EmbCacheManagerImpl::CreateCacheForTable(const EmbCacheInfo& embCacheInfo, return H_THREAD_NUM_ERROR; } - uint32_t reserveDevice = embCacheInfo.maxCacheSize / VOCAB_CACHE_RATIO; - if (!offsetMappers[embCacheInfo.tableName].Initialize(reserveDevice, embCacheInfo.maxCacheSize)) { + uint32_t reserve = embCacheInfo.vocabSize / VOCAB_CACHE_RATIO; + if (!offsetMappers[embCacheInfo.tableName].Initialize(reserve, embCacheInfo.maxCacheSize)) { offsetMappers[embCacheInfo.tableName].UnInitialize(); offsetMappers.erase(embCacheInfo.tableName); return H_MEMORY_ALLOC_ERROR; } EmbPoolParam embPoolParam{prefillBufferSize, refillThreadNum}; - uint32_t reserveHost = embCacheInfo.vocabSize / VOCAB_CACHE_RATIO; - if (!embTables[embCacheInfo.tableName].Initialize(embCacheInfo, reserveHost, initializerInfos, embPoolParam)) { + if (!embTables[embCacheInfo.tableName].Initialize(embCacheInfo, reserve, initializerInfos, embPoolParam)) { offsetMappers.erase(embCacheInfo.tableName); embTables.erase(embCacheInfo.tableName); return H_MEMORY_ALLOC_ERROR; diff --git a/src/core/emb_table/embedding_ddr.cpp b/src/core/emb_table/embedding_ddr.cpp index ca706c73..151e372c 100644 --- a/src/core/emb_table/embedding_ddr.cpp +++ b/src/core/emb_table/embedding_ddr.cpp @@ -235,7 +235,7 @@ void EmbeddingDDR::SyncLatestEmbedding() } } else { // 在保存之前先更新ddr和ssd的embedding - SwapOutInfo info; + HBMSwapOutInfo info; cacheManager_->ProcessSwapOutKeys(name, swapOutKeys, info); vector swapOutAddrs; rc = embCache->EmbeddingLookupAddrs(name, info.swapOutDDRKeys, swapOutAddrs); diff --git a/src/core/hybrid_mgmt/hybrid_mgmt.cpp b/src/core/hybrid_mgmt/hybrid_mgmt.cpp index 9e195419..01beb358 100644 --- a/src/core/hybrid_mgmt/hybrid_mgmt.cpp +++ b/src/core/hybrid_mgmt/hybrid_mgmt.cpp @@ -962,8 +962,8 @@ void HybridMgmt::LookUpAndRemoveAddrs(const EmbTaskInfo &info) lookUpFunc(DDRSwapKeyQue, DDRSwapAddrsQue, SWAP_OUT_STR, ddrSwapKeyQueName); lookUpFunc(DDRSwapKeyQue, DDRSwapAddrsQue, SWAP_IN_STR, ddrSwapKeyQueName); - lookUpFunc(HBMSwapKeyQue, tableToQueueLookup, SWAP_IN_STR, hbmSwapKeyQueName); - lookUpFunc(HBMSwapKeyQue, tableToQueueLookup, SWAP_OUT_STR, hbmSwapKeyQueName); + lookUpFunc(HBMSwapKeyQue, HBMSwapAddrsQue, SWAP_IN_STR, hbmSwapKeyQueName); + lookUpFunc(HBMSwapKeyQue, HBMSwapAddrsQue, SWAP_OUT_STR, hbmSwapKeyQueName); lookUpSwapInAddrsPushId[info.name]++; } @@ -987,7 +987,7 @@ void HybridMgmt::LookUpSwapAddrs(const string &embName, const string &swapStr) LOG_DEBUG( "table:{}, swapStr:{}, keys.size:{}, addrs.size:{}, pushId:{}, lookupAddrsTC(ms):{}", embName, swapStr, keys.size(), addrs.size(), id, lookupAddrsTC.ElapsedMS()); - tableToQueueLookup[swapName].Pushv(addrs); + HBMSwapAddrsQue[swapName].Pushv(addrs); if (swapStr==SWAP_IN_STR) { lookUpSwapInAddrsPushId[embName]++; LOG_DEBUG("LookUpSwapAddrs, table:{}, pushId:{}, lookUpSwapInAddrsPushId:{}", @@ -1258,8 +1258,8 @@ void HybridMgmt::InitDataPipelineForDDR(const string &embName) // 初始化公共队列 HBMSwapKeyQue[embName+SWAP_IN_STR]; HBMSwapKeyQue[embName+SWAP_OUT_STR]; - tableToQueueLookup[embName+SWAP_IN_STR]; - tableToQueueLookup[embName+SWAP_OUT_STR]; + HBMSwapAddrsQue[embName + SWAP_IN_STR]; + HBMSwapAddrsQue[embName + SWAP_OUT_STR]; // 初始化lookup线程 lookUpSwapInAddrsPushId[embName]; // 此处初始化,避免多线程竞争导致计数错误 @@ -1276,13 +1276,13 @@ void HybridMgmt::InitDataPipelineForL3Storage(const string &embName, int extEmbe // 初始化公共队列 HBMSwapKeyQue[embName+SWAP_IN_STR]; HBMSwapKeyQue[embName+SWAP_OUT_STR]; - tableToQueueLookup[embName+SWAP_IN_STR]; - tableToQueueLookup[embName+SWAP_OUT_STR]; + HBMSwapAddrsQue[embName + SWAP_IN_STR]; + HBMSwapAddrsQue[embName + SWAP_OUT_STR]; HBMSwapKeyQue[embName + ADDR_STR]; - SwapOut2L3StorageKeyQue[embName + SWAP_IN_STR]; - SwapOut2L3StorageKeyQue[embName + ADDR_STR]; - SwapOut2L3StorageKeyQue[embName + SWAP_OUT_STR]; + HBMSwapKeyForL3StorageQue[embName + SWAP_IN_STR]; + HBMSwapKeyForL3StorageQue[embName + ADDR_STR]; + HBMSwapKeyForL3StorageQue[embName + SWAP_OUT_STR]; DDRSwapKeyQue[embName + SWAP_OUT_STR]; DDRSwapKeyQue[embName + SWAP_IN_STR]; @@ -1316,9 +1316,8 @@ void HybridMgmt::InitEmbeddingCache(const vector& embInfos) embInfo.name, embInfo.hostVocabSize, embInfo.extEmbeddingSize, embInfo.devVocabSize); EmbCache::EmbCacheInfo embCacheInfo(embInfo.name, embInfo.hostVocabSize, embInfo.embeddingSize, embInfo.extEmbeddingSize, embInfo.devVocabSize); - size_t prefill = std::max(embInfo.hostVocabSize/10, 2 * embInfo.devVocabSize); int ret = embCache->CreateCacheForTable( - embCacheInfo, embInfo.initializeInfos, INVALID_KEY_VALUE, prefill, EMBEDDING_THREAD_NUM); + embCacheInfo, embInfo.initializeInfos, INVALID_KEY_VALUE, embInfo.hostVocabSize, EMBEDDING_THREAD_NUM); if (ret != H_OK) { throw runtime_error(embInfo.name + "create cache for table failed, error code: " + std::to_string(ret)); } @@ -1327,13 +1326,13 @@ void HybridMgmt::InitEmbeddingCache(const vector& embInfos) void HybridMgmt::JoinEmbeddingCacheThread() { - for (auto &p : tableToQueueLookup) { + for (auto &p : HBMSwapAddrsQue) { p.second.DestroyQueue(); } for (auto &p : HBMSwapKeyQue) { p.second.DestroyQueue(); } - for (auto &p : SwapOut2L3StorageKeyQue) { + for (auto &p : HBMSwapKeyForL3StorageQue) { p.second.DestroyQueue(); } for (auto &p : DDRSwapKeyQue) { @@ -1439,7 +1438,7 @@ bool HybridMgmt::EmbeddingReceiveDDR(const EmbTaskInfo& info, float*& ptr, vecto } TimeCost EmbeddingRecvTC = TimeCost(); - swapOutAddrs = tableToQueueLookup[info.name+SWAP_OUT_STR].WaitAndPop(); + swapOutAddrs = HBMSwapAddrsQue[info.name + SWAP_OUT_STR].WaitAndPop(); if (!isRunning) { return false; } @@ -1617,7 +1616,7 @@ bool HybridMgmt::EmbeddingReceiveL3Storage(const EmbTaskInfo &info, float *&ptr, } TimeCost EmbeddingRecvTC = TimeCost(); // finish时会pop空vector,因此需要额外判定isRunning - swapOutAddrs = tableToQueueLookup[info.name+SWAP_OUT_STR].WaitAndPop(); + swapOutAddrs = HBMSwapAddrsQue[info.name + SWAP_OUT_STR].WaitAndPop(); if (!isRunning) { return false; } @@ -1681,8 +1680,8 @@ void HybridMgmt::EmbeddingUpdateL3Storage(const EmbTaskInfo& info, float *embPtr // L3Storage更新 TimeCost L3StorageUpdateTC = TimeCost(); - std::vector swapOutL3StorageAddrOffs = SwapOut2L3StorageKeyQue[info.name + ADDR_STR].WaitAndPop(); - std::vector swapOutL3StorageKeys = SwapOut2L3StorageKeyQue[info.name + SWAP_OUT_STR].WaitAndPop(); + std::vector swapOutL3StorageAddrOffs = HBMSwapKeyForL3StorageQue[info.name + ADDR_STR].WaitAndPop(); + std::vector swapOutL3StorageKeys = HBMSwapKeyForL3StorageQue[info.name + SWAP_OUT_STR].WaitAndPop(); if (!isRunning) { return; } @@ -1874,8 +1873,8 @@ void HybridMgmt::HandleFirstBatchCaseL3Storage(const EmbBaseInfo& info, HBMSwapKeyQue[info.name + SWAP_IN_STR].Pushv(swapInKoPair.first); // HBM->L3Storage - SwapOut2L3StorageKeyQue[info.name + SWAP_OUT_STR].Pushv(emptySwapOutL3StorageKeys); - SwapOut2L3StorageKeyQue[info.name + ADDR_STR].Pushv(emptySwapOutL3StorageAddrOff); + HBMSwapKeyForL3StorageQue[info.name + SWAP_OUT_STR].Pushv(emptySwapOutL3StorageKeys); + HBMSwapKeyForL3StorageQue[info.name + ADDR_STR].Pushv(emptySwapOutL3StorageAddrOff); } void HybridMgmt::HandleDataSwapForL3Storage(const EmbBaseInfo& info, @@ -1888,18 +1887,18 @@ void HybridMgmt::HandleDataSwapForL3Storage(const EmbBaseInfo& info, LOG_DEBUG("ProcessSwapInKeysTC(ms):{} ", ProcessSwapInKeysTC.ElapsedMS()); TimeCost ProcessSwapOutKeysTC; - SwapOutInfo swapInfo; - cacheManager->ProcessSwapOutKeys(info.name, swapOutKeys, swapInfo); + HBMSwapOutInfo hbmSwapInfo; + cacheManager->ProcessSwapOutKeys(info.name, swapOutKeys, hbmSwapInfo); LOG_DEBUG("ProcessSwapOutKeysTC(ms):{} ", ProcessSwapOutKeysTC.ElapsedMS()); LOG_DEBUG("table:{}, batchId:{}, channelId:{}, swapInSize:{}, swapOutSize:{}", info.name, info.batchId, info.channelId, swapInKeys.size(), swapOutKeys.size()); - LOG_DEBUG("table:{}, batchId:{}, channelId:{}, swapOutDDRKeys:{}, swapOutDDRAddrOffs:{}, " - "swapOutL3StorageKeys:{}, swapOutL3StorageAddrOff:{}", - info.name, info.batchId, info.channelId, swapInfo.swapOutDDRKeys.size(), - swapInfo.swapOutDDRAddrOffs.size(), swapInfo.swapOutL3StorageKeys.size(), - swapInfo.swapOutL3StorageAddrOffs.size()); - LOG_DEBUG("table:{}, batchId:{}, channelId:{}, DDRToL3StorageKeys:{}, L3StorageToDDRKeys:{}", + LOG_DEBUG("table:{}, batchId:{}, channelId:{}, swap out, HBM2DDR Keys:{}, HBM2DDR AddrOffs:{}, " + "HBM2L3Storage Keys:{}, HBM2L3Storage AddrOff:{}", + info.name, info.batchId, info.channelId, hbmSwapInfo.swapOutDDRKeys.size(), + hbmSwapInfo.swapOutDDRAddrOffs.size(), hbmSwapInfo.swapOutL3StorageKeys.size(), + hbmSwapInfo.swapOutL3StorageAddrOffs.size()); + LOG_DEBUG("table:{}, batchId:{}, channelId:{}, DDR2L3Storage Keys:{}, L3Storage2DDR Keys:{}", info.name, info.batchId, info.channelId, DDRToL3StorageKeys.size(), L3StorageToDDRKeys.size()); auto DDRToL3StorageKeysForL3S = DDRToL3StorageKeys; @@ -1912,18 +1911,18 @@ void HybridMgmt::HandleDataSwapForL3Storage(const EmbBaseInfo& info, DDRSwapKeyForL3StorageQue[info.name + SWAP_IN_STR].Pushv(L3StorageToDDRKeysForL3S); // HBM<->DDR - HBMSwapKeyQue[info.name + SWAP_OUT_STR].Pushv(swapInfo.swapOutDDRKeys); - HBMSwapKeyQue[info.name + ADDR_STR].Pushv(swapInfo.swapOutDDRAddrOffs); + HBMSwapKeyQue[info.name + SWAP_OUT_STR].Pushv(hbmSwapInfo.swapOutDDRKeys); + HBMSwapKeyQue[info.name + ADDR_STR].Pushv(hbmSwapInfo.swapOutDDRAddrOffs); HBMSwapKeyQue[info.name + SWAP_IN_STR].Pushv(swapInKeys); // HBM->L3Storage - SwapOut2L3StorageKeyQue[info.name + SWAP_OUT_STR].Pushv(swapInfo.swapOutL3StorageKeys); - SwapOut2L3StorageKeyQue[info.name + ADDR_STR].Pushv(swapInfo.swapOutL3StorageAddrOffs); + HBMSwapKeyForL3StorageQue[info.name + SWAP_OUT_STR].Pushv(hbmSwapInfo.swapOutL3StorageKeys); + HBMSwapKeyForL3StorageQue[info.name + ADDR_STR].Pushv(hbmSwapInfo.swapOutL3StorageAddrOffs); } bool HybridMgmt::BuildH2DEmbedding(const EmbTaskInfo &info, vector &h2dEmb) { - std::vector swapInAddrs = tableToQueueLookup[info.name+SWAP_IN_STR].WaitAndPop(); + std::vector swapInAddrs = HBMSwapAddrsQue[info.name + SWAP_IN_STR].WaitAndPop(); if (!isRunning) { return false; } diff --git a/src/core/hybrid_mgmt/hybrid_mgmt.h b/src/core/hybrid_mgmt/hybrid_mgmt.h index 0654be91..f5897861 100644 --- a/src/core/hybrid_mgmt/hybrid_mgmt.h +++ b/src/core/hybrid_mgmt/hybrid_mgmt.h @@ -159,21 +159,20 @@ namespace MxRec { std::vector> lookUpSwapInAddrsThreads; std::map>> HBMSwapKeyQue; - std::map>> SwapOut2L3StorageKeyQue; + std::map>> HBMSwapKeyForL3StorageQue; std::map>> DDRSwapKeyQue; std::map>> DDRSwapKeyForL3StorageQue; + std::map>> HBMSwapAddrsQue; std::map>> DDRSwapAddrsQue; std::mutex evictMut; std::map> trainKeysSet; - const string SWAP_IN_STR = "SwapIn"; const string SWAP_OUT_STR = "SwapOut"; - const string ADDR_STR = "Addr"; + const string ADDR_STR = "Addr"; ock::ctr::EmbCacheManagerPtr embCache = nullptr; - std::map>> tableToQueueLookup; std::map> lastSwapInPosMap {}; std::map>> trainTestSwitchInfoStore {}; std::atomic lookupAddrSuccess {true}; diff --git a/src/core/l3_storage/cache_manager.h b/src/core/l3_storage/cache_manager.h index 3f5b0a22..79335788 100644 --- a/src/core/l3_storage/cache_manager.h +++ b/src/core/l3_storage/cache_manager.h @@ -40,7 +40,7 @@ namespace MxRec { absl::flat_hash_map& keyOffsetMap; }; - struct SwapOutInfo { + struct HBMSwapOutInfo { vector swapOutDDRKeys; vector swapOutDDRAddrOffs; vector swapOutL3StorageKeys; @@ -89,7 +89,7 @@ namespace MxRec { void PutKey(const string& embTableName, const emb_key_t& key, RecordType type); void ProcessSwapOutKeys(const string& tableName, const vector& swapOutKeys, - SwapOutInfo& info); + HBMSwapOutInfo& info); void ProcessSwapInKeys(const string& tableName, const vector& swapInKeys, vector& DDRToL3StorageKeys, -- Gitee From ddefbd55694512acf3a2213c94cd491ed3058077 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E7=BD=97=E5=B9=B8=E8=BF=90?= Date: Wed, 3 Jul 2024 15:03:10 +0800 Subject: [PATCH 08/37] =?UTF-8?q?=E3=80=90FIX=E3=80=91=E4=BF=AE=E5=A4=8Dpr?= =?UTF-8?q?efill=E9=97=AE=E9=A2=98?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/core/hybrid_mgmt/hybrid_mgmt.cpp | 3 ++- src/core/utils/common.h | 1 + 2 files changed, 3 insertions(+), 1 deletion(-) diff --git a/src/core/hybrid_mgmt/hybrid_mgmt.cpp b/src/core/hybrid_mgmt/hybrid_mgmt.cpp index fda54d9d..e4e30f64 100644 --- a/src/core/hybrid_mgmt/hybrid_mgmt.cpp +++ b/src/core/hybrid_mgmt/hybrid_mgmt.cpp @@ -1321,8 +1321,9 @@ void HybridMgmt::InitEmbeddingCache(const vector& embInfos) embInfo.name, embInfo.hostVocabSize, embInfo.extEmbeddingSize, embInfo.devVocabSize); EmbCache::EmbCacheInfo embCacheInfo(embInfo.name, embInfo.hostVocabSize, embInfo.embeddingSize, embInfo.extEmbeddingSize, embInfo.devVocabSize); + size_t prefill = std::max(embInfo.hostVocabSize/HOST_TO_PREFILL_RATIO, embInfo.devVocabSize); int ret = embCache->CreateCacheForTable( - embCacheInfo, embInfo.initializeInfos, INVALID_KEY_VALUE, embInfo.hostVocabSize, EMBEDDING_THREAD_NUM); + embCacheInfo, embInfo.initializeInfos, INVALID_KEY_VALUE, prefill, EMBEDDING_THREAD_NUM); if (ret != H_OK) { throw runtime_error(embInfo.name + "create cache for table failed, error code: " + std::to_string(ret)); } diff --git a/src/core/utils/common.h b/src/core/utils/common.h index 9a39e7ac..c020bbc5 100644 --- a/src/core/utils/common.h +++ b/src/core/utils/common.h @@ -69,6 +69,7 @@ namespace MxRec { constexpr int SSD_SIZE_INDEX = 2; constexpr int MAX_FILE_NUM = 1000; constexpr int EMBEDDING_THREAD_NUM = 2; + constexpr int HOST_TO_PREFILL_RATIO = 10; // for GLOG struct GlogConfig { static bool gStatOn; -- Gitee From afa1b548e1d91dd6bd5f8b9a2b4e61af04019440 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E7=BD=97=E5=B9=B8=E8=BF=90?= Date: Wed, 3 Jul 2024 15:09:34 +0800 Subject: [PATCH 09/37] =?UTF-8?q?=E3=80=90FIX=E3=80=91=E4=BF=AE=E5=A4=8Dre?= =?UTF-8?q?serve=E5=9C=A8dev=E4=BE=A7=E6=B5=AA=E8=B4=B9=E7=9A=84=E9=97=AE?= =?UTF-8?q?=E9=A2=98?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../embedding_cache/cache_manager/cache_manager.cpp | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/src/AccCTR/src/embedding_cache/cache_manager/cache_manager.cpp b/src/AccCTR/src/embedding_cache/cache_manager/cache_manager.cpp index 8a6187a1..a9fac9f6 100644 --- a/src/AccCTR/src/embedding_cache/cache_manager/cache_manager.cpp +++ b/src/AccCTR/src/embedding_cache/cache_manager/cache_manager.cpp @@ -64,7 +64,8 @@ int EmbCacheManagerImpl::CreateCacheForTable(const EmbCacheInfo& embCacheInfo, } if ((prefillBufferSize < 1) || (prefillBufferSize > embCacheInfo.vocabSize)) { - ExternalLogger::PrintLog(LogLevel::ERROR, "prefillBufferSize has to be between [1, hostVocabSize]"); + ExternalLogger::PrintLog(LogLevel::ERROR, "prefillBufferSize: " + std::to_string(prefillBufferSize) + + "has to be between [1, hostVocabSize]"); return H_PREFILL_BUFFER_SIZE_INVALID; } @@ -72,16 +73,16 @@ int EmbCacheManagerImpl::CreateCacheForTable(const EmbCacheInfo& embCacheInfo, return H_THREAD_NUM_ERROR; } - uint32_t reserve = embCacheInfo.vocabSize / VOCAB_CACHE_RATIO; - if (!offsetMappers[embCacheInfo.tableName].Initialize(reserve, embCacheInfo.maxCacheSize)) { + uint32_t reserveDevice = embCacheInfo.maxCacheSize / VOCAB_CACHE_RATIO; + if (!offsetMappers[embCacheInfo.tableName].Initialize(reserveDevice, embCacheInfo.maxCacheSize)) { offsetMappers[embCacheInfo.tableName].UnInitialize(); offsetMappers.erase(embCacheInfo.tableName); return H_MEMORY_ALLOC_ERROR; } EmbPoolParam embPoolParam{prefillBufferSize, refillThreadNum}; - - if (!embTables[embCacheInfo.tableName].Initialize(embCacheInfo, reserve, initializerInfos, embPoolParam)) { + uint32_t reserveHost = embCacheInfo.vocabSize / VOCAB_CACHE_RATIO; + if (!embTables[embCacheInfo.tableName].Initialize(embCacheInfo, reserveHost, initializerInfos, embPoolParam)) { offsetMappers.erase(embCacheInfo.tableName); embTables.erase(embCacheInfo.tableName); return H_MEMORY_ALLOC_ERROR; -- Gitee From fe3982f8995eb03e428f03866c941b8569f86785 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E7=BD=97=E5=B9=B8=E8=BF=90?= Date: Wed, 3 Jul 2024 15:19:09 +0800 Subject: [PATCH 10/37] =?UTF-8?q?=E3=80=90FIX=E3=80=91=E4=BF=AE=E5=A4=8Dre?= =?UTF-8?q?serve=E5=9C=A8dev=E4=BE=A7=E6=B5=AA=E8=B4=B9=E7=9A=84=E9=97=AE?= =?UTF-8?q?=E9=A2=98?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/core/l3_storage/cache_manager.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/core/l3_storage/cache_manager.cpp b/src/core/l3_storage/cache_manager.cpp index 75d73b2d..ee3d7bc5 100644 --- a/src/core/l3_storage/cache_manager.cpp +++ b/src/core/l3_storage/cache_manager.cpp @@ -181,7 +181,7 @@ int64_t CacheManager::GetTableUsage(const string& tableName) } void CacheManager::ProcessSwapOutKeys(const string& tableName, const vector& swapOutKeys, - SwapOutInfo& info) + HBMSwapOutInfo& info) { auto& swapOutDDRKeys = info.swapOutDDRKeys; auto& swapOutDDRAddrOffs = info.swapOutDDRAddrOffs; -- Gitee From 486e3e9f7159de52b78bb4313d027e7d48d6918b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E7=BD=97=E5=B9=B8=E8=BF=90?= Date: Wed, 3 Jul 2024 15:37:52 +0800 Subject: [PATCH 11/37] =?UTF-8?q?=E3=80=90FIX=E3=80=91=E4=BF=AE=E5=A4=8DSS?= =?UTF-8?q?D=E7=B2=BE=E5=BA=A6?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/core/hybrid_mgmt/hybrid_mgmt.cpp | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/src/core/hybrid_mgmt/hybrid_mgmt.cpp b/src/core/hybrid_mgmt/hybrid_mgmt.cpp index 01beb358..6969c27d 100644 --- a/src/core/hybrid_mgmt/hybrid_mgmt.cpp +++ b/src/core/hybrid_mgmt/hybrid_mgmt.cpp @@ -1142,8 +1142,6 @@ void HybridMgmt::EmbeddingReceiveAndUpdateL3Storage(int batchId, int index, cons .extEmbeddingSize=embInfo.extEmbeddingSize, .name=embInfo.name }; - // host swap out need to be executed before lookup - LookUpAndRemoveAddrs(info); float* ptr = nullptr; vector swapOutAddrs; @@ -1614,6 +1612,9 @@ bool HybridMgmt::EmbeddingReceiveL3Storage(const EmbTaskInfo &info, float *&ptr, if (!isRunning) { return false; } + // DDR swap out key need to be removed + LookUpAndRemoveAddrs(info); + TimeCost EmbeddingRecvTC = TimeCost(); // finish时会pop空vector,因此需要额外判定isRunning swapOutAddrs = HBMSwapAddrsQue[info.name + SWAP_OUT_STR].WaitAndPop(); -- Gitee From 474ca30dc51783a59d16803324e25cf8aa2d0395 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E7=BD=97=E5=B9=B8=E8=BF=90?= Date: Wed, 3 Jul 2024 16:04:13 +0800 Subject: [PATCH 12/37] =?UTF-8?q?=E3=80=90FIX=E3=80=91=E5=88=A0=E9=99=A4?= =?UTF-8?q?=E5=8A=A8=E6=80=81=E6=89=A9=E5=AE=B9=E5=86=97=E4=BD=99=E4=BB=A3?= =?UTF-8?q?=E7=A0=81=EF=BC=8C=E4=BF=AE=E5=A4=8Dtest?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/tests/emb_table/embedding_ddr_test.cpp | 1 - 1 file changed, 1 deletion(-) diff --git a/src/tests/emb_table/embedding_ddr_test.cpp b/src/tests/emb_table/embedding_ddr_test.cpp index 60ec5af6..097167f6 100644 --- a/src/tests/emb_table/embedding_ddr_test.cpp +++ b/src/tests/emb_table/embedding_ddr_test.cpp @@ -22,7 +22,6 @@ See the License for the specific language governing permissions and #include #include #include "utils/common.h" -#include "emb_table/emb_table.h" #include "emb_table/embedding_ddr.h" using namespace std; -- Gitee From e6c501c12489146c7887c52519f739922367f780 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E7=BD=97=E5=B9=B8=E8=BF=90?= Date: Thu, 4 Jul 2024 09:26:28 +0800 Subject: [PATCH 13/37] =?UTF-8?q?=E3=80=90FIX=E3=80=91issure?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../src/embedding_cache/cache_manager/cache_manager.cpp | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/AccCTR/src/embedding_cache/cache_manager/cache_manager.cpp b/src/AccCTR/src/embedding_cache/cache_manager/cache_manager.cpp index a9fac9f6..76e90abc 100644 --- a/src/AccCTR/src/embedding_cache/cache_manager/cache_manager.cpp +++ b/src/AccCTR/src/embedding_cache/cache_manager/cache_manager.cpp @@ -64,8 +64,8 @@ int EmbCacheManagerImpl::CreateCacheForTable(const EmbCacheInfo& embCacheInfo, } if ((prefillBufferSize < 1) || (prefillBufferSize > embCacheInfo.vocabSize)) { - ExternalLogger::PrintLog(LogLevel::ERROR, "prefillBufferSize: " + std::to_string(prefillBufferSize) + - "has to be between [1, hostVocabSize]"); + ExternalLogger::PrintLog(LogLevel::ERROR, "PrefillBufferSize: " + std::to_string(prefillBufferSize) + + " has to be between [1, hostVocabSize]."); return H_PREFILL_BUFFER_SIZE_INVALID; } -- Gitee From 637ef26d445886b98aee01f64b13a40081c60c04 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E7=BD=97=E5=B9=B8=E8=BF=90?= Date: Thu, 4 Jul 2024 09:54:03 +0800 Subject: [PATCH 14/37] =?UTF-8?q?=E3=80=90FIX=E3=80=91delete?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/tests/emb_table/embedding_mgmt_test.cpp | 1 - src/tests/emb_table/embedding_static_test.cpp | 1 - 2 files changed, 2 deletions(-) diff --git a/src/tests/emb_table/embedding_mgmt_test.cpp b/src/tests/emb_table/embedding_mgmt_test.cpp index 055cf5c5..81a354bf 100644 --- a/src/tests/emb_table/embedding_mgmt_test.cpp +++ b/src/tests/emb_table/embedding_mgmt_test.cpp @@ -22,7 +22,6 @@ See the License for the specific language governing permissions and #include #include #include "utils/common.h" -#include "emb_table/emb_table.h" #include "emb_table/embedding_mgmt.h" using namespace std; diff --git a/src/tests/emb_table/embedding_static_test.cpp b/src/tests/emb_table/embedding_static_test.cpp index a08569b3..5d1f0ab7 100644 --- a/src/tests/emb_table/embedding_static_test.cpp +++ b/src/tests/emb_table/embedding_static_test.cpp @@ -21,7 +21,6 @@ See the License for the specific language governing permissions and #include #include #include "utils/common.h" -#include "emb_table/emb_table.h" #include "emb_table/embedding_static.h" using namespace std; -- Gitee From c9a321e908b0290b60b1776b595e38b37d2698ee Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E7=BD=97=E5=B9=B8=E8=BF=90?= Date: Thu, 4 Jul 2024 16:08:15 +0800 Subject: [PATCH 15/37] =?UTF-8?q?=E3=80=90FIX=E3=80=91=E6=89=A9=E5=AE=B9?= =?UTF-8?q?=E6=A8=A1=E5=BC=8F=E4=B8=8B=EF=BC=8Ctable.capacity=E5=87=BA?= =?UTF-8?q?=E7=8E=B0=E5=81=B6=E5=8F=91=E8=B4=9F=E5=80=BC?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/core/emb_table/embedding_ddr.cpp | 2 +- src/core/emb_table/embedding_dynamic.cpp | 4 ++-- src/core/emb_table/embedding_table.h | 3 ++- 3 files changed, 5 insertions(+), 4 deletions(-) diff --git a/src/core/emb_table/embedding_ddr.cpp b/src/core/emb_table/embedding_ddr.cpp index ca706c73..b9ca70dc 100644 --- a/src/core/emb_table/embedding_ddr.cpp +++ b/src/core/emb_table/embedding_ddr.cpp @@ -45,7 +45,7 @@ void EmbeddingDDR::Key2Offset(std::vector& splitKey, int channel) int64_t EmbeddingDDR::capacity() const { - return capacity_; + return capacity_.load(); } /* diff --git a/src/core/emb_table/embedding_dynamic.cpp b/src/core/emb_table/embedding_dynamic.cpp index 7f8cd7e5..703d08ad 100644 --- a/src/core/emb_table/embedding_dynamic.cpp +++ b/src/core/emb_table/embedding_dynamic.cpp @@ -77,7 +77,7 @@ void EmbeddingDynamic::Key2Offset(std::vector& keys, int channel) int64_t EmbeddingDynamic::capacity() const { - return capacity_; + return capacity_.load(); } int64_t EmbeddingDynamic::GetEmptyEmbeddingAddress() @@ -103,7 +103,7 @@ void EmbeddingDynamic::MallocEmbeddingBlock(int embNum) float *embAddr = static_cast(block) + (i * extEmbSize_); embeddingList_.push_back(embAddr); } - capacity_ += embNum; + capacity_.fetch_add(embNum); } void EmbeddingDynamic::RandomInit(void* addr, size_t embNum) diff --git a/src/core/emb_table/embedding_table.h b/src/core/emb_table/embedding_table.h index cbf15a7a..3396a8a0 100644 --- a/src/core/emb_table/embedding_table.h +++ b/src/core/emb_table/embedding_table.h @@ -15,6 +15,7 @@ See the License for the specific language governing permissions and #ifndef MX_REC_EMBEDDING_TABLE_H #define MX_REC_EMBEDDING_TABLE_H +#include #include #include #include @@ -113,7 +114,7 @@ protected: size_t embSize_; size_t extEmbSize_; int seed_; - int64_t capacity_; + std::atomic capacity_; size_t rankId_; size_t rankSize_; vector loadOffset; -- Gitee From a55ed2e12febf7542083651f8ed95ffbd3894e90 Mon Sep 17 00:00:00 2001 From: steepcurve Date: Fri, 5 Jul 2024 08:01:45 +0000 Subject: [PATCH 16/37] fix: conflict in src/core/utils/common.h. Signed-off-by: steepcurve --- src/core/utils/common.h | 1 + 1 file changed, 1 insertion(+) diff --git a/src/core/utils/common.h b/src/core/utils/common.h index f8ff4565..8c7528f4 100644 --- a/src/core/utils/common.h +++ b/src/core/utils/common.h @@ -70,6 +70,7 @@ constexpr size_t MAX_VOCABULARY_SIZE = 1e10; constexpr int SSD_SIZE_INDEX = 2; constexpr int MAX_FILE_NUM = 1000; constexpr int EMBEDDING_THREAD_NUM = 2; +constexpr int HOST_TO_PREFILL_RATIO = 10; // for GLOG struct GlogConfig { static bool gStatOn; -- Gitee From 03a664699c4aa5fd7bd6e8353327181dcd677aec Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E7=BD=97=E5=B9=B8=E8=BF=90?= Date: Fri, 5 Jul 2024 16:38:57 +0800 Subject: [PATCH 17/37] =?UTF-8?q?=E3=80=90FIX=E3=80=91=E6=89=A9=E5=AE=B9?= =?UTF-8?q?=E6=A8=A1=E5=BC=8F=E4=B8=8B=EF=BC=8Ctable.capacity=E5=87=BA?= =?UTF-8?q?=E7=8E=B0=E5=81=B6=E5=8F=91=E8=B4=9F=E5=80=BC?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/core/hybrid_mgmt/hybrid_mgmt.cpp | 591 +++++++++++++-------------- src/core/hybrid_mgmt/hybrid_mgmt.h | 408 +++++++++--------- 2 files changed, 485 insertions(+), 514 deletions(-) diff --git a/src/core/hybrid_mgmt/hybrid_mgmt.cpp b/src/core/hybrid_mgmt/hybrid_mgmt.cpp index 61064fb4..3eb99685 100644 --- a/src/core/hybrid_mgmt/hybrid_mgmt.cpp +++ b/src/core/hybrid_mgmt/hybrid_mgmt.cpp @@ -15,23 +15,23 @@ See the License for the specific language governing permissions and #include "hybrid_mgmt.h" +#include + #include +#include #include -#include #include #include -#include +#include "checkpoint/checkpoint.h" +#include "emb_table/embedding_mgmt.h" #include "hd_transfer/hd_transfer.h" #include "hybrid_mgmt/hybrid_mgmt_block.h" -#include "utils/time_cost.h" -#include "utils/logger.h" -#include "utils/common.h" -#include "checkpoint/checkpoint.h" -#include "key_process/key_process.h" #include "key_process/feature_admit_and_evict.h" -#include "emb_table/embedding_mgmt.h" - +#include "key_process/key_process.h" +#include "utils/common.h" +#include "utils/logger.h" +#include "utils/time_cost.h" using namespace MxRec; using namespace std; @@ -98,8 +98,8 @@ bool HybridMgmt::Initialize(RankInfo rankInfo, const vector& embInfos, InitRankInfo(rankInfo, embInfos); GlogConfig::gStatOn = GlobalEnv::statOn; - LOG_INFO(MGMT + "begin initialize, localRankSize:{}, localRankId:{}, rank:{}", - rankInfo.localRankSize, rankInfo.localRankId, rankInfo.rankId); + LOG_INFO(MGMT + "begin initialize, localRankSize:{}, localRankId:{}, rank:{}", rankInfo.localRankSize, + rankInfo.localRankId, rankInfo.rankId); mgmtRankInfo = rankInfo; mgmtEmbInfo = embInfos; @@ -134,15 +134,15 @@ bool HybridMgmt::Initialize(RankInfo rankInfo, const vector& embInfos, Start(); } - for (const auto& info: embInfos) { - LOG_INFO(MGMT + "table:{}, vocab size dev+host:{}+{}, send count:{}", - info.name, info.devVocabSize, info.hostVocabSize, info.sendCount); + for (const auto& info : embInfos) { + LOG_INFO(MGMT + "table:{}, vocab size dev+host:{}+{}, send count:{}", info.name, info.devVocabSize, + info.hostVocabSize, info.sendCount); } LOG_INFO(MGMT + "end initialize, rankId:{}, isDDR:{}, " "step[train_interval, eval_interval, save_interval, max_train_step]:[{}, {}, {}, {}]", - rankInfo.rankId, rankInfo.isDDR, - rankInfo.ctrlSteps.at(TRAIN_CHANNEL_ID), rankInfo.ctrlSteps.at(EVAL_CHANNEL_ID), - rankInfo.ctrlSteps.at(SAVE_STEP_INDEX), rankInfo.ctrlSteps.at(MAX_TRAIN_STEP_INDEX)); + rankInfo.rankId, rankInfo.isDDR, rankInfo.ctrlSteps.at(TRAIN_CHANNEL_ID), + rankInfo.ctrlSteps.at(EVAL_CHANNEL_ID), rankInfo.ctrlSteps.at(SAVE_STEP_INDEX), + rankInfo.ctrlSteps.at(MAX_TRAIN_STEP_INDEX)); #endif isInitialized = true; @@ -225,7 +225,7 @@ bool HybridMgmt::Load(const string& loadPath, vector warmStartTables) if (warmStartTables.size() == 0) { EmbeddingMgmt::Instance()->Load(loadPath, trainKeysSet); } else { - for (auto& tableName: warmStartTables) { + for (auto& tableName : warmStartTables) { EmbeddingMgmt::Instance()->Load(tableName, loadPath, trainKeysSet); } } @@ -373,17 +373,17 @@ void HybridMgmt::Start() void HybridMgmt::StartThreadForHBM() { #ifndef GTEST - auto parseKeysTaskForHBMTrain = [this]() { - TrainTask(TaskType::HBM); - LOG_INFO("parseKeysTaskForHBMTrain done"); - }; - procThreads.emplace_back(std::make_unique(parseKeysTaskForHBMTrain)); - - auto parseKeysTaskForHBMEval = [this]() { - EvalTask(TaskType::HBM); - LOG_INFO("parseKeysTaskForHBMEval done"); - }; - procThreads.emplace_back(std::make_unique(parseKeysTaskForHBMEval)); + auto parseKeysTaskForHBMTrain = [this]() { + TrainTask(TaskType::HBM); + LOG_INFO("parseKeysTaskForHBMTrain done"); + }; + procThreads.emplace_back(std::make_unique(parseKeysTaskForHBMTrain)); + + auto parseKeysTaskForHBMEval = [this]() { + EvalTask(TaskType::HBM); + LOG_INFO("parseKeysTaskForHBMEval done"); + }; + procThreads.emplace_back(std::make_unique(parseKeysTaskForHBMEval)); #endif } @@ -424,7 +424,7 @@ void HybridMgmt::Destroy() isRunning = false; mutexDestroy = true; - for (const auto& embInfo: mgmtEmbInfo) { + for (const auto& embInfo : mgmtEmbInfo) { for (int index = 0; index < EMBEDDING_THREAD_NUM; index++) { cvLastUpdateFinishMap[embInfo.name][index].notify_all(); cvLastLookUpFinishMap[embInfo.name][index].notify_all(); @@ -456,7 +456,9 @@ void HybridMgmt::Destroy() // 停止预处理 KEY_PROCESS_INSTANCE->Destroy(); // stop embCache, even if the host emb is still allocating - if (embCache != nullptr) { embCache->Destroy(); } + if (embCache != nullptr) { + embCache->Destroy(); + } LOG_DEBUG(MGMT + "Destroy hybrid_mgmt module end."); } @@ -493,12 +495,10 @@ void HybridMgmt::EvalTask(TaskType type) do { hybridMgmtBlock->CheckAndSetBlock(channelId); if (hybridMgmtBlock->GetBlockStatus(channelId)) { - LOG_DEBUG("eval channel block at batchId:{}, needWaitSave:{}", - evalBatchId, hybridMgmtBlock->IsNeedWaitSave()); + LOG_DEBUG("eval channel block at batchId:{}, needWaitSave:{}", evalBatchId, + hybridMgmtBlock->IsNeedWaitSave()); std::unique_lock checkSaveLocker(saveMutex); - cvCheckSave.wait(checkSaveLocker, [this] { - return !hybridMgmtBlock->IsNeedWaitSave() || mutexDestroy; - }); + cvCheckSave.wait(checkSaveLocker, [this] { return !hybridMgmtBlock->IsNeedWaitSave() || mutexDestroy; }); hybridMgmtBlock->Wake(TRAIN_CHANNEL_ID); LOG_DEBUG("wake TrainTask"); hybridMgmtBlock->DoBlock(channelId); @@ -513,29 +513,28 @@ void HybridMgmt::EvalTask(TaskType type) #endif } -void HybridMgmt::SendUniqKeysAndRestoreVecHBM(const EmbBaseInfo &info, - const unique_ptr> &infoVecs, bool isGrad) const +void HybridMgmt::SendUniqKeysAndRestoreVecHBM(const EmbBaseInfo& info, const unique_ptr>& infoVecs, + bool isGrad) const { TimeCost sendUniqueKeysSyncTC; - LOG_DEBUG("channelId:{} batchId:{}, global unique, table name: {}, is grad: {}", - info.channelId, info.batchId, info.name, isGrad); + LOG_DEBUG("channelId:{} batchId:{}, global unique, table name: {}, is grad: {}", info.channelId, info.batchId, + info.name, isGrad); if (isGrad) { hdTransfer->Send(TransferChannel::UNIQKEYS, {infoVecs->back()}, info.channelId, info.name); } infoVecs->pop_back(); - LOG_DEBUG("channelId:{} batchId:{}, sendUniqueKeysSyncTC(ms):{}", - info.channelId, info.batchId, sendUniqueKeysSyncTC.ElapsedMS()); + LOG_DEBUG("channelId:{} batchId:{}, sendUniqueKeysSyncTC(ms):{}", info.channelId, info.batchId, + sendUniqueKeysSyncTC.ElapsedMS()); TimeCost sendUniqueRestoreVecSyncTC; if (isGrad) { hdTransfer->Send(TransferChannel::RESTORE_SECOND, {infoVecs->back()}, info.channelId, info.name); } infoVecs->pop_back(); - LOG_DEBUG("channelId:{} batchId:{}, sendUniqueRestoreVecSyncTC(ms):{}", - info.channelId, info.batchId, sendUniqueRestoreVecSyncTC.ElapsedMS()); + LOG_DEBUG("channelId:{} batchId:{}, sendUniqueRestoreVecSyncTC(ms):{}", info.channelId, info.batchId, + sendUniqueRestoreVecSyncTC.ElapsedMS()); } - /// 当前处理的batch是否是最后一个batch,涵盖train切换eval、save场景 /// \param batchId 已处理的batch数 /// \return @@ -544,13 +543,12 @@ bool HybridMgmt::IsTrainEndBatch(int batchId) const // case 1:需要切eval // case 2:需要save时,补发pos后被阻塞,等待save完成,避免embCache状态发送变化 // batchId是从0开始的,所以要+1对上step - bool isNeedSwitchToEval = mgmtRankInfo.ctrlSteps[TRAIN_CHANNEL_ID] != -1 && - (batchId + 1) % mgmtRankInfo.ctrlSteps[TRAIN_CHANNEL_ID] == 0; - bool isNeedSave = mgmtRankInfo.ctrlSteps[SAVE_STEP_INDEX] != -1 && - mgmtRankInfo.ctrlSteps[SAVE_STEP_INDEX] != 0 && + bool isNeedSwitchToEval = + mgmtRankInfo.ctrlSteps[TRAIN_CHANNEL_ID] != -1 && (batchId + 1) % mgmtRankInfo.ctrlSteps[TRAIN_CHANNEL_ID] == 0; + bool isNeedSave = mgmtRankInfo.ctrlSteps[SAVE_STEP_INDEX] != -1 && mgmtRankInfo.ctrlSteps[SAVE_STEP_INDEX] != 0 && (batchId + 1) % mgmtRankInfo.ctrlSteps[SAVE_STEP_INDEX] == 0; - LOG_DEBUG("mgmtRankInfo.ctrlSteps[TRAIN_CHANNEL_ID]:{}, batchId:{}", - mgmtRankInfo.ctrlSteps[TRAIN_CHANNEL_ID], batchId); + LOG_DEBUG("mgmtRankInfo.ctrlSteps[TRAIN_CHANNEL_ID]:{}, batchId:{}", mgmtRankInfo.ctrlSteps[TRAIN_CHANNEL_ID], + batchId); LOG_DEBUG("isNeedSwitchToEval:{}, isNeedSave:{}", isNeedSwitchToEval, isNeedSave); return isNeedSwitchToEval || isNeedSave; } @@ -570,26 +568,23 @@ bool HybridMgmt::ParseKeys(int channelId, int& batchId, TaskType type) #ifndef GTEST LOG_INFO(MGMT + "channelId:{} batchId:{}, ParseKeys start.", channelId, batchId); TimeCost parseKeyTC; - bool remainBatch = true; // 是否从通道获取了数据 + bool remainBatch = true; // 是否从通道获取了数据 vector parseKeyThreadPool; for (const auto& embInfo : mgmtEmbInfo) { - EmbBaseInfo info = {.batchId=batchId, .channelId=channelId, .name=embInfo.name}; + EmbBaseInfo info = {.batchId = batchId, .channelId = channelId, .name = embInfo.name}; switch (type) { case TaskType::HBM: - parseKeyThreadPool.emplace_back([this, info, &remainBatch, embInfo]() { - ProcessEmbInfoHBM(info, remainBatch, embInfo.isGrad); - }); + parseKeyThreadPool.emplace_back( + [this, info, &remainBatch, embInfo]() { ProcessEmbInfoHBM(info, remainBatch, embInfo.isGrad); }); break; case TaskType::DDR: if (!isL3StorageEnabled) { - parseKeyThreadPool.emplace_back([this, info, &remainBatch, embInfo]() { - ProcessEmbInfoDDR(info, remainBatch); - }); + parseKeyThreadPool.emplace_back( + [this, info, &remainBatch, embInfo]() { ProcessEmbInfoDDR(info, remainBatch); }); } else { - parseKeyThreadPool.emplace_back([this, info, &remainBatch, embInfo]() { - ProcessEmbInfoL3Storage(info, remainBatch); - }); + parseKeyThreadPool.emplace_back( + [this, info, &remainBatch, embInfo]() { ProcessEmbInfoL3Storage(info, remainBatch); }); } break; default: @@ -608,14 +603,14 @@ bool HybridMgmt::ParseKeys(int channelId, int& batchId, TaskType type) if (!isRunning) { return false; } - LOG_DEBUG(MGMT + "channelId:{} batchId:{}, ParseKeys end, parseKeyTC(ms):{}", - channelId, batchId, parseKeyTC.ElapsedMS()); + LOG_DEBUG(MGMT + "channelId:{} batchId:{}, ParseKeys end, parseKeyTC(ms):{}", channelId, batchId, + parseKeyTC.ElapsedMS()); batchId++; #endif return true; } -void HybridMgmt::ProcessEmbInfoHBM(const EmbBaseInfo &info, bool& remainBatchOut, bool isGrad) +void HybridMgmt::ProcessEmbInfoHBM(const EmbBaseInfo& info, bool& remainBatchOut, bool isGrad) { TimeCost parseKeysTc; LOG_DEBUG("ProcessEmbInfoHBM table:{}, batchId:{}, channel:{}", info.name, info.batchId, info.channelId); @@ -628,13 +623,13 @@ void HybridMgmt::ProcessEmbInfoHBM(const EmbBaseInfo &info, bool& remainBatchOut return; } if (infoVecs == nullptr) { - LOG_INFO(MGMT + "table:{}, channelId:{} batchId:{}, ParseKeys infoVecs empty !", - info.name, info.channelId, info.batchId); + LOG_INFO(MGMT + "table:{}, channelId:{} batchId:{}, ParseKeys infoVecs empty !", info.name, info.channelId, + info.batchId); remainBatchOut = false; return; } - LOG_DEBUG("table:{}, channelId:{} batchId:{}, ParseKeysHBM GetInfoVec end", - info.name, info.channelId, info.batchId); + LOG_DEBUG("table:{}, channelId:{} batchId:{}, ParseKeysHBM GetInfoVec end", info.name, info.channelId, + info.batchId); // 动态shape场景下,获取all2all向量(通信量矩阵) SendAll2AllVec(info, remainBatchOut); @@ -644,10 +639,10 @@ void HybridMgmt::ProcessEmbInfoHBM(const EmbBaseInfo &info, bool& remainBatchOut // 发送查询向量 TimeCost sendLookupSyncTC; - hdTransfer->Send(TransferChannel::LOOKUP, { infoVecs->back() }, info.channelId, info.name); + hdTransfer->Send(TransferChannel::LOOKUP, {infoVecs->back()}, info.channelId, info.name); infoVecs->pop_back(); - LOG_DEBUG("table:{}, channelId:{} batchId:{}, sendLookupSyncTC(ms):{}", - info.name, info.channelId, info.batchId, sendLookupSyncTC.ElapsedMS()); + LOG_DEBUG("table:{}, channelId:{} batchId:{}, sendLookupSyncTC(ms):{}", info.name, info.channelId, info.batchId, + sendLookupSyncTC.ElapsedMS()); // 训练时,使用全局去重聚合梯度,发送全局去重的key和对应的恢复向量 if (mgmtRankInfo.useSumSameIdGradients && info.channelId == TRAIN_CHANNEL_ID) { @@ -657,18 +652,17 @@ void HybridMgmt::ProcessEmbInfoHBM(const EmbBaseInfo &info, bool& remainBatchOut // 发送恢复向量 TimeCost sendRestoreSyncTC; hdTransfer->Send(TransferChannel::RESTORE, *infoVecs, info.channelId, info.name); - LOG_DEBUG("table:{}, sendRestoreSyncTC(ms):{}, parseKeysTc HBM mode (ms):{}", - info.name, sendRestoreSyncTC.ElapsedMS(), parseKeysTc.ElapsedMS()); + LOG_DEBUG("table:{}, sendRestoreSyncTC(ms):{}, parseKeysTc HBM mode (ms):{}", info.name, + sendRestoreSyncTC.ElapsedMS(), parseKeysTc.ElapsedMS()); - LOG_INFO(MGMT + "table:{}, channelId:{} batchId:{}, embName:{}, ParseKeys with HBM mode end.", - info.name, info.channelId, info.batchId, info.name); + LOG_INFO(MGMT + "table:{}, channelId:{} batchId:{}, embName:{}, ParseKeys with HBM mode end.", info.name, + info.channelId, info.batchId, info.name); if (info.channelId == TRAIN_CHANNEL_ID) { alreadyTrainOnce = true; } } - /// 构造训练所需的各种向量数据 /// \param embName 表名 /// \param batchId 已处理的batch数 @@ -680,7 +674,7 @@ void HybridMgmt::ProcessEmbInfoDDR(const EmbBaseInfo& info, bool& remainBatchOut TimeCost getAndSendTensorsTC; LOG_DEBUG("ProcessEmbInfoDDR start, table:{}, channel:{}, batchId:{}", info.name, info.channelId, info.batchId); - if (info.channelId == TRAIN_CHANNEL_ID && info.batchId == hybridMgmtBlock->maxTrainStep) { + if (info.channelId == TRAIN_CHANNEL_ID && info.batchId == hybridMgmtBlock->maxTrainStep) { HandleReachMaxStepCase(info, remainBatchOut); return; } @@ -718,10 +712,10 @@ void HybridMgmt::ProcessEmbInfoDDR(const EmbBaseInfo& info, bool& remainBatchOut SendGlobalUniqueVec(info, uniqueKeys, restoreVecSec); TimeCost swapProcessTC; - auto &swapInPos = swapInKoPair.second; - auto &swapOutPos = swapOutKoPair.second; + auto& swapInPos = swapInKoPair.second; + auto& swapOutPos = swapOutKoPair.second; auto lastSwapInPos = lastSwapInPosMap[info.name]; - lastSwapInPosMap[info.name] = swapInPos; // 暂存待下一步发送 + lastSwapInPosMap[info.name] = swapInPos; // 暂存待下一步发送 auto isNeedReturn = HandleSpecialProcessStatusDDR(info, getAndSendTensorsTC, swapInKoPair, swapOutKoPair); if (isNeedReturn) { @@ -827,7 +821,6 @@ void HybridMgmt::EvictL3StorageKeys(const string& embName, const vectorEvictL3StorageEmbedding(embName, keys); } - /// 通过pyBind在python侧调用,通知hybridMgmt上层即将进行图的执行,需要进行唤醒 /// \param channelID 通道id /// \param steps 运行的步数,由于可能存在循环下沉,所以1个session run 对应N步 @@ -919,16 +912,14 @@ void HybridMgmt::SetOptimizerInfo(const string& embName, OptimizerInfo optimInfo } // L3Storage -void HybridMgmt::LookUpAndRemoveAddrs(const EmbTaskInfo &info) +void HybridMgmt::LookUpAndRemoveAddrs(const EmbTaskInfo& info) { uint64_t memSize = info.extEmbeddingSize * sizeof(float); const std::string hbmSwapKeyQueName = "HBMSwapKeyQue"; const std::string ddrSwapKeyQueName = "DDRSwapKeyQue"; - auto lookUpFunc = [this, memSize, info]( - std::map>> &fromQue, - std::map>> &toQue, - const string &swapStr, const string &fromQueName - ) { + auto lookUpFunc = [this, memSize, info](std::map>>& fromQue, + std::map>>& toQue, + const string& swapStr, const string& fromQueName) { std::vector keys = fromQue[info.name + swapStr].WaitAndPop(); if (!isRunning) { return; @@ -942,8 +933,8 @@ void HybridMgmt::LookUpAndRemoveAddrs(const EmbTaskInfo &info) throw runtime_error("EmbeddingLookupAddrs failed! error code:" + std::to_string(rc)); } if (&fromQue == &DDRSwapKeyQue && swapStr == SWAP_OUT_STR) { - for (auto &addr : addrs) { - auto *newAddr = (float*)malloc(memSize); + for (auto& addr : addrs) { + auto* newAddr = (float*)malloc(memSize); rc = memcpy_s(newAddr, memSize, addr, memSize); if (rc != 0) { throw runtime_error("memcpy_s failed! error code:" + std::to_string(rc)); @@ -968,7 +959,7 @@ void HybridMgmt::LookUpAndRemoveAddrs(const EmbTaskInfo &info) } // DDR -void HybridMgmt::LookUpSwapAddrs(const string &embName, const string &swapStr) +void HybridMgmt::LookUpSwapAddrs(const string& embName, const string& swapStr) { int id = 0; std::string swapName = embName + swapStr; @@ -977,21 +968,20 @@ void HybridMgmt::LookUpSwapAddrs(const string &embName, const string &swapStr) if (!isRunning) { return; } - vector addrs; + vector addrs; TimeCost lookupAddrsTC; int rc = embCache->EmbeddingLookupAddrs(embName, keys, addrs); if (rc != H_OK) { lookupAddrSuccess = false; throw runtime_error("EmbeddingLookupAddrs failed! error code: " + std::to_string(rc)); } - LOG_DEBUG( - "table:{}, swapStr:{}, keys.size:{}, addrs.size:{}, pushId:{}, lookupAddrsTC(ms):{}", - embName, swapStr, keys.size(), addrs.size(), id, lookupAddrsTC.ElapsedMS()); + LOG_DEBUG("table:{}, swapStr:{}, keys.size:{}, addrs.size:{}, pushId:{}, lookupAddrsTC(ms):{}", embName, + swapStr, keys.size(), addrs.size(), id, lookupAddrsTC.ElapsedMS()); HBMSwapAddrsQue[swapName].Pushv(addrs); - if (swapStr==SWAP_IN_STR) { + if (swapStr == SWAP_IN_STR) { lookUpSwapInAddrsPushId[embName]++; - LOG_DEBUG("LookUpSwapAddrs, table:{}, pushId:{}, lookUpSwapInAddrsPushId:{}", - embName, id, lookUpSwapInAddrsPushId[embName]); + LOG_DEBUG("LookUpSwapAddrs, table:{}, pushId:{}, lookUpSwapInAddrsPushId:{}", embName, id, + lookUpSwapInAddrsPushId[embName]); } id++; } @@ -1006,15 +996,15 @@ void HybridMgmt::FetchDeviceEmb() if (mgmtRankInfo.isDDR) { // DDR模式保存host的emb表以及hashmap LOG_DEBUG(MGMT + "start host side save: ddr mode"); - for (const auto &embInfo: mgmtEmbInfo) { + for (const auto& embInfo : mgmtEmbInfo) { std::vector> koVec; embCache->ExportDeviceKeyOffsetPairs(embInfo.name, koVec); std::vector swapOutPos; - for (const auto &p : koVec) { + for (const auto& p : koVec) { swapOutPos.push_back(p.second); } - vector swapTensor; + vector swapTensor; swapTensor.emplace_back(Vec2TensorI32(swapOutPos)); swapTensor.emplace_back(Tensor(tensorflow::DT_INT32, {1})); auto swapOutLen = swapTensor.back().flat(); @@ -1030,7 +1020,7 @@ void HybridMgmt::FetchDeviceEmb() // 这里就是新增的embedding处理线程 void HybridMgmt::EmbeddingTask() { - for (const auto& embInfo: mgmtEmbInfo) { + for (const auto& embInfo : mgmtEmbInfo) { lastUpdateFinishStepMap[embInfo.name] = 0; lastLookUpFinishStepMap[embInfo.name] = 0; lastSendFinishStepMap[embInfo.name] = 0; @@ -1045,7 +1035,7 @@ void HybridMgmt::EmbeddingTask() void HybridMgmt::MultiThreadEmbHDTransWrap() { for (int index = 0; index < EMBEDDING_THREAD_NUM; index++) { - for (const auto& embInfo: mgmtEmbInfo) { + for (const auto& embInfo : mgmtEmbInfo) { CreateEmbeddingLookUpAndSendThread(index, embInfo); CreateEmbeddingReceiveAndUpdateThread(index, embInfo); } @@ -1059,13 +1049,11 @@ void HybridMgmt::EmbeddingLookUpAndSendDDR(int batchId, int index, const EmbInfo cvNotifyIndex = index + 1; } - EmbTaskInfo info = { - .batchId=batchId, - .threadIdx=index, - .cvNotifyIndex=cvNotifyIndex, - .extEmbeddingSize=embInfo.extEmbeddingSize, - .name=embInfo.name - }; + EmbTaskInfo info = {.batchId = batchId, + .threadIdx = index, + .cvNotifyIndex = cvNotifyIndex, + .extEmbeddingSize = embInfo.extEmbeddingSize, + .name = embInfo.name}; vector h2dEmb; auto isSuccess = EmbeddingLookUpDDR(info, h2dEmb); @@ -1084,13 +1072,11 @@ void HybridMgmt::EmbeddingReceiveAndUpdateDDR(int batchId, int index, const EmbI cvNotifyIndex = index + 1; } - EmbTaskInfo info = { - .batchId=batchId, - .threadIdx=index, - .cvNotifyIndex=cvNotifyIndex, - .extEmbeddingSize=embInfo.extEmbeddingSize, - .name=embInfo.name - }; + EmbTaskInfo info = {.batchId = batchId, + .threadIdx = index, + .cvNotifyIndex = cvNotifyIndex, + .extEmbeddingSize = embInfo.extEmbeddingSize, + .name = embInfo.name}; float* ptr = nullptr; vector swapOutAddrs; @@ -1110,13 +1096,11 @@ void HybridMgmt::EmbeddingLookUpAndSendL3Storage(int batchId, int index, const E cvNotifyIndex = index + 1; } - EmbTaskInfo info = { - .batchId=batchId, - .threadIdx=index, - .cvNotifyIndex=cvNotifyIndex, - .extEmbeddingSize=embInfo.extEmbeddingSize, - .name=embInfo.name - }; + EmbTaskInfo info = {.batchId = batchId, + .threadIdx = index, + .cvNotifyIndex = cvNotifyIndex, + .extEmbeddingSize = embInfo.extEmbeddingSize, + .name = embInfo.name}; vector h2dEmb; auto isSuccess = EmbeddingLookUpL3Storage(info, h2dEmb); @@ -1135,13 +1119,11 @@ void HybridMgmt::EmbeddingReceiveAndUpdateL3Storage(int batchId, int index, cons cvNotifyIndex = index + 1; } - EmbTaskInfo info = { - .batchId=batchId, - .threadIdx=index, - .cvNotifyIndex=cvNotifyIndex, - .extEmbeddingSize=embInfo.extEmbeddingSize, - .name=embInfo.name - }; + EmbTaskInfo info = {.batchId = batchId, + .threadIdx = index, + .cvNotifyIndex = cvNotifyIndex, + .extEmbeddingSize = embInfo.extEmbeddingSize, + .name = embInfo.name}; float* ptr = nullptr; vector swapOutAddrs; @@ -1151,7 +1133,6 @@ void HybridMgmt::EmbeddingReceiveAndUpdateL3Storage(int batchId, int index, cons EmbeddingUpdateL3Storage(info, ptr, swapOutAddrs, dims0); } - /// 构造训练所需的各种向量数据 /// \param embName 表名 /// \param batchId 已处理的batch数 @@ -1164,7 +1145,7 @@ void HybridMgmt::ProcessEmbInfoL3Storage(const EmbBaseInfo& info, bool& remainBa TimeCost getAndSendTensorsTC; LOG_DEBUG("ProcessEmbInfoL3Storage table:{}, channel:{}, batchId:{}", info.name, info.channelId, info.batchId); - if (info.channelId == TRAIN_CHANNEL_ID && info.batchId == hybridMgmtBlock->maxTrainStep) { + if (info.channelId == TRAIN_CHANNEL_ID && info.batchId == hybridMgmtBlock->maxTrainStep) { HandleReachMaxStepCase(info, remainBatchOut); return; } @@ -1202,12 +1183,12 @@ void HybridMgmt::ProcessEmbInfoL3Storage(const EmbBaseInfo& info, bool& remainBa SendGlobalUniqueVec(info, uniqueKeys, restoreVecSec); TimeCost swapProcessTC; - auto &swapInKeys = swapInKoPair.first; - auto &swapInPos = swapInKoPair.second; - auto &swapOutKeys = swapOutKoPair.first; - auto &swapOutPos = swapOutKoPair.second; + auto& swapInKeys = swapInKoPair.first; + auto& swapInPos = swapInKoPair.second; + auto& swapOutKeys = swapOutKoPair.first; + auto& swapOutPos = swapOutKoPair.second; auto lastSwapInPos = lastSwapInPosMap[info.name]; - lastSwapInPosMap[info.name] = swapInPos; // 暂存待下一步发送 + lastSwapInPosMap[info.name] = swapInPos; // 暂存待下一步发送 auto isNeedReturn = HandleSpecialProcessStatusL3Storage(info, getAndSendTensorsTC, swapInKoPair, swapOutKoPair); if (isNeedReturn) { @@ -1232,18 +1213,17 @@ void HybridMgmt::ProcessEmbInfoL3Storage(const EmbBaseInfo& info, bool& remainBa #endif } -void HybridMgmt::SendTensorForSwap(const EmbBaseInfo& info, - const vector &swapInPosUint, - const vector &swapOutPosUint) +void HybridMgmt::SendTensorForSwap(const EmbBaseInfo& info, const vector& swapInPosUint, + const vector& swapOutPosUint) { #ifndef GTEST vector swapTensor; swapTensor.emplace_back(Vec2TensorI32(swapInPosUint)); swapTensor.emplace_back(Vec2TensorI32(swapOutPosUint)); - swapTensor.emplace_back(Tensor(tensorflow::DT_INT32, { 1 })); + swapTensor.emplace_back(Tensor(tensorflow::DT_INT32, {1})); auto swapInLen = swapTensor.back().flat(); swapInLen(0) = swapInPosUint.size(); - swapTensor.emplace_back(Tensor(tensorflow::DT_INT32, { 1 })); + swapTensor.emplace_back(Tensor(tensorflow::DT_INT32, {1})); auto swapOutLen = swapTensor.back().flat(); swapOutLen(0) = swapOutPosUint.size(); @@ -1251,11 +1231,11 @@ void HybridMgmt::SendTensorForSwap(const EmbBaseInfo& info, #endif } -void HybridMgmt::InitDataPipelineForDDR(const string &embName) +void HybridMgmt::InitDataPipelineForDDR(const string& embName) { // 初始化公共队列 - HBMSwapKeyQue[embName+SWAP_IN_STR]; - HBMSwapKeyQue[embName+SWAP_OUT_STR]; + HBMSwapKeyQue[embName + SWAP_IN_STR]; + HBMSwapKeyQue[embName + SWAP_OUT_STR]; HBMSwapAddrsQue[embName + SWAP_IN_STR]; HBMSwapAddrsQue[embName + SWAP_OUT_STR]; @@ -1269,11 +1249,11 @@ void HybridMgmt::InitDataPipelineForDDR(const string &embName) LOG_DEBUG("data pipeline for ddr init"); } -void HybridMgmt::InitDataPipelineForL3Storage(const string &embName, int extEmbeddingSize) +void HybridMgmt::InitDataPipelineForL3Storage(const string& embName, int extEmbeddingSize) { // 初始化公共队列 - HBMSwapKeyQue[embName+SWAP_IN_STR]; - HBMSwapKeyQue[embName+SWAP_OUT_STR]; + HBMSwapKeyQue[embName + SWAP_IN_STR]; + HBMSwapKeyQue[embName + SWAP_OUT_STR]; HBMSwapAddrsQue[embName + SWAP_IN_STR]; HBMSwapAddrsQue[embName + SWAP_OUT_STR]; @@ -1300,7 +1280,7 @@ void HybridMgmt::InitEmbeddingCache(const vector& embInfos) EmbeddingMgmt::Instance()->SetEmbCacheForEmbTable(embCache); EmbeddingMgmt::Instance()->SetHDTransferForEmbTable(hdTransfer); - for (auto embInfo: embInfos) { + for (auto embInfo : embInfos) { if (isL3StorageEnabled) { InitDataPipelineForL3Storage(embInfo.name, embInfo.extEmbeddingSize); } else { @@ -1314,9 +1294,9 @@ void HybridMgmt::InitEmbeddingCache(const vector& embInfos) embInfo.name, embInfo.hostVocabSize, embInfo.extEmbeddingSize, embInfo.devVocabSize); EmbCache::EmbCacheInfo embCacheInfo(embInfo.name, embInfo.hostVocabSize, embInfo.embeddingSize, embInfo.extEmbeddingSize, embInfo.devVocabSize); - size_t prefill = std::max(embInfo.hostVocabSize/HOST_TO_PREFILL_RATIO, embInfo.devVocabSize); - int ret = embCache->CreateCacheForTable( - embCacheInfo, embInfo.initializeInfos, INVALID_KEY_VALUE, prefill, EMBEDDING_THREAD_NUM); + size_t prefill = std::max(embInfo.hostVocabSize / HOST_TO_PREFILL_RATIO, embInfo.devVocabSize); + int ret = embCache->CreateCacheForTable(embCacheInfo, embInfo.initializeInfos, INVALID_KEY_VALUE, prefill, + EMBEDDING_THREAD_NUM); if (ret != H_OK) { throw runtime_error(embInfo.name + "create cache for table failed, error code: " + std::to_string(ret)); } @@ -1325,22 +1305,22 @@ void HybridMgmt::InitEmbeddingCache(const vector& embInfos) void HybridMgmt::JoinEmbeddingCacheThread() { - for (auto &p : HBMSwapAddrsQue) { + for (auto& p : HBMSwapAddrsQue) { p.second.DestroyQueue(); } - for (auto &p : HBMSwapKeyQue) { + for (auto& p : HBMSwapKeyQue) { p.second.DestroyQueue(); } - for (auto &p : HBMSwapKeyForL3StorageQue) { + for (auto& p : HBMSwapKeyForL3StorageQue) { p.second.DestroyQueue(); } - for (auto &p : DDRSwapKeyQue) { + for (auto& p : DDRSwapKeyQue) { p.second.DestroyQueue(); } - for (auto &p : DDRSwapKeyForL3StorageQue) { + for (auto& p : DDRSwapKeyForL3StorageQue) { p.second.DestroyQueue(); } - for (auto &p : DDRSwapAddrsQue) { + for (auto& p : DDRSwapAddrsQue) { p.second.DestroyQueue(); } for (auto& t : EmbeddingLookUpAndSendThreadPool) { @@ -1363,25 +1343,26 @@ void HybridMgmt::HandleReachMaxStepCase(const EmbBaseInfo& info, bool& remainBat // 2. 如果切换过: // a. eval场景跑完,不用send,外面自然退出 // b. save场景,能触发,说明期望的train step已经跑完(由IsTrainEndBatch判定send),当前step也不用send - LOG_DEBUG("table:{}, batchId:{}, ProcessStatus:{}, reach maxTrainStep", - info.name, info.batchId, ProcessStatus2Str(ProcessStatus::NORMAL)); + LOG_DEBUG("table:{}, batchId:{}, ProcessStatus:{}, reach maxTrainStep", info.name, info.batchId, + ProcessStatus2Str(ProcessStatus::NORMAL)); if (specialProcessStatus[info.name] == ProcessStatus::NORMAL) { LOG_DEBUG("table:{}, batchId:{}, need send swap tensor" - " for last step to finish train", info.name, info.batchId); + " for last step to finish train", + info.name, info.batchId); std::vector emptySwapOutPos; SendTensorForSwap(info, lastSwapInPosMap[info.name], emptySwapOutPos); } else { - LOG_DEBUG("table:{}, batchId:{}, switch from eval or save, unnecessary to send emptySwapOutPos", - info.name, info.batchId); + LOG_DEBUG("table:{}, batchId:{}, switch from eval or save, unnecessary to send emptySwapOutPos", info.name, + info.batchId); } remainBatchOut = false; hybridMgmtBlock->SetBlockStatus(TRAIN_CHANNEL_ID, true); } -void HybridMgmt::HandleEosCase(const EmbBaseInfo& info, bool &remainBatchOut) +void HybridMgmt::HandleEosCase(const EmbBaseInfo& info, bool& remainBatchOut) { - LOG_INFO("GetUniqueKeys get eos, handle final batch for current epoch, table:{}, channel:{}, batchId:{}", - info.name, info.channelId, info.batchId); + LOG_INFO("GetUniqueKeys get eos, handle final batch for current epoch, table:{}, channel:{}, batchId:{}", info.name, + info.channelId, info.batchId); bool sendAllChannel = false; if (info.channelId == TRAIN_CHANNEL_ID) { vector emptySwapOutPos; @@ -1418,8 +1399,8 @@ void HybridMgmt::HandleEosCase(const EmbBaseInfo& info, bool &remainBatchOut) // train+eval+train场景 // 交给train的ProcessEmbInfoDDR启动最后n-1步eval // train发送pos让eval step n-1跑完,到eval step n时各channel遇到eos后结束(train、eval共享的channel除外) - LOG_INFO("GetUniqueKeys get eos, skip send pos for eval channel, table:{}, batchId:{}", - info.name, info.batchId); + LOG_INFO("GetUniqueKeys get eos, skip send pos for eval channel, table:{}, batchId:{}", info.name, + info.batchId); } } KEY_PROCESS_INSTANCE->SendEos(info.name, info.batchId, info.channelId, sendAllChannel); @@ -1454,22 +1435,22 @@ bool HybridMgmt::EmbeddingReceiveDDR(const EmbTaskInfo& info, float*& ptr, vecto if (aclData == nullptr) { throw runtime_error("Acl get tensor data from dataset failed."); } - ptr = reinterpret_cast(acltdtGetDataAddrFromItem(aclData)); + ptr = reinterpret_cast(acltdtGetDataAddrFromItem(aclData)); // 判断拿到的embedding个数是否与swapOutKeys个数相等 size_t dimNum = acltdtGetDimNumFromItem(aclData); int64_t dims[dimNum]; acltdtGetDimsFromItem(aclData, dims, dimNum); - LOG_DEBUG("table:{}, batchId:{}, dims[0]:{}, swapOutAddrs size:{}", - info.name, info.batchId, dims[0], swapOutAddrs.size()); + LOG_DEBUG("table:{}, batchId:{}, dims[0]:{}, swapOutAddrs size:{}", info.name, info.batchId, dims[0], + swapOutAddrs.size()); if (dims[0] != static_cast(swapOutAddrs.size())) { throw runtime_error("data dims[0] != swapOutKeys.size()"); } } - LOG_DEBUG("table:{}, batchId:{}, thread:{}, EmbeddingRecvTC(ms):{}", - info.name, info.batchId, info.threadIdx, EmbeddingRecvTC.ElapsedMS()); + LOG_DEBUG("table:{}, batchId:{}, thread:{}, EmbeddingRecvTC(ms):{}", info.name, info.batchId, info.threadIdx, + EmbeddingRecvTC.ElapsedMS()); lastRecvFinishStepMap[info.name]++; cvLastRecvFinishMap[info.name][info.cvNotifyIndex].notify_all(); @@ -1486,8 +1467,8 @@ void HybridMgmt::EmbeddingUpdateDDR(const EmbTaskInfo& info, const float* embPtr uint64_t memSize = info.extEmbeddingSize * sizeof(float); uint64_t extEmbeddingSize = info.extEmbeddingSize; -# pragma omp parallel for num_threads(MGMT_CPY_THREADS) default(none) \ - shared(swapOutAddrs, embPtr, extEmbeddingSize, memSize) +#pragma omp parallel for num_threads(MGMT_CPY_THREADS) default(none) \ + shared(swapOutAddrs, embPtr, extEmbeddingSize, memSize) for (uint64_t i = 0; i < swapOutAddrs.size(); i++) { auto rc = memcpy_s(swapOutAddrs[i], memSize, embPtr + i * extEmbeddingSize, memSize); if (rc != 0) { @@ -1497,18 +1478,19 @@ void HybridMgmt::EmbeddingUpdateDDR(const EmbTaskInfo& info, const float* embPtr if (MxRec::Logger::GetLevel() <= MxRec::Logger::DEBUG) { string sample; if (!swapOutAddrs.empty()) { - sample = FloatPtrToLimitStr(swapOutAddrs.front(), info.extEmbeddingSize); // print first element + sample = FloatPtrToLimitStr(swapOutAddrs.front(), info.extEmbeddingSize); // print first element } LOG_DEBUG("table:{}, batchId:{}, thread:{}, receive d2hEmb, ext emb:{}, emb size:{}, emb samples:{}, " - "EmbeddingUpdateTC(ms):{}", info.name.c_str(), info.batchId, info.threadIdx, - info.extEmbeddingSize, swapOutAddrs.size(), sample, EmbeddingUpdateTC.ElapsedMS()); + "EmbeddingUpdateTC(ms):{}", + info.name.c_str(), info.batchId, info.threadIdx, info.extEmbeddingSize, swapOutAddrs.size(), sample, + EmbeddingUpdateTC.ElapsedMS()); } lastUpdateFinishStepMap[info.name]++; cvLastUpdateFinishMap[info.name][info.cvNotifyIndex].notify_all(); } -bool HybridMgmt::EmbeddingLookUpDDR(const EmbTaskInfo &info, vector& h2dEmb) +bool HybridMgmt::EmbeddingLookUpDDR(const EmbTaskInfo& info, vector& h2dEmb) { std::unique_lock lastUpdateFinishLocker(lastUpdateFinishMutexMap[info.name][info.threadIdx]); cvLastUpdateFinishMap[info.name][info.threadIdx].wait(lastUpdateFinishLocker, [info, this] { @@ -1537,7 +1519,7 @@ bool HybridMgmt::EmbeddingLookUpDDR(const EmbTaskInfo &info, vector& h2d return true; } -void HybridMgmt::EmbeddingSendDDR(const EmbTaskInfo &info, vector& h2dEmb) +void HybridMgmt::EmbeddingSendDDR(const EmbTaskInfo& info, vector& h2dEmb) { std::unique_lock lastSendFinishLocker(lastSendFinishMutexMap[info.name][info.threadIdx]); cvLastSendFinishMap[info.name][info.threadIdx].wait(lastSendFinishLocker, [info, this] { @@ -1547,10 +1529,11 @@ void HybridMgmt::EmbeddingSendDDR(const EmbTaskInfo &info, vector& h2dEm hdTransfer->Send(TransferChannel::H2D, h2dEmb, TRAIN_CHANNEL_ID, info.name, info.batchId); lastSendFinishStepMap[info.name]++; cvLastSendFinishMap[info.name][info.cvNotifyIndex].notify_all(); - LOG_DEBUG("table:{}, batchId:{}, thread:{}, SendH2DEmbTC(ms):{}", - info.name, info.batchId, info.threadIdx, SendTC.ElapsedMS()); + LOG_DEBUG("table:{}, batchId:{}, thread:{}, SendH2DEmbTC(ms):{}", info.name, info.batchId, info.threadIdx, + SendTC.ElapsedMS()); - // 对于end of sequence场景,key process需要基于h2dNextBatchId等待每个table都完成了最后1个step发送,才能发EOS至各channel + // 对于end of sequence场景,key + // process需要基于h2dNextBatchId等待每个table都完成了最后1个step发送,才能发EOS至各channel hybridMgmtBlock->h2dNextBatchId[info.name]++; LOG_DEBUG("h2dNextBatchId, table:{}, next batchId:{}", info.name, hybridMgmtBlock->h2dNextBatchId[info.name]); } @@ -1603,8 +1586,8 @@ void HybridMgmt::CreateEmbeddingReceiveAndUpdateThread(int index, const EmbInfo& }); } -bool HybridMgmt::EmbeddingReceiveL3Storage(const EmbTaskInfo &info, float *&ptr, - vector &swapOutAddrs, int64_t& dims0) +bool HybridMgmt::EmbeddingReceiveL3Storage(const EmbTaskInfo& info, float*& ptr, vector& swapOutAddrs, + int64_t& dims0) { std::unique_lock lastRecvFinishLocker(lastRecvFinishMutexMap[info.name][info.threadIdx]); cvLastRecvFinishMap[info.name][info.threadIdx].wait(lastRecvFinishLocker, [info, this] { @@ -1635,26 +1618,26 @@ bool HybridMgmt::EmbeddingReceiveL3Storage(const EmbTaskInfo &info, float *&ptr, if (aclData == nullptr) { throw runtime_error("Acl get tensor data from dataset failed."); } - ptr = reinterpret_cast(acltdtGetDataAddrFromItem(aclData)); + ptr = reinterpret_cast(acltdtGetDataAddrFromItem(aclData)); // 判断拿到的embedding个数是否与swapOutKeys个数相等 size_t dimNum = acltdtGetDimNumFromItem(aclData); int64_t dims[dimNum]; acltdtGetDimsFromItem(aclData, dims, dimNum); - LOG_DEBUG("table:{}, batchId:{}, recv d2h, dims[0]:{}, swapOutAddrs.size:{}", - info.name, info.batchId, dims[0], swapOutAddrs.size()); + LOG_DEBUG("table:{}, batchId:{}, recv d2h, dims[0]:{}, swapOutAddrs.size:{}", info.name, info.batchId, dims[0], + swapOutAddrs.size()); dims0 = dims[0]; } - LOG_DEBUG("table:{}, batchId:{}, thread:{}, EmbeddingRecvTC(ms):{}", - info.name.c_str(), info.batchId, info.threadIdx, EmbeddingRecvTC.ElapsedMS()); + LOG_DEBUG("table:{}, batchId:{}, thread:{}, EmbeddingRecvTC(ms):{}", info.name.c_str(), info.batchId, + info.threadIdx, EmbeddingRecvTC.ElapsedMS()); lastRecvFinishStepMap[info.name]++; cvLastRecvFinishMap[info.name][info.cvNotifyIndex].notify_all(); return true; } -void HybridMgmt::EmbeddingUpdateL3Storage(const EmbTaskInfo& info, float *embPtr, - vector& swapOutAddrs, int64_t& dims0) +void HybridMgmt::EmbeddingUpdateL3Storage(const EmbTaskInfo& info, float* embPtr, vector& swapOutAddrs, + int64_t& dims0) { std::unique_lock lastUpdateFinishLocker(lastUpdateFinishMutexMap[info.name][info.threadIdx]); cvLastUpdateFinishMap[info.name][info.threadIdx].wait(lastUpdateFinishLocker, [info, this] { @@ -1669,16 +1652,16 @@ void HybridMgmt::EmbeddingUpdateL3Storage(const EmbTaskInfo& info, float *embPtr uint64_t memSize = info.extEmbeddingSize * sizeof(float); uint64_t extEmbeddingSize = info.extEmbeddingSize; // DDR更新 -# pragma omp parallel for num_threads(MGMT_CPY_THREADS) default(none) \ - shared(swapOutAddrs, swapOutDDRAddrOffs, embPtr, extEmbeddingSize, memSize) +#pragma omp parallel for num_threads(MGMT_CPY_THREADS) default(none) \ + shared(swapOutAddrs, swapOutDDRAddrOffs, embPtr, extEmbeddingSize, memSize) for (uint64_t i = 0; i < swapOutAddrs.size(); i++) { auto rc = memcpy_s(swapOutAddrs[i], memSize, embPtr + swapOutDDRAddrOffs[i] * extEmbeddingSize, memSize); if (rc != 0) { throw runtime_error("memcpy_s failed, error code:" + to_string(rc)); } } - LOG_DEBUG("table:{}, batchId:{}, thread:{}, EmbeddingUpdateTC(ms):{}", - info.name.c_str(), info.batchId, info.threadIdx, EmbeddingUpdateTC.ElapsedMS()); + LOG_DEBUG("table:{}, batchId:{}, thread:{}, EmbeddingUpdateTC(ms):{}", info.name.c_str(), info.batchId, + info.threadIdx, EmbeddingUpdateTC.ElapsedMS()); // L3Storage更新 TimeCost L3StorageUpdateTC = TimeCost(); @@ -1693,8 +1676,8 @@ void HybridMgmt::EmbeddingUpdateL3Storage(const EmbTaskInfo& info, float *embPtr } cacheManager->UpdateL3StorageEmb(info.name, embPtr, extEmbeddingSize, swapOutL3StorageKeys, swapOutL3StorageAddrOffs); - LOG_DEBUG("table:{}, batchId:{}, thread{}, L3StorageUpdateTC(ms):{}", - info.name.c_str(), info.batchId, info.threadIdx, L3StorageUpdateTC.ElapsedMS()); + LOG_DEBUG("table:{}, batchId:{}, thread{}, L3StorageUpdateTC(ms):{}", info.name.c_str(), info.batchId, + info.threadIdx, L3StorageUpdateTC.ElapsedMS()); lastUpdateFinishStepMap[info.name]++; cvLastUpdateFinishMap[info.name][info.cvNotifyIndex].notify_all(); @@ -1726,8 +1709,8 @@ bool HybridMgmt::EmbeddingLookUpL3Storage(const EmbTaskInfo& info, vectorTransferDDR2L3Storage(info.name, info.extEmbeddingSize, DDR2L3StorageKeys, DDR2L3StorageAddrs); - LOG_DEBUG("table:{}, thread:{}, transferDDR2L3StorageTC(ms):{}", - info.name.c_str(), info.threadIdx, transferDDR2L3StorageTC.ElapsedMS()); + LOG_DEBUG("table:{}, thread:{}, transferDDR2L3StorageTC(ms):{}", info.name.c_str(), info.threadIdx, + transferDDR2L3StorageTC.ElapsedMS()); TimeCost fetchL3StorageEmb2DDRTC = TimeCost(); // swapInKeys中在L3Storage的挪到DDR @@ -1737,8 +1720,8 @@ bool HybridMgmt::EmbeddingLookUpL3Storage(const EmbTaskInfo& info, vectorFetchL3StorageEmb2DDR(info.name, info.extEmbeddingSize, L3Storage2DDRKeys, L3Storage2DDRAddrs); - LOG_DEBUG("table:{}, thread:{}, fetchL3StorageEmb2DDRTC(ms):{}", - info.name.c_str(), info.threadIdx, fetchL3StorageEmb2DDRTC.ElapsedMS()); + LOG_DEBUG("table:{}, thread:{}, fetchL3StorageEmb2DDRTC(ms):{}", info.name.c_str(), info.threadIdx, + fetchL3StorageEmb2DDRTC.ElapsedMS()); bool isSuccess = BuildH2DEmbedding(info, h2dEmb); if (!isSuccess) { @@ -1763,12 +1746,13 @@ void HybridMgmt::EmbeddingSendL3Storage(const EmbTaskInfo& info, vector& cvLastSendFinishMap[info.name][info.cvNotifyIndex].notify_all(); LOG_DEBUG("table:{}, thread:{}, SendH2DEmbTC(ms):{}", info.name.c_str(), info.threadIdx, SendTC.ElapsedMS()); - // 对于end of sequence场景,key process需要基于h2dNextBatchId等待每个table都完成了最后1个step发送,才能发EOS至各channel + // 对于end of sequence场景,key + // process需要基于h2dNextBatchId等待每个table都完成了最后1个step发送,才能发EOS至各channel hybridMgmtBlock->h2dNextBatchId[info.name]++; LOG_DEBUG("h2dNextBatchId, table:{}, next batchId:{}", info.name, hybridMgmtBlock->h2dNextBatchId[info.name]); } -void HybridMgmt::HandleEosCaseHBM(const string &embName, int batchId, int channelId, bool &remainBatchOut) +void HybridMgmt::HandleEosCaseHBM(const string& embName, int batchId, int channelId, bool& remainBatchOut) { bool sendAllChannel = false; if (channelId == EVAL_CHANNEL_ID) { @@ -1813,19 +1797,19 @@ void HybridMgmt::HandleFirstBatchCaseDDR(const EmbBaseInfo& info, pair, vector>& swapOutKoPair) { TimeCost swapProcessTC; - auto &swapInKeys = swapInKoPair.first; - auto &swapInPos = swapInKoPair.second; - auto &swapOutKeys = swapOutKoPair.first; - auto &swapOutPos = swapOutKoPair.second; + auto& swapInKeys = swapInKoPair.first; + auto& swapInPos = swapInKoPair.second; + auto& swapOutKeys = swapOutKoPair.first; + auto& swapOutPos = swapOutKoPair.second; vector emptySwapOutKeys; - LOG_DEBUG("table:{}, batchId:{}, channelId:{}, swapInSize:{}, swapOutSize:{}", - info.name, info.batchId, info.channelId, swapInKoPair.first.size(), emptySwapOutKeys.size()); + LOG_DEBUG("table:{}, batchId:{}, channelId:{}, swapInSize:{}, swapOutSize:{}", info.name, info.batchId, + info.channelId, swapInKoPair.first.size(), emptySwapOutKeys.size()); trainTestSwitchInfoStore[info.name] = {swapOutKeys, swapOutPos}; LOG_DEBUG("handle first batch case, delay sending swapInPos, table:{}", info.name); - LOG_DEBUG("enqueue HBMSwapKeyQue table:{}, batchId:{}, channelId:{}, swapInSize:{}, swapOutSize:{}", - info.name, info.batchId, info.channelId, swapInKeys.size(), emptySwapOutKeys.size()); + LOG_DEBUG("enqueue HBMSwapKeyQue table:{}, batchId:{}, channelId:{}, swapInSize:{}, swapOutSize:{}", info.name, + info.batchId, info.channelId, swapInKeys.size(), emptySwapOutKeys.size()); HBMSwapKeyQue[info.name + SWAP_OUT_STR].Pushv(emptySwapOutKeys); HBMSwapKeyQue[info.name + SWAP_IN_STR].Pushv(swapInKeys); } @@ -1836,8 +1820,8 @@ void HybridMgmt::HandleFirstBatchCaseL3Storage(const EmbBaseInfo& info, { // 发现train、save、eval切换,先保存状态,发emptySwapOutKeys以对应上一步的emptySwapOutPos vector emptySwapOutKeys; - LOG_DEBUG("table:{}, batchId:{}, channelId:{}, swapInSize:{}, swapOutSize:{}", - info.name, info.batchId, info.channelId, swapInKoPair.first.size(), emptySwapOutKeys.size()); + LOG_DEBUG("table:{}, batchId:{}, channelId:{}, swapInSize:{}, swapOutSize:{}", info.name, info.batchId, + info.channelId, swapInKoPair.first.size(), emptySwapOutKeys.size()); trainTestSwitchInfoStore[info.name] = {swapOutKoPair.first, swapOutKoPair.second}; TimeCost ProcessSwapInKeysTC = TimeCost(); @@ -1851,14 +1835,14 @@ void HybridMgmt::HandleFirstBatchCaseL3Storage(const EmbBaseInfo& info, vector emptySwapOutL3StorageKeys; vector emptySwapOutL3StorageAddrOff; - LOG_DEBUG("table:{}, batchId:{}, channelId:{}, swapInSize:{}, swapOutSize:{}", - info.name, info.batchId, info.channelId, swapInKoPair.first.size(), swapOutKoPair.first.size()); + LOG_DEBUG("table:{}, batchId:{}, channelId:{}, swapInSize:{}, swapOutSize:{}", info.name, info.batchId, + info.channelId, swapInKoPair.first.size(), swapOutKoPair.first.size()); LOG_DEBUG("table:{}, batchId:{}, channelId:{}, swapOutDDRKeys.size:{}, swapOutDDRAddrOffs.size:{}, " "swapOutL3StorageKeys.size:{}, swapOutL3StorageAddrOff.size:{}", info.name, info.batchId, info.channelId, emptySwapOutDDRKeys.size(), emptySwapOutDDRAddrOffs.size(), emptySwapOutL3StorageKeys.size(), emptySwapOutL3StorageAddrOff.size()); - LOG_DEBUG("table:{}, batchId:{}, channelId:{}, DDRToL3StorageKeys.size:{}, L3StorageToDDRKeys.size:{}", - info.name, info.batchId, info.channelId, DDRToL3StorageKeys.size(), L3StorageToDDRKeys.size()); + LOG_DEBUG("table:{}, batchId:{}, channelId:{}, DDRToL3StorageKeys.size:{}, L3StorageToDDRKeys.size:{}", info.name, + info.batchId, info.channelId, DDRToL3StorageKeys.size(), L3StorageToDDRKeys.size()); auto DDRToL3StorageKeysForL3S = DDRToL3StorageKeys; auto L3StorageToDDRKeysForL3S = L3StorageToDDRKeys; @@ -1879,8 +1863,8 @@ void HybridMgmt::HandleFirstBatchCaseL3Storage(const EmbBaseInfo& info, HBMSwapKeyForL3StorageQue[info.name + ADDR_STR].Pushv(emptySwapOutL3StorageAddrOff); } -void HybridMgmt::HandleDataSwapForL3Storage(const EmbBaseInfo& info, - vector &swapInKeys, vector &swapOutKeys) +void HybridMgmt::HandleDataSwapForL3Storage(const EmbBaseInfo& info, vector& swapInKeys, + vector& swapOutKeys) { TimeCost ProcessSwapInKeysTC; vector L3StorageToDDRKeys; @@ -1893,15 +1877,15 @@ void HybridMgmt::HandleDataSwapForL3Storage(const EmbBaseInfo& info, cacheManager->ProcessSwapOutKeys(info.name, swapOutKeys, hbmSwapInfo); LOG_DEBUG("ProcessSwapOutKeysTC(ms):{} ", ProcessSwapOutKeysTC.ElapsedMS()); - LOG_DEBUG("table:{}, batchId:{}, channelId:{}, swapInSize:{}, swapOutSize:{}", - info.name, info.batchId, info.channelId, swapInKeys.size(), swapOutKeys.size()); + LOG_DEBUG("table:{}, batchId:{}, channelId:{}, swapInSize:{}, swapOutSize:{}", info.name, info.batchId, + info.channelId, swapInKeys.size(), swapOutKeys.size()); LOG_DEBUG("table:{}, batchId:{}, channelId:{}, swap out, HBM2DDR Keys:{}, HBM2DDR AddrOffs:{}, " "HBM2L3Storage Keys:{}, HBM2L3Storage AddrOff:{}", info.name, info.batchId, info.channelId, hbmSwapInfo.swapOutDDRKeys.size(), hbmSwapInfo.swapOutDDRAddrOffs.size(), hbmSwapInfo.swapOutL3StorageKeys.size(), hbmSwapInfo.swapOutL3StorageAddrOffs.size()); - LOG_DEBUG("table:{}, batchId:{}, channelId:{}, DDR2L3Storage Keys:{}, L3Storage2DDR Keys:{}", - info.name, info.batchId, info.channelId, DDRToL3StorageKeys.size(), L3StorageToDDRKeys.size()); + LOG_DEBUG("table:{}, batchId:{}, channelId:{}, DDR2L3Storage Keys:{}, L3Storage2DDR Keys:{}", info.name, + info.batchId, info.channelId, DDRToL3StorageKeys.size(), L3StorageToDDRKeys.size()); auto DDRToL3StorageKeysForL3S = DDRToL3StorageKeys; auto L3StorageToDDRKeysForL3S = L3StorageToDDRKeys; @@ -1922,22 +1906,20 @@ void HybridMgmt::HandleDataSwapForL3Storage(const EmbBaseInfo& info, HBMSwapKeyForL3StorageQue[info.name + ADDR_STR].Pushv(hbmSwapInfo.swapOutL3StorageAddrOffs); } -bool HybridMgmt::BuildH2DEmbedding(const EmbTaskInfo &info, vector &h2dEmb) +bool HybridMgmt::BuildH2DEmbedding(const EmbTaskInfo& info, vector& h2dEmb) { std::vector swapInAddrs = HBMSwapAddrsQue[info.name + SWAP_IN_STR].WaitAndPop(); if (!isRunning) { return false; } - h2dEmb.emplace_back(Tensor(tensorflow::DT_FLOAT, { - int(swapInAddrs.size()), static_cast(info.extEmbeddingSize) - })); - auto &tmpTensor = h2dEmb.back(); - float *h2dEmbAddr = tmpTensor.flat().data(); + h2dEmb.emplace_back( + Tensor(tensorflow::DT_FLOAT, {int(swapInAddrs.size()), static_cast(info.extEmbeddingSize)})); + auto& tmpTensor = h2dEmb.back(); + float* h2dEmbAddr = tmpTensor.flat().data(); TimeCost embeddingLookupTC = TimeCost(); uint64_t memSize = info.extEmbeddingSize * sizeof(float); -# pragma omp parallel for num_threads(MGMT_CPY_THREADS) default(none) \ - shared(swapInAddrs, h2dEmbAddr, info, memSize) +#pragma omp parallel for num_threads(MGMT_CPY_THREADS) default(none) shared(swapInAddrs, h2dEmbAddr, info, memSize) for (uint64_t i = 0; i < swapInAddrs.size(); i++) { auto rc = memcpy_s(h2dEmbAddr + i * info.extEmbeddingSize, memSize, swapInAddrs[i], memSize); if (rc != 0) { @@ -1951,7 +1933,7 @@ bool HybridMgmt::BuildH2DEmbedding(const EmbTaskInfo &info, vector &h2dE return true; } -vector HybridMgmt::GetUniqueKeys(const EmbBaseInfo &info, bool &remainBatchOut) +vector HybridMgmt::GetUniqueKeys(const EmbBaseInfo& info, bool& remainBatchOut) { bool isEos = false; auto uniqueKeys = KEY_PROCESS_INSTANCE->GetUniqueKeys(info, isEos, lookUpSwapInAddrsPushId); @@ -1961,8 +1943,8 @@ vector HybridMgmt::GetUniqueKeys(const EmbBaseInfo &info, bool &remain } if (uniqueKeys.empty()) { remainBatchOut = false; - LOG_WARN("table:{}, channelId:{} batchId:{}, UniqueKeys result is empty", - info.name, info.channelId, info.batchId); + LOG_WARN("table:{}, channelId:{} batchId:{}, UniqueKeys result is empty", info.name, info.channelId, + info.batchId); return uniqueKeys; } @@ -1971,7 +1953,7 @@ vector HybridMgmt::GetUniqueKeys(const EmbBaseInfo &info, bool &remain trainKeysSet[info.name].insert(uniqueKeys.begin(), uniqueKeys.end()); LOG_DEBUG("table:{}, batchId:{}, KeyMaintainTC(ms):{}", info.name, info.batchId, KeyMaintainTC.ElapsedMS()); } else { - for (auto &key : uniqueKeys) { + for (auto& key : uniqueKeys) { if (trainKeysSet[info.name].find(key) == trainKeysSet[info.name].end()) { key = INVALID_KEY_VALUE; LOG_TRACE("find key not train before, set as invalid key"); @@ -1983,28 +1965,27 @@ vector HybridMgmt::GetUniqueKeys(const EmbBaseInfo &info, bool &remain return uniqueKeys; } -vector HybridMgmt::GetRestoreVecSec(const EmbBaseInfo &info, bool &remainBatchOut) +vector HybridMgmt::GetRestoreVecSec(const EmbBaseInfo& info, bool& remainBatchOut) { auto restoreVecSec = KEY_PROCESS_INSTANCE->GetRestoreVecSec(info); if (restoreVecSec.empty()) { remainBatchOut = false; - LOG_WARN("table:{}, channelId:{} batchId:{}, restoreVecSec result is empty", - info.name, info.channelId, info.batchId); + LOG_WARN("table:{}, channelId:{} batchId:{}, restoreVecSec result is empty", info.name, info.channelId, + info.batchId); return restoreVecSec; } LOG_DEBUG("table:{}, channelId:{} batchId:{}, GetRestoreVecSec end", info.name, info.channelId, info.batchId); return restoreVecSec; } -void HybridMgmt::SendAll2AllVec(const EmbBaseInfo &info, bool &remainBatchOut) +void HybridMgmt::SendAll2AllVec(const EmbBaseInfo& info, bool& remainBatchOut) { if (!mgmtRankInfo.useStatic) { bool isEos = false; // useless, adapt to HBM mode TimeCost getAll2AllTC; - unique_ptr> all2all = KEY_PROCESS_INSTANCE->GetInfoVec( - info, ProcessedInfo::ALL2ALL, isEos); - LOG_DEBUG("table:{}, channelId:{}, batchId:{}, GetInfoVec all2all end, GetAll2AllTC(ms):{}", - info.name, info.channelId, info.batchId, getAll2AllTC.ElapsedMS()); + unique_ptr> all2all = KEY_PROCESS_INSTANCE->GetInfoVec(info, ProcessedInfo::ALL2ALL, isEos); + LOG_DEBUG("table:{}, channelId:{}, batchId:{}, GetInfoVec all2all end, GetAll2AllTC(ms):{}", info.name, + info.channelId, info.batchId, getAll2AllTC.ElapsedMS()); if (all2all == nullptr) { remainBatchOut = false; LOG_WARN("Information vector is nullptr!"); @@ -2012,17 +1993,16 @@ void HybridMgmt::SendAll2AllVec(const EmbBaseInfo &info, bool &remainBatchOut) } TimeCost sendAll2AllTC; hdTransfer->Send(TransferChannel::ALL2ALL, *all2all, info.channelId, info.name); - LOG_DEBUG("table:{}, channelId:{}, batchId:{}, send all2all end, sendAll2AllTC(ms):{}", - info.name, info.channelId, info.batchId, sendAll2AllTC.ElapsedMS()); + LOG_DEBUG("table:{}, channelId:{}, batchId:{}, send all2all end, sendAll2AllTC(ms):{}", info.name, + info.channelId, info.batchId, sendAll2AllTC.ElapsedMS()); } } -void HybridMgmt::SendRestoreVec(const EmbBaseInfo &info, bool &remainBatchOut) +void HybridMgmt::SendRestoreVec(const EmbBaseInfo& info, bool& remainBatchOut) { bool isEos = false; // useless, adapt to HBM mode TimeCost getRestoreTC; - unique_ptr> infoVecs = KEY_PROCESS_INSTANCE->GetInfoVec( - info, ProcessedInfo::RESTORE, isEos); + unique_ptr> infoVecs = KEY_PROCESS_INSTANCE->GetInfoVec(info, ProcessedInfo::RESTORE, isEos); if (infoVecs == nullptr) { remainBatchOut = false; if (isRunning) { @@ -2030,66 +2010,67 @@ void HybridMgmt::SendRestoreVec(const EmbBaseInfo &info, bool &remainBatchOut) } return; } - LOG_DEBUG("table:{}, channelId:{}, batchId:{}, get restore end, getRestoreTC(ms):{}", - info.name, info.channelId, info.batchId, getRestoreTC.ElapsedMS()); + LOG_DEBUG("table:{}, channelId:{}, batchId:{}, get restore end, getRestoreTC(ms):{}", info.name, info.channelId, + info.batchId, getRestoreTC.ElapsedMS()); TimeCost sendRestoreSyncTC; hdTransfer->Send(TransferChannel::RESTORE, *infoVecs, info.channelId, info.name); - LOG_DEBUG("table:{}, channelId:{}, batchId:{}, send restore end, sendRestoreSyncTC(ms):{}", - info.name, info.channelId, info.batchId, sendRestoreSyncTC.ElapsedMS()); + LOG_DEBUG("table:{}, channelId:{}, batchId:{}, send restore end, sendRestoreSyncTC(ms):{}", info.name, + info.channelId, info.batchId, sendRestoreSyncTC.ElapsedMS()); } -void HybridMgmt::SendLookupOffsets(const EmbBaseInfo &info, - vector &uniqueKeys, vector &restoreVecSec) +void HybridMgmt::SendLookupOffsets(const EmbBaseInfo& info, vector& uniqueKeys, + vector& restoreVecSec) { // uniqueKeys already transfer to offset in GetSwapPairsAndKey2Offset // graph will filter out invalid offset(-1). see function _set_specific_value_for_non_valid_key TimeCost sendLookupOffsetsTC; std::vector lookupOffsets; - for (const auto &index : restoreVecSec) { + for (const auto& index : restoreVecSec) { if (index == INVALID_INDEX_VALUE) { lookupOffsets.emplace_back(static_cast(INVALID_KEY_VALUE)); continue; } lookupOffsets.emplace_back(uniqueKeys[index]); } - hdTransfer->Send(TransferChannel::LOOKUP, { Vec2TensorI32(lookupOffsets) }, info.channelId, info.name); - LOG_DEBUG("table:{}, channelId:{}, batchId:{}, send lookupOffset, sendLookupOffsetsTC(ms):{}", - info.name, info.channelId, info.batchId, sendLookupOffsetsTC.ElapsedMS()); + hdTransfer->Send(TransferChannel::LOOKUP, {Vec2TensorI32(lookupOffsets)}, info.channelId, info.name); + LOG_DEBUG("table:{}, channelId:{}, batchId:{}, send lookupOffset, sendLookupOffsetsTC(ms):{}", info.name, + info.channelId, info.batchId, sendLookupOffsetsTC.ElapsedMS()); } -void HybridMgmt::SendGlobalUniqueVec(const EmbBaseInfo &info, - vector &uniqueKeys, vector &restoreVecSec) +void HybridMgmt::SendGlobalUniqueVec(const EmbBaseInfo& info, vector& uniqueKeys, + vector& restoreVecSec) { if (!(info.channelId == TRAIN_CHANNEL_ID && mgmtRankInfo.useSumSameIdGradients)) { return; } TimeCost sendUniqueKeysSyncTC; - hdTransfer->Send(TransferChannel::UNIQKEYS, {mgmtRankInfo.useDynamicExpansion ? Vec2TensorI64(uniqueKeys) : - Vec2TensorI32(uniqueKeys) }, info.channelId, info.name); - LOG_DEBUG("table:{}, channelId:{}, batchId:{}, sendUniqueKeysSyncTC(ms):{}", - info.name, info.channelId, info.batchId, sendUniqueKeysSyncTC.ElapsedMS()); + hdTransfer->Send(TransferChannel::UNIQKEYS, + {mgmtRankInfo.useDynamicExpansion ? Vec2TensorI64(uniqueKeys) : Vec2TensorI32(uniqueKeys)}, + info.channelId, info.name); + LOG_DEBUG("table:{}, channelId:{}, batchId:{}, sendUniqueKeysSyncTC(ms):{}", info.name, info.channelId, + info.batchId, sendUniqueKeysSyncTC.ElapsedMS()); TimeCost sendRestoreVecSecSyncTC; - hdTransfer->Send(TransferChannel::RESTORE_SECOND, {Vec2TensorI32(restoreVecSec) }, info.channelId, info.name); - LOG_DEBUG("table:{}, channelId:{}, batchId:{}, sendRestoreVecSecSyncTC(ms):{}", - info.name, info.channelId, info.batchId, sendRestoreVecSecSyncTC.ElapsedMS()); + hdTransfer->Send(TransferChannel::RESTORE_SECOND, {Vec2TensorI32(restoreVecSec)}, info.channelId, info.name); + LOG_DEBUG("table:{}, channelId:{}, batchId:{}, sendRestoreVecSecSyncTC(ms):{}", info.name, info.channelId, + info.batchId, sendRestoreVecSecSyncTC.ElapsedMS()); } -bool HybridMgmt::HandleSpecialProcessStatusDDR(const EmbBaseInfo &info, TimeCost& getAndSendTensorsTC, - pair, vector> &swapInKoPair, - pair, vector> &swapOutKoPair) +bool HybridMgmt::HandleSpecialProcessStatusDDR(const EmbBaseInfo& info, TimeCost& getAndSendTensorsTC, + pair, vector>& swapInKoPair, + pair, vector>& swapOutKoPair) { TimeCost swapProcessTC; - auto &swapInPos = swapInKoPair.second; - auto &swapOutKeys = swapOutKoPair.first; - auto &swapOutPos = swapOutKoPair.second; + auto& swapInPos = swapInKoPair.second; + auto& swapOutKeys = swapOutKoPair.first; + auto& swapOutPos = swapOutKoPair.second; if (specialProcessStatus[info.name] == ProcessStatus::AFTER_SWITCH_FIRST_BATCH) { // 发现train、save、eval切换,先保存状态,发emptySwapOutKeys以对应上一步的emptySwapOutPos HandleFirstBatchCaseDDR(info, swapInKoPair, swapOutKoPair); - LOG_DEBUG("handle channel switch case:afterSwitchFirstBatch, table:{}, channelId:{}, batchId:{}", - info.name, info.channelId, info.batchId); + LOG_DEBUG("handle channel switch case:afterSwitchFirstBatch, table:{}, channelId:{}, batchId:{}", info.name, + info.channelId, info.batchId); if (mgmtRankInfo.ctrlSteps[info.channelId] == 1) { vector emptySwapOutPos; @@ -2110,32 +2091,33 @@ bool HybridMgmt::HandleSpecialProcessStatusDDR(const EmbBaseInfo &info, TimeCost swapOutKeys.insert(swapOutKeys.end(), tempStore[0].begin(), tempStore[0].end()); swapOutPos.insert(swapOutPos.end(), tempStore[1].begin(), tempStore[1].end()); specialProcessStatus[info.name] = ProcessStatus::NORMAL; - LOG_DEBUG("handle channel switch case:afterSwitchSecondBatch, table:{}, channelId:{}, batchId:{}", - info.name, info.channelId, info.batchId); + LOG_DEBUG("handle channel switch case:afterSwitchSecondBatch, table:{}, channelId:{}, batchId:{}", info.name, + info.channelId, info.batchId); } return false; } -bool HybridMgmt::HandleSpecialProcessStatusL3Storage(const EmbBaseInfo &info, TimeCost &getAndSendTensorsTC, - pair, vector> &swapInKoPair, - pair, vector> &swapOutKoPair) +bool HybridMgmt::HandleSpecialProcessStatusL3Storage(const EmbBaseInfo& info, TimeCost& getAndSendTensorsTC, + pair, vector>& swapInKoPair, + pair, vector>& swapOutKoPair) { TimeCost swapProcessTC; - auto &swapInPos = swapInKoPair.second; - auto &swapOutKeys = swapOutKoPair.first; - auto &swapOutPos = swapOutKoPair.second; + auto& swapInPos = swapInKoPair.second; + auto& swapOutKeys = swapOutKoPair.first; + auto& swapOutPos = swapOutKoPair.second; if (specialProcessStatus[info.name] == ProcessStatus::AFTER_SWITCH_FIRST_BATCH) { // 发现train、save、eval切换,先保存状态,发emptySwapOutKeys以对应上一步的emptySwapOutPos HandleFirstBatchCaseL3Storage(info, swapInKoPair, swapOutKoPair); - LOG_DEBUG("handle channel switch case:afterSwitchFirstBatch, table:{}, channelId:{}, batchId:{}", - info.name, info.channelId, info.batchId); + LOG_DEBUG("handle channel switch case:afterSwitchFirstBatch, table:{}, channelId:{}, batchId:{}", info.name, + info.channelId, info.batchId); if (mgmtRankInfo.ctrlSteps[info.channelId] == 1) { vector emptySwapOutPos; SendTensorForSwap(info, swapInPos, emptySwapOutPos); LOG_DEBUG("ProcessEmbInfoL3Storage special case, user only run one step, " - "table:{}, channelId:{}, batchId:{}", info.name, info.channelId, info.batchId); + "table:{}, channelId:{}, batchId:{}", + info.name, info.channelId, info.batchId); } specialProcessStatus[info.name] = ProcessStatus::AFTER_SWITCH_SECOND_BATCH; @@ -2149,13 +2131,12 @@ bool HybridMgmt::HandleSpecialProcessStatusL3Storage(const EmbBaseInfo &info, Ti swapOutKeys.insert(swapOutKeys.end(), tempStore[0].begin(), tempStore[0].end()); swapOutPos.insert(swapOutPos.end(), tempStore[1].begin(), tempStore[1].end()); specialProcessStatus[info.name] = ProcessStatus::NORMAL; - LOG_DEBUG("handle channel switch case:afterSwitchSecondBatch, table:{}, channelId:{}, batchId:{}", - info.name, info.channelId, info.batchId); + LOG_DEBUG("handle channel switch case:afterSwitchSecondBatch, table:{}, channelId:{}, batchId:{}", info.name, + info.channelId, info.batchId); } return false; } - void HybridMgmt::CheckLookupAddrSuccessDDR() { if (!lookupAddrSuccess) { @@ -2169,20 +2150,19 @@ void HybridMgmt::CheckLookupAddrSuccessDDR() } } - -void HybridMgmt::GetSwapPairsAndKey2Offset(const EmbBaseInfo &info, vector &uniqueKeys, - pair, vector> &swapInKoPair, - pair, vector> &swapOutKoPair) +void HybridMgmt::GetSwapPairsAndKey2Offset(const EmbBaseInfo& info, vector& uniqueKeys, + pair, vector>& swapInKoPair, + pair, vector>& swapOutKoPair) { TimeCost GetSwapPairsAndKey2OffsetTC; int swapInCode = embCache->GetSwapPairsAndKey2Offset(info.name, uniqueKeys, swapInKoPair, swapOutKoPair); if (swapInCode != H_OK) { - string errMsg = StringFormat("table:%s, GetSwapPairsAndKey2Offset failed! error code:%d", - info.name.c_str(), swapInCode); + string errMsg = + StringFormat("table:%s, GetSwapPairsAndKey2Offset failed! error code:%d", info.name.c_str(), swapInCode); throw runtime_error(errMsg); } - LOG_DEBUG("table:{}, channel:{}, batchId:{}, GetSwapPairsAndKey2OffsetTC(ms):{}", - info.name, info.channelId, info.batchId, GetSwapPairsAndKey2OffsetTC.ElapsedMS()); + LOG_DEBUG("table:{}, channel:{}, batchId:{}, GetSwapPairsAndKey2OffsetTC(ms):{}", info.name, info.channelId, + info.batchId, GetSwapPairsAndKey2OffsetTC.ElapsedMS()); LOG_DEBUG("table:{}, channel:{}, batchId:{}, swapIn keys:{}, swapIn pos:{}, swapOut keys:{}, swapOut pos:{}", info.name, info.channelId, info.batchId, VectorToString(swapInKoPair.first), @@ -2190,15 +2170,14 @@ void HybridMgmt::GetSwapPairsAndKey2Offset(const EmbBaseInfo &info, vector, vector>& swapInKoPair, +void HybridMgmt::EnqueueSwapInfo(const EmbBaseInfo& info, pair, vector>& swapInKoPair, pair, vector>& swapOutKoPair) { - auto &swapInKeys = swapInKoPair.first; - auto &swapOutKeys = swapOutKoPair.first; + auto& swapInKeys = swapInKoPair.first; + auto& swapOutKeys = swapOutKoPair.first; - LOG_DEBUG("enqueue HBMSwapKeyQue table:{}, batchId:{}, channelId:{}, swapInSize:{}, swapOutSize:{}", - info.name, info.batchId, info.channelId, swapInKeys.size(), swapOutKeys.size()); + LOG_DEBUG("enqueue HBMSwapKeyQue table:{}, batchId:{}, channelId:{}, swapInSize:{}, swapOutSize:{}", info.name, + info.batchId, info.channelId, swapInKeys.size(), swapOutKeys.size()); HBMSwapKeyQue[info.name + SWAP_OUT_STR].Pushv(swapOutKeys); HBMSwapKeyQue[info.name + SWAP_IN_STR].Pushv(swapInKeys); @@ -2208,7 +2187,7 @@ void HybridMgmt::EnqueueSwapInfo(const EmbBaseInfo &info, bool HybridMgmt::IsTrainAndEvalCase() { bool isChannelSwitchCase = false; - for (auto& i: mgmtEmbInfo) { + for (auto& i : mgmtEmbInfo) { if (specialProcessStatus[i.name] == ProcessStatus::AFTER_SWITCH_FIRST_BATCH) { isChannelSwitchCase = true; break; diff --git a/src/core/hybrid_mgmt/hybrid_mgmt.h b/src/core/hybrid_mgmt/hybrid_mgmt.h index f5897861..ab34b19f 100644 --- a/src/core/hybrid_mgmt/hybrid_mgmt.h +++ b/src/core/hybrid_mgmt/hybrid_mgmt.h @@ -17,308 +17,300 @@ See the License for the specific language governing permissions and #define MX_REC_EMB_MGMT_H #include -#include #include #include +#include #include "absl/container/flat_hash_map.h" - +#include "emb_table/embedding_table.h" +#include "hd_transfer/hd_transfer.h" +#include "hybrid_mgmt_block.h" +#include "l3_storage/cache_manager.h" +#include "ock_ctr_common/include/embedding_cache.h" +#include "ock_ctr_common/include/error_code.h" +#include "ock_ctr_common/include/factory.h" #include "utils/common.h" #include "utils/config.h" #include "utils/singleton.h" #include "utils/task_queue.h" #include "utils/time_cost.h" -#include "ock_ctr_common/include/factory.h" -#include "ock_ctr_common/include/embedding_cache.h" -#include "ock_ctr_common/include/error_code.h" - -#include "hd_transfer/hd_transfer.h" -#include "l3_storage/cache_manager.h" -#include "hybrid_mgmt_block.h" -#include "emb_table/embedding_table.h" namespace MxRec { - using namespace std; - using namespace tensorflow; - using namespace Common; - - enum class TaskType { - HBM, - DDR - }; - - enum class ProcessStatus { - NORMAL, - AFTER_SWITCH_FIRST_BATCH, - AFTER_SWITCH_SECOND_BATCH - }; - - inline string ProcessStatus2Str(ProcessStatus s) +using namespace std; +using namespace tensorflow; +using namespace Common; + +enum class TaskType { + HBM, + DDR +}; + +enum class ProcessStatus { + NORMAL, + AFTER_SWITCH_FIRST_BATCH, + AFTER_SWITCH_SECOND_BATCH +}; + +inline string ProcessStatus2Str(ProcessStatus s) +{ + switch (s) { + case ProcessStatus::NORMAL: + return "normal"; + case ProcessStatus::AFTER_SWITCH_FIRST_BATCH: + return "afterSwitchFirstBatch"; + case ProcessStatus::AFTER_SWITCH_SECOND_BATCH: + return "afterSwitchSecondBatch"; + default: + throw std::invalid_argument("Invalid ProcessStatus"); + } +}; + +struct EmbTaskInfo { + int batchId; + int threadIdx; + int cvNotifyIndex; + int extEmbeddingSize; + string name; +}; + +class HybridMgmt { +public: + HybridMgmt() = default; + + ~HybridMgmt() { - switch (s) { - case ProcessStatus::NORMAL: - return "normal"; - case ProcessStatus::AFTER_SWITCH_FIRST_BATCH: - return "afterSwitchFirstBatch"; - case ProcessStatus::AFTER_SWITCH_SECOND_BATCH: - return "afterSwitchSecondBatch"; - default: - throw std::invalid_argument("Invalid ProcessStatus"); - } - }; - - struct EmbTaskInfo { - int batchId; - int threadIdx; - int cvNotifyIndex; - int extEmbeddingSize; - string name; - }; - - class HybridMgmt { - public: - HybridMgmt() = default; - - ~HybridMgmt() - { - if (isRunning) { - Destroy(); - } + if (isRunning) { + Destroy(); } + } - HybridMgmt(const HybridMgmt&) = delete; + HybridMgmt(const HybridMgmt&) = delete; - HybridMgmt& operator=(const HybridMgmt&) = delete; + HybridMgmt& operator=(const HybridMgmt&) = delete; - bool Initialize(RankInfo rankInfo, const vector& embInfos, int seed, - const vector& thresholdValues, bool ifLoad); + bool Initialize(RankInfo rankInfo, const vector& embInfos, int seed, + const vector& thresholdValues, bool ifLoad); - void Save(const string& savePath); + void Save(const string& savePath); - bool Load(const string& loadPath, vector warmStartTables); + bool Load(const string& loadPath, vector warmStartTables); - OffsetT SendHostMap(const string tableName); + OffsetT SendHostMap(const string tableName); - OffsetT SendLoadMap(const string tableName); + OffsetT SendLoadMap(const string tableName); - void ReceiveHostMap(AllKeyOffsetMapT receiveKeyOffsetMap); + void ReceiveHostMap(AllKeyOffsetMapT receiveKeyOffsetMap); - void Start(); + void Start(); - void StartThreadForHBM(); + void StartThreadForHBM(); - void StartThreadForDDR(); + void StartThreadForDDR(); - void Destroy(); + void Destroy(); - bool ParseKeys(int channelId, int& batchId, TaskType type); + bool ParseKeys(int channelId, int& batchId, TaskType type); - bool Evict(); + bool Evict(); - void NotifyBySessionRun(int channelID) const; + void NotifyBySessionRun(int channelID) const; - void CountStepBySessionRun(int channelID, int steps) const; + void CountStepBySessionRun(int channelID, int steps) const; - int64_t GetTableSize(const string& embName) const; + int64_t GetTableSize(const string& embName) const; - int64_t GetTableCapacity(const string& embName) const; + int64_t GetTableCapacity(const string& embName) const; - void SetOptimizerInfo(const string& embName, OptimizerInfo optimInfo) const; + void SetOptimizerInfo(const string& embName, OptimizerInfo optimInfo) const; - void FetchDeviceEmb(); + void FetchDeviceEmb(); - void ProcessEmbInfoHBM(const EmbBaseInfo& info, bool& remainBatchOut, bool isGrad); + void ProcessEmbInfoHBM(const EmbBaseInfo& info, bool& remainBatchOut, bool isGrad); - void ProcessEmbInfoDDR(const EmbBaseInfo& info, bool& remainBatchOut); + void ProcessEmbInfoDDR(const EmbBaseInfo& info, bool& remainBatchOut); - void ProcessEmbInfoL3Storage(const EmbBaseInfo& info, bool& remainBatchOut); + void ProcessEmbInfoL3Storage(const EmbBaseInfo& info, bool& remainBatchOut); - GTEST_PRIVATE: - bool mutexDestroy { false }; - std::mutex lookUpAndSendBatchIdMtx; - std::mutex receiveAndUpdateBatchIdMtx; - std::map lookUpAndSendTableBatchMap; - std::map receiveAndUpdateTableBatchMap; + GTEST_PRIVATE : bool mutexDestroy{false}; + std::mutex lookUpAndSendBatchIdMtx; + std::mutex receiveAndUpdateBatchIdMtx; + std::map lookUpAndSendTableBatchMap; + std::map receiveAndUpdateTableBatchMap; - std::map> lastUpdateFinishMutexMap; - std::map> cvLastUpdateFinishMap; - std::map lastUpdateFinishStepMap; - std::map> lastLookUpFinishMutexMap; - std::map> cvLastLookUpFinishMap; - std::map lastLookUpFinishStepMap; - std::map> lastSendFinishMutexMap; - std::map> cvLastSendFinishMap; - std::map lastSendFinishStepMap; - std::map> lastRecvFinishMutexMap; - std::map> cvLastRecvFinishMap; - std::map lastRecvFinishStepMap; + std::map> lastUpdateFinishMutexMap; + std::map> cvLastUpdateFinishMap; + std::map lastUpdateFinishStepMap; + std::map> lastLookUpFinishMutexMap; + std::map> cvLastLookUpFinishMap; + std::map lastLookUpFinishStepMap; + std::map> lastSendFinishMutexMap; + std::map> cvLastSendFinishMap; + std::map lastSendFinishStepMap; + std::map> lastRecvFinishMutexMap; + std::map> cvLastRecvFinishMap; + std::map lastRecvFinishStepMap; - std::vector EmbeddingLookUpAndSendThreadPool; - std::vector EmbeddingReceiveAndUpdateThreadPool; - std::vector> lookUpSwapOutAddrsThreads; - std::vector> lookUpSwapInAddrsThreads; + std::vector EmbeddingLookUpAndSendThreadPool; + std::vector EmbeddingReceiveAndUpdateThreadPool; + std::vector> lookUpSwapOutAddrsThreads; + std::vector> lookUpSwapInAddrsThreads; - std::map>> HBMSwapKeyQue; - std::map>> HBMSwapKeyForL3StorageQue; - std::map>> DDRSwapKeyQue; - std::map>> DDRSwapKeyForL3StorageQue; - std::map>> HBMSwapAddrsQue; - std::map>> DDRSwapAddrsQue; + std::map>> HBMSwapKeyQue; + std::map>> HBMSwapKeyForL3StorageQue; + std::map>> DDRSwapKeyQue; + std::map>> DDRSwapKeyForL3StorageQue; + std::map>> HBMSwapAddrsQue; + std::map>> DDRSwapAddrsQue; - std::mutex evictMut; + std::mutex evictMut; - std::map> trainKeysSet; - const string SWAP_IN_STR = "SwapIn"; - const string SWAP_OUT_STR = "SwapOut"; + std::map> trainKeysSet; + const string SWAP_IN_STR = "SwapIn"; + const string SWAP_OUT_STR = "SwapOut"; - const string ADDR_STR = "Addr"; - ock::ctr::EmbCacheManagerPtr embCache = nullptr; - std::map> lastSwapInPosMap {}; - std::map>> trainTestSwitchInfoStore {}; - std::atomic lookupAddrSuccess {true}; + const string ADDR_STR = "Addr"; + ock::ctr::EmbCacheManagerPtr embCache = nullptr; + std::map> lastSwapInPosMap{}; + std::map>> trainTestSwitchInfoStore{}; + std::atomic lookupAddrSuccess{true}; - std::mutex saveMutex; - std::condition_variable cvCheckSave; + std::mutex saveMutex; + std::condition_variable cvCheckSave; - void SetFeatureTypeForLoad(vector& loadFeatures); + void SetFeatureTypeForLoad(vector& loadFeatures); - void EvictKeys(const string& embName, const vector& keys); + void EvictKeys(const string& embName, const vector& keys); - void InitRankInfo(RankInfo& rankInfo, const vector& embInfos) const; + void InitRankInfo(RankInfo& rankInfo, const vector& embInfos) const; - void EvictL3StorageKeys(const string& embName, const vector& keys) const; + void EvictL3StorageKeys(const string& embName, const vector& keys) const; - void LookUpAndRemoveAddrs(const EmbTaskInfo &info); // L3Storage, synchronous + void LookUpAndRemoveAddrs(const EmbTaskInfo& info); // L3Storage, synchronous - void LookUpSwapAddrs(const std::string &embName, const std::string &swapStr); // DDR, asynchronous + void LookUpSwapAddrs(const std::string& embName, const std::string& swapStr); // DDR, asynchronous - void EmbeddingTask(); + void EmbeddingTask(); - void MultiThreadEmbHDTransWrap(); + void MultiThreadEmbHDTransWrap(); - void EmbeddingLookUpAndSendDDR(int batchId, int index, const EmbInfo& embInfo); + void EmbeddingLookUpAndSendDDR(int batchId, int index, const EmbInfo& embInfo); - void EmbeddingReceiveAndUpdateDDR(int batchId, int index, const EmbInfo& embInfo); + void EmbeddingReceiveAndUpdateDDR(int batchId, int index, const EmbInfo& embInfo); - void EmbeddingLookUpAndSendL3Storage(int batchId, int index, const EmbInfo& embInfo); + void EmbeddingLookUpAndSendL3Storage(int batchId, int index, const EmbInfo& embInfo); - void EmbeddingReceiveAndUpdateL3Storage(int batchId, int index, const EmbInfo& embInfo); + void EmbeddingReceiveAndUpdateL3Storage(int batchId, int index, const EmbInfo& embInfo); - void SendTensorForSwap(const EmbBaseInfo& info, - const vector &swapInPosUint, - const vector &swapOutPosUint); + void SendTensorForSwap(const EmbBaseInfo& info, const vector& swapInPosUint, + const vector& swapOutPosUint); - private: - HybridMgmtBlock* hybridMgmtBlock; - vector mgmtEmbInfo; - RankInfo mgmtRankInfo; - CacheManager* cacheManager; - vector> procThreads {}; - map> evictKeyMap {}; - HDTransfer *hdTransfer; - OffsetMapT offsetMapToSend; - OffsetMapT loadOffsetToSend; - bool isL3StorageEnabled { false }; - bool isRunning; - bool isLoad { false }; - bool isInitialized { false }; - bool alreadyTrainOnce = false; // 用于判断是否为predict模式 - map lookUpSwapInAddrsPushId; // 用于处理eos场景,当消费者追上生产者且长时间无上游数据,会触发eos - map specialProcessStatus; +private: + HybridMgmtBlock* hybridMgmtBlock; + vector mgmtEmbInfo; + RankInfo mgmtRankInfo; + CacheManager* cacheManager; + vector> procThreads{}; + map> evictKeyMap{}; + HDTransfer* hdTransfer; + OffsetMapT offsetMapToSend; + OffsetMapT loadOffsetToSend; + bool isL3StorageEnabled{false}; + bool isRunning; + bool isLoad{false}; + bool isInitialized{false}; + bool alreadyTrainOnce = false; // 用于判断是否为predict模式 + map lookUpSwapInAddrsPushId; // 用于处理eos场景,当消费者追上生产者且长时间无上游数据,会触发eos + map specialProcessStatus; - void TrainTask(TaskType type); + void TrainTask(TaskType type); - void EvalTask(TaskType type); + void EvalTask(TaskType type); - void SendUniqKeysAndRestoreVecHBM(const EmbBaseInfo &info, - const unique_ptr> &infoVecs, bool isGrad) const; + void SendUniqKeysAndRestoreVecHBM(const EmbBaseInfo& info, const unique_ptr>& infoVecs, + bool isGrad) const; - void HandleEndBatchCase(const EmbBaseInfo& info, vector& swapInPos); + void HandleEndBatchCase(const EmbBaseInfo& info, vector& swapInPos); - bool IsTrainEndBatch(int batchId) const; + bool IsTrainEndBatch(int batchId) const; - bool IsEvalEndBatch(int batchId) const; + bool IsEvalEndBatch(int batchId) const; - void InitEmbeddingCache(const vector& embInfos); + void InitEmbeddingCache(const vector& embInfos); - void InitDataPipelineForDDR(const string &embName); + void InitDataPipelineForDDR(const string& embName); - void InitDataPipelineForL3Storage(const string &embName, int extEmbeddingSize); + void InitDataPipelineForL3Storage(const string& embName, int extEmbeddingSize); - void JoinEmbeddingCacheThread(); + void JoinEmbeddingCacheThread(); - void HandleReachMaxStepCase(const EmbBaseInfo& info, bool& remainBatchOut); + void HandleReachMaxStepCase(const EmbBaseInfo& info, bool& remainBatchOut); - void HandleEosCase(const EmbBaseInfo& info, bool& remainBatchOut); + void HandleEosCase(const EmbBaseInfo& info, bool& remainBatchOut); - void HandleEosCaseHBM(const string& embName, int batchId, int channelId, bool& remainBatchOut); + void HandleEosCaseHBM(const string& embName, int batchId, int channelId, bool& remainBatchOut); - bool EmbeddingReceiveDDR(const EmbTaskInfo& info, float*& ptr, vector& swapOutAddrs); + bool EmbeddingReceiveDDR(const EmbTaskInfo& info, float*& ptr, vector& swapOutAddrs); - void EmbeddingUpdateDDR(const EmbTaskInfo& info, const float* embPtr, vector& swapOutAddrs); + void EmbeddingUpdateDDR(const EmbTaskInfo& info, const float* embPtr, vector& swapOutAddrs); - bool EmbeddingLookUpDDR(const EmbTaskInfo& info, vector& h2dEmb); + bool EmbeddingLookUpDDR(const EmbTaskInfo& info, vector& h2dEmb); - void EmbeddingSendDDR(const EmbTaskInfo& info, vector& h2dEmb); + void EmbeddingSendDDR(const EmbTaskInfo& info, vector& h2dEmb); - bool EmbeddingReceiveL3Storage(const EmbTaskInfo& info, float*& ptr, vector& swapOutAddrs, - int64_t& dims0); + bool EmbeddingReceiveL3Storage(const EmbTaskInfo& info, float*& ptr, vector& swapOutAddrs, int64_t& dims0); - void EmbeddingUpdateL3Storage(const EmbTaskInfo& info, float* embPtr, vector& swapOutAddrs, - int64_t& dims0); + void EmbeddingUpdateL3Storage(const EmbTaskInfo& info, float* embPtr, vector& swapOutAddrs, int64_t& dims0); - bool EmbeddingLookUpL3Storage(const EmbTaskInfo& info, vector& h2dEmb); + bool EmbeddingLookUpL3Storage(const EmbTaskInfo& info, vector& h2dEmb); - void EmbeddingSendL3Storage(const EmbTaskInfo& info, vector& h2dEmb); + void EmbeddingSendL3Storage(const EmbTaskInfo& info, vector& h2dEmb); - void CreateEmbeddingLookUpAndSendThread(int index, const EmbInfo& embInfo); + void CreateEmbeddingLookUpAndSendThread(int index, const EmbInfo& embInfo); - void CreateEmbeddingReceiveAndUpdateThread(int index, const EmbInfo& embInfo); + void CreateEmbeddingReceiveAndUpdateThread(int index, const EmbInfo& embInfo); - void HandleFirstBatchCaseDDR(const EmbBaseInfo& info, - std::pair, vector>& swapInKoPair, - std::pair, vector>& swapOutKoPair); + void HandleFirstBatchCaseDDR(const EmbBaseInfo& info, std::pair, vector>& swapInKoPair, + std::pair, vector>& swapOutKoPair); - void HandleFirstBatchCaseL3Storage(const EmbBaseInfo& info, - std::pair, vector>& swapInKoPair, - std::pair, vector>& swapOutKoPair); + void HandleFirstBatchCaseL3Storage(const EmbBaseInfo& info, + std::pair, vector>& swapInKoPair, + std::pair, vector>& swapOutKoPair); - void HandleDataSwapForL3Storage(const EmbBaseInfo& info, - vector &swapInKeys, vector &swapOutKeys); + void HandleDataSwapForL3Storage(const EmbBaseInfo& info, vector& swapInKeys, + vector& swapOutKeys); - bool BuildH2DEmbedding(const EmbTaskInfo& info, vector& h2dEmb); + bool BuildH2DEmbedding(const EmbTaskInfo& info, vector& h2dEmb); - vector GetUniqueKeys(const EmbBaseInfo& info, bool& remainBatchOut); + vector GetUniqueKeys(const EmbBaseInfo& info, bool& remainBatchOut); - vector GetRestoreVecSec(const EmbBaseInfo& info, bool& remainBatchOut); + vector GetRestoreVecSec(const EmbBaseInfo& info, bool& remainBatchOut); - void SendAll2AllVec(const EmbBaseInfo& info, bool& remainBatchOut); + void SendAll2AllVec(const EmbBaseInfo& info, bool& remainBatchOut); - void SendRestoreVec(const EmbBaseInfo& info, bool& remainBatchOut); + void SendRestoreVec(const EmbBaseInfo& info, bool& remainBatchOut); - void SendLookupOffsets(const EmbBaseInfo& info, vector& uniqueKeys, vector& restoreVecSec); + void SendLookupOffsets(const EmbBaseInfo& info, vector& uniqueKeys, vector& restoreVecSec); - void SendGlobalUniqueVec(const EmbBaseInfo& info, vector& uniqueKeys, vector& restoreVecSec); + void SendGlobalUniqueVec(const EmbBaseInfo& info, vector& uniqueKeys, vector& restoreVecSec); - bool HandleSpecialProcessStatusDDR(const EmbBaseInfo& info, TimeCost& getAndSendTensorsTC, - std::pair, vector>& swapInKoPair, - std::pair, vector>& swapOutKoPair); + bool HandleSpecialProcessStatusDDR(const EmbBaseInfo& info, TimeCost& getAndSendTensorsTC, + std::pair, vector>& swapInKoPair, + std::pair, vector>& swapOutKoPair); - bool HandleSpecialProcessStatusL3Storage(const EmbBaseInfo& info, TimeCost& getAndSendTensorsTC, - std::pair, vector>& swapInKoPair, - std::pair, vector>& swapOutKoPair); + bool HandleSpecialProcessStatusL3Storage(const EmbBaseInfo& info, TimeCost& getAndSendTensorsTC, + std::pair, vector>& swapInKoPair, + std::pair, vector>& swapOutKoPair); - void CheckLookupAddrSuccessDDR(); + void CheckLookupAddrSuccessDDR(); - void GetSwapPairsAndKey2Offset(const EmbBaseInfo& info, vector &uniqueKeys, - std::pair, vector>& swapInKoPair, - std::pair, vector>& swapOutKoPair); + void GetSwapPairsAndKey2Offset(const EmbBaseInfo& info, vector& uniqueKeys, + std::pair, vector>& swapInKoPair, + std::pair, vector>& swapOutKoPair); - void EnqueueSwapInfo(const EmbBaseInfo& info, - std::pair, vector>& swapInKoPair, - std::pair, vector>& swapOutKoPair); + void EnqueueSwapInfo(const EmbBaseInfo& info, std::pair, vector>& swapInKoPair, + std::pair, vector>& swapOutKoPair); - bool IsTrainAndEvalCase(); - }; -} -#endif // MX_REC_EMB_MGMT_H + bool IsTrainAndEvalCase(); +}; +} // namespace MxRec +#endif // MX_REC_EMB_MGMT_H -- Gitee From ade2d1089abee6af3a3e4ef09313c5b1d5522bd7 Mon Sep 17 00:00:00 2001 From: penghuiyang <1060916628@qq.com> Date: Fri, 5 Jul 2024 13:21:35 +0000 Subject: [PATCH 18/37] =?UTF-8?q?!209=20=E3=80=90FIX=E3=80=91=E5=A4=9A?= =?UTF-8?q?=E6=9C=BA=E8=AE=AD=E7=BB=83=E6=95=B0=E6=8D=AE=E4=BF=9D=E5=AD=98?= =?UTF-8?q?=E5=8A=A0=E8=BD=BD=E9=80=82=E9=85=8D=20*=20=E3=80=90FIX?= =?UTF-8?q?=E3=80=91=E5=A4=9A=E6=9C=BA=E8=AE=AD=E7=BB=83=E6=95=B0=E6=8D=AE?= =?UTF-8?q?=E4=BF=9D=E5=AD=98=E5=8A=A0=E8=BD=BD=E9=80=82=E9=85=8D+hdfs=20*?= =?UTF-8?q?=20=E3=80=90FIX=E3=80=91=E5=A4=9A=E6=9C=BA=E8=AE=AD=E7=BB=83?= =?UTF-8?q?=E6=95=B0=E6=8D=AE=E4=BF=9D=E5=AD=98=E5=8A=A0=E8=BD=BD=E9=80=82?= =?UTF-8?q?=E9=85=8D+hdfs=20*=20=E3=80=90FIX=E3=80=91=E5=A4=9A=E6=9C=BA?= =?UTF-8?q?=E8=AE=AD=E7=BB=83=E6=95=B0=E6=8D=AE=E4=BF=9D=E5=AD=98=E5=8A=A0?= =?UTF-8?q?=E8=BD=BD=E9=80=82=E9=85=8D=20*=20=E3=80=90FIX=E3=80=91?= =?UTF-8?q?=E5=A4=9A=E6=9C=BA=E8=AE=AD=E7=BB=83=E6=95=B0=E6=8D=AE=E4=BF=9D?= =?UTF-8?q?=E5=AD=98=E5=8A=A0=E8=BD=BD=E9=80=82=E9=85=8D?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- mx_rec/saver/patch.py | 8 +++++--- mx_rec/saver/saver.py | 27 ++++++++++++++++++--------- 2 files changed, 23 insertions(+), 12 deletions(-) diff --git a/mx_rec/saver/patch.py b/mx_rec/saver/patch.py index dcdf95ca..0f3a237b 100644 --- a/mx_rec/saver/patch.py +++ b/mx_rec/saver/patch.py @@ -44,7 +44,8 @@ from tensorflow.python.training.saving import saveable_object_util import numpy as np from mpi4py import MPI -from mx_rec.saver.saver import Saver as SparseSaver, check_file_system_is_valid +from mx_rec.saver.saver import Saver as SparseSaver, check_file_system_is_valid, should_write_data +from mx_rec.util.communication.hccl_ops import get_local_rank_size from mx_rec.util.initialize import ConfigInitializer from mx_rec.validator.validator import para_checker_decorator, ClassValidator, StringValidator, OptionalIntValidator, \ OptionalStringValidator, DirectoryValidator @@ -253,7 +254,7 @@ def save(self, sess, save_path, global_step=None, latest_filename=None, meta_gra comm = MPI.COMM_WORLD rank = comm.Get_rank() comm.Barrier() - if rank == 0: + if should_write_data(rank, save_path): model_checkpoint_path = compat.as_str(get_model_checkpoint_path(self, checkpoint_file, sess)) if write_state: update_checkpoint_state(self, model_checkpoint_path, save_path_parent, latest_filename, meta_graph_suffix, @@ -453,10 +454,11 @@ def patch_for_write_graph_func(func): comm = MPI.COMM_WORLD rank = comm.Get_rank() # In the case of multiple processes, choose one process to write graph. - if rank == 0: + if len(args) > 1 and should_write_data(rank, args[1]): return func(*args, **kwargs) else: return None + return wrapper diff --git a/mx_rec/saver/saver.py b/mx_rec/saver/saver.py index 9e0e1d29..a6362506 100644 --- a/mx_rec/saver/saver.py +++ b/mx_rec/saver/saver.py @@ -35,7 +35,6 @@ from mx_rec.util.log import logger from mx_rec.optimizers.base import CustomizedOptimizer from mx_rec.util.tf_version_adapter import npu_ops - SAVE_SPARSE_PATH_PREFIX = "sparse" @@ -171,7 +170,7 @@ class Saver(object): comm = MPI.COMM_WORLD rank = comm.Get_rank() comm.Barrier() - if rank == 0: + if should_write_data(rank, saving_path): table_list = self.save_op_dict.keys() for table_name in table_list: self.merge_sparse_file(saving_path, table_name) @@ -267,7 +266,7 @@ class Saver(object): else: self._save_ddr(sess, root_dir) logger.debug(f"Host data was saved.") - + def _save_hbm(self, sess, root_dir): self.config_instance.hybrid_manager_config.save_host_data(root_dir) if self.config_instance.use_dynamic_expansion: @@ -285,7 +284,7 @@ class Saver(object): for thread in threads: thread.join() - + def _save_ddr(self, sess, root_dir): # 接受host侧传来的需要swap_out的offset用于更新host侧并保存 self.config_instance.hybrid_manager_config.fetch_device_emb() @@ -306,7 +305,7 @@ class Saver(object): channel_name=f'{table_name}_save_h2d_{TRAIN_CHANNEL_ID}') if use_static: swap_out_pos = swap_out_pos[:swap_out_len] - + table = [var] optimizer = ConfigInitializer.get_instance().optimizer_config.get_optimizer_by_table_name(table_name) if optimizer is not None: @@ -382,7 +381,6 @@ class Saver(object): else: placeholder_dict, restore_fetch_list = self.placeholder_dict, self.restore_fetch_dict - for table_name in placeholder_dict: optimizer_instance = ConfigInitializer.get_instance().optimizer_config.optimizer_instance if optimizer_instance: @@ -395,7 +393,7 @@ class Saver(object): table_instance0 = self.config_instance.sparse_embed_config.get_table_instance(self.var_list[0]) if not table_instance0.is_hbm: return - + if self.config_instance.use_dynamic_expansion: # Data related to dynamic expansion needs to be restored only on the host side. return @@ -405,7 +403,7 @@ class Saver(object): for table_name, sub_placeholder_dict in placeholder_dict.items(): load_offset = self.config_instance.hybrid_manager_config.get_load_offset(table_name) fill_placeholder(reading_path, sub_placeholder_dict, restore_feed_dict, - NameDescriptor(table_name, DataName.EMBEDDING.value), load_offset) + NameDescriptor(table_name, DataName.EMBEDDING.value), load_offset) if "optimizer" in sub_placeholder_dict: optimizer_state_placeholder_dict_group = sub_placeholder_dict.get("optimizer") @@ -698,4 +696,15 @@ def set_optimizer_info(optimizer: CustomizedOptimizer, table_name: str): """ from mxrec_pybind import OptimizerInfo optim_info = OptimizerInfo(optimizer.optimizer_type, optimizer.optim_param_list) - ConfigInitializer.get_instance().hybrid_manager_config.set_optim_info(table_name, optim_info) \ No newline at end of file + ConfigInitializer.get_instance().hybrid_manager_config.set_optim_info(table_name, optim_info) + + +def should_write_data(rank_id: int, save_path: str) -> bool: + # When using hdfs filesystem, only the rank0 process execute write data operation, assuming use same hdfs path in + # multi-machine. + # When using local filesystem, the process which `rank_id % local_rank_size == 0` execute write data operation. + # When using hdfs filesystem, and use different hdfs path to save data, should modify check condition + # as same as local filesystem. + is_hdfs = check_file_system_is_hdfs(save_path) + local_rank_size = get_local_rank_size() + return rank_id == 0 if is_hdfs else rank_id % local_rank_size == 0 -- Gitee From f07efc133ddc6a416d6e7f5fec9e7bb2fddacd21 Mon Sep 17 00:00:00 2001 From: penghuiyang <1060916628@qq.com> Date: Mon, 8 Jul 2024 11:46:03 +0800 Subject: [PATCH 19/37] =?UTF-8?q?=E3=80=90FIX=E3=80=91criteo=E6=95=B0?= =?UTF-8?q?=E6=8D=AE=E5=A4=84=E7=90=86=E8=84=9A=E6=9C=AC=E5=88=A4=E6=96=AD?= =?UTF-8?q?=E6=9D=A1=E4=BB=B6=E4=BF=AE=E6=94=B9?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- examples/dlrm/criteo_tb/gen_ttf.py | 18 +++++++++--------- 1 file changed, 9 insertions(+), 9 deletions(-) diff --git a/examples/dlrm/criteo_tb/gen_ttf.py b/examples/dlrm/criteo_tb/gen_ttf.py index 8715f048..986bc6df 100644 --- a/examples/dlrm/criteo_tb/gen_ttf.py +++ b/examples/dlrm/criteo_tb/gen_ttf.py @@ -224,9 +224,9 @@ def make_example(label_list, dense_feat_list, sparse_feat_list): sparse_feature = np.array(sparse_feat_list, dtype=np.int64).reshape(-1) label = np.array(label_list, dtype=np.int64).reshape(-1) feature_dict = {"dense_feature": tf.train.Feature(float_list=tf.train.FloatList(value=dense_feature)), - "sparse_feature": tf.train.Feature(int64_list=tf.train.Int64List(value=sparse_feature)), - "label": tf.train.Feature(int64_list=tf.train.Int64List(value=label)) - } + "sparse_feature": tf.train.Feature(int64_list=tf.train.Int64List(value=sparse_feature)), + "label": tf.train.Feature(int64_list=tf.train.Int64List(value=label)) + } example = tf.train.Example(features=tf.train.Features(feature=feature_dict)) return example @@ -273,10 +273,10 @@ def convert_input2tfrd_multiprocess(proc_num, proc_id, in_file_path, output_file label = int(items[0]) values = items[1:14] cats = items[14:] - if len(values) == 13: - raise ValueError("values.size: {}".format(len(values))) - if len(cats) == 26: - raise ValueError("cats.size: {}".format(len(cats))) + if len(values) != 13: + raise ValueError("dense feature length must be 13, current values.size: {}".format(len(values))) + if len(cats) != 26: + raise ValueError("sparse feature length must be 26, current cats.size: {}".format(len(cats))) val_list, cat_list = criteo_stats_dict.map_cat2id(values, cats) dense_res_list.append(val_list) cat_res_list.append(cat_list) @@ -363,7 +363,7 @@ if __name__ == "__main__": process_num = args.train_process_num if len(train_data_files) == 0: raise ValueError(f'file not exist in train_data_dir:{train_data_dir}') - if process_num % len(train_data_files) == 0: + if process_num % len(train_data_files) != 0: raise ValueError(f'process_num {process_num} must exact div length of train_data_files {len(train_data_files)}') for process_id in range(process_num): @@ -387,7 +387,7 @@ if __name__ == "__main__": process_num = args.test_process_num if len(test_data_files) == 0: raise ValueError(f'file not exist in test_data_dir:{test_data_dir}') - if process_num % len(test_data_files) == 0: + if process_num % len(test_data_files) != 0: raise ValueError(f'process_num {process_num} must exact div length of test_data_files {len(test_data_files)}') for process_id in range(process_num): -- Gitee From 33991245ee3d8f68cccb5d18b5ae6a20fab07014 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E7=BD=97=E5=B9=B8=E8=BF=90?= Date: Tue, 9 Jul 2024 10:49:44 +0800 Subject: [PATCH 20/37] =?UTF-8?q?=E3=80=90FIX=E3=80=91=E6=89=A9=E5=AE=B9?= =?UTF-8?q?=E6=A8=A1=E5=BC=8F=E4=B8=8B=EF=BC=8Ctable.capacity=E5=87=BA?= =?UTF-8?q?=E7=8E=B0=E5=81=B6=E5=8F=91=E8=B4=9F=E5=80=BC?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/core/emb_table/embedding_table.h | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/core/emb_table/embedding_table.h b/src/core/emb_table/embedding_table.h index 3396a8a0..ef741887 100644 --- a/src/core/emb_table/embedding_table.h +++ b/src/core/emb_table/embedding_table.h @@ -114,7 +114,7 @@ protected: size_t embSize_; size_t extEmbSize_; int seed_; - std::atomic capacity_; + std::atomic capacity_{0}; size_t rankId_; size_t rankSize_; vector loadOffset; -- Gitee From 1b81040851f1bf326983ce1e4e6589c0c4a5986d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E4=BD=95=E9=9C=96?= Date: Tue, 9 Jul 2024 15:45:28 +0800 Subject: [PATCH 21/37] =?UTF-8?q?=E4=BF=AE=E6=94=B9mxRec=E9=95=9C=E5=83=8F?= =?UTF-8?q?=E4=BB=93=E7=9A=84=E9=93=BE=E6=8E=A5?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- README.md | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/README.md b/README.md index 17d38fcd..f6bfb828 100644 --- a/README.md +++ b/README.md @@ -125,6 +125,6 @@ mxRec所支持的使用环境、功能特性、API接口与使用样例请参考 mxRec框架基础镜像,基于TensorFlow 1.15.0、tensorflow2.6.5制作的基础镜像,安装mxRec后即可开始训练,以及样例使用介绍。 -1. https://ascendhub.huawei.com/#/detail/mxrec-tf1 +1. https://www.hiascend.com/developer/ascendhub/detail/mxrec-tf1 -2. https://ascendhub.huawei.com/#/detail/mxrec-tf2 +2. https://www.hiascend.com/developer/ascendhub/detail/mxrec-tf2 -- Gitee From 42ac8e68ecab452042587c8fe7bac19c7abca82c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E4=BD=95=E9=9C=96?= Date: Tue, 9 Jul 2024 15:46:14 +0800 Subject: [PATCH 22/37] =?UTF-8?q?=E6=B7=BB=E5=8A=A0dlrm=E6=A8=A1=E5=9E=8B?= =?UTF-8?q?=E8=BF=90=E8=A1=8C=E8=AF=B4=E6=98=8E?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- examples/dlrm/README.md | 60 +++++++++++++++++++++++++++++++++++++++++ 1 file changed, 60 insertions(+) create mode 100644 examples/dlrm/README.md diff --git a/examples/dlrm/README.md b/examples/dlrm/README.md new file mode 100644 index 00000000..85293c0c --- /dev/null +++ b/examples/dlrm/README.md @@ -0,0 +1,60 @@ +# DLRM模型运行说明 + +## 代码结构 +```shell +. +├── criteo_tb +│ ├── gen_ttf.py # criteo_tb原始数据转换成tfrecord格式的脚本 +│ └── README.md # 数据格式转换脚本说明 +├── model +│ ├── config.py # 模型配置文件 +│ ├── delay_loss_scale.py # loss缩放函数 +│ ├── gradient_descent_w.py # 自定义SGD优化器 +│ ├── main_mxrec.py # 主函数 +│ ├── mean_auc.py # 计算acu的脚本 +│ ├── model.py # DLRM模型 +│ ├── op_impl_mode.ini # 算子执行模式配置 +│ ├── optimizer.py # 优化器 +│ └── run.sh # 运行DLRM模型的脚本 +└── README.md # DLRM模型运行说明 +``` + +## 1.准备数据 +参考criteo_tb目录下的说明文档准备好模型所需要的数据集,放在一个目录下,比如:/data/criteo_tb/。 + +## 2.准备运行环境 +运行环境可以参考[mxRec用户指南](https://www.hiascend.com/document/detail/zh/mind-sdk/60rc1/mxRec/mxrecug/mxrecug_0007.html) +“安装部署”章节进行准备。 + +## 3.安装mxRec +mxRec软件包可以通过[mxRec用户指南](https://www.hiascend.com/document/detail/zh/mind-sdk/60rc1/mxRec/mxrecug/mxrecug_0007.html) +“安装部署”>“环境准备”>“获取软件包”章节提供的链接进行下载,选择自己需要的架构(x86或者arm)的mxRec包。下载完成之后,将mxRec包解压,进入解压后的目录(mindxsdk-mxrec) +如下: +```shell +. +├── cust_op +│ └── cust_op_by_addr +├── examples +│ ├── DCNv2 +│ ├── demo +│ └── dlrm +├── tf1_whl +│ └── mx_rec-{version}-py3-none-linux_x86_64.whl # version为版本号 +├── tf2_whl +│ └── mx_rec-{version}-py3-none-linux_x86_64.whl # version为版本号 +└── version.info +``` +其中,tf1_whl和tf2_whl目录下分别是适配tf1和tf2的mxRec软件包,按照自己需要选择其中一个进行安装即可(用pip/pip3 install 软件包这种方式进行安装)。 +确认安装mxRec的目录,比如mxRec安装在 /usr/local/python3.7.5/lib/python3.7/site-packages/mx_rec这个目录下。 + +## 4.运行DLRM模型 +执行完以上步骤之后,接下来就可以运行DLRM模型,其中run.sh就是运行的脚本,默认是8张卡。其中需要传入5个参数,分别对应:so_path、mx_rec_package_path、hccl_cfg_json、 +dlrm_criteo_data_path和ip。运行命令如: +```shell +bash run.sh {so_path} {mx_rec_package_path} {hccl_cfg_json} {dlrm_criteo_data_path} {ip} +``` +* so_path:so_path是mxRec中动态库的目录,一般在mxRec的安装目录下的libasc目录,比如:/usr/local/python3.7.5/lib/python3.7/site-packages/mx_rec/libasc。 +* mx_rec_package_path:mx_rec_package_path是mxRec的安装目录,比如:/usr/local/python3.7.5/lib/python3.7/site-packages/mx_rec。 +* hccl_cfg_json:hccl_cfg_json是hccl通信配置文件,如果配置了ip参数,这个参数就不用了,直接给一个""空字符串即可。 +* dlrm_criteo_data_path:dlrm_criteo_data_path是数据集所在的目录,比如/data/criteo_tb/。 +* ip:ip是运行模型的机器所在的ip,建议配置。 -- Gitee From 31aa8b6db348a4e8dd2688b1331559eb20264aa1 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E4=BD=95=E9=9C=96?= Date: Tue, 9 Jul 2024 15:54:15 +0800 Subject: [PATCH 23/37] =?UTF-8?q?=E6=B7=BB=E5=8A=A0DCNv2=E6=A8=A1=E5=9E=8B?= =?UTF-8?q?=E8=BF=90=E8=A1=8C=E8=AF=B4=E6=98=8E?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- examples/DCNv2/README.md | 54 ++++++++++++++++++++++++++++++++++++++++ 1 file changed, 54 insertions(+) create mode 100644 examples/DCNv2/README.md diff --git a/examples/DCNv2/README.md b/examples/DCNv2/README.md new file mode 100644 index 00000000..e9b8a75f --- /dev/null +++ b/examples/DCNv2/README.md @@ -0,0 +1,54 @@ +# DCNv2模型运行说明 + +## 代码结构 +```shell +. +├── config.py # 模型配置文件 +├── delay_loss_scale.py # loss缩放函数 +├── main_mxrec.py # 主函数 +├── model.py # DCNv2模型 +├── op_impl_mode.ini # 算子执行模式配置 +├── optimizer.py # 优化器 +├── README.md # DCNv2模型运行说明 +└── run.sh # 运行DCNv2模型的脚本 +``` + +## 1.准备数据 +参考dlrm模型中criteo_tb目录下的说明文档准备好模型所需要的数据集,放在一个目录下,比如:/data/criteo_tb/。 + +## 2.准备运行环境 +运行环境可以参考[mxRec用户指南](https://www.hiascend.com/document/detail/zh/mind-sdk/60rc1/mxRec/mxrecug/mxrecug_0007.html) +“安装部署”章节进行准备。 + +## 3.安装mxRec +mxRec软件包可以通过[mxRec用户指南](https://www.hiascend.com/document/detail/zh/mind-sdk/60rc1/mxRec/mxrecug/mxrecug_0007.html) +“安装部署”>“环境准备”>“获取软件包”章节提供的链接进行下载,选择自己需要的架构(x86或者arm)的mxRec包。下载完成之后,将mxRec包解压,进入解压后的目录(mindxsdk-mxrec) +如下: +```shell +. +├── cust_op +│ └── cust_op_by_addr +├── examples +│ ├── DCNv2 +│ ├── demo +│ └── dlrm +├── tf1_whl +│ └── mx_rec-{version}-py3-none-linux_x86_64.whl # version为版本号 +├── tf2_whl +│ └── mx_rec-{version}-py3-none-linux_x86_64.whl # version为版本号 +└── version.info +``` +其中,tf1_whl和tf2_whl目录下分别是适配tf1和tf2的mxRec软件包,按照自己需要选择其中一个进行安装即可(用pip/pip3 install 软件包这种方式进行安装)。 +确认安装mxRec的目录,比如mxRec安装在 /usr/local/python3.7.5/lib/python3.7/site-packages/mx_rec这个目录下。 + +## 4.运行DLRM模型 +执行完以上步骤之后,接下来就可以运行DLRM模型,其中run.sh就是运行的脚本,默认是8张卡。其中需要传入5个参数,分别对应:so_path、mx_rec_package_path、hccl_cfg_json、 +dlrm_criteo_data_path和ip。运行命令如: +```shell +bash run.sh {so_path} {mx_rec_package_path} {hccl_cfg_json} {dlrm_criteo_data_path} {ip} +``` +* so_path:so_path是mxRec中动态库的目录,一般在mxRec的安装目录下的libasc目录,比如:/usr/local/python3.7.5/lib/python3.7/site-packages/mx_rec/libasc。 +* mx_rec_package_path:mx_rec_package_path是mxRec的安装目录,比如:/usr/local/python3.7.5/lib/python3.7/site-packages/mx_rec。 +* hccl_cfg_json:hccl_cfg_json是hccl通信配置文件,如果配置了ip参数,这个参数就不用了,直接给一个""空字符串即可。 +* dlrm_criteo_data_path:dlrm_criteo_data_path是数据集所在的目录,比如/data/criteo_tb/。 +* ip:ip是运行模型的机器所在的ip,建议配置。 -- Gitee From 39fa9310b122431bffc75204d7fb8d18343db93f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E9=83=AD=E6=9C=9B?= <1244372993@qq.com> Date: Tue, 9 Jul 2024 16:10:33 +0800 Subject: [PATCH 24/37] =?UTF-8?q?WideDeep=E6=A0=B7=E4=BE=8B=20README?= =?UTF-8?q?=E6=96=87=E6=A1=A3=E5=AE=8C=E5=96=84?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- examples/WideDeep/{README_WD.md => README.md} | 45 +++++++++++++------ 1 file changed, 31 insertions(+), 14 deletions(-) rename examples/WideDeep/{README_WD.md => README.md} (89%) diff --git a/examples/WideDeep/README_WD.md b/examples/WideDeep/README.md similarity index 89% rename from examples/WideDeep/README_WD.md rename to examples/WideDeep/README.md index beb592c9..aef2379f 100644 --- a/examples/WideDeep/README_WD.md +++ b/examples/WideDeep/README.md @@ -4,7 +4,7 @@ *** ## 开源项目链接 - +Commits on Apr 29, 2022, 提交的SHA-1 hash值(提交ID):4bbfb492b872c5a3290a2bce1ed5c160162558a3 ```shell https://github.com/ZiyaoGeng/RecLearn ``` @@ -41,7 +41,7 @@ python critro.py --data_path data_path --output_path output_path ```python # get txt_list -split_file_list = get_split_file_path(dataset_path = dataset_path) +file_split_list = get_split_file_path(dataset_path=data_path) ``` *** #### 2. 建立特征映射 @@ -49,7 +49,7 @@ split_file_list = get_split_file_path(dataset_path = dataset_path) ```python # get feature_map -fea_map = get_fea_map(split_file_list=split_file_list) +feature_map = get_fea_map(split_file_list=file_split_list) ``` *** #### 3. dense_feature分桶离散化 @@ -57,7 +57,7 @@ fea_map = get_fea_map(split_file_list=split_file_list) ```python # dense feature: Bin continuous data into intervals. -data_df[dense_features] = rec_kbins_discretizer(data_df[dense_features], 1000, fea_map) +data_df[dense_features] = rec_kbins_discretizer(data_df[dense_features], 1000, feature_map) ``` *** #### 4. sparse_feature特征映射 @@ -66,7 +66,10 @@ data_df[dense_features] = rec_kbins_discretizer(data_df[dense_features], 1000, f ```python # sparse feature: mapping for col in sparse_features: - data_df[col] = data_df[col].map(lambda x: fea_map[col][x]) + try: + data_df[col] = data_df[col].map(lambda x: feature_map[col][x]) + except KeyError as er: + raise KeyError("Feature {} not found in dataset".format(col)) from er ``` *** #### 5. 39个特征增加偏移项 @@ -74,12 +77,14 @@ for col in sparse_features: ```python # add offsets -slot_size_array = [1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, - 1462, 585, 10131228, 2202609, 307, 25, 12519, 635, 5, 93147, 5685, 8351594, 3196, - 29, 14994, 5461307, 12, 5654, 2174, 5, 7046548, 19, 17, 286182, 106, 142573] +slot_size_array = [ + 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, + 1462, 585, 10131228, 2202609, 307, 25, 12519, 635, 5, 93147, 5685, 8351594, 3196, + 29, 14994, 5461307, 12, 5654, 2174, 5, 7046548, 19, 17, 286182, 106, 142573 +] offset_size_list = np.cumsum([0] + slot_size_array[:-1]) -for j in range(1,len(offset_size_list)+1): - data_df.iloc[:, j] += offset_size_list[j-1] +for col_index in range(1, len(offset_size_list) + 1): + data_df.iloc[:, col_index] += offset_size_list[col_index - 1] ``` *** #### 6. 数据集格式转换:txt >> tfrecord @@ -93,13 +98,25 @@ convert_input2tfrd(in_file_path=file, out_file_path=output_path) ## 模型运行 -参考mxrec的`README.md`文件在NPU服务器上配置环境后,可按照[mxrec-tf1](https://ascendhub.huawei.com/#/detail/mxrec-tf1)中DLRM模型运行命令启动模型训练。`so_path`、`mx_rec_package_path`、`hccl_cfg_json`配置不变,根据实际数据集路径配置`dlrm_criteo_data_path`。 +参考mxrec的`README.md`文件在NPU服务器上配置环境并安装镜像创建容器后,可参考DLRM模型运行命令启动模型训练。模型运行脚本是run.sh,运行此脚本需要四个参数:so_path、mx_rec_package_path、hccl_cfg_json以及dlrm_criteo_data_path。其中, +- so_path: mxrec中libasc所在路径,在镜像中已经安装过mxrec,所以so_path是:/usr/local/python3.7.5/lib/python3.7/site-packages/mx_rec/libasc/ +- mx_rec_package_path: mxrec这个包的安装路径,镜像中是:/usr/local/python3.7.5/lib/python3.7/site-packages/mx_rec/ +- hccl_cfg_json: hccl配置文件所在路径,一般是当前路径下的hccl文件 +- dlrm_criteo_data_path: Wide&Deep模型需要的数据所在路径,根据实际情况进行配置 +运行mxRec有两种方式,一种是使用hccl配置文件(rank table方案),一种是不使用hccl配置文件(去rank table方案)。 +- 使用hccl配置文件(rank table方案) ```shell -# 运行命令 bash run.sh {so_path} {mx_rec_package_path} {hccl_cfg_json} {dlrm_criteo_data_path} ``` *** +- 不使用hccl配置文件(去rank table方案) +```shell +bash run.sh {so_path} {mx_rec_package_path} {hccl_cfg_json} {dlrm_criteo_data_path} {IP} +``` +如:bash run.sh /usr/local/python3.7.5/lib/python3.7/site-packages/mx_rec/libasc/ /usr/local/python3.7.5/lib/python3.7/site-packages/mx_rec/ hccl_json_8p.json /dataset 10.10.10.10。 +**注意:** 去rank table方案,当前路径下不存在hccl文件,模型仍可正常运行。 + ## 模型结果 [开源项目](https://github.com/ZiyaoGeng/RecLearn)使用Criteo4500W数据集在GPU上训练模型,结果为`Log Loss=0.4692`、`AUC=0.7930`。适配完成模型后,固定`CACHE_MODE="HBM"`、`USE_FAAE=0`,在`run.sh`中配置其他选项卡,运行结果如下。 @@ -135,8 +152,8 @@ bash run.sh {so_path} {mx_rec_package_path} {hccl_cfg_json} {dlrm_criteo_data_pa *** ## 模型迁移 -**迁移思路:** 在现有已适配好的dlrm模型框架下,改动相关代码逻辑,完成Wide&deep模型的适配。**核心:根据开源项目model代码修改`model.py`;数据处理操作一部分放入`criteo.py`,一部分放入`main_mxrec.py`中`make_batch_and_iterator()`内;`main_mxrec.py`中其他相关代码改动主要是为了适配mxrec提供的相关特性。** - +**迁移思路:** 在现有已适配好的dlrm模型框架下,改动相关代码逻辑,完成Wide&deep模型的适配。**核心:根据开源项目model代码修改`model.py`;数据处理操作一部分放入`criteo.py`,一部分放入`main_mxrec.py`中`make_batch_and_iterator()`内;`main_mxrec.py`中其他相关代码改动主要是为了适配mxrec提供的相关特性。** +详细改动见https://gitee.com/ascend/mxrec/pulls/171/commits,Commits ID:7a05b033d41af51df9aed7414ad04216dff821cc。 下文所提到的`动态扩容`、`动态shape`、`自动改图`、`一表多查`是mxrec提供的相关特性,开关选项见`run.sh`。 ```shell -- Gitee From 909ace13858217f2812884cd13d0ad8aeaaf7d19 Mon Sep 17 00:00:00 2001 From: penghuiyang <1060916628@qq.com> Date: Tue, 9 Jul 2024 16:23:46 +0800 Subject: [PATCH 25/37] =?UTF-8?q?=E3=80=90FIX=E3=80=91LazyAdam=E8=9E=8D?= =?UTF-8?q?=E5=90=88=E7=AE=97=E5=AD=90=E6=8F=8F=E8=BF=B0=E4=BF=A1=E6=81=AF?= =?UTF-8?q?=E4=BF=AE=E6=94=B9?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- cust_op/fused_lazy_adam/README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/cust_op/fused_lazy_adam/README.md b/cust_op/fused_lazy_adam/README.md index 13ed6994..3cb69f2d 100644 --- a/cust_op/fused_lazy_adam/README.md +++ b/cust_op/fused_lazy_adam/README.md @@ -6,7 +6,7 @@ ├── aclnn_lazy_adam_test # 单算子测试用例 ├── lazy_adam.json # 算子原型配置 ├── op_host # LazyAdam融合算子Host侧实现 -├── op_kernel # LazyAdam融合算子Kernel测实现 +├── op_kernel # LazyAdam融合算子Kernel侧实现 ├── README.md # LazyAdam融合算子说明文档 └── run.sh # LazyAdam融合算子安装脚本 ``` -- Gitee From 14ac6e7f2f7d5b62f9ba4aaae3e57c2082ea036a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E9=83=AD=E6=9C=9B?= <1244372993@qq.com> Date: Tue, 9 Jul 2024 17:14:57 +0800 Subject: [PATCH 26/37] =?UTF-8?q?=E6=A3=80=E8=A7=86=E6=84=8F=E8=A7=81?= =?UTF-8?q?=E4=BF=AE=E6=94=B9?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- examples/WideDeep/README.md | 5 +++-- examples/WideDeep/criteo.py | 4 ++-- 2 files changed, 5 insertions(+), 4 deletions(-) diff --git a/examples/WideDeep/README.md b/examples/WideDeep/README.md index aef2379f..f4815cd9 100644 --- a/examples/WideDeep/README.md +++ b/examples/WideDeep/README.md @@ -5,6 +5,7 @@ *** ## 开源项目链接 Commits on Apr 29, 2022, 提交的SHA-1 hash值(提交ID):4bbfb492b872c5a3290a2bce1ed5c160162558a3 +commit的链接: https://github.com/ZiyaoGeng/RecLearn/tree/4bbfb492b872c5a3290a2bce1ed5c160162558a3 ```shell https://github.com/ZiyaoGeng/RecLearn ``` @@ -68,8 +69,8 @@ data_df[dense_features] = rec_kbins_discretizer(data_df[dense_features], 1000, f for col in sparse_features: try: data_df[col] = data_df[col].map(lambda x: feature_map[col][x]) - except KeyError as er: - raise KeyError("Feature {} not found in dataset".format(col)) from er + except KeyError as e: + raise KeyError("Feature {} not found in dataset".format(col)) from e ``` *** #### 5. 39个特征增加偏移项 diff --git a/examples/WideDeep/criteo.py b/examples/WideDeep/criteo.py index 617c76f6..3c8ea430 100644 --- a/examples/WideDeep/criteo.py +++ b/examples/WideDeep/criteo.py @@ -248,8 +248,8 @@ if __name__ == '__main__': for col in sparse_features: try: data_df[col] = data_df[col].map(lambda x: feature_map[col][x]) - except KeyError as er: - raise KeyError("Feature {} not found in dataset".format(col)) from er + except KeyError as e: + raise KeyError("Feature {} not found in dataset".format(col)) from e # dense feature: Bin continuous data into intervals. data_df[dense_features] = rec_kbins_discretizer(data_df[dense_features], 1000, feature_map) # add offsets -- Gitee From 30d416ea128496119c1e95ed43240628727c7ca3 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E4=BD=95=E9=9C=96?= Date: Tue, 9 Jul 2024 20:18:51 +0800 Subject: [PATCH 27/37] =?UTF-8?q?=E6=B7=BB=E5=8A=A0demo=E6=A8=A1=E5=9E=8B?= =?UTF-8?q?=E8=BF=90=E8=A1=8C=E8=AF=B4=E6=98=8E?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- examples/DCNv2/README.md | 6 +- examples/demo/README.md | 13 +++++ examples/demo/little_demo/README.md | 56 ++++++++++++++++++ examples/demo/little_demo_estimator/README.md | 57 +++++++++++++++++++ 4 files changed, 129 insertions(+), 3 deletions(-) create mode 100644 examples/demo/README.md create mode 100644 examples/demo/little_demo/README.md create mode 100644 examples/demo/little_demo_estimator/README.md diff --git a/examples/DCNv2/README.md b/examples/DCNv2/README.md index e9b8a75f..f1940ebe 100644 --- a/examples/DCNv2/README.md +++ b/examples/DCNv2/README.md @@ -14,7 +14,7 @@ ``` ## 1.准备数据 -参考dlrm模型中criteo_tb目录下的说明文档准备好模型所需要的数据集,放在一个目录下,比如:/data/criteo_tb/。 +参考DLRM模型中criteo_tb目录下的说明文档准备好模型所需要的数据集,放在一个目录下,比如:/data/criteo_tb/。 ## 2.准备运行环境 运行环境可以参考[mxRec用户指南](https://www.hiascend.com/document/detail/zh/mind-sdk/60rc1/mxRec/mxrecug/mxrecug_0007.html) @@ -41,8 +41,8 @@ mxRec软件包可以通过[mxRec用户指南](https://www.hiascend.com/document/ 其中,tf1_whl和tf2_whl目录下分别是适配tf1和tf2的mxRec软件包,按照自己需要选择其中一个进行安装即可(用pip/pip3 install 软件包这种方式进行安装)。 确认安装mxRec的目录,比如mxRec安装在 /usr/local/python3.7.5/lib/python3.7/site-packages/mx_rec这个目录下。 -## 4.运行DLRM模型 -执行完以上步骤之后,接下来就可以运行DLRM模型,其中run.sh就是运行的脚本,默认是8张卡。其中需要传入5个参数,分别对应:so_path、mx_rec_package_path、hccl_cfg_json、 +## 4.运行DCNv2模型 +执行完以上步骤之后,接下来就可以运行DCNv2模型,其中run.sh就是运行的脚本,默认是8张卡。其中需要传入5个参数,分别对应:so_path、mx_rec_package_path、hccl_cfg_json、 dlrm_criteo_data_path和ip。运行命令如: ```shell bash run.sh {so_path} {mx_rec_package_path} {hccl_cfg_json} {dlrm_criteo_data_path} {ip} diff --git a/examples/demo/README.md b/examples/demo/README.md new file mode 100644 index 00000000..931f8c26 --- /dev/null +++ b/examples/demo/README.md @@ -0,0 +1,13 @@ +# demo样例说明 + +## 代码结构 +```shell +. +├── little_demo # sess.run模式的demo +├── little_demo_estimator # estimator模式的demo +└── README.md # demo样例说明 +``` + +mxRec提供了一个非常简单的样例模型demo,用于快速体验mxRec。在TensorFlow中,运行模型有sess.run和estimator两种模式。因此,mxRec也提供了两种 +模式下的样例。其中little_demo是sess.run模式的样例;little_demo_estimator是estimator模式的样例。用户可以选择自己需要或者感兴趣的模式进行 +体验,各个模式的样例的说明见对应目录下的README文档。 \ No newline at end of file diff --git a/examples/demo/little_demo/README.md b/examples/demo/little_demo/README.md new file mode 100644 index 00000000..dabe105b --- /dev/null +++ b/examples/demo/little_demo/README.md @@ -0,0 +1,56 @@ +# sess.run模式下demo模型运行说明 + +## 代码结构 +```shell +. +├── config.py # 模型配置文件 +├── dataset.py # 生成数据集的脚本 +├── deterministic_loss # 确定性计算loss样例 +├── main.py # 主函数 +├── model.py # demo模型 +├── op_impl_mode.ini # 算子执行模式配置 +├── optimizer.py # 优化器 +├── random_data_generator.py # 数据生成器 +├── README.md # demo模型运行说明 +├── run_deterministic.sh # 运行确定性计算的脚本 +├── run_mode.py # 执行模型train、evaluate和predict的脚本 +└── run.sh # demo运行脚本 +``` + +## 1.准备数据 +demo样例无需从其他地方下载数据集,在demo样例中mxRec会自动生成数据集,详情见dataset.py和random_data_generator.py。 + +## 2.准备运行环境 +运行环境可以参考[mxRec用户指南](https://www.hiascend.com/document/detail/zh/mind-sdk/60rc1/mxRec/mxrecug/mxrecug_0007.html) +“安装部署”章节进行准备。 + +## 3.安装mxRec +mxRec软件包可以通过[mxRec用户指南](https://www.hiascend.com/document/detail/zh/mind-sdk/60rc1/mxRec/mxrecug/mxrecug_0007.html) +“安装部署”>“环境准备”>“获取软件包”章节提供的链接进行下载,选择自己需要的架构(x86或者arm)的mxRec包。下载完成之后,将mxRec包解压,进入解压后的目录(mindxsdk-mxrec) +如下: +```shell +. +├── cust_op +│ └── cust_op_by_addr +├── examples +│ ├── DCNv2 +│ ├── demo +│ └── dlrm +├── tf1_whl +│ └── mx_rec-{version}-py3-none-linux_x86_64.whl # version为版本号 +├── tf2_whl +│ └── mx_rec-{version}-py3-none-linux_x86_64.whl # version为版本号 +└── version.info +``` +其中,tf1_whl和tf2_whl目录下分别是适配tf1和tf2的mxRec软件包,按照自己需要选择其中一个进行安装即可(用pip/pip3 install 软件包这种方式进行安装)。 +确认安装mxRec的目录,比如mxRec安装在 /usr/local/python3.7.5/lib/python3.7/site-packages/mx_rec这个目录下。 + +## 4.运行demo模型 +执行完以上步骤之后,接下来就可以运行demo模型,其中run.sh就是运行的脚本,默认是8张卡。其中需要传入ip这个参数,运行命令如: +```shell +bash run.sh main.py {ip} +``` +* ip:ip是运行模型的机器所在的ip。 + +**Tips**:run.sh脚本中有一个参数是mx_rec_package_path,mx_rec_package_path是mxRec的安装目录,比如:/usr/local/python3.7.5/lib/python3.7/site-packages/mx_rec。 +这个参数在脚本是默认的,用户需要根据自己环境中mxRec实际安装的路径进行配置。 \ No newline at end of file diff --git a/examples/demo/little_demo_estimator/README.md b/examples/demo/little_demo_estimator/README.md new file mode 100644 index 00000000..aca25a34 --- /dev/null +++ b/examples/demo/little_demo_estimator/README.md @@ -0,0 +1,57 @@ +# estimator模式下demo模型运行说明 + +## 代码结构 +```shell +. +├── config.py # 模型配置文件 +├── dataset.py # 生成数据集的脚本 +├── main.py # 主函数 +├── nn_model_build.py # demo模型 +├── nn_model_input.py # 定义model_fn +├── nn_optim.py # 定义train的各个op +├── nn_reader.py # 定义input_fn +├── op_precision.ini # 算子执行模式配置 +├── random_data_generator.py # 数据生成器 +├── README.md # demo模型运行说明 +├── run.sh # demo运行脚本 +├── tf_adapter.py # 导入tf adapter +└── utils.py # 公共函数 +``` + +## 1.准备数据 +demo样例无需从其他地方下载数据集,在demo样例中mxRec会自动生成数据集,详情见dataset.py和random_data_generator.py。 + +## 2.准备运行环境 +运行环境可以参考[mxRec用户指南](https://www.hiascend.com/document/detail/zh/mind-sdk/60rc1/mxRec/mxrecug/mxrecug_0007.html) +“安装部署”章节进行准备。 + +## 3.安装mxRec +mxRec软件包可以通过[mxRec用户指南](https://www.hiascend.com/document/detail/zh/mind-sdk/60rc1/mxRec/mxrecug/mxrecug_0007.html) +“安装部署”>“环境准备”>“获取软件包”章节提供的链接进行下载,选择自己需要的架构(x86或者arm)的mxRec包。下载完成之后,将mxRec包解压,进入解压后的目录(mindxsdk-mxrec) +如下: +```shell +. +├── cust_op +│ └── cust_op_by_addr +├── examples +│ ├── DCNv2 +│ ├── demo +│ └── dlrm +├── tf1_whl +│ └── mx_rec-{version}-py3-none-linux_x86_64.whl # version为版本号 +├── tf2_whl +│ └── mx_rec-{version}-py3-none-linux_x86_64.whl # version为版本号 +└── version.info +``` +其中,tf1_whl和tf2_whl目录下分别是适配tf1和tf2的mxRec软件包,按照自己需要选择其中一个进行安装即可(用pip/pip3 install 软件包这种方式进行安装)。 +确认安装mxRec的目录,比如mxRec安装在 /usr/local/python3.7.5/lib/python3.7/site-packages/mx_rec这个目录下。 + +## 4.运行demo模型 +执行完以上步骤之后,接下来就可以运行demo模型,其中run.sh就是运行的脚本,默认是8张卡。其中需要传入ip这个参数,运行命令如: +```shell +bash run.sh main.py {ip} +``` +* ip:ip是运行模型的机器所在的ip。 + +**Tips**:run.sh脚本中有一个参数是mx_rec_package_path,mx_rec_package_path是mxRec的安装目录,比如:/usr/local/python3.7.5/lib/python3.7/site-packages/mx_rec。 +这个参数在脚本是默认的,用户需要根据自己环境中mxRec实际安装的路径进行配置。 \ No newline at end of file -- Gitee From 38866e896710f3ff873083c02216bb638672aecd Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E7=BD=97=E5=B9=B8=E8=BF=90?= Date: Wed, 10 Jul 2024 16:16:29 +0800 Subject: [PATCH 28/37] =?UTF-8?q?=E3=80=90FIX=E3=80=91DDR=E6=A8=A1?= =?UTF-8?q?=E5=BC=8F=E5=9C=A8device=E6=B5=8B=E8=BF=90=E8=A1=8C=E8=BE=83?= =?UTF-8?q?=E5=BF=AB=E7=9A=84=E6=83=85=E5=86=B5=E4=B8=8B=EF=BC=8Chost?= =?UTF-8?q?=E6=B5=8B=E7=94=B3=E8=AF=B7=E5=86=85=E5=AD=98=E5=92=8C=E5=88=9D?= =?UTF-8?q?=E5=A7=8B=E5=8C=96=E6=85=A2=EF=BC=8C=E5=AF=BC=E8=87=B4=E6=8A=A5?= =?UTF-8?q?=E9=94=99=E9=80=80=E5=87=BA?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/AccCTR/src/embedding_cache/offset_mapper/address_mapper.h | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/AccCTR/src/embedding_cache/offset_mapper/address_mapper.h b/src/AccCTR/src/embedding_cache/offset_mapper/address_mapper.h index 46daaf29..3b87e6e6 100644 --- a/src/AccCTR/src/embedding_cache/offset_mapper/address_mapper.h +++ b/src/AccCTR/src/embedding_cache/offset_mapper/address_mapper.h @@ -109,7 +109,7 @@ public: fullCv.notify_all(); } - BeforePutFuncState GetNewValueToBeInserted(uint64_t& value, uint32_t maxRetry = 1000) + BeforePutFuncState GetNewValueToBeInserted(uint64_t& value, uint32_t maxRetry = 10000) { for (uint32_t i = 0; i < maxRetry; i++) { if (BufferBin.pop(value)) { @@ -252,7 +252,7 @@ public: FkvState FindAndPutIfNotFound(uint64_t key, uint64_t& value) { FkvState ret = MapperBase::FindAndPutIfNotFound(key, value, [&]() { - if (HM_UNLIKELY(current_size.load() >= hostVocabSize)) { + if (HM_UNLIKELY(current_size.load() > hostVocabSize)) { ock::ExternalLogger::PrintLog(ock::LogLevel::ERROR, "host does not have enough space"); return BeforePutFuncState::BEFORE_NO_SPACE; } -- Gitee From d076f903d3669c8194312b793de1349586f9f1b9 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E7=BD=97=E5=B9=B8=E8=BF=90?= Date: Wed, 10 Jul 2024 16:55:41 +0800 Subject: [PATCH 29/37] =?UTF-8?q?=E3=80=90FIX=E3=80=91DDR=E6=A8=A1?= =?UTF-8?q?=E5=BC=8F=E6=8A=A5=E9=94=99host=E7=A9=BA=E9=97=B4=E4=B8=8D?= =?UTF-8?q?=E8=B6=B3?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../src/embedding_cache/offset_mapper/address_mapper.h | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/src/AccCTR/src/embedding_cache/offset_mapper/address_mapper.h b/src/AccCTR/src/embedding_cache/offset_mapper/address_mapper.h index 3b87e6e6..8b7e4e67 100644 --- a/src/AccCTR/src/embedding_cache/offset_mapper/address_mapper.h +++ b/src/AccCTR/src/embedding_cache/offset_mapper/address_mapper.h @@ -109,7 +109,7 @@ public: fullCv.notify_all(); } - BeforePutFuncState GetNewValueToBeInserted(uint64_t& value, uint32_t maxRetry = 10000) + BeforePutFuncState GetNewValueToBeInserted(uint64_t& value, uint32_t maxRetry = 1000) { for (uint32_t i = 0; i < maxRetry; i++) { if (BufferBin.pop(value)) { @@ -252,8 +252,11 @@ public: FkvState FindAndPutIfNotFound(uint64_t key, uint64_t& value) { FkvState ret = MapperBase::FindAndPutIfNotFound(key, value, [&]() { - if (HM_UNLIKELY(current_size.load() > hostVocabSize)) { - ock::ExternalLogger::PrintLog(ock::LogLevel::ERROR, "host does not have enough space"); + if (HM_UNLIKELY(current_size.load() >= hostVocabSize)) { + ock::ExternalLogger::PrintLog( + ock::LogLevel::ERROR, + "host does not have enough space, current: " + std::to_string(current_size.load()) + + ", host max size: " + std::to_string(hostVocabSize)); return BeforePutFuncState::BEFORE_NO_SPACE; } return emExpendMemInfoPtr->GetNewValueToBeInserted(value); -- Gitee From db89f0016478fcf3f9bde8481d0ae8ad4a1cb934 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E4=BD=95=E9=9C=96?= Date: Fri, 12 Jul 2024 18:05:31 +0800 Subject: [PATCH 30/37] =?UTF-8?q?=E4=BF=AE=E6=94=B9mxRec=20README?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- README.md | 32 +++++++++++++++++++++++++++----- 1 file changed, 27 insertions(+), 5 deletions(-) diff --git a/README.md b/README.md index f6bfb828..5a2d9c03 100644 --- a/README.md +++ b/README.md @@ -14,7 +14,7 @@ mxRec作为面向互联网市场搜索推荐广告的应用使能SDK产品,对 ## 安装方式 -安装前,请参考《CANN 软件安装指南》安装CANN开发套件软件包和TensorFlow适配昇腾插件。 +安装前,请参考[CANN 软件安装指南](https://www.hiascend.com/document/detail/zh/CANNCommunityEdition/80RC2alpha003/softwareinst/instg/instg_0022.html)安装CANN开发套件软件包和TensorFlow适配昇腾插件。 CANN软件提供进程级环境变量设置脚本,供用户在进程中引用,以自动完成环境变量设置。用户进程结束后自动失效。可在程序启动的Shell脚本中使用如下命令设置CANN的相关环境变量,也可通过命令行执行如下命令(以root用户默认安装路径“/usr/local/Ascend”为例): ```shell @@ -65,12 +65,34 @@ bash run.sh 将pybind11和securec的压缩包放在与mxRec代码同级的opensource目录下,并且将其分别更名为pybind11-2.10.3.zip、huaweicloud-sdk-c-obs-3.23.9.zip。如果没有opensource目录,则需要在mxRec同级的目录下手动创建opensource目录,然后将pybind11和securec的压缩包放在opensource目录下。 -为了构建多个版本的whl包,编译脚本在python虚拟环境完成对应tensorflow版本的安装。用户可以根据实际情况调整编译脚本,指定tensorflow的安装路径。编译方法: +由于构建脚本需要适配内部构建工程,所以在脚本中存在适配代码,但是这些代码可能对于用户来说不需要,所以在编译之前需要做如下处理: + +在build目录中存在build_tf1.sh和build_tf2.sh,其中分别存在如下代码: +```shell +# 配置tf1路径 +source /opt/buildtools/tf1_env/bin/activate +tf1_path=$(dirname "$(dirname "$(which python3.7)")")/lib/python3.7/site-packages/tensorflow_core +deactivate tf1_env +``` +```shell +# 配置tf2路径 +source /opt/buildtools/tf2_env/bin/activate +tf2_path=$(dirname "$(dirname "$(which python3.7)")")/lib/python3.7/site-packages/tensorflow +deactivate tf2_env +``` + +可以看到,上述代码中都有激活Python虚拟环境的步骤,因此用户有两种选择: + +1. 根据需要在/opt/buildtools/目录下(没有此目录需要先创建)创建tf1_env和tf2_env两个Python虚拟环境,并在虚拟环境中安装对应版本的Tensorflow +2. 将source /opt/buildtools/tf1_env/bin/activate和deactivate tf1_env注释掉或者删除或者source /opt/buildtools/tf2_env/bin/activate和deactivate tf2_env注释掉或者删除 + + +编译方法: 进入mxRec代码目录: -- setup.py:执行脚本setup.py,比如:**python3.7 setup.py**完成tf1和tf2版本whl包的构建和打包,构建成功后,whl包在build/mindxsdk-mxrec/目录下,其中tf1_whl和tf2_whl目录下存在对应的whl包。执行脚本前,请参考build/build_tf1.sh、build/build_tf2.sh创建对应的虚拟环境,在虚拟环境中完成对应tensorflow版本的安装,并修改对应的激活命令。 -- setup_tf1.py:执行脚本setup_tf1.py,比如:**python3.7 setup_tf1.py bdist_wheel**完成tf1版本whl包的构建,构建成功后,whl包在build/mindxsdk-mxrec/tf1_whl子目录下。执行脚本前,请参考build/build_tf1.sh创建tf1虚拟环境,在虚拟环境中完成tensorflow 1.15.0版本的安装,并修改对应的激活命令。 -- setup_tf2.py:执行脚本setup_tf2.py,比如:**python3.7 setup_tf2.py bdist_wheel**完成tf2版本whl包的构建,构建成功后,whl包在build/mindxsdk-mxrec/tf2_whl子目录下。执行脚本前,请参考build/build_tf2.sh创建tf2虚拟环境,在虚拟环境中完成tensorflow 2.6.5版本的安装,并修改对应的激活命令。 +- setup.py:此脚本供内部使用,用于同时构建tf1和tf2的mxRec包,用户通常只需要其中一个,所以建议使用下面两个脚本构建。 +- setup_tf1.py:执行脚本setup_tf1.py,比如:**python3.7 setup_tf1.py bdist_wheel**完成tf1版本whl包的构建,构建成功后,whl包在build/mindxsdk-mxrec/tf1_whl子目录下。 +- setup_tf2.py:执行脚本setup_tf2.py,比如:**python3.7 setup_tf2.py bdist_wheel**完成tf2版本whl包的构建,构建成功后,whl包在build/mindxsdk-mxrec/tf2_whl子目录下。 如需使用动态扩容功能,进入“./cust_op/cust_op_by_addr”目录中。参考以下命令编译并安装动态扩容算子包。 ```shell -- Gitee From c2d469d400a520846808d08ad7cb1016c0e462ab Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E9=83=AD=E6=9C=9B?= <1244372993@qq.com> Date: Fri, 12 Jul 2024 18:20:22 +0800 Subject: [PATCH 31/37] =?UTF-8?q?Little=20demo=E6=A8=A1=E5=9E=8Bestimator?= =?UTF-8?q?=E6=A8=A1=E5=BC=8FDDR=E4=BF=9D=E5=AD=98=E9=97=AE=E9=A2=98?= =?UTF-8?q?=E4=BF=AE=E6=94=B9=EF=BC=9B=E9=97=A8=E7=A6=81=E6=B5=8B=E8=AF=95?= =?UTF-8?q?=E7=94=A8=E4=BE=8B=E4=BF=AE=E6=94=B9=EF=BC=9B?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- mx_rec/saver/saver.py | 9 +++++++++ src/core/hybrid_mgmt/hybrid_mgmt.cpp | 10 +++++++++- tests/mx_rec/saver/test_saver.py | 6 +++--- 3 files changed, 21 insertions(+), 4 deletions(-) diff --git a/mx_rec/saver/saver.py b/mx_rec/saver/saver.py index a6362506..f7ba8f03 100644 --- a/mx_rec/saver/saver.py +++ b/mx_rec/saver/saver.py @@ -127,6 +127,15 @@ class Saver(object): save_path = save_path if save_path else self._prefix_name directory, base_name = os.path.split(save_path) + # skip save in step-0, cause host skip save in step-0 EmbeddingDDR::Save SyncLatestEmbedding + try: + step_in_name = int(base_name.split("-")[-1]) + if step_in_name == 0: + return + except ValueError as err: + raise ValueError(f"The base_name {base_name} needs to include save_step message " + f"eg: mode-100") from err + if global_step: if not isinstance(global_step, compat.integral_types): global_step = int(sess.run(global_step)) diff --git a/src/core/hybrid_mgmt/hybrid_mgmt.cpp b/src/core/hybrid_mgmt/hybrid_mgmt.cpp index 3eb99685..bcc3a2a5 100644 --- a/src/core/hybrid_mgmt/hybrid_mgmt.cpp +++ b/src/core/hybrid_mgmt/hybrid_mgmt.cpp @@ -499,7 +499,15 @@ void HybridMgmt::EvalTask(TaskType type) hybridMgmtBlock->IsNeedWaitSave()); std::unique_lock checkSaveLocker(saveMutex); cvCheckSave.wait(checkSaveLocker, [this] { return !hybridMgmtBlock->IsNeedWaitSave() || mutexDestroy; }); - hybridMgmtBlock->Wake(TRAIN_CHANNEL_ID); + + if (hybridMgmtBlock->pythonBatchID[EVAL_CHANNEL_ID] >= hybridMgmtBlock->hybridBatchId[EVAL_CHANNEL_ID]) + { + hybridMgmtBlock->Wake(TRAIN_CHANNEL_ID); + } else { + std::this_thread::sleep_for(SLEEP_MS); + continue; + } + LOG_DEBUG("wake TrainTask"); hybridMgmtBlock->DoBlock(channelId); } diff --git a/tests/mx_rec/saver/test_saver.py b/tests/mx_rec/saver/test_saver.py index bcfa0948..53066038 100644 --- a/tests/mx_rec/saver/test_saver.py +++ b/tests/mx_rec/saver/test_saver.py @@ -61,18 +61,18 @@ class TestSaver(unittest.TestCase): self.saver = Saver() with tf.compat.v1.Session(graph=self.graph) as sess: - embedding_directory = "./sparse-model/test_table/embedding" + embedding_directory = "./sparse-model-1/test_table/embedding" data_file = os.path.join(embedding_directory, "slice.data") attribute_file = os.path.join(embedding_directory, "slice.attribute") sess.run(tf.global_variables_initializer()) origin_embedding = sess.run(self.var)[[0, 1, 4, 6, 8], :] - self.saver.save(sess) + self.saver.save(sess, save_path="model-1") self.assertTrue(os.path.exists(embedding_directory), "embedding目录已创建") self.assertTrue(os.path.exists(data_file), "embedding的data文件存储成功") self.assertTrue(os.path.exists(attribute_file), "embedding的attribute文件存储成功") - tf.io.gfile.rmtree("./sparse-model") + tf.io.gfile.rmtree("./sparse-model-1") def build_graph(self): self.graph = tf.compat.v1.Graph() -- Gitee From 6775ab93f0b004a6bbe15ce0d56f58da5df35745 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E9=83=AD=E6=9C=9B?= <1244372993@qq.com> Date: Fri, 12 Jul 2024 18:35:54 +0800 Subject: [PATCH 32/37] =?UTF-8?q?=E6=8B=BC=E5=86=99=E4=BF=AE=E6=94=B9?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/core/hybrid_mgmt/hybrid_mgmt.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/core/hybrid_mgmt/hybrid_mgmt.cpp b/src/core/hybrid_mgmt/hybrid_mgmt.cpp index bcc3a2a5..737cdb1d 100644 --- a/src/core/hybrid_mgmt/hybrid_mgmt.cpp +++ b/src/core/hybrid_mgmt/hybrid_mgmt.cpp @@ -500,7 +500,7 @@ void HybridMgmt::EvalTask(TaskType type) std::unique_lock checkSaveLocker(saveMutex); cvCheckSave.wait(checkSaveLocker, [this] { return !hybridMgmtBlock->IsNeedWaitSave() || mutexDestroy; }); - if (hybridMgmtBlock->pythonBatchID[EVAL_CHANNEL_ID] >= hybridMgmtBlock->hybridBatchId[EVAL_CHANNEL_ID]) + if (hybridMgmtBlock->pythonBatchId[EVAL_CHANNEL_ID] >= hybridMgmtBlock->hybridBatchId[EVAL_CHANNEL_ID]) { hybridMgmtBlock->Wake(TRAIN_CHANNEL_ID); } else { -- Gitee From cb43c6a8da89f2a25118df6d68631eec9549998d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E9=83=AD=E6=9C=9B?= <1244372993@qq.com> Date: Fri, 12 Jul 2024 19:05:56 +0800 Subject: [PATCH 33/37] =?UTF-8?q?cleancode=E4=BF=AE=E6=94=B9?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/core/hybrid_mgmt/hybrid_mgmt.cpp | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/src/core/hybrid_mgmt/hybrid_mgmt.cpp b/src/core/hybrid_mgmt/hybrid_mgmt.cpp index 737cdb1d..93954401 100644 --- a/src/core/hybrid_mgmt/hybrid_mgmt.cpp +++ b/src/core/hybrid_mgmt/hybrid_mgmt.cpp @@ -500,9 +500,8 @@ void HybridMgmt::EvalTask(TaskType type) std::unique_lock checkSaveLocker(saveMutex); cvCheckSave.wait(checkSaveLocker, [this] { return !hybridMgmtBlock->IsNeedWaitSave() || mutexDestroy; }); - if (hybridMgmtBlock->pythonBatchId[EVAL_CHANNEL_ID] >= hybridMgmtBlock->hybridBatchId[EVAL_CHANNEL_ID]) - { - hybridMgmtBlock->Wake(TRAIN_CHANNEL_ID); + if (hybridMgmtBlock->pythonBatchId[EVAL_CHANNEL_ID] >= hybridMgmtBlock->hybridBatchId[EVAL_CHANNEL_ID]) { + hybridMgmtBlockgi->Wake(TRAIN_CHANNEL_ID); } else { std::this_thread::sleep_for(SLEEP_MS); continue; -- Gitee From d2ba56b47391194d99c807df5bd8879437cb6418 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E9=83=AD=E6=9C=9B?= <1244372993@qq.com> Date: Fri, 12 Jul 2024 19:20:39 +0800 Subject: [PATCH 34/37] =?UTF-8?q?=E6=8B=BC=E5=86=99=E4=BF=AE=E6=94=B9?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/core/hybrid_mgmt/hybrid_mgmt.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/core/hybrid_mgmt/hybrid_mgmt.cpp b/src/core/hybrid_mgmt/hybrid_mgmt.cpp index 93954401..cab348ba 100644 --- a/src/core/hybrid_mgmt/hybrid_mgmt.cpp +++ b/src/core/hybrid_mgmt/hybrid_mgmt.cpp @@ -501,7 +501,7 @@ void HybridMgmt::EvalTask(TaskType type) cvCheckSave.wait(checkSaveLocker, [this] { return !hybridMgmtBlock->IsNeedWaitSave() || mutexDestroy; }); if (hybridMgmtBlock->pythonBatchId[EVAL_CHANNEL_ID] >= hybridMgmtBlock->hybridBatchId[EVAL_CHANNEL_ID]) { - hybridMgmtBlockgi->Wake(TRAIN_CHANNEL_ID); + hybridMgmtBlock->Wake(TRAIN_CHANNEL_ID); } else { std::this_thread::sleep_for(SLEEP_MS); continue; -- Gitee From 33c03cadd19b48fb11daaac4045925bd13a4236f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E7=BD=97=E5=B9=B8=E8=BF=90?= Date: Mon, 15 Jul 2024 09:49:14 +0800 Subject: [PATCH 35/37] =?UTF-8?q?=E3=80=90FIX=E3=80=91DDR=E6=A8=A1?= =?UTF-8?q?=E5=BC=8F=E5=81=B6=E5=8F=91=E6=8A=A5=E9=94=99=E7=A9=BA=E9=97=B4?= =?UTF-8?q?=E4=B8=8D=E8=B6=B3?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/core/hybrid_mgmt/hybrid_mgmt.cpp | 41 ++++++++++++++++++---------- src/core/hybrid_mgmt/hybrid_mgmt.h | 2 +- 2 files changed, 27 insertions(+), 16 deletions(-) diff --git a/src/core/hybrid_mgmt/hybrid_mgmt.cpp b/src/core/hybrid_mgmt/hybrid_mgmt.cpp index 3eb99685..73c30e13 100644 --- a/src/core/hybrid_mgmt/hybrid_mgmt.cpp +++ b/src/core/hybrid_mgmt/hybrid_mgmt.cpp @@ -959,30 +959,43 @@ void HybridMgmt::LookUpAndRemoveAddrs(const EmbTaskInfo& info) } // DDR -void HybridMgmt::LookUpSwapAddrs(const string& embName, const string& swapStr) +void HybridMgmt::LookUpSwapAddrs(const string& embName) { int id = 0; - std::string swapName = embName + swapStr; + std::string swapInName = embName + SWAP_IN_STR; + std::string swapOutName = embName + SWAP_OUT_STR; + vector addrs; while (isRunning && lookupAddrSuccess) { - std::vector keys = HBMSwapKeyQue[swapName].WaitAndPop(); if (!isRunning) { return; } - vector addrs; - TimeCost lookupAddrsTC; + // swap in + std::vector keys = HBMSwapKeyQue[swapInName].WaitAndPop(); + TimeCost lookupAddrsInTC; int rc = embCache->EmbeddingLookupAddrs(embName, keys, addrs); if (rc != H_OK) { lookupAddrSuccess = false; throw runtime_error("EmbeddingLookupAddrs failed! error code: " + std::to_string(rc)); } - LOG_DEBUG("table:{}, swapStr:{}, keys.size:{}, addrs.size:{}, pushId:{}, lookupAddrsTC(ms):{}", embName, - swapStr, keys.size(), addrs.size(), id, lookupAddrsTC.ElapsedMS()); - HBMSwapAddrsQue[swapName].Pushv(addrs); - if (swapStr == SWAP_IN_STR) { - lookUpSwapInAddrsPushId[embName]++; - LOG_DEBUG("LookUpSwapAddrs, table:{}, pushId:{}, lookUpSwapInAddrsPushId:{}", embName, id, - lookUpSwapInAddrsPushId[embName]); + LOG_DEBUG("table:{}, swapStr:{}, keys.size:{}, addrs.size:{}, pushId:{}, lookupAddrsInTC(ms):{}", embName, + SWAP_IN_STR, keys.size(), addrs.size(), id, lookupAddrsInTC.ElapsedMS()); + HBMSwapAddrsQue[swapInName].Pushv(addrs); + + lookUpSwapInAddrsPushId[embName]++; + LOG_DEBUG("LookUpSwapAddrs, table:{}, pushId:{}, lookUpSwapInAddrsPushId:{}", embName, id, + lookUpSwapInAddrsPushId[embName]); + + // swap out + keys = HBMSwapKeyQue[swapOutName].WaitAndPop(); + TimeCost lookupAddrsOutTC; + rc = embCache->EmbeddingLookupAddrs(embName, keys, addrs); + if (rc != H_OK) { + lookupAddrSuccess = false; + throw runtime_error("EmbeddingLookupAddrs failed! error code: " + std::to_string(rc)); } + LOG_DEBUG("table:{}, swapStr:{}, keys.size:{}, addrs.size:{}, pushId:{}, lookupAddrsOutTC(ms):{}", embName, + SWAP_OUT_STR, keys.size(), addrs.size(), id, lookupAddrsOutTC.ElapsedMS()); + HBMSwapAddrsQue[swapOutName].Pushv(addrs); id++; } } @@ -1242,9 +1255,7 @@ void HybridMgmt::InitDataPipelineForDDR(const string& embName) // 初始化lookup线程 lookUpSwapInAddrsPushId[embName]; // 此处初始化,避免多线程竞争导致计数错误 lookUpSwapInAddrsThreads.emplace_back( - std::async(std::launch::async, [=] { LookUpSwapAddrs(embName, SWAP_IN_STR); })); - lookUpSwapOutAddrsThreads.emplace_back( - std::async(std::launch::async, [=] { LookUpSwapAddrs(embName, SWAP_OUT_STR); })); + std::async(std::launch::async, [=] { LookUpSwapAddrs(embName); })); LOG_DEBUG("data pipeline for ddr init"); } diff --git a/src/core/hybrid_mgmt/hybrid_mgmt.h b/src/core/hybrid_mgmt/hybrid_mgmt.h index ab34b19f..57a7ddd1 100644 --- a/src/core/hybrid_mgmt/hybrid_mgmt.h +++ b/src/core/hybrid_mgmt/hybrid_mgmt.h @@ -187,7 +187,7 @@ public: void LookUpAndRemoveAddrs(const EmbTaskInfo& info); // L3Storage, synchronous - void LookUpSwapAddrs(const std::string& embName, const std::string& swapStr); // DDR, asynchronous + void LookUpSwapAddrs(const std::string& embName); // DDR, asynchronous void EmbeddingTask(); -- Gitee From 5592a8e616f1ca0e98873b2ac84ec94fdeb20fc5 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E7=BD=97=E5=B9=B8=E8=BF=90?= Date: Mon, 15 Jul 2024 17:19:33 +0800 Subject: [PATCH 36/37] =?UTF-8?q?=E3=80=90FIX=E3=80=91DDR=E6=A8=A1?= =?UTF-8?q?=E5=BC=8F=E5=81=B6=E5=8F=91=E6=8A=A5=E9=94=99=E7=A9=BA=E9=97=B4?= =?UTF-8?q?=E4=B8=8D=E8=B6=B3?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/core/hybrid_mgmt/hybrid_mgmt.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/core/hybrid_mgmt/hybrid_mgmt.cpp b/src/core/hybrid_mgmt/hybrid_mgmt.cpp index 9921fe27..f8ad9216 100644 --- a/src/core/hybrid_mgmt/hybrid_mgmt.cpp +++ b/src/core/hybrid_mgmt/hybrid_mgmt.cpp @@ -971,7 +971,7 @@ void HybridMgmt::LookUpSwapAddrs(const string& embName) int id = 0; std::string swapInName = embName + SWAP_IN_STR; std::string swapOutName = embName + SWAP_OUT_STR; - vector addrs; + std::vector addrs; while (isRunning && lookupAddrSuccess) { if (!isRunning) { return; -- Gitee From 45f3fe4365341c6024d52c589c77bb9af41e5248 Mon Sep 17 00:00:00 2001 From: penghuiyang <1060916628@qq.com> Date: Mon, 15 Jul 2024 22:20:31 +0800 Subject: [PATCH 37/37] =?UTF-8?q?=E3=80=90FIX=E3=80=91LazyAdam=E8=9E=8D?= =?UTF-8?q?=E5=90=88=E7=AE=97=E5=AD=90=E6=96=B0=E7=89=88=E6=9C=ACCANN?= =?UTF-8?q?=E7=BC=96=E8=AF=91=E5=A4=B1=E8=B4=A5=E4=BF=AE=E6=94=B9=EF=BC=9B?= =?UTF-8?q?=E8=AE=A1=E7=AE=97=E9=80=BB=E8=BE=91=E5=90=8C=E6=AD=A5py?= =?UTF-8?q?=E8=84=9A=E6=9C=AC=EF=BC=9B?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../aclnn_lazy_adam_test/scripts/gen_data.py | 2 +- cust_op/fused_lazy_adam/op_kernel/lazy_adam.cpp | 6 ++++++ 2 files changed, 7 insertions(+), 1 deletion(-) diff --git a/cust_op/fused_lazy_adam/aclnn_lazy_adam_test/scripts/gen_data.py b/cust_op/fused_lazy_adam/aclnn_lazy_adam_test/scripts/gen_data.py index 6e07f836..6e8c9251 100644 --- a/cust_op/fused_lazy_adam/aclnn_lazy_adam_test/scripts/gen_data.py +++ b/cust_op/fused_lazy_adam/aclnn_lazy_adam_test/scripts/gen_data.py @@ -121,7 +121,7 @@ def _gen_golden_data(): update_v = beta2 * old_v_slice + (1 - beta2) * np.square(gradient) out_v = _scatter_nd_update(input_v, indices, update_v) - denominator_slice = np.sqrt(update_v) + epsilon + denominator_slice = np.sqrt(np.abs(update_v)) + epsilon update_var = np.divide(-lr * update_m, denominator_slice) out_var = _scatter_nd_add(input_var, indices, update_var) diff --git a/cust_op/fused_lazy_adam/op_kernel/lazy_adam.cpp b/cust_op/fused_lazy_adam/op_kernel/lazy_adam.cpp index 76164e50..e0ad8e45 100644 --- a/cust_op/fused_lazy_adam/op_kernel/lazy_adam.cpp +++ b/cust_op/fused_lazy_adam/op_kernel/lazy_adam.cpp @@ -176,6 +176,7 @@ private: this->updateV = localVSlice + this->updateV; // 计算Var + Abs(this->updateV, this->updateV, row * this->dim2); Sqrt(this->updateVar, this->updateV, row * this->dim2); Adds(this->updateVar, this->updateVar, this->epsilon, row * this->dim2); Muls(this->temp, this->updateM, -this->lr, row * this->dim2); @@ -233,5 +234,10 @@ extern "C" __global__ __aicore__ void lazy_adam(GM_ADDR gradient, GM_ADDR indice tiling_data.row, tiling_data.indicesAllocSize, tiling_data.otherAllocSize, tiling_data.batch, tiling_data.loopCount, tiling_data.rowLeft, tiling_data.loopCountTail, tiling_data.rowLeftTail, tiling_data.coreNum); +#ifdef KERNEL_TASK_TYPE_DEFAULT + // Set kernel type with new versions of CANN to avoid matmul error during compiling. + // In previous versions of CANN, avoid matmul error by using '#ifndef __GET_CODE_CHANNEL__'. + KERNEL_TASK_TYPE_DEFAULT(KERNEL_TYPE_AIV_ONLY); +#endif op32.Process(); } \ No newline at end of file -- Gitee