"# 准备"
"import matplotlib.pyplot as plt\n",
"%matplotlib inline\n",
"import pandas as pd\n",
"import numpy as np"
"# 构造数据"
"source": [
"m = 100\n",
"\n",
"np.random.seed(42)\n",
"x_values = 100 * np.random.rand(m,1)\n",
"x_values = np.sort(x_values, axis=0)\n",
"y = 7 * np.sin(0.12 * x_values) + x_values + 2 * np.random.randn(m, 1)\n",
"plt.figure(figsize=(10, 5))\n",
"plt.plot(x_values, y, \"b.\")\n",
"plt.show()"
"X = np.c_[np.ones([m, 1]), x_values]"
"# k 值对权重的影响"
"plt.figure(figsize=(12,8))\n",
"\n",
"ks = [100, 5, 1]\n",
"for index in range(len(ks)):\n",
"    \n",
"    ws = []\n",
"    for i in range(m):\n",
"        wi = np.exp(- np.sum(np.square(X[i] - X[m//2])) / (2 * ks[index]**2))\n",
"        ws.append(wi)\n",
"    \n",
"    plt.subplot(len(ks), 1, index+1)\n",
"    plt.plot(x_values, ws)\n",
"    plt.title(\"k={}\".format(ks[index]))\n",
"    \n",
"plt.tight_layout()\n",
"plt.show()"
"# 预测"
"[48.041602]\n"
"def calculate_theta(x_test, k):\n",
"    # 构造矩阵 W\n",
"    W = np.eye(m, m)\n",
"    for i in range(m):\n",
"        W[i,i] = np.exp(- np.sum(np.square(X[i] - x_test)) / (2 * k**2))\n",
"\n",
"    # 应用局部加权线性回归，求解 theta\n",
"    theta = np.linalg.inv(X.T.dot(W).dot(X)).dot(X.T).dot(W).dot(y)\n",
"    \n",
"    return theta\n",
"\n",
"def predict(x_test, k):\n",
"    theta = calculate_theta(x_test, k)\n",
"    y_pred = theta[0] + x_test * theta[1]\n",
"    return y_pred\n",
"\n",
"print(predict(50, 5))"
"# 过拟合与欠拟合"
"test_count = 50\n",
"x_test_values = np.linspace(0, 100, test_count)\n",
"\n",
"def lwlr(x_test_values, k):\n",
"    \n",
"    left_values = x_test_values - m / test_count / 2\n",
"    right_values = x_test_values + m / test_count / 2\n",
"    X_tests = np.c_[np.ones(test_count), x_test_values.reshape(-1, 1)]\n",
"    \n",
"    x_plots = []\n",
"    y_plots = []\n",
"\n",
"    for t, l, r in zip(X_tests, left_values, right_values):\n",
"        \n",
"        theta = calculate_theta(t, k)\n",
"\n",
"        x_test_points = np.array([[l], [r]])\n",
"        X_test = np.c_[np.ones([2, 1]), x_test_points]\n",
"        y_test_points = X_test.dot(theta)\n",
"\n",
"        x_plots.extend(x_test_points)\n",
"        y_plots.extend(y_test_points)\n",
"    \n",
"    plt.plot(x_values, y, \"b.\")\n",
"    plt.plot(x_plots, y_plots, 'r-', linewidth=2)\n",
"    plt.title(\"k={}\".format(k))\n",
"\n",
"plt.figure(figsize=(12, 8))\n",
"ks = [100, 5, 1]\n",
"for index in range(len(ks)):\n",
"    plt.subplot(len(ks), 1, index+1)\n",
"    lwlr(x_test_values, ks[index])\n",
"plt.tight_layout()\n",
"plt.show()"
