1 Star 0 Fork 31

Hakula / prml-21-spring

Create your Gitee Account
Explore and code with more than 6 million developers,Free private repositories !:)
Sign up
Clone or download
README.md 2.46 KB
Copy Edit Web IDE Raw Blame History
WillQvQ authored 2021-03-17 19:34 . more details about method fit

作业-1: K 近邻

作业简述

K近邻(k-Nearest Neighbor,KNN) 算法是一个简单的机器学习算法。你需要按照要求实现 K 近邻算法的代码,并进行一些探究实验。

提交内容

  • README.md: 课程作业报告
  • source.py: 课程作业源码

提交方法参见提交指南

具体要求

你需要参考 handout/ 文件夹下的 source.py 来实现一个 KNN 类,使得其它程序可以调用该类的 fit 方法进行模型的训练,使用 predict 方法来进行预测。fit 方法可以包含数据处理,K 值选择等过程

class KNN:

    def __init__(self):
        pass

    def fit(self, train_data, train_label):
        pass

    def predict(self, test_data):
        pass

其中,fit 方法的参数 train_datatrain_label 均为 numpy.ndarray 类型,大小分别为 (N, K)(N,) ,其中 N 为数据的条数,K 为数据的维度。 predict 方法的参数 test_data 类型和 train_data 一致,输出的类型与train_label 的类型一致。

你可以使用 handout/tester_demo.py 来对代码进行简单的测试。

实验探究(80%)

  1. 使用 np.random.multivariate_normal 生成若干个(例如:3个)符合二维高斯分布的集合,给每个集合配上一个标签后混合为 datalabel
  2. 将数据随机划分为 80% 的训练集和 20% 的测试集,共有 train_datatrain_labeltest_datatest_label 四个部分。
  3. 使用自己编写的 KNN 模型,在生成的数据集上进行训练和测试,使用图表分析实验结果,你可能需要使用 matplotlib 库。
  4. 修改数据集的属性(例如:不同高斯分布之间的距离),进行探究性实验。

实验部分的代码也需要写在 source.py 中,这部分的分数由助教根据你的代码和报告给出。

自动测试(20%)

我们会使用自动化的工具来测试你写在 source.py 中的 KNN 类,如果你的代码不能正确运行,你将失去这 20% 的分数。我们会提供多组用于自动测试的 train_datatrain_labeltest_datatest_label ,数据的维度各不相同。模型的准确率不做为评分指标。

测试环境如下:

conda create -n assignment-1 python=3.8 -y
conda activate assignment-1
pip install numpy
pip install matplotlib

模式识别与机器学习 / 复旦大学 / 2021年春

Comment ( 0 )

Sign in for post a comment

1
https://gitee.com/hakula139/prml-21-spring.git
git@gitee.com:hakula139/prml-21-spring.git
hakula139
prml-21-spring
prml-21-spring
master

Search

105716 1d94204e 1850385 105716 2d26be5c 1850385