【PyTorch】使用PyTorch创建卷积神经网络并在CIFAR-10数据集上进行分类

news2025/1/13 10:15:10

前言

在深度学习的世界中,图像分类任务是一个经典的问题,它涉及到识别给定图像中的对象类别。CIFAR-10数据集是一个常用的基准数据集,包含了10个类别的60000张32x32彩色图像。在本博客中,我们将探讨如何使用PyTorch框架创建一个简单的卷积神经网络(CNN)来对CIFAR-10数据集中的图像进行分类。

在下一篇博客中,我们将尝试不断优化模型结构和训练过程,以达到更高的准确率和性能。

引用

关于卷积神经网络的原理,感兴趣的请参阅我的另一篇博客,里面只使用numpy和基础函数组建了一个卷积神经网络模型,并完成训练和测试
【手搓深度学习算法】从头创建卷积神经网络

背景

卷积神经网络是深度学习中用于图像识别和分类的一种强大工具。它们能够自动从图像中提取特征,并通过一系列卷积层、池化层和全连接层来学习图像的复杂模式。

CIFAR-10数据集包含了飞机、汽车、鸟类、猫、鹿、狗、青蛙、马、船和卡车等10个类别的图像。每个类别有6000张图像,其中50000张用于训练,10000张用于测试。
请添加图片描述

代码解析

我们的目标是构建一个能够处理CIFAR-10数据集的CNN模型。以下是我们的模型结构和数据处理流程的简要概述:

数据预处理

我们首先定义了unpickle函数来加载CIFAR-10数据集的批次文件。read_data函数用于读取数据,将其转换为适合卷积网络输入的格式,并进行归一化处理。我们还提供了一个选项来将图像转换为灰度。

def unpickle(file):
    import pickle
    with open(file, 'rb') as fo:
        dict = pickle.load(fo, encoding='bytes')
    return dict

def read_data(file_path, gray = False, percent = 0, normalize = True):
    data_src = unpickle(file_path)
    np_data = np.array(data_src["data".encode()]).astype("float32")
    np_labels = np.array(data_src["labels".encode()]).astype("float32").reshape(-1,1)
    single_data_length = 32*32 
    image_ret = None
    if (gray):
        np_data = (np_data[:, :single_data_length] + np_data[:, single_data_length:(2*single_data_length)] + np_data[:, 2*single_data_length : 3*single_data_length])/3
        image_ret = np_data.reshape(len(np_data),32,32)
    else:
        image_ret = np_data.reshape(len(np_data),32,32,3)
    
    if(normalize):
        mean = np.mean(np_data)
        std = np.std(np_data)
        np_data = (np_data - mean) / std
    
    if (percent != 0):
        np_data = np_data[:int(len(np_data)*percent)]
        np_labels = np_labels[:int(len(np_labels)*percent)]
        image_ret = image_ret[:int(len(image_ret)*percent)]
    num_classes = len(np.unique(np_labels))
    np_data, np_labels = convert_to_conv_input(np_data, np_labels)
    return np_data, np_labels, num_classes, image_ret 

网络结构

Conv类定义了我们的CNN模型,它包含一个卷积层、一个最大池化层、一个ReLU激活函数和一个全连接层。在forward方法中,我们指定了数据通过网络的流程。

class Conv(th.nn.Module):
    def __init__(self, *args, **kwargs) -> None:
        super(Conv, self).__init__()
        self.conv = th.nn.Conv2d(in_channels=3, out_channels=16, kernel_size=3)
        self.pool = th.nn.MaxPool2d(kernel_size=2,stride=2)
        self.relu = th.nn.ReLU()
        self.linear = th.nn.Linear(16*15*15, 10)
        self.softmax = th.nn.Softmax(dim=1)
        
    def forward(self, x):
        x = self.conv(x) #32,16,30,30
        x = self.pool(x) #32,16,15,15
        x = self.relu(x)
        x = x.view(x.size(0), -1)
        x = self.linear(x)
        return x
    
    # 在predict函数中,额外调用了softmax,将线性层的10个特征值转化为概率,在前向传播中不用是因为pytorch中交叉熵函数自带了softmax
    def predict(self,x):
        x = self.forward(x)
        x = self.softmax(x)
        return x
卷积层、池化层、线性层的输入特征数量的计算方法

线性层的输入特征个数取决于前面层的输出。
具体来说,线性层的输入特征个数是卷积层和池化层处理后的输出特征图的总元素数量。

卷积层定义如下:

self.conv = th.nn.Conv2d(in_channels=3, out_channels=16, kernel_size=3)

这里,in_channels=3 表示输入图像有3个颜色通道(RGB),out_channels=16 表示卷积层将输出16个特征图。

接下来是池化层:

self.pool = th.nn.MaxPool2d(kernel_size=2, stride=2)

kernel_size=2,表示池化窗口的大小是2x2。stride=2 表示池化操作的步长是2。

为了计算线性层的输入特征个数,我们需要知道卷积层和池化层之后的输出特征图的大小。这可以通过计算公式得到,或者通过在实际数据上运行网络的前向传播来确定。

计算公式如下:

对于卷积层,输出特征图的大小可以通过以下公式计算:

H_out = (H_in + 2 * padding - dilation * (kernel_size - 1) - 1) / stride + 1
W_out = (W_in + 2 * padding - dilation * (kernel_size - 1) - 1) / stride + 1

对于池化层,输出特征图的大小也可以通过类似的公式计算。

由于没有指定paddingdilation,查看函数定义可知它们的默认值分别是0和1。因此,如果输入图像的大小是32x32,卷积层之后的大小将是:

H_out = (32 - 1 * (3 - 1) - 1) / 1 + 1 = 30
W_out = (32 - 1 * (3 - 1) - 1) / 1 + 1 = 30

因此,卷积层的输出将有16个30x30的特征图。

然后,池化层将这些特征图的大小减半(因为kernel_size=2stride=2),所以输出将是16个15x15的特征图。

最后,线性层的输入特征个数将是这些特征图的总元素数量:

num_features = out_channels * H_out_pool * W_out_pool = 16 * 15 * 15 = 3600

因此,线性层的正确定义应该是:

self.linear = th.nn.Linear(3600, num_classes)

训练过程

main函数中,我们初始化了模型、损失函数和优化器。我们使用随机梯度下降(SGD)作为优化算法,并设置了学习率。接着,我们进入了训练循环,其中包括前向传播、损失计算、反向传播和权重更新。

loss_function = th.nn.CrossEntropyLoss()
optimizer = th.optim.SGD(conv_model.parameters(), lr = lr)

测试和评估

训练完成后,我们使用训练好的模型对测试数据进行评估,并计算准确率。我们还提供了一个predict方法,它在给定输入数据后返回模型的预测概率。

def predict(self,x):
        x = self.forward(x)
        x = self.softmax(x)
        return x
softmax激活函数

Softmax 激活函数是一种广泛使用的函数,它将一个实数向量转换为概率分布。在深度学习中,它常常用于多类别分类问题的输出层。

Softmax 函数的定义如下:

softmax ( z ) i = e z i ∑ j e z j \text{softmax}(z)_i = \frac{e^{z_i}}{\sum_{j} e^{z_j}} softmax(z)i=jezjezi

其中 z z z 是输入向量, z i z_i zi z z z 的第 i i i 个元素, softmax ( z ) i \text{softmax}(z)_i softmax(z)i 是输出向量的第 i i i 个元素。

Softmax 函数的主要特性是它的输出是一个概率分布,即所有输出元素的值都在 ( 0 , 1 ) (0, 1) (0,1) 区间内,且所有输出元素的值之和为 1。这使得 Softmax 函数非常适合用于表示概率。

Softmax 函数的一个重要性质是它是连续的,且其导数容易计算。这使得 Softmax 函数在深度学习中的反向传播过程中非常有用。

Softmax 函数的导数如下:

∂ ∂ z i softmax ( z ) i = softmax ( z ) i ( 1 − softmax ( z ) i ) \frac{\partial}{\partial z_i}\text{softmax}(z)_i = \text{softmax}(z)_i(1 - \text{softmax}(z)_i) zisoftmax(z)i=softmax(z)i(1softmax(z)i)

这个导数表达式表明,对于 Softmax 函数的输出 y i y_i yi,其对输入 z i z_i zi 的导数等于 y i ( 1 − y i ) y_i(1 - y_i) yi(1yi)。这个导数表达式在反向传播过程中非常有用,因为它可以直接用于计算梯度。

训练过程中没有使用softmax层,是应为torch的交叉熵损失函数已经包含了softmax的操作,如果叠加使用,可能得到错误的结果。

运行结果

作为一个简单的卷积模型,在测试集上得到了60%的准确率
请添加图片描述

完整代码

本文不提供完整代码,因为随着我的微调优化过程,已经没有这个版本的基线代码了,想要最终代码的欢迎阅读下一篇博客 “记一次卷积网络调优的过程”
在这里插入图片描述

注意点

  • 数据预处理:确保数据被正确地加载和归一化,这对模型的训练效果至关重要。
  • 模型结构:模型的层数和参数需要根据任务的复杂性来调整。过于简单的模型可能无法捕捉到数据中的复杂特征,而过于复杂的模型可能会导致过拟合。
  • 损失函数:我们使用交叉熵损失函数,它适用于多类别分类问题。
  • 优化器:在每次迭代前,记得清除累积的梯度,以避免错误的梯度更新。

可能的优化点

  • 学习率调整:可以尝试使用学习率调度器来在训练过程中调整学习率,以改善模型的收敛速度和性能。
  • 权重初始化:尝试不同的权重初始化方法,以帮助模型更快地收敛。
  • 正则化技术:使用如Dropout、L2正则化等技术来减少过拟合。
  • 数据增强:通过对训练图像进行随机变换(如旋转、缩放、裁剪等),可以增加模型的泛化能力。
  • 更深的网络:考虑增加更多的卷积层和池化层来提取更复杂的特征。
  • 批量归一化:在卷积层之后添加批量归一化层,以稳定训练过程并加速收敛。

结论

通过本博客,我们展示了如何使用PyTorch框架构建一个简单的CNN模型,并在CIFAR-10数据集上进行训练和测试。虽然我们的模型结构相对简单,但它为理解深度学习和图像分类提供了一个很好的起点。在下一篇博客中,我们将尝试不断优化模型结构和训练过程,以达到更高的准确率和性能。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1414301.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

TensorFlow2实战-系列教程2:神经网络分类任务

🧡💛💚TensorFlow2实战-系列教程 总目录 有任何问题欢迎在下面留言 本篇文章的代码运行界面均在Jupyter Notebook中进行 本篇文章配套的代码资源已经上传 1、Mnist数据集 下载mnist数据集: %matplotlib inline from pathlib imp…

vs2019报错MSB4019 找不到导入的项目“BuildCustomizations\CUDA 9.2.props”

在VS中执行生成,报错如下:严重性 代码 说明 项目 文件 行 禁止显示状态 错误 MSB4019 找不到导入的项目“D:\Microsoft Visual Studio\2019\Community\MSBuild\Microsoft\VC\v160\BuildCustomizations\CUDA 9.2.props”。请确认 Import 声明“D:\Microso…

在autodl训练yolov8时卡在下载字体

1.问题 在autodl训练yolov8到这一步之后会卡住很久 2. 解决办法 Ctric中断后发现是下载Arial字体卡住了,这个字体需要从外网中下载 先手动从链接中下载https://ultralytics.com/assets/Arial.ttf ,然后上传到autodl。然后将这个文件移动到/root/.config/…

机电制造ERP软件有哪些品牌?哪家的机电制造ERP系统比较好

机电制造过程比较复杂,涵盖零配件、采购、图纸设计、工艺派工、生产计划、物料需求计划、委外加工等诸多环节。而供应链涉及供应商的选择、材料采购价格波动分析、材料交货、品质检验等过程,其中某个环节出现问题都可能会影响产品交期和经营效益。 近些…

一文速通Python添加、修改和删除字典元素

添加、修改和删除字典元素是 Python 中使用字典时常见的操作。字典是一种无序、可变的数据结构,用于存储键值对。在 Python 中,对字典元素进行添加、修改和删除操作可以帮助我们动态地管理数据,方便地根据需求对字典进行更新和维护。 一、添…

详讲api网关之kong的基本概念及安装和使用(一)

什么是api网关 前面我们聊过sentinel,用来限流熔断和降级,如果你只有一个服务,用sentinel自然没有问题,但是如果是有多个服务,特别是微服务的兴起,那么每个服务都使用sentinel就给系统维护带来麻烦。那么网…

Making Large Language Models Perform Better in Knowledge Graph Completion论文阅读

文章目录 摘要1.问题的提出引出当前研究的不足与问题KGC方法LLM幻觉现象解决方案 2.数据集和模型构建数据集模型方法基线方法任务模型方法基于LLM的KGC的知识前缀适配器知识前缀适配器 与其他结构信息引入方法对比 3.实验结果与分析结果分析:可移植性实验&#xff1…

那些年与指针的爱恨情仇(一)---- 指针本质及其相关性质用法

关注小庄 顿顿解馋 (≧∇≦) 引言: 小伙伴们在学习c语言过程中是否因为指针而困扰,指针简直就像是小说女主,它逃咱追,我们插翅难飞…本篇文章让博主为你打理打理指针这个傲娇鬼吧~ 本节我们将认识到指针本质,何为指针和…

k8s 版本发布与回滚

一、实验环境准备: kubectl get pods -o wide kubectl get nodes -o wide kubectl get svc 准备两个nginx镜像,版本号一个是V3,一个是V4 二、准备一个nginx.yaml文件 apiVersion: apps/v1 kind: Deployment metadata:name: nginx-deploylab…

解释性人工智能(XAI)—— AI 决策的透明之道

在当今数字化时代,人工智能(AI)已经成为我们生活中不可或缺的一部分。AI 系统的决策和行为对我们的生活产生了深远的影响,从医疗保健到金融服务再到自动驾驶汽车。 然而,有时候 AI 的决策似乎像黑盒子一样难以理解&am…

linux服务器ssh连接慢问题处理

一、 可能导致慢的几个原因 1、网络问题:网络延迟、带宽限制和包丢失等网络问题都有可能导致SSH连接变慢。 2、客户端设置:错误的客户端设置,如使用过高的加密算法或不适当的密钥设置,可能导致SSH连接变慢。 3、服务器负载过高…

element-ui 树形控件 实现点击某个节点获取本身节点和底下所有的子节点数据

1、需求&#xff1a;点击树形控件中的某个节点&#xff0c;需要拿到它本身和底下所有的子节点的id 1、树形控件代码 <el-tree:data"deptOptions"node-click"getVisitCheckedNodes"ref"target_tree_Speech"node-key"id":default-ex…

elasticsearch8的整体总结

es概述 elasticsearch简介 官网: https://www.elastic.co/ ElasticSearch是一个基于Lucene&#xff08;Apache开源全文检索工具包&#xff09;的搜索服务器。它提供了一个分布式多用户能力的全文搜索引擎&#xff0c;基于RESTful web接口。Elasticsearch是用Java开发的&…

MySQL:数据库索引详解

1、什么是索引&#xff1a; 索引是一种用于快速查询和检索数据的数据结构。常见的索引结构有: B 树&#xff0c; B树和 Hash。 索引的作用就相当于目录的作用。打个比方: 我们在查字典的时候&#xff0c;如果没有目录&#xff0c;那我们就只能一页一页的去找我们需要查的那个字…

基于comsol热黏性声学模块仿真声学超材料的声学特性

研究内容&#xff1a; 传统的声学吸收器被用于具有与工作波长相当的厚度的结构&#xff0c;这在低频范围的实际应用中造成了主要障碍。我们提出了一种基于超表面的完美吸收体&#xff0c;能够在极低频区域实现声波的完全吸收。具有深亚波长厚度至特征尺寸k&#xff1d;223的超…

基于Matlab/Simulink直驱式风电储能制氢仿真模型

接着还是以直驱式风电为DG中的研究对象&#xff0c;上篇博客考虑的风电并网惯性的问题&#xff0c;这边博客主要讨论功率消纳的问题。 考虑到风速是随机变化的&#xff0c;导致风电输出功率的波动性和间歇性问题突出&#xff1b;随着其应用规模的不断扩大以及风电在电网中渗透率…

【洛谷 P7072】[CSP-J2020] 直播获奖 题解(优先队列+对顶堆)

[CSP-J2020] 直播获奖 题目描述 NOI2130 即将举行。为了增加观赏性&#xff0c;CCF 决定逐一评出每个选手的成绩&#xff0c;并直播即时的获奖分数线。本次竞赛的获奖率为 w % w\% w%&#xff0c;即当前排名前 w % w\% w% 的选手的最低成绩就是即时的分数线。 更具体地&am…

Typora 无法导出 pdf 问题的解决

目录 问题描述 解决困难 解决方法 问题描述 Windows 下&#xff0c;以前&#xff08;Windows 11&#xff09; Typora 可以顺利较快地由 .md 导出 .pdf 文件&#xff0c;此功能当然非常实用与重要。 然而&#xff0c;有一次电脑因故重装了系统&#xff08;刷机&#xff09;…

【代码随想录15】110.平衡二叉树 257. 二叉树的所有路径 404.左叶子之和

目录 110. 平衡二叉树题目描述参考代码 257. 二叉树的所有路径题目描述参考代码 404.左叶子之和题目描述参考代码 110. 平衡二叉树 题目描述 给定一个二叉树&#xff0c;判断它是否是高度平衡的二叉树。 本题中&#xff0c;一棵高度平衡二叉树定义为&#xff1a; 一个二叉树…

亚马逊测评:卖家如何操作测评,安全高效(自养号测评)

亚马逊测评的作用在于让用户更真实、清晰、快捷地了解产品以及产品的使用方法和体验。通过买家对产品的测评&#xff0c;也可以帮助厂商和卖家优化产品缺陷&#xff0c;提高用户的使用体验。这进而帮助他们获得更好的销量&#xff0c;并更深入地了解市场需求。亚马逊测评在满足…