使用 PyTorch 实现 AlexNet 进行 MNIST 图像分类

news2024/11/15 18:00:03

AlexNet 是一种经典的深度学习模型,它在 2012 年的 ImageNet 图像分类比赛中大放异彩,彻底改变了计算机视觉领域的格局。AlexNet 的核心创新包括使用深度卷积神经网络(CNN)来处理图像,并采用了多个先进的技术如 ReLU 激活函数、Dropout 正则化等。

本文将介绍如何使用 PyTorch 框架实现 AlexNet,并在 MNIST 数据集上进行训练。MNIST 是一个简单但经典的数据集,常用于初学者测试机器学习算法。

文末附完整项目。

一、AlexNet 网络结构

AlexNet 的结构大致可以分为两部分:特征提取部分(卷积层)和分类部分(全连接层)。下面是 AlexNet 的简要结构:

  • 卷积层:五个卷积层用于提取特征。每个卷积层后面都有一个激活函数(ReLU)和一个池化层。
  • 全连接层:三个全连接层,第一个和第二个全连接层后有 Dropout 层,防止过拟合。
  • 输出层:使用 Softmax 激活函数输出 1000 个类别的概率。

二、使用 PyTorch 实现 AlexNet

训练部分

1. 导入必要的库

首先,我们需要导入一些必要的库,包括 PyTorch 和一些数据处理工具。

import torch
import torch.nn as nn
import torch.optim as optim
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
from torchvision import datasets
from torch.cuda.amp import GradScaler, autocast  # 导入混合精度训练
from tqdm import tqdm
import multiprocessing
import matplotlib.pyplot as plt

2. 定义 AlexNet 模型

接下来,我们定义一个类 AlexNet,继承自 nn.Module,并在其中实现 AlexNet 的结构。

class AlexNet(nn.Module):
    def __init__(self):
        super(AlexNet, self).__init__()
        # 定义卷积层和池化层
        self.features = nn.Sequential(
            nn.Conv2d(1, 64, kernel_size=11, stride=4, padding=2),  # 第一个卷积层
            nn.ReLU(inplace=True),  # 激活函数
            nn.MaxPool2d(kernel_size=3, stride=2),  # 池化层
            nn.Conv2d(64, 192, kernel_size=5, padding=2),  # 第二个卷积层
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=3, stride=2),
            nn.Conv2d(192, 384, kernel_size=3, padding=1),  # 第三个卷积层
            nn.ReLU(inplace=True),
            nn.Conv2d(384, 256, kernel_size=3, padding=1),  # 第四个卷积层
            nn.ReLU(inplace=True),
            nn.Conv2d(256, 256, kernel_size=3, padding=1),  # 第五个卷积层
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=3, stride=2)  # 池化层
        )
        # 定义全连接层
        self.classifier = nn.Sequential(
            nn.Dropout(),  # Dropout 层,防止过拟合
            nn.Linear(256 * 6 * 6, 4096),  # 第一个全连接层
            nn.ReLU(inplace=True),
            nn.Dropout(),
            nn.Linear(4096, 4096),  # 第二个全连接层
            nn.ReLU(inplace=True),
            nn.Linear(4096, 10)  # 输出层,10 个类别
        )

    def forward(self, x):
        # 前向传播
        x = self.features(x)  # 卷积层
        x = x.view(x.size(0), -1)  # 展平数据
        x = self.classifier(x)  # 全连接层
        return x

3. 数据预处理和加载

在进行训练之前,我们需要对 MNIST 数据集进行预处理。AlexNet 要求输入的图像大小为 227x227,因此我们需要调整图像的大小。

# 使用 torchvision.transforms 对图像进行一系列的预处理操作
transform = transforms.Compose([
    transforms.Resize((227, 227)),  # 调整输入图像的大小为 227x227 (符合 AlexNet 的要求)
    transforms.ToTensor(),  # 将图像转换为 Tensor 格式
    transforms.Normalize((0.5,), (0.5,))  # 标准化操作,均值0.5,标准差0.5
])

# 下载并加载 MNIST 数据集,数据集已经被预处理
trainset = datasets.MNIST(root='./data', train=True, download=True, transform=transform)
testset = datasets.MNIST(root='./data', train=False, download=True, transform=transform)

# 使用 DataLoader 加载训练集和测试集,设置 batch size 和多线程加载
trainloader = DataLoader(trainset, batch_size=128, num_workers=2, pin_memory=True)
testloader = DataLoader(testset, batch_size=128, num_workers=2, pin_memory=True)

4. 定义损失函数和优化器

使用交叉熵损失函数和 Adam 优化器来训练模型。

#创建模型实例并将其移动到 GPU 上
model = AlexNet().to(device)
#定义损失函数和优化器
criterion = nn.CrossEntropyLoss()  # 分类问题常用的损失函数
optimizer = optim.AdamW(model.parameters(), lr=0.001)  # 使用 AdamW 优化器

# 用于保存训练过程中的损失和准确率
train_losses = []
train_accuracies = []

5. 训练模型

现在我们可以开始训练模型了。我们会对训练集进行多轮训练,并每轮输出损失和准确率。

def train_model():
    epochs = 5  # 训练周期数
    accumulation_steps = 4  # 梯度累积的步骤数(当前未使用)
    scaler = GradScaler()  # 初始化混合精度训练的 GradScaler
    for epoch in range(epochs):
        model.train()  # 设置模型为训练模式
        running_loss = 0.0  # 初始化当前 epoch 的损失
        correct = 0  # 记录正确的预测个数
        total = 0  # 记录总的样本数
        print(f"Epoch [{epoch + 1}/{epochs}] started.")

        # 使用 tqdm 包装 trainloader 以显示进度条
        for i, (inputs, labels) in enumerate(tqdm(trainloader, desc=f"Epoch {epoch + 1}/{epochs}", ncols=100), 1):
            inputs, labels = inputs.to(device), labels.to(device)  # 将数据和标签移动到 GPU 上

            optimizer.zero_grad()  # 清空优化器中的梯度信息

            with autocast():  # 启用混合精度训练
                outputs = model(inputs)  # 获取模型输出
                loss = criterion(outputs, labels)  # 计算损失

            scaler.scale(loss).backward()  # 反向传播计算梯度
            if (i + 1) % accumulation_steps == 0:  # 每 accumulation_steps 次更新一次梯度(目前无效)
                scaler.step(optimizer)
                scaler.update()
                optimizer.zero_grad()

            running_loss += loss.item()  # 累加当前 batch 的损失
            _, predicted = torch.max(outputs, 1)  # 获取模型的预测结果
            total += labels.size(0)  # 更新总样本数
            correct += (predicted == labels).sum().item()  # 更新正确的预测个数

        # 计算本轮训练的平均损失和准确率
        epoch_loss = running_loss / len(trainloader)
        epoch_accuracy = correct / total * 100
        train_losses.append(epoch_loss)  # 保存当前 epoch 的损失
        train_accuracies.append(epoch_accuracy)  # 保存当前 epoch 的准确率
        print(f"Epoch [{epoch + 1}/{epochs}] finished. Loss: {epoch_loss:.4f}, Accuracy: {epoch_accuracy:.2f}%\n")

    # 6. 保存模型
    torch.save(model.state_dict(), 'alexnet_mnist.pth')  # 保存模型权重
    print("Model saved successfully!")

测试部分

6. 数据预处理和加载

首先,我们需要对 MNIST 数据集进行预处理,确保图像的尺寸符合 AlexNet 的输入要求。AlexNet 的标准输入尺寸为 227x227,因此我们需要调整 MNIST 图像的尺寸,并将其转换为张量格式进行处理。

# 定义对图像的转换操作:调整大小、转换为Tensor、标准化
transform = transforms.Compose([
    transforms.Resize((227, 227)),  # 调整输入图像的大小为 227x227 (符合 AlexNet 的要求)
    transforms.ToTensor(),  # 将图像转换为 Tensor 格式
    transforms.Normalize((0.5,), (0.5,))  # 标准化操作,均值0.5,标准差0.5
])

# 加载 MNIST 数据集(训练集),并应用定义的图像转换
trainset = datasets.MNIST(root='./data', train=True, download=True, transform=transform)
trainloader = DataLoader(trainset, batch_size=64, shuffle=False)  # 使用 DataLoader 批量加载数据

        数据预处理:我们使用 transforms.Compose 来组合多个数据转换操作。首先将图像调整为 227x227 像素,以符合 AlexNet 的输入要求,然后将图像转换为 Tensor 格式,并进行标准化处理。

        加载数据:通过 DataLoader 加载训练数据,设定批处理大小为 64,并禁用数据打乱(因为我们并不进行训练,仅展示前几个图像)。

7. 定义 AlexNet 模型结构

接下来,我们实现 AlexNet 的卷积层和全连接层。这里我们将使用灰度图像作为输入,因此输入通道数为 1(而非 3)。

# 这个模型是基于经典的 AlexNet 结构,只不过输入是灰度图像(1通道),而非 RGB 图像(3通道)
class AlexNet(nn.Module):
    def __init__(self):
        super(AlexNet, self).__init__()
        # 特征提取部分(卷积层 + 激活函数 + 池化层)
        self.features = nn.Sequential(
            nn.Conv2d(1, 64, kernel_size=11, stride=4, padding=2),  # 输入1通道(灰度图),输出64通道
            nn.ReLU(inplace=True),  # 激活函数 ReLU
            nn.MaxPool2d(kernel_size=3, stride=2),  # 池化层
            nn.Conv2d(64, 192, kernel_size=5, padding=2),  # 第二个卷积层
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=3, stride=2),  # 池化层
            nn.Conv2d(192, 384, kernel_size=3, padding=1),  # 第三个卷积层
            nn.ReLU(inplace=True),
            nn.Conv2d(384, 256, kernel_size=3, padding=1),  # 第四个卷积层
            nn.ReLU(inplace=True),
            nn.Conv2d(256, 256, kernel_size=3, padding=1),  # 第五个卷积层
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=3, stride=2)  # 池化层
        )
        # 分类器部分(全连接层)
        self.classifier = nn.Sequential(
            nn.Dropout(),  # Dropout 层,防止过拟合
            nn.Linear(256 * 6 * 6, 4096),  # 全连接层1,输入大小 256*6*6,输出 4096
            nn.ReLU(inplace=True),
            nn.Dropout(),
            nn.Linear(4096, 4096),  # 全连接层2,输出 4096
            nn.ReLU(inplace=True),
            nn.Linear(4096, 10)  # 输出层,10个类别(MNIST 0-9)
        )

    def forward(self, x):
        # 前向传播过程
        x = self.features(x)  # 通过卷积层提取特征
        x = x.view(x.size(0), -1)  # 展平数据(展平成一维)
        x = self.classifier(x)  # 通过全连接层进行分类
        return x

3. 加载预训练模型

在实际使用中,我们通常会保存训练好的模型权重并进行加载。这里,我们假设已经训练好模型并将其保存为 alexnet_mnist.pth

# 此函数用于加载保存的模型权重
def load_model(model_path='alexnet_mnist.pth'):
    model = AlexNet().to(device)  # 创建模型实例并移动到设备(GPU/CPU)
    model.load_state_dict(torch.load(model_path))  # 加载模型权重
    model.eval()  # 设置模型为评估模式
    print("Model loaded successfully!")
    return model

4. 批量进行图片预测

接下来,我们将编写一个函数,用于批量预测图像,并返回其对应的预测结果。

# 该函数从 dataloader 中获取指定数量的图像,并进行预测
def batch_predict_images(model, dataloader, num_images=6):
    predictions = []  # 用于保存预测结果
    images = []  # 用于保存输入图像
    labels = []  # 用于保存实际标签

    # 不计算梯度以提高效率
    with torch.no_grad():
        for i, (input_images, input_labels) in enumerate(dataloader):
            if i * 64 >= num_images:  # 控制处理的图像数量
                break

            input_images = input_images.to(device)  # 将图像数据转移到 GPU 上
            input_labels = input_labels.to(device)  # 将标签数据转移到 GPU 上

            # 通过模型进行预测
            outputs = model(input_images)
            _, predicted = torch.max(outputs, 1)  # 获取预测的类别

            predictions.extend(predicted.cpu().numpy())  # 保存预测结果到 CPU 上
            images.extend(input_images.cpu().numpy())  # 保存输入图像到 CPU 上
            labels.extend(input_labels.cpu().numpy())  # 保存真实标签到 CPU 上

    return images[:num_images], labels[:num_images], predictions[:num_images]

三、可视化

为了更好地了解模型的训练情况,我们可以通过绘制图表来展示训练过程中的损失和准确率。

def plot_training_progress():
    plt.figure(figsize=(12, 6))  # 创建一个宽12英寸、高6英寸的图形窗口

    # 绘制训练损失的子图
    plt.subplot(1, 2, 1)  # 1行2列的第一个子图
    plt.plot(range(1, 6), train_losses, marker='o', label='Train Loss')
    plt.title('Train Loss per Epoch')
    plt.xlabel('Epoch')
    plt.ylabel('Loss')
    plt.grid(True)

    # 绘制训练准确率的子图
    plt.subplot(1, 2, 2)  # 1行2列的第二个子图
    plt.plot(range(1, 6), train_accuracies, marker='o', label='Train Accuracy')
    plt.title('Train Accuracy per Epoch')
    plt.xlabel('Epoch')
    plt.ylabel('Accuracy (%)')
    plt.grid(True)

    # 调整布局,避免子图重叠
    plt.tight_layout()

    # 显示图形窗口
    plt.show()

我们编写一个函数来可视化批量图像及其对应的预测结果。我们将使用 matplotlib 来绘制图像。

# 该函数显示图像和预测结果
def visualize_images(images, labels, predictions):
    fig, axes = plt.subplots(2, 3, figsize=(10, 7))  # 创建 2x3 的图像子图
    axes = axes.ravel()  # 将子图展平成一维数组

    for i in range(6):
        image = images[i].squeeze()  # 去掉多余的维度(例如[1, 227, 227] -> [227, 227])
        ax = axes[i]
        ax.imshow(image, cmap='gray')  # 显示图像,使用灰度图
        ax.set_title(f"Pred: {predictions[i]} | Actual: {labels[i]}")  # 显示预测标签和实际标签
        ax.axis('off')  # 关闭坐标轴显示

    plt.tight_layout()  # 调整子图之间的间距
    plt.show()  # 显示图像

四、启动

#训练
if __name__ == '__main__':
    # 设置 multiprocessing 的启动方法为 'spawn'(Windows 需要)
    multiprocessing.set_start_method('spawn')

    # 开始训练模型
    train_model()

    # 绘制训练过程图
    plot_training_progress()



#测试
if __name__ == '__main__':
    # 设置 multiprocessing 的启动方法为 'spawn',用于兼容不同操作系统(Windows需要)
    multiprocessing.set_start_method('spawn')

    # 加载训练好的模型
    model = load_model()

    # 获取前6张图像及其预测结果
    images, labels, predictions = batch_predict_images(model, trainloader, num_images=6)

    # 可视化这些图像及其预测结果
    visualize_images(images, labels, predictions)

五、总结

在本文中,我们介绍了如何使用 PyTorch 实现 AlexNet 并在 MNIST 数据集上进行训练。通过这个过程,你可以了解如何构建卷积神经网络、加载数据集、训练模型并进行评估。AlexNet 的结构在计算机视觉任务中仍然具有重要意义,尤其是在图像分类任务中。

PyTorch 使得实现和训练深度学习模型变得更加简便和灵活,你可以通过对本文代码的修改来尝试不同的模型或数据集,从而加深对深度学习的理解。

五、参考资料

  • PyTorch 官方文档
  • AlexNet 论文
  • AlexNet: 使用 PyTorch 实现 AlexNet 进行 MNIST 图像分类icon-default.png?t=O83Ahttps://gitee.com/qxdlll/alex-net
  • GitHub - qxd-ljy/AlexNet: 使用 PyTorch 实现 AlexNet 进行 MNIST 图像分类使用 PyTorch 实现 AlexNet 进行 MNIST 图像分类. Contribute to qxd-ljy/AlexNet development by creating an account on GitHub.icon-default.png?t=O83Ahttps://github.com/qxd-ljy/AlexNet

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2240981.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

基于图的去中心化社会推荐过滤器

🏡作者主页:点击! 🤖编程探索专栏:点击! ⏰️创作时间:2024年11月11日19点20分 点击开启你的论文编程之旅https://www.aspiringcode.com/content?id17176636216843&uideba758a1550b46bb…

深度学习模型评价指标介绍

模型评价指标 模型评价指标1.混淆矩阵2.Overall Accuracy3.Average accuracy4.Kappa系数5.Recall6.Precision7.F18.PR曲线9.置信度10.IOU11.AP12.mAP 模型评价指标 在我们学习机器学习以及深度学习,甚至在计算机视觉领域,我们不可避免的要利用一些指标评…

k8s 1.28.2 集群部署 docker registry 接入 MinIO 存储

文章目录 [toc]docker registry 部署生成 htpasswd 文件生成 secret 文件 生成 registry 配置文件创建 service创建 statefulset创建 ingress验证 docker registry docker registry 监控docker registry ui docker registry dockerfile docker registry 配置文件 S3 storage dr…

【自用】0-1背包问题与完全背包问题的Java实现

引言 背包问题是计算机科学领域的一个经典优化问题,分为多种类型,其中最常见的是0-1背包问题和完全背包问题。这两种问题的核心在于如何在有限的空间内最大化收益,但它们之间存在一些关键的区别:0-1背包问题允许每个物品只能选择…

Zookeeper的安装与使用

一、简介 1.1、概念 ZooKeeper 是一个开源的分布式协调服务,主要用于解决分布式系统中的数据一致性问题。它提供了一种可靠的机制来管理和协调分布式系统的各个节点。ZooKeeper 的设计目标是简化分布式应用的开发,提供简单易用的接口和高性能、高稳定性…

【模块一】kubernetes容器编排进阶实战之etcd的介绍与使用

etcd进阶 etcd简介:  etcd是CoreOS团队于2013年6月发起的开源项目,它的目标是构建一个高可用的分布式键值(key-value)数据库。etcd内部采用raft协议作为一致性算法,etcd基于Go语言实现。 官方网站:https://etcd.io/  gith…

【机器学习】如何配置anaconda环境(无脑版)

马上就要上机器学习的实验,这里想写一下我配置机器学习的anaconda环境的二三事 一、首先,下载安装包: Download Now | Anaconda 二、打开安装包,一直点NEXT进行安装 这里要记住你要下载安装的路径在哪,后续配置环境…

矩阵中的路径(dfs)-acwing

题目 23. 矩阵中的路径 - AcWing题库 代码 class Solution { public://以每一个坐标作为dfs起点bool hasPath(vector<vector<char>>& matrix, string str) {for (int i 0; i < matrix.size(); i )for (int j 0; j < matrix[i].size(); j )if (dfs(…

WEB攻防-通用漏洞SQL注入sqlmapOracleMongodbDB2等

SQL注入课程体系&#xff1a; 1、数据库注入-access mysql mssql oracle mongodb postgresql 2、数据类型注入-数字型 字符型 搜索型 加密型&#xff08;base64 json等&#xff09; 3、提交方式注入-get post cookie http头等 4、查询方式注入-查询 增加 删除 更新 堆叠等 …

android studio 更改gradle版本方法(备忘)

如果出现类似以下&#xff1a; Your build is currently configured to use Java 17.0.11 and Gradle 6.1.1. 或者类似&#xff1a; Failed to calculate the value of task ‘:app:compileDebugJavaWithJavac‘ property ‘options.generatedSo 消息时需要修改gradle版本&…

设计模式之装饰器模式(SSO单点登录功能扩展,增加拦截用户访问方法范围场景)

前言&#xff1a; 两个本想描述一样的意思的词&#xff0c;只因一字只差就让人觉得一个是好牛&#xff0c;一个好搞笑。往往我们去开发编程写代码时也经常将一些不恰当的用法用于业务需求实现中&#xff0c;但却不能意识到。一方面是由于编码不多缺少较大型项目的实践&#xff…

日志:中文 URI 参数乱码之 encodeURI、encodeURIComponent、escape 作为 Ajax 中文参数编码给 ASP 的记录

前面提到的了 ASP 输出 UTF-8 编码的中文不定时出现乱码的解决方案&#xff1a;ASP页面改为UTF-8编码后&#xff0c;刷新页面中文输入输出不定时乱码终极解决方案 今天遇到的则是输入 UTF-8 编码中文 URI 参数乱码的问题&#xff0c;第一次可以&#xff0c;刷新后取得的输入参…

Intern大模型训练营(八):Llamaindex RAG 实践

1. 基于 LlamaIndex 构建自己的 RAG 知识库 首先在Intern Studio中申请30% A100的开发机。 进入开发机后&#xff0c;创建新的conda环境&#xff0c;命名为 llamaindex&#xff0c;在命令行模式下运行&#xff1a; conda create -n llamaindex python3.10 复制完成后&#…

leetcode104:二叉树的最大深度

给定一个二叉树 root &#xff0c;返回其最大深度。 二叉树的 最大深度 是指从根节点到最远叶子节点的最长路径上的节点数。 示例 1&#xff1a; 输入&#xff1a;root [3,9,20,null,null,15,7] 输出&#xff1a;3示例 2&#xff1a; 输入&#xff1a;root [1,null,2] 输出…

Unity插件-Smart Inspector 免费的,接近虚幻引擎的蓝图Tab管理

习惯了虚幻的一张蓝图&#xff0c;关联所有Tab &#xff08;才发现Unity&#xff0c;的Component一直被人吐槽&#xff0c;但实际上是&#xff1a;本身结构Unity 的GameObject-Comp结构&#xff0c;是好的不能再好了&#xff0c;只是配上 smart Inspector就更清晰了&#xff0…

RDIFramework.NET CS敏捷开发框架 V6.1发布(.NET6+、Framework双引擎、全网唯一)

RDIFramework.NET C/S敏捷开发框架V6.1版本迎来重大更新与调整&#xff0c;全面重新设计业务逻辑代码&#xff0c;代码量减少一半以上&#xff0c;开发更加高效。全系统引入全新字体图标&#xff0c;整个界面焕然一新。底层引入最易上手的ORM框架SqlSugar&#xff0c;让开发更加…

运行springBlade项目历程

框架选择 官网地址&#xff1a;https://bladex.cn 使用手册&#xff1a;https://www.kancloud.cn/smallchill/blade 常见问题&#xff1a;https://sns.bladex.cn/article-14966.html 问答社区&#xff1a;https://sns.bladex.cn 环境配置 存在jdk8的情况下安装jdk17 jdk17gi…

图形 2.7 LDR与HDR

LDR与HDR B站视频&#xff1a;图形 2.7 LDR与HDR 文章目录 LDR与HDR基本概念LDRHDR为什么需要HDR不同显示屏的差异 Unity中的HDRCamera HDR 设置Lightmap HDR设置拾色器 HDR设置优缺点 HDR与Bloom通常Bloom渲染步骤渲染出原图获取图像中较亮部分高斯模糊叠加 Unity中Bloom渲染…

单片机设计智能翻译手势识别系统

目录 前言 一、本设计主要实现哪些很“开门”功能&#xff1f; 二、电路设计原理图 电路图采用Altium Designer进行设计&#xff1a; 三、实物设计图 四、程序源代码设计 五、获取资料内容 前言 在全球化的浪潮下&#xff0c;语言的多样性也为人们的交流带来了不小的挑战…

Python调用API翻译Excel中的英语句子并回填数据

一、问题描述 最近遇到一个把Excel表中两列单元格中的文本读取&#xff0c;然后翻译&#xff0c;再重新回填到单元格中的案例。大约有700多行&#xff0c;1400多个句子&#xff0c;一个个手动复制粘贴要花费不少时间&#xff0c;而且极易出错。这时&#xff0c;我们就可以请出…