## Cynhard85 / MachineLearningTutorial

Cynhard 提交于 2018-07-14 10:09 . add 局部加权线性回归
{
"cells": [
{
"cell_type": "markdown",
"source": [
"# 准备"
]
},
{
"cell_type": "code",
"execution_count": 1,
"outputs": [],
"source": [
"import matplotlib.pyplot as plt\n",
"%matplotlib inline\n",
"import pandas as pd\n",
"import numpy as np"
]
},
{
"cell_type": "markdown",
"source": [
"# 构造数据"
]
},
{
"cell_type": "code",
"execution_count": 2,
"outputs": [
{
"data": {
"text/plain": [
"<matplotlib.figure.Figure at 0x2251479da58>"
]
},
"output_type": "display_data"
}
],
"source": [
"m = 100\n",
"\n",
"np.random.seed(42)\n",
"x_values = 100 * np.random.rand(m,1)\n",
"x_values = np.sort(x_values, axis=0)\n",
"y = 7 * np.sin(0.12 * x_values) + x_values + 2 * np.random.randn(m, 1)\n",
"plt.figure(figsize=(10, 5))\n",
"plt.plot(x_values, y, \"b.\")\n",
"plt.show()"
]
},
{
"cell_type": "code",
"execution_count": 3,
"outputs": [],
"source": [
"X = np.c_[np.ones([m, 1]), x_values]"
]
},
{
"cell_type": "markdown",
"source": [
"# k 值对权重的影响"
]
},
{
"cell_type": "code",
"execution_count": 4,
"outputs": [
{
"data": {
"text/plain": [
"<matplotlib.figure.Figure at 0x2251479df60>"
]
},
"output_type": "display_data"
}
],
"source": [
"plt.figure(figsize=(12,8))\n",
"\n",
"ks = [100, 5, 1]\n",
"for index in range(len(ks)):\n",
"    \n",
"    ws = []\n",
"    for i in range(m):\n",
"        wi = np.exp(- np.sum(np.square(X[i] - X[m//2])) / (2 * ks[index]**2))\n",
"        ws.append(wi)\n",
"    \n",
"    plt.subplot(len(ks), 1, index+1)\n",
"    plt.plot(x_values, ws)\n",
"    plt.title(\"k={}\".format(ks[index]))\n",
"    \n",
"plt.tight_layout()\n",
"plt.show()"
]
},
{
"cell_type": "markdown",
"source": [
"# 预测"
]
},
{
"cell_type": "code",
"execution_count": 5,
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"[48.041602]\n"
]
}
],
"source": [
"def calculate_theta(x_test, k):\n",
"    # 构造矩阵 W\n",
"    W = np.eye(m, m)\n",
"    for i in range(m):\n",
"        W[i,i] = np.exp(- np.sum(np.square(X[i] - x_test)) / (2 * k**2))\n",
"\n",
"    # 应用局部加权线性回归，求解 theta\n",
"    theta = np.linalg.inv(X.T.dot(W).dot(X)).dot(X.T).dot(W).dot(y)\n",
"    \n",
"    return theta\n",
"\n",
"def predict(x_test, k):\n",
"    theta = calculate_theta(x_test, k)\n",
"    y_pred = theta[0] + x_test * theta[1]\n",
"    return y_pred\n",
"\n",
"print(predict(50, 5))"
]
},
{
"cell_type": "markdown",
"source": [
"# 过拟合与欠拟合"
]
},
{
"cell_type": "code",
"execution_count": 6,
"outputs": [
{
"data": {
"text/plain": [
"<matplotlib.figure.Figure at 0x22514801dd8>"
]
},
"output_type": "display_data"
}
],
"source": [
"test_count = 50\n",
"x_test_values = np.linspace(0, 100, test_count)\n",
"\n",
"def lwlr(x_test_values, k):\n",
"    \n",
"    left_values = x_test_values - m / test_count / 2\n",
"    right_values = x_test_values + m / test_count / 2\n",
"    X_tests = np.c_[np.ones(test_count), x_test_values.reshape(-1, 1)]\n",
"    \n",
"    x_plots = []\n",
"    y_plots = []\n",
"\n",
"    for t, l, r in zip(X_tests, left_values, right_values):\n",
"        \n",
"        theta = calculate_theta(t, k)\n",
"\n",
"        x_test_points = np.array([[l], [r]])\n",
"        X_test = np.c_[np.ones([2, 1]), x_test_points]\n",
"        y_test_points = X_test.dot(theta)\n",
"\n",
"        x_plots.extend(x_test_points)\n",
"        y_plots.extend(y_test_points)\n",
"    \n",
"    plt.plot(x_values, y, \"b.\")\n",
"    plt.plot(x_plots, y_plots, 'r-', linewidth=2)\n",
"    plt.title(\"k={}\".format(k))\n",
"\n",
"plt.figure(figsize=(12, 8))\n",
"ks = [100, 5, 1]\n",
"for index in range(len(ks)):\n",
"    plt.subplot(len(ks), 1, index+1)\n",
"    lwlr(x_test_values, ks[index])\n",
"plt.tight_layout()\n",
"plt.show()"
]
},
{
"cell_type": "code",
"execution_count": null,
"outputs": [],
"source": []
}
],
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.6.4"
},
"toc": {
"base_numbering": 1,
"number_sections": true,
"sideBar": true,
"skip_h1_title": false,
"title_sidebar": "Contents",
"toc_cell": false,
"toc_position": {},
"toc_section_display": true,
"toc_window_display": false
}
},
"nbformat": 4,
"nbformat_minor": 2
}