【神经网络手写数字识别-最全源码(pytorch)】

news2024/10/7 14:30:56

Torch安装的方法

在这里插入图片描述

学习方法

  • 1.边用边学,torch只是一个工具,真正用,查的过程才是学习的过程
  • 2.直接就上案例就行,先来跑,遇到什么来解决什么

Mnist分类任务:

  • 网络基本构建与训练方法,常用函数解析

  • torch.nn.functional模块

  • nn.Module模块

读取Mnist数据集

  • 会自动进行下载
# 查看自己的torch的版本
import torch
print(torch.__version__)
%matplotlib inline
# 前两步,不用管是在网上下载数据,后续的我们都是在本地的数据进行操作
from pathlib import Path
import requests

DATA_PATH = Path("data")
PATH = DATA_PATH / "mnist"

PATH.mkdir(parents=True, exist_ok=True)

URL = "http://deeplearning.net/data/mnist/"
FILENAME = "mnist.pkl.gz"

if not (PATH / FILENAME).exists():
        content = requests.get(URL + FILENAME).content
        (PATH / FILENAME).open("wb").write(content)
import pickle
import gzip

with gzip.open((PATH / FILENAME).as_posix(), "rb") as f:
        ((x_train, y_train), (x_valid, y_valid), _) = pickle.load(f, encoding="latin-1")

784是mnist数据集每个样本的像素点个数

from matplotlib import pyplot
import numpy as np

pyplot.imshow(x_train[0].reshape((28, 28)), cmap="gray")
print(x_train.shape)

在这里插入图片描述
全连接神经网络的结构
在这里插入图片描述在这里插入图片描述注意数据需转换成tensor才能参与后续建模训练

import torch

x_train, y_train, x_valid, y_valid = map(
    torch.tensor, (x_train, y_train, x_valid, y_valid)
)
n, c = x_train.shape
x_train, x_train.shape, y_train.min(), y_train.max()
print(x_train, y_train)
print(x_train.shape)
print(y_train.min(), y_train.max())

torch.nn.functional 很多层和函数在这里都会见到

torch.nn.functional中有很多功能,后续会常用的。那什么时候使用nn.Module,什么时候使用nn.functional呢?一般情况下,如果模型有可学习的参数,最好用nn.Module,其他情况nn.functional相对更简单一些

import torch.nn.functional as F

loss_func = F.cross_entropy

def model(xb):
    return xb.mm(weights) + bias
bs = 64
xb = x_train[0:bs]  # a mini-batch from x
yb = y_train[0:bs]
weights = torch.randn([784, 10], dtype = torch.float,  requires_grad = True) 
bs = 64
bias = torch.zeros(10, requires_grad=True)

print(loss_func(model(xb), yb))

创建一个model来更简化代码

  • 必须继承nn.Module且在其构造函数中需调用nn.Module的构造函数
  • 无需写反向传播函数,nn.Module能够利用autograd自动实现反向传播
  • Module中的可学习参数可以通过named_parameters()或者parameters()返回迭代器
from torch import nn

class Mnist_NN(nn.Module):
    # 构造函数
    def __init__(self):
        super().__init__()
        self.hidden1 = nn.Linear(784, 128)
        self.hidden2 = nn.Linear(128, 256)
        self.out  = nn.Linear(256, 10)
        self.dropout = nn.Dropout(0.5)
    #前向传播自己定义,反向传播是自动进行的
    def forward(self, x):
        x = F.relu(self.hidden1(x))
        x = self.dropout(x)
        x = F.relu(self.hidden2(x))
        x = self.dropout(x)
        #x = F.relu(self.hidden3(x))
        x = self.out(x)
        return x
        

在这里插入图片描述

net = Mnist_NN()
print(net)

在这里插入图片描述
可以打印我们定义好名字里的权重和偏置项

for name,parameter in net.named_parameters():
    print(name, parameter,parameter.size())

在这里插入图片描述

使用TensorDataset和DataLoader来简化

from torch.utils.data import TensorDataset
from torch.utils.data import DataLoader

train_ds = TensorDataset(x_train, y_train)
train_dl = DataLoader(train_ds, batch_size=bs, shuffle=True)

valid_ds = TensorDataset(x_valid, y_valid)
valid_dl = DataLoader(valid_ds, batch_size=bs * 2)
def get_data(train_ds, valid_ds, bs):
    return (
        DataLoader(train_ds, batch_size=bs, shuffle=True),
        DataLoader(valid_ds, batch_size=bs * 2),
    )
  • 一般在训练模型时加上model.train(),这样会正常使用Batch Normalization和 Dropout
  • 测试的时候一般选择model.eval(),这样就不会使用Batch Normalization和 Dropout
import numpy as np

def fit(steps, model, loss_func, opt, train_dl, valid_dl):
    for step in range(steps):
        model.train()  # 训练的时候需要更新权重参数
        for xb, yb in train_dl:
            loss_batch(model, loss_func, xb, yb, opt)

        model.eval() # 验证的时候不需要更新权重参数
        with torch.no_grad():
            losses, nums = zip(
                *[loss_batch(model, loss_func, xb, yb) for xb, yb in valid_dl]
            )
        val_loss = np.sum(np.multiply(losses, nums)) / np.sum(nums)
        print('当前step:'+str(step), '验证集损失:'+str(val_loss))

zip的用法

a = [1,2,3]
b = [4,5,6]
zipped = zip(a,b)
print(list(zipped))
a2,b2 = zip(*zip(a,b))
print(a2)
print(b2)
from torch import optim
def get_model():
    model = Mnist_NN()
    return model, optim.SGD(model.parameters(), lr=0.001)
def loss_batch(model, loss_func, xb, yb, opt=None):
    loss = loss_func(model(xb), yb)

    if opt is not None:
        loss.backward()
        opt.step()
        opt.zero_grad()

    return loss.item(), len(xb)

三行搞定!

train_dl,valid_dl = get_data(train_ds, valid_ds, bs)
model, opt = get_model()
fit(100, model, loss_func, opt, train_dl, valid_dl)

在这里插入图片描述

correct = 0
total = 0
for xb,yb in valid_dl:
    outputs = model(xb)
    _,predicted = torch.max(outputs.data,1)
    total += yb.size(0)
    correct += (predicted == yb).sum().item()
print(f"Accuracy of the network the 10000 test imgaes {100*correct/total}")

![在这里插入图片描述](https://img-blog.csdnimg.cn/89e5e749b680426c9700aac9f93bf76a.png

后期有兴趣的小伙伴们可以比较SGD和Adam两种优化器,哪个效果更好一点

-SGD 20epoch 85%
-Adam 20epoch 85%

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/856818.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

【C语言】数据在内存中的存储详解

文章目录 一、什么是数据类型二、类型的基本归类三、 整型在内存中的存储1.原码、反码、补码2.大小端(1)什么是大小端(2)为什么会有大小端 四、浮点型在内存中的存储1. 浮点数存储规则 五、练习1.2.3.4.5.6.7. 一、什么是数据类型 我们可以把数据类型想象为一个矩形盒子&#x…

DCMM数据管理成熟度之数据战略-数据战略规划

需要咨询加 :shuirunjj 标准原文 1概述 数据战略规划是在所有利益相关者之间达成共识的结果。从宏观及微观两个层面确定开展数据管理及应用的动因,并综合反映数据提供方和消费方的需求。 2 过程描述 过程描述如下: a) 识别利益相关者,明确利益相关者的需求; …

人机融合智能可化简为遥控+预先规划+重新规划过程

人机融合智能可以被简单描述为人类的遥控、机器的预先规划以及人-机器共同的动态重新规划的过程。 首先,人类的遥控是指人类通过指令、控制和操作来操纵机器的行为和功能。人类可以利用各种界面和输入设备,如键盘、鼠标、触摸屏等,将自己的意…

Python做一个绘图系统3:从文本文件导入数据并绘图

文章目录 导入数据文件对话框修改绘图逻辑源代码 Python绘图系统系列:将matplotlib嵌入到tkinter 简单的绘图系统 导入数据 单纯从作图的角度来说,更多情况是已经有了一组数据,然后需要将其绘制。这组数据可能是txt格式的,也可能…

HashMap的put方法流程

首先根据key的值计算hash值,找到该元素在数组中存储的下标如果数组是空的,则调用resize进行初始化;如果没有哈希冲突直接放在对应的数组下标里如果冲突了,且key已经存在,就覆盖掉value如果冲突后是链表结构&#xff0c…

Android Studio实现刮刮卡效果

代码和刮刮乐图片参考网络 实现效果 MainActivity import android.app.Activity; import android.os.Bundle;public class MainActivity extends Activity {Overrideprotected void onCreate(Bundle savedInstanceState) {super.onCreate(savedInstanceState);setContentVi…

汽车控制器底层软件BOOTLOADER开发经历

现在所谓智能汽车必备的OTA技术,在ECU控制器层面就是BOOT的开发,对应autosar体系里面的BSW基础软件。 同学刚开始接触汽车软件开发会有一种思想,要学就学听起来high level的autosar,但是到底autosar是个什么东西也搞不懂&#xf…

基于数据全生命周期的数据资产价值评估方法及应用

基于数据全生命周期的数据资产价值评估方法及应用 李冬青, 刘吟啸, 邓镭, 李铭洋 阿里巴巴集团,上海 200120 摘要:数据资产价值评估是现代数据资产管理和运营以及数据流通的基础。基于数据全生命周期理论,从第一性原则出发,通过评…

2023好用苹果电脑杀毒软件Cleanmymac X

苹果电脑怎么杀毒?这个问题自从苹果电脑变得越来越普及,苹果电脑的安全性问题也逐渐成为我们关注的焦点。虽然苹果电脑的安全性相对较高,但仍然存在着一些潜在的威胁,比如流氓软件窥探隐私和恶意软件等。那么,苹果电脑…

Day 25 C++ stack容器(栈)

文章目录 stack 基本概念定义基本概念栈顶(Top)——指向栈中最上面的元素的位置。入栈(Push)——将元素添加到栈顶。出栈(Pop)——从栈顶移除元素。栈空(Empty)——当栈中没有任何元…

企业权限管理(三)-产品添加

产品添加 从product-list.jsp跳转到product-add.jsp <button type"button" class"btn btn-default" title"新建" onclick"location.href${pageContext.request.contextPath}/pages/product-add.jsp"><iclass"fa fa-file…

后端开发9.商品类型模块

概述 简介 商品类型我设计的复杂了点,设计了多级类型 效果图 数据库设计

ORACLE和MYSQL区别

1&#xff0c;Oracle没有offet,limit&#xff0c;在mysql中我们用它们来控制显示的行数&#xff0c;最多的是分页了。oracle要分页的话&#xff0c;要换成rownum。 2&#xff0c;oracle建表时&#xff0c;没有auto_increment&#xff0c;所有要想让表的一个字段自增&#xff0c…

(JS逆向专栏十三)某信平台网站登入SM2

声明: 本文章中所有内容仅供学习交流&#xff0c;严禁用于商业用途和非法用途&#xff0c;否则由此产生的一切后果均与作者无关&#xff0c;若有侵权&#xff0c;请联系我立即删除&#xff01; 名称:电信 目标:登入参数 加密类型:SM2 目标网址:https://login.189.cn/web/login …

推出全新TrenchStop™ 5 WR6系列,IKWH50N65WR6XKSA1、IKWH40N65WR6XKSA1带来更佳的系统可靠性(IGBT)

推出全新分立式封装的650V TRENCHSTOP 5 WR6系列&#xff0c;该系列采用TO-247-3-HCC封装&#xff0c;能够实现额定电流分别为20A、30A、40A、50A、60A和70 A的丰富产品组合&#xff0c;可轻松替换前代技术&#xff0c;如TRENCHSTOP 5 WR5、HighSpeed 3 H3技术。该系列针对家用…

Linux驱动之设备树添加蜂鸣器驱动

目录 一、蜂鸣器简介 二、硬件原理分析 三、蜂鸣器驱动原理 四、开发环境 五、修改设备树文件 1、添加 pinctrl 节点 2、添加 BEEP 设备节点 3、检查 PIN 是否被其他外设使用 六、蜂鸣器驱动程序编写 七、测试程序编写 八、运行验证 在 I.MX6U-ALPHA 开发板上有一个有源…

【揽睿星舟】艺术二维码完全生成攻略

导航栏 一、云端平台 1-1、云端平台的优势&#xff1a; 1-2、选择适合的云端平台需要考虑以下几个方面&#xff1a; 二、账号注册界面如下&#xff1a; 三、生成方法 3-1、图像到图像 3-1-1、二维码生成 3-1-2、选择云端平台来启动Stable Diffusion的Web UI 3-1-3、使用S…

记录--使用 JS 实现基本的截图功能

这里给大家分享我在网上总结出来的一些知识&#xff0c;希望对大家有所帮助 思路分析 在开始动手之前&#xff0c;分析一下整个功能的实现过程&#xff1a; 根据图片大小创建 canvas1 画布&#xff0c;并将原图片直接定位在 canvas1 上&#xff1b; 在画布上添加一个蒙层&…

MySQL插入数据库 insert into 语句 用法总结

目录 步骤 一、建表&#xff1a; 二、插入第一行数据 二、插入第二行数据&#xff08;指定要插入字段&#xff09; 三、插入第三行数据&#xff08;指定要插入的字段&#xff0c;但不是所有字段&#xff0c;除了(stu_id, stu_gender)&#xff09; 四、使用一条insert in…

maven中常见问题

文章目录 一、配置项提示二、父子打包三、打包之后不显示target四、自定义打包之后的jar包名称五、整个项目打包5.1、父项目管理插件和微服务打包 一、配置项提示 SpringBoot中提示错误信息 表示的是SpringBoot中的注释提示没有配置&#xff01;那么可以来使用一下springboot官…