人工智能-深度学习-BP算法

news2024/12/27 9:25:08

BP算法的核心思想是通过计算损失函数对网络参数的梯度,然后使用梯度下降法来更新网络参数,从而最小化损失函数。

误差反向传播算法(BP)的基本步骤:

  1. 前向传播:正向计算得到预测值。

  2. 计算损失:通过损失函数计算预测值和真实值的差距。

  3. 梯度计算:反向传播的核心是计算损失函数对每个权重和偏置的梯度。

  4. 更新参数:一旦得到每层梯度,就可以使用梯度下降算法来更新每层的权重和偏置,使得损失逐渐减小。

  5. 迭代训练:将前向传播、梯度计算、参数更新的步骤重复多次,直到损失函数收敛或达到预定的停止条件。


前向传播

        英文(Foward propagathon),表示将输入的数据逐级向前的每个神经元运算传输,直到到达输出层。

        前向传播的目的是计算网络的预测值,以便后续计算损失函数并进行反向传播。

以下是一个三层神经网络实例,包含输入层(x1,x2,x3,b),隐藏层,输出层

链接输入和隐藏层的就是我们熟知的权重(w),在隐藏层中会完成两个动作

1.加权求和再求平均值,用于预测或计算损失。

2.再就是使用激活函数,将原本的线性公式,运算转为非线性。激活函数的使用视数据的情况而定

通常有sigmoid,Tanh,ReLu,softmax等激活函数。

x1,x2,x3,x4每个特征向量分别对隐藏层的n个神经元做线性运算,

再经过激活函数激活后得到最终预测值。

第一层的输出值等于第二层的输入值,依次经过运算。

        代码逻辑:

def test():
    # 前向传播
    i = torch.tensor([0.05,0.1])
    model1 = torch.nn.Linear(2,2)
    model1.weight.data = torch.tensor([
        [0.15,0.20],[0.25,0.30]
    ])
    model1.bias.data = torch.tensor([0.35,0.35])
    l1_l2 = model1(i)
    h1_h2 = torch.nn.Sigmoid(l1_l2)
    model2 = torch.nn.Linear(2,2)
    model2.weight.data = torch.tensor([
        [0.40,0.45],[0.50,0.55]
    ])
    model2.bias.data = torch.tensor([0.60,0.60])
    l3_l4 = model2(h1_h2)
    o1_02 = torch.nn.Sigmoid(l3_l4)

反向传播

        通过计算损失函数相对于每个参数的梯度来调整权重,使模型在训练数据上的表现逐渐优化。反向传播结合了链式求导法则和梯度下降算法,是神经网络模型训练过程中更新参数的关键步骤。

        前向传播的作用就是为了得到预测值,然后来为反向传播做准备。

        反向传播是为了得到更好的w,也就是我们的权重。


链式法则

        在深度学习中,链式法则是反向传播算法的基础,这样就可以通过分层的计算求得损失函数相对于每个参数的梯度。

复合函数的复杂性

  • 神经网络的输出是输入数据经过多层线性变换和非线性激活函数后的结果。每一层的输出都是下一层的输入,形成了一个复合函数。

  • 例如,假设有一个三层神经网络,输出 yy 可以表示为:

其中,f、g、h 分别是不同层的激活函数。

链式求导法通过将复合函数的导数分解为各个简单函数的导数的乘积,简化了求导过程。

也就是求得我们每一层的导函数,以便于做梯度更新,找到最小损失。

反向传播中的数学运算。

实在看不懂数学公式。。。

直接上API。

import torch
#第一步,创建一个神经网络类,继承官方的nn.module
class mybet(torch.nn.Module):
    #定义网络结构
    def __init__(self,input_size,output_size):
        #初始化父类:语法要求调用super方法生成父类的功能让子类继承父类的功能
        super(mybet,self).__init__()
        #定义网格结构
        self.hide1 = torch.nn.Sequential(torch.nn.Linear(input_size,3),torch.nn.Sigmoid())
        self.hide2 = torch.nn.Sequential(torch.nn.Linear(3,2),torch.nn.Sigmoid())
        self.out = torch.nn.Sequential(torch.nn.Linear(2,output_size),torch.nn.Sigmoid())
    def forward(self,input):
        input.shape[1]
        x = self.hide1(input)
        x = self.hide2(x)
        pred = self.out(x)
        return pred

def train():
    #数据集读取(这里自己编)
    input = torch.tensor([[0.5, 0.1],
                          [0.05, 0.180],
                          [0.05, 0.310]])
    target = torch.tensor([[1, 2],
                           [0, 3],
                           [1, 123]], dtype=torch.float32)
    #创建网格
    net = mybet(2,2)
    #定义损失函数
    loos_func = torch.nn.MSELoss()
    #定义优化器
    optimizer = torch.optim.SGD(net.parameters(),lr=0.01)
    #训练
    for epoch in range(500):
        #前向传播
        y_pred = net(input)
        #计算损失
        loss = loos_func(y_pred,target)
        #梯度清零
        optimizer.zero_grad()
        #反向传播(计算每一层w的偏导数(梯度值))
        loss.backward()
        print(net.hide1[0].weight)
        break
if __name__ == '__main__':
    
    train()

完整的全连接。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2251944.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

(免费送源码)计算机毕业设计原创定制:Apache+JSP+Ajax+Springboot+MySQL Springboot自习室在线预约系统

摘 要 远程预约是一种全新的网络租用方式,它通过互联网突破了时间和空间限制,实现了便捷快速的预约与管理功能。在对数据信息有效组织并整合了一定使用功能后,远程预约系统可以方便地实现预约与取消,以及信息查询等功能。经过本人…

【51单片机】程序实验910.直流电机-步进电机

主要参考学习资料:B站【普中官方】51单片机手把手教学视频 前置知识:C语言 单片机套装:普中STC51单片机开发板A4标准版套餐7 码字不易,求点赞收藏加关注(•ω•̥) 有问题欢迎评论区讨论~ 目录 程序实验9&10.直流电机-步进电机…

windows 应用 UI 自动化实战

UI 自动化技术架构选型 UI 自动化是软件测试过程中的重要一环,网络上也有很多 UI 自动化相关的知识或资料,具体到 windows 端的 UI 自动化,我们需要从以下几个方面考虑: 开发语言 毋庸置疑,在 UI 自动化测试领域&am…

我不是挂王-用python实现燕双鹰小游戏

一.准备工作 1.前言提要 作为程序员在浩瀚的数字宇宙中,常常感觉现实世界是一台精密运作的虚拟机,其底层的物理逻辑如同铁律般难以撼动。然而我们拥有在虚拟世界中自由驰骋、创造无限可能的独特力量。突发奇我想用Python写出燕双鹰的小游戏,这样想想就很…

会议直击|美格智能亮相2024紫光展锐全球合作伙伴大会,融合5G+AI共拓全球市场

11月26日,2024紫光展锐全球合作伙伴大会在上海举办,作为紫光展锐年度盛会,吸引来自全球的众多合作伙伴和行业专家、学者共同参与。美格智能与紫光展锐竭诚合作多年,共同面向5G、AI和卫星通信为代表的前沿科技,聚焦技术…

3. STM32_串口

数据通信的基础概念 什么是串行/并行通信: 串行通信就是数据逐位按顺序依次传输 并行通信就是数据各位通过多条线同时传输。 什么是单工/半双工/全双工通信: 单工通信:数据只能沿一个方向传输 半双工通信:数据可以沿两个方向…

RPC与HTTP调用模式的架构差异

RPC(Remote Procedure Call,远程过程调用)和 HTTP 调用是两种常见的通信模式,它们在架构上有以下一些主要差异: 协议层面 RPC:通常使用自定义的二进制协议,对数据进行高效的序列化和反序列化&am…

Microsoft Excel如何插入多行

1.打开要编辑的excel表,在指定位置,鼠标右键点击“插入”一行 2.按住shift键,鼠标的光标箭头会变化成如下图所示 3.一直按住shift键和鼠标左键,往下拖动,直至到插入足够的行

【python】图像、音频、视频等文件数据采集

【python】图像、音频、视频等文件数据采集 先安装所需要的工具一、Tesseract-OCRTesseract-OCR环境变量设置验证是否配置成功示例语言包下载失败 二、ffmpeg验证是否安装成功示例 先安装所需要的工具 一、Tesseract-OCR Tesseract是一个 由HP实验室开发 由Google维护的开源的…

虚拟机docker记录

最近看了一个up的这个视频,感觉docker真的挺不错的,遂也想来搞一下: https://www.bilibili.com/video/BV1QC4y1A7Xi/?spm_id_from333.337.search-card.all.click&vd_sourcef5fd730321bc0e9ca497d98869046942 这里我用的是vmware安装ubu…

C++STL之vector(超详细)

CSTL之vector 1.vector基本介绍2.vector重要接口2.1.构造函数2.2.迭代器2.3.空间2.3.1.resize2.3.2.capacity 2.4.增删查找 3.迭代器失效4.迭代器分类 🌟🌟hello,各位读者大大们你们好呀🌟🌟 🚀&#x1f68…

深度学习实验十三 卷积神经网络(4)——使用预训练resnet18实现CIFAR-10分类

目录 一、数据加载 二、数据集类构建 三、模型构建 四、模型训练 五、模型评价及预测 附完整可运行代码: 实验大体步骤: 注: 在自己电脑的CPU跑代码 连接远程服务器跑代码√ 本次实验由于数据量巨大,我的笔记本上还没有…

【Maven Helper】分析依赖冲突案例

目录 Maven Helper实际案例java文件pom.xml文件运行抛出异常分析 参考资料 《咏鹅》骆宾王 鹅,鹅,鹅,曲项向天歌。 白毛浮绿水,红掌拨清波。 骆宾王是在自己7岁的时候就写下了这首杂言 Maven Helper A must have plugin for wor…

Android 桌面窗口新功能推进,聊一聊 Android 桌面化的未来

Android 桌面化支持可以说是 Android 15 里被多次提及的 new features,例如在 Android 15 QPR1 Beta 2 里就提到为 Pixel 平板引入了桌面窗口支持,桌面窗口允许用户在自由窗口同时运行多个应用,同时可以像在传统 PC 平台上一样调整这些窗口的…

【深度学习】四大图像分类网络之VGGNet

2014年,牛津大学计算机视觉组(Visual Geometry Group)和Google DeepMind公司一起研发了新的卷积神经网络,并命名为VGGNet。VGGNet是比AlexNet更深的深度卷积神经网络,该模型获得了2014年ILSVRC竞赛的第二名&#xff0c…

Pytest框架学习20--conftest.py

conftest.py作用 正常情况下,如果多个py文件之间需要共享数据,如一个变量,或者调用一个方法 需要先在一个新文件中编写函数等,然后在使用的文件中导入,然后使用 pytest中定义个conftest.py来实现数据,参…

【力扣】389.找不同

问题描述 思路解析 只有小写字母,这种设计参数小的,直接桶排序我最开始的想法是使用两个不同的数组,分别存入他们单个字符转换后的值,然后比较是否相同。也确实通过了 看了题解后,发现可以优化,首先因为t相…

HarmonyOS4+NEXT星河版入门与项目实战(23)------组件转场动画

文章目录 1、控件图解2、案例实现1、代码实现2、代码解释3、实现效果4、总结1、控件图解 这里我们用一张完整的图来汇整 组件转场动画的用法格式、属性和事件,如下所示: 2、案例实现 这里我们对上一节小鱼游戏进行改造,让小鱼在游戏开始的时候增加一个转场动画,让小鱼自…

Wireshark常用功能使用说明

此处用于记录下本人所使用 wireshark 所可能用到的小技巧。Wireshark是一款强大的数据包分析工具,此处仅介绍常用功能。 Wireshark常用功能使用说明 1.相关介绍1.1.工具栏功能介绍1.1.1.时间戳/分组列表概况等设置 1.2.Windows抓包 2.wireshark过滤器规则2.1.wiresh…

Vue3 开源UI 框架推荐 (大全)

一 、前言 💥这篇文章主要推荐了支持 Vue3 的开源 UI 框架,包括 web 端和移动端的多个框架,如 Element-Plus、Ant Design Vue 等 web 端框架,以及 Vant、NutUI 等移动端框架,并分别介绍了它们的特性和资源地址。&#…