pytorch学习——如何构建一个神经网络——以手写数字识别为例

news2024/11/23 19:09:36

目录

一.概念介绍

1.1神经网络核心组件

1.2神经网络结构示意图

1.3使用pytorch构建神经网络的主要工具

二、实现手写数字识别

2.1环境

2.2主要步骤

2.3神经网络结构

2.4准备数据

2.4.1导入模块

2.4.2定义一些超参数

2.4.3下载数据并对数据进行预处理

2.4.4可视化数据集中部分元素

 2.4.5构建模型和实例化神经网络

2.4.6训练模型

2.4.7可视化损失函数

2.4.7.1 train  loss 

 2.4.7.2 test loss

一.概念介绍

        神经网络是一种计算模型,它模拟了人类神经系统的工作方式,由大量的神经元和它们之间的连接组成。每个神经元接收一些输入信息,并对这些信息进行处理,然后将结果传递给其他神经元。这些神经元之间的连接具有不同的权重,这些权重可以根据神经网络的训练数据进行调整。通过调整权重,神经网络可以对输入数据进行分类、回归、聚类等任务。

        通俗来讲,神经网络就是设置一堆参数,初始化这堆参数,然后通过求导,知道这些参数对结果的影响,然后调整这些参数的大小。直到参数大小可以接近完美地拟合实际结果。神经网络有两个部分:正向传播和反向传播。正向传播是求值,反向传播是求出参数对结果的影响,从而调整参数。所以,神经网络:正向传播->反向传播->正向传播->反向传播……     

        比如我们要预测一个图像是不是猫。如果是猫,它的结果就是1,如果不是猫,它的结果就是0.我们现在有一堆图片,有的是猫,有的不是猫,所以它对应的标签(这个是y)是:0 1 1 0 1。而我们的预测结果可能是对的,也可能是错的,假设我们的预测结果是:0 0 1 1 0.我们有3个预测对了,有2个预测错了。那么我们的损失值是2/5。当然这么搞的话太“粗糙”了,实际上我们会有一个函数来定义损失值是什么。而且我们的预测结果也不是一个确凿的数字,而是一个概率:比如我们预测第3张图片是猫的概率是0.8,那么我们的预测结果是0.8.总之,定义了损失值(这个损失值记为J)以后,我们要让这个损失值尽可能地小。

参考:什么是神经网络? - 绯红之刃的回答 - 知乎 

1.1神经网络核心组件

        神经网络看上去挺复杂,节点多,层多,参数多,但其结构都是类似的,核心部分和组件都是相通的,确定完这些核心组件,这个神经网络也就基本确定了。

核心组件包括:

(1)层:神经网络的基础数据结构是层,层是一个数据处理模块,它接受一个或多个张量作为输入,并输出一个或多个张量,由一组可调整参数描述。

(2)模型:模型是由多个层组成的网络,用于对输入数据进行分类、回归、聚类等任务。

 

(3)损失函数:参数学习的目标函数,通过最小化损失函数来学习各种参数。损失函数是衡量模型输出结果与真实标签之间的差异的函数,目标是最小化损失函数,提高模型性能。

(4)优化器:使损失函数的值最小化。根据损失函数的梯度更新神经网络中的权重和偏置,以使损失函数的值最小化,提高模型性能和稳定性。

1.2神经网络结构示意图

 描述:多个层链接在一起构成一个模型或网络,输入数据通过这个模型转换为预测值,然后损失函数把预测值与真实值进行比较,得到损失值(损失值可以是距离、概率值等),该损失值用于衡量预测值与目标结果的匹配或相似程度,优化器利用损失值更新权重参数,从而使损失值越来越小。这是一个循环过程,损失值达到一个阀值或循环次数到达指定次数,循环结束。

1.3使用pytorch构建神经网络的主要工具

 参考:第3章 Pytorch神经网络工具箱 | Python技术交流与分享

在PyTorch中,构建神经网络主要使用以下工具:

  1. torch.nn模块:提供了构建神经网络所需的各种层和模块,如全连接层、卷积层、池化层、循环神经网络等。

  2. torch.nn.functional模块:提供了一些常用的激活函数和损失函数,如ReLU、Sigmoid、CrossEntropyLoss等。

  3. torch.optim模块:提供了各种优化器,如SGD、Adam、RMSprop等,用于更新神经网络中的权重和偏置。

  4. torch.utils.data模块:提供了处理数据集的工具,如Dataset、DataLoader等,可以方便地处理数据集、进行批量训练等操作。

这些工具之间的相互关系如下:

  1. 使用torch.nn模块构建神经网络的各个层和模块。

  2. 使用torch.nn.functional模块中的激活函数和损失函数对神经网络进行非线性变换和优化。

  3. 使用torch.optim模块中的优化器对神经网络中的权重和偏置进行更新,以最小化损失函数。

  4. 使用torch.utils.data模块中的数据处理工具对数据集进行处理,方便地进行批量训练和数据预处理。

二、实现手写数字识别

2.1环境

        实例环境使用Pytorch1.0+,GPU或CPU,源数据集为MNIST。

2.2主要步骤

(1)利用Pytorch内置函数mnist下载数据
(2)利用torchvision对数据进行预处理,调用torch.utils建立一个数据迭代器
(3)可视化源数据
(4)利用nn工具箱构建神经网络模型
(5)实例化模型,并定义损失函数及优化器
(6)训练模型
(7)可视化结果

2.3神经网络结构

实验中使用两个隐含层,每层激活函数为Relu,最后使用torch.max(out,1)找出张量out最大值对应索引作为预测值。

2.4准备数据

2.4.1导入模块

import numpy as np
import torch
# 导入 pytorch 内置的 mnist 数据
from torchvision.datasets import mnist 
#导入预处理模块
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
#导入nn及优化器
import torch.nn.functional as F
import torch.optim as optim
from torch import nn

2.4.2定义一些超参数

# 定义训练和测试时的批处理大小
train_batch_size = 64
test_batch_size = 128

# 定义学习率和迭代次数
learning_rate = 0.01
num_epoches = 20

# 定义优化器的超参数
lr = 0.01
momentum = 0.5
#动量优化器通过引入动量参数(Momentum),在更新参数时考虑之前的梯度信息,可以使得参数更新方向更加稳定,同时加速梯度下降的收敛速度。动量参数通常设置在0.5到0.9之间,可以根据具体情况进行调整。

2.4.3下载数据并对数据进行预处理

#定义预处理函数,这些预处理依次放在Compose函数中。
transform = transforms.Compose([transforms.ToTensor(),transforms.Normalize([0.5], [0.5])])
#下载数据,并对数据进行预处理
train_dataset = mnist.MNIST('./data', train=True, transform=transform, download=True)
test_dataset = mnist.MNIST('./data', train=False, transform=transform)
#dataloader是一个可迭代对象,可以使用迭代器一样使用。
train_loader = DataLoader(train_dataset, batch_size=train_batch_size, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=test_batch_size, shuffle=False)

注:

①transforms.Compose可以把一些转换函数组合在一起;
②Normalize([0.5], [0.5])对张量进行归一化,这里两个0.5分别表示对张量进行归一化的全局平均值和方差。因图像是灰色的只有一个通道,如果有多个通道,需要有多个数字,如三个通道,应该是Normalize([m1,m2,m3], [n1,n2,n3])
③download参数控制是否需要下载,如果./data目录下已有MNIST,可选择False。
④用DataLoader得到生成器,这可节省内存。

2.4.4可视化数据集中部分元素

# 导入matplotlib.pyplot库,并设置inline模式
import matplotlib.pyplot as plt
%matplotlib inline

# 枚举数据加载器中的一批数据
examples = enumerate(test_loader)
batch_idx, (example_data, example_targets) = next(examples)

# 创建一个图像对象
fig = plt.figure()

# 显示前6个图像和对应的标签
for i in range(6):
  plt.subplot(2,3,i+1)           # 将图像分成2行3列,当前位置为第i+1个
  plt.tight_layout()             # 自动调整子图之间的间距
  plt.imshow(example_data[i][0], cmap='gray', interpolation='none')  # 显示图像
  plt.title("Ground Truth: {}".format(example_targets[i]))          # 显示标签
  plt.xticks([])                 # 隐藏x轴刻度
  plt.yticks([])                 # 隐藏y轴刻度

注:

  1. 导入matplotlib.pyplot库,并设置inline模式,以在Jupyter Notebook中显示图像。

  2. 枚举数据加载器中的一批数据,其中test_loader是一个测试数据集加载器。

  3. 创建一个图像对象,用于显示图像和标签。

  4. 显示前6个图像和对应的标签,其中plt.subplot()用于将图像分成2行3列,plt.tight_layout()用于自动调整子图之间的间距,plt.imshow()用于显示图像,plt.title()用于显示标签,plt.xticks()和plt.yticks()用于隐藏x轴和y轴的刻度。

 2.4.5构建模型和实例化神经网络

class Net(nn.Module):
    """
    使用sequential构建网络,Sequential()函数的功能是将网络的层组合到一起
    """
    def __init__(self, in_dim, n_hidden_1, n_hidden_2, out_dim):
        super(Net, self).__init__()
        self.layer1 = nn.Sequential(nn.Linear(in_dim, n_hidden_1),nn.BatchNorm1d(n_hidden_1))
        self.layer2 = nn.Sequential(nn.Linear(n_hidden_1, n_hidden_2),nn.BatchNorm1d(n_hidden_2))
        self.layer3 = nn.Sequential(nn.Linear(n_hidden_2, out_dim))
        
 
    def forward(self, x):
        x = F.relu(self.layer1(x))
        x = F.relu(self.layer2(x))
        x = self.layer3(x)
        return x


#检测是否有可用的GPU,有则使用,否则使用CPU
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
#实例化网络
model = Net(28 * 28, 300, 100, 10)
model.to(device)
 
# 定义损失函数和优化器
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(model.parameters(), lr=lr, momentum=momentum)

2.4.6训练模型

# 开始训练
losses = []
acces = []
eval_losses = []
eval_acces = []
 
 
for epoch in range(num_epoches):
    train_loss = 0
    train_acc = 0
    model.train()
    #动态修改参数学习率
    if epoch%5==0:
        optimizer.param_groups[0]['lr']*=0.1
    for img, label in train_loader:
        img=img.to(device)
        label = label.to(device)
        img = img.view(img.size(0), -1)
        # 前向传播
        out = model(img)
        loss = criterion(out, label)
        # 反向传播
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        # 记录误差
        train_loss += loss.item()
        # 计算分类的准确率
        _, pred = out.max(1)
        num_correct = (pred == label).sum().item()
        acc = num_correct / img.shape[0]
        train_acc += acc
        
    losses.append(train_loss / len(train_loader))
    acces.append(train_acc / len(train_loader))
    # 在测试集上检验效果
    eval_loss = 0
    eval_acc = 0
    # 将模型改为预测模式
    model.eval()
    for img, label in test_loader:
        img=img.to(device)
        label = label.to(device)
        img = img.view(img.size(0), -1)
        out = model(img)
        loss = criterion(out, label)
        # 记录误差
        eval_loss += loss.item()
        # 记录准确率
        _, pred = out.max(1)
        num_correct = (pred == label).sum().item()
        acc = num_correct / img.shape[0]
        eval_acc += acc
        
    eval_losses.append(eval_loss / len(test_loader))
    eval_acces.append(eval_acc / len(test_loader))
    print('epoch: {}, Train Loss: {:.4f}, Train Acc: {:.4f}, Test Loss: {:.4f}, Test Acc: {:.4f}'
          .format(epoch, train_loss / len(train_loader), train_acc / len(train_loader), 
                     eval_loss / len(test_loader), eval_acc / len(test_loader)))

2.4.7可视化损失函数

2.4.7.1 train  loss 

plt.title('train loss')
plt.plot(np.arange(len(losses)), losses)
plt.legend(['Train Loss'], loc='upper right')

 2.4.7.2 test loss

# 绘制测试集损失函数
plt.plot(eval_losses, label='Test Loss')
plt.title('Test Loss')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.legend()
plt.show()

 

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/828454.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

TSINGSEE青犀视频智能视频监控EasyCVR如何将实时监控视频流分享出去?

开源EasyDarwin视频监控平台EasyCVR能在复杂的网络环境中,将分散的各类视频资源进行统一汇聚、整合、集中管理,在视频监控播放上,TSINGSEE青犀视频安防监控平台可支持1、4、9、16个画面窗口播放,可同时播放多路视频流,…

【深度学习】Vision Transformer论文,ViT的一些见解《 一幅图像抵得上16x16个词:用于大规模图像识别的Transformer模型》

必看文章:https://blog.csdn.net/qq_37541097/article/details/118242600 论文名称: An Image Is Worth 16x16 Words: Transformers For Image Recognition At Scale 论文下载:https://arxiv.org/abs/2010.11929 官方代码:https:…

【微信小程序】保存多张图片到本地相册

<template><view class"container"><u-swiper :list"list" circular radius0 indicator indicatorModedot height950rpx></u-swiper><view class"btn btn2" click"saveFun">保存到相册</view><…

【MySQL】当前读和快照读

文章目录 当前读快照读 在学习 MVCC 多版本并发控制之前&#xff0c;必须先了解一下&#xff0c;什么是 MySQL InnoDB 下的 当前读和 快照读? 当前读 读取的是记录的最新版本&#xff0c;读取时还要保证其他并发事务不能修改当前记录&#xff0c;会对读取的记录进行加锁。对…

Rocky(centos) jar 注册成服务,能开机自启动

概述 涉及&#xff1a;1&#xff09;sh 无法直接运行java命令&#xff0c;可以软连&#xff0c;此处是直接路径 2&#xff09;sh脚本报一堆空格换行错误&#xff1a;需将转成unix标准格式&#xff1b; #切换到上传的脚本路径 dos2unix 脚本文件名.sh 2&#xff09;SELINUX …

Ubuntu18.04安装ROS

ROS 安装前的准备 &#xff08;1&#xff09;为了安装顺利使用国内下载源&#xff0c;&#xff08;我个人linux使用了代理&#xff09; 清华大学源 sudo sh -c ‘. /etc/lsb-release && echo “deb http://mirrors.tuna.tsinghua.edu.cn/ros/ubuntu/ $DISTRIB_CODENA…

第四讲:利用ADO方式连接Access数据库

【分享成果&#xff0c;随喜正能量】最值得信赖的&#xff0c;其实是自己从孤独中得来的东西&#xff0c;而不是别人给予自己的东西。每个人都是一座孤岛&#xff0c;有些人一生都在想要逃离这座岛&#xff0c;有些人一生都在创造并丰富自己这座岛。。 《VBA数据库解决方案》教…

水环境地质3D可视化综合管理软件提高运维效率

谈起数字孪生技术&#xff0c;总让人兴奋不已&#xff0c;这种将物理实体的数字化模型与实际物理实体相结合、以虚控实的技术&#xff0c;是数字化转型和第四次工业革命的重要载体&#xff0c;那么在地质行业中&#xff0c;数字孪生有哪些应用场景? 在地质勘探中&#xff0c; …

潜在客户生成最实用指南,你还在等什么?

潜在客户是指对您的产品或服务表现出兴趣的人&#xff0c;它们提供个人识别信息&#xff0c;您的团队可以使用这些信息来跟进他们。随着越来越多的营销渠道涌现&#xff0c;接触新客户和开展有效的潜在客户生成活动变得越来越困难。赢得新的潜在客户听起来很困难&#xff0c;但…

【B/S手术麻醉系统源码】手术麻醉管理系统在临床中的应用

手术麻醉管理系统是临床麻醉工作中一个不容忽视的环节&#xff0c;麻醉医生必须对病人在麻醉手术过程中的情况与体征变化&#xff0c;采取的处理措施及术后随访等全过程作出及时、真实、确切的记录。麻醉记录不仅有助于确保临床麻醉准确&#xff0c;总结经验教训&#xff0c;提…

kcc呼叫中心语音转写功能

呼叫中心是客户和企业之间沟通的一个桥梁&#xff0c;也是客户服务和客户关系的一个重要组成部分。通过呼叫中心&#xff0c;企业可以建立起一个以客户为中心的服务模式&#xff0c;为客户提供高质量、高效率的服务&#xff0c;对于塑造企业形象&#xff0c;提高客户满意度&…

txt替换字符为换行

txt替换字符为换行 txt如何批量将同一个符号替换成换行符 2023-01-26 03:16:03 有时候看到网页有些排列的很整齐的文本想复制使用&#xff0c;可是复制下来放到txt后不分行&#xff0c;要么就是中间隔着一些公用的符号。那么我们怎么才能快速的让文本按原来的形式分行显示在…

一百三十九、Kettle——Linux安装Kettle8.2

一、目的 为了方便海豚调度kettle任务&#xff0c;在Linux上安装kettle 二、kettle版本与前提 版本&#xff1a;kettle8.2 pdi-ce-8.2.0.0-342 前提&#xff1a;Linux已经安装好jdk 三、安装步骤 &#xff08;一&#xff09;打开安装包所在地 [roothurys22 ~]# cd …

让SpringBoot不再需要Controller、Service、DAO、Mapper,卧槽!这款工具绝了

Dataway 是基于 DataQL 服务聚合能力&#xff0c;为应用提供的一个接口配置工具&#xff0c;使得使用者无需开发任何代码就配置一个满足需求的接口。整个接口配置、测试、冒烟、发布&#xff0c;一站式都通过 Dataway 提供的 UI 界面完成。UI 会以 Jar 包方式提供并集成到应用中…

小鹏遭遇“动荡”,自动驾驶副总裁吴新宙离职,现已完成团队过渡

根据最新消息&#xff0c;小鹏汽车的自动驾驶副总裁吴新宙宣布将加入全球GPU芯片巨头英伟达。吴新宙将成为该公司全球副总裁&#xff0c;直接向英伟达全球CEO黄仁勋汇报。小鹏汽车董事长何小鹏和吴新宙本人已在微博上确认该消息&#xff0c;并解释离职原因涉及家庭和多方面因素…

拿捏--->逻辑推断问题(猜凶手+猜名次)

文章目录 猜凶手问题题目描述算法思路代码实现 猜名次问题题目描述算法思路代码实现 猜凶手问题 题目描述 算法思路 代码实现 #define _CRT_SECURE_NO_WARNINGS #include<stdio.h> int main() {char killer 0;for (killer A; killer < D; killer){if ((killer ! …

Rpc异步日志模块

Rpc异步日志模块作用 在一个大型分布式系统中&#xff0c;任何部署的分布式节点都可能发生崩溃&#xff0c;试想如果用普通的办法&#xff0c;即先排查哪个节点down掉了&#xff0c;找到down掉的节点后采取调试工具gdb调试该节点&#xff0c;进而排查宕机的原因。这中排查方法…

Redis的缓存、消息队列、计数器应用

目录 一、redis的应用场景 二、redis如何用于缓存 三、redis如何用于消息队列 四、redis如何用于计数器 一、redis的应用场景 Redis在实际应用中有广泛的应用场景&#xff0c;以下是一些常见的Redis应用场景&#xff1a; 缓存&#xff1a;Redis可以用作缓存层&#xff0c;…

高性能远程通信框架grpc基本使用

文章目录 一、了解grpc二、关于protobuf三、试玩grpc3.1整个工程目录3.2 proto文件编写3.3 使用maven protobuf插件转换.proto文件3.4 Grpc服务端业务实现类3.5 pom参考3.6 grpc client调用 一、了解grpc 谷歌开源远程进程调用框架&#xff0c;支持多语言系统间通信&#xff0…

百分点科技跻身中国智慧应急人工智能解决方案市场前三

近日&#xff0c; 全球领先的IT市场研究和咨询公司IDC发布了《中国智慧应急解决方案市场份额&#xff0c;2022》报告&#xff0c;数据显示&#xff0c;2022年中国智慧应急整体市场为104亿元人民币。其中&#xff0c;智慧应急人工智能解决方案子市场备受关注&#xff0c;百分点科…