PyTorch: 基于VGG16处理MNIST数据集的图像分类任务

news2024/7/6 21:17:25

引言

在本博客中,小编将向大家介绍如何使用VGG16处理MNIST数据集的图像分类任务。MNIST数据集是一个常用的手写数字分类数据集,包含60,000个训练样本和10,000个测试样本。我们将使用Python编程语言和PyTorch深度学习框架来实现这个任务。

在Conda虚拟环境下安装pytorch

# CUDA 11.6
pip install torch==1.12.1+cu116 torchvision==0.13.1+cu116 torchaudio==0.12.1 --extra-index-url https://download.pytorch.org/whl/cu116
# CUDA 11.3
pip install torch==1.12.1+cu113 torchvision==0.13.1+cu113 torchaudio==0.12.1 --extra-index-url https://download.pytorch.org/whl/cu113
# CUDA 10.2
pip install torch==1.12.1+cu102 torchvision==0.13.1+cu102 torchaudio==0.12.1 --extra-index-url https://download.pytorch.org/whl/cu102
# CPU only
pip install torch==1.12.1+cpu torchvision==0.13.1+cpu torchaudio==0.12.1 --extra-index-url https://download.pytorch.org/whl/cpu

步骤一:利用代码自动下载mnist数据集

import torchvision.datasets as datasets  
import torchvision.transforms as transforms  
  
# 定义数据预处理操作  
transform = transforms.Compose([
    transforms.Resize(224), # 将图像大小调整为(224, 224)
    transforms.ToTensor(),  # 将图像转换为PyTorch张量
    transforms.Normalize((0.5,), (0.5,))  # 对图像进行归一化
])
  
# 下载并加载MNIST数据集  
train_dataset = datasets.MNIST(root='./data', train=True, transform=transform, download=True)  
test_dataset = datasets.MNIST(root='./data', train=False, transform=transform)

步骤二:搭建基于VGG16的图像分类模型

class VGGClassifier(nn.Module):
    def __init__(self, num_classes):
        super(VGGClassifier, self).__init__()
        self.features = models.vgg16(pretrained=True).features  # 使用预训练的VGG16模型作为特征提取器
        # 重构VGG16网络的第一层卷积层,适配mnist数据的灰度图像格式
        self.features[0] = nn.Conv2d(1, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        self.classifier = nn.Sequential(
            nn.Linear(512 * 7 * 7, 4096),  # 添加一个全连接层,输入特征维度为512x7x7,输出维度为4096
            nn.ReLU(True),
            nn.Dropout(), # 随机将一些神经元“关闭”,这样可以有效地防止过拟合。
            nn.Linear(4096, 4096),  # 添加一个全连接层,输入和输出维度都为4096
            nn.ReLU(True),
            nn.Dropout(),
            nn.Linear(4096, num_classes),  # 添加一个全连接层,输入维度为4096,输出维度为类别数(10)
        )
        self._initialize_weights()  # 初始化权重参数

    def forward(self, x):
        x = self.features(x)  # 通过特征提取器提取特征
        x = x.view(x.size(0), -1)  # 将特征张量展平为一维向量
        x = self.classifier(x)  # 通过分类器进行分类预测
        return x

    def _initialize_weights(self):  # 定义初始化权重的方法,使用Xavier初始化方法
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
                if m.bias is not None:
                    nn.init.constant_(m.bias, 0)
            elif isinstance(m, nn.Linear):
                nn.init.normal_(m.weight, 0, 0.01)
                nn.init.constant_(m.bias, 0)

步骤三:训练模型

import torch.optim as optim  
from torch.utils.data import DataLoader  
  
# 定义超参数和训练参数  
batch_size = 64  # 批处理大小  
num_epochs = 5  # 训练轮数
learning_rate = 0.01  # 学习率
num_classes = 10  # 类别数(MNIST数据集有10个类别)  
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")  # 判断是否使用GPU进行训练,如果有GPU则使用GPU进行训练,否则使用CPU。

# 定义训练集和测试集的数据加载器  
train_loader = DataLoader(dataset=train_dataset, batch_size=batch_size, shuffle=True)  
test_loader = DataLoader(dataset=test_dataset, batch_size=batch_size, shuffle=False)  
  
# 初始化模型和优化器  
model = VGGClassifier(num_classes=num_classes).to(device)  # 将模型移动到指定设备(GPU或CPU)  
criterion = nn.CrossEntropyLoss()  # 使用交叉熵损失函数  
optimizer = optim.SGD(model.parameters(), lr=learning_rate)  # 使用随机梯度下降优化器(SGD)  
  
# 训练模型  
for epoch in range(num_epochs):  
    for i, (images, labels) in enumerate(train_loader):  
        images = images.to(device)  # 将图像数据移动到指定设备  
        labels = labels.to(device)  # 将标签数据移动到指定设备  
          
        # 前向传播  
        outputs = model(images)  
        loss = criterion(outputs, labels)  
          
        # 反向传播和优化  
        optimizer.zero_grad()  # 清空梯度缓存  
        loss.backward()  # 计算梯度  
        optimizer.step()  # 更新权重参数  
          
        if (i+1) % 100 == 0:  # 每100个batch打印一次训练信息  
            print('Epoch [{}/{}], Step [{}/{}], Loss: {:.4f}'.format(epoch+1, num_epochs, i+1, len(train_loader), loss.item()))  
              
# 保存模型参数  
torch.save(model.state_dict(), './model.pth')

步骤四:测试模型

# 加载训练好的模型参数
model.load_state_dict(torch.load('./model.pth'))
model.eval()  # 将模型设置为评估模式,关闭dropout等操作

# 定义评估指标变量
correct = 0  # 记录预测正确的样本数量
total = 0  # 记录总样本数量

# 测试模型性能
with torch.no_grad():  # 关闭梯度计算,节省内存空间
    for images, labels in test_loader:
        images = images.to(device)  # 将图像数据移动到指定设备
        labels = labels.to(device)  # 将标签数据移动到指定设备
        outputs = model(images)  # 模型前向传播,得到预测结果
        _, predicted = torch.max(outputs.data, 1)  # 取预测结果的最大值对应的类别作为预测类别
        total += labels.size(0)  # 更新总样本数量
        correct += (predicted == labels).sum().item()  # 统计预测正确的样本数量

# 计算模型准确率并打印出来
accuracy = 100 * correct / total  # 计算准确率,将正确预测的样本数量除以总样本数量并乘以100得到百分比形式的准确率。
print('Accuracy of the model on the test images: {} %'.format(accuracy))  # 打印出模型的准确率。

运行结果

在这里插入图片描述

后续模型的优化和改进建议

  1. 数据增强:通过旋转、缩放、平移等方式来增加训练数据,从而让模型拥有更好的泛化能力。
  2. 调整模型参数:可以尝试调整模型的参数,比如学习率、批次大小、迭代次数等,来提高模型的性能。
  3. 更换网络结构:可以尝试使用更深的网络结构,如ResNet、DenseNet等,来提高模型的性能。
  4. 调整优化器:本次代码采用SGD优化器,但仍可以尝试使用不同的优化器,如Adam、RMSprop等,来找到最适合我们模型的优化器。
  5. 添加正则化操作:为了防止过拟合,可以添加一些正则化项,如L1正则化、L2正则化等。
  6. 代码目前只有等训练完全结束后才能进入测试阶段,后续可以在每个epoch结束,甚至是指定的迭代次数完成后便进入测试阶段。因为训练完全结束的模型很可能已经过拟合,在测试集上不能表现较强的泛化能力。

完整代码

import torch
import torch.nn as nn

import torch.optim as optim
from torch.utils.data import DataLoader

import torchvision.models as models
import torchvision.datasets as datasets
import torchvision.transforms as transforms

import warnings
warnings.filterwarnings("ignore")

# 定义数据预处理操作
transform = transforms.Compose([
    transforms.Resize(224), # 将图像大小调整为(224, 224)
    transforms.ToTensor(),  # 将图像转换为PyTorch张量
    transforms.Normalize((0.5,), (0.5,))  # 对图像进行归一化
])

# 下载并加载MNIST数据集
train_dataset = datasets.MNIST(root='./data', train=True, transform=transform, download=True)
test_dataset = datasets.MNIST(root='./data', train=False, transform=transform)


class VGGClassifier(nn.Module):
    def __init__(self, num_classes):
        super(VGGClassifier, self).__init__()
        self.features = models.vgg16(pretrained=True).features  # 使用预训练的VGG16模型作为特征提取器
        # 重构网络的第一层卷积层,适配mnist数据的灰度图像格式
        self.features[0] = nn.Conv2d(1, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        self.classifier = nn.Sequential(
            nn.Linear(512 * 7 * 7, 4096),  # 添加一个全连接层,输入特征维度为512x7x7,输出维度为4096
            nn.ReLU(True),
            nn.Dropout(), # 随机将一些神经元“关闭”,有效地防止过拟合。
            nn.Linear(4096, 4096),  # 添加一个全连接层,输入和输出维度都为4096
            nn.ReLU(True),
            nn.Dropout(),
            nn.Linear(4096, num_classes),  # 添加一个全连接层,输入维度为4096,输出维度为类别数(10)
        )
        self._initialize_weights()  # 初始化权重参数

    def forward(self, x):
        x = self.features(x)  # 通过特征提取器提取特征
        x = x.view(x.size(0), -1)  # 将特征张量展平为一维向量
        x = self.classifier(x)  # 通过分类器进行分类预测
        return x

    def _initialize_weights(self):  # 定义初始化权重的方法,使用Xavier初始化方法
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
                if m.bias is not None:
                    nn.init.constant_(m.bias, 0)
            elif isinstance(m, nn.Linear):
                nn.init.normal_(m.weight, 0, 0.01)
                nn.init.constant_(m.bias, 0)



# 定义超参数和训练参数
batch_size = 64  # 批处理大小
num_epochs = 5  # 训练轮数(epoch)
learning_rate = 0.01  # 学习率(learning rate)
num_classes = 10  # 类别数(MNIST数据集有10个类别)
device = torch.device(
    "cuda:0" if torch.cuda.is_available() else "cpu")  # 判断是否使用GPU进行训练,如果有GPU则使用第一个GPU(cuda:0)进行训练,否则使用CPU进行训练。

# 定义数据加载器
train_loader = DataLoader(dataset=train_dataset, batch_size=batch_size, shuffle=True)
test_loader = DataLoader(dataset=test_dataset, batch_size=batch_size, shuffle=False)

# 初始化模型和优化器
model = VGGClassifier(num_classes=num_classes).to(device)  # 将模型移动到指定设备(GPU或CPU)
criterion = nn.CrossEntropyLoss()  # 使用交叉熵损失函数
optimizer = optim.SGD(model.parameters(), lr=learning_rate)  # 使用随机梯度下降优化器(SGD)

# 训练模型
for epoch in range(num_epochs):
    for i, (images, labels) in enumerate(train_loader):
        images = images.to(device)  # 将图像数据移动到指定设备
        labels = labels.to(device)  # 将标签数据移动到指定设备

        # 前向传播
        outputs = model(images)
        loss = criterion(outputs, labels)

        # 反向传播和优化
        optimizer.zero_grad()  # 清空梯度缓存
        loss.backward()  # 计算梯度
        optimizer.step()  # 更新权重参数

        if (i + 1) % 100 == 0:  # 每100个batch打印一次训练信息
            print('Epoch [{}/{}], Step [{}/{}], Loss: {:.4f}'.format(epoch + 1, num_epochs, i + 1, len(train_loader),
                                                                     loss.item()))

# 训练结束,保存模型参数
torch.save(model.state_dict(), './model.pth')

# 加载训练好的模型参数
model.load_state_dict(torch.load('./model.pth'))
model.eval()  # 将模型设置为评估模式,关闭dropout等操作

# 定义评估指标变量
correct = 0  # 记录预测正确的样本数量
total = 0  # 记录总样本数量

# 测试模型性能
with torch.no_grad():  # 关闭梯度计算,节省内存空间
    for images, labels in test_loader:
        images = images.to(device)  # 将图像数据移动到指定设备
        labels = labels.to(device)  # 将标签数据移动到指定设备
        outputs = model(images)  # 模型前向传播,得到预测结果
        _, predicted = torch.max(outputs.data, 1)  # 取预测结果的最大值对应的类别作为预测类别
        total += labels.size(0)  # 更新总样本数量
        correct += (predicted == labels).sum().item()  # 统计预测正确的样本数量

# 计算模型准确率并打印出来
accuracy = 100 * correct / total  # 计算准确率,将正确预测的样本数量除以总样本数量并乘以100得到百分比形式的准确率。
print('Accuracy of the model on the test images: {} %'.format(accuracy))  # 打印出模型的准确率。

结束语

如果本博文对你有所帮助/启发,可以点个赞/收藏支持一下,如果能够持续关注,小编感激不尽~
如果有相关需求/问题需要小编帮助,欢迎私信~
小编会坚持创作,持续优化博文质量,给读者带来更好de阅读体验~

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1297584.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

鸿蒙开发组件之Image

Image组件加载图片方式有三种: 1、网络地址加载 直接Image(xxxx),添加上图片的网络地址就可以了。注意:真机、模拟题调试需要申请"ohos.permission.INTERNET"权限 Image(https://xxxxxxx) 2、PixelMap格式加载像素图 Image(PixelMapObjec…

根据年份和第几周来获取,那一个周的周天日期

在工作中遇到这个问题,仓库有物料录入,告诉了年份和这个年的第几周,要求把时间转换为XXXX-XX-XX的格式。日期为那个周的最后一天(周天) 在Java中想要获取特定年份和周数的周天日期,可以使用LocalDate类 pu…

【SpringBoot】响应

controller方法中的return的结果,使用ResponseBody注解(方法注解或类注解)响应给服务器。 RestController Controller ResponseBody 类上有RestController注解或ResponseBody注解时:表示当前类下所有的方法返回值做为响应数据…

四招打造完美分层自动化测试框架,让测试更高效!

写在前面 我们刚开始做自动化测试,可能写的代码都是基于原生写的代码,看起来特别不美观,而且感觉特别生硬。 来看下面一段代码: 具体表现如下: driver对象在测试类中显示 定位元素的value值在测试类中显示 定位元素…

数据表排序

指针用的有点少了&#xff0c;有点不适应 用的冒泡排序 代码如下&#xff1a; #include<stdio.h> int num[100][100]; int * p[100], jud[100]; int judge(int i, int j, int rank); int m, n, k;int main(void) {scanf("%d%d%d", &m, &n, &k);f…

字符串函数strtok

1.调用格式&#xff1a; 2.调用形式&#xff1a;char*strtok(char*p1,const char*p2),其中第二个是由分隔符组成的字符串&#xff0c;第一个为需要分隔的字符串 3.调用目的&#xff1a;将分隔符之间的字符串取出 4.调用时一般将源字符串拷贝后调用&#xff0c;因为此函数会将…

C++11原子操作atomic

文章目录 原子操作atomic原子操作的相关函数原子操作的特点“平凡的”与“合格的” 原子操作atomic 前面我们介绍了互斥锁等一系列多线程相关操作&#xff0c;这里我们来说下原子操作atomic。 可以理解为原子变量就是将上面的操作进行了整合的一个全新变量&#xff0c;但是实际…

sensitive word 敏感词(脏词) 如何忽略无意义的字符?达到更好的过滤效果?

忽略字符 说明 我们的敏感词一般都是比较连续的&#xff0c;比如 傻帽 那就有大聪明发现&#xff0c;可以在中间加一些字符&#xff0c;比如【傻!#$帽】跳过检测&#xff0c;但是骂人等攻击力不减。 那么&#xff0c;如何应对这些类似的场景呢&#xff1f; 我们可以指定特…

带大家做一个,易上手的家常可乐鸡翅

将鸡翅从冰箱中拿出 泡水解冻 这里 我用的二十个 将葱切段 切一些蒜片 有姜也可以切一些小片下来 这里 家里没姜了 六根干辣椒 一把花椒 等鸡翅化开之后 清洗干净 然后 如下图 中间位置切两刀 方便入味 起锅烧油 然后 下鸡翅 干辣椒 花椒 先翻炒一下 这里不需要放水 鸡翅会…

JavaScript基础知识整理(最全知识点, 精简版,0基础版)

文章目录 一、输入和输出内容 1.1 输出 1.1.1 在浏览器的控制台输出打印 1.1.2 直接在浏览器的页面上输出内容 1.1.3 页面弹出警告对话框 1.2 输入 二、变量 2.1 变量是什么 2.2 变量的声明和赋值 2.3 变量的命名规范和规范 三、变量扩展&#xff08;数组&#xff09; 3.1 数组…

Databend 开源周报第 122 期

Databend 是一款现代云数仓。专为弹性和高效设计&#xff0c;为您的大规模分析需求保驾护航。自由且开源。即刻体验云服务&#xff1a;https://app.databend.cn 。 Whats On In Databend 探索 Databend 本周新进展&#xff0c;遇到更贴近你心意的 Databend 。 支持链式函数调…

windows11 windows 11 (win11 win 11) 怎么安装 Python3 ? numpy? sounddevice? 声音信号处理库?

首先确认要安装的 sounddevice 库&#xff0c;链接&#xff1a;https://python-sounddevice.readthedocs.io/en/0.4.6/ 根据文档&#xff0c;可知最新的 sounddevice 版本是 0.4.6 进入安装页面查看&#xff0c;发现 Newest sounddevice 可以使用 pip 安装&#xff0c;如下图…

mysql中NULL值

mysql中NULL值表示“没有值”&#xff0c;它跟空字符串""是不同的 例如&#xff0c;执行下面两个插入记录的语句&#xff1a; insert into test_table (description) values (null); insert into test_table (description) values ();执行以后&#xff0c;查看表的…

Linux系统编程:高级IO总结

非阻塞IO基本概念 高级IO核心就一个概念&#xff1a;非阻塞IO。 与该概念相对的&#xff0c;就是我们之前学习过的阻塞IO。 非阻塞IO&#xff08;Non-blocking I/O&#xff09;是一种IO模型&#xff0c;用于实现异步IO操作&#xff0c;使应用程序能够在等待IO操作完成的同时…

【Python】 生成二维码

创建了一个使用 python 创建二维码的程序。 下面是生成的程序的图像。 功能描述 输入网址&#xff08;URL&#xff09;。 输入二维码的名称。 当单击 QR 码生成按钮时&#xff0c;将使用 QRname 中输入的字符将 QR 码生成为图像。 程序代码 import qrcode import tkinterd…

【C++】简单工厂模式

2023年12月6日&#xff0c;周三下午 今天又学习了一次简单工厂模式 每多学习一次&#xff0c;都会加深对设计模式的理解 目录 什么是简单工厂模式简单工厂模式的优缺点举例说明 什么是简单工厂模式 简单工厂模式&#xff08;Simple Factory Pattern&#xff09;是一种创建型…

【深度学习】一维数组的 K-Means 聚类算法理解

刚看了这个算法&#xff0c;理解如下&#xff0c;放在这里&#xff0c;备忘&#xff0c;如有错误的地方&#xff0c;请指出&#xff0c;谢谢 需要做聚类的数组我们称之为【源数组】 需要一个分组个数K变量来标记需要分多少个组&#xff0c;这个数组我们称之为【聚类中心数组】…

网络基础---网络层详解(图文清晰易懂!!!)

目录 一、网络层的功能 二、IP数据包的格式 三、ICMP协议 1.ICMP协议的概念和作用 2.ping命令 2.1 ping 的格式 2.2 ping选项 2.3 当我们ping不通时&#xff0c;及服务器出现问题&#xff0c;如何排错 2.4 信息传递时出现的问题类型和具体情况 四、冲突域和广播域 1.…

STM32--GPIO点亮LED灯(手把手,超详细)

写在前面&#xff1a;在前面的学习中&#xff0c;我们学习了STM32的编译环境&#xff08;MDK&#xff09;、时钟树以及GPIO的8种工作模式&#xff1b;这节我们学习正式入门STM32---点亮第一个LED灯&#xff1b;即利用GPIO进行电灯&#xff0c;尽管是一个十分简单的实现&#xf…

2020年第九届数学建模国际赛小美赛A题自由泳解题全过程文档及程序

2020年第九届数学建模国际赛小美赛 A题 自由泳 原题再现&#xff1a; 在所有常见的游泳泳姿中&#xff0c;哪一种最快&#xff1f;哪个冲程推力最大&#xff1f;在自由泳项目中&#xff0c;游泳者可以选择他们的泳姿&#xff0c;他们通常选择前面的爬行。然而&#xff0c;游泳…