PyTorch: 基于【VGG16】处理MNIST数据集的图像分类任务【准确率98.9%+】

news2024/9/23 17:24:08

目录

  • 引言
  • 在Conda虚拟环境下安装pytorch
  • 步骤一:利用代码自动下载mnist数据集
  • 步骤二:搭建基于VGG16的图像分类模型
  • 步骤三:训练模型
  • 步骤四:测试模型
  • 运行结果
  • 后续模型的优化和改进建议
  • 完整代码
  • 结束语

引言

在本博客中,小编将向大家介绍如何使用VGG16处理MNIST数据集的图像分类任务。MNIST数据集是一个常用的手写数字分类数据集,包含60,000个训练样本和10,000个测试样本。我们将使用Python编程语言和PyTorch深度学习框架来实现这个任务。

在Conda虚拟环境下安装pytorch

# CUDA 11.6
pip install torch==1.12.1+cu116 torchvision==0.13.1+cu116 torchaudio==0.12.1 --extra-index-url https://download.pytorch.org/whl/cu116
# CUDA 11.3
pip install torch==1.12.1+cu113 torchvision==0.13.1+cu113 torchaudio==0.12.1 --extra-index-url https://download.pytorch.org/whl/cu113
# CUDA 10.2
pip install torch==1.12.1+cu102 torchvision==0.13.1+cu102 torchaudio==0.12.1 --extra-index-url https://download.pytorch.org/whl/cu102
# CPU only
pip install torch==1.12.1+cpu torchvision==0.13.1+cpu torchaudio==0.12.1 --extra-index-url https://download.pytorch.org/whl/cpu

步骤一:利用代码自动下载mnist数据集

import torchvision.datasets as datasets  
import torchvision.transforms as transforms  
  
# 定义数据预处理操作  
transform = transforms.Compose([
    transforms.Resize(224), # 将图像大小调整为(224, 224)
    transforms.ToTensor(),  # 将图像转换为PyTorch张量
    transforms.Normalize((0.5,), (0.5,))  # 对图像进行归一化
])
  
# 下载并加载MNIST数据集  
train_dataset = datasets.MNIST(root='./data', train=True, transform=transform, download=True)  
test_dataset = datasets.MNIST(root='./data', train=False, transform=transform)

步骤二:搭建基于VGG16的图像分类模型

class VGGClassifier(nn.Module):
    def __init__(self, num_classes):
        super(VGGClassifier, self).__init__()
        self.features = models.vgg16(pretrained=True).features  # 使用预训练的VGG16模型作为特征提取器
        # 重构VGG16网络的第一层卷积层,适配mnist数据的灰度图像格式
        self.features[0] = nn.Conv2d(1, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        self.classifier = nn.Sequential(
            nn.Linear(512 * 7 * 7, 4096),  # 添加一个全连接层,输入特征维度为512x7x7,输出维度为4096
            nn.ReLU(True),
            nn.Dropout(), # 随机将一些神经元“关闭”,这样可以有效地防止过拟合。
            nn.Linear(4096, 4096),  # 添加一个全连接层,输入和输出维度都为4096
            nn.ReLU(True),
            nn.Dropout(),
            nn.Linear(4096, num_classes),  # 添加一个全连接层,输入维度为4096,输出维度为类别数(10)
        )
        self._initialize_weights()  # 初始化权重参数

    def forward(self, x):
        x = self.features(x)  # 通过特征提取器提取特征
        x = x.view(x.size(0), -1)  # 将特征张量展平为一维向量
        x = self.classifier(x)  # 通过分类器进行分类预测
        return x

    def _initialize_weights(self):  # 定义初始化权重的方法,使用Xavier初始化方法
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
                if m.bias is not None:
                    nn.init.constant_(m.bias, 0)
            elif isinstance(m, nn.Linear):
                nn.init.normal_(m.weight, 0, 0.01)
                nn.init.constant_(m.bias, 0)

步骤三:训练模型

import torch.optim as optim  
from torch.utils.data import DataLoader  
  
# 定义超参数和训练参数  
batch_size = 64  # 批处理大小  
num_epochs = 5  # 训练轮数
learning_rate = 0.01  # 学习率
num_classes = 10  # 类别数(MNIST数据集有10个类别)  
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")  # 判断是否使用GPU进行训练,如果有GPU则使用GPU进行训练,否则使用CPU。

# 定义训练集和测试集的数据加载器  
train_loader = DataLoader(dataset=train_dataset, batch_size=batch_size, shuffle=True)  
test_loader = DataLoader(dataset=test_dataset, batch_size=batch_size, shuffle=False)  
  
# 初始化模型和优化器  
model = VGGClassifier(num_classes=num_classes).to(device)  # 将模型移动到指定设备(GPU或CPU)  
criterion = nn.CrossEntropyLoss()  # 使用交叉熵损失函数  
optimizer = optim.SGD(model.parameters(), lr=learning_rate)  # 使用随机梯度下降优化器(SGD)  
  
# 训练模型  
for epoch in range(num_epochs):  
    for i, (images, labels) in enumerate(train_loader):  
        images = images.to(device)  # 将图像数据移动到指定设备  
        labels = labels.to(device)  # 将标签数据移动到指定设备  
          
        # 前向传播  
        outputs = model(images)  
        loss = criterion(outputs, labels)  
          
        # 反向传播和优化  
        optimizer.zero_grad()  # 清空梯度缓存  
        loss.backward()  # 计算梯度  
        optimizer.step()  # 更新权重参数  
          
        if (i+1) % 100 == 0:  # 每100个batch打印一次训练信息  
            print('Epoch [{}/{}], Step [{}/{}], Loss: {:.4f}'.format(epoch+1, num_epochs, i+1, len(train_loader), loss.item()))  
              
# 保存模型参数  
torch.save(model.state_dict(), './model.pth')

步骤四:测试模型

# 加载训练好的模型参数
model.load_state_dict(torch.load('./model.pth'))
model.eval()  # 将模型设置为评估模式,关闭dropout等操作

# 定义评估指标变量
correct = 0  # 记录预测正确的样本数量
total = 0  # 记录总样本数量

# 测试模型性能
with torch.no_grad():  # 关闭梯度计算,节省内存空间
    for images, labels in test_loader:
        images = images.to(device)  # 将图像数据移动到指定设备
        labels = labels.to(device)  # 将标签数据移动到指定设备
        outputs = model(images)  # 模型前向传播,得到预测结果
        _, predicted = torch.max(outputs.data, 1)  # 取预测结果的最大值对应的类别作为预测类别
        total += labels.size(0)  # 更新总样本数量
        correct += (predicted == labels).sum().item()  # 统计预测正确的样本数量

# 计算模型准确率并打印出来
accuracy = 100 * correct / total  # 计算准确率,将正确预测的样本数量除以总样本数量并乘以100得到百分比形式的准确率。
print('Accuracy of the model on the test images: {} %'.format(accuracy))  # 打印出模型的准确率。

运行结果

在这里插入图片描述

后续模型的优化和改进建议

  1. 数据增强:通过旋转、缩放、平移等方式来增加训练数据,从而让模型拥有更好的泛化能力。
  2. 调整模型参数:可以尝试调整模型的参数,比如学习率、批次大小、迭代次数等,来提高模型的性能。
  3. 更换网络结构:可以尝试使用更深的网络结构,如ResNet、DenseNet等,来提高模型的性能。
  4. 调整优化器:本次代码采用SGD优化器,但仍可以尝试使用不同的优化器,如Adam、RMSprop等,来找到最适合我们模型的优化器。
  5. 添加正则化操作:为了防止过拟合,可以添加一些正则化项,如L1正则化、L2正则化等。
  6. 代码目前只有等训练完全结束后才能进入测试阶段,后续可以在每个epoch结束,甚至是指定的迭代次数完成后便进入测试阶段。因为训练完全结束的模型很可能已经过拟合,在测试集上不能表现较强的泛化能力。

完整代码

import torch
import torch.nn as nn

import torch.optim as optim
from torch.utils.data import DataLoader

import torchvision.models as models
import torchvision.datasets as datasets
import torchvision.transforms as transforms

import warnings
warnings.filterwarnings("ignore")

# 定义数据预处理操作
transform = transforms.Compose([
    transforms.Resize(224), # 将图像大小调整为(224, 224)
    transforms.ToTensor(),  # 将图像转换为PyTorch张量
    transforms.Normalize((0.5,), (0.5,))  # 对图像进行归一化
])

# 下载并加载MNIST数据集
train_dataset = datasets.MNIST(root='./data', train=True, transform=transform, download=True)
test_dataset = datasets.MNIST(root='./data', train=False, transform=transform)


class VGGClassifier(nn.Module):
    def __init__(self, num_classes):
        super(VGGClassifier, self).__init__()
        self.features = models.vgg16(pretrained=True).features  # 使用预训练的VGG16模型作为特征提取器
        # 重构网络的第一层卷积层,适配mnist数据的灰度图像格式
        self.features[0] = nn.Conv2d(1, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        self.classifier = nn.Sequential(
            nn.Linear(512 * 7 * 7, 4096),  # 添加一个全连接层,输入特征维度为512x7x7,输出维度为4096
            nn.ReLU(True),
            nn.Dropout(), # 随机将一些神经元“关闭”,有效地防止过拟合。
            nn.Linear(4096, 4096),  # 添加一个全连接层,输入和输出维度都为4096
            nn.ReLU(True),
            nn.Dropout(),
            nn.Linear(4096, num_classes),  # 添加一个全连接层,输入维度为4096,输出维度为类别数(10)
        )
        self._initialize_weights()  # 初始化权重参数

    def forward(self, x):
        x = self.features(x)  # 通过特征提取器提取特征
        x = x.view(x.size(0), -1)  # 将特征张量展平为一维向量
        x = self.classifier(x)  # 通过分类器进行分类预测
        return x

    def _initialize_weights(self):  # 定义初始化权重的方法,使用Xavier初始化方法
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
                if m.bias is not None:
                    nn.init.constant_(m.bias, 0)
            elif isinstance(m, nn.Linear):
                nn.init.normal_(m.weight, 0, 0.01)
                nn.init.constant_(m.bias, 0)



# 定义超参数和训练参数
batch_size = 64  # 批处理大小
num_epochs = 5  # 训练轮数(epoch)
learning_rate = 0.01  # 学习率(learning rate)
num_classes = 10  # 类别数(MNIST数据集有10个类别)
device = torch.device(
    "cuda:0" if torch.cuda.is_available() else "cpu")  # 判断是否使用GPU进行训练,如果有GPU则使用第一个GPU(cuda:0)进行训练,否则使用CPU进行训练。

# 定义数据加载器
train_loader = DataLoader(dataset=train_dataset, batch_size=batch_size, shuffle=True)
test_loader = DataLoader(dataset=test_dataset, batch_size=batch_size, shuffle=False)

# 初始化模型和优化器
model = VGGClassifier(num_classes=num_classes).to(device)  # 将模型移动到指定设备(GPU或CPU)
criterion = nn.CrossEntropyLoss()  # 使用交叉熵损失函数
optimizer = optim.SGD(model.parameters(), lr=learning_rate)  # 使用随机梯度下降优化器(SGD)

# 训练模型
for epoch in range(num_epochs):
    for i, (images, labels) in enumerate(train_loader):
        images = images.to(device)  # 将图像数据移动到指定设备
        labels = labels.to(device)  # 将标签数据移动到指定设备

        # 前向传播
        outputs = model(images)
        loss = criterion(outputs, labels)

        # 反向传播和优化
        optimizer.zero_grad()  # 清空梯度缓存
        loss.backward()  # 计算梯度
        optimizer.step()  # 更新权重参数

        if (i + 1) % 100 == 0:  # 每100个batch打印一次训练信息
            print('Epoch [{}/{}], Step [{}/{}], Loss: {:.4f}'.format(epoch + 1, num_epochs, i + 1, len(train_loader),
                                                                     loss.item()))

# 训练结束,保存模型参数
torch.save(model.state_dict(), './model.pth')

# 加载训练好的模型参数
model.load_state_dict(torch.load('./model.pth'))
model.eval()  # 将模型设置为评估模式,关闭dropout等操作

# 定义评估指标变量
correct = 0  # 记录预测正确的样本数量
total = 0  # 记录总样本数量

# 测试模型性能
with torch.no_grad():  # 关闭梯度计算,节省内存空间
    for images, labels in test_loader:
        images = images.to(device)  # 将图像数据移动到指定设备
        labels = labels.to(device)  # 将标签数据移动到指定设备
        outputs = model(images)  # 模型前向传播,得到预测结果
        _, predicted = torch.max(outputs.data, 1)  # 取预测结果的最大值对应的类别作为预测类别
        total += labels.size(0)  # 更新总样本数量
        correct += (predicted == labels).sum().item()  # 统计预测正确的样本数量

# 计算模型准确率并打印出来
accuracy = 100 * correct / total  # 计算准确率,将正确预测的样本数量除以总样本数量并乘以100得到百分比形式的准确率。
print('Accuracy of the model on the test images: {} %'.format(accuracy))  # 打印出模型的准确率。

结束语

如果本博文对你有所帮助/启发,可以点个赞/收藏支持一下,如果能够持续关注,小编感激不尽~
如果有相关需求/问题需要小编帮助,欢迎私信~
小编会坚持创作,持续优化博文质量,给读者带来更好de阅读体验~

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1313182.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

MySQL数据库卸载-Windows

目录 1. 停止MySQL服务 2. 卸载MySQL相关组件 3. 删除MySQL安装目录 4. 删除MySQL数据目录 5. 再次打开服务,查看是否有MySQL卸载残留 1. 停止MySQL服务 winR 打开运行,输入 services.msc 点击 "确定" 调出系统服务。 2. 卸载MySQL相关组…

国标级联/流媒体音视频平台EasyCVR设备录像下载异常该如何解决?

视频监控TSINGSEE青犀视频平台EasyCVR能在复杂的网络环境中,将分散的各类视频资源进行统一汇聚、整合、集中管理,在视频监控播放上,视频安防监控汇聚平台可支持1、4、9、16个画面窗口播放,可同时播放多路视频流,也能支…

一、win10+yolov8+anaconda环境部署

1、安装anaconda (1)打开aonconda下载地址:https://www.anaconda.com/download,点击download下载。 2、下载完成后,双击打开,点击Next,I Agree,选择just me; 3、勾选…

SQL进阶理论篇(五):什么是Hash索引

文章目录 简介MySQL中的Hash索引与B树的区别总结参考文献 简介 hash,即哈希,也被称为是散列函数。 Hash在数据库中的应用,可以帮助我们大幅度提升检索数据的效率。 大名鼎鼎的MD5其实就是Hash函数的一种变体。 Hash算法,是通过…

ArkTS编译时遇到arkts-no-obj-literals-as-types错误【Bug已解决-鸿蒙】

文章目录 项目场景:问题描述原因分析:解决方案:解决方案1解决方案2此Bug解决方案总结项目场景: 在开发鸿蒙项目过程中,遇到了arkts-no-obj-literals-as-types,总结了自己和网上人的解决方案,故写下这篇文章。 遇到问题: rkTS编译时遇到arkts-no-obj-literals-as-type…

操作系统中的作业管理

从用户的角度看,作业是系统为完成一个用户的计算任务(或一次事务处理)所做的工作总和。例如,对于用户编制的源程序,需经过对源程序的编译、连接编辑或连接装入及运行产生计算结果。这其中的每一个步骤,常称…

解锁知识的新大门:自建知识付费小程序的技术指南

在数字化时代,知识付费小程序的崛起为创作者和学习者提供了全新的学习和分享方式。本文将以“知识付费小程序源码”为关键词,从技术角度出发,为你展示如何搭建一个独具特色的知识付费平台。 步骤1:选择适用的知识付费小程序源码…

知识库SEO:提升网站内容质量与搜索引擎排名的策略

随着搜索引擎算法的不断更新和优化,单纯依靠关键词堆砌和外部链接的时代已经过去。现在的SEO(搜索引擎优化)已经转向了以提供高质量、有价值内容为核心的阶段。知识库SEO便是这个新阶段的重要策略之一。 | 一、知识库SEO的概念与意义 1.定义…

《儿童绘本》期刊杂志发表论文投稿

《儿童绘本》杂志是由国家新闻出版管理部门批准,由吉林省舆林报刊发展有限责任公司主管主办,国内外公开发行的全国优秀期刊。办刊宗旨:以“普及绘本知识、推动儿童阅读”为理念,带动家庭亲子阅读,推动阅读教育及图画书…

一文解析数据结构是如何装入 CPU 寄存器的?

我们在之前很多文章的讲解中涉及了CPU与寄存器,然后有同学问了这样一个问题:既然CPU内部的寄存器数量有限,容量有限,那么我们使用的庞大的数据结构是怎样装入寄存器供CPU计算的呢?这篇文章就为你讲解一下这个问题。 内…

交叉销售与场景业务销售运营

交叉销售 交叉销售的定义 交叉销售是一种从横向角度开发产品市场的方式,是营销人员在完成本职工作以后,主动积极的向现有客户、市场等销售其他的、额外的产品或服务。 交叉销售的类型 补充销售 搭配销售个性化推荐奖励推荐 捆绑销售 交叉销售的意义 通过增加客户的转移成本…

VMP泄露编译的一些注意事项

VMP编译教程 鉴于VMP已经在GitHub上被大佬强制开源,特此出一期编译教程。各位熟悉的可以略过,不熟悉的可以参考一下。 环境(软件) Visual Studio 2015 - 2022 (建议使用VS2019,Qt插件只有这个版本及以上…

企业信息建设现状

信息化建设是传统计算机与互联网技术高速发展并融合的产物,现阶段已经成为引领产业创新的决定性技术手段。 随着信息化的不断发展与进步,各行各业都开始了信息化的建设与应用。信息化是未来发展的大趋势,企业运用信息技术可以大幅度提高员工效…

【node】 地址标准化 解析手机号、姓名、行政区

地址标准化 解析手机号、姓名、行政区 实现效果链接源码 实现效果 将东光县科技园南路444号马晓姐13243214321 解析为 东光县科技园南路444号 13243214321 河北省;沧州市;东光县;东光镇 马晓姐 console.log(address, phone, divisions,name);链接 API概览 源码 https://gi…

NPM开发工具的简介和使用方法及代码示例

NPM(Node Package Manager)是Node.js的包管理工具,用于管理和共享被发布到模块仓库的JavaScript代码。本文将介绍NPM的定义、使用方法、代码示例以及总结。 一、NPM的定义 NPM是Node.js的默认包管理工具,它的功能包括安装、管理、…

CSS的盒子模型(重点)

网页布局的三大核心:盒子模型、浮动、定位 网页布局的过程: 1. 先准备好相关的网页元素,网页元素基本都是盒子 Box 。 2. 利用 CSS 设置好盒子样式,然后摆放到相应位置。 3. 往盒子里面装内容.网页布局的核心本质: 就…

【带头学C++】----- 九、类和对象 ---- 9.12 C++之友元函数(9.12.1---12.4)

❤️❤️❤️❤️❤️❤️❤️❤️❤️❤️❤️创做不易,麻烦点个关注❤️❤️❤️❤️❤️❤️❤️❤️❤️❤️❤️❤️ ❤️❤️❤️❤️❤️❤️❤️❤️❤️文末有惊喜!献舞一支!❤️❤️❤️❤️❤️❤️❤️❤️❤️❤️ 目录 9.12…

[论文阅读]Multimodal Virtual Point 3D Detection

Multimodal Virtual Point 3D Detection 多模态虚拟点3D检测 论文网址:MVP 论文代码:MVP 论文简读 方法MVP方法的核心思想是将RGB图像中的2D检测结果转换为虚拟的3D点,并将这些虚拟点与原始的Lidar点云合并。具体步骤如下: (1)…

Course3-Week2-推荐系统

Course3-Week2-推荐系统 文章目录 Course3-Week2-推荐系统1. 推荐机制的问题引入1.1 预测电影评分1.2 数学符号 2. 协同过滤算法2.1 协同过滤算法-线性回归2.2 协同过滤算法-逻辑回归2.3 均值归一化2.4 协同过滤算法的TensorFlow实现2.5 寻找相似的电影、协同过滤算法的缺点2.6…

C++之模板

目录 泛型编程 模板 函数模板 函数模板的实例化 隐式实例化 显示实例化 类模板 我们知道STL(标准模板库)是C学习的精华所在,在学习STL之前我们得先学习一个新的知识点-------模板。那么模板究竟是什么呢?围绕着这个问题&a…