初探深度学习-手写字体识别

news2024/11/24 21:06:09

前言

手写数字的神经网络识别通常指的是通过训练有素的神经网络模型来识别和分类手写数字图像的任务。这种类型的任务是机器学习和计算机视觉领域的一个经典问题,经常作为入门级的图像识别问题来展示和测试各种机器学习算法的能力。
在实际应用中,手写数字识别可以用于处理和分析用户书写的数字信息,比如自动识别和输入手写数字文档中的数据,或者用于教育领域的练习评分系统,其中系统可以自动识别学生书写的数学题答案并提供反馈。
MNIST数据集是手写数字识别中最常使用的数据集之一。它包含了大量不同书写风格的手写数字图片,以及与之对应的标签。这个数据集被广泛用于训练和测试各种机器学习模型,包括深度学习模型。
在训练过程中,神经网络会学习如何识别和区分不同的手写数字,这个过程涉及到从大量的样本中提取特征,并调整网络内部参数,以便在看到一个新的手写数字图片时,能够准确地预测它所代表的数字。

训练过程

1-数据处理

本次训练的数据集选自 MNIST 数据集,该数据集保存在torchvision包中,因此我们只需要在训练开始前导入该包并进行数据处理即可:

# 导入训练集相关数据包
import torchvision
from torchvision.transforms import ToTensor
# 导入相关数据,并保存到本地,以便于下次训练
train_ds = torchvision.datasets.MNIST('data/',train=True,transform=ToTensor(),download=True)
test_ds = torchvision.datasets.MNIST('data/',train=False,transform=ToTensor(),download=True)
#对数据进行相关的处理
train_dl = torch.utils.data.DataLoader(train_ds,batch_size=64,shuffle=True)
test_dl = torch.utils.data.DataLoader(test_ds,batch_size=64)

在导入数据集合的时候,transform=Tensor()做的事情是将图片和标签转换为PyTorch张量,这样为我们后续进行相关模型处理提供了方便。
其中DataLoader函数对我们导入的相关数据进行了处理,其中进行了两个重要的步骤:

  1. 对训练集和测试集的数据设置了batch_size=64,这是每个批量包含64个样本的意思是,在每一次训练迭代中,模型会同时处理64个样本。这些样本会一起通过模型的前向传播(forward pass)过程,然后计算损失(如何预测不准确)并更新模型的参数。
  2. 对训练集的数据设置了打乱,这样会增加了模型的泛化能力,模型能更好地从整体上学习数据的分布,而不是记住每个单独的数据点。

2-模型构建

在导入完数据后,我们需要设计一个简单的神经网络,因为我们处理的是图像,因此输入维度是28*28,又因为我们处理的是手写字体识别,这是一个十分类的问题,因此我们的输出层的维度是10。在此基础上,我们对这个三层全连接的神经网络中间的隐藏层进行相关的设置,具体的代码如下:

class Model(nn.Module):
    def __init__(self):
        super().__init__()
        # 第一层输入展平后长度为28X28,创建128个神经元
        self.liner_1 = nn.Linear(28*28,128)
        # 第二层是前一层的输出,创建84个神经元
        self.liner_2 = nn.Linear(128,84)
        # 输出层接收第二层的输入84,输出分类个数为10
        self.liner_3 = nn.Linear(84,10)

在简单设置完这个三层的神经网络之后,我们需要设计这个神经网络的前向传播过程,具体代码如下:

    def forward(self,input):
        x = input.view(-1,28*28) #将输入展平为二维(1,28,28)->(28*28)
        x = torch.relu(self.liner_1(x))
        x = torch.relu(self.liner_2(x))
        x = self.liner_3(x)
        return x

这段代码实现的是将每个批量的多维张量的大小进行改变,以保证保持批次大小不变,但同时得到图像的像素总数,作为第一个线性层的输入,然后将输入通过第一个全连接层(self.liner_1),然后应用ReLU(Rectified Linear Unit)激活函数。
输入完成后再进行传入到第二个全连接层,这是另一个非线性变换,进一步帮助网络提取和组合特征。
x = self.liner_3(x) 这行代码将上一个层的输出通过输出层(self.liner_3)。这一层通常用于生成最终的预测结果。在这个例子中,self.liner_3是一个分类层,它将输出一个10维的向量,表示10个类别的得分。
最后,函数返回通过整个网络前向传播后的输出x。这个输出将用于后续的损失计算和梯度下降步骤,以训练网络的权重。
在设置完相关的神经网络模型后,我们还需要设置相关的损失函数和优化器,相关代码如下:

model = Model().to(device)
loss_fn = nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(model.parameters(),lr=0.001)

这段代码创建了一个损失函数CrossEntropyLoss的实例。这是一个用于分类问题的损失函数,它计算预测的概率分布与真实标签之间的交叉熵。
然后创建了一个优化器SGD(Stochastic Gradient Descent)的实例。这个优化器用于更新模型的参数,以最小化损失函数。
通过将模型中模型中所有可训练参数的列表进行返回,并且设置学习率为0.001进行优化器每次更新参数时所使用的步长的设定。

3-训练函数&测试函数

设计完模型后,我们需要通过对模型的使用,来设计我们的训练函数和测试函数,其中训练函数代码如下:

#编写训练循环
def train(dataloader,model,loss_fn,optimizer):
    size = len(dataloader.dataset)
    num_batchs = len(dataloader)
    train_loss,correct = 0,0
    for X,y in dataloader:
        X,y =X.to(device),y.to(device)
        pred = model(X)
        loss = loss_fn(pred,y)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        with torch.no_grad():
            correct += (pred.argmax(1)==y).type(torch.float).sum().item()
            train_loss += loss.item()
    train_loss /= num_batchs
    correct /= size
    return train_loss,correct

这段代码首先获取数据加载器中整个数据集的大小和批量,将其赋值给sizenum_batchs,同时将train_losscorrect置空。
然后遍历数据加载器中的每个批次,将其数据和标签赋值给Xy,并将其移动到GPU中。
在此基础上,进行模型的训练,首先通过将X传入模型,得到前向传播的结果pred,然后计算预测pred和真实标签y之间的损失。
其中以下三行代码极其关键:

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

这三段代码实现的功能是:首先清除优化器中先前计算的梯度,为新的梯度更新做准备。然后计算损失关于模型参数的梯度。最后使用计算出的梯度更新模型的参数。
最后通过以下代码更新损失和准确度,并将其返回:

        with torch.no_grad():
            correct += (pred.argmax(1)==y).type(torch.float).sum().item()
            train_loss += loss.item()
    train_loss /= num_batchs
    correct /= size
    return train_loss,correct

测试函数与训练函数大体相似,不同的是没有使用相关的传播代码,因为它只需要测试,不需要训练:

def test(dataloader,model):
    size = len(dataloader.dataset)
    num_batches =len(dataloader)
    test_loss,correct = 0,0
    with torch.no_grad():
        for X,y in dataloader:
            X,y =X.to(device),y.to(device)
            pred = model(X)
            test_loss += loss_fn(pred,y).item()
            correct +=(pred.argmax(1)==y).type(torch.float).sum().item()
    test_loss /= num_batches
    correct /= size
    return test_loss,correct

4-模型训练

模型训练的过程就是通过持续调用模型的训练函数和测试函数,进行多轮训练,此时我们设置训练轮数为 50 轮,相关代码如下:

epochs = 50
train_loss = []
train_acc = []
test_loss = []
test_acc = []
for epoch in range(epochs):
    epoch_loss,epoch_acc=train(train_dl,model,loss_fn,optimizer)
    epoch_test_loss,epoch_test_acc=test(test_dl,model)
    train_loss.append(epoch_loss)
    train_acc.append(epoch_acc)
    test_loss.append(epoch_test_loss)
    test_acc.append(epoch_test_loss)
    template=("epoch:{:2d},train_loss:{:.5f},train_acc:{:.1f}%,""test_loss:{:.5f},test_acc:{:.1f}%")
    print(template.format(epoch,epoch_loss,epoch_acc*100,epoch_test_loss,epoch_test_acc*100))
print("Done!")  

5-模型使用

在训练完模型后,我们需要对模型进行保存,以便后续将模型赋值给相关函数和功能,保存模型的代码如下,该代码会将训练的模型以model_complete_pth方式进行保存:

torch.save(model,'model_complete.pth')

接下来我们调用该模型,对传入的图形数据进行预测,整体代码如下:

import torch
import torchvision
from torchvision.transforms import ToTensor
# 对训练数据进行处理
from random import shuffle
import matplotlib.pyplot as plt
import numpy as np
from torch import nn
class Model(nn.Module):
    def __init__(self):
        super().__init__()
        # 第一层输入展平后长度为28X28,创建128个神经元
        self.liner_1 = nn.Linear(28*28,128)
        # 第二层是前一层的输出,创建84个神经元
        self.liner_2 = nn.Linear(128,84)
        # 输出层接收第二层的输入84,输出分类个数为10
        self.liner_3 = nn.Linear(84,10)
    def forward(self,input):
        x = input.view(-1,28*28) #将输入展平为二维(1,28,28)->(28*28)
        x = torch.relu(self.liner_1(x))
        x = torch.relu(self.liner_2(x))
        x = self.liner_3(x)
        return x
device = 'cuda' if torch.cuda.is_available() else "cpu"
model = Model().to(device)
train_ds = torchvision.datasets.MNIST('data/',train=True,transform=ToTensor(),download=True)
test_ds = torchvision.datasets.MNIST('data/',train=False,transform=ToTensor(),download=True)
train_dl = torch.utils.data.DataLoader(train_ds,batch_size=64,shuffle=True)
test_dl = torch.utils.data.DataLoader(test_ds,batch_size=64)
def process_image(image):
    image = image.unsqueeze(0)  # 在第一维增加一个维度,因为模型期望的输入形状是 (batch_size, channels, height, width)
    image = image.to(device)
    return image
def predict(image):
    image = process_image(image)
    with torch.no_grad():
        prediction = model(image)
        _,predicted_digit = torch.max(prediction, 1)
    return predicted_digit.item()
device = 'cuda' if torch.cuda.is_available() else "cpu"
model = torch.load('model_complete.pth')
for i in range(5):
    test_image, _ = test_ds[i]
    plt.imshow(test_image.squeeze(), cmap='gray')
    plt.show()
    predicted_digit = predict(test_image)
    print(f"The predicted digit is: {predicted_digit}")

程序运行后,会根据输入的图片进行相关的分类,准确度较高:image.png

总结

本篇文章我们记录了对手写字体识别的训练任务,包括了数据处理、模型构建、训练函数构建、测试函数构建、模型训练及使用五个部分,作为初级的深度学习训练任务,这是一个十分类的任务,通过这个任务,我们可以具体的去感知神经网络任务的设计、使用、损失函数的使用、优化器的使用,这背后涉及到很多的数学知识需要我们去学习和调优。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1493580.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

比肩Gen-2,全新开源文生视频模型

著名开源平台Stability.ai在官网宣布,推出全新文生视频的扩散模型Stable Video Diffusion,已开源了该项目并公布了论文。 据悉,用户通过文本或图像就能生成高精准,14帧和25帧的短视频。目前,Stable Video Diffusion处…

为 OpenBMC 添加一个新的系统

1. 前言 在上一篇文章中向大家介绍了 OpenBMC 的是什么以及它的作用和应用场景,并且以一个自带的示例平台 romulus 展示了从下载源码包开始到启动系统并访问 Web 控制页面的整体构建流程。 通过前文已经了解到如何为已有的平台构建系统镜像,下面我们来…

如何使用LEAKEY轻松检测和验证目标服务泄露的敏感凭证

关于LEAKEY LEAKEY是一款功能强大的Bash脚本,该脚本能够检测和验证目标服务中意外泄露的敏感凭证,以帮助广大研究人员检测目标服务的数据安全状况。值得一提的是,LEAKEY支持高度自定义开发,能够轻松添加要检测的新服务。 LEAKEY主…

自动灭火贴有用吗?搞清楚自动灭火贴的使用范围很关键!

近年来,伴随着新能源车辆的自燃案例频频发生,许多新型自动灭火产品走红网络。灭火球、灭火宝、灭火手雷......层出不穷的新型产品褒贬各异,哪怕是目前占领市场份额较多的自动灭火贴也不乏有人心里嘀咕:自动灭火贴好用吗&#xff1…

什么是Docker容器?

Docker是一种轻量级的虚拟化技术,同时是一个开源的应用容器运行环境搭建平台,可以让开发者以便捷方式打包应用到一个可移植的容器中,然后安装至任何运行Linux或Windows等系统的服务器上。相较于传统虚拟机,Docker容器提供轻量化的…

Linux编程3.3 进程-进程的终止

1、正常终止 从main函数返回调用exit(标准C库函数)调用_exti或_Exit(系统调用)最后一个线程从其启动例程返回最后一个线程调用 pthread exit 2、异常终止 调用abort接受到一个信号并终止最后一个线程对取消请求做处理响应 3、进程返回 通常程序运行…

谷歌Gemini批量多线程写原创文章API软件-支持双标题违禁词过滤

谷歌Gemini批量多线程写原创文章软件介绍: 1、Gemini 是谷歌筹备了一年之久的GPT4真正竞品,也是目前谷歌能拿出手的功能最为强悍、适配最为灵活的大模型。 2、谷歌Gemini目前免费申请key,key没有额度限制,可以一直写文章。 3、谷…

Mint_21.3 drawing-area和goocanvas的FB笔记(四)

Cairo图形输出 cairo的surface可以是pixbuf, 可以是screen, 可以是png图,也可以是pdf文件 、svg文件、ps文件,定义了surface就可以用cairo_create(surface)产生cairo context, 操作cairo context就可以方便地在surface上画图,如果surface是p…

LeetCode-第137题-只出现一次的数||

1.题目描述 给你一个整数数组 nums ,除某个元素仅出现 一次 外,其余每个元素都恰出现 三次 。请你找出并返回那个只出现了一次的元素。 你必须设计并实现线性时间复杂度的算法且使用常数级空间来解决此问题。 2.样例描述 3.思路描述 先把数组排序&am…

戏说c第二十六篇: 测试完备性衡量(代码覆盖率)

前言 师弟:“师兄,我又被鄙视了。说我的系统太差,测试不过关。” 我:“怎么说?” 师弟:“每次发布版本给程夏,都被她发现一些bug,太丢人了。师兄,有什么方法来衡量测试的…

vue2源码分析-vue入口文件global-api分析

文章背景 vue项目开发过程中,首先会有一个初始化的流程,以及我们会使用到很多全局的api,如 this.$set this.$delete this.$nextTick,以及初始化方法extend,initUse, initMixin , initExtend, initAssetRegisters 等等那它们是怎么实现,让我们一起来探究下吧 源码目录 global-…

2024年展望Android原生开发的现状,2024网易Android高级面试题及答案

没有稳定的工作,只有稳定的能力。 又到了万物复苏的季节,在程序猿这个行当里,作为 Android 开发出生的,在经历了八年的脱发生涯后,有了越来越多的想法和感触 趋势 随着各类移动跨平台的兴起,在 ReactNati…

基于SpringBoot+Vue实现的人力资源管理系统

系统介绍: 基于SpringBootVue实现的人力资源管理系统是为了提高企业人力资源管理水平而开发的。主要目标是通过对员工及人力资源活动信息(考勤、工资)等的编制来提高企业效率。 系统一共分为五大菜单项,分别是首页、薪资管理、权限管理、系统…

魔行观察-每日品牌监测-沪上阿姨-开店趋势

今日监测对象:沪上阿姨,监测时间段:2014年9月至2023年12月,发布时间:2024-03-05 数据获取地址:魔查查https://www.moxingdata.com/品牌基础信息 现有门店人均消费覆盖省份经营模式投资金额814416.531特许/…

工时管理软件:为什么企业需要工时跟踪?

工时跟踪对于企业经营来说,可能不是首要事项。工时跟踪有什么用? 管理学大师彼得德鲁克曾说过:If you can’t measure it, you can’t improve it(如果无法衡量,就无法改进)。企业经营也是同样道理&#x…

分布式系统中常用的缓存方案

1. 引言 随着互联网应用的发展和规模的不断扩大,分布式系统中的缓存成为了提升性能和扩展性的重要手段之一。本文将介绍几种在分布式系统中常用的缓存方案,包括分布式内存缓存、分布式键值存储、分布式对象存储和缓存网关等。 1.1 缓存在分布式系统中的…

FEP容量瓶多应用于制药光电光伏行业

常用规格:25ml、50ml、100ml、250mlFEP容量瓶也叫特氟龙容量瓶,容量瓶是为配制一定物质的量浓度的溶液用的精确定容器皿,常和移液管配合使用。广泛用于ICP-MS、ICP-OES等痕量分析以及同位素分析等高端实验。地质、电子化学品、半导体分析测试…

挑战杯 基于深度学习的动物识别 - 卷积神经网络 机器视觉 图像识别

文章目录 0 前言1 背景2 算法原理2.1 动物识别方法概况2.2 常用的网络模型2.2.1 B-CNN2.2.2 SSD 3 SSD动物目标检测流程4 实现效果5 部分相关代码5.1 数据预处理5.2 构建卷积神经网络5.3 tensorflow计算图可视化5.4 网络模型训练5.5 对猫狗图像进行2分类 6 最后 0 前言 &#…

安防摄像头(IPC)的步进马达及IR-CUT驱动国产芯片——D6212

应用领域 安防摄像头(IPC)的步进马达及IR-CUT驱动。 02 功能介绍 D6212内置8路带有续流二极管的达林顿驱动管阵列和一个H桥驱动,单芯片即可实现2个步进电机和一个IR-CUT的直接驱动,使得电路应用非常简单。单个达林顿管在输入电…

android开发教程百度网盘,高并发系统基础篇

展望未来 操作系统 移动操作系统的演变过程,从按键交互的塞班功能机到触摸屏交互的Android/IOS智能机,从小屏幕手机到全面屏、刘海屏、水滴屏。任何系统无非干两件事:输入和输出,接收到外部输入信号后经过操作系统处理后输出信息…