diff --git a/optimal_knn.py b/optimal_knn.py index 53200c3119acb21350831b431d4ccb8cb3b2adc7..bbca8e6e4318ae5d7a47170ad28d2ecdae406b61 100644 --- a/optimal_knn.py +++ b/optimal_knn.py @@ -1,15 +1,48 @@ -# TODO: 导入必要的库和模块 +import numpy as np +from sklearn.datasets import load_digits +from sklearn.model_selection import train_test_split +from sklearn.neighbors import KNeighborsClassifier +from sklearn.metrics import accuracy_score +import pickle +import matplotlib.pyplot as plt +import gradio as gr -# TODO: 加载数字数据集 +# 加载保存的KNN模型 +with open('best_knn_model.pkl', 'rb') as f: + knn = pickle.load(f) -# TODO: 将数据集划分为训练集和测试集 +# 定义预测函数 +def predict(image): + # 将输入的图像转换为与训练数据相同的格式 + image = np.array(image).reshape(1, -1) + # 使用预训练的KNN模型进行预测 + prediction = knn.predict(image) + return prediction[0] -# TODO: 初始化变量以存储最佳准确率,相应的k值和最佳knn模型 +# 创建Gradio接口 +demo = gr.Interface( + predict, + gr.Image(label="输入图像"), + gr.Label(label="预测结果"), + title="手写数字识别", + description="输入一个手写数字图像,模型将预测出相应的数字。" +) -# TODO: 初始化一个列表以存储每个k值的准确率 +# 启动Gradio接口 +if __name__ == "__main__": + demo.launch() +# Print the best accuracy and corresponding k value +print(f"\nBest accuracy: {best_accuracy:.4f}") +print(f"Best k value: {best_k}") -# TODO: 尝试从1到40的k值,对于每个k值,训练knn模型,保存最佳准确率,k值和knn模型 +# Plot the relationship between k values and accuracy +plt.figure(figsize=(10, 6)) +plt.plot(range(1, 41), accuracies, marker='o') +plt.title('Relationship between K value and Accuracy') +plt.xlabel('K value') +plt.ylabel('Accuracy') +plt.grid(True) +plt.savefig('accuracy_plot.pdf') +print("Accuracy plot saved as 'accuracy_plot.pdf'") -# TODO: 将最佳KNN模型保存到二进制文件 - -# TODO: 打印最佳准确率和相应的k值 \ No newline at end of file +print("\nProcess completed successfully!") \ No newline at end of file