31 Star 146 Fork 232

Ascend/mindsdk-referenceapps

加入 Gitee
与超过 1200万 开发者一起发现、参与优秀开源项目,私有仓库也完全免费 :)
免费加入
文件
克隆/下载
TestAscendIndexVStar.cpp 4.42 KB
一键复制 编辑 原始数据 按行查看 历史
wwwttt001 提交于 7个月前 . change name
/*
* Copyright(C) 2024. Huawei Technologies Co.,Ltd. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <numeric>
#include <cmath>
#include <random>
#include <iostream>
#include <cfloat>
#include <gtest/gtest.h>
#include <faiss/ascend/AscendIndexVStar.h>
#include <cstring>
#include <sys/time.h>
#include <cstdlib>
#include <memory>
using namespace faiss::ascend;
namespace {
const int MILLI_SECOND = 1000;
inline double GetMillisecs()
{
struct timeval tv = {0, 0};
gettimeofday(&tv, nullptr);
return tv.tv_sec * 1e3 + tv.tv_usec * 1e-3;
}
void Generate(size_t ntotal, std::vector<float> &data, int seed = 5678)
{
int maxValue = 255;
int offset = 128;
std::default_random_engine e(seed);
std::uniform_real_distribution<float> rCode(0.0f, 1.0f);
data.resize(ntotal);
for (size_t i = 0; i < ntotal; ++i) {
data[i] = static_cast<float>(maxValue * rCode(e) - offset);
}
}
void Norm(std::vector<float> &data, int dim)
{
float square = 0.0;
int nTotal = (dim == 0) ? 0 : static_cast<int>(data.size() / dim);
for (int i = 0; i < nTotal; ++i) {
square = 0.0;
for (int j = 0; j < dim; ++j) {
square += pow(data[i * dim + j], 2); // 2是先求平方,后续开根
}
square = sqrt(square);
if (fabs(square) < FLT_EPSILON) {
std::cerr << "Error: Invalid square value." << std::endl;
return;
}
for (int j = 0; j < dim; ++j) {
data[static_cast<size_t>(i) * dim + j] /= square;
}
}
}
/**
* AKMode需要提前生成算子和码本
* 码本和算子参数根据实际情况调整, dim nlistL1 subDimL1 要与创建的索引一致
* 算子:python3 vstar_generate_models.py --dim 1024 --nlist1 1024 --subDimL1 32
* 码本:python3 vstar_train_codebook.py --dataPath {实际base数据路径} --dim 1024 --codebookPath {实际码本输出路径}
--nListL1 1024 --subDimL1 32 --device 0
*/
TEST(TestAscendIndexVstar, Test_Search_Func)
{
int dim = 1024;
size_t ntotal = 1e5;
int nlist = 256;
int subSpaceDim = 128;
std::vector<int> devices = {0};
bool verbose = false;
AscendIndexVstarInitParams aParams(dim, subSpaceDim, nlist, devices, verbose);
auto index = std::make_shared<AscendIndexVStar>(aParams);
// 添加码本 需要提前生成好码本路径
std::string codebook = "/home/work/codebook_256_1024_128/codebook_l1_l2.bin";
auto ret = index->AddCodeBooksByPath(codebook);
EXPECT_EQ(ret, 0);
// 生成base底库数据
std::vector<float> data(ntotal);
Generate(ntotal * dim, data);
// 标准化
Norm(data, dim);
// add底库
index->Add(data);
size_t total = 0;
index->GetNTotal(total);
EXPECT_EQ(total, ntotal);
// search检索
int topk = 100;
int warmUpTimes = 10;
size_t nq = 9000;
std::vector<float> distsWarm(nq * topk);
std::vector<int64_t> labelsWarm(nq * topk);
// warm up
for (int i = 0; i < warmUpTimes; ++i) {
AscendIndexSearchParams searchParamsWarm {100, data, topk, distsWarm, labelsWarm};
index->Search(searchParamsWarm);
}
// search
std::vector<size_t> searchNum = {1, 8, 16, 32, 64, 128, 256};
int loopTimes = 100;
for (auto n : searchNum) {
std::vector<float> queryData(data.begin(), data.begin() + n * dim);
std::vector<float> dists(n * topk, 0);
std::vector<int64_t> labels(n * topk, 0);
double ts = GetMillisecs();
for (int i = 0; i < loopTimes; ++i) {
AscendIndexSearchParams searchParams {n, queryData, topk, dists, labels};
index->Search(searchParams);
}
double te = GetMillisecs();
printf("base:%zu, dim:%d, search num:%zu, QPS:%.4f\n",
ntotal, dim, n, MILLI_SECOND * n * loopTimes / (te - ts));
}
}
} // namespace
int main(int argc, char **argv)
{
testing::InitGoogleTest(&argc, argv);
return RUN_ALL_TESTS();
}
Loading...
马建仓 AI 助手
尝试更多
代码解读
代码找茬
代码优化
C++
1
https://gitee.com/ascend/mindsdk-referenceapps.git
git@gitee.com:ascend/mindsdk-referenceapps.git
ascend
mindsdk-referenceapps
mindsdk-referenceapps
master

搜索帮助