8月18日(周六)成都源创会火热报名中,四位一线行业大牛与你面对面,探讨区块链技术热潮下的冷思考。
Watch Star Fork

Cynhard85 / MachineLearningTutorial

加入码云
与超过 300 万 开发者一起发现、参与优秀开源项目,私有仓库也完全免费 :)
免费加入
暂无描述
一键复制 编辑 原始数据 按行查看 历史
机器学习 - 梯度下降.ipynb 33.81 KB liuxinyang 提交于 2018-07-06 16:43 . modify names
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# 准备数据"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [],
"source": [
"import matplotlib.pyplot as plt\n",
"%matplotlib inline\n",
"import numpy as np\n",
"import pandas as pd\n",
"import datetime"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAXwAAAD8CAYAAAB0IB+mAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDIuMS4yLCBodHRwOi8vbWF0cGxvdGxpYi5vcmcvNQv5yAAAHCZJREFUeJzt3X+wXGV9x/H3N5cEoqaGkGuISW4SSoYxphb0Dj8KdBAIBKRErTiBjkRgJkMLg07rCMgMUixTrFOtTBUmLYzQAQJUbTKIxYCg4BgwCSAkIeYGCFxDQ5CfDj+S3PvtH+fccFj2x9k9P/acPZ/XTCZ3z57d89y9O9999vt8n+cxd0dERHrfuG43QERE8qGALyJSEQr4IiIVoYAvIlIRCvgiIhWhgC8iUhEK+CIiFaGALyJSEQr4IiIVsU+3GxA1depUnzNnTrebISJSKuvWrXvR3ftbnVeogD9nzhzWrl3b7WaIiJSKmW2Lc55SOiIiFaGALyJSEQr4IiIVoYAvIlIRCvgiIhWhgC8iUhGFKssUEellqzfu4IEtOzl2Xj8L50/L/foK+CIiOVi9cQcX3foIb+4e4Y61w5x7zFxef2t3rsFfAV9EJCPRHv0DW3by5u4RAN7cPcJ1v9jKyKjnGvwV8EVEMlCvRz9xfB9v7h6hz2Bk1IH3Bv9rzjwss6CvQVsRkSZWb9zB5SufYPXGHW09rrZH//pbu7nmzMM4+6jZnH/cwUwc3wfwnuD/wJad6f4CEan18M2sD1gL/N7dTzOzucAKYAqwHviCu+9K63oiIlmr7aW30/s+dl4/d6wd5s3dI0wc37c3XTP2+ENnTeaBLTuZtN94bnjw6Xedl5U0UzpfAjYBfxLe/ibwHXdfYWbXAecB16Z4PRGRTNX20h/YsjN2wF84fxrXnHlYw6qcesG/FDl8M5sJfAq4Cvh7MzPgeOCs8JQbgStQwBeREqnXS29HNKincV5SafXw/w34KjApvH0A8Iq77wlvDwMzUrqWiEguWvXSyyZxwDez04AX3H2dmR03drjOqd7g8cuAZQADAwNJmyMikqq8et95SKNK52jgdDN7hmCQ9niCHv9kMxv7QJkJbK/3YHdf7u6D7j7Y35/dYIWISNUlDvjufqm7z3T3OcAS4Ofu/jfAfcDnwtOWAiuTXktERDqXZR3+xQQDuEMEOf3rM7yWiIi0kOpMW3e/H7g//Pkp4PA0n19EJCvdXtgsD1paQUR6XqtgnmSCVZznLwotrSAipdLuUgdjwfymX2/jolsfqfu4ehOs2mlPq+cvCgV8ESmNToJrnGB+7Lz+vWvbTBzfx6T9xrf8UBn74LnloW0df1jkTSkdESmNTpY6iDNbNjrBKrq2TaP0TjQFNKFvHBP6xrFrZDTztXCSUsAXkdLoZKmDuLNlxyZYXb7yiZYfKtEPnl0jo3zykH5mTXlf4XP4CvgiUhqdLnUQnS3baoA1zodK7TlnHTG70IF+jLnXXfGgKwYHB33t2rXdboaI9KhoKmbi+L6G1Thxqm6KVJljZuvcfbDVeerhi0hlxB0DiLN+ThnX2FGVjohURm01TrsDrJ3uflUU6uGLSGUkWe446eSsIlDAF5FK6TQVk2T3q6JQSkdEJIak6aAiUA9fRCSGuOmgIlXv1FLAFxGJqVU6qOh5fqV0RKSy0q66SbIIWx4U8EWkkrJY5bLoef7EAd/M9jOzh83sMTPbYGb/GB6fa2YPmdkWM7vNzCYkb66ISDqy6I2P5fnPPmp24dI5kE4P/23geHf/c+BQYJGZHQl8E/iOu88DXgbOS+FaIlJBWUx4yqo3vnD+NK5cvKBwwR5SGLT1YDGeP4Y3x4f/HDgeOCs8fiNwBXBt0uuJSPFkWZmS1UBokklYZZVKlY6Z9QHrgIOB7wFbgVfcfU94yjAwI41riUixZF2ZkuWEpzKuh5NEKoO27j7i7ocCMwk2Lv9IvdPqPdbMlpnZWjNbu3NnsUa0RaS1rCtTOk29lH3dmyykWqXj7q8A9wNHApPNbOwbxExge4PHLHf3QXcf7O8v1oi2iLSWdWVKJwOhZdpnNk+JUzpm1g/sdvdXzGwicCLBgO19wOeAFcBSYGXSa4lI8eSRC2839dIL695kIY0c/nTgxjCPPw643d3vNLONwAoz+yfgEeD6FK4lIl3UaHC2UUDu1jIDnWyFWAXa8UpEYom7W1Sn5ydpV70PlSKvaZO2uDteaaatiMTS7uBsHssMNMvVF7kevlsU8EUklnYHZ+Oe36iaJk6VTe2Hyi0PbVNlThNK6YhIbO2mSVqd3yjt085m42PnTegL+q+7RkYzTSEVkTYxF5HUtVst0+r8RtU07Ww2PlYh9NxLb3Df5p0tH1NlSumISNc0Svu0kz4ay9WfdcTsQq9UWQRK6YhIV6VZZVOlypyouCkdBXwRyUVVg3EeVJYpIoWhpQ6KQYO2IpK5euWT6u3nTz18EclcdBB2Qt84fjX0B/X2u0ABX0QyF13x8uiDD2DXyChQzI2+e5lSOiKSi7Ga/NUbd7DmqZe0sFkXKOCL9LiiLS5Wxa0Fi0JlmSI9LOnSBVIOKssUkYYrVrazkqW2CuwdCvgiPSzp0gXN6uf1QVA+aWxxOAu4CTgQGAWWu/t3zWwKcBswB3gG+Ly7v5z0eiISX6N8edw8eqNFzKIpoTvWDislVBJpDNruAf7B3deb2SRgnZmtBr4I3OvuV5vZJcAlwMUpXE9E2tBoxco4K1822ipQe8aWU+KUjrs/7+7rw59fBzYBM4DFwI3haTcCn056LRHJRqP0TLR+PtqLb3czFCmGVKt0zGwO8EtgAfCsu0+O3Peyu+/f7PGq0hHJX6cVO1oMrThy3wDFzD4A/BD4sru/ZmZxH7cMWAYwMDCQVnNEJKZO0zPtboYi3ZdKlY6ZjScI9je7+4/CwzvMbHp4/3TghXqPdffl7j7o7oP9/fpaKJKGdipolJ6pjjSqdAy4Htjk7t+O3LUKWApcHf6/Mum1RKS1ditoNPO1OtJI6RwNfAF43MweDY99jSDQ325m5wHPAmekcC0RaaGTFI3SM9WQOOC7+4NAo4T9CUmfX0Ta06iUUkSLp4n0mDKnaFT5ky0FfJEeVMYUjWbvZk9r6YhIIbSzoJt0RgFfpEKKvOCZykOzp5SOSEGlnc9uJ2XSjVx6mcceykIBX6SAsshnxy3X7GYuvYxjD2WilI5IAWWRz46bMlEuvXcp4IsUUBb57EYrX+ZxbSkG7WkrUlDdrElXPXy5xF0tUwFfRKTktIm5iIi8iwK+iEhFqCxTJCfKi0u3KeCLZCQa4AGtEyNdp4AvkoHayUtHHjQl9qSnVt8C9E1BOqUcvkgGaicvAS1r28c+JG769TYuuvWRuuvdxDlHpJG09rS9wcxeMLMnIsemmNlqM9sS/r9/GtcSKYPayUtnHTG75aSnODNcNQtWkkirh/8DYFHNsUuAe919HnBveFukEurNal04fxpXLl6QaIarZsFKEqlNvDKzOcCd7r4gvL0ZOM7dnzez6cD97n5Is+fQxCupul7I4Re9fb0o95m2dQL+K+4+OXL/y+7+nrSOmS0DlgEMDAx8Ytu2bam0R0TyFx2snji+T9VIOSnNTFt3X+7ug+4+2N+vr6dSblltMFLkjUuiNMZQbFkG/B1hKofw/xcyvJZI12VVQVOmyhyNMRRblgF/FbA0/HkpsDLDa4l0XVa92zL1muMuwSzdkcrEKzO7FTgOmGpmw8DXgauB283sPOBZ4Iw0riVSVMfO6+eOtcN789dp9W6bPW/tbN4iDJZq16ri0vLIIinqpEKl08qc6ADphL7gy/qukVENllZQ3EFbLa0gkqJ2e7dx94+t97zRVM+ukdG9x5st3SDV1vUqHZEqS5Kfjw6QTugbt7eXr8FSaUQ9fBG6N1koSd5/bIC0aDl8KS7l8KXyuj1ZSDNTJSnl8EViqpdWyTPwqqpF8qIcvlRe3MlCZZntKtKIevhSebW58EYbk2jHKik7BXwRWqdV0k77KG8v3aCUjkgMaa4RU6a1caS3qIcvUqNe7ztO2ieubg8SS3Up4ItENMvVp1VNk9WaOyKtKOCLROTR+07z24JIOxTwRSLi9r6TDrqq9l66QQFfJKKTEs1zj5nL62/tVm9dCk8BX6RGuyWa1/1iKyOjrvp8KTyVZYq0KVqi2WcwMhqsR1X03ahEMg/4ZrbIzDab2ZCZXZL19USyFt3G7/zjDtYerlIamaZ0zKwP+B6wEBgGfmNmq9x9Y5bXFclaNO1z6KzJqriRUsg6h384MOTuTwGY2QpgMaCALz1DFTdSFlmndGYAz0VuD4fH9jKzZWa21szW7typ/KeISFayDvhW59i7dlxx9+XuPujug/39yn9Kc1qiWKRzWQf8YWBW5PZMYHvG15QepUXHRJLJOuD/BphnZnPNbAKwBFiV8TWlRyXZ8FtEMg747r4HuBC4G9gE3O7uG7K8pvSuNJcoBqWHpHq0ibkkludmHmldq9sbl4ukSZuYSy7y3vovWgKZJPhrTXqpIi2tIIl0K69ebwC3nRRN2ukhkTJQD18S6dZmHrUfNLc8tI01T70UewVLrUkvVaQcviSWdQ6/3vPX5uCPPGgK921+59tF3zhjZNTfk5/X5uHSi+Lm8BXwpdCaDa5Ggzew97w+g5HI2/rso2Zz5eIFGqiVnqVBW+kJzQZXa9ewGUvRTNpvPDc8+PR70kwaqJWqU8CXXHSaSmlnjKDVCpbaPFyqTikdyVzSVEqaeXfl8KUXKaUjhZE0lZLm8sPtPpc+IKSXqA5fMtes5r3IyxtosTbpNerhS+Ya1bznPUu3XRrklV6jHr7kYuH8aVy5eMG7AmbRV7/UbFzpNerhS9cUvWpGs3Gl16hKR7pKg6IiyalKR0ohjw3A9aEiEkiUwzezM8xsg5mNmtlgzX2XmtmQmW02s5OTNVOkM6q0EXlH0kHbJ4DPAr+MHjSz+QTbGX4UWAR838z6El5LpG1FHxgWyVOigO/um9x9c527FgMr3P1td38aGAIOT3ItkU6o0kbkHVnl8GcAayK3h8NjIrlSpY3IO1oGfDO7Bziwzl2XufvKRg+rc6xuOZCZLQOWAQwMDLRqjkjb8hgYFimDlgHf3U/s4HmHgVmR2zOB7Q2efzmwHIKyzA6uJQWiihiR4spqpu0qYImZ7Wtmc4F5wMMZXUsKQhUxIsWWtCzzM2Y2DBwF/MTM7gZw9w3A7cBG4H+BC9x9JGljpdhUESNSbEmrdH7s7jPdfV93n+buJ0fuu8rd/9TdD3H3nyZvqhSdKmJEik0zbSU1aVbEaCxAJH0K+PIuSQNtGhUxRV82WaSstDxyiaW9eUhRBl01FiCSDQX8ksoiOBcl0GosQCQbCvgllUVwLkqgHRsLOPuo2UrniKRIOfySymLzkLyWIYgzTqDZsSLp0wYoJZZ3JUsa14sOyE4c36cevEgKtAFKBeTZC66tnDn3mLm8/tbupsG/3geENgYX6R7l8CWW2kB93S+2Nh0wbjSoXJRxApEqUsCvqHZLOqOBus9gZDRIBTYaMG40qKwBWZHuUUqngjqZ2BQd0J2033huePDppgPGzQaVG6WiNLtWJFsK+BXUTh69NgiPnXforMlNg3O7FT+aXSuSPQX8Copb0tksCMcZMG5nUFmDuSLZUw6/guLm0fOceavBXJHsqYdfUXF631lM7mrWHu09K5ItTbzqQWkOfmogVaT4cpl4ZWbfAv4K2AVsBc5x91fC+y4FzgNGgIvc/e4k15J40h781BIHIr0jaQ5/NbDA3T8G/A64FMDM5gNLgI8Ci4Dvm1lfwmtJDEVZ8VJEiifpFoc/c/c94c01wMzw58XACnd/292fBoaAw5NcS+LR4KeINJLmoO25wG3hzzMIPgDGDIfHJGPNBj+VjxeptpYB38zuAQ6sc9dl7r4yPOcyYA9w89jD6pxfd3TYzJYBywAGBgZiNFlaqZd318QmEWkZ8N39xGb3m9lS4DTgBH+n5GcYmBU5bSawvcHzLweWQ1ClE6PNUker3rsmNolIohy+mS0CLgZOd/c3InetApaY2b5mNheYBzyc5FpV087iZnG2O2yW2097b1wRKaakOfx/B/YFVpsZwBp3P9/dN5jZ7cBGglTPBe4+kvBaldFu+iVO771Rbl+pHpHqSBTw3f3gJvddBVyV5Pmrqt30S9wZsfVy+0r1iFSH1tIpoHZLK5OsMa8yTpHq0NIKBZVnCaXKNUXKLe7SCgr4IiIlFzfgK6UjIlIRWh65htIbItKr1MOPiFPPLiJSVgr4EVppUkR6mQJ+RC+WKGoWrYiMUQ4/Is1t9oowFqBZtCISpYBfI40dnooSaDWLVkSiKpXSSSu90ep50h4L6LTdvZiiEpHOVaaHn1avO87zxF3bJut2p5miEpHyq0zA7yS9US8Pn2RlyrjXSNru2rYo0IsIVCil0256o1FNftznWTh/GlcuXtAy2CdZx15EpB2V6eG3m95o1LNOM02S9rcFEZFmKhPwob30RrM8fFppkk7WsS9CuaeIlFOi1TLN7BvAYmAUeAH4ortvt2D7q+8CpwJvhMfXt3q+oq2WmUdwjV4DaHq96ADuxPF9qqsXESC/1TK/5e4fc/dDgTuBy8PjpxDsYzsPWAZcm/A6XREnD5/WNYCW+Xwt/SAiSSQK+O7+WuTm+4GxrwuLgZs8sAaYbGbTk1yr18UJ5hrAFZEkEufwzewq4GzgVeCT4eEZwHOR04bDY88nvV6RJUkBxcnnawBXRJJomcM3s3uAA+vcdZm7r4ycdymwn7t/3cx+Avyzuz8Y3ncv8FV3X1fn+ZcRpH0YGBj4xLZt2zr+Zbopjfy6BmRFpBNxc/gte/jufmLMa94C/AT4OkGPflbkvpnA9gbPvxxYDsGgbcxrFcZYkH7upTcSr1ujSVIikqVEKR0zm+fuW8KbpwNPhj+vAi40sxXAEcCr7t5z6Zxor35C3zgm9I1j18io8usiUkhJc/hXm9khBGWZ24Dzw+N3EZRkDhGUZZ6T8DodyzJNEh1o3TUyyicP6WfWlPcpJSMihZQo4Lv7Xzc47sAFSZ47DVkvU1w70HrWEbMV6EWksHp6pm2a68HX+6agqhkRKZOeDvhpLVPc7JuCBlpFpCx6OuCn1QPXzlEi0gt6OuBDvB54q4HdNDc0ERHplp4I+EkqceIM7Kadq9cEKxHphtIH/E4rcdqdMJVWrr4oG5yLSPWUfserTlaQjO409auhPzChL3gZ8kjXaMVLEemW0gf8TlaQrJ0wdfTBB3D2UbNz6W1rxUsR6ZbSp3Q6ya/nPWGqNmev2n0R6YZEO16lLc8dr/IaONUuVSKStdRWyyyjOME8rwlTquEXkaIofQ6/VnRAttFWgXlSzl5EiqLnevhF61ErZy8iRdFzAb+Is2K13o6IFEHPBXz1qEVE6uu5gA/qUYuI1JPKoK2ZfcXM3MymhrfNzK4xsyEz+62ZfTyN64iISOcSB3wzmwUsBJ6NHD4FmBf+WwZcm/Q6IiKSTBo9/O8AXwWiM7gWAzd5YA0w2cymp3AtERHpUKKAb2anA79398dq7poBPBe5PRweExGRLmk5aGtm9wAH1rnrMuBrwEn1HlbnWN01HMxsGUHah4GBgVbNERGRDrUM+O5+Yr3jZvZnwFzgMTMDmAmsN7PDCXr0syKnzwS2N3j+5cByCNbSaafxIiISX2qLp5nZM8Cgu79oZp8CLgROBY4ArnH3w2M8x05gW4dNmAq82OFjs1bUtqld7VG72qN2tSdJu2a7e8tZplnV4d9FEOyHgDeAc+I8KE6DGzGztXFWi+uGorZN7WqP2tUetas9ebQrtYDv7nMiPztwQVrPLSIiyfXcapkiIlJfLwX85d1uQBNFbZva1R61qz1qV3syb1ehdrwSEZHs9FIPX0REmihVwDezM8xsg5mNmtlgzX2Xhou1bTazkxs8fq6ZPWRmW8zsNjObkEEbbzOzR8N/z5jZow3Oe8bMHg/Py2UjXzO7wsx+H2nfqQ3OWxS+jkNmdknGbfqWmT0ZLrL3YzOb3OC8XF6vVr+7me0b/o2HwvfSnKzaErnmLDO7z8w2he//L9U55zgzezXyt70863ZFrt30b5P3YopmdkjkdXjUzF4zsy/XnJPb62VmN5jZC2b2ROTYFDNbHcai1Wa2f4PHLg3P2WJmSxM3xt1L8w/4CHAIcD9Bzf/Y8fnAY8C+BJPBtgJ9dR5/O7Ak/Pk64G8zbu+/Apc3uO8ZYGrOr98VwFdanNMXvn4HARPC13V+hm06Cdgn/PmbwDe79XrF+d2BvwOuC39eAtyWw99tOvDx8OdJwO/qtOs44M48309x/zYEJdo/JZiBfyTwUI5t6wP+j6BOvSuvF/CXwMeBJyLH/gW4JPz5knrve2AK8FT4//7hz/snaUupevjuvsndN9e5azGwwt3fdvenCer/3zXRy4LpwMcD/x0euhH4dFZtDa/3eeDWrK6RkcOBIXd/yt13ASsIXt9MuPvP3H1PeHMNwazsbonzuy8meO9A8F46IfxbZ8bdn3f39eHPrwObKNfaVN1cTPEEYKu7dzqhMzF3/yXwUs3h6PuoUSw6GVjt7i+5+8vAamBRkraUKuA3EWextgOAVyLBJesF3Y4Fdrj7lgb3O/AzM1sXrieUlwvDr9U3NPga2c2F784l6AnWk8frFed333tO+F56leC9lYswhXQY8FCdu48ys8fM7Kdm9tG82kTrv00331NLaNzp6tbrBTDN3Z+H4AMd+FCdc1J/3Qq345U1WazN3Vc2elidY7XlR7EXdGslZhvPpHnv/mh3325mHwJWm9mTYU8gkWZtI9iX4BsEv/c3CFJO59Y+RZ3HJirlivN6mdllwB7g5gZPk8nrVdvUOscyex+1y8w+APwQ+LK7v1Zz93qCtMUfw7GZ/yHYjyIPrf42XXnNwjG604FL69zdzdcrrtRft8IFfG+wWFsLcRZre5Hgq+Q+Yc+s4YJuSdtoZvsAnwU+0eQ5tof/v2BmPyZIJyQOYHFfPzP7D+DOOnfFXvgurTaFg1GnASd4mLys8xyZvF414vzuY+cMh3/nD/Ler+upM7PxBMH+Znf/Ue390Q8Ad7/LzL5vZlPdPfM1Y2L8bVJ/T8V0CrDe3XfU3tHN1yu0w8ymu/vzYXrrhTrnDBOMNYyZSTB+2bFeSemsApaEFRRzCT6pH46eEAaS+4DPhYeWAo2+MSR1IvCkuw/Xu9PM3m9mk8Z+Jhi4fKLeuWmqyZt+psE1fwPMs6CiaQLBV+JVGbZpEXAxcLq7v9HgnLxerzi/+yqC9w4E76WfN/qQSks4RnA9sMndv93gnAPHxhIsWLF2HPCHLNsVXivO32YVcHZYrXMk8OpYOiNjDb9ld+v1ioi+jxrForuBk8xs/zD9elJ4rHN5jFKn9Y8gSA0DbwM7gLsj911GUGGxGTglcvwu4MPhzwcRfBAMAXcA+2bUzh8A59cc+zBwV6Qdj4X/NhCkNvJ4/f4LeBz4bfiGm17btvD2qQSVIFuzblv4t3gOeDT8d11tm/J8ver97sCVBB9IAPuF752h8L10UA5/t2MIvsr/NvI6nQqcP/Y+I1iddkP4Gq0B/iKn91Tdv01N2wz4XviaPk6kwi7Ddr2PIIB/MHKsK68XwYfO88DuMH6dRzDucy+wJfx/SnjuIPCfkceeG77XhoBzkrZFM21FRCqiV1I6IiLSggK+iEhFKOCLiFSEAr6ISEUo4IuIVIQCvohIRSjgi4hUhAK+iEhF/D/+LuspspnoWAAAAABJRU5ErkJggg==\n",
"text/plain": [
"<matplotlib.figure.Figure at 0x22a54e79f28>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"np.random.seed(42)\n",
"\n",
"m = 100\n",
"w = np.array([3, 4]).reshape([-1, 1])\n",
"x = np.linspace(-10, 10, m, [-1, 1])\n",
"X = np.c_[np.ones([m, 1]), x]\n",
"y = X.dot(w) + np.random.normal(0, 5, m).reshape([-1, 1])\n",
"\n",
"plt.scatter(x, y, s=10)\n",
"plt.show()"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [],
"source": [
"df = pd.DataFrame(data=list(zip(x, y.flatten())), columns=['x', 'y'])\n",
"df.to_csv(\"{}.csv\".format(datetime.datetime.now().strftime('%Y%m%d%H%M%S%f')), index=False)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# 批量梯度下降"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"[[2.56385135]\n",
" [4.03448317]]\n"
]
}
],
"source": [
"def gradient_descent(epoches, eta, save):\n",
" np.random.seed(42)\n",
" w = 10 * np.random.randn(2, 1) + 10\n",
" ws = []\n",
" mses = []\n",
"\n",
" # train\n",
" for epoch in range(epoches):\n",
" mse = np.mean((X.dot(w) - y)**2)\n",
" mses.append(mse)\n",
" ws.append(w.flatten())\n",
" gradients = X.T.dot(X.dot(w) - y)\n",
" w -= eta * gradients\n",
" \n",
" mse = np.mean((X.dot(w) - y)**2)\n",
" mses.append(mse) # last mse\n",
" ws.append(w.flatten()) # we need the last w\n",
" print(w)\n",
"\n",
" # save data\n",
" if save:\n",
" df = pd.DataFrame(data=ws, columns=['w0', 'w1'])\n",
" df.to_csv(\"{}.csv\".format(datetime.datetime.now().strftime('%Y%m%d%H%M%S%f')), index=False)\n",
" df = pd.DataFrame(data=mses, columns=['mse'])\n",
" df.to_csv(\"{}.csv\".format(datetime.datetime.now().strftime('%Y%m%d%H%M%S%f')), index=False)\n",
"\n",
"gradient_descent(epoches=1000, eta=0.00005, save=True)"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"[[2.48076741]\n",
" [4.03448317]]\n"
]
}
],
"source": [
"w = np.linalg.inv(X.T.dot(X)).dot(X.T).dot(y)\n",
"print(w)"
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {},
"outputs": [
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAXwAAAD8CAYAAAB0IB+mAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDIuMS4yLCBodHRwOi8vbWF0cGxvdGxpYi5vcmcvNQv5yAAAIABJREFUeJzt3Xl8VNX5x/HPQwBBXFhERNkVLREVS0SporjErVVq3VB/lSqKG6hQN8S6oLbgSkEBUalarbsWqrgEV8Q1gAgEERQRBFlUEGULyfn9cSY4hJnMTGbuzCTzfb9evMyduXPvySQ+c/Kcc55jzjlERKT2q5PpBoiISHoo4IuI5AgFfBGRHKGALyKSIxTwRURyhAK+iEiOUMAXEckRCvgiIjlCAV9EJEfUzXQDwu2yyy6uXbt2mW6GiEiNMm3atFXOueaxzsuqgN+uXTuKi4sz3QwRkRrFzBbFc55SOiIiOUIBX0QkRyjgi4jkCAV8EZEcoYAvIpIjFPBFRHKEAr6ISKZt3gzl5YHfRgFfRCRNikqWc+OE2RSVLP/1wU8/hYMPhjFjAr9/Vi28EhGpTYpKljNl/kp6dPSLYC9/cgbrS8t4tngJF3bbna6P3UePF8dTp6wMNm6Eiy+GvLzA2qOALyISgKKS5VsF+EM6NGV9aRkAnRbO5uQxF7LX90sox5h+8jlMOmsAB89bRWF+i8DapIAvIlKF8F56IsF4yvyVWwJ8xX+buU30f+Nf9Jn2EnVwfNm0FdeeMIAZbTpT9ukqnpjzIyPPOjCwoK+ALyISReVeeiLBuEfH5jxbvIT1pWU0rJdH/80LyX/8KhouXUJ5nTweOOQ07ul+Jpvr1aes3AH+g2HK/JWBBfyUDdqaWZ6ZzTCzl0LH7c3sIzObb2ZPm1n9VN1LRCQdKvfSp8xfGfdrC/NbMPKsA+m3X2Mmz36Erv1603DpEujShTrFn9DhwX9yZo+OXNxzLxrW83n7hvXytuT7g5DKHv4VwFxgp9DxcOBe59xTZjYW6AsEPwwtIpIilXvpiQbjws+nUnjVZfDdd7DddnDTTXDVVVCvHoWwpSffpXXjaqWNEmXOueQvYtYKeBS4HRgEnASsBHZzzm02s+7Azc6546q6TkFBgVN5ZBHJJtXK4X/3HfTvD88/748PPRQeegh+85tA2mhm05xzBbHOS1UPfwRwDbBj6LgZsNo5tzl0vATYI9ILzawf0A+gTZs2KWqOiEhqFOa3iD/QOwePPQYDB8KPP0KjRjBsGFx6KdTJ/LKnpFtgZn8AVjjnpoU/HOHUiH9KOOfGOecKnHMFzZsHl7sSEQnUokVwwgnwl7/4YH/ccTBnju/pZ0Gwh9T08A8FTjazE4EG+Bz+CKCxmdUN9fJbAUtTcC8RkexSXg6jR8N118Evv0CTJnDvvXDuuWCR+r6Zk/THjnNusHOulXOuHdAbeNM5dw7wFnBa6LQ+wIRk7yUiklXmzYPDD4cBA3ywP+00KCmBPn2yLthDsLV0rgUGmdkCfE7/4QDvJSKSlIh1bqIpLYV//AMOOACmToUWLfwA7bPPwm67Bd/Yakrpwivn3NvA26GvvwK6pfL6IiLVEWumTUILrGbMgL59/X8BzjsP7r7bp3KyXHaMJIiIxCmhnji/BvPHPljE5U/OiPi6uBZYbdgA118PBx3kg33btvDaazB+PEXLNiXUpkxRwBeRGiOe4F1ZPMG8R8fmW6123bFBva0D+NSp0KWLT+OUl8Pll/Pms29y4/rdufO1eQm3KVNUS0dEaoxIwTvWHPl4VstWlEGYMn8lOzaox/j3FrK+tIyXp37BPt++TJsn/+Xn2P/mN/DQQxQ12WtLCijPoCw06TzoWjjJUg9fRGqMyj3xeEodVATzc7u3rTI3X5jfgqG9OrN2QynrS8s4/KtpTHjgYtr8Z7yfRz9kiE/lHHroVh88ZQ7y6lhCbcoU9fBFpMYI74knUuogkdWyRzavy4GvjOCUzyYD8FOn/djpP4/5lE5I5b8azj+sPWs3lAZeCydZKamlkyqqpSMiQatyxs7zz8Nll8Hy5ZTWq8/X/a+m4x03Q91t+8bVrZMfhHhr6Sjgi0jOCJ9+2bBe3q8pnmXLfAmEF17wJx52mC92ts8+mW1wnOIN+Mrhi0jO2GbQ94sV8MgjkJ/vg/0OO8D998M770QM9olOCc02CvgikjPCB333/GUVV9w5wC+cWr0ajj8eZs+OWtmyOlNCs40CvojkjML8Fow88wDu//49Xnv4Upq9/w40bepLGk+a5BdTRZHM7lfZQrN0RCR3zJ1L4UUXwPvv++PTT4dRo3wtnBiS3f0qGyjgi0jtV1oKd94Jt9wCmzb5AmejR8Mpp8R9iXinhGbT7J3KFPBFpHabPt0XO/v0U398/vlw113VKnYWaz5/QkXYMkA5fBGpndavh8GDoVs3H+zbtYOiInj44S3BPtWzbrI9z6+ALyK1z3vv+ZWxw4b5YmdXXAGzZsExx2w5JYhZN9Up/ZBOSad0zKwB8C6wXeh6zznnbjKz9sBTQFNgOvBn59ymZO8nIhLV2rW+V3///f64Uyffo+/efZtTq1OILZbqln5Il1T08DcCRznnDgC6AMeb2SHAcOBe51xH4EegbwruJSI5KK7Uy6uvQufOPtjXrQs33OCLnUUI9hBcb7yiCFu2BXtIQQ/f+doMP4cO64X+OeAo4OzQ448CNwNjkr2fiGSfIGemxBwI/f57GDTIz6UH6NrV9+oPOKDK62Z7bzwIKZmlY2Z5wDRgL+B+4EtgtXNuc+iUJcAeqbiXiGSXoGemRE29OPdrsbMVK6BBAz/tctCgiMXOIkmkimZtkJJBW+dcmXOuC9AKv49tp0inRXqtmfUzs2IzK165MrtGtEUktqBnpkRMvSxbBqee6hdOrVgBPXrAzJlwzTVbgn1Nr3sThJTO0nHOrcZvYn4I0NjMKj5mWwFLo7xmnHOuwDlX0Lx5do1oi0hsQc9M2WoDk95dKPxoki929uKLvtjZ6NHw9tuw995bXlMb6t4EIRWzdJoDpc651WbWEDgGP2D7FnAafqZOH2BCsvcSkcyKlKtPRy68ML8FhQ3XQb8/w2S/MQknnABjx0KbNtucH8QMnNogFTn8lsCjoTx+HeAZ59xLZlYCPGVmtwEzgIdTcC8RyZCqcvXRcuEpGcwtK/MzbwYPhnXrfLGzf/4TzjkHzCK+pDbUvQlCKmbpfAYcGOHxr/D5fBGpBRLtNadkMHfuXF8W4YMP/PEZZ/hiZ7vuutV9MvFXR02kWjoiEpdEe81JpVVKS+GOO2DoUF/srGVLn6v/4x+3Oq3yh0rlvWUV6LemgC8icUm01xzvB8Q2PfTp032Bs5kzAfj21LPZ46H7oXHjbV5b+UNl7DtfUlbusrJwWTZQwBeRuCXSa47nAyK8hz7xgy95cdVk2j8yBsrKWNx4N649rj8zOnZl5NKNFG4b77f6UMkzKCv3s781UBuZAr6IBCbWB0RFD/2gxbMZ9uoo2v/wLZjx/kn/R9+9/sj6+g2giuAd/qGyY4N6jH9voQZqq6CALyIZ03P3hvxm8ljOnvYSAD/vuTc7PP4ov+zUHp6cAXEE7/APlS6tG2ugtgrmS+Fkh4KCAldcXJzpZohIALbJ1b/yClx0ESxeTFleXRZdOIAOI/4B220X+XyJysymOecKYp2nHr6IBC48V1/0bgn7LniW3Sc+55/s2pW88ePpsP/+W71Gs2xSTwFfRAI3Zf5K1m/azInzpnJL0Viar1vti50NHQoDB8Zd7EySo3dZRAJ3dONyDp/wd46Z5xdQLd6vgOcvuZl9jyigUME+bfROi0hwnIN//YsjBg2CNWvY0LAR715wNVfudBDrFjkaPjlD8+XTSHvaikgwFi6EY4/1pRHWrIETT6TBvLm8d/SprNu89Xx5SQ8FfBFJrbIyX9ysc2df2bJZM3j8cXjpJWjdOus3+q7NlNIRqeWiTW8MZNpjSQlccMGvxc569/bBP6zYmQqbZY7m4YvUYuHTIRvWy9uSL4/2eLWVlsLw4XDrrb7Y2e67w5gxcPLJqftmJKp45+ErpSNSi0XbfjCRbQljbhU4bRoUFMDf/uaD/YUXwpw5CvZZSAFfpBaLli+PN49e1VaBb0z/mil/Oh/XrRt89hl06ABvvAHjxkWsbCmZl4otDlsDjwG7AeXAOOfcP82sKfA00A74GjjDOfdjsvcTkfhFy5fHm0ePVtP+k8f+y54D+9Puh28pszos6XMRbe+/Gxo1Stv3JolLxaDtZuCvzrnpZrYjMM3MioC/AG8454aZ2XXAdcC1KbifiCQgWomCeEoXVK5p37NlA7j0Ug4aMwaAL5q14ZoTr2D/U45lqIJ91kvFFofLgGWhr9ea2VxgD6AX0DN02qPA2yjgi2SlaDN2wv8SOGXZTA48rR8sXkx53bqM7n4mI7udRl7DBlymqZU1Qkpn6ZhZO+BdoDPwjXOucdhzPzrnmkR4TT+gH0CbNm26Llq0KGXtEZHYYs7YWbXK17t5/HF/XFAA48dTlLerplZmibRXyzSzHYDngSudcz9ZlN3kK3POjQPGgZ+Wmar2iOSyRObYR9171jl49lno3x9WrvTFzm69Fa68EurWpRAU6GuYlMzSMbN6+GD/hHPuhdDDy82sZej5lsCKVNxLRKpW1cyaSCLO2Fm6FE45Bc480wf7I46AWbPgqqtU2bIGSzrgm+/KPwzMdc7dE/bURKBP6Os+wIRk7yUisSUyxx5+zdOf270tI3t3ofCDlyA/HyZMgB13hLFj4c03Ya+90tF8CVAqPqoPBf4MzDKzT0OPXQ8MA54xs77AN8DpKbiXiMRQeWZNPLVqCvNbUNjgF7jwHB/cAX7/ex/sW7UKuMWSLqmYpfMeEC1hf3Sy1xeRxCRcq6asDEaNgiFDYN06X+xs5Eg46yyIcyxOagYl40Rqobi3B5wzx5cv/ugjf9y7tw/2zTMzzVL72AZLpRVEctGmTX7GzYEH+mC/++4+Z//kkxkN9okMNkviFPBFckhRyXLG3vUUa/frAjfe6KtcXnihL2uc4WJniQ42S+KU0hHJUqlOb7wxbSHfDLiGCz98gTxXzrrWbdn+kfFw1FGB3zse1RlslsSoHr5IFkp5vfp33uH7s86l2bJvKLM6jC84me+uup6/nXFQ8PdOgHL41ZP2lbYikjpRV78m6qef4NprYexYmgHzm7fl6uMvZ17bfEZ2bhPsvash7sFmqRbl8EWyUEr2fX35Zdh3Xz+Xvl49uOkmFr0+hf1PPbbKXrv2nK29lNIRyVLVTm+sWuXr3TzxhD8+6CB4+GHYb7/g7y0ZEW9KRwFfpLZwDp5+GgYM8EG/YcNfi53l5WW6dRIg5fBFcsm338Kll8LEif64Z0948EHVv5GtKIcvUpM55wN7fr4P9jvtBA884PeWVbCXStTDF6mpvvzSL5p66y1//Ic/wJgxKnYmUamHLxKQopLl3Dhh9pYSAZWPq62sDO65xw/CvvUW7LIL/Oc/voevYC9VUA9fJADhi5eeLV7C+Ye1Z/x7C7ccR5sWGXN2zOzZrDn7XHaeNcMfn302jBiRsfo3UrOohy8SgMqLlyaXfBezTkyVxcM2bYJbbqH8wN+y86wZLNuhGZeccRNFQ+5RsJe4pWqLw/FmtsLMZoc91tTMisxsfui/22xgLlJbVV68dEz+bjEXM0UtHvbJJ9C1K9x8M3U2l/JEl+M59oLRvNL+IBUYk4SkKqXzCHAf8FjYY9cBbzjnhpnZdaHja1N0P5GsFmkTki6tG1eZrqlcPOyIVo38HrL33gvl5bDnnhQPGc5tXzVSgTGplpQtvDKzdsBLzrnOoeN5QE/n3LLQJuZvO+f2qeoaWnglua4ih3/yj19QcNs1fiZOnTowaBDccgtsv33Wr4LN9vbVRtmw8KqFc24ZQCjo7xrgvURqhcI9GlD4z1Ewbpx/oHNnXxahW7dfz8niAmOVB6vTWWlTYsv4oK2Z9TOzYjMrXrlS+Uip2ZKaevnSS77Y2bhxvtjZLbfAtGnQrVvqpnQGTJuYZLcgA/7yUCqH0H9XRDrJOTfOOVfgnCtortkGUoNVe4u+lSv99MqTTvIlErp1g+nT/Y5U9evXqK3/VGkzuwUZ8CcCfUJf9wEmBHgvkYxLuHfrnN9DNj/f/7dhQ7+g6v33fSqnutfNoIrB6nO7t1U6Jwulalrmk8AHwD5mtsTM+gLDgEIzmw8Uho5Faq2EerdLlvg9ZM8+21e2PPJImDULBg7cprJlvNfNlrRPYX4LhvbqrGCfhVQeWSSFYs5QKS+Hhx6Cq6/2u1HttBMlf72Jp/YvpMfeu0YNktGuW/H4jg3qbVnJm+5tCSXzVA9fJNssWOCLnb39tj8+6STeHTiUi978rlqBOnxGTJ5BWdj/yud2b8vQXp2jv1hqlXgDfsZn6YjUemVlcPfdsP/+Ptg3bw5PPQUTJjD5p7rVzs+H5/bLHOTVMUCDpRKdiqeJEOBiodmz4fzzfXkEgHPO8cXOdtkF2HZ1bSKBuvJrzz+sPWs3lGrBk0SllI7kvPDUSMry35s2wd//7v+VlvqyxWPHwu9/H/H+1f2w0apWgexYaStSI0Sa9phU8Pz4Y9+rnzPHH198MQwf7nejiiCZlbPZvOpWso9y+JLzUjbtcd06+OtfoXt3H+z32svn7MeMiRrsRdJJPXzJeZEqW1YWs0bMW2/BBRfAV1/5YmdXXw033wzbb5++b0QkBgV8EWKnRqKmfdas8cH9wQf9ifvtB+PHQ0HV6VTl3iUTlNIRiUPEtM///ufLIjz4oC92NnQoFBfHFexrSm0cqV3UwxepJFLvOzztc1Qzo+etV/q59AAHH+xLGO+7b1zXT/kgsUic1MMXCVNV77uw064M/eUzev7xCB/st9/e70Y1dWrcwR5UUVIyRz18kTBRe9+LF8Mll8DLL/sTjz7a163v0CHhe8QzSCwSBAV8kTDbrHzdsxk88IAfmF27FnbeGe6+m6JDfs+UWavosWF5tQK25s9LJijgi4QJ730fV38th/Y/C955xz/ZqxeMHk3R6rytpmiqpIHUFAr4IpUU7t2Mwkn/hr/9DTZs8MXO7rsPTj8dzJjyyeyt0j5j3/mSsnKnPVwl6wU+aGtmx5vZPDNbYGbXBX0/kaR89plfKXv11T7Y/9//wdy5cMYZYL4aZfiga55BWbmvR5Xtu1GJBNrDN7M84H78jldLgE/MbKJzriTI+4okbOPGX4udbd7si5098ACceOI2p4anfSpvPKIZN5LNgk7pdAMWOOe+AjCzp4BegAK+ZI8PP4S+faEk9Gt5ySUwbFiV9W/CB127tG6sGTdSIwQd8PcAFocdLwEODvieIvH55Refpx8xwm8o3rGj337w8MMTuoxm3EhNEXQO3yI8tlUBfjPrZ2bFZla8cqXyn5Imb7zh697ce6/PzV9zDcycmXCwF6lJgg74S4DWYcetgKXhJzjnxjnnCpxzBc2bK/8pVYtZojiW1av9vrLHHAMLF/ptBz/6yNerb9gwtY0VyTJBB/xPgI5m1t7M6gO9gYkB31NqqaSLjk2Y4IudPfQQ1K8Pt94aV7Ezkdoi0IDvnNsM9AdeA+YCzzjn5gR5T6m9IpU9iMuKFdC7N/zxj7BsGRxyCMyYQdGfLuTGSfNUrVJyRuDz8J1zk5xzezvn9nTO3R70/aT2SrjomHPw+OPQqRM8/bQvdjZiBLz3HkU0U4liyTlaaStJS9dmHgkVHVu82O8lO2mSPz7mGF/srH17QCWKJTcp4EtSYm79l2LhUyAjftCUl/sFU9de+2uxs3vugfPO27JSFiIUSdOCKckBCviSlEz1lCN90Gy/6Cv2uHoA7eZMA6Dk4KNYNeweDu95wDavV4liyUUK+JKUTPWUwz9oNm3cxI8330aPF8fRYPMmVm3fmKHHXcLEjr+j4RvLGLnrbhEDuhZMSa5RwJekpKOnHCl1U/FB0+7b+dz56ig6L5sPwPOdj+LWoy5gdUNfFqHyXx3aPFxymQK+JC3InnK0MYLCPRvz8srXafvYKPLKNrN+tz24oudFvN72t9TPq0N9YFNZ+VZ/daR7vEEk2yjgS1pUt2cdcYxgzVfQty8d5s71J112GQ3/8Q9OX7yO3UL3qHht+P00M0dynQK+BC6ZnnX4GEFTV0rf50bCEw/9Wuzs4YehRw8ACvN33Oq6le+hmTmS6xTwJXDJ9Kwrxgi+fW4iZ467lYbfLoa8PL9ByY03JlT/pjrjDcr5S22igC+BS6pnvXo1hXdfD+PH++MDDvC9+q5dq9WWRMYblPOX2kYBXwJXVc+6yh70f/8Ll17q69/Urw833eR79vXqpaXdyvlLbaOAL2kRqWcdtQe9fDkMGADPPutP7N7d9+o7dUprm5Xzl9pGAV8yZpse9BcrKJz2Olx5JfzwAzRqBP/4h+/l5+WlvX1ajSu1jQK+ZEx4D7rDuu+5/J57YMqb/snCQl/srF27jLZRq3GlNlHAl4wpzG/ByDMPYMOo+znh8RHUXfcLNG7stx3s02erYmcikjwFfMmcL76g8NILYMoUf3zKKXD//dCyZUpvo6mVIl5SG6CY2elmNsfMys2soNJzg81sgZnNM7Pjkmum1CqbN/s9ZPff3wf7Fi3guefghRcCCfba6ETES3bHq9nAn4B3wx80s3z8/rX7AscDo80s/aNukn1mzoSDD4brroONG33qpqQETj01kNtVe1tEkVooqYDvnJvrnJsX4alewFPOuY3OuYXAAqBbMveSGm7DBrjhBr9h+PTp0KYNvPoqPPIING0a2G0T3hZRpBYLKoe/B/Bh2PGS0GOSi95/H/r2hc8/98f9+8Pf/w477hj4rTW1UuRXMQO+mU0Gdovw1BDn3IRoL4vwmIty/X5AP4A2bdrEao7UJD//DEOGwKhRvtjZPvvAQw/BYYeltRmaWinixQz4zrljqnHdJUDrsONWwNIo1x8HjAMoKCiI+KEgNUfFjJheK0voevu18PXXftHUNdf4YmcNGmS6iSI5K6iUzkTgP2Z2D7A70BH4OKB7SZYoKlnODePfZdDrD9J1VpF/sEsXXxbht7/NbONEJLmAb2anAKOA5sDLZvapc+4459wcM3sGKAE2A5c558qSb65ks1X/for/jb2FXX/5kY159Xj3rEspHH9n2oqdiUjVkgr4zrkXgRejPHc7cHsy15caIlTs7KxQsbPiPTpx40kDGTjgZAV7kSyilbZSfc7Bv//ti539+CM0asTnV1zP/wr+wMB9khso1epYkdRTwK/BggiKcV/zm2/goov8XHqAY4+FBx7gN+3acUsK2qCNR0RSL9mVtpIhQZQMiOua5eW+3s2++/pg36SJXzz16qspq2yp1bEiwVDAr6GCCIoxrzlvHhxxhF849fPPvhxCSUnKK1tqdaxIMBTwa6gggmLUa27eDMOG+f1k33vv12Jnzz0Hu0Vak1e1opLl3DhhdtS/SipWx57bva3SOSIpZM5lz1qngoICV1xcnOlm1BhpyeF/+qkvizB9OgDTj+rFmluHceTvflPt61fk5xvWy1NAF0kBM5vmnCuIdZ4GbWuwIEoGbLnmhg2+LMLw4VBWxvrdW9H/iIt5o00XGr7yNSMbN4l570gfSNoYXCRzlNKRbb3/Phx4oC9wVl4OAwZw993P80abLoAP1P/5aFGVaZloA8DKz4tkjgJ+joqYR//5Z7j8cl/c7PPPfbGzKVNg5EgO3r/tlkBdP68OUxd8X+VsnmgDwMrPi2SOUjo5KOI89yUzoV8/WLTIFzu79lr429+2FDsLLzO8+Id1vDXPB/BoaZnwDcor9+RVvVIkMxTwc1B477ve2jU0G3ARvBmqdH3ggTB+vC96VklFoC4qWc6HX/0QMZiHn5toHXqtrhUJlmbp5KCKHv7hs6dw2+QxNP/5R9huO7j5ZvjrX7eqfxMtCKc6OGv2jkj1aZaORFXY1PH2R/fRouhl/8Bhh/mNSfbZZ6vzqipxkOq0jGbviARPg7a5xDl49FHIz/fBfocd4L774J13tgn2kN4SB5q9IxI89fBroYjplkWLfLGz117zx8cdBw88AG3bRr1OVQOvqaa9Z0WCl1QO38zuBE4CNgFfAuc551aHnhsM9AXKgMudc6/Fup5y+MnbJhd+5gEUvv08XHcd/PKLL3Y2YgT8+c9x1b/RQKpI9ktXDr8IGOyc22xmw4HBwLVmlg/0BvbFb3E42cz21q5XwQtPw7T8bhF7n3E9zJ3hnzztNJ/CaRF/4NYUSpHaI6kcvnPudefc5tDhh/jNygF6AU855zY65xYCC4BuydxL4tOjY3N2rOO49INneOVfA2g7d4YP8M8/D88+m1CwF5HaJZU5/POBp0Nf74H/AKiwJPSYBKxw41Km/vd6dpo7yz9w3nlw993QpInSMyI5LmbAN7PJQKQauEOccxNC5wzBb1b+RMXLIpwfcbDAzPoB/QDatGkTR5Mlog0bYOhQuOMOdior84Ox48b5najQLlIiEkfAd84dU9XzZtYH+ANwtPt1BHgJ0DrstFbA0ijXHweMAz9oG0ebpbKpU/nlz3+h0cIFODPs8svh9tv9tMsQzXMXkaRy+GZ2PHAtcLJzbl3YUxOB3ma2nZm1BzoCHydzr1wTa5MQANauhQEDcD160GjhAhY0bcU5595J0UXXbxXsoep57nHdS0RqvGRz+PcB2wFF5qf4feicu9g5N8fMngFK8KmeyzRDJ35xpV9ee80XO/vmG8rr5DH6kNO473dnsrFuffaK0HuPNs9dqR6R3JFUwHfO7VXFc7cDtydz/VxVZfrlhx9g0CC/YhbgwAP5eMidjP60lI0xFkhFmmKpVI9I7lBphSwUNf3y/POQn++D/Xbb+X1mP/6Y7qceXe0a8yppIJI7VC0zS201hbJJOfTvDy+84J+MUuwsJfdS716kxol3pa0CfjarKHY2cCCsXu0HYocPh4svhjr640xEPJVHrum+/toPyhYV+ePjj4exY6ssdiYiUhV1EyvJ+BTF8nIYNQo6d/bBvmlTeOwxmDRJwV5EkqIefpiMT1GcOxcuuADef98fn366D/6qfyMiKaAefphunyfAAAANRklEQVR0bvixldJS+Pvf/T6y778Pu+3mB2ifeUbBXkRSRgE/TEamKE6fDt26wZAhsGkTnH8+lJTAKaek5PIZT1GJSNZQSidMKnddijnVcf16X+zszjuhrAzatYMHH4RjqixdlHAbtIpWRCoo4FeSig0/Ygba996Dvn3hiy/8rlNXXAG33bZN/ZtkaRWtiITLqZROqtIbsa4TdSxg7Vq/gKpHDx/sO3WCqVP9loNVBPvqtluraEUkXM4svNpmr9dqpjfiuU7Ec76Z4TcR/+YbqFvX7zF7ww2+REKA7dYqWpHaTwuvKqlOeiNSsIznOuFjAUc1z6Pn8Gv8XHqArl3h4YfhgAOi3iPZdlduiwK9iEAOpXQSTW9U9Kwf+2ARlz85Y0s6Jd7rFHbalaGln9PzlJ4+2Ddo4MsifPjhVsE+0j2SabeISDQ508NPdAZOtJ51XNdZtgwuuwxefNEf9+jhi53tvXdc96iq3QA3TpitFI2IJCxnAj4klt7o0bE5zxYv2ZI7D+9ZR72Oc/DII75efUWxszvu8Ln7CMXOqrpHpHZrmqWIJCOpgG9mtwK9gHJgBfAX59xS89tf/RM4EVgXenx6so1Np4Tn5C9c6IudTZ7sj084wRc7q2Jj9kR775pmKSLJSDaHf6dzbn/nXBfgJeDG0OMn4Pex7Qj0A8YkeZ+MKMxvwdBenasOqmVlMHKkL3Y2ebIvdvbvf8PLL1cZ7CvfA1A+X0QClewWhz+FHTYCKuZ49gIec37O54dm1tjMWjrnliVzv6wzd65fQPXBB/74jDN8sbNdd034UtXJ56t3LyKJSDqHb2a3A+cCa4AjQw/vASwOO21J6LHaEfBLS31ufuhQX/+mZUsYPZqivbsz5YMV9OjoEg7GiebzRUQSFTOlY2aTzWx2hH+9AJxzQ5xzrYEngP4VL4twqYgrvMysn5kVm1nxypVpqk6ZjGnToKDAL5ratMn38EtKKNq7e8yUTFUqeu/V2ZdWRCQeMXv4zrl4q3n9B3gZuAnfo28d9lwrYGmU648DxoFfaRvnvdJv/Xq45Ra46y6ft2/fHh58kKKWnZnyzhIW/7Au6QFV9d5FJEjJztLp6JybHzo8Gfg89PVEoL+ZPQUcDKyp0fn7d9/1G5PMn++LnV15Jdx2G0WLft4yTbJ+Xh3q59VhU1m5BlRFJCslm8MfZmb74KdlLgIuDj0+CT8lcwF+WuZ5Sd6n2pKqJfPTTzB4MIwe7Y/z831ZhEMOAWDK/IVbevWbyso5cp/mtG66vQZURSQrJTtL59QojzvgsmSunQpJLVR65RW/YGrxYl/sbPBgv0lJWLGzygOtZx/cVoFeRLJWrV5pW62FSt9/DwMH+rn04IudjR9PUd0WTHl1/la9d02TFJGapFYXT0tooZJzfg/ZTp18sG/QwE+9/PBDiuq2iDoDJ67FWSIiWaBW9/Dj7oEvXeqLnf33v/748MN9sbOOHQGVNBCR2qFWB3yIMdXRORg/ntKBg6i39ic2N9qBunfd6WvihBU7i3dRVLy0KYmIZEKtCPjVCqBffeUD+xtvUA94s0MBt/7+cq4//FgKK1W2TPXm5qp4KSKZUOMDfsIBtKwMRo2i7PrryVu/nrU77MwNPS9gQn5PMIuarknVoiilh0QkU2r8oG3UDcMjKSmBww6DgQPJW7+eiZ0Op7DvGF7Z/2gwS8uCKVW8FJFMqfE9/Ljy65s2+e0Fb7sNNm3ip6a7MqjnRUzueDAAR+7VLG0LpjSVU0QypcYH/JgBtLjYFzj77DN/fOGFTDv/Kqa+/BWkacFU5TEG1cwRkUwwvyg2OxQUFLji4uLUXGz9erjpJrj7bigvhw4d4MEH4aijgPTNlAkfY2hYL0+DtCKScmY2zTlXEOu8Gt/Dj6T40RdpP/hKmi37xk+vHDTI165v1GjLOenqZWuQVkSyRY0ftN3KTz+xuHcfCv7yJ5ot+4YFu7Thoyde8r38sGCfThqkFZFsUXt6+JMmwUUX0XrJEkrr5HF/9zMYfcgZ9G7YmoMz2CwN0opItqgdAX/gQBgxAoA1+x7Anw+9iM+atMmaHrUGaUUkG9SOlM6hh/piZ3fdxc6fFjPgij9pq0ARkUpSMkvHzK4C7gSaO+dWmZkB/8RvgrIO+Itzbnqs6yQ1S2fpUth99+q9VkSkBot3lk7SPXwzaw0UAt+EPXwC0DH0rx8wJtn7xKRgLyJSpVSkdO4FrgHC/1ToBTzmvA+BxmbWMgX3EhGRakoq4JvZycC3zrmZlZ7aA1gcdrwk9JiIiGRIzFk6ZjYZ2C3CU0OA64FjI70swmMRBwvMrB8+7UObNm1iNUdERKopZsB3zh0T6XEz2w9oD8z0Y7S0AqabWTd8j7512OmtgKVRrj8OGAd+0DaRxouISPyqndJxzs1yzu3qnGvnnGuHD/K/dc59B0wEzjXvEGCNc25ZaposIiLVEdTCq0n4KZkL8NMyzwvoPiIiEqeUBfxQL7/iawdclqpri4hI8rKqPLKZrQQWVfPluwCrUticVMrWtqldiVG7EqN2JSaZdrV1zsWsI5NVAT8ZZlYcz0qzTMjWtqldiVG7EqN2JSYd7aodtXRERCQmBXwRkRxRmwL+uEw3oArZ2ja1KzFqV2LUrsQE3q5ak8MXEZGq1aYevoiIVKFGBXwzO93M5phZuZkVVHpusJktMLN5ZnZclNe3N7OPzGy+mT1tZvUDaOPTZvZp6N/XZvZplPO+NrNZofOquQlAwm272cy+DWvfiVHOOz70Pi4ws+sCbtOdZva5mX1mZi+aWeMo56Xl/Yr1vZvZdqGf8YLQ71K7oNoSds/WZvaWmc0N/f5fEeGcnma2Juxne2PQ7Qq7d5U/m9CK+5Gh9+wzM/ttwO3ZJ+x9+NTMfjKzKyudk7b3y8zGm9kKM5sd9lhTMysKxaIiM2sS5bV9QufMN7M+STfGOVdj/gGdgH2At4GCsMfzgZnAdvj6Pl8CeRFe/wzQO/T1WOCSgNt7N3BjlOe+BnZJ8/t3M3BVjHPyQu9fB6B+6H3ND7BNxwJ1Q18PB4Zn6v2K53sHLgXGhr7uDTydhp9bS3zZEoAdgS8itKsn8FI6f5/i/dngV92/gi+qeAjwURrblgd8h5+nnpH3Czgc+C0wO+yxO4DrQl9fF+n3HmgKfBX6b5PQ102SaUuN6uE75+Y65+ZFeKoX8JRzbqNzbiG+pEO38BNCu3AdBTwXeuhR4I9BtTV0vzOAJ4O6R0C6AQucc1855zYBT+Hf30A45153zm0OHX6IL7SXKfF8773wvzvgf5eODv2sA+OcW+ZCO8Y559YCc6lZ5cYzuT/G0cCXzrnqLuhMmnPuXeCHSg+H/x5Fi0XHAUXOuR+ccz8CRcDxybSlRgX8KsRTf78ZsDosuARdo78HsNw5Nz/K8w543cymhUpEp0v/0J/V46P8GZnJvQzOx/cEI0nH+xXP977lnNDv0hr871ZahFJIBwIfRXi6u5nNNLNXzGzfdLWJ2D+bTP5O9SZ6pytT7xdACxcqKBn6764Rzkn5+xZU8bRqsyrq7zvnJkR7WYTHKk8/irtGfyxxtvEsqu7dH+qcW2pmuwJFZvZ5qCeQlKraht9q8lb8930rPuV0fuVLRHhtUlO54nm/zGwIsBl4IsplAnm/Kjc1wmOB/R4lysx2AJ4HrnTO/VTp6en4tMXPobGZ/+K3GE2HWD+bjLxnoTG6k4HBEZ7O5PsVr5S/b1kX8F2U+vsxxFN/fxX+T8m6oZ5Z1Br9ybbRzOoCfwK6VnGNpaH/rjCzF/HphKQDWLzvn5k9CLwU4am49zJIVZtCg1F/AI52oeRlhGsE8n5VEs/3XnHOktDPeWe2/XM95cysHj7YP+Gce6Hy8+EfAM65SWY22sx2cc4FXjMmjp9Nyn+n4nQCMN05t7zyE5l8v0KWm1lL59yyUHprRYRzluDHGiq0wo9fVlttSelMBHqHZlC0x39Sfxx+QiiQvAWcFnqoDxDtL4ZkHQN87pxbEulJM2tkZjtWfI0fuJwd6dxUqpQ3PSXKPT8BOpqf0VQf/yfxxADbdDxwLXCyc25dlHPS9X7F871PxP/ugP9dejPah1SqhMYIHgbmOufuiXLObhVjCeY3IaoDfB9ku0L3iudnk6n9MaL+lZ2p9ytM+O9RtFj0GnCsmTUJpV+PDT1WfekYpU7VP3yQWgJsBJYDr4U9NwQ/w2IecELY45OA3UNfd8B/ECwAngW2C6idjwAXV3psd2BSWDtmhv7Nwac20vH+/RuYBXwW+oVrWbltoeMT8TNBvgy6baGfxWLg09C/sZXblM73K9L3DgzFfyABNAj97iwI/S51SMPP7TD8n/Kfhb1PJwIXV/yeAf1D781M/OD379L0OxXxZ1OpbQbcH3pPZxE2wy7Adm2PD+A7hz2WkfcL/6GzDCgNxa+++HGfN4D5of82DZ1bADwU9trzQ79rC4Dzkm2LVtqKiOSI2pLSERGRGBTwRURyhAK+iEiOUMAXEckRCvgiIjlCAV9EJEco4IuI5AgFfBGRHPH/wBrpnIM4oTMAAAAASUVORK5CYII=\n",
"text/plain": [
"<matplotlib.figure.Figure at 0x22a569e2ac8>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"plt.scatter(x, y, s=10)\n",
"\n",
"xs = np.array([-10, 10])\n",
"Xs = np.c_[np.ones([2,1]), xs]\n",
"ys = Xs.dot(w)\n",
"plt.plot(xs, ys, 'r', linewidth=2)\n",
"plt.show()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## 学习率过小"
]
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"[[13.77884944]\n",
" [ 4.18644247]]\n"
]
}
],
"source": [
"gradient_descent(epoches=1000, eta=0.000001, save=True)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## 学习率过大"
]
},
{
"cell_type": "code",
"execution_count": 9,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"[[ 2.50642535]\n",
" [244.65957398]]\n"
]
}
],
"source": [
"gradient_descent(epoches=100, eta=0.0006, save=True)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# 随机梯度下降"
]
},
{
"cell_type": "code",
"execution_count": 10,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"[[2.28077622]\n",
" [3.77353814]]\n"
]
}
],
"source": [
"def stochastic_gradient_descent(epoches, eta, save=False):\n",
" np.random.seed(42)\n",
" w = 10 * np.random.randn(2, 1) + 10\n",
" ws = []\n",
" mses = []\n",
" \n",
" for epoch in range(epoches):\n",
" for sample in range(m):\n",
" mse = np.mean((X.dot(w) - y)**2)\n",
" mses.append([epoch*m+sample, mse])\n",
" ws.append(w.flatten())\n",
" \n",
" index = np.random.randint(m)\n",
" xi = X[index:index+1]\n",
" yi = y[index:index+1]\n",
" gradients = xi.T.dot(xi.dot(w) - yi)\n",
" w -= eta * gradients\n",
"\n",
" mse = np.mean((X.dot(w) - y)**2)\n",
" mses.append([epoch*m+sample, mse]) # last mse\n",
" ws.append(w.flatten()) # we need the last w\n",
" print(w)\n",
" \n",
" # save data\n",
" if save:\n",
" df = pd.DataFrame(data=ws, columns=['w0', 'w1'])\n",
" df.to_csv(\"{}.csv\".format(datetime.datetime.now().strftime('%Y%m%d%H%M%S%f')), index=False)\n",
" df = pd.DataFrame(data=mses, columns=['iteration', 'mse'])\n",
" df.to_csv(\"{}.csv\".format(datetime.datetime.now().strftime('%Y%m%d%H%M%S%f')), index=False)\n",
" \n",
"stochastic_gradient_descent(epoches=10, eta=0.005, save=True)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# 小批量梯度下降"
]
},
{
"cell_type": "code",
"execution_count": 17,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"[[2.7526632 ]\n",
" [4.00159136]]\n"
]
}
],
"source": [
"def minibatch_gradient_descent(epoches, eta, batch_size, save=False):\n",
" np.random.seed(42)\n",
" w = 10 * np.random.randn(2, 1) + 10\n",
" ws = []\n",
" mses = []\n",
" batches = m // batch_size\n",
" for epoch in range(epoches):\n",
" shuffled_indices = np.random.permutation(m)\n",
" X_suffled = X[shuffled_indices]\n",
" y_suffled = y[shuffled_indices]\n",
" for batch in range(batches):\n",
" mse = np.mean((X.dot(w) - y)**2)\n",
" mses.append([epoch*batches+batch, mse])\n",
" ws.append(w.flatten())\n",
" \n",
" X_batch = X_suffled[batch:batch*batch_size]\n",
" y_batch = y_suffled[batch:batch*batch_size]\n",
" gradients = X_batch.T.dot(X_batch.dot(w) - y_batch)\n",
" w -= eta * gradients\n",
"\n",
" mse = np.mean((X.dot(w) - y)**2)\n",
" mses.append([epoch*batches+batch, mse]) # last mse\n",
" ws.append(w.flatten()) # we need the last w\n",
" print(w)\n",
" \n",
" # save data\n",
" if save:\n",
" df = pd.DataFrame(data=ws, columns=['w0', 'w1'])\n",
" df.to_csv(\"{}.csv\".format(datetime.datetime.now().strftime('%Y%m%d%H%M%S%f')), index=False)\n",
" df = pd.DataFrame(data=mses, columns=['iteration', 'mse'])\n",
" df.to_csv(\"{}.csv\".format(datetime.datetime.now().strftime('%Y%m%d%H%M%S%f')), index=False)\n",
" \n",
"minibatch_gradient_descent(epoches=100, eta=0.0001, batch_size=10, save=True)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.6.4"
},
"toc": {
"base_numbering": 1,
"nav_menu": {},
"number_sections": true,
"sideBar": true,
"skip_h1_title": false,
"title_cell": "Table of Contents",
"title_sidebar": "Contents",
"toc_cell": false,
"toc_position": {},
"toc_section_display": true,
"toc_window_display": false
}
},
"nbformat": 4,
"nbformat_minor": 2
}

评论 ( 0 )

你可以在登录后,发表评论