182 lines
		
	
	
		
			4.3 KiB
		
	
	
	
		
			Plaintext
		
	
	
	
	
	
			
		
		
	
	
			182 lines
		
	
	
		
			4.3 KiB
		
	
	
	
		
			Plaintext
		
	
	
	
	
	
{
 | 
						||
 "cells": [
 | 
						||
  {
 | 
						||
   "cell_type": "code",
 | 
						||
   "execution_count": 7,
 | 
						||
   "metadata": {},
 | 
						||
   "outputs": [],
 | 
						||
   "source": [
 | 
						||
    "# 引入库\n",
 | 
						||
    "import numpy as np\n",
 | 
						||
    "import matplotlib.pyplot as plt\n",
 | 
						||
    "import os"
 | 
						||
   ]
 | 
						||
  },
 | 
						||
  {
 | 
						||
   "cell_type": "code",
 | 
						||
   "execution_count": null,
 | 
						||
   "metadata": {},
 | 
						||
   "outputs": [],
 | 
						||
   "source": [
 | 
						||
    "# 检查os位置\n",
 | 
						||
    "print(os.getcwd())"
 | 
						||
   ]
 | 
						||
  },
 | 
						||
  {
 | 
						||
   "cell_type": "code",
 | 
						||
   "execution_count": 9,
 | 
						||
   "metadata": {},
 | 
						||
   "outputs": [],
 | 
						||
   "source": [
 | 
						||
    "# 生成数据\n",
 | 
						||
    "def generate_data():\n",
 | 
						||
    "    w = 1.35\n",
 | 
						||
    "    b = 2.89\n",
 | 
						||
    "    x_min = 0\n",
 | 
						||
    "    x_max = 10\n",
 | 
						||
    "    x = np.linspace(x_min, x_max, 100)\n",
 | 
						||
    "    y = w * x + b\n",
 | 
						||
    "    y += np.random.normal(scale=0.5, size=y.shape)\n",
 | 
						||
    "    data = np.column_stack((x, y))\n",
 | 
						||
    "    return data\n",
 | 
						||
    "\n",
 | 
						||
    "# 保存数据\n",
 | 
						||
    "def save_data(filename, data):\n",
 | 
						||
    "    np.savetxt(filename, data, delimiter=',')\n",
 | 
						||
    "    print(f\"{filename} 已成功创建并写入数据。\")\n",
 | 
						||
    "\n",
 | 
						||
    "# 生成并保存数据\n",
 | 
						||
    "data = generate_data()\n",
 | 
						||
    "#save_data('./1_data.txt', data)"
 | 
						||
   ]
 | 
						||
  },
 | 
						||
  {
 | 
						||
   "cell_type": "code",
 | 
						||
   "execution_count": 10,
 | 
						||
   "metadata": {},
 | 
						||
   "outputs": [],
 | 
						||
   "source": [
 | 
						||
    "# 读取数据\n",
 | 
						||
    "#points = np.genfromtxt(\"./1_data.txt\", delimiter=',')\n",
 | 
						||
    "\n",
 | 
						||
    "points = data\n",
 | 
						||
    " \n",
 | 
						||
    "x = points[:, 0]\n",
 | 
						||
    "y = points[:, 1]"
 | 
						||
   ]
 | 
						||
  },
 | 
						||
  {
 | 
						||
   "cell_type": "markdown",
 | 
						||
   "metadata": {},
 | 
						||
   "source": [
 | 
						||
    "损失函数: \n",
 | 
						||
    "$$J(w,b) = \\frac{1}{2m} \\sum_{i=1}^{m} (y_{w,b}(x^{(i)}) - y^{(i)})^2$$\n",
 | 
						||
    "\n",
 | 
						||
    "梯度下降:\n",
 | 
						||
    "\n",
 | 
						||
    "分别对w和b求偏导数,然后更新w和b\n",
 | 
						||
    "$$\n",
 | 
						||
    "w = w - \\alpha\\cdot\\frac{\\partial J(w,b)}{\\partial w}\n",
 | 
						||
    "$$\n",
 | 
						||
    "\n",
 | 
						||
    "$$\n",
 | 
						||
    "b = b - \\alpha\\cdot\\frac{\\partial J(w,b)}{\\partial b}\n",
 | 
						||
    "$$"
 | 
						||
   ]
 | 
						||
  },
 | 
						||
  {
 | 
						||
   "cell_type": "code",
 | 
						||
   "execution_count": 11,
 | 
						||
   "metadata": {},
 | 
						||
   "outputs": [],
 | 
						||
   "source": [
 | 
						||
    "# 定义损失函数\n",
 | 
						||
    "def compute_loss(w,b):\n",
 | 
						||
    "    return np.sum((y-w*x-b)**2)/2*len(x)\n",
 | 
						||
    "\n",
 | 
						||
    "# 等效\n",
 | 
						||
    "def compute_loss_equivalent(w,b):\n",
 | 
						||
    "    sum = 0\n",
 | 
						||
    "    for i in range(len(x)):\n",
 | 
						||
    "        sum += (y[i] - (w*x[i]+b))**2\n",
 | 
						||
    "    return sum/(2*len(x))\n",
 | 
						||
    "\n",
 | 
						||
    "# 定义梯度下降\n",
 | 
						||
    "def gradient_descent(w,b,alpha,num_iter):\n",
 | 
						||
    "    m = len(x)\n",
 | 
						||
    "    for _ in range(num_iter):\n",
 | 
						||
    "        # 计算梯度\n",
 | 
						||
    "        dw = -np.sum(x*(y-w*x-b))/m\n",
 | 
						||
    "        db = -np.sum(y-w*x-b)/m\n",
 | 
						||
    "        # 更新w和b\n",
 | 
						||
    "        w = w - alpha*dw\n",
 | 
						||
    "        b = b - alpha*db\n",
 | 
						||
    "    return w,b"
 | 
						||
   ]
 | 
						||
  },
 | 
						||
  {
 | 
						||
   "cell_type": "code",
 | 
						||
   "execution_count": null,
 | 
						||
   "metadata": {},
 | 
						||
   "outputs": [],
 | 
						||
   "source": [
 | 
						||
    "# 主函数\n",
 | 
						||
    "if __name__ == \"__main__\":\n",
 | 
						||
    "    # 初始化w和b\n",
 | 
						||
    "    w,b = 0,0\n",
 | 
						||
    "    # 设置学习率\n",
 | 
						||
    "    alpha = 0.01\n",
 | 
						||
    "    # 设置迭代次数\n",
 | 
						||
    "    num_iter = 1000\n",
 | 
						||
    "    # 进行梯度下降\n",
 | 
						||
    "    w,b = gradient_descent(w,b,alpha,num_iter)\n",
 | 
						||
    "    print(\"w:\", w)\n",
 | 
						||
    "    print(\"b:\", b)\n",
 | 
						||
    "    # 计算损失\n",
 | 
						||
    "    loss = compute_loss(w,b)\n",
 | 
						||
    "    print(\"loss:\", loss)\n",
 | 
						||
    "\n",
 | 
						||
    "    plt.figure(dpi=600)\n",
 | 
						||
    "    #plt.switch_backend('Agg')  # 使用 Agg 渲染器\n",
 | 
						||
    "    # 绘制数据点\n",
 | 
						||
    "    plt.scatter(x, y, color='blue', label='original data')\n",
 | 
						||
    "\n",
 | 
						||
    "    # 绘制回归直线\n",
 | 
						||
    "    plt.plot(x, w*x + b, color='red', label='regression line')\n",
 | 
						||
    "\n",
 | 
						||
    "    # 添加标题和标签\n",
 | 
						||
    "    plt.title('linear regression')\n",
 | 
						||
    "    plt.xlabel('x')\n",
 | 
						||
    "    plt.ylabel('y')\n",
 | 
						||
    "\n",
 | 
						||
    "    # 显示图例\n",
 | 
						||
    "    plt.legend()\n",
 | 
						||
    "\n",
 | 
						||
    "    # 显示图像\n",
 | 
						||
    "    plt.show()"
 | 
						||
   ]
 | 
						||
  }
 | 
						||
 ],
 | 
						||
 "metadata": {
 | 
						||
  "kernelspec": {
 | 
						||
   "display_name": "pt",
 | 
						||
   "language": "python",
 | 
						||
   "name": "python3"
 | 
						||
  },
 | 
						||
  "language_info": {
 | 
						||
   "codemirror_mode": {
 | 
						||
    "name": "ipython",
 | 
						||
    "version": 3
 | 
						||
   },
 | 
						||
   "file_extension": ".py",
 | 
						||
   "mimetype": "text/x-python",
 | 
						||
   "name": "python",
 | 
						||
   "nbconvert_exporter": "python",
 | 
						||
   "pygments_lexer": "ipython3",
 | 
						||
   "version": "3.10.14"
 | 
						||
  }
 | 
						||
 },
 | 
						||
 "nbformat": 4,
 | 
						||
 "nbformat_minor": 2
 | 
						||
}
 |