【PyTorch实战演练】使用Cifar10数据集训练LeNet5网络并实现图像分类(附代码)

news2025/1/20 14:51:59

文章目录

      • 0. 前言
      • 1. Cifar10数据集
        • 1.1 Cifar10数据集下载
        • 1.2 Cifar10数据集解析
      • 2. LeNet5网络
        • 2.1 LeNet5的网络结构
        • 2.2 基于PyTorch的LeNet5网络编码
      • 3. LeNet5网络训练及输出验证
        • 3.1 LeNet5网络训练
        • 3.2 LeNet5网络验证
      • 4. 完整代码
        • 4.1 训练代码
        • 4.1 验证代码

0. 前言

按照国际惯例,首先声明:本文只是我自己学习的理解,虽然参考了他人的宝贵见解,但是内容可能存在不准确的地方。如果发现文中错误,希望批评指正,共同进步。

本文是基于PyTorch框架使用LeNet5网络实现图像分类的实战演练,训练的数据集采用Cifar10,旨在通过实操强化对深度学习尤其是卷积神经元网络的理解。

本文是一个完整的保姆级学习指引,只要具备最基础的深度学习知识就可以通过本文的指引:使用PyTorch库从零搭建LeNet5网络,然后对其进行训练,最后能够识别实拍图像中的实物。

1. Cifar10数据集

Cifar10数据集由计算机科学家Geoffrey Hinton的学生Alex Krizhevsky、Ilya Sutskever 在1990年代创建。Cifar10是一个包含10个类别的图像分类数据集,每个类别包含6000张32x32像素的彩色图像,总计60000张图像,其中50000个图像用于训练网络模型(训练组),10000个图像用于验证网络模型(验证组)。

其名字Cifar10代表Canadian Institute for Advanced Research(加拿大高级研究所)做的10种分类的图像集,后面的Cifar100则是100种分类的图像集。

1.1 Cifar10数据集下载

使用torchvision直接下载Cifar10:

from torchvision import datasets
from torchvision import transforms

data_path = 'CIFAR10/IMG_file'
cifar10 = datasets.CIFAR10(root=data_path, train=True, download=True,transform=transforms.ToTensor())   #首次下载时download设为true

datasets.CIFAR10中的参数:

  • root:下载文件的路径
  • train:如果为True,则是下载训练组数据,总计50000张图像;如果为False,则是下载验证组数据,总计10000张图像
  • download:新下载时需要设定为True,如果已经下载好数据可以设定为False
  • transform:对图像数据进行变形,这里指定为transforms.ToTensor()图像数据会被转换为Tensor,数据范围调整到0~1,省得我们再写一行归一化代码了
1.2 Cifar10数据集解析

下载之后可以看一下Cifar10数据集的具体内容:

print(type(cifar10))
print(cifar10[0])
------------------------输出------------------------------------
<class 'torchvision.datasets.cifar.CIFAR10'>
(tensor([[[0.2314, 0.1686, 0.1961,  ..., 0.6196, 0.5961, 0.5804],
         [0.0627, 0.0000, 0.0706,  ..., 0.4824, 0.4667, 0.4784],
         [0.0980, 0.0627, 0.1922,  ..., 0.4627, 0.4706, 0.4275],
         ...,
         [0.8157, 0.7882, 0.7765,  ..., 0.6275, 0.2196, 0.2078],
         [0.7059, 0.6784, 0.7294,  ..., 0.7216, 0.3804, 0.3255],
         [0.6941, 0.6588, 0.7020,  ..., 0.8471, 0.5922, 0.4824]],

        [[0.2431, 0.1804, 0.1882,  ..., 0.5176, 0.4902, 0.4863],
         [0.0784, 0.0000, 0.0314,  ..., 0.3451, 0.3255, 0.3412],
         [0.0941, 0.0275, 0.1059,  ..., 0.3294, 0.3294, 0.2863],
         ...,
         [0.6667, 0.6000, 0.6314,  ..., 0.5216, 0.1216, 0.1333],
         [0.5451, 0.4824, 0.5647,  ..., 0.5804, 0.2431, 0.2078],
         [0.5647, 0.5059, 0.5569,  ..., 0.7216, 0.4627, 0.3608]],

        [[0.2471, 0.1765, 0.1686,  ..., 0.4235, 0.4000, 0.4039],
         [0.0784, 0.0000, 0.0000,  ..., 0.2157, 0.1961, 0.2235],
         [0.0824, 0.0000, 0.0314,  ..., 0.1961, 0.1961, 0.1647],
         ...,
         [0.3765, 0.1333, 0.1020,  ..., 0.2745, 0.0275, 0.0784],
         [0.3765, 0.1647, 0.1176,  ..., 0.3686, 0.1333, 0.1333],
         [0.4549, 0.3686, 0.3412,  ..., 0.5490, 0.3294, 0.2824]]]), 6)

Process finished with exit code 0

可以见到Cifar10有其单独的数据类型torchvision.datasets.cifar.CIFAR10,其结构类似list。

如果输出其中某一元素,例如第一个cifar10[0],其中包含:

  • 一个维度为[3,32,32]的tensor(因为上面Transform已经指定了ToTensor),这个就是RGB三通道的图像数据
  • 一个标量数据label,这里是6,这个数据代表图像的真实分类,其对应关系如下表:
    在这里插入图片描述

这里我们也可以用matplotlib把图像的tensor数据转回图像,看看这个label为6的图像究竟是什么样的:

from torchvision import datasets
import matplotlib.pyplot as plt
from torchvision import transforms

data_path = 'CIFAR10/IMG_file'
cifar10 = datasets.CIFAR10(root=data_path, train=True, download=False,transform=transforms.ToTensor())   #首次下载时download设为true

# print(type(cifar10))
# print(cifar10[0])

img,label = cifar10[0]
plt.imshow(img.permute(1,2,0))
plt.show()

输出为:
在这里插入图片描述
没错,这是一个label为6的Frog,32×32像素的图像就只能做到这个程度了。

这里使用了.permute()是因为原始数据的维度是[channel3, H32, W32],而.imshow()要求的输入维度应该是[H, W, channel],需要调整下原始数据的维度顺序。

2. LeNet5网络

LeNet5是由Yann LeCun在20世纪90年代初提出,是一个经典的卷积神经网络。LeNet5由7层神经网络组成,包括2个卷积层、2个池化层和3个全连接层。其(在当时的时代背景下)创造性地使用了卷积层和池化层对输入进行特征提取,减少了参数数量,同时增强了网络对输入图像的平移和旋转不变性。

LeNet5被广泛应用于手写数字识别,也可用于其他图像分类任务。虽然现在的深度卷积神经网络比LeNet5有更好的性能,但LeNet5对于学习卷积神经网络的基本原理和方法具有重要的教育意义

2.1 LeNet5的网络结构

LeNet5的网络结构如下图:
在这里插入图片描述
LeNet5的输入为32x32的图像:

  • 第一层为一个卷积层,包含6个5x5的卷积核,输出的特征图为28x28
  • 第二层为一个2x2的最大池化层,将特征图大小缩小一半14×14
  • 第三层为另一个卷积层,包含16个5x5的卷积核,输出的特征图为10x10
  • 第四层同第二层,将特征图大小缩小一半5×5
  • 第五层为一个全连接层,含有120个神经元
  • 第六层为另一个全连接层,含有84个神经元
  • 最后一层为输出层,包含10个神经元,每个神经元对应一个label
2.2 基于PyTorch的LeNet5网络编码

根据上文LeNet5的网络结构,编写代码如下:

import torch.nn as nn

class LeNet(nn.Module):
    def __init__(self):
        super().__init__()
        self.net = nn.Sequential(
            nn.Conv2d(in_channels=3, out_channels=6, kernel_size=5),  # 由于图片为RGB彩图,channel_in = 3
            #输出张量为 Batch(1)*Channel(6)*H(28)*W(28)
            nn.Sigmoid(),
            nn.MaxPool2d(kernel_size=2, stride=2),
            # 输出张量为 Batch(1)*Channel(6)*H(14)*W(14)
            nn.Conv2d(in_channels=6,out_channels= 16,kernel_size= 5),
            # 输出张量为 Batch(1)*Channel(16)*H(10)*W(10)
            nn.Sigmoid(),
            nn.MaxPool2d(kernel_size=2, stride=2),
            # 输出张量为 Batch(1)*Channel(16)*H(5)*W(5)
            nn.Conv2d(in_channels=16, out_channels=120,kernel_size=5),
            # 输出张量为 Batch(1)*Channel(120)*H(1)*W(1)
            nn.Flatten(),
            # 将输出一维化,用于后面的全连接网络输入
            nn.Linear(120, 84),
            nn.Sigmoid(),
            nn.Linear(84, 10)
        )

    def forward(self, x):
        return self.net(x)

3. LeNet5网络训练及输出验证

3.1 LeNet5网络训练

碍于我的电脑没有GPU,使用CPU版PyTorch数据训练非常慢,我只取了Cifar10的前2000个数据进行训练 (T_T)

small_cifar10 = []
for i in range(2000):
    small_cifar10.append(cifar10[i])

训练相关设置如下:

  • 损失函数:交叉熵损失函数nn.CrossEntropyLoss()
  • 优化方式:随机梯度下降torch.optim.SGD()
  • epoch与learning rate:这是比较头疼的地方,目前我没有探索出太好的方式能在初期就把epoch和lr设定的比较好,只能进行逐步尝试。为了不浪费每次训练,我们可以把每次训练的权重保存下来,下次训练基于上次的结果进行。保存和加载权重的方式可以参考往期博客:通过实例学习Pytorch加载权重.load_state_dict()与保存权重.save()。下图展示了我的探索过程:lr的取值大约从1e-5逐步降低到2e-7,epoch总计大概有3000左右,loss值由初始的10000左右下降到100内。

这一块的训练过程忘记完整记录每一步的详细参数(epoch和lr)了,如果你有需要可以留下邮箱,我把训练好的权重发给你。读者也可以探索更好的训练参数。

在这里插入图片描述

3.2 LeNet5网络验证

激动人心的时刻来了!现在来验证我们训练好的网络能否准确识别目标图像!

我选用的图像是小鹏汽车在2023年上市的G6车型进行验证,图像如下:
在这里插入图片描述
加载我们训练好的权重文件,把图像输入到模型中:

def img_totensor(img_file):
    img = Image.open(img_file)
    transform = transforms.Compose([transforms.ToTensor(), transforms.Resize((32, 32))])
    img_tensor = transform(img).unsqueeze(0)  #这里要升维,对应增加batch维度

    return img_tensor

test_model = LeNet()
test_model.load_state_dict(torch.load('CIFAR10/small2000_8.pth'))

img1 = img_totensor('1.jpg')
img2 = img_totensor('2.jpg')
img3 = img_totensor('3.jpg')
img4 = img_totensor('4.jpg')

print(test_model(img1))
print(test_model(img2))
print(test_model(img3))
print(test_model(img4))

最终输出如下:

tensor([[ 8.4051, 12.0952, -7.9274,  0.3868, -3.0866, -4.7883, -1.6089, -3.6484,
         -1.1387,  4.7348]], grad_fn=<AddmmBackward0>)
tensor([[-1.1992, 17.4531, -2.7929, -6.0410, -1.7589, -2.6942, -3.6753, -2.6800,
          3.6378,  2.4267]], grad_fn=<AddmmBackward0>)
tensor([[ 1.7580, 10.6321, -5.3922, -0.4557, -2.0147, -0.5974, -0.5785, -4.7977,
         -1.2916,  5.4786]], grad_fn=<AddmmBackward0>)
tensor([[10.5689,  6.2413, -0.9554, -4.4162,  1.0807, -7.9541, -5.3185, -6.0609,
          5.1129,  4.2243]], grad_fn=<AddmmBackward0>)

我们来解读一下这个输出:

  • 第1、2、3个图像对应输出tensor最大值在第[1]个元素(从0开始计数),即对应label值为1,真实分类为Car,预测正确。
  • 第4个图像的输出预测错误,最大值在第[0]个元素,LeNet5认为这个图像是Airplane。

这个准确率虽然不算高,但是别忘了我仅仅使用了Cifar10的前2000个数据进行训练;而且LeNet5网络输入为32×32大小的图像,例如上面的青蛙,即使让人来分辨也是挺困难的任务。

4. 完整代码

4.1 训练代码
#文件命名为 CIFAR10_main.py 后面验证时需要调用
from torchvision import datasets
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
from torchvision import transforms
from tqdm import tqdm


data_path = 'CIFAR10/IMG_file'
cifar10 = datasets.CIFAR10(data_path, train=True, download=False,transform=transforms.ToTensor())   #首次下载时download设为true


class LeNet(nn.Module):
    def __init__(self):
        super().__init__()
        self.net = nn.Sequential(
            nn.Conv2d(in_channels=3, out_channels=6, kernel_size=5),  # 由于图片为RGB彩图,channel_in = 3
            #输出张量为 Batch(1)*Channel(6)*H(28)*W(28)
            nn.Sigmoid(),
            nn.MaxPool2d(kernel_size=2, stride=2),
            # 输出张量为 Batch(1)*Channel(6)*H(14)*W(14)
            nn.Conv2d(in_channels=6,out_channels= 16,kernel_size= 5),
            # 输出张量为 Batch(1)*Channel(16)*H(10)*W(10)
            nn.Sigmoid(),
            nn.MaxPool2d(kernel_size=2, stride=2),
            # 输出张量为 Batch(1)*Channel(16)*H(5)*W(5)
            nn.Conv2d(in_channels=16, out_channels=120,kernel_size=5),
            # 输出张量为 Batch(1)*Channel(120)*H(1)*W(1)
            nn.Flatten(),
            # 将输出一维化,用于后面的全连接网络输入
            nn.Linear(120, 84),
            nn.Sigmoid(),
            nn.Linear(84, 10)
        )

    def forward(self, x):
        return self.net(x)

if __name__ == '__main__':
    model = LeNet()
    model.load_state_dict(torch.load('CIFAR10/small2000_7.pth'))

    loss = nn.CrossEntropyLoss()
    opt = torch.optim.SGD(model.parameters(),lr=2e-7)


    small_cifar10 = []
    for i in range(2000):
        small_cifar10.append(cifar10[i])

    for epoch in range(1000):
        opt.zero_grad()
        total_loss = torch.tensor([0])
        for img,label in tqdm(small_cifar10):
            output = model(img.unsqueeze(0))
            label = torch.tensor([label])
            LeNet_loss = loss(output, label)
            total_loss = total_loss + LeNet_loss
            LeNet_loss.backward()
            opt.step()

        total_loss_numpy = total_loss.detach().numpy()
        plt.scatter(epoch,total_loss_numpy,c='b')
        print(total_loss)
        print("epoch=",epoch)


    torch.save(model.state_dict(),'CIFAR10/small2000_8.pth')
    plt.show()

4.1 验证代码
import torch
from torchvision import transforms
from PIL import Image
from CIFAR10_main import LeNet

def img_totensor(img_file):
    img = Image.open(img_file)
    transform = transforms.Compose([transforms.ToTensor(), transforms.Resize((32, 32))])
    img_tensor = transform(img).unsqueeze(0)  #这里要升维,对应增加batch维度

    return img_tensor

test_model = LeNet()
test_model.load_state_dict(torch.load('CIFAR10/small2000_8.pth'))

img1 = img_totensor('1.jpg')
img2 = img_totensor('2.jpg')
img3 = img_totensor('3.jpg')
img4 = img_totensor('4.jpg')

print(test_model(img1))
print(test_model(img2))
print(test_model(img3))
print(test_model(img4))

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1051553.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

Spring Boot 常用注解详解:全面指南

Spring Boot 中有许多常用的注解&#xff0c;这些注解用于配置、管理和定义 Spring Boot 应用程序的各个方面。以下是这些注解按大类和小类的方式分类&#xff0c;并附有解释和示例。 一、Spring Boot 核心注解 SpringBootApplication 解释&#xff1a;这是一个组合注解&a…

反射学习笔记

反射学习笔记 一、反射入门案例 在反射中&#xff0c;万物皆对象&#xff0c;方法也是对象。反射可以在不修改源码的情况下&#xff0c;只需修改配置文件&#xff0c;就能实现功能的改变。 实体类 /*** 动物猫类*/ public class Cat {private String name;public void hi()…

openGauss学习笔记-84 openGauss 数据库管理-内存优化表MOT管理-内存表特性-MOT部署服务器优化:x86

文章目录 openGauss学习笔记-84 openGauss 数据库管理-内存优化表MOT管理-内存表特性-MOT部署服务器优化&#xff1a;x8684.1 BIOS84.2 操作系统环境设置84.3 网络 openGauss学习笔记-84 openGauss 数据库管理-内存优化表MOT管理-内存表特性-MOT部署服务器优化&#xff1a;x86 …

大学各个专业介绍

计算机类 五米高考-计算机类 注&#xff1a;此处平均薪酬为毕业五年平均薪酬&#xff0c;薪酬数据仅供参考 来源&#xff1a; 掌上高考 电气类 五米高考-电气类 机械类 五米高考-机械类 电子信息类 五米高考-电子信息类 土木类 五米高考-土木类

Cloudflare进阶技巧:缓存利用最大化

1. 引言 cloudflare我想你应该知道是什么&#xff0c;一家真正意义上免费无限量的CDN&#xff0c;至今未曾有哥们喷它的。当然&#xff0c;在国内的速度确实比较一般&#xff0c;不过这也不能怪它。 CDN最大的特色&#xff0c;我想就是它的缓存功能&#xff0c;达到防攻击&am…

【数据结构】归并排序、基数排序算法的学习知识点总结

目录 1、归并排序 1.1 算法思想 1.2 代码实现 1.3 例题分析 2、基数排序 2.1 算法思想 2.2 代码实现 2.3 例题分析 1、归并排序 1.1 算法思想 归并排序是一种采用分治思想的经典排序算法&#xff0c;通过将待排序数组分成若干个子序列&#xff0c;将每个子序列排序&#xff…

安卓玩机-----给app加注册码 app加弹窗 云注入弹窗

在对接很多工作室业务中有些客户需要在他们自带的有些app中加注册码或者验证码的需求。其实操作起来也很简单。很多反编译软件有自带的注入功能。例如注入弹窗。这个是需要对应的注册码来启动应用。而且是随机id。重新安装app后需要重新注册才可以继续使用&#xff0c;原则上可…

Ubuntu 20.04部署Promethues

sudo lsb_release -r可以看到操作系统版本是20.04&#xff0c;sudo uname -r可以看到内核版本是5.5.19。 sudo wget -c https://github.com/prometheus/prometheus/releases/download/v2.37.1/prometheus-2.37.1.linux-amd64.tar.gz下载必要的组件。 tar -zxf prometheus-2.…

python对RabbitMQ的简单使用

原文链接&#xff1a;https://blog.csdn.net/weixin_43810267/article/details/123914324 RabbitMq 是实现了高级消息队列协议&#xff08;AMQP&#xff09;的开源消息代理中间件。消息队列是一种应用程序对应用程序的通行方式&#xff0c;应用程序通过写消息&#xff0c;将消…

日常学习之:如何基于 OpenAI 构建自己的向量数据库

文章目录 原理前期准备依赖安装Pinecone 数据库注册Index 创建&#xff08;相当于传统数据库中的创建 table&#xff09; 基于 pinecone 数据库的代码实现尝试用 OpenAI 的 API 构建 embedding将示例的数据 embedding 后写入你的 pinecode &#xff08;构建向量数据库&#xff…

【三次握手、四次挥手】TCP建立连接和断开连接的过程、为什么需要三次握手,为什么需要四次挥手、TCP的可靠传输如何保证、为什么需要等待2MSL等重点知识汇总

目录 三次握手 为什么握手需要三次 四次挥手 为什么挥手需要四次 TCP的可靠传输如何保证 TIME_WAIT等待的时间是2MSL 三次握手 三次握手其实就是指建立一个TCP连接。进行三次握手的主要作用就是为了确认双方的接收能力和发送能力是否正常、指定自己的初始化序列号为后面的…

【【萌新的RISCV学习之流水线通路的控制-8】】

萌新的RISCV学习之流水线通路的控制-8 我们在之前学习了整个单周期的模块工作流程 我们按照整体的思路分段 将数据通路划分为5个阶段 IF &#xff1a; 取地址 ID &#xff1a;指令译码和读存储器堆 EX :执行或计算地址 MEM : 数据存储器访问 WB : 写回 单周期数据通路&…

Three.js加载360全景图片/视频

Three.js加载360全景图片/视频 效果 原理 将全景图片/视频作为texture引入到three.js场景中将贴图与球形网格模型融合&#xff0c;将球模型当做成环境容器使用处理视频时需要以dom为载体&#xff0c;加载与控制视频动作每次渲染时更新当前texture&#xff0c;以达到视频播放效…

强化学习到底是什么?它是怎么运维的

https://mp.weixin.qq.com/s/LL3HfU2iNlmSqaTX_3J7fQ 强化学习是一种行为学习模型,由算法提供数据分析反馈,引导用户逐步获取最佳结果。 来源丨Towards Data Science 作者丨Jair Ribeiro 编译丨科技行者 强化学习属于机器学习中的一个子集,它使代理能够理解在特定环境中…

TensorFlow入门(四、数据流向机制)

session与"图"工作过程中存在的两种数据的流向机制,即:注入机制和取回机制 注入机制(feed):即通过占位符向模式中传入数据 取回机制(fetch):指在执行计算图时&#xff0c;可以同时获取多个操作节点的计算结果 实例代码如下: import tensorflow.compat.v1 as tftf…

【Java】建筑工地智慧管理系统源码

智慧工地系统运用物联网信息技术&#xff0c;致力于推动建筑工程行业的建设发展&#xff0c;做到全自动、信息化&#xff0c;智能化的全方位智慧工地&#xff0c;实现工程施工可视化智能管理以提高工程管理信息化水平。 智慧工地平台拥有一整套完善的智慧工地解决方案&#xff…

C语言入门Day_27 开发环境

前言&#xff1a; 在线编译环境涉及到联网&#xff0c;如果在没有网的情况下&#xff0c;我们就不能写代码了&#xff0c;这一章节&#xff0c;我们将会给大家介绍一下如何搭建一个本地的C语言编译环境。 如果想要设置 C 语言环境&#xff0c;需要确保电脑上有以下两款可用的…

Hive【Hive(三)查询语句】

前言 今天是中秋节&#xff0c;早上七点就醒了&#xff0c;干啥呢&#xff0c;大一开学后空教室紧缺&#xff0c;还不趁着假期来学校等啥呢。顺便偷偷许个愿吧&#xff0c;希望在明年的这个时候&#xff0c;秋招不知道赶不赶得上&#xff0c;我希望拿几个国奖&#xff0c;蓝桥杯…

基于微信小程序的宠物寄养平台小程序设计与实现(源码+lw+部署文档+讲解等)

文章目录 前言系统主要功能&#xff1a;具体实现截图论文参考详细视频演示为什么选择我自己的网站自己的小程序&#xff08;小蔡coding&#xff09;有保障的售后福利 代码参考源码获取 前言 &#x1f497;博主介绍&#xff1a;✌全网粉丝10W,CSDN特邀作者、博客专家、CSDN新星计…

Spark SQL案例【电商购买数据分析】

数据说明 Spark 数据分析 &#xff08;Scala&#xff09; import org.apache.spark.rdd.RDD import org.apache.spark.sql.{DataFrame, SparkSession} import org.apache.spark.{SparkConf, SparkContext}import java.io.{File, PrintWriter}object Taobao {case class Info(u…