神经网络识别数字图像案例

news2024/9/9 5:15:53

学习资料:从零设计并训练一个神经网络,你就能真正理解它了_哔哩哔哩_bilibili

这个视频讲得相当清楚。本文是学习笔记,不是原创,图都是从视频上截图的。

1. 神经网络

2. 案例说明

具体来说,设计一个三层的神经网络。以数字图像作为输入,经过神经网络的计算,识别出图像中的数字是几,从而实现数字图像的分类。

3. 视频讲解内容的提纲

4. 神经网络的设计和实现

我们要处理的数据是28*28像素的灰色通道图像。

这样的灰色图像包括了28*28=784个数据点。需要先将他展平为1*784大小的向量。然后将这个向量输入到神经网络中。

用一个三层神经网络处理图片对应的向量X。输入成需要接收784维的图片向量X。X里面每个维度的数据都有一个神经元来接收。因此输入层要包含784个神经元。

隐藏成用于特征提取特征向量,将输入的特征向量处理成更高级的特征向量。

因为手写数字图像识别并不复杂,所以将隐藏层的神经元个数设置为256。这样,输入层和隐藏层之间就会有个784*256的线性层。它可以将一个784维的输入向量转换为256维的输出向量。

该输出向量会继续向前传播到达输出层。

由于最终要将数字图像识别为0~9,十种可能的数字。因此,输出层需要定义10个神经元,对应这十种数字。

256维的向量在经过隐藏层和输出层之间的线性层计算后,就得到了10维的输出结果。这个10维的向量就代表了10个数字的预测得分。

为了继续得到输出层的预测概率,还要将输出层的输出输入到softmax层。softmax层会将10维的向量转换为10个概率值p0~p9。p0~p9相加的总和等于1.

5. 神经网络的Pytorch实现

import torch
from torch import nn

# 定义神经网络Network
class Network(nn.Module):
    def __init__(self):
        super().__init__()
        # 线性层1,输入层和隐藏层之间的线性层
        self.layer1 = nn.Linear(784, 258)
        # 线性层2,隐藏层和输出层之间的线性层
        self.layer2 = nn.Linear(256, 10)
    # 在前向传播,forward函数中,输入为图像x
    def forward(self, x):
        x = x.view(-1, 28 * 28) # 使用view函数,将x展平
        x = self.layer1(x) # 将x输入到layer1
        x = torch.relu(x) # 使用relu激活
        return self.layer2(x) # 输入至layer2计算结果

    # 这里没有直接定义softmax层,因为后面会使用CrossEntropyLoss损失函数
    # 在这个损失函数中,会实现softmax的计算

6. 训练数据的准备和处理

from torchvision import transforms
from torchvision import datasets
from torch.utils.data import DataLoader

# 初学只要知道大致的数据处理流程即可
if __name__ == '__main__'
    # 实现图像的预处理pipeline
    transform = trnasforms.Compose([
        # 转换成单通道灰度图
        transforms.Grayscale(num_output_channels=1),
        # 转换为张量
        transforms.ToTensor()
    ])

    # 使用ImageFolder函数,读取数据文件夹,构建数据集dataset
    # 这个函数会将保持数据的文件夹的名字,作为数据的标签,组织数据
    train_dataset = datasets.ImageFolder(root='./mnist_images/train', transform=transform)
    test_dataset = datasets.ImageFolder(root='./mnist_images/test', transform=transform)

    # 打印他们的长度
    print("train_dataset length: ", len(train_dataset))
    print("test_dataset length: ", len(test_dataset))

    # 使用train_loader, 实现小批量的数据读取
    # 这里设置小批量的大小,batch_size=64. 也就是每个批次,包括64个数据
    train_loader = DataLoader(train_datase, batch_size=64, shuffle=True)
    # 打印train_loader的长度
    print("train_loader length: ", len(train_loader))
    # 6000个训练数据,如果每个小批量,读入64个样本,那么60000个数据会被分成938组
    # 938*64=60032,说明最后一组不够64个数据

    # 循环遍历train_loader
    # 每一次循环,都会取出64个图像数据,作为一个小批量batch
    for batch_idx, (data, label) in enumerate(train_loader)
        if batch_idx == 3:
            break
        print("batch_idx: ", batch_idx)
        print("data.shape: ", data.shape) # 数据的尺寸
        print("label: ", label.shape) # 图像中的数字
        print(label)

7. 模型的训练和测试

import torch
from torch import nn
from torch import optim
from model import Network
from torchvision import transforms
from torchvision import datasets
from torch.utils.data import DataLoader

if __name__ == '__main__'
    # 图像的预处理
    transform = transforms.Compose([
        transforms.Grayscale(num_output_channels=1),
        transforms.ToTensor()
    ])

    # 读入并构造数据集
    train_dataset = datasets.ImageFolder(root='./mnist_images/train', transform=transform)
    print("train_dataset length: ", len(train_dataset))

    # 小批量的数据读入
    train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)
    print("train_loader length: ", len(train_loader))

    # 在使用Pytorch训练模型时,需要创建三个对象:
    model = Network() # 1.模型本身,就是我们设计的神经网络
    optimizer = optim.Adam(model.parameters()) #2.优化器,优化模型中的参数
    criterion = nn.CrossEntropyLoss() #3.损失函数,分类问题,使用交叉熵损失误差

    # 进入模型的循环迭代
    # 外层循环,代表了整个训练数据集的遍历次数
    for epoch in range(10):
        # 内层循环使用train_loader, 进行小批量的数据读取
        for batch_idx, (data, label) in enumerate(train_loader):
            # 内层每循环一次,就会进行一次梯度下降算法
            # 包括了5个步骤
            # 这5个步骤是使用pytorch框架训练模型的定式,初学时先记住即可
            # 1. 计算神经网络的前向传播结果
            output = model(data)
            # 2. 计算output和标签label之间的损失loss
            loss = criterion(output, label)
            # 3. 使用backward计算梯度
            loss.backward()
            # 4. 使用optimizer.step更新参数
            optimizer.step()
            # 5.将梯度清零
            optimizer.zero_grad()

            if batch_idx % 100 == 0:
                print(f"Epoch {epoch + 1}/10"
                      f"| Batch {batch_idx}/{len(train_loader)}"
                      f"| Loss: {loss.item():.4f}"
                      )
    torch.save(model.state_dict(), 'mnist.pth')

from model import Network
from torchvision import transforms
from torchvision import datasets
import torch

if __name__ == '__main__'
    transform = transforms.Compose([
        transforms.Grayscale(num_output_channels=1),
        transforms.ToTensor()
    ])
    # 读取测试数据集
    test_dataset = datasets.ImageFolder(root='./mnist_images/test', transform=transform)
    print("test_dataset length: ", len(test_dataset))

    model = Network() # 定义神经网络模型
    model.load_state_dict(torch.load('mnist.pth')) # 加载刚刚训练好的模型文件

    rigth = 0 # 保存正确识别的数量
    for i, (x, y) in enumerate(test_dataset):
        output = model(x) # 将其中的数据x输入到模型
        predict = output.argmax(1).item() # 选择概率最大标签的作为预测结果
        # 对比预测值predict和真实标签y
        if predict == y:
            right += 1
        else:
            # 将识别错误的样例打印出来
            img_path = test_dataset.samples[i][0]
            print(f"wrong case: predict = {predict} y = {y} img_path = {img_path}")
    # 计算出测试效果
    sample_num = len(test_dataset)
    acc = right * 1.0 / sample_num
    print("test accuracy = %d / %d = %.3lf" % (right, sample_num, acc))

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1918516.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

Qt常用基础控件总结—带边框的部件(QFrame和QLabel)

带边框的部件 框架控件QFrame类 QFrame类介绍 QFrame 类是带有边框的部件的基类,带边框部件的特点是有一个明显的边框,QFrame类就是用来实现边框的不同效果的(把这种效果称为边框样式),所有继承自 QFrame 的子类都可以使用 QFrame 类实现的效果。 部件通常是矩形的(其他…

图纸文档管理新篇章:陕西航沣与三品软件合作 优化研发流程

近日,陕西航沣新材料有限公司与三品软件正式达成合作协议,共同打造高效、智能的图纸文档管理平台。此次合作旨在赋能陕西航沣在高性能碳纤维增强纸基摩擦材料领域的创新与发展,提升企业的核心竞争力。 客户简介 陕西航沣新材料有限公司&…

脚本批量修改文件名 格式xx.bat

批量修改文件名适用于windows系统 分为4步 1.新建一个 批量修改文件名.txt文件 2.复制下面代码,保存 echo off chcp 65001 >nul set a0 setlocal EnableDelayedExpansion for %%n in (*.png) do (set /A a1ren "%%n" "影魅!a!.jpg" )3.修…

C语言-顺序表

🎯引言 欢迎来到HanLop博客的C语言数据结构初阶系列。在这个系列中,我们将深入探讨各种基本的数据结构和算法,帮助您打下坚实的编程基础。本次我将为你讲解。顺序表(也称为数组)是一种线性表,因其简单易用…

Base64文件流查看下载PDF方法-CSDN

问题描述 数票通等接口返回的PDF类型发票是以Base64文件流的方式返回的&#xff0c;无法直接查看预览PDF发票&#xff0c; 处理方法 使用第三方在线工具&#xff1a;https://www.jyshare.com/front-end/61/ 在Html代码框中粘贴如下代码 <embed type"application/pd…

LeetCode LCR024.反转链表 经典题目 C写法

LeetCode LCR024.反转链表 经典题目C写法 第一种思路&#x1f9d0;&#xff1a; ​ 使用三个指针&#xff0c;n1,n2,n3&#xff0c;n1为空&#xff0c;n2为头结点&#xff0c;n3为头结点的next。开始反转后&#xff0c;n1赋值给n2的next&#xff0c;n2赋值给n1&#xff0c;n3赋…

VBA 批量发送邮件

1. 布局 2. 代码 前期绑定的话&#xff0c;需要勾选 Microsoft Outlook 16.0 Object Library Option ExplicitConst SEND_Y As String "Yes" Const SEND_N As String "No" Const SEND_SELECT_ALL As String "Select All" Const SEND_CANCEL…

ASP.NET Web应用中的 Razor Pages/MVC/Web API/Blazor

如果希望使用ASP.NET Core创建新的 Web 应用程序&#xff0c;应该选择哪种方法&#xff1f;Razor Pages还是 MVC&#xff08;模型-视图-控制器&#xff09;&#xff0c;又或者使用Web API Vue/React/......。 每种方法都有各自的优点和缺点。 什么是 MVC&#xff1f; 大多数服…

Windows桌面上透明的记事本怎么设置

作为一名经常需要记录灵感的作家&#xff0c;我的Windows桌面总是布满了各种文件和窗口。在这样的环境下&#xff0c;一个传统的记事本应用往往会显得突兀&#xff0c;遮挡住我急需查看的资料。于是&#xff0c;我开始寻找一种既能满足记录需求&#xff0c;又能保持桌面整洁美观…

ozon商家版本APP下载,ozon商家版本是怎么样的

在数字化时代&#xff0c;电子商务平台正以前所未有的速度扩张其市场份额&#xff0c;其中俄罗斯的Ozon平台便是典型代表。作为Ozon平台的商家&#xff0c;了解和掌握Ozon商家版本APP的使用对于提升经营效率、把握销售机会至关重要。本篇文章将为您解析Ozon商家版本APP的下载途…

搭建邮局服务器的配置步骤?如何管理协议?

搭建邮局服务器需要考虑的安全措施&#xff1f;怎么搭建服务器&#xff1f; 在现代互联网环境中&#xff0c;电子邮件是重要的沟通工具。为了保证信息传递的稳定性和安全性&#xff0c;许多企业选择自行搭建邮局服务器。AokSend将详细介绍搭建邮局服务器的配置步骤&#xff0c…

JeeSite与TopIAM整合实现单点登录(SSO)的技术探讨

一、引言 在现今的企业级应用系统中&#xff0c;随着业务的发展和系统的复杂化&#xff0c;单点登录&#xff08;Single Sign-On&#xff0c;简称SSO&#xff09;已成为提升用户体验、增强系统安全性的重要手段。JeeSite作为一个高效、高性能、强安全性的Java EE快速开发平台&…

zookeeper加入开机启动项

Windows的任务计划程序&#xff08;Task Scheduler&#xff09;是一个强大的工具&#xff0c;允许你安排程序在特定时间自动运行&#xff0c;包括开机时。 打开任务计划程序&#xff1a; 按下Win R键&#xff0c;打开“运行”对话框。输入taskschd.msc并回车&#xff0c;打开…

使用Docker制作python项目镜像

各docker桌面版本集合&#xff1a;如果提示新版本系统不支持&#xff0c;可下载旧版本 我也分享在下面。 链接: https://pan.baidu.com/s/1HvaO2wOIE3pNE0bM7Qm3sA?pwdg7ky 提取码: g7ky –来自百度网盘超级会员v2的分享 来源参考&#xff1a;https://zhuanlan.zhihu.com/p/65…

前端 js 单引号,双引号、斜杠, 表格 tr input、checkbox、、、、

直接上代码 var target (leftOrRight LEFT ? $("#left") : $("#right"));target.empty();// let tbody $("resultRight tbody");// tbody.empty();for (var i 0; i < items.length; i) {debugger// target.append("<option valu…

超纯水除硼 ,芯片专用超纯水硼的去除方法

硼在元素周期表里面是五号元素&#xff0c;是IIIA族中唯一 一个非金属元素。它是制造P型半导体的主要掺杂剂&#xff0c;基材中硼的含量直接影响半导体的极限电压&#xff0c;因此要严格控制基材中硼的含量。在半导体制造的过程中&#xff0c;水、气、化直接跟产品接触&#xf…

「51媒体」能否提供一份成功邀约媒体的技巧?

传媒如春雨&#xff0c;润物细无声&#xff0c;大家好&#xff0c;我是51媒体网胡老师。 媒体宣传加速季&#xff0c;100万补贴享不停&#xff0c;一手媒体资源&#xff0c;全国100城线下落地执行。详情请联系胡老师。 成功邀约媒体的技巧涉及多个方面&#xff0c;包括了解媒体…

MongoDB教程(二):mongoDB引用shell

&#x1f49d;&#x1f49d;&#x1f49d;首先&#xff0c;欢迎各位来到我的博客&#xff0c;很高兴能够在这里和您见面&#xff01;希望您在这里不仅可以有所收获&#xff0c;同时也能感受到一份轻松欢乐的氛围&#xff0c;祝你生活愉快&#xff01; 文章目录 引言一、MongoD…

MessageBox与HubSpot:企业沟通与客户管理的双重利器

今天咱们来聊聊两个超实用的工具——MessageBox和HubSpot。它们就像是你的超级助手&#xff0c;让你和客户沟通起来更顺畅&#xff0c;管理起来也更轻松。 先说说MessageBox吧 想象一下&#xff0c;你正在忙着工作&#xff0c;突然客户发来个消息&#xff0c;你嗖的一下就收到…

拉卡拉支付 Go SDK

最近有一个需求&#xff0c;需要用到拉卡拉的支付&#xff0c;然后秉着开源精神去网上找到了 github.com/go-pay/gopay 一个支付的库&#xff0c;等到我使用的时候却发现拉卡拉的实现是 国外的接口&#xff0c;&#x1f602;&#x1f602;&#x1f602;。 无奈之下&#xff0c…