深度学习系列1——Pytorch 图像分类(LeNet)

news2025/1/17 18:11:13

1. 概述

本文主要是参照 B 站 UP 主 霹雳吧啦Wz 的视频学习笔记,参考的相关资料在文末参照栏给出,包括实现代码和文中用的一些图片。

整个工程已经上传个人的 github https://github.com/lovewinds13/QYQXDeepLearning ,下载即可直接测试,数据集文件因为比较大,已经删除了,按照下文教程下载即可。


2. LeNet

LeNet 可以说是第一个卷积神经网络,LeNet-5。LeNet-5 由Y. LeCun 在 1998 年发表的文章《Gradient-Based Learning Applied to Document Recognition 》中正式提出,应用于数字识别问题。LeNet 包含了卷积网络的基本组件,如下图:可以看到卷积层,池化层,全连接层。

2.1 网络结构

在这里插入图片描述
LeNet-5 由 2 个卷积层,2 个池化层(下采样层),3 个全连接层组成。

说明
输入层(INPUT)32 X 32 X 1 的图片(长、宽、色彩)
卷积层(C1)输入 32 X 32 X 1,卷积核 5 X 5 X 6,步长(stride)为 1, 输出 28 X 28 X 6 的特征图
池化层(S2)输入 28 X 28 X 6, 过滤器为 2 X 2,输出 14 X 14 X 6
卷积层(C3)输入 14 X 14 X 6,卷积核 5 X 5 X 16,步长(stride)为 1, 输出 10 X 10 X 16 的特征图
池化层(S4)输入 10 X 10 X 16, 过滤器为 2 X 2,输出 5 X 5 X 16
全连接层(C5)输入 5 X 5 X 16,卷积核 5 X 5 X 120,步长(stride)为 1,输出 1 X 1 X 120 的特征图
全连接层(F6)输入 120 个节点,输出 84 个节点
全连接层(OUTPUT)输入 84 个节点,输出 10 个节点

模型框图:
在这里插入图片描述

3. demo 实现

针对 CIFAR10 数据集,进行图像识别。

整个过程实现流程:

在这里插入图片描述

3.1 demo 结构:

demo 包含 model.py ,train.py,predict.py 三个文件。
在这里插入图片描述

3.2 model.py

"""
模型
"""


import torch.nn as nn
import torch.nn.functional as F


class LeNet(nn.Module): # 集成nn.Module父类
    def __init__(self):
        super(LeNet, self).__init__()

        # 看一下具体的参数
        self.conv1 = nn.Conv2d(in_channels=3,
                               out_channels=16,
                               kernel_size=5,
                               stride=1,
                               padding=0,
                               bias=True
                               )
        self.pool1 = nn.MaxPool2d(2, 2)
        self.conv2 = nn.Conv2d(16, 32, 5)
        self.pool2 = nn.MaxPool2d(2, 2)
        self.fc1 = nn.Linear(32*5*5, 120)
        self.fc2 = nn.Linear(120, 84)
        self.fc3 = nn.Linear(84, 10)

        # self.relu = nn.ReLU(inplace=True)

    # 正向传播
    def forward(self, x):
        x = F.relu(self.conv1(x))   # 输入: (3, 32, 32), 输出: (16, 28, 28)
        x = self.pool1(x)   # 输出: (16, 14, 14)
        x = F.relu(self.conv2(x))   # 输出: (32, 10, 10)
        x = self.pool2(x)   # 输出: (32, 5, 5)
        x = x.view(-1, 32*5*5)  # 输出: (32*5*5)
        x = F.relu(self.fc1(x)) # 输出: (120)
        x = F.relu(self.fc2(x)) # 输出: (84)
        x = self.fc3(x) # 输出(10)

        return x

"""
调试信息, 查看模型参数传递
"""
# import torch
# input1 = torch.rand([32, 3, 32, 32])
# modelx = LeNet()
# print(modelx)
# output = modelx(input1)

3.2.1 卷积后的图像尺寸

(1)正方形图像:输入大小 W X W,卷积核大小 F X F,步长 S,Padding 为 P,卷积之后输出大小为 N XN ,N 的计算如下:
在这里插入图片描述

 x = F.relu(self.conv1(x))   # 输入: (3, 32, 32), 输出: (16, 28, 28)

(2)矩形图像:输入大小 H X W,卷积核大小 F X F,步长 S,Padding 为 P,卷积之后输出大小计算如下:
在这里插入图片描述

3.2.2 池化后的图像尺寸

输入大小 H X W,卷积核尺寸 F X F,步长 S,池化之后输出大小计算如下:

在这里插入图片描述

x = self.pool1(x)   # 输入: (16, 28, 28), 输出: (16, 14, 14)

3.2.3 Tensor 展平

经过前面一层处理,数据输出为三维 Tensor (32, 5, 5),使用 view() 方法来展平数据。

x = x.view(-1, 32*5*5)  # 输出: (32*5*5)

3.3 train.py

3.3.1 导入包

import torch
import torchvision
import torch.nn as nn
from model import LeNet
import torch.optim as optim
import torchvision.transforms as transforms
import matplotlib.pyplot as plt
import numpy as np
import time

3.3.2 数据集预处理

    transform = transforms.Compose([
        transforms.ToTensor(),  # 数据转为张量
        transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))  # 标准化处理
    ])

将数据转成张量,并做标准化处理。

3.3.3 导入数据集

数据集包括训练集和测试集,设置 download=True,自动从 Pytorch 网站下载数据集。下图为开始下载数据集。

在这里插入图片描述

3.3.4 数据集测试

可通过下面的代码,查看数据集图片。

 # 定义的分类标签
    class_labels = ('plane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck')
   # 查看数据集的图片
    def img_show(img):
        img = img / 2 + 0.5
        np_img = img.numpy()
        plt.imshow(np.transpose(np_img, (1, 2, 0)))
        plt.show()

    # 查看数据集中的5张图像
    print(''.join(" %5s " % class_labels[val_label[j]] for j in range(5)))
    img_show(torchvision.utils.make_grid(val_image))

因为 Pytorch Tensor 读入数据时,维度参数顺序发生了变化。
Pytorch Tensor 对应 [深度,高度,宽度],而原始数据是[高度,宽度,深度]。故通过下面的代码调整,才能正常显示图片。

 plt.imshow(np.transpose(np_img, (1, 2, 0)))

导入数据集:

# 导入训练集数据(50000张图片)
    train_set = torchvision.datasets.CIFAR10(root='./data', # root: 数据集存储路径
                                             train=True,    # 数据集为训练集
                                             download=False,  # download: True时下载数据集(下载完成修改为False)
                                             transform=transform    # 数据预处理
                                             )
    #   加载训练集
    train_loader = torch.utils.data.DataLoader(train_set,   # 加载训练集
                                               batch_size=50,   # batch 大小
                                               shuffle=True,    # 是否随机打乱训练集
                                               num_workers=0    # 使用的线程数量
                                               )
    # 导入测试集(10000张图片)
    val_set = torchvision.datasets.CIFAR10(root='./data',
                                           train=False,     # 数据集为测试集
                                           download=False,
                                           transform=transform
                                           )
    # 加载测试集数据
    val_loader = torch.utils.data.DataLoader(val_set,
                                             batch_size=10000,   # 测试集batch大小
                                             shuffle=False,
                                             num_workers=0
                                             )
    # 获取测试集中的图片和标签
    val_data_iter = iter(val_loader)
    # val_image, val_label = val_data_iter.next()
    val_image, val_label = next(val_data_iter)  #python 3

3.3.4 训练过程

(1)CPU 训练代码:

 net = LeNet()   # 用于训练的网络模型
    # 指定GPU or CPU 进行训练
    net.to("cpu")
    loss_function = nn.CrossEntropyLoss()   # 损失函数(交叉熵函数)
    optimizer = optim.Adam(net.parameters(), lr=0.001)  # 优化器(训练参数, 学习率)

    # 训练的轮数
    for epoch in range(5):
        start_time = time.perf_counter()
        running_loss = 0.0
        # 遍历训练集, 从0开始
        for step, data in enumerate(train_loader, start=0):
            inputs, labels = data   # 得到训练集图片和标签
            optimizer.zero_grad()   # 清除历史梯度
            outputs = net(inputs)   # 正向传播
            loss = loss_function(outputs, labels)   # 损失计算
            loss.backward() # 反向传播
            optimizer.step()    #优化器更新参数

            # 用于打印精确率等评估参数
            running_loss += loss.item()
            if step % 500 == 499:   # 500步打印一次
                with torch.no_grad():
                    outputs = net(val_image)    # 传入测试集数据
                    predict_y = torch.max(outputs, dim=1)[1]
                    accuracy = torch.eq(predict_y, val_label).sum().item() / val_label.size(0)

                    # 打印训练轮数、精确率等
                    print("[%d, %5d] train_loss: %.3f   test_accuracy: %.3f" %
                          (epoch + 1 , step + 1, running_loss / 500, accuracy)
                          )
                    running_loss = 0.0
        end_time = time.perf_counter()
        print("cost time = ", end_time - start_time)

    print("Finished trainning")

    save_path = "./LeNet.pth"
    torch.save(net.state_dict(), save_path) # 保存训练输出的模型文件

训练打印信息:
在这里插入图片描述(2)GPU 训练代码:
需要将训练设备指定为 GPU,且需要修改对应数据和标签。

net = LeNet()   # 用于训练的网络模型
    # 指定GPU or CPU 进行训练
    net.to("cuda")
    loss_function = nn.CrossEntropyLoss()   # 损失函数(交叉熵函数)
    optimizer = optim.Adam(net.parameters(), lr=0.001)  # 优化器(训练参数, 学习率)

    # 训练的轮数
    for epoch in range(2):
        running_loss = 0.0
        # 遍历训练集, 从0开始
        for step, data in enumerate(train_loader, start=0):
            inputs, labels = data   # 得到训练集图片和标签
            optimizer.zero_grad()   # 清除历史梯度
            outputs = net(inputs.to(device))   # 正向传播
            loss = loss_function(outputs, labels.to(device))   # 损失计算
            loss.backward() # 反向传播
            optimizer.step()    #优化器更新参数

            # 用于打印精确率等评估参数
            running_loss += loss.item()
            if step % 500 == 499:   # 500步打印一次
                with torch.no_grad():
                    outputs = net(val_image.to(device))    # 传入测试集数据
                    predict_y = torch.max(outputs, dim=1)[1]
                    accuracy = torch.eq(predict_y, val_label.to(device)).sum().item() / val_label.size(0)

                    # 打印训练轮数、精确率等
                    print("[%d, %5d] train_loss: %.3f   test_accuracy: %.3f" %
                          (epoch + 1 , step + 1, running_loss / 500, accuracy)
                          )
                    running_loss = 0.0

    print("Finished trainning")

    save_path = "./LeNet.pth"
    torch.save(net.state_dict(), save_path) # 保存训练输出的模型文件

通过对比可发现,GPU 的速度快于 CPU。

注:
本文采用 pycharm 开发,需要安装对应 CUDA,具体的版本需要查看自己电脑对应的 GPU 型号,然后下载 CUDA 安装。本文的信息如下:

在这里插入图片描述

3.4. predict.py

此文件为模型测试代码。

""""
测试
"""
import torch
import torchvision.transforms as transforms
from PIL import Image
from model import LeNet


def main():
    transform = transforms.Compose([
        transforms.Resize((32, 32)),
        transforms.ToTensor(),
        transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
    ])

    data_class = ('plane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck')

    net = LeNet()
    net.load_state_dict(torch.load('LeNet.pth'))
    # net.load_state_dict(torch.load('LeNet.pth', map_location=torch.device("cpu")))

    test_image = Image.open('cat_test2.jpg')
    test_image = transform(test_image)  # [C H W]
    test_image = torch.unsqueeze(test_image, dim=0)  # [N C H W]

    with torch.no_grad():
        outputs = net(test_image)
        predict = torch.max(outputs, dim=1)[1].numpy()
    print(f"It is {data_class[int(predict)]}")

测试图片为:cat_test2.jpg
在这里插入图片描述
测试结果:
在这里插入图片描述


欢迎关注【千艺千寻】,共同成长
在这里插入图片描述


参考:

  1. pytorch图像分类篇:2.pytorch官方demo实现一个分类器(LeNet)
  2. B站——2.1 pytorch官方demo(Lenet)
  3. Pytorch中nn.Conv1d、Conv2D与BatchNorm1d、BatchNorm2d函数
  4. pytorch官方demo实现图像分类(LeNet)
  5. UP主代码——Test1_official_demo
  6. Pytorch中文
  7. pytorch中的卷积操作详解
  8. LeNet 论文地址
  9. LeNet:第一个卷积神经网络

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/12194.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

VLAN trunk扩展 GVRP

目录 一、GVRP产生背景 VTP协议 GARP GVRP 二、GVRP的实现和基本概念 GVRP的应用 GVRP的单向注册 GVRP单向注销 GVRP的注册模式 VTP和GVRP的使用风险 一、GVRP产生背景 VTP协议 如何解决园区网中大批量的VLAN的配置问题? 早期可以使用excel表格配置VL…

小程序开发 - 基本组件

目录 小程序启动过程 页面渲染过程 新建文件夹 组件 view scroll-view swiper和swiper-item button image 小程序启动过程 将小程序代码包下载本地解析app.json全局配置文件执行app.js小程序入口文件,调用App()创建渲染小程序首页小程序启动完成 页面渲染过…

【MQTT基础篇(一)】MQTT介绍

文章目录MQTT介绍1 MQTT历史2 MQTT版本MQTT介绍 MQTT是一个客户端服务端架构的发布/订阅模式的消息传输协议。它的设计思想是轻巧、开放、简单、规范,易于实现。这些特点使得它对很多场景来说都是很好的选择,特别是对于受限的环境如机器与机器的通信&…

第七章:项目成本管理

一、规划成本管理 确定如何估算、预算、管理、监督和控制项目成本的过程。主要作用是在整个项目期间为如何管理项目成本提供指南和反向。 输入工具与技术输出 1.项目章程 2.项目管理文件 进度管理计划风险管理计划3.事业环境因素 4.组织过程资产 1.专家判断 2.数据分析 3.会…

SpringMVC学习篇(九)

SpringMVC拦截器例子 1 界面登录验证 1.1 准备工作 1.1.1 导入servlet-api依赖 <dependency><groupId>org.apache.tomcat</groupId><artifactId>servlet-api</artifactId><version>6.0.53</version> </dependency>1.1.2 创…

基于机器视觉的移动消防机器人(三)--软件设计

本文素材来源于北方民族大学 机电工程学院 作者&#xff1a;牟义达、黄瑞翔、李涛 指导老师&#xff1a;田国禾、张春涛 1. 总系统软件流程图 为了实现消防功能&#xff0c;对软件进行系统设计。根据机器人要实现的功能进行逐一设计&#xff0c;设计完之后再将其整合到一起&a…

流体力学基础——粘性

1、粘性&#xff1a;流体的属性 粘性就是流体阻碍自身流动的特性&#xff1b; 专业定义&#xff1a;粘性是流体持续剪切变形时内部产生剪切力的性质&#xff1b; 流体内部的粘性力&#xff0c;类似于固体的摩擦力&#xff0c;但是只有动粘性力&#xff1b; 表面张力不是粘性…

远程桌面一直被人爆破的解决思路

目录前言初步解决方法题外话预防措施获取日志Get-EventLog例子防火墙操作编写软件自动提取IP和添加黑名单调用powershell命令调用cmd命令前言 某天远程自己的电脑发现登不上了&#xff0c;错误信息如下&#xff1a; 开始也没在意&#xff0c;后面出现了好几次才反应过来。查看…

程序员职场生态:近8成本科毕业生起薪过万,跳槽首选智能汽车行业

中国互联网行业经历了超过20年的高速发展&#xff0c;逐渐融入到各行各业&#xff0c;程序员在其中发挥着举足轻重的作用&#xff0c;从业人员数量与日俱增。GitHub数据显示&#xff0c;2021年中国开发者规模达到755万。 近日&#xff0c;拉勾招聘数据研究院对程序员群体开展深…

SpringBoot笔记

文章目录1️⃣ 简介一. 什么是 IoC 容器&#xff1f;二. AOP面向切面编程三. SSM整合四. HttpServletRequest五. HttpServletResponse六. Cookie 与 Session七. Cookie八. Session九. 转发与重定向十. Spring项目转SpringBoot十一. Spring生命周期十二. 什么是 pom十三. 为什么…

知识整理说明:1799962-26-7,(4E)-TCO-NH2,(4E)-反式环辛烯-氨基

(4E)-TCO-amine物理数据&#xff1a; CAS&#xff1a;1799962-26-7| 中文名&#xff1a;(4E)-反式环辛烯-氨基 | 英文名&#xff1a;(4E)-TCO-amine&#xff0c;(4E)-TCO-NH2 结构式&#xff1a; 英文别名&#xff1a; (4E)-TCO-amine 中文别名&#xff1a; (4E)-反式环辛烯…

耗时半月,把牛客网最火Java面试题总结成PDF,涵盖所有面试高频题

最近感慨面试难的人越来越多了&#xff0c;一方面是市场环境&#xff0c;更重要的一方面是企业对Java的人才要求越来越高了。 基本上这样感慨的分为两类人&#xff0c;第一&#xff0c;虽然挂着3、5年经验&#xff0c;但肚子里货少&#xff0c;也没啥拿得出手的项目&#xff0c…

Vite 入门篇:学会它,一起提升开发幸福感。

相信大部分兄弟都体验过 Vite 了&#xff0c;知道它很快。但你知道它为什么快&#xff0c;相比 Webpack 有哪些不同吗&#xff1f;今天咱们就来全面了解一下 Vite &#xff0c;尤其适合新手兄弟。话不多说&#xff0c;开整&#xff01; 什么是构建工具 很多人对构建工具没有什…

RE转NFA转DFA

https://github.com/Nightmare4214/re_nfa_dfa 前置知识 ϵ\epsilonϵ代表空串 语言 某个给定字母表上一个任意的可数的串集合 正则语言/正则表达式 正则语言&#xff08;regular language&#xff09;/正则表达式&#xff08;regular expression&#xff09; 每个正则表达…

2022 SpeechHome 语音技术研讨会-回顾

2022年11月13日&#xff0c;第二届SpeechHome语音技术研讨会和第七届Kaldi技术交流会圆满落幕。本届SpeechHome语音技术研讨会由中国计算机学会、深圳市人工智能学会、小米集团、腾讯天籁实验室、语音之家主办&#xff0c;CCF语音对话与听觉专委会作为指导单位&#xff0c;由内…

【Java开发】 Spring 03:云服务器 Docker 环境下安装 MongoDB 并连接 Spring 项目实现简单 CRUD

接下来介绍一下 NoSQL &#xff0c;相比于 Mysql 等关系型的数据库&#xff0c;NoSQL &#xff08;文档型数据库&#xff09;由于存储的数据之间无关系&#xff0c;因此具备大数据量&#xff0c;高性能等特点&#xff0c;用于解决大规模数据集合多重数据种类带来的挑战&#xf…

点击化学试剂1609736-43-7,TCO-NH2 hydrochloride,反式环辛烯-氨基HCL盐

TCO-amine hydrochloride物理数据&#xff1a; CAS&#xff1a;1609736-43-7| 中文名&#xff1a; 反式环辛烯-氨基盐酸盐&#xff0c;反式环辛烯-氨基HCL盐 | 英文名&#xff1a;TCO-amine hydrochloride 结构式&#xff1a; 中文别名&#xff1a; 环辛-4-烯-1-基 (3-氨基丙…

Mvvm中的Lifecycle

lifecycle&#xff1a;一个持有activity/fragment生命周期信息的类&#xff0c;允许其他对象观察此对状态 Event:从框架和lifecycle类派发的生命周期事件&#xff0c;也就是activity和fragment的各个状态会发Event state:这个就好理解了&#xff0c;就是activity和fragment当…

工业互联与MQTT

、工业互联网 新一代信息通信技术与工业经济深度融合的新型基础设施、应用模式和工业生态&#xff0c;通过对人、机、物、系统等的全面连接&#xff0c;构建起覆盖全产业链、全价值链的全新制造和服务体系&#xff0c;为工业乃至产业数字化、网络化、智能化发展提供了实现途径&…

论文阅读-Dr.Deep_基于医疗特征上下文学习的患者健康状态可解释评估

论文地址&#xff1a;Dr.Deep&#xff1a;基于医疗特征上下文学习的患者健康状态可解释评估 (ict.ac.cn) 代码地址&#xff1a;GitHub - Accountable-Machine-Intelligence/Dr.Deep 简介&#xff1a; 深度学习是当前医疗多变量时序数据分析的主流方法。临床辅助决策关乎病人生…