{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# 多分类问题 - 手写数字识别\n", "\n", "## 数据集\n", "- minst数据集(手写数字数据集)\n", "\n", "## 激活函数\n", "- softmax\n", "\n", "## 损失函数\n", "- 交叉熵\n", "\n", "## 优化器\n", "- 梯度下降\n", "\n", "## 模型\n", "- 全连接层\n" ] }, { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [], "source": [ "# 导库\n", "import tensorflow as tf\n", "from tensorflow.keras import Sequential\n", "from tensorflow.keras.layers import Dense\n", "from tensorflow.keras.losses import SparseCategoricalCrossentropy" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# TEST\n", "model = Sequential([Dense(units=25,activation='relu'),\n", " Dense(units=15,activation='relu'),\n", " Dense(units=10,activation='softmax')])\n", "model.compile(loss=SparseCategoricalCrossentropy())\n", "# 输出模型\n", "print(model.summary())" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# 官方实例\n", "\n", "import tensorflow as tf\n", "mnist = tf.keras.datasets.mnist\n", "\n", "(x_train, y_train),(x_test, y_test) = mnist.load_data()\n", "x_train, x_test = x_train / 255.0, x_test / 255.0\n", "\n", "model = tf.keras.models.Sequential([\n", " tf.keras.layers.Flatten(input_shape=(28, 28)),\n", " tf.keras.layers.Dense(128, activation='relu'),\n", " tf.keras.layers.Dropout(0.2),\n", " tf.keras.layers.Dense(10, activation='softmax')\n", "])\n", "\n", "model.compile(optimizer=tf.keras.optimizers.Adam(learning_rate=0.001),\n", " loss=SparseCategoricalCrossentropy(),\n", " metrics=['accuracy'])\n", "\n", "model.fit(x_train, y_train, epochs=5)\n", "model.evaluate(x_test, y_test)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# 用torch实现\n", "\n", "import torch\n", "import torch.nn as nn\n", "import torch.optim as optim\n", "from torchvision import datasets, transforms\n", "\n", "# 数据预处理\n", "transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,))])\n", "\n", "# 加载数据集\n", "trainset = datasets.MNIST(root='./data', train=True, download=True, transform=transform)\n", "trainloader = torch.utils.data.DataLoader(trainset, batch_size=64, shuffle=True)\n", "\n", "testset = datasets.MNIST(root='./data', train=False, download=True, transform=transform)\n", "testloader = torch.utils.data.DataLoader(testset, batch_size=64, shuffle=False)\n", "\n", "# 定义模型\n", "class SimpleNet(nn.Module):\n", " def __init__(self):\n", " super(SimpleNet, self).__init__()\n", " self.flatten = nn.Flatten()\n", " self.fc1 = nn.Linear(28 * 28, 128)\n", " self.dropout = nn.Dropout(0.2)\n", " self.fc2 = nn.Linear(128, 10)\n", "\n", " def forward(self, x):\n", " x = self.flatten(x)\n", " x = torch.relu(self.fc1(x))\n", " x = self.dropout(x)\n", " x = self.fc2(x)\n", " return x\n", "\n", "model = SimpleNet()\n", "\n", "if torch.backends.mps.is_available():\n", " device = torch.device(\"mps\")\n", "else:\n", " device = torch.device(\"cpu\")\n", "\n", "model.to(device)\n", "\n", "# 定义损失函数和优化器\n", "criterion = nn.CrossEntropyLoss()\n", "optimizer = optim.Adam(model.parameters(), lr=0.001)\n", "\n", "# 训练模型\n", "epochs = 5\n", "for epoch in range(epochs):\n", " running_loss = 0\n", " for images, labels in trainloader:\n", " images, labels = images.to(device), labels.to(device) # 将数据移动到设备上\n", " optimizer.zero_grad()\n", " output = model(images)\n", " loss = criterion(output, labels)\n", " loss.backward()\n", " optimizer.step()\n", " running_loss += loss.item()\n", " print(f\"Epoch {epoch+1}/{epochs} - Loss: {running_loss/len(trainloader)}\")\n", "\n", "# 测试模型\n", "correct = 0\n", "total = 0\n", "with torch.no_grad():\n", " for images, labels in testloader:\n", " images, labels = images.to(device), labels.to(device) # 将数据移动到设备上\n", " output = model(images)\n", " _, predicted = torch.max(output, 1)\n", " total += labels.size(0)\n", " correct += (predicted == labels).sum().item()\n", "\n", "print(f\"Accuracy: {100 * correct / total}%\")" ] } ], "metadata": { "kernelspec": { "display_name": "ail", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.10.16" } }, "nbformat": 4, "nbformat_minor": 2 }