《深度学习》—— 神经网络模型对手写数字的识别

news2024/9/19 23:15:39

神经网络模型对手写数字的识别

import torch
from torch import nn  # 导入神经网络模块
from torch.utils.data import DataLoader  # 数据包管理工具,打包数据,
from torchvision import datasets  # 封装了很多与图像相关的模型,数据集
from torchvision.transforms import ToTensor  # 数据转换,张量,将其他类型的数据转换为tensor张量

"""
MNIST包含70,000张手写数字图像:60,000张用于训练,10,000张用于测试。
图像是灰度的,28x28像素的,并且居中的,以减少预处理和加快运行。
"""
""" 下载训练数据集 (包含训练数据+标签)"""
training_data = datasets.MNIST(
    root='data',
    train=True,
    download=True,
    transform=ToTensor()  # 张量,图片是不能直接传入神经网络模型
)  # 对于pytorch库能够识别的数据一般是tensor张量.
# NumPy 数组只能在CPU上运行。Tensor可以在GPU上运行,这在深度学习应用中可以显著提高计算速度。

""" 下载测试数据集(包含训练图片+标签)"""
test_data = datasets.MNIST(
    root='data',
    train=False,
    download=True,
    transform=ToTensor()
)
print(len(training_data))

""" 展示手写字图片 """
# tensor --> numpy 矩阵类型的数据
from matplotlib import pyplot as plt

figure = plt.figure()
for i in range(9):
    img, label = training_data[i + 59000]  # 提取第59000张图片

    figure.add_subplot(3, 3, i + 1)  # 图像窗口中创建多个小窗口,小窗口用于显示图片
    plt.title(label)
    plt.axis("off")  # 关闭坐标
    plt.imshow(img.squeeze(), cmap="gray")
    a = img.squeeze()  # img.squeeze()从张量img中去掉维度为1的(降维)
plt.show()

training_dataloader = DataLoader(training_data, batch_size=64)  # 64张图片为一个包
test_dataloader = DataLoader(test_data, batch_size=64)
for X, y in test_dataloader:  # X 表示打包好的每一个数据包
    print(f"Shape of X [N, C, H, W]: {X.shape}")
    print(f"Shape of y: {y.shape} {y.dtype}")
    break

""" 判断当前设备是否支持GPU,其中mps是苹果m系列芯片的GPU """
device = "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu"
print(f"Using {device} device")


class NeuralNetwork(nn.Module):  # 通过调用类的形式来使用神经网络,神经网络的模型->nn.module
    def __init__(self):
        super().__init__()  # 继承的父类初始化
        self.flatten = nn.Flatten()  # 展开,创建一个展开对象flatten
        self.hidden1 = nn.Linear(28 * 28, 256)  # 第1个参数:有多少个神经元传入进来,第2个参数:有多少个数据传出去前一层神经元的个数,当前本层神经元个数
        self.hidden2 = nn.Linear(256, 128)  # 输出必需和标签的类别相同,输入必须是上一层的神经元个数
        self.hidden3 = nn.Linear(128, 256)
        self.hidden4 = nn.Linear(256, 128)
        self.out = nn.Linear(128, 10)

    #
    def forward(self, x):  # 前向传播,你得告诉它,数据的流向。是神经网络层连接起来,函数名称不能改。当你调用forward函数的时候,传入进来的图像数据
        x = self.flatten(x)
        x = self.hidden1(x)
        x = torch.sigmoid(x)  # 激活函数
        x = self.hidden2(x)
        x = torch.sigmoid(x)
        x = self.hidden3(x)
        x = torch.sigmoid(x)
        x = self.hidden4(x)
        x = torch.sigmoid(x)
        x = self.out(x)
        return x


model = NeuralNetwork().to(device)  # 把刚刚创建的模型传入到gpu或cpu
print(model)


# 定义训练模型的函数
def train(dataloader, model, loss_fn, optimizer):
    model.train()  # 告诉模型,开始训练,模型中w进行随机化操作,已经更新w。在训练过程中,w会被修改的
    # pytorch提供2种方式来切换训练和测试的模式,分别是:model.train()和 model.eval()。
    # 一般用法是:在训练开始之前写上model.trian(),在测试时写上model.eval()。
    batch_size_num = 1
    for X, y in dataloader:
        X, y = X.to(device), y.to(device)  # 把训练数据集和标签传入cpu或GPU
        pred = model.forward(X)  # .forward可以被省略,父类中已经对次功能进行了设置。自动初始化w权值
        loss = loss_fn(pred, y)  # 通过交叉熵损失函数计算损失值loss

        optimizer.zero_grad()  # 梯度值清零
        loss.backward()  # 反向传播计算得到每个参数的梯度值w
        optimizer.step()  # 根据梯度更新网络w参数

        loss_value = loss.item()  # 从tensor数据中提取数据出来,tensor获取损失值
        if batch_size_num % 200 == 0:
            print(f"loss: {loss_value:>7f} [number:{batch_size_num}]")
        batch_size_num += 1


# 定义测试模型的函数
def test(dataloader, model, loss_fn):
    size = len(dataloader.dataset)
    num_batches = len(dataloader)
    model.eval()  # 测试,w就不能再更新。
    test_loss, correct = 0, 0
    with torch.no_grad():  # 一个上下文管理器,关闭梯度计算。当你确认不会调用Tensor.backward()的时候
        for X, y in dataloader:
            X, y = X.to(device), y.to(device)
            pred = model.forward(X)
            test_loss += loss_fn(pred, y).item()  # test loss是会自动累加每一个批次的损失值
            correct += (pred.argmax(1) == y).type(torch.float).sum().item()
            a = (pred.argmax(1) == y)  # dim=1表示每一行中的最大值对应的索引号,dim=0表示每一列中的最大值对应的索引号
            b = (pred.argmax(1) == y).type(torch.float)
    test_loss /= num_batches  # 衡量模型测试的好坏。
    correct /= size  # 平均的正确率
    print(f"Test result: \n Accuracy: {(100 * correct)}%, Avg loss: {test_loss}")


loss_fn = nn.CrossEntropyLoss()  # 创建交叉熵损失函数对象,因为手写字识别中一共有10个数字,输出会有10个结果

optimizer = torch.optim.Adam(model.parameters(), lr=0.01)  # 创建一个优化器

# 设置训练轮数
epochs = 10
for e in range(epochs):
    print(f"Epoch {e + 1}\n")
    train(training_dataloader, model, loss_fn, optimizer)
print("Done!")
# 测试模型
test(test_dataloader, model, loss_fn)

  • 展示的手写数字图片如下:
    在这里插入图片描述
  • 模型结构如下:
    在这里插入图片描述
  • 训练结果如下:
  • 共有10轮训练
    在这里插入图片描述
  • 测试结果如下:
    在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2147244.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

codetop字符串刷题,刷穿地心!!不再畏惧!!暴打面试官!!

主要供自己回顾与复习,题源codetop标签字符串近半年,会不断更新 1.有效的括号字符串2.括号生成3.最长单词4.字符串转换整数(atoi)5.整数转罗马数字6.罗马数字转整数7.比较版本号8.最长公共前缀9.面试题17.15.最长单词10.验证IP地址11.面试题01.06.字符串…

介绍一下常用的激活函数?

常用的激活函数 Sigmoid函数Tanh函数ReLU函数Leaky ReLU函数Softmax函数 Sigmoid函数 特点: 将任意实数映射到(0,1)区间内,输出值可以作为概率来解释。 函数平滑且易于求导,但其导数在两端趋近于0,即存在梯度消失问题。 输出值不…

CWFED:自然灾害检测数据集(猫脸码客 第192期)

Cyclone Wildfire Flood Earthquake Database 在自然灾害频发的今天,准确、及时地获取并分析相关数据对于灾害预防、预警及响应至关重要。为此,Cyclone Wildfire Flood Earthquake Database(以下简称CWFE Database)应运而生&…

计算机毕业设计 农场投入品运营管理系统 Java+SpringBoot+Vue 前后端分离 文档报告 代码讲解 安装调试

🍊作者:计算机编程-吉哥 🍊简介:专业从事JavaWeb程序开发,微信小程序开发,定制化项目、 源码、代码讲解、文档撰写、ppt制作。做自己喜欢的事,生活就是快乐的。 🍊心愿:点…

gcc升级(含命令行升级、手动升级两种方式)

gcc升级 1.yum源替换1.1 备份原始repo配置文件1.2 重新配置CentOS-Base.reporepo文件1.3 清除缓存并重新创建 2. gcc安装3.命令行升级gcc4.手动升级4.1 安装包下载4.2 解压4.3 gcc升级4.3.1 依赖拉取4.3.2 gmp安装4.3.3 mpfr安装4.3.4 mpc安装4.3.5 gcc编译、安装 4.4 gcc命令配…

Linux环境变量进程地址空间

目录 一、初步认识环境变量 1.1常见的环境变量 1.2环境变量的基本概念 二、命令行参数 2.1通过命令行参数获取环境变量 2.2本地变量和内建命令 2.3环境变量的获取 三、进程地址空间 3.1进程(虚拟)地址空间的引入 3.2进程地址空间的布局和理解 …

简易CPU设计入门:本CPU项目的指令格式

在这一节里面,主要是理论知识,基本上不讲代码。不过,本项目的代码包,大家还是需要下载的。 本项目的代码包的下载方法,参考下面的链接所指示的文章。 下载本项目代码 本节,其实是要讲本项目CPU的指令集。…

大模型蒸馏技术

一篇题为《The Mamba in the Llama: Distilling and Accelerating Hybrid Models》的论文证明:通过重用注意力层的权重,大型 transformer 可以被蒸馏成大型混合线性 RNN,只需最少的额外计算,同时可保留其大部分生成质量。 先来说…

Python学习——【2.1】if语句相关语法

文章目录 【2.1】if语句相关一、布尔类型和比较运算符(一)布尔类型(二)比较运算符 二、if语句的基本格式※、练习 三、if-else组合判断语句※、练习 四、if-elif-else多条件判断语句※、练习 五、判断语句的嵌套※、实战案例 【2.…

AlexNet项目图片分类通用模型代码

目录 一:建立AlexNet模型(在model文件中写) 1.构造5层卷积层 2.构造3层神经网络层 3.forward函数 4.模型最终代码 二:训练数据(在train中写) 1.读出数据 2.训练 3. 测试模型更新参数 4.完整的训练…

Datawhile 组队学习Tiny-universe Task01

Task01:LLama3模型讲解 仓库链接:GitHub - datawhalechina/tiny-universe: 《大模型白盒子构建指南》:一个全手搓的Tiny-Universe 参考博客:LLaMA的解读与其微调(含LLaMA 2):Alpaca-LoRA/Vicuna/BELLE/中文LLaMA/姜子…

新的突破,如何让AI与人类对话变得“顺滑”:Moshi背后的黑科技

你有没有想过,当我们跟智能音箱、客服机器人或者语音助手对话时,它们是怎么“听懂”我们说的话,又是怎么迅速给出回应的?就好像你对着Siri、Alexa说一句:“给我订个披萨”,它立刻明白你想要干嘛,然后帮你下单。背后的技术其实比我们想象的要复杂得多,但现在,有了Moshi…

Qt_布局管理器

目录 1、QVBoxLayout垂直布局 1.1 QVBoxLayout的使用 1.2 多个布局管理器 2、QHBoxLayout水平布局 2.1 QHBoxLayout的使用 2.2 嵌套的Layout 3、QGridLayout网格布局 3.1 QGridLayout的使用 3.2 设置控件大小比例 4、QFormLayout 4.1 QFormLayout的使用 5、…

【2024】前端学习笔记8-内外边距-边框-背景

学习笔记 外边距:Margin内边距:Padding边框:Border背景:Background 外边距:Margin 用于控制元素周围的空间,它在元素边框之外创建空白区域,可用于调整元素与相邻元素(包括父元素和兄…

AI预测福彩3D采取888=3策略+和值012路或胆码测试9月19日新模型预测第92弹

经过90多期的测试,当然有很多彩友也一直在观察我每天发的预测结果,得到了一个非常有价值的信息,那就是9码定位的命中率非常高,90多期一共只错了10次,这给喜欢打私房菜的朋友提供了极高价值的预测结果~当然了&#xff0…

教育政策与智能技术:构建新时代教师队伍

据最新统计,我国目前拥有各级各类教师共计1891.8万人,这一庞大的教师群体不仅支撑起了全球规模最大的教育体系,更成为了推动教育创新与变革的主力军。面对教育数字化的不断发展,育人内容、目标要求、方式方法的全面升级&#xff0…

【测向定位】差频MUSIC算法DOA估计【附MATLAB代码】

​微信公众号:EW Frontier QQ交流群:554073254 摘要 利用多频处理方法,在不产生空间混叠的情况下,估计出高频区域平面波的波达方向。该方法利用了差频(DF),即两个高频之间的差。这使得能够在可…

鹏鼎控股社招校招入职SHL综合能力测评:高分攻略及真题题库解析答疑

鹏鼎控股(深圳)股份有限公司,成立于1999年4月29日,是一家专注于印制电路板(PCB)的设计、研发、制造与销售的高新技术企业。公司总部位于中国广东省深圳市,并在全球多个地区设有生产基地和服务中…

【软考】数据字典(DD)

目录 1. 说明2. 数据字典的内容2.1 说明2.2 数据流条目2.3 数据存储条目2.4 数据项条目2.5 基本加工条目 3. 数据词典管理4. 加工逻辑的描述4.1 说明4.2 结构化语言4.3 判定表4.3 判定树 5. 例题5.1 例题1 1. 说明 1.数据流图描述了系统的分解,但没有对图中各成分进…

软件自动定时启动器-添加可执行文件软件,设置启动的时间,也可以设置关闭的时间-供大家学习研究参考

点击添加软件,可以添加可执行文件软件,设置启动的时间,也可以设置关闭的时间 注意,时间为00:00:00 等于没设置,这个时间不在设置范围,其他任何时间都可以。 下载地址: h…