1 Star 0 Fork 245

陈庆宇/faiss_dog_cat_question

加入 Gitee
与超过 1200万 开发者一起发现、参与优秀开源项目,私有仓库也完全免费 :)
免费加入
文件
该仓库未声明开源许可证文件(LICENSE),使用请关注具体项目描述及其代码上游依赖。
克隆/下载
FaissKNeighbors.py 2.39 KB
一键复制 编辑 原始数据 按行查看 历史
陈庆宇 提交于 2024-10-08 11:08 . update FaissKNeighbors.py.
import numpy as np # NumPy是一个用于科学计算的基础包,用于处理大型多维数组和矩阵
import faiss # FAISS库用于高效的相似度搜索和稠密向量的聚类
# 定义FaissKNeighbors类,用于执行基于FAISS的K近邻搜索
class FaissKNeighbors:
# 类初始化函数:初始化k值,FAISS资源对象res,以及用于存储数据的索引
def __init__(self, k=1, res=None):
self.index = None # 用于存储训练数据的索引
self.y = None # 用于存储训练数据的标签
self.k = k # 最近邻个数
self.res = res # FAISS GPU资源对象
# 训练函数:将训练数据加入到FAISS索引中
def fit(self, X, y):
# 初始化 self.index 为一个FAISS索引: IndexFlatL2, 该索引使用欧氏距离进行搜索
if self.res is not None:
self.index = faiss.IndexFlatL2(X.shape[1])
self.index = faiss.index_cpu_to_gpu(self.res, 0, self.index)
else:
self.index = faiss.IndexFlatL2(X.shape[1])
# 将训练数据加入到FAISS索引中
self.index.add(X.astype(np.float32))
# 初始化 self.y 为传入的 y
self.y = np.array(y)
# 预测函数:对新的数据集X进行分类预测
def predict(self, X):
# 搜索X中每个向量的k个最近邻
distances, indices = self.index.search(X.astype(np.float32), self.k)
# 根据索引获得最近邻的标签
votes = [self.y[i] for i in indices]
# 通过投票机制得出最终预测的标签
predictions = np.array([np.argmax(np.bincount(vote)) for vote in votes])
return predictions
# 评分函数:计算预测准确率
def score(self, X, y_true):
# 预测
predictions = self.predict(X)
# 计算准确率
accuracy = np.mean(predictions == y_true)
return accuracy
# 使用示例
if __name__ == "__main__":
# 创建一些随机数据作为示例
X_train = np.random.rand(100, 64).astype('float32')
y_train = np.random.randint(0, 5, 100)
X_test = np.random.rand(10, 64).astype('float32')
y_test = np.random.randint(0, 5, 10)
# 创建模型实例
knn = FaissKNeighbors(k=3)
# 训练模型
knn.fit(X_train, y_train)
# 预测
predictions = knn.predict(X_test)
# 计算准确率
accuracy = knn.score(X_test, y_test)
print("Accuracy:", accuracy)
马建仓 AI 助手
尝试更多
代码解读
代码找茬
代码优化
1
https://gitee.com/qingyu-chen/faiss_dog_cat_question.git
git@gitee.com:qingyu-chen/faiss_dog_cat_question.git
qingyu-chen
faiss_dog_cat_question
faiss_dog_cat_question
main

搜索帮助