VGG16模型实现MNIST图像分类

news2024/10/10 4:58:49

MNIST图像数据集

MNIST(Modified National Institute of Standards and Technology)是一个经典的机器学习数据集,常用于训练和测试图像处理和机器学习算法,特别是在数字识别领域。该数据集包含了大约 7 万张手写数字图片,其中 6 万张是用于训练,1 万张用于测试。每张图片都是 28x28 像素的灰度图像,展示了从 0 到 9 的手写数字。这些图像已经被处理过,以使得数字在图像中居中且尺寸一致。

MNIST 数据集是一个广泛被用于测试新的机器学习算法的基准,因为它相对较小,易于理解,且可以用于快速验证算法的有效性。许多人使用 MNIST 作为开始学习深度学习的入门数据集,因为它提供了一个简单但具有挑战性的任务,即将手写数字图像分类为相应的数字。

尽管 MNIST 已经存在了很长时间,但它仍然是一个重要的基准数据集,特别是对于新的机器学习研究和算法的初步测试。MINIST数据集中部分图片如下所示:

下载MNIST数据集

由于MINIST作为经典数据集,已经被内嵌在torchvision库中的dataset中了,所以直接使用代码datasets.MNIST进行下载即可。

下载后的文件格式如下图所示。

搭建VGG16图像分类模型

class VGGClassifier(nn.Module):
    def __init__(self, num_classes):
        super(VGGClassifier, self).__init__()
        self.features = models.vgg16(pretrained=True).features  # 使用预训练的VGG16模型作为特征提取器
        # 重构网络的第一层卷积层,适配mnist数据的灰度图像格式
        self.features[0] = nn.Conv2d(1, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        self.classifier = nn.Sequential(
            nn.Linear(512 * 7 * 7, 256),  # 添加一个全连接层,输入特征维度为512x7x7,输出维度为4096
            nn.ReLU(True),
            nn.Dropout(), # 随机将一些神经元“关闭”,有效地防止过拟合。
            nn.Linear(256, 256),  # 添加一个全连接层,输入和输出维度都为4096
            nn.ReLU(True),
            nn.Dropout(),
            nn.Linear(256, num_classes),  # 添加一个全连接层,输入维度为4096,输出维度为类别数(10)
        )
        self._initialize_weights()  # 初始化权重参数

定义VGG网络结构如上所示,在上面代码中我定义了一个基于 VGG16 架构分类器的模型。VGG16 是一种经典的卷积神经网络模型,由 16 层深度的卷积层和全连接层组成,所构建的 VGGClassifier 类的网络结构包含两个主要部分:

特征提取器(features):这部分使用了预训练的 VGG16 模型的特征提取器。通过调用 models.vgg16(pretrained=True).features 来加载 VGG16 的特征提取器部分。然后,将第一层卷积层的输入通道数从 3 修改为 1,以适应 MNIST 数据集的灰度图像格式。

分类器(classifier):这部分是自定义的分类器,用于对提取的特征进行分类。首先,通过几个全连接层将特征图展平成一维张量,然后通过一系列的线性层和激活函数对特征进行处理。具体来说,包括:一个包含 256 个神经元的全连接层,输入维度为 512x7x7(经过 VGG16 的特征提取器后的输出尺寸),使用 ReLU 激活函数。一个 Dropout 层,用于防止过拟合,随机关闭一些神经元。一个包含 256 个神经元的全连接层,使用 ReLU 激活函数。再次添加一个 Dropout 层。最后是一个包含 num_classes 个神经元的全连接层,用于输出最终的类别预测结果。

通过上述方式,整个网络结构将 VGG16 的特征提取器和自定义的分类器相结合,以适应 MNIST 数据集的图像分类任务。

构建的VGG网络结构如下图所示:

VGG网络结构图

模型训练

# 定义超参数和训练参数
batch_size = 16  # 批处理大小
num_epochs = 5  # 训练轮数(epoch)
learning_rate = 0.001  # 学习率(learning rate)
num_classes = 10  # 类别数(MNIST数据集有10个类别)
device = torch.device(
    "cuda:0" if torch.cuda.is_available() else "cpu")  # 判断是否使用GPU进行训练,如果有GPU则使用第一个GPU(cuda:0)进行训练,否则使用CPU进行训练。

模型参数设置如下表所示(代码见上)

模型超参数

数值

batchsize

16

num_epochs

5

learning_rate

0.001

num_classes

10

由于MINIST数据集样本数量较大,所以对于上述代码训练速度也会较慢,我考虑使用我的笔记本电脑独显进行运算,却发现电脑显存不够,于是我调小batchsize与epoch,并降低学习率learning rate才让GPU勉强能够运行上面代码,并获得到了模型model.pth,最终获得模型在测试集上面的识别精度为96.7%,精度还是比较高的。(由于笔记本电脑性能有限,在处理较大规模数据的小型项目时速度较慢,故上述代码运行了一下午左右的时间才跑完)。

模型测试

使用上面模型进行手写数字识别的检验。绘制一张图片上面含有9张子图,随机选取识别结果的9张进行展示 。识别效果以及运行结果如下图所示。

 

附录:

 VGG训练代码

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
import torchvision.models as models
import torchvision.datasets as datasets
import torchvision.transforms as transforms

import warnings
warnings.filterwarnings("ignore")

# 定义数据预处理操作
transform = transforms.Compose([
    transforms.Resize(224), # 将图像大小调整为(224, 224)
    transforms.ToTensor(),  # 将图像转换为PyTorch张量
    transforms.Normalize((0.5,), (0.5,))  # 对图像进行归一化
])

# 下载并加载MNIST数据集
train_dataset = datasets.MNIST(root='./data', train=True, transform=transform, download=True)
test_dataset = datasets.MNIST(root='./data', train=False, transform=transform)


class VGGClassifier(nn.Module):
    def __init__(self, num_classes):
        super(VGGClassifier, self).__init__()
        self.features = models.vgg16(pretrained=True).features  # 使用预训练的VGG16模型作为特征提取器
        # 重构网络的第一层卷积层,适配mnist数据的灰度图像格式
        self.features[0] = nn.Conv2d(1, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        self.classifier = nn.Sequential(
            nn.Linear(512 * 7 * 7, 256),  # 添加一个全连接层,输入特征维度为512x7x7,输出维度为4096
            nn.ReLU(True),
            nn.Dropout(), # 随机将一些神经元“关闭”,有效地防止过拟合。
            nn.Linear(256, 256),  # 添加一个全连接层,输入和输出维度都为4096
            nn.ReLU(True),
            nn.Dropout(),
            nn.Linear(256, num_classes),  # 添加一个全连接层,输入维度为4096,输出维度为类别数(10)
        )
        self._initialize_weights()  # 初始化权重参数

    def forward(self, x):
        x = self.features(x)  # 通过特征提取器提取特征
        x = x.view(x.size(0), -1)  # 将特征张量展平为一维向量
        x = self.classifier(x)  # 通过分类器进行分类预测
        return x

    def _initialize_weights(self):  # 定义初始化权重的方法,使用Xavier初始化方法
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
                if m.bias is not None:
                    nn.init.constant_(m.bias, 0)
            elif isinstance(m, nn.Linear):
                nn.init.normal_(m.weight, 0, 0.01)
                nn.init.constant_(m.bias, 0)

# 定义超参数和训练参数
batch_size = 16  # 批处理大小
num_epochs = 5  # 训练轮数(epoch)
learning_rate = 0.001  # 学习率(learning rate)
num_classes = 10  # 类别数(MNIST数据集有10个类别)
device = torch.device(
    "cuda:0" if torch.cuda.is_available() else "cpu")  # 判断是否使用GPU进行训练,如果有GPU则使用第一个GPU(cuda:0)进行训练,否则使用CPU进行训练。

# 定义数据加载器
train_loader = DataLoader(dataset=train_dataset, batch_size=batch_size, shuffle=True)
test_loader = DataLoader(dataset=test_dataset, batch_size=batch_size, shuffle=False)

# 初始化模型和优化器
model = VGGClassifier(num_classes=num_classes).to(device)  # 将模型移动到指定设备(GPU或CPU)
criterion = nn.CrossEntropyLoss()  # 使用交叉熵损失函数
optimizer = optim.SGD(model.parameters(), lr=learning_rate)  # 使用随机梯度下降优化器(SGD)

# 训练模型
for epoch in range(num_epochs):
    for i, (images, labels) in enumerate(train_loader):
        images = images.to(device)  # 将图像数据移动到指定设备
        labels = labels.to(device)  # 将标签数据移动到指定设备

        # 前向传播
        outputs = model(images)
        loss = criterion(outputs, labels)

        # 反向传播和优化
        optimizer.zero_grad()  # 清空梯度缓存
        loss.backward()  # 计算梯度
        optimizer.step()  # 更新权重参数

        if (i + 1) % 100 == 0:  # 每100个batch打印一次训练信息
            print('Epoch [{}/{}], Step [{}/{}], Loss: {:.4f}'.format(epoch + 1, num_epochs, i + 1, len(train_loader),
                                                                     loss.item()))

# 训练结束,保存模型参数
torch.save(model.state_dict(), './model.pth')

# 加载训练好的模型参数
model.load_state_dict(torch.load('./model.pth'))
model.eval()  # 将模型设置为评估模式,关闭dropout等操作

# 定义评估指标变量
correct = 0  # 记录预测正确的样本数量
total = 0  # 记录总样本数量

# 测试模型性能
with torch.no_grad():  # 关闭梯度计算,节省内存空间
    for images, labels in test_loader:
        images = images.to(device)  # 将图像数据移动到指定设备
        labels = labels.to(device)  # 将标签数据移动到指定设备
        outputs = model(images)  # 模型前向传播,得到预测结果
        _, predicted = torch.max(outputs.data, 1)  # 取预测结果的最大值对应的类别作为预测类别
        total += labels.size(0)  # 更新总样本数量
        correct += (predicted == labels).sum().item()  # 统计预测正确的样本数量

# 计算模型准确率并打印出来
accuracy = 100 * correct / total  # 计算准确率,将正确预测的样本数量除以总样本数量并乘以100得到百分比形式的准确率。
print('Accuracy of the model on the test images: {} %'.format(accuracy))  # 打印出模型的准确率。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2195248.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

wsl环境下安装MySQL5.7

安装操作需root权限: 1-通过 sudo su - ,切换到root用户。 2-在每一个命令前加上sudo,临时提升权限 1、下载apt仓库文件 wget https://dev.mysql.com/get/mysql-apt-config_0.8.12-1_all.deb 安装包是.deb的文件2、配置仓库,使…

MyBatis 批量插入方案

MyBatis 批量插入 MyBatis 插入数据的方法有几种: for 循环,每次都重新连接一次数据库,每次只插入一条数据。 在编写 sql 时用 for each 标签,建立一次数据库连接。 使用 MyBatis 的 batchInsert 方法。 下面是方法 1 和 2 的…

Linux防火墙-案例(一)filter表

作者介绍:简历上没有一个精通的运维工程师。希望大家多多关注作者,下面的思维导图也是预计更新的内容和当前进度(不定时更新)。 我们经过上小章节讲了Linux的部分进阶命令,我们接下来一章节来讲讲Linux防火墙。由于目前以云服务器为主&#x…

51单片机的水位检测系统【proteus仿真+程序+报告+原理图+演示视频】

1、主要功能 该系统由AT89C51/STC89C52单片机LCD1602显示模块水位传感器继电器LED、按键和蜂鸣器等模块构成。适用于水位监测、水位控制、水位检测相似项目。 可实现功能: 1、LCD1602实时显示水位高度 2、水位传感器采集水位高度 3、按键可设置水位的下限 4、按键可手动加…

动手学大模型应用开发之大模型简介

动手学大模型应用开发之大模型简介 主要学习目标什么是大语言模型大模型的能力和特点涌现能力作为基座模型支持多元应用的能力支持对话作为统一入口的能力大模型特点 常见大模型ChatGpt通义千问 LangChainLangChain的核心模块 总结相关学习链接 主要学习目标 学习如何进行大模…

【AI知识点】激活函数(Activation Function)

激活函数(Activation Function) 是神经网络中的一个关键组件,负责将输入的线性组合转化为非线性输出。它赋予神经网络模型以复杂的表达能力,使其能够处理非线性问题,比如分类、图像识别和自然语言处理等任务。 1. 激活…

【redis-06】redis的stream流实现消息中间件

redis系列整体栏目 内容链接地址【一】redis基本数据类型和使用场景https://zhenghuisheng.blog.csdn.net/article/details/142406325【二】redis的持久化机制和原理https://zhenghuisheng.blog.csdn.net/article/details/142441756【三】redis缓存穿透、缓存击穿、缓存雪崩htt…

Spring Boot:医院管理的数字化转型

5系统详细实现 5.1 医生模块的实现 5.1.1 病床信息管理 医院管理系统的医生可以管理病床信息,可以对病床信息添加修改删除操作。具体界面的展示如图5.1所示。 图5.1 病床信息管理界面 5.1.2 药房信息管理 医生可以对药房信息进行添加,修改,…

今日指数day8实战补充用户管理模块(下)

ps : 由于前端将userId封装为BigInt类型 , 导致有精度损失, 传入的userId不正确 , 部分功能无法正确实现 , 但是代码已经完善 1.4 更新用户角色信息接口说明 1)原型效果 2)接口说明 功能描述:更新用户角色信息 服务路径:/user/…

基于FPGA的ov5640摄像头图像采集(二)

之前讲过ov5640摄像头图像采集,但是只包了的摄像头驱动与数据对齐两部分,但是由于摄像头输入的像素时钟与HDMI输出的驱动时钟并不相同,所有需要利用DDR3来将像素数据进行缓存再将像素数据从DDR3中读出,对DDR3的读写参考米联客的IP…

别再为日期时间头疼了!Python datetime模块助你高效搞定一切时间问题,让你的代码从此与时间赛跑,快人一步!

博客主页:长风清留扬-CSDN博客系列专栏:Python基础专栏每天更新大数据相关方面的技术,分享自己的实战工作经验和学习总结,尽量帮助大家解决更多问题和学习更多新知识,欢迎评论区分享自己的看法感谢大家点赞&#x1f44…

花半小时用豆包Marscode 和 Supabase免费部署了一个远程工作的导航站

以下是「 豆包MarsCode 体验官」优秀文章,作者谦哥。 🚀 项目地址:remotejobs.justidea.cn/ 🚀 项目截图: 数据处理 感谢开源项目:https://github.com/remoteintech/remote-jobs 网站信息获取&#xff1…

MyBatis 操作数据库入门

目录 前言 1.创建springboot⼯程 2.数据准备 3.配置Mybatis数据库连接信息 4.编写SQL语句,进行测试 前言 什么是MyBatis? MyBatis是⼀款优秀的 持久层 框架,⽤于简化JDBC的开发 Mybatis操作数据库的入门步骤: 1.创建springboot⼯程 2.数…

SOMEIP_ETS_171: SD_Unicast_FindService

测试目的: 验证DUT能够响应Tester发送的多个单播FindService消息,并至少回复一个单播OfferService消息。 描述 本测试用例旨在确保DUT能够正确处理单播FindService消息请求,并为请求的服务提供至少一个单播OfferService消息作为响应。 测…

SpringBootWeb快速入门!详解如何创建一个简单的SpringBoot项目?

在现代Web开发中,SpringBoot以其简化的配置和快速的开发效率而受到广大开发者的青睐。本篇文章将带领你从零开始,搭建一个基于SpringBoot的简单Web应用~ 一、前提准备 想要创建一个SpringBoot项目,需要做如下准备: idea集成开发…

亲身经历告诉你该如何自学编程

我2016年硕士毕业后,从一个纯机械学生开始转行做软件开发,其中少不了要自学编程,这其中经历的到现在看来还历历在目。 我曾经写过一些关于我转行做软件开发经历的文章,如果你感兴趣,可以点击这里的链接(我…

国庆期间的问题,如何在老家访问杭州办公室的网络呢

背景:国庆期间的问题,如何在老家访问杭州办公室的网络呢 实现方案:异地组网 实现语言:Java 环境:三个网络,一台拥有公网IP的服务器、一台杭州本地机房内服务器、你老家所在网络中的一台电脑(…

【Git】TortoiseGitPlink提示输入密码解决方法

问题 克隆仓库,TortoiseGitPlink提示输入密码 解法 1、打开TortoiseGit 下的puttygen工具 位置:C:\Program Files\TortoiseGit\bin\ 2、点击【Load】按钮,载入 C:\Users\Administrator\.ssh\ 文件夹下的id_rsa文件。 3、点击save private …

Python数据分析-远程办公与心理健康分析

一、研究背景 随着信息技术的飞速发展和全球化的推进,远程工作(Remote Work)成为越来越多企业和员工的选择。尤其是在2020年新冠疫情(COVID-19)爆发后,全球范围内的封锁措施使得远程工作模式迅速普及。根据…

Mysql数据库--JDBC编程

文章目录 1.JDBC编程基础2.驱动程序下载3.新建项目3.1导入java包3.2转换为库 4.开始创作4.1准备数据库4.2创建DataSource4.3和数据库建立连接4.4构造sql,准备发送到服务器4.5发送sql,执行sql4.6释放系统资源4.7自行输入的设置4.8插入数据完整源代码4.9查…