pytorch实现 --- 手写数字识别

news2024/11/29 2:47:49

        本篇文章是博主在人工智能等领域学习时,用于个人学习、研究或者欣赏使用,并基于博主对人工智能等领域的一些理解而记录的学习摘录和笔记,若有不当和侵权之处,指出后将会立即改正,还望谅解。文章分类在Pytorch

       Pytorch(1)---pytorch实现 --- 手写数字识别》

pytorch实现 --- 手写数字识别

目录

1.项目介绍

2.实现方法

3.程序代码

4.运行结果


1.项目介绍

        使用pytorch实现手写数字识别,十分简单的小项目,环境搭建好,一跑就通。


2.实现方法

2.1方式1        

 安装库:

pip install numpy torch torchvision matplotlib

 运行:

python test.py

首次运行会下载MNIST数据集,请保持网络畅通

2.2方式2

        如果使用pycharm,已经安装好了pytorch环境,那么直接在pytorch环境中运行下面这份代码就好。


3.程序代码

"""手写数字识别项目
    时间:2023.11.6
    环境:pytorch
    作者:Rainbook
"""

import torch
from torch.utils.data import DataLoader
from torchvision import transforms
from torchvision.datasets import MNIST
import matplotlib.pyplot as plt

class Net(torch.nn.Module):  # 定义一个Net类,神经网络的主体
    def __init__(self):  # 全连接层,四个
        super().__init__()
        self.fc1 = torch.nn.Linear(28*28, 64)  # 输入层输入28*28,输出64
        self.fc2 = torch.nn.Linear(64, 64)  # 中间层,输入64,输出64
        self.fc3 = torch.nn.Linear(64, 64)
        self.fc4 = torch.nn.Linear(64, 10)  # 中间层(隐藏层)的最后一层,输出10个特征值
    
    def forward(self, x):  # 前向传播过程
        # self.fc1(x)全连接线性计算,再套上一个激活函数torch.nn.functional.relu()
        x = torch.nn.functional.relu(self.fc1(x))
        x = torch.nn.functional.relu(self.fc2(x))
        x = torch.nn.functional.relu(self.fc3(x))
        # 最后一层进行softmax归一化,log_softmax是为了提高计算稳定性,在softmax后面套上了一个对数运算
        x = torch.nn.functional.log_softmax(self.fc4(x), dim=1)
        return x


def get_data_loader(is_train):
    to_tensor = transforms.Compose([transforms.ToTensor()])  # 定义数据转换类型tensor,多维数组(张量)
    """下载MNIST数据集,
        "":当前位置
        is_train:判断是训练集还是测试集;
        batch_size:一个批次包含15张图片;
        shuffle:数据随机打乱的
    """
    data_set = MNIST("", is_train, transform=to_tensor, download=True)
    return DataLoader(data_set, batch_size=15, shuffle=True)  # 数据加载器


def evaluate(test_data, net):  # 用来评估神经网络
    n_correct = 0
    n_total = 0
    with torch.no_grad():
        for (x, y) in test_data:
            outputs = net.forward(x.view(-1, 28*28))  # 计算神经网络的预测值
            for i, output in enumerate(outputs):  # 对每个批次的预测值进行比较,累加正确预测的数量
                if torch.argmax(output) == y[i]:
                    n_correct += 1
                n_total += 1
    return n_correct / n_total  # 返回正确率


def main():
    # 导入训练集和测试集
    train_data = get_data_loader(is_train=True)
    test_data = get_data_loader(is_train=False)
    net = Net()  # 初始化神经网络

    # 打印初始网络的正确率,应当是10%附近。手写数字有十种结果,随机猜的正确率就是1/10
    print("initial accuracy:", evaluate(test_data, net))
    """训练神经网络
    pytorch的固定写法
    """
    optimizer = torch.optim.Adam(net.parameters(), lr=0.001)
    for epoch in range(5):  # 需要在一个数据集上反复训练神经网络,epoch网络轮次,提高数据集的利用率
        for (x, y) in train_data:
            net.zero_grad()  # 初始化
            output = net.forward(x.view(-1, 28*28))  # 正向传播
            # 计算差值,nll_loss对数损失函数,为了匹配log_softmax的log运算
            loss = torch.nn.functional.nll_loss(output, y)
            loss.backward()  # 反向误差传播
            optimizer.step()  # 优化网络参数
        print("epoch", epoch, "accuracy:", evaluate(test_data, net))  # 打印当前网络的正确率

    """测试神经网络
        训练完成后,随机抽取3张图片进行测试
    """
    for (n, (x, _)) in enumerate(test_data):
        if n > 3:
            break
        predict = torch.argmax(net.forward(x[0].view(-1, 28*28)))  # 测试结果
        plt.figure(n)  # 画出图像
        plt.imshow(x[0].view(28, 28))  # 像素大小28*28
        plt.title("prediction: " + str(int(predict)))  # figure的标题
    plt.show()


if __name__ == "__main__":
    main()

4.运行结果

4.1正确率

4.2测试结果

        参考资料来源:B站

        文章若有不当和不正确之处,还望理解与指出。由于部分文字、图片等来源于互联网,无法核实真实出处,如涉及相关争议,请联系博主删除。如有错误、疑问和侵权,欢迎评论留言联系作者,或者关注VX公众号:Rain21321,联系作者。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1176805.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

天津热力管网监测系统丨高效、稳定的供热解决方案

热力管网监测系统是一种用于监测和控制系统热力管道的智能技术。热力管网监测系统通过将传感仪器建设在热力管道上,实现对管道内温度、压力、流量等参数的实时监测,同时将数据传输到控制中心进行数据分析,以保障热力管道的安全稳定运行。 中央…

不学51直接学stm32可以吗?学stm32需要哪些基础?

不学51直接学stm32可以吗?学stm32需要哪些基础? 不管那些大佬技术多么牛逼,大多数入门都是从51单片机开始。 最近有一些入门的小伙伴问我说看到同学都从直接从STM32开始干了。最近很多小伙伴找我,说想要一些stm32的资料&#xff…

C++ reference

cppreference.com 《现代C语言核心特性解析》 这是一本 C 进阶图书,全书分为 42 章,深入探讨了从 C11 到 C20 引入的核心特性。 本书不仅通过大量的实例代码讲解特性的概念和语法,还从编译器的角度分析特性的实现原理,让读者…

王道p18 第11题 现在有两个等长升序序列 A和 B,试设计一个在时间和空间两方面都尽可能高效的算法,找出两个序列 A和B的中位数。

视频讲解&#x1f447;&#xff1a; p18 第10题 c语言代码实现王道数据结构课后代码题_哔哩哔哩_bilibili 本题代码如下 int search(int a[], int b[], int c[]) {int i 0;int j 0;int k 0;while (i < 5 && j < 5){if (a[i] < b[j])c[k] a[i];elsec[k…

NOIP2005提高组第二轮T3:传纸条

题目链接 NOIP2005提高组第二轮T3&#xff1a;传纸条 题目描述 小渊和小轩是好朋友也是同班同学&#xff0c;他们在一起总有谈不完的话题。一次素质拓展活动中&#xff0c;班上同学安排坐成一个 m m m 行 n n n 列的矩阵&#xff0c;而小渊和小轩被安排在矩阵对角线的两端…

GraphCodeBert:基于数据流的代码表征预训练模型

GraphCodeBert https://arxiv.org/abs/2009.08366 1 模型结构 使用多层双向 Transformer 变量定义 &#xff1a; C&#xff1a;源码集合W&#xff1a;文本集合V&#xff1a;变量集合E&#xff1a;变量间的边的集合 输入&#xff1a;把注释&#xff0c;源代码和变量集连接为…

【MySQL基本功系列】第一篇 先熟悉MySQL的运行逻辑

​ 我将推出一系列关于MySQL的博客文章&#xff0c;涵盖了从入门到深入底层的原理。这些文章将包括MySQL的运行逻辑、InnoDB存储引擎、SQL优化、undo log、bin log等多个方面的知识。希望这些文章能为你提供宝贵的信息和洞见&#xff0c;并帮助你更好地理解和应用MySQL。同时&a…

国产系列 | Atlas 300I Pro 推理卡性能、应用场景、技术规格介绍

Atlas 300I Pro 推理卡是基于昇腾AI处理器的新一代高性能推理卡&#xff0c;融合“通用处理器、AI Core、编解码”于一体&#xff0c;提供超强AI推理、目标检索等功能&#xff0c;具有超强算力、超高能效、高性能特征检索、安全启动等优势&#xff0c;可广泛应用于OCR识别、语音…

UMS攸信技术与欣奕华复合机器人携手共进,领跑智能制造未来!

近年来&#xff0c;全球机器人领域的相关创新机构与科技企业不断探索人工智能、人机协作、多技术融合等领域&#xff0c;推动机器人在仓储运输、智能工厂、医疗康复等领域的深入应用。 2023年10月19日&#xff0c;攸信技术与浙江欣奕华达成战略合作&#xff0c;成为其产品特约经…

Mysql Cluster (NDB - Network Database) - 分布式

Mysql高可用架构 复制&#xff08;Replication&#xff09; 是本文中所有 MySQL 技术的基础。包括&#xff1a;异步复制、半同步复制&#xff0c;增强半同步复制。InnoDB 副本集&#xff08;MySQL InnoDB ReplicaSet&#xff09; 无缝衔接其他 MySQL 官方提供的应用程序&#…

Java智慧工地管理平台可视化大数据建造工地APP源码

建筑行业是国民经济的重要物质生产部门和支柱产业之一&#xff0c;同时&#xff0c;建筑业也是一个安全事故多发的高危行业。如何加强施工现场安全管理、降低事故发生频率、杜绝各种违规操作和不文明施工、提高建筑工程质量&#xff0c;是摆在各级政府部门、施工企业面前的一道…

多机位直播案例

目录 1、案例简述 2、设备准备&#xff1a; &#xff08;1&#xff09;笔记本电脑 &#xff08;2&#xff09;手机 &#xff08;3&#xff09;触控一体机 &#xff08;4&#xff09;教室前端监控摄像机 &#xff08;5&#xff09;教室后端监控摄像机 &#xff08;6&…

R语言piecewiseSEM结构方程模型在生态环境领域实践技术应用

结构方程模型&#xff08;Sructural Equation Modeling&#xff0c;SEM&#xff09;可分析系统内变量间的相互关系&#xff0c;并通过图形化方式清晰展示系统中多变量因果关系网&#xff0c;具有强大的数据分析功能和广泛的适用性&#xff0c;是近年来生态、进化、环境、地学、…

java拉取股票数据进行分析

1.背景 2.数据获取分析 3.代码获取数据 代码: package com.life.gupiao;import cn.hutool.core.date.DateTime; import cn.hutool.core.date.DateUtil; import cn.hutool.http.HttpRequest; import cn.hutool.http.HttpUtil; import cn.hutool.poi.excel.ExcelUtil; import…

Scala中使用Selenium进行网页内容摘录的详解

前言 公众号成为获取信息的重要途径之一。而对于公众号运营者来说&#xff0c;了解公众号的数据情况非常重要。比如&#xff0c;你可能想要获取公众号的文章内容&#xff0c;进行数据分析或者生成摘要。或者你可能想要监控竞争对手的公众号&#xff0c;了解他们的最新动态动态…

【漏洞复现】Django _2.0.8_任意URL跳转漏洞(CVE-2018-14574)

感谢互联网提供分享知识与智慧&#xff0c;在法治的社会里&#xff0c;请遵守有关法律法规 文章目录 1.1、漏洞描述1.2、漏洞等级1.3、影响版本1.4、漏洞复现1、基础环境2、漏洞扫描3、漏洞验证 1.5、修复建议 说明内容漏洞编号CVE-2018-14574漏洞名称Django任意URL跳转漏洞漏洞…

Jmeter全流程性能测试实战

项目背景&#xff1a; 我们的平台为全国某行业监控平台&#xff0c;经过3轮功能测试、接口测试后&#xff0c;98%的问题已经关闭&#xff0c;决定对省平台向全国平台上传数据的接口进行性能测试。 01、测试步骤 1、编写性能测试方案 由于我是刚进入此项目组不久&#xff0c…

Chart 2 OpenCL简介

文章目录 前言OpenCL简介OpenCL 标准API 函数OpenCL C OpenCL Profiles总结 前言 记录本人学习OpenCL的历程&#xff0c;总结一些重要的知识点&#xff0c;作为个人学习笔记&#xff0c;参考书籍 Qualcomm Snapdragon™ Mobile Platform OpenCL General Programming and Optim…

【广州华锐互动】智慧安防应急可视化系统定制开发

随着科技的飞速发展&#xff0c;我们的生活方式正以前所未有的速度发生变化。在这个变革的时代&#xff0c;智慧安防应急可视化系统作为一种新兴的技术&#xff0c;正在为现代安全领域带来革命性的突破。本文将探讨智慧安防应急可视化系统的概念、应用和前景&#xff0c;以期为…

制造业出海如何乘风破浪?制胜绝招在这里!

目录 问题1: 企业为什么要出海&#xff1f; 问题2: 中国制造业出海企业应具备那些能力&#xff1f; 问题3: 出海应注意哪些事项以保证数据安全&#xff1f; 问题4: 出海企业应怎样做好人才管理&#xff1f; 问题5: 企业如何高质量出海&#xff1f; 国内制造领域各行各业纷…