2 Star 2 Fork 2

qxd/AlexNet

加入 Gitee
与超过 1200万 开发者一起发现、参与优秀开源项目,私有仓库也完全免费 :)
免费加入
克隆/下载
贡献代码
同步代码
取消
提示: 由于 Git 不支持空文件夾,创建文件夹后会生成空的 .keep 文件
Loading...
README
Apache-2.0

​ AlexNet 是一种经典的深度学习模型,它在 2012 年的 ImageNet 图像分类比赛中大放异彩,彻底改变了计算机视觉领域的格局。AlexNet 的核心创新包括使用深度卷积神经网络(CNN)来处理图像,并采用了多个先进的技术如 ReLU 激活函数、Dropout 正则化等。

本文将介绍如何使用 PyTorch 框架实现 AlexNet,并在 MNIST 数据集上进行训练。MNIST 是一个简单但经典的数据集,常用于初学者测试机器学习算法。

一、AlexNet 网络结构 AlexNet 的结构大致可以分为两部分:特征提取部分(卷积层)和分类部分(全连接层)。下面是 AlexNet 的简要结构:

卷积层:五个卷积层用于提取特征。每个卷积层后面都有一个激活函数(ReLU)和一个池化层。 全连接层:三个全连接层,第一个和第二个全连接层后有 Dropout 层,防止过拟合。 输出层:使用 Softmax 激活函数输出 1000 个类别的概率。 二、使用 PyTorch 实现 AlexNet 训练部分

  1. 导入必要的库 首先,我们需要导入一些必要的库,包括 PyTorch 和一些数据处理工具。

import torch import torch.nn as nn import torch.optim as optim import torchvision.transforms as transforms from torch.utils.data import DataLoader from torchvision import datasets from torch.cuda.amp import GradScaler, autocast # 导入混合精度训练 from tqdm import tqdm import multiprocessing import matplotlib.pyplot as plt

  1. 定义 AlexNet 模型 接下来,我们定义一个类 AlexNet,继承自 nn.Module,并在其中实现 AlexNet 的结构。

class AlexNet(nn.Module): def init(self): super(AlexNet, self).init() # 定义卷积层和池化层 self.features = nn.Sequential( nn.Conv2d(1, 64, kernel_size=11, stride=4, padding=2), # 第一个卷积层 nn.ReLU(inplace=True), # 激活函数 nn.MaxPool2d(kernel_size=3, stride=2), # 池化层 nn.Conv2d(64, 192, kernel_size=5, padding=2), # 第二个卷积层 nn.ReLU(inplace=True), nn.MaxPool2d(kernel_size=3, stride=2), nn.Conv2d(192, 384, kernel_size=3, padding=1), # 第三个卷积层 nn.ReLU(inplace=True), nn.Conv2d(384, 256, kernel_size=3, padding=1), # 第四个卷积层 nn.ReLU(inplace=True), nn.Conv2d(256, 256, kernel_size=3, padding=1), # 第五个卷积层 nn.ReLU(inplace=True), nn.MaxPool2d(kernel_size=3, stride=2) # 池化层 ) # 定义全连接层 self.classifier = nn.Sequential( nn.Dropout(), # Dropout 层,防止过拟合 nn.Linear(256 * 6 * 6, 4096), # 第一个全连接层 nn.ReLU(inplace=True), nn.Dropout(), nn.Linear(4096, 4096), # 第二个全连接层 nn.ReLU(inplace=True), nn.Linear(4096, 10) # 输出层,10 个类别 )

def forward(self, x):
    # 前向传播
    x = self.features(x)  # 卷积层
    x = x.view(x.size(0), -1)  # 展平数据
    x = self.classifier(x)  # 全连接层
    return x
  1. 数据预处理和加载 在进行训练之前,我们需要对 MNIST 数据集进行预处理。AlexNet 要求输入的图像大小为 227x227,因此我们需要调整图像的大小。

使用 torchvision.transforms 对图像进行一系列的预处理操作

transform = transforms.Compose([ transforms.Resize((227, 227)), # 调整输入图像的大小为 227x227 (符合 AlexNet 的要求) transforms.ToTensor(), # 将图像转换为 Tensor 格式 transforms.Normalize((0.5,), (0.5,)) # 标准化操作,均值0.5,标准差0.5 ])

下载并加载 MNIST 数据集,数据集已经被预处理

trainset = datasets.MNIST(root='./data', train=True, download=True, transform=transform) testset = datasets.MNIST(root='./data', train=False, download=True, transform=transform)

使用 DataLoader 加载训练集和测试集,设置 batch size 和多线程加载

trainloader = DataLoader(trainset, batch_size=128, num_workers=2, pin_memory=True) testloader = DataLoader(testset, batch_size=128, num_workers=2, pin_memory=True)

  1. 定义损失函数和优化器 使用交叉熵损失函数和 Adam 优化器来训练模型。

#创建模型实例并将其移动到 GPU 上 model = AlexNet().to(device) #定义损失函数和优化器 criterion = nn.CrossEntropyLoss() # 分类问题常用的损失函数 optimizer = optim.AdamW(model.parameters(), lr=0.001) # 使用 AdamW 优化器

用于保存训练过程中的损失和准确率

train_losses = [] train_accuracies = []

  1. 训练模型 现在我们可以开始训练模型了。我们会对训练集进行多轮训练,并每轮输出损失和准确率。

def train_model(): epochs = 5 # 训练周期数 accumulation_steps = 4 # 梯度累积的步骤数(当前未使用) scaler = GradScaler() # 初始化混合精度训练的 GradScaler for epoch in range(epochs): model.train() # 设置模型为训练模式 running_loss = 0.0 # 初始化当前 epoch 的损失 correct = 0 # 记录正确的预测个数 total = 0 # 记录总的样本数 print(f"Epoch [{epoch + 1}/{epochs}] started.")

    # 使用 tqdm 包装 trainloader 以显示进度条
    for i, (inputs, labels) in enumerate(tqdm(trainloader, desc=f"Epoch {epoch + 1}/{epochs}", ncols=100), 1):
        inputs, labels = inputs.to(device), labels.to(device)  # 将数据和标签移动到 GPU 上

        optimizer.zero_grad()  # 清空优化器中的梯度信息

        with autocast():  # 启用混合精度训练
            outputs = model(inputs)  # 获取模型输出
            loss = criterion(outputs, labels)  # 计算损失

        scaler.scale(loss).backward()  # 反向传播计算梯度
        if (i + 1) % accumulation_steps == 0:  # 每 accumulation_steps 次更新一次梯度(目前无效)
            scaler.step(optimizer)
            scaler.update()
            optimizer.zero_grad()

        running_loss += loss.item()  # 累加当前 batch 的损失
        _, predicted = torch.max(outputs, 1)  # 获取模型的预测结果
        total += labels.size(0)  # 更新总样本数
        correct += (predicted == labels).sum().item()  # 更新正确的预测个数

    # 计算本轮训练的平均损失和准确率
    epoch_loss = running_loss / len(trainloader)
    epoch_accuracy = correct / total * 100
    train_losses.append(epoch_loss)  # 保存当前 epoch 的损失
    train_accuracies.append(epoch_accuracy)  # 保存当前 epoch 的准确率
    print(f"Epoch [{epoch + 1}/{epochs}] finished. Loss: {epoch_loss:.4f}, Accuracy: {epoch_accuracy:.2f}%\n")

# 6. 保存模型
torch.save(model.state_dict(), 'alexnet_mnist.pth')  # 保存模型权重
print("Model saved successfully!")

测试部分 6. 数据预处理和加载 首先,我们需要对 MNIST 数据集进行预处理,确保图像的尺寸符合 AlexNet 的输入要求。AlexNet 的标准输入尺寸为 227x227,因此我们需要调整 MNIST 图像的尺寸,并将其转换为张量格式进行处理。

定义对图像的转换操作:调整大小、转换为Tensor、标准化

transform = transforms.Compose([ transforms.Resize((227, 227)), # 调整输入图像的大小为 227x227 (符合 AlexNet 的要求) transforms.ToTensor(), # 将图像转换为 Tensor 格式 transforms.Normalize((0.5,), (0.5,)) # 标准化操作,均值0.5,标准差0.5 ])

加载 MNIST 数据集(训练集),并应用定义的图像转换

trainset = datasets.MNIST(root='./data', train=True, download=True, transform=transform) trainloader = DataLoader(trainset, batch_size=64, shuffle=False) # 使用 DataLoader 批量加载数据

数据预处理:我们使用 transforms.Compose 来组合多个数据转换操作。首先将图像调整为 227x227 像素,以符合 AlexNet 的输入要求,然后将图像转换为 Tensor 格式,并进行标准化处理。

加载数据:通过 DataLoader 加载训练数据,设定批处理大小为 64,并禁用数据打乱(因为我们并不进行训练,仅展示前几个图像)。

  1. 定义 AlexNet 模型结构 接下来,我们实现 AlexNet 的卷积层和全连接层。这里我们将使用灰度图像作为输入,因此输入通道数为 1(而非 3)。

这个模型是基于经典的 AlexNet 结构,只不过输入是灰度图像(1通道),而非 RGB 图像(3通道)

class AlexNet(nn.Module): def init(self): super(AlexNet, self).init() # 特征提取部分(卷积层 + 激活函数 + 池化层) self.features = nn.Sequential( nn.Conv2d(1, 64, kernel_size=11, stride=4, padding=2), # 输入1通道(灰度图),输出64通道 nn.ReLU(inplace=True), # 激活函数 ReLU nn.MaxPool2d(kernel_size=3, stride=2), # 池化层 nn.Conv2d(64, 192, kernel_size=5, padding=2), # 第二个卷积层 nn.ReLU(inplace=True), nn.MaxPool2d(kernel_size=3, stride=2), # 池化层 nn.Conv2d(192, 384, kernel_size=3, padding=1), # 第三个卷积层 nn.ReLU(inplace=True), nn.Conv2d(384, 256, kernel_size=3, padding=1), # 第四个卷积层 nn.ReLU(inplace=True), nn.Conv2d(256, 256, kernel_size=3, padding=1), # 第五个卷积层 nn.ReLU(inplace=True), nn.MaxPool2d(kernel_size=3, stride=2) # 池化层 ) # 分类器部分(全连接层) self.classifier = nn.Sequential( nn.Dropout(), # Dropout 层,防止过拟合 nn.Linear(256 * 6 * 6, 4096), # 全连接层1,输入大小 25666,输出 4096 nn.ReLU(inplace=True), nn.Dropout(), nn.Linear(4096, 4096), # 全连接层2,输出 4096 nn.ReLU(inplace=True), nn.Linear(4096, 10) # 输出层,10个类别(MNIST 0-9) )

def forward(self, x):
    # 前向传播过程
    x = self.features(x)  # 通过卷积层提取特征
    x = x.view(x.size(0), -1)  # 展平数据(展平成一维)
    x = self.classifier(x)  # 通过全连接层进行分类
    return x
  1. 加载预训练模型 在实际使用中,我们通常会保存训练好的模型权重并进行加载。这里,我们假设已经训练好模型并将其保存为 alexnet_mnist.pth。

此函数用于加载保存的模型权重

def load_model(model_path='alexnet_mnist.pth'): model = AlexNet().to(device) # 创建模型实例并移动到设备(GPU/CPU) model.load_state_dict(torch.load(model_path)) # 加载模型权重 model.eval() # 设置模型为评估模式 print("Model loaded successfully!") return model

  1. 批量进行图片预测 接下来,我们将编写一个函数,用于批量预测图像,并返回其对应的预测结果。

该函数从 dataloader 中获取指定数量的图像,并进行预测

def batch_predict_images(model, dataloader, num_images=6): predictions = [] # 用于保存预测结果 images = [] # 用于保存输入图像 labels = [] # 用于保存实际标签

# 不计算梯度以提高效率
with torch.no_grad():
    for i, (input_images, input_labels) in enumerate(dataloader):
        if i * 64 >= num_images:  # 控制处理的图像数量
            break

        input_images = input_images.to(device)  # 将图像数据转移到 GPU 上
        input_labels = input_labels.to(device)  # 将标签数据转移到 GPU 上

        # 通过模型进行预测
        outputs = model(input_images)
        _, predicted = torch.max(outputs, 1)  # 获取预测的类别

        predictions.extend(predicted.cpu().numpy())  # 保存预测结果到 CPU 上
        images.extend(input_images.cpu().numpy())  # 保存输入图像到 CPU 上
        labels.extend(input_labels.cpu().numpy())  # 保存真实标签到 CPU 上

return images[:num_images], labels[:num_images], predictions[:num_images]

三、可视化 为了更好地了解模型的训练情况,我们可以通过绘制图表来展示训练过程中的损失和准确率。

def plot_training_progress(): plt.figure(figsize=(12, 6)) # 创建一个宽12英寸、高6英寸的图形窗口

# 绘制训练损失的子图
plt.subplot(1, 2, 1)  # 1行2列的第一个子图
plt.plot(range(1, 6), train_losses, marker='o', label='Train Loss')
plt.title('Train Loss per Epoch')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.grid(True)

# 绘制训练准确率的子图
plt.subplot(1, 2, 2)  # 1行2列的第二个子图
plt.plot(range(1, 6), train_accuracies, marker='o', label='Train Accuracy')
plt.title('Train Accuracy per Epoch')
plt.xlabel('Epoch')
plt.ylabel('Accuracy (%)')
plt.grid(True)

# 调整布局,避免子图重叠
plt.tight_layout()

# 显示图形窗口
plt.show()

我们编写一个函数来可视化批量图像及其对应的预测结果。我们将使用 matplotlib 来绘制图像。

该函数显示图像和预测结果

def visualize_images(images, labels, predictions): fig, axes = plt.subplots(2, 3, figsize=(10, 7)) # 创建 2x3 的图像子图 axes = axes.ravel() # 将子图展平成一维数组

for i in range(6):
    image = images[i].squeeze()  # 去掉多余的维度(例如[1, 227, 227] -> [227, 227])
    ax = axes[i]
    ax.imshow(image, cmap='gray')  # 显示图像,使用灰度图
    ax.set_title(f"Pred: {predictions[i]} | Actual: {labels[i]}")  # 显示预测标签和实际标签
    ax.axis('off')  # 关闭坐标轴显示

plt.tight_layout()  # 调整子图之间的间距
plt.show()  # 显示图像

四、启动 #训练 if name == 'main': # 设置 multiprocessing 的启动方法为 'spawn'(Windows 需要) multiprocessing.set_start_method('spawn')

# 开始训练模型
train_model()

# 绘制训练过程图
plot_training_progress()

#测试 if name == 'main': # 设置 multiprocessing 的启动方法为 'spawn',用于兼容不同操作系统(Windows需要) multiprocessing.set_start_method('spawn')

# 加载训练好的模型
model = load_model()

# 获取前6张图像及其预测结果
images, labels, predictions = batch_predict_images(model, trainloader, num_images=6)

# 可视化这些图像及其预测结果
visualize_images(images, labels, predictions)

五、总结 在本文中,我们介绍了如何使用 PyTorch 实现 AlexNet 并在 MNIST 数据集上进行训练。通过这个过程,你可以了解如何构建卷积神经网络、加载数据集、训练模型并进行评估。AlexNet 的结构在计算机视觉任务中仍然具有重要意义,尤其是在图像分类任务中。

PyTorch 使得实现和训练深度学习模型变得更加简便和灵活,你可以通过对本文代码的修改来尝试不同的模型或数据集,从而加深对深度学习的理解。

六、参考资料 PyTorch 官方文档 AlexNet 论文 https://gitee.com/qxdlll/alex-net https://github.com/qxd-ljy/AlexNet ​ ​

Apache License Version 2.0, January 2004 http://www.apache.org/licenses/ TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 1. Definitions. "License" shall mean the terms and conditions for use, reproduction, and distribution as defined by Sections 1 through 9 of this document. "Licensor" shall mean the copyright owner or entity authorized by the copyright owner that is granting the License. "Legal Entity" shall mean the union of the acting entity and all other entities that control, are controlled by, or are under common control with that entity. For the purposes of this definition, "control" means (i) the power, direct or indirect, to cause the direction or management of such entity, whether by contract or otherwise, or (ii) ownership of fifty percent (50%) or more of the outstanding shares, or (iii) beneficial ownership of such entity. "You" (or "Your") shall mean an individual or Legal Entity exercising permissions granted by this License. "Source" form shall mean the preferred form for making modifications, including but not limited to software source code, documentation source, and configuration files. "Object" form shall mean any form resulting from mechanical transformation or translation of a Source form, including but not limited to compiled object code, generated documentation, and conversions to other media types. "Work" shall mean the work of authorship, whether in Source or Object form, made available under the License, as indicated by a copyright notice that is included in or attached to the work (an example is provided in the Appendix below). "Derivative Works" shall mean any work, whether in Source or Object form, that is based on (or derived from) the Work and for which the editorial revisions, annotations, elaborations, or other modifications represent, as a whole, an original work of authorship. For the purposes of this License, Derivative Works shall not include works that remain separable from, or merely link (or bind by name) to the interfaces of, the Work and Derivative Works thereof. "Contribution" shall mean any work of authorship, including the original version of the Work and any modifications or additions to that Work or Derivative Works thereof, that is intentionally submitted to Licensor for inclusion in the Work by the copyright owner or by an individual or Legal Entity authorized to submit on behalf of the copyright owner. For the purposes of this definition, "submitted" means any form of electronic, verbal, or written communication sent to the Licensor or its representatives, including but not limited to communication on electronic mailing lists, source code control systems, and issue tracking systems that are managed by, or on behalf of, the Licensor for the purpose of discussing and improving the Work, but excluding communication that is conspicuously marked or otherwise designated in writing by the copyright owner as "Not a Contribution." "Contributor" shall mean Licensor and any individual or Legal Entity on behalf of whom a Contribution has been received by Licensor and subsequently incorporated within the Work. 2. Grant of Copyright License. Subject to the terms and conditions of this License, each Contributor hereby grants to You a perpetual, worldwide, non-exclusive, no-charge, royalty-free, irrevocable copyright license to reproduce, prepare Derivative Works of, publicly display, publicly perform, sublicense, and distribute the Work and such Derivative Works in Source or Object form. 3. Grant of Patent License. Subject to the terms and conditions of this License, each Contributor hereby grants to You a perpetual, worldwide, non-exclusive, no-charge, royalty-free, irrevocable (except as stated in this section) patent license to make, have made, use, offer to sell, sell, import, and otherwise transfer the Work, where such license applies only to those patent claims licensable by such Contributor that are necessarily infringed by their Contribution(s) alone or by combination of their Contribution(s) with the Work to which such Contribution(s) was submitted. If You institute patent litigation against any entity (including a cross-claim or counterclaim in a lawsuit) alleging that the Work or a Contribution incorporated within the Work constitutes direct or contributory patent infringement, then any patent licenses granted to You under this License for that Work shall terminate as of the date such litigation is filed. 4. Redistribution. You may reproduce and distribute copies of the Work or Derivative Works thereof in any medium, with or without modifications, and in Source or Object form, provided that You meet the following conditions: (a) You must give any other recipients of the Work or Derivative Works a copy of this License; and (b) You must cause any modified files to carry prominent notices stating that You changed the files; and (c) You must retain, in the Source form of any Derivative Works that You distribute, all copyright, patent, trademark, and attribution notices from the Source form of the Work, excluding those notices that do not pertain to any part of the Derivative Works; and (d) If the Work includes a "NOTICE" text file as part of its distribution, then any Derivative Works that You distribute must include a readable copy of the attribution notices contained within such NOTICE file, excluding those notices that do not pertain to any part of the Derivative Works, in at least one of the following places: within a NOTICE text file distributed as part of the Derivative Works; within the Source form or documentation, if provided along with the Derivative Works; or, within a display generated by the Derivative Works, if and wherever such third-party notices normally appear. The contents of the NOTICE file are for informational purposes only and do not modify the License. You may add Your own attribution notices within Derivative Works that You distribute, alongside or as an addendum to the NOTICE text from the Work, provided that such additional attribution notices cannot be construed as modifying the License. You may add Your own copyright statement to Your modifications and may provide additional or different license terms and conditions for use, reproduction, or distribution of Your modifications, or for any such Derivative Works as a whole, provided Your use, reproduction, and distribution of the Work otherwise complies with the conditions stated in this License. 5. Submission of Contributions. Unless You explicitly state otherwise, any Contribution intentionally submitted for inclusion in the Work by You to the Licensor shall be under the terms and conditions of this License, without any additional terms or conditions. Notwithstanding the above, nothing herein shall supersede or modify the terms of any separate license agreement you may have executed with Licensor regarding such Contributions. 6. Trademarks. This License does not grant permission to use the trade names, trademarks, service marks, or product names of the Licensor, except as required for reasonable and customary use in describing the origin of the Work and reproducing the content of the NOTICE file. 7. Disclaimer of Warranty. Unless required by applicable law or agreed to in writing, Licensor provides the Work (and each Contributor provides its Contributions) on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied, including, without limitation, any warranties or conditions of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A PARTICULAR PURPOSE. You are solely responsible for determining the appropriateness of using or redistributing the Work and assume any risks associated with Your exercise of permissions under this License. 8. Limitation of Liability. In no event and under no legal theory, whether in tort (including negligence), contract, or otherwise, unless required by applicable law (such as deliberate and grossly negligent acts) or agreed to in writing, shall any Contributor be liable to You for damages, including any direct, indirect, special, incidental, or consequential damages of any character arising as a result of this License or out of the use or inability to use the Work (including but not limited to damages for loss of goodwill, work stoppage, computer failure or malfunction, or any and all other commercial damages or losses), even if such Contributor has been advised of the possibility of such damages. 9. Accepting Warranty or Additional Liability. While redistributing the Work or Derivative Works thereof, You may choose to offer, and charge a fee for, acceptance of support, warranty, indemnity, or other liability obligations and/or rights consistent with this License. However, in accepting such obligations, You may act only on Your own behalf and on Your sole responsibility, not on behalf of any other Contributor, and only if You agree to indemnify, defend, and hold each Contributor harmless for any liability incurred by, or claims asserted against, such Contributor by reason of your accepting any such warranty or additional liability. END OF TERMS AND CONDITIONS APPENDIX: How to apply the Apache License to your work. To apply the Apache License to your work, attach the following boilerplate notice, with the fields enclosed by brackets "[]" replaced with your own identifying information. (Don't include the brackets!) The text should be enclosed in the appropriate comment syntax for the file format. We also recommend that a file or class name and description of purpose be included on the same "printed page" as the copyright notice for easier identification within third-party archives. Copyright [yyyy] [name of copyright owner] Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License.

简介

使用 PyTorch 实现 AlexNet 进行 MNIST 图像分类 展开 收起
Python
Apache-2.0
取消

发行版

暂无发行版

贡献者

全部

近期动态

不能加载更多了
马建仓 AI 助手
尝试更多
代码解读
代码找茬
代码优化
Python
1
https://gitee.com/qxdlll/alex-net.git
git@gitee.com:qxdlll/alex-net.git
qxdlll
alex-net
AlexNet
master

搜索帮助