From f44f282dc0323cdab1d35edc124ce880d1051e33 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E5=BC=A0=E6=9D=B0?= <3070524466@qq.com> Date: Thu, 2 Nov 2023 11:00:24 +0000 Subject: [PATCH] update optimal_knn.py. MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: 张杰 <3070524466@qq.com> --- optimal_knn.py | 64 ++++++++++++++++++++++++++++++++++++++------------ 1 file changed, 49 insertions(+), 15 deletions(-) diff --git a/optimal_knn.py b/optimal_knn.py index 53200c3..001ed8f 100644 --- a/optimal_knn.py +++ b/optimal_knn.py @@ -1,15 +1,49 @@ -# TODO: 导入必要的库和模块 - -# TODO: 加载数字数据集 - -# TODO: 将数据集划分为训练集和测试集 - -# TODO: 初始化变量以存储最佳准确率,相应的k值和最佳knn模型 - -# TODO: 初始化一个列表以存储每个k值的准确率 - -# TODO: 尝试从1到40的k值,对于每个k值,训练knn模型,保存最佳准确率,k值和knn模型 - -# TODO: 将最佳KNN模型保存到二进制文件 - -# TODO: 打印最佳准确率和相应的k值 \ No newline at end of file +# 导入必要的库和模块 +import sklearn +import tqdm +from sklearn import datasets +from sklearn.model_selection import train_test_split +from sklearn.preprocessing import StandardScaler +from sklearn.neighbors import KNeighborsClassifier +from sklearn.metrics import accuracy_score +import pickle +import matplotlib.pyplot as plt +import numpy as np +#: 加载数字数据集 +digits = datasets.load_digits() +X = digits.data +y = digits.target +#: 将数据集划分为训练集和测试集 +X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42) +#: 初始化变量以存储最佳准确率,相应的k值和最佳knn模型 +best_acc_accuracy = 0.0 +best_k = 0 +best_model = None +#: 初始化一个列表以存储每个k值的准确率 +arr_accuracy = [] +#: 尝试从1到40的k值,对于每个k值,训练knn模型,保存最佳准确率,k值和knn模型 +for k in tqdm.tqdm(range(1,41)): + knn = KNeighborsClassifier(n_neighbors=k) + knn.fit(X_train, y_train) + acc = knn.score(X_test,y_test) + arr_accuracy.append(acc) + if acc > best_acc_accuracy: + best_acc_accuracy = acc + best_k = k + best_model = knn +#: 将最佳KNN模型保存到二进制文件 +with open('knn_model.pkl', 'wb') as file: + pickle.dump(knn, file) +# : 打印最佳准确率和相应的k值 +print('best acc is:',best_acc_accuracy) +print('best k is:',best_k) +k_value = range(1,41) +Accuracy = arr_accuracy +plt.plot(k_value,Accuracy) +plt.title('Accuracy of different k_value') +plt.text(6.5,0.989,'k=6,Accuracy=0.99',fontsize=12,color='red') +plt.axvline(best_k,color='r') +plt.xlabel('k_Value') +plt.ylabel('Accuracy') +plt.savefig('accuracy_polt.pdf') +plt.show() \ No newline at end of file -- Gitee