29 Star 141 Fork 228

Ascend/mindsdk-referenceapps

加入 Gitee
与超过 1200万 开发者一起发现、参与优秀开源项目,私有仓库也完全免费 :)
免费加入
文件
克隆/下载
TestAscendIndexIVFSQT.cpp 5.26 KB
一键复制 编辑 原始数据 按行查看 历史
/*
* Copyright(C) 2020. Huawei Technologies Co.,Ltd. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
// 需要生成ivfsqt算子(-d 256 -c 16384)+flat_at算子(-d 256 -c 16384)+flat_at_int8算子(-d 256 -c 16384)
#include <algorithm>
#include <cmath>
#include <iostream>
#include <vector>
#include <random>
#include <cfloat>
#include <sys/time.h>
#include <faiss/ascend/custom/AscendIndexIVFSQT.h>
#include <faiss/ascend/AscendCloner.h>
#include <faiss/index_io.h>
using namespace std;
void Norm(float *data, size_t n, size_t dim)
{
#pragma omp parallel for if(n > 100)
for (size_t i = 0; i < n; ++i) {
float l2norm = 0;
for (size_t j = 0; j < dim; ++j) {
l2norm += data[i * dim + j] * data[i * dim + j];
}
l2norm = sqrt(l2norm);
if (fabs(l2norm) < FLT_EPSILON) {
std::cerr << "Error: Invalid l2norm value." << std::endl;
}
for (size_t j = 0; j < dim; ++j) {
data[i * dim + j] = data[i * dim + j] / l2norm;
}
}
}
const size_t DIM_IN = 256;
const size_t DIM_OUT = 64;
const float THRESHOLD = 1.5;
const size_t ADD_TOTAL = 6000000;
const int FUZZYK = 3;
const size_t K = 100;
const size_t QUERY_NUM = 500000;
const std::vector<int> DEVICE = {0};
const int FAST_RAND_MAX = 0x7FFF;
unsigned int g_seed = 5678;
inline int FastRand(void)
{
const int mutipliyNum = 214013;
const int addNum = 2531011;
const int rshiftNum = 16;
g_seed = (mutipliyNum * g_seed + addNum);
return (g_seed >> rshiftNum) & FAST_RAND_MAX;
}
inline double GetMillisecs()
{
struct timeval tv {
0, 0
};
gettimeofday(&tv, nullptr);
return tv.tv_sec * 1e3 + tv.tv_usec * 1e-3;
}
void TestAscendToCpu(faiss::ascend::AscendIndexIVFSQT *index, std::vector<float>& data,
std::vector<float>& dist, std::vector<faiss::idx_t>& label)
{
const char *globalFileName = "IVFSQT.faiss";
cout << "Test index_ascend_to_cpu For IVFSQT, result save to " << globalFileName << endl;
faiss::Index *cpuIndex = faiss::ascend::index_ascend_to_cpu(index);
faiss::write_index(cpuIndex, globalFileName);
delete cpuIndex;
cout << "Test index_ascend_to_cpu For IVFSQT finished" << endl;
cout << "Test index_cpu_to_ascend For IVFSQT, read from " << globalFileName << endl;
faiss::Index *initIndex = faiss::read_index(globalFileName);
faiss::ascend::AscendIndexIVFSQT *realIndex =
dynamic_cast<faiss::ascend::AscendIndexIVFSQT *>(faiss::ascend::index_cpu_to_ascend(
DEVICE, initIndex));
cout << "Test index_cpu_to_ascend For IVFSQT, finished" << endl;
realIndex->search(QUERY_NUM, data.data(), K, dist.data(), label.data());
delete realIndex;
delete initIndex;
}
int main(int argc, char **argv)
{
const size_t nlist = 16384;
const size_t trainTotal = 2000000;
const int niter = 16;
const int centroids = 256;
std::vector<float> data(DIM_IN * ADD_TOTAL);
cout << "generate data" << endl;
for (size_t i = 0; i < data.size(); i++) {
data[i] = 1.0 * FastRand() / FAST_RAND_MAX;
}
Norm(data.data(), ADD_TOTAL, DIM_IN);
faiss::ascend::AscendIndexIVFSQT *index = nullptr;
try {
cout << "index start" << endl;
faiss::ascend::AscendIndexIVFSQTConfig conf({DEVICE});
conf.cp.niter = niter;
conf.useKmeansPP = true;
conf.cp.max_points_per_centroid = centroids;
cout << "index init" << endl;
index = new faiss::ascend::AscendIndexIVFSQT(DIM_IN, DIM_OUT, nlist,
faiss::ScalarQuantizer::QuantizerType::QT_8bit, faiss::MetricType::METRIC_INNER_PRODUCT, conf);
index->verbose = true;
index->setFuzzyK(FUZZYK);
index->setThreshold(THRESHOLD);
cout << "train start" << endl;
index->train(trainTotal, data.data());
cout << "add start" << endl;
index->add(ADD_TOTAL, data.data());
cout << "update start" << endl;
index->update();
cout << "search start" << endl;
std::vector<float> dist(QUERY_NUM * K, 0.0);
std::vector<faiss::idx_t> label(QUERY_NUM * K, 0);
double start = GetMillisecs();
index->search(QUERY_NUM, data.data(), K, dist.data(), label.data());
double end = GetMillisecs();
cout << "search finished successfully" << endl;
cout << "search time cost:" << end - start << " ms" << endl;
TestAscendToCpu(index, data, dist, label);
delete index;
} catch (faiss::FaissException& e) {
cout << "Exception caught!" << e.what() << endl;
if (index == nullptr) {
delete index;
}
return -1;
}
return 0;
}
Loading...
马建仓 AI 助手
尝试更多
代码解读
代码找茬
代码优化
C++
1
https://gitee.com/ascend/mindsdk-referenceapps.git
git@gitee.com:ascend/mindsdk-referenceapps.git
ascend
mindsdk-referenceapps
mindsdk-referenceapps
master

搜索帮助