CIFAR-10 数据集图像分类与可视化

news2024/9/21 22:55:15

数据准备

CIFAR-10 and CIFAR-100 datasets (toronto.edu)在上述网站中下载Python版本的CIFAR-10数据集。

下载后的压缩包解压后会得到几个文件如下:

对应的data_batch_1 ~ data_batch_5 是划分好的训练数据,每个文件里包含10000张图片,test_batch 是测试集数据,也包含10000张图片。他们的结构是一样的,需要分别对这些data_bach进行处理。

查阅相关文献可知,对应的data_batch都是使用的pickle库进行处理获得的。所以在处理该文件时,也需要使用pickle库进行读取。

编写一段代码脚本,将原来文件拆解成图片,并将训练集图片与测试集图片分别保存在train和test文件夹中可以得到如下图所示结果。

如上图所示,可知对应的训练集数据为5万张,测试集数据为1万张。

对应代码运行结果如下图所示

TIP:其他可选方案,其实torchvision库中的CIFAR库是可以直接加载的。使用代码torchvision.datasets.CIFAR10就可以直接调用库中的数据集。在此,直接下载完全部图片后再进行处理,会更加方便。

torchvision.datasets.CIFAR10用于加载 CIFAR-10 数据集。参数包括:

root:数据集存放的根目录。

train:True 表示加载训练集,False 表示加载测试集。

download:是否下载数据集,如果设置为 True,数据集将会被自动下载到 root 目录下。

transform:用于对数据进行转换的操作。

对上述数据集中数据进行归一化、图像增强等操作。

import torchvision.transforms as transforms

# 定义图像预处理操作
transform_train = transforms.Compose([
    transforms.RandomCrop(32, padding=4),  # 随机裁剪,数据增强
    transforms.RandomHorizontalFlip(),  # 随机水平翻转,数据增强
    transforms.ToTensor(),  # 转换为张量
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))  # 归一化
])

使用随机裁剪、水平翻转技术进行数据增强操作,提高后续模型的特征提取能力。

模型构建

使用 PyTorch 构建卷积神经网络模型。设计合适的网络结构,包括卷积层、池化层、全连接层等。搭建的卷积神经网络结构图如下所示

import torch
from torch import nn
from torch.nn import Sequential, Conv2d, MaxPool2d, Linear, Flatten


class Module(nn.Module):
    def __init__(self):
        super(Module, self).__init__()
        self.model1 = Sequential(  # 效果同上
            Conv2d(3, 32, 5, padding=2),
            MaxPool2d(2),
            Conv2d(32, 32, 5, padding=2),
            MaxPool2d(2),
            Conv2d(32, 64, 5, padding=2),
            MaxPool2d(2),
            Flatten(),
            Linear(1024, 64),
            Linear(64, 10)
        )

    def forward(self, x):
        x = self.model1(x)
        return x


if __name__ == '__main__':
    # 验证网络正确性
    model = Module()
    input = torch.ones((64, 3, 32, 32))
    output = model(input)
    print(output.shape)  # torch.Size([64, 10])

该卷积神经网络包含了一个卷积层 (Conv2d),输入通道数为3,输出通道数为32,卷积核大小为5x5,使用零填充(padding=2)。一个最大池化层 (MaxPool2d),池化窗口大小为2x2。另一个卷积层,输入通道数为32,输出通道数为32,卷积核大小为5x5,同样使用零填充。另一个最大池化层,池化窗口大小为2x2。还有一个卷积层,输入通道数为32,输出通道数为64,卷积核大小为5x5,零填充。再接一个最大池化层,池化窗口大小为2x2。

然后是将特征展平的层 (Flatten),用于将卷积层输出的特征张量展平成一维向量。

接着是一个全连接层 (Linear),输入大小为1024,输出大小为64。

最后是另一个全连接层,输入大小为64,输出大小为10。这里的10代表着输出类别的数量。

后面函数解释:

def forward(self, x)是模型的前向传播函数,定义了数据从输入到输出的流程。

x = self.model1(x):这里将输入数据 x 输入到 model1 中,进行前向传播计算。

return x:返回模型的输出结果。

if __name__ == '__main__'::这是Python中的常用写法,用于判断当前脚本是否作为主程序执行。

model = Module():创建了一个模型对象 model,实例化了前面定义的 Module 类。

input = torch.ones((64, 3, 32, 32)):创建了一个大小为64x3x32x32的张量作为输入数据,表示64个样本,每个样本的图像大小为32x32,通道数为3(假设是RGB图像)。

output = model(input):将输入数据输入到模型中进行前向传播计算,得到输出结果。

print(output.shape):打印输出结果的形状,这里输出的形状为 torch.Size([64, 10]),表示有64个样本,每个样本对应一个长度为10的输出向量,其中每个元素表示对应类别的预测分数或概率。

对应构建的卷积神经网络结构图如下图所示:

模型训练

定义损失函数和优化器。将数据集分为训练集和验证集。在训练集上训练模型,通过验证集调整模型参数,避免过拟合。

# 6损失函数
loss_fn = nn.CrossEntropyLoss()

# 7优化器
learning_rate = 1e-2
optimizer = torch.optim.SGD(model.parameters(), lr=learning_rate)

# 8设置训练网络的一些参数
total_train_step = 0  # 记录训练次数
total_test_step = 0  # 记录测试次数
epoch = 10  # 训练的轮数

损失函数、优化器如上所示,损失函数使用交叉熵损失函数,优化器中学习率learning rate为0.01,优化器使用SGD优化器。

模型评估

使用测试集评估模型性能,计算准确率等指标。

随着训练次数增加,模型在测试集上面的整体损失LOSS一直在下降,正确率一直在提升。训练准确率在第34轮训练时到达66.7%

可视化展示

通过表格展示准确率等实验结果。绘制准确率和损失函数随训练轮次变化的曲线图。随机选取部分图像,展示模型的预测结果和真实标签。

此处的可视化使用了tensorboard展示板结合日志文件进行展示

tensorboard --logdir=logs

logs代表着日志文件对应的文件夹所在位置

使用上面代码进行读取代码运行产生的日志文件。

TIP:日志文件所在的文件夹路径中不能存在中文路径,否则会报错。

import torch
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
from torchvision.datasets import CIFAR10
import matplotlib.pyplot as plt
import numpy as np

# 加载测试数据集
test_data = CIFAR10(root="data", train=False, transform=transforms.ToTensor(), download=True)
test_loader = DataLoader(test_data, batch_size=16, shuffle=True)

# 定义类别名称
classes = ['airplane', 'automobile', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck']


# 加载模型
model = torch.load(r"C:\Users\Lenovo\Desktop\计算机视觉实验\实验2\CIFAR-10\200轮训练权重\model_34.pth")  # 假设模型保存在 model.pth 中


# 设置模型为评估模式
model.eval()

# 从测试数据集中随机选择一批图像和标签
images, labels = next(iter(test_loader))

# 对图像进行预测
with torch.no_grad():
    outputs = model(images)
    _, predicted = torch.max(outputs, 1)

# 将图像、预测结果和真实标签组合在一起并展示
fig, axes = plt.subplots(4, 4, figsize=(12, 12))
for i, ax in enumerate(axes.flat):
    image = images[i].permute(1, 2, 0)  # 将图像从 (C, H, W) 转换为 (H, W, C)
    label = labels[i]
    prediction = predicted[i]

    ax.imshow(image)
    ax.axis('off')
    ax.set_title(f'Predicted: {classes[prediction]}, Actual: {classes[label]}',fontsize=10)

plt.show()

识别效果如上图所示.

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1972544.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

基于SpringBoot + Vue的前后端分离项目-外包平台

项目名称:外包平台 作者的B站地址:程序员云翼的个人空间-程序员云翼个人主页-哔哩哔哩视频 csdn地址:程序员云翼-CSDN博客 1.项目技术栈: 前后端分离的项目 后端:Springboot MybatisPlus 前端:Vue …

达梦数据库安装(DM8)新版 windows11下安装及超详细使用教程

windows11下达梦数据库安装 1、安装参考链接2、存在问题2.1新建表空间失败,详情错误号: -70142.2创建表、视图等 1、安装参考链接 https://blog.csdn.net/u014096024/article/details/134722013 2、存在问题 2.1新建表空间失败,详情错误号: -7014 解决…

掌握 LINQ:通过示例解释 C# 中强大的 LINQ的集运算

文章目录 集运算符原理实战示例1. Union2. Intersect3. Except4. ExceptWith5. Concat6. Distinct 注意事项总结 在C#中,LINQ(Language Integrated Query)提供了丰富的集合操作功能,使得对集合数据进行查询、过滤、排序等操作变得…

从程序员视角浅入浅出了解计算机硬件——内存

前言 内存(Memory)是计算机的重要部件,用于存储数据和指令的重要组件,是冯诺依曼计算机中是的存储器部分。作为与CPU进行沟通的桥梁,内存用于临时存储计CPU中的运算数据,以及与硬盘、网卡等外部组件数据,以便CPU能够快…

STM32卡死、跑飞如何调试确定问题

目录 前言 一、程序跑飞原因 二、调试工具 2.1Registers工具 2.2 Memory工具 2.3 Disassembly工具 2.4 Call Stack工具 三、找到程序跑飞位置 方式一、 方式二、 前言 我们初学STM32的时候代码难免会出现疏忽,导致程序跑飞,不再正常运行&#…

电脑桌面便签软件哪个好,桌面便签如何显示在桌面?

在繁忙的工作日里,一款优秀的电脑桌面便签软件就像是一位贴心的小秘书,帮助你记录重要事项,提醒你不要错过任何细节。那么,哪个电脑桌面便签软件可以帮助我们更好地记录和管理日常工作和学习中的事项呢?又如何将桌面便…

16.搜索框滑块和简单验证

一、一些简单的验证 邮箱验证 <!-- 邮件验证 --><p>邮箱&#xff1a;<input type"email" name"email"></p>邮箱验证框的type是email&#xff0c;在框内&#xff0c;它会自动检测输入内容的格式 &#xff0c;若格式非邮箱格式&…

从分散到整合,细说比特币发展史

原文标题&#xff1a;《Layered Bitcoin》 撰文&#xff1a;Saurabh Deshpande 编译&#xff1a;Chris&#xff0c;Techub News 古往今来&#xff0c;货币在社会中都具有三个关键的功能&#xff1a;财富的储存手段、交换媒介和计量单位。虽然货币的形式在不断变化&#xff0c…

一文了解一下 MindSpeed,MindSpeed 是专为华为昇腾设备设计的大模型分布式加速套件。

https://gitee.com/ascend/MindSpeed Gitee Ascend/MindSpeed 项目&#xff0c;MindSpeed 是针对华为昇腾设备的大模型加速库。 MindSpeed 是专为华为昇腾设备设计的大模型加速库&#xff0c;旨在解决用户在大模型训练过程中遇到的显存资源不足等挑战。该库借鉴了 Megatron、D…

如何理解分布式光纤测温DTS的“实时在线监测”的概念?

实时在线监测是相对于非实时在线监测而言的一种高要求的监测方式。在非实时监测中&#xff0c;我们可以使用手持红外测温仪等设备&#xff0c;在需要时进行开机测量&#xff0c;而在不需要时则可以关机。然而&#xff0c;实时在线监测的目标是要求连续、全天候、每秒都不间断地…

检索增强生成RAG系列10--RAG的实际案例

讲了很多理论&#xff0c;最后来一篇实践作为结尾。本次案例根据阿里云的博金大模型挑战赛的题目以及数据集做一次实践。 完整代码地址&#xff1a;https://github.com/forever1986/finrag.git 本次实践代码有参考&#xff1a;https://github.com/Tongyi-EconML/FinQwen/ 目录 …

我的cesium for UE 踩坑之旅(一)

我的小小历程 创建过程场景搭建引入cesium for UE插件创建空白关卡并添加SunSky照明和FloatingPawn进行场景设置设置cesium token重设场景初始点位置顶层菜单窗口 —>打开cesium ion Assets &#xff0c;从而加入自己的资产 UI制作前端UI页面制作顶部菜单打开内容浏览器窗口…

第100+19步 ChatGPT学习:R实现朴素贝叶斯分类

基于R 4.2.2版本演示 一、写在前面 有不少大佬问做机器学习分类能不能用R语言&#xff0c;不想学Python咯。 答曰&#xff1a;可&#xff01;用GPT或者Kimi转一下就得了呗。 加上最近也没啥内容写了&#xff0c;就帮各位搬运一下吧。 二、R代码实现朴素贝叶斯分类 &#xf…

人工智能ai聊天都有哪些?分享4款智能软件!

在这个科技日新月异的时代&#xff0c;人工智能&#xff08;AI&#xff09;已经悄然渗透到我们生活的方方面面&#xff0c;其中最令人兴奋的莫过于那些能够与人类进行流畅对话的AI聊天软件。它们不仅让交流跨越了物种的界限&#xff0c;更在娱乐、教育、客服等多个领域展现出无…

苹果电脑可以玩什么小游戏 适合Mac电脑玩的休闲游戏推荐

对于游戏爱好者而言&#xff0c;Mac似乎并不是游戏体验的首选平台。这主要是因为相较于Windows系统&#xff0c;Mac上的游戏资源显得相对有限。不过&#xff0c;这并不意味着Mac用户就与游戏世界绝缘。实际上&#xff0c;Mac平台上有着一系列小巧精致且趣味横生的小游戏&#x…

苍穹外卖项目day12(day09)---- 查询历史订单、查询订单详情、取消订单、再来一单(用户端)

目录 用户端历史订单模块&#xff1a; - 查询历史订单 产品原型 业务规则 接口设计 user/OrderController OrderService OrderServiceImpl OrderMapper OrderMapper.xml OrderDetailMapper 功能测试&#xff1a; - 查询订单详情 产品原型 接口设计 OrderControll…

安卓常用控件(下)

ImageView ImageView是用于在界面上展示图片的一个控件&#xff0c;它可以让我们的程序界面变得更加丰富多彩。 属性名描述id给当前控件定义一个唯一的标识符。layout_width给控件指定一个宽度。match_parent&#xff1a;控件大小与父布局一样&#xff1b;wrap_content&#x…

【两数相加】python刷题记录

R3-链表-链表高精度加法 目录 递归法 迭代 递归法 l1.vall2.valcarry&#xff0c;得到的和&#xff0c;%10为当前位存储的值&#xff0c;除以10为当前的进位值 # Definition for singly-linked list. # class ListNode: # def __init__(self, val0, nextNone): # …

SpringMVC和Spring

1.AOP 1.基础内容 AOP是面向切面的的编程&#xff0c;AOP 是一种编程思想&#xff0c;是面向对象编程&#xff08;OOP&#xff09;的一种补充。 面向切面编程&#xff0c;实现在不修改源代码的情况下给程序动态统一添加额外功能的一种技术&#xff08;增强代码&#xff09;&…

【开源项目】基于RTP协议的H264播放器

基于RTP协议的H264播放器 1. 概述2.工程3.测试4.小结 1. 概述 前面记录了一篇基于RTP协议的H264的推流器、接收器的实现过程&#xff0c;但是没有加上解码播放&#xff0c;这里记录一下如何实现解码和播放&#xff0c;也是在前面的基础之上实现的。前一篇的记录为【开源项目】…