From 15d8469aa10ba3cedc4673c6b4aef97b442ff104 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E6=98=8E=E8=BE=89123456?= <3165095743@qq.com> Date: Tue, 10 Sep 2024 13:59:29 +0000 Subject: [PATCH] update optimal_knn.py. MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: 明辉123456 <3165095743@qq.com> --- optimal_knn.py | 48 +++++++++++++++++++++++++++++++++++++++++++++++- 1 file changed, 47 insertions(+), 1 deletion(-) diff --git a/optimal_knn.py b/optimal_knn.py index 53200c3..0d89b1c 100644 --- a/optimal_knn.py +++ b/optimal_knn.py @@ -1,15 +1,61 @@ +import matplotlib.pyplot as plt +from sklearn.datasets import load_digits +from sklearn.model_selection import train_test_split +from sklearn.neighbors import KNeighborsClassifier +from sklearn.metrics import accuracy_score +from tqdm import tqdm +import pickle # TODO: 导入必要的库和模块 # TODO: 加载数字数据集 +digits = load_digits() +X = digits.data +y = digits.target # TODO: 将数据集划分为训练集和测试集 +X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.3, random_state=42) # TODO: 初始化变量以存储最佳准确率,相应的k值和最佳knn模型 +best_accuracy = 0 +best_k = 1 +best_model = None # TODO: 初始化一个列表以存储每个k值的准确率 +accuracies = [] # TODO: 尝试从1到40的k值,对于每个k值,训练knn模型,保存最佳准确率,k值和knn模型 +for k in tqdm(range(1, 41), desc="训练KNN模型"): + model = KNeighborsClassifier(n_neighbors=k) + model.fit(X_train, y_train) + y_pred = model.predict(X_test) + accuracy = accuracy_score(y_test, y_pred) + + # 存储每个k值的准确率 + accuracies.append(accuracy) + + # 更新最佳模型 + if accuracy > best_accuracy: + best_accuracy = accuracy + best_k = k + best_model = model # TODO: 将最佳KNN模型保存到二进制文件 +with open('best_knn_model.pkl', 'wb') as f: + pickle.dump(best_model, f) + +# TODO: 打印最佳准确率和相应的k值 +print(f"最佳k值: {best_k}, 最佳准确率: {best_accuracy}") -# TODO: 打印最佳准确率和相应的k值 \ No newline at end of file +# TODO: 保存准确率图到PDF文件 +plt.figure(figsize=(10, 6)) +plt.plot(range(1, 41), accuracies, marker='o', linestyle='-', color='b', label='Accuracy') +plt.axvline(x=best_k, color='r', linestyle='--', label=f'Best k value: {best_k}') +plt.text(best_k, best_accuracy, f'k={best_k}\nAccuracy={best_accuracy:.2f}', + verticalalignment='bottom', horizontalalignment='right', color='red') +plt.xlabel('K Value') +plt.ylabel('Accuracy') +plt.title('Accuracy of different k values') +plt.legend() +plt.grid(True) +plt.savefig('accuracy_plot.pdf') +plt.show() -- Gitee