基于pytorch实现手写数字识别

news2024/9/24 5:02:15

1,先安装pytorch,在pytorch环境中安装库:

1)进入所安装的pytorch环境,我的是pytorch

所以激活它:

conda activate pytorch

2)使用pip安装numpy,torch,torchvision,matplotlib库 

pip install numpy torch torchvision matplotlib

回车安装4个库

2,再将test.py文件用vscode打开,pycharm也行(主要我不怎么会用),这里用vscode展示。

 注意右下角环境要选好。

这里我已经测试了两次,最高在0.96左右。

献上源码:

import torch
from torch.utils.data import DataLoader
from torchvision import transforms
from torchvision.datasets import MNIST
import matplotlib.pyplot as plt


class Net(torch.nn.Module):#定义一个NET类,它就是神经网络的主体

    def __init__(self):
        super().__init__()  #四个全连接层
        self.fc1 = torch.nn.Linear(28*28, 64)#输入为28*28的像素尺寸图像
        self.fc2 = torch.nn.Linear(64, 64)
        self.fc3 = torch.nn.Linear(64, 64)#中间三层放了64个节点
        self.fc4 = torch.nn.Linear(64, 10)#输出为10个数字类别
    def forward(self, x):#forward函数定义了前向传播过程,参数x是图像输入
        x = torch.nn.functional.relu(self.fc1(x))#每层连接中我们先做全连接线性计算
        x = torch.nn.functional.relu(self.fc2(x))#再套上一个激活函数torch.nn.functional.relu
        x = torch.nn.functional.relu(self.fc3(x))
        x = torch.nn.functional.log_softmax(self.fc4(x), dim=1)#输出层通过sodtmax归一化,这里的log_softmax是为了提高计算的稳定性。
        return x#在softmax之外又套上了torch.nn.functional.log_softmax对数运算


def get_data_loader(is_train):#导入数据
    to_tensor = transforms.Compose([transforms.ToTensor()])#定义一个tensor,是一个多维数组,中文叫张量
    data_set = MNIST("", is_train, transform=to_tensor, download=True)#下载MNIST数据集,""是下载目录,空表示当前目录,is_train用来决定是导入训练集还是测试集
    return DataLoader(data_set, batch_size=15, shuffle=True)#batch_size=15表示一个批次含15张图片,shuffle=True表示数据是随机打乱的,最后返回数据加载器


def evaluate(test_data, net):#用来评估神经网络的正确率,
    n_correct = 0
    n_total = 0
    with torch.no_grad():
        for (x, y) in test_data:
            outputs = net.forward(x.view(-1, 28*28))#计算神经网络预测值
            for i, output in enumerate(outputs):#对批次结果进行比较,累加正确预测的数量
                if torch.argmax(output) == y[i]:#argmax函数计算数据中最大值的序号也就是预测的手写数字结果
                    n_correct += 1
                n_total += 1
    return n_correct / n_total#返回正确率


def main():

    train_data = get_data_loader(is_train=True)#导入训练集
    test_data = get_data_loader(is_train=False)#导入测试集
    net = Net()#初始化神经网络
    
    print("initial accuracy:", evaluate(test_data, net))#打印初始网络的正确率
    optimizer = torch.optim.Adam(net.parameters(), lr=0.001)#一下几行代码训练神经网络,都是pytorch的固定写法
    for epoch in range(2):#epoch反复训练,提高数据集的利用率,每一个轮次就是一个epoch
        for (x, y) in train_data:
            net.zero_grad()#初始化
            output = net.forward(x.view(-1, 28*28))#正向传播
            loss = torch.nn.functional.nll_loss(output, y)#计算差值,nll_loss是对数损失函数,是为了匹配前面的log_softmax中的对数运算
            loss.backward()#反向误差传播
            optimizer.step()#优化网络参数
        print("epoch", epoch, "accuracy:", evaluate(test_data, net))#每个轮次后打印当前网络的正确率

    for (n, (x, _)) in enumerate(test_data):#训练完成后随机抽取3张图像显示网络预测结果
        if n > 3:
            break
        predict = torch.argmax(net.forward(x[0].view(-1, 28*28)))
        plt.figure(n)
        plt.imshow(x[0].view(28, 28))
        plt.title("prediction: " + str(int(predict)))
    plt.show()


if __name__ == "__main__":
    main()

 1,讲解

1)使用MNIST数据集:手写数字图片7万张(训练6万张,测试1万张)。

2)什么是神经网络?

通过softmax归一化得到了看起来像概率的数值(概率分布),但它还不是真的概率,

调整a,b的值,如梯度下降算法,ADAM算法将神经网络问题转为最优化问题,重复过程几万次。

神经网络的本质是一个数学函数,训练的过程就是调整函数中的参数。

观察公式是线性的,但不是每个都是线性的,所以再套上一个非线性函数(也叫激活函数),f()

 

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1483382.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

鸿蒙开发案例:进京赶考(4)

系列文章目录 鸿蒙开发案例:进京赶考(1) 鸿蒙开发案例:进京赶考(2) 鸿蒙开发案例:进京赶考(3) 鸿蒙开发案例:进京赶考(4) 案例介绍…

计算机服务器中了mallox勒索病毒怎么解密,mallox勒索病毒解密流程

科技技术的第一生产力,网络技术的不断发展与应用,让企业逐步走向数字化时代,通过网络的力量可以为企业更好地开展各项业务工作,网络数据安全问题也由此成为众多企业关心的主要话题。近日,云天数据恢复中心接到某化工集…

Python + Appium 自动化操作微信入门看这一篇就够了

Appium 是一个开源的自动化测试工具,支持 Android、iOS 平台上的原生应用,支持 Java、Python、PHP 等多种语言。 Appium 封装了 Selenium,能够为用户提供所有常见的 JSON 格式的 Selenium 命令以及额外的移动设备相关的控制命令,…

5GC SBA架构

协议标准:Directory Listing /ftp/Specs/archive/23_series/23.501/ (3gpp.org) NF描述说明NSSFNetwork Slice Selection Function网络切片选择,根据UE的切片选择辅助信息、签约信息等确定UE允许接入的网络切片实例。NEF Network Exposure Function网络开…

Docker与虚拟机比较

在对比Docker和虚拟机前,先简单了解下虚拟化,明确Docker和虚拟机分别对应的虚拟化级别,然后对Docker和虚拟机进行比较。需要注意的是,Docker和虚拟机并没有什么可比性,而是Docker使用的容器技术和虚拟机使用的虚拟化技…

git介绍4.2

git(版本控制工具) 一、git 介绍 1、git是目前世界上最先进的分布式版本控制系统,可以有效,高速的处理从小到大的项目版本管理。 2、git是linux torvalds 为了帮助管理linux内核开发二开发的一个开放源码的版本控制软件。 3、git作用:更好…

【深度学习笔记】计算机视觉——图像增广

图像增广 sec_alexnet提到过大型数据集是成功应用深度神经网络的先决条件。 图像增广在对训练图像进行一系列的随机变化之后,生成相似但不同的训练样本,从而扩大了训练集的规模。 此外,应用图像增广的原因是,随机改变训练样本可以…

Java宝典-类和对象

目录 1.面向对象1.1 面向过程与面向对象 2. 类的定义和使用2.1 如何定义类3.类的实例化4.this引用5.构造方法5.1 什么是构造方法5.2 构造方法的特点 6.包6.1 导包6.2 自定义包 7.封装8.访问限定符9.static9.1 static修饰的成员变量9.2 static修饰的成员方法 10.代码块10.1 普通…

大数据信用报告如何查询?有哪些需要注意的?

大数据信用对于有资金周转的人来说是比较重要的,主要由于大数据信用无形的被不少机构用于贷前风控,无论是机构要求的还是自查,提前了解大数据信用情况是常规操作,那大数据信用报告如何查询?有哪些需要注意的呢?本文详细为大家讲…

Linux使用基础命令

1.常用系统工作命令 (1).用echo命令查看SHELL变量的值 qiangziqiangzi-virtual-machine:~$ echo $SHELL /bin/bash(2).查看本机主机名 qiangziqiangzi-virtual-machine:~$ echo $HOSTNAME qiangzi-virtual-machine (3).date命令用于显示/设置系统的时间或日期 qiangziqian…

数据结构篇十:红黑树

文章目录 前言1. 红黑树的概念2. 红黑树的性质3. 红黑树节点的定义4. 红黑树的插入4.1 情况一: cur为红,p为红,g为黑,u存在且为红4.2 情况二: cur为红,p为红,g为黑,u不存在/u存在且为黑。4.2.1 …

图论 - Trie树(字符串统计、最大异或对)

文章目录 前言Part 1:Trie字符串统计1.题目描述输入格式输出格式数据范围输入样例输出样例 2.算法 Part 2:最大异或对1.题目描述输入格式输出格式数据范围输入样例输出样例 2.算法 前言 本篇博客将介绍Trie树的常见应用,包括:Trie…

30道python自动化测试面试题(全)

🍅 视频学习:文末有免费的配套视频可观看 🍅 关注公众号【互联网杂货铺】,回复 1 ,免费获取软件测试全套资料,资料在手,涨薪更快 1、什么项目适合做自动化测试? 关键字:…

【Pytorch】论文复现 Vision Transformer (ViT)

文章目录 0. 进行设置1. 获取数据2. 创建Dataset和DataLoader3. 复现 ViT 论文:概述4. Equation 1: 将数据拆分为 patch 并创建类、位置和 patch 嵌入5. Equation 2: Multi-Head Attention (MSA)6. Equation 3: Multilayer Perceptron (MLP)7. 创建 Transformer 编码…

【计算机网络_应用层】协议定制序列化反序列化

文章目录 1. TCP协议的通信流程2. 应用层协议定制3. 通过“网络计算器”的实现来实现应用层协议定制和序列化3.1 protocol3.2 序列化和反序列化3.2.1 手写序列化和反序列化3.2.2 使用Json库 3.3 数据包读取3.4 服务端设计3.5 最后的源代码和运行结果 1. TCP协议的通信流程 在之…

每个人都应该知道的AI大模型:通往智能未来的桥梁

人工智能大模型已成为我们通往智能未来的桥梁。这些模型,如OpenAI的GPT-4,不仅是技术的巅峰,更是人类智慧的结晶。在这篇文章中,我们将深入探讨AI大模型的重要性,它们是如何工作的,以及它们对社会的潜在影响…

算法------(13)KMP

例题:(1)AcWing 831. KMP字符串 。。其实写完也不太理解。。随便写点吧 KMP就是求next数组和运用next的数组的过程。相比传统匹配模式一次更新一单位距离的慢速方法,next数组可以让下表字符串一次更新n - next【n】个距离&#x…

Java项目layui分页中文乱码

【问题描述】这部分没改之前中文乱码。 【解决办法】在layui.js或者layui.all.js文件中替换共、页、条转换成Unicode码格式。 字符Unicode共&#x5171页&#x9875条&#x6761【完美解决】改完之后重新运行项目,浏览器F12缓存清除就好了,右键

从键盘输入5个整数,将这些整数插入到一个链表中,并按从小到大次序排列,最后输出这些整数。

设节点定义如下struct Node {int Element; // 节点中的元素为整数类型struct Node * Next; // 指向下一个节点 }; 从键盘输入5个整数,将这些整数插入到一个链表中,并按从小到大次序排列,最后输出这些整数。注释那段求指出错误,求解…

【QT+QGIS跨平台编译】之六十二:【QGIS_CORE跨平台编译】—【错误处理:未定义类型QgsPolymorphicRelation】

文章目录 一、未定义类型QgsPolymorphicRelation二、解决办法一、未定义类型QgsPolymorphicRelation 报错信息: 错误原因为,使用了未定义类型 QgsPolymorphicRelation 二、解决办法 QgsRelation.h文件中 ①注释第36行: //class QgsPolymorphicRelation;②注释第414行: …