【22-23 春学期】人工智能基础--AI作业6-误差反向传播

news2024/11/15 20:34:05

老师发布作业链接:(429条消息) 【22-23 春学期】AI作业6-误差反向传播_HBU_David的博客-CSDN博客

目录

老师发布作业链接:(429条消息) 【22-23 春学期】AI作业6-误差反向传播_HBU_David的博客-CSDN博客

1.梯度下降

 2.反向传播

3.计算图

4.使用Numpy编程实现例题

5.使用PyTorch的Backward()编程实现例题

1.梯度下降

梯度下降是一种最小化目标函数的优化算法,在机器学习中经常使用。其基本思想是通过反复迭代来逐步调整模型参数,使目标函数的值不断减小,从而达到最小化目标函数的目的。

在每一次迭代中,梯度下降算法会计算目标函数关于当前参数的梯度,即目标函数在当前参数点处的斜率,然后朝着梯度下降的方向调整参数,使得目标函数值减小。如果梯度为正,则参数向负方向移动;如果梯度为负,则参数向正方向移动。重复这个过程,直到找到局部或全局最小值,或者达到预定的停止条件。

梯度下降算法有不同的变种,如批量梯度下降(Batch Gradient Descent)、随机梯度下降(Stochastic Gradient Descent)、小批量梯度下降(Mini-Batch Gradient Descent)等。这些变种算法主要区别在于如何计算梯度和如何更新参数。

 2.反向传播

反向传播(Backpropagation)是一种用于计算神经网络中每个参数对损失函数的导数的算法。它是训练神经网络中的权重和偏置的关键步骤。

在训练过程中,反向传播从输出层开始向前传播误差,通过链式法则计算每一层的误差贡献,然后再通过链式法则计算每个参数对误差的贡献,最终得到每个参数的梯度。这些梯度可以用于更新网络中的权重和偏置,使得损失函数得到最小化。

具体来说,反向传播的过程可以分为两个阶段:前向传播和反向传播。在前向传播阶段,神经网络将输入数据通过每一层的权重和偏置计算得到输出结果,然后计算与真实值的误差。在反向传播阶段,误差从输出层开始向前传播,通过链式法则计算每一层的误差贡献,最终计算得到每个参数的梯度。

通过反向传播算法,神经网络可以自动计算每个参数对损失函数的影响,并根据这些影响来更新参数,从而使神经网络逐步优化,提高预测精度。

3.计算图

计算图(Computational Graph)是一种图形化表示计算过程的方式。在机器学习中,计算图通常用于表示神经网络中的计算流程,从而方便进行求导和优化。

在计算图中,节点表示变量或操作,边表示数据流。计算图中的每个节点都对应一个数学运算,例如加法、乘法、卷积等。每个节点的输入和输出都是张量,可以是标量、向量、矩阵或高维张量。

计算图可以分为静态计算图和动态计算图两种类型。静态计算图在计算前需要预先定义好网络结构和参数,然后将整个计算流程编译为计算图,进行优化和求导。而动态计算图则是在运行时动态生成,每个计算步骤都可以根据需要重新构造计算图,可以更加灵活。

通过计算图,可以清晰地了解每个节点之间的依赖关系,从而更好地理解神经网络中的计算流程,方便进行求导和优化。同时,计算图还可以通过自动微分技术计算导数,为反向传播等算法提供基础支持。

4.使用Numpy编程实现例题

import numpy as np
import matplotlib.pyplot as plt
 
 
def sigmoid(z):
    a = 1 / (1 + np.exp(-z))
    return a
 
 
def forward_propagate(x1, x2, y1, y2, w1, w2, w3, w4, w5, w6, w7, w8): # 正向传播
    in_h1 = w1 * x1 + w3 * x2
    out_h1 = sigmoid(in_h1)
    in_h2 = w2 * x1 + w4 * x2
    out_h2 = sigmoid(in_h2)
 
    in_o1 = w5 * out_h1 + w7 * out_h2
    out_o1 = sigmoid(in_o1)
    in_o2 = w6 * out_h1 + w8 * out_h2
    out_o2 = sigmoid(in_o2)
 
    error = (1 / 2) * (out_o1 - y1) ** 2 + (1 / 2) * (out_o2 - y2) ** 2
 
    return out_o1, out_o2, out_h1, out_h2, error
 
 
def back_propagate(out_o1, out_o2, out_h1, out_h2):    # 反向传播
    d_o1 = out_o1 - y1
    d_o2 = out_o2 - y2
 
    d_w5 = d_o1 * out_o1 * (1 - out_o1) * out_h1
    d_w7 = d_o1 * out_o1 * (1 - out_o1) * out_h2
    d_w6 = d_o2 * out_o2 * (1 - out_o2) * out_h1
    d_w8 = d_o2 * out_o2 * (1 - out_o2) * out_h2
 
    d_w1 = (d_w5 + d_w6) * out_h1 * (1 - out_h1) * x1
    d_w3 = (d_w5 + d_w6) * out_h1 * (1 - out_h1) * x2
    d_w2 = (d_w7 + d_w8) * out_h2 * (1 - out_h2) * x1
    d_w4 = (d_w7 + d_w8) * out_h2 * (1 - out_h2) * x2
 
    return d_w1, d_w2, d_w3, d_w4, d_w5, d_w6, d_w7, d_w8
 
 
def update_w(step,w1, w2, w3, w4, w5, w6, w7, w8):    #梯度下降,更新权值
    w1 = w1 - step * d_w1
    w2 = w2 - step * d_w2
    w3 = w3 - step * d_w3
    w4 = w4 - step * d_w4
    w5 = w5 - step * d_w5
    w6 = w6 - step * d_w6
    w7 = w7 - step * d_w7
    w8 = w8 - step * d_w8
    return w1, w2, w3, w4, w5, w6, w7, w8
 
 
if __name__ == "__main__":
    w1, w2, w3, w4, w5, w6, w7, w8 = 0.2, -0.4, 0.5, 0.6, 0.1, -0.5, -0.3, 0.8 # 可以给随机值,为配合PPT,给的指定值
    x1, x2 = 0.5, 0.3   # 输入值
    y1, y2 = 0.23, -0.07 # 正数可以准确收敛;负数不行。why? 因为用sigmoid输出,y1, y2 在 (0,1)范围内。
    N = 10             # 迭代次数
    step = 10           # 步长
 
    print("输入值:x1, x2;",x1, x2, "输出值:y1, y2:", y1, y2)
    eli = []
    lli = []
    for i in range(N):
        print("=====第" + str(i) + "轮=====")
        # 正向传播
        out_o1, out_o2, out_h1, out_h2, error = forward_propagate(x1, x2, y1, y2, w1, w2, w3, w4, w5, w6, w7, w8)
        print("正向传播:", round(out_o1, 5), round(out_o2, 5))
        print("损失函数:", round(error, 2))
        # 反向传播
        d_w1, d_w2, d_w3, d_w4, d_w5, d_w6, d_w7, d_w8 = back_propagate(out_o1, out_o2, out_h1, out_h2)
        # 梯度下降,更新权值
        w1, w2, w3, w4, w5, w6, w7, w8 = update_w(step,w1, w2, w3, w4, w5, w6, w7, w8)
        eli.append(i)
        lli.append(error)
 
 
    plt.plot(eli, lli)
    plt.ylabel('Loss')
    plt.xlabel('w')
    plt.show()

 

5.使用PyTorch的Backward()编程实现例题

​
import torch
# prepare dataset
# x,y是矩阵,3行1列 也就是说总共有3个数据,每个数据只有1个特征
x_data = torch.tensor([[1.0], [2.0], [3.0]])
y_data = torch.tensor([[2.0], [4.0], [6.0]])
 
#design model using class
"""
our model class should be inherit from nn.Module, which is base class for all neural network modules.
member methods __init__() and forward() have to be implemented
class nn.linear contain two member Tensors: weight and bias
class nn.Linear has implemented the magic method __call__(),which enable the instance of the class can
be called just like a function.Normally the forward() will be called 
"""
class LinearModel(torch.nn.Module):
    def __init__(self):
        super(LinearModel, self).__init__()
        # (1,1)是指输入x和输出y的特征维度,这里数据集中的x和y的特征都是1维的
        # 该线性层需要学习的参数是w和b  获取w/b的方式分别是~linear.weight/linear.bias
        self.linear = torch.nn.Linear(1, 1)
 
    def forward(self, x):
        y_pred = self.linear(x)
        return y_pred
 
model = LinearModel()
 
# construct loss and optimizer
# criterion = torch.nn.MSELoss(size_average = False)
criterion = torch.nn.MSELoss(reduction = 'sum')
optimizer = torch.optim.SGD(model.parameters(), lr = 0.01) # model.parameters()自动完成参数的初始化操作
 
# training cycle forward, backward, update
for epoch in range(100):
    y_pred = model(x_data) # forward:predict
    loss = criterion(y_pred, y_data) # forward: loss
    print(epoch, loss.item())
 
    optimizer.zero_grad() # the grad computer by .backward() will be accumulated. so before backward, remember set the grad to zero
    loss.backward() # backward: autograd,自动计算梯度
    optimizer.step() # update 参数,即更新w和b的值
 
print('w = ', model.linear.weight.item())
print('b = ', model.linear.bias.item())
 
x_test = torch.tensor([[4.0]])
y_test = model(x_test)
print('y_pred = ', y_test.data)


​

 

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/464293.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

ROS第四十四节——路径规划

1.新建launch文件 nav05_path.launch <launch><node pkg"move_base" type"move_base" respawn"false" name"move_base" output"screen" clear_params"true"><rosparam file"$(find nav_dem…

Mybatis 知识总结1(入门、JDBC、数据库连接池、lombok介绍)

Mybatis 知识总结&#xff08;一&#xff09; 3.1 Mybatis 介绍 什么是 Mybatis ? MyBatis 是一款优秀的持久层框架&#xff0c;用户简化 JDBC 的开发。MyBatis 是 Apache 的一个开源项目 iBaits&#xff0c;2010年这个项目由 apache 迁移到了 google code&#xff0c;并且改…

Chapter 6 :CDC Constraints(ug903)

6.1 About CDC Constraints 跨时钟域&#xff08;CDC&#xff09;约束适用于具有不同发射和捕获时钟的时序路径。根据发射和捕获时钟的关系以及在CDC路径上设置的时序异常&#xff0c;有同步CDC和异步CDC。例如&#xff0c;同步时钟之间但被错误路径约束覆盖的CDC路径…

活动目录域服务

域用户能够使用[username]csk.cn进行登录。 创建一个名为“CSK”的OU&#xff0c;并新建以下域用户和组&#xff1a; sa01-sa20&#xff0c;请将该用户添加到sales用户组。 ma01-ma10&#xff0c;请将该用户添加到manager用户组。 除manager 组以外的所有用户隐藏C盘。 除…

深入理解 Linux 内核

文章目录 前言一、内存寻址1、内存地址2、硬件中的分段&#xff08;1&#xff09;段选择符 3、Linux 中的分段&#xff08;1&#xff09;Linux GDT&#xff08;2&#xff09;Linux LDT 4、硬件中的分页5、Linux 中的分页&#xff08;1&#xff09;进程页表&#xff08;2&#x…

CRLF注入漏洞、URL重定向、资源处理拒绝服务详细介绍(附实例)

目录 一、CRLF注入漏洞 漏洞简介 演示介绍 漏洞检测工具&#xff1a;CRLFuzz 二、URL重定向漏洞 漏洞简介 漏洞相关业务 演示介绍 创建重定向虚假钓鱼网站 三、WEB 拒绝服务 简介 漏洞相关业务 演示介绍 一、CRLF注入漏洞 漏洞简介 CRLF 注入漏洞&#xff0c;是因…

centos7 firewall-cmd主机之间端口转发

目录 1. firewalld1.1 firewalld守护进程1.2 控制端口/服务1.3 伪装IP1.4 端口转发 2. 案例2.1 配置ServerA2.2 安装nginx测试 &#xff08;可选&#xff09;2.3 开启端口2.4 伪装IP2.5 端口转发2.6 配置ServerB2.7 修改nginx页面显示内容2.8 访问ServerB2.9 访问ServerA 1. fi…

低代码是开发的未来,还是只能解决边角问题的鸡肋?

随着互联网行业寒冬期的到来&#xff0c;降本增效、开源节流几乎成为了全球互联网厂商共同的应对措施&#xff0c;甚至高薪酬程序员的“35岁危机”一下子似乎变成了现实。程序员的高薪吸引了各行各业的“跨界选手”&#xff0c;是编程门槛降低了吗&#xff1f;不全是&#xff0…

搭建linux邮件服务器

参考&#xff1a;企业级邮件服务器实战_哔哩哔哩_bilibili Linux 平台开源免货的邮件服务器包括: Sendmail、Postix、Omail ; 邮件服务器构成了电子邮件系统的核心&#xff0c;每个收信人都有一个位于某个邮件服务器上的邮箱(mailbox)&#xff0c;一个邮件消息的典型旅程是从…

管道命令(cut、grep、sort、wc、uniq、tee、tr、col、join、paste、expand/unexpand、split、xargs)

文章目录 管道命令(pipe)选取命令&#xff1a;cut、grepcut使用案例cut的优点缺点 grep使用案例 排序命令&#xff1a;sort、wc、uniqsort使用案例 uniq使用案例 wc使用案例 双向重定向&#xff1a;tee使用案例 字符转换命令&#xff1a;tr、col、join、paste、expandtr使用案例…

非量表数据应该如何分析?

问卷中的非量表数据应该怎么分析&#xff1f; 样本特征分析 对于非量表题的描述可以使用频数分析或者可视化图形进行描述&#xff0c;比如单选题也可以使用柱形图等进行展示&#xff0c;通过结果展示了解样本的基本情况&#xff0c;最后结合分析结果提出建议等。差异分析 除此之…

mybatis中进行时间范围查询

一 oracle数据库 数据库时间类型为DATE TO_CHAR 把日期或数字转换为字符串 TO_DATE 把字符串转换为数据库中的日期类型 TO_DATE(char, ‘格式’) TO_NUMBER 将字符串转换为数字 TO_NUMBER(char, ‘格式’) 1、入参是String类型的数据 mybatis 处理时间范围 使用TO_DATE函数…

2023个税验证Excel表

根据北京市工资计算公式制作该表格&#xff0c;用来验证每月发放工资是否有误&#xff0c;统计年度总收入等。 下载链接如下&#xff08;提升等级用&#xff09;&#xff1a; https://download.csdn.net/download/wayright/87732783 不下载&#xff0c;按照上面表格数据自己制作…

下载高清图片素材,就上这6个网站,免费还能商用

图片素材网站我已经推荐过很多了&#xff0c;今天就再给大家推荐6个高清图片素材网&#xff0c;免费下载哦~建议收藏起来。 1、菜鸟图库 https://www.sucai999.com/pic.html?vNTYwNDUx 我推荐过很多次的一个设计素材网站&#xff0c;除了设计类&#xff0c;还有很多自媒体可…

el-input-number 输入框添加单位

需求 使用 element-ui 的 InputNumber 控件,实现金额填写,需要在数字后面添加一个单位:元 实现效果 代码部分 <template><el-dialogclass="morendialog":title="(formData.id ? 修改 : 新增) + title":visi

没有什么比破除束缚更自由的事情了

我发现&#xff0c;我时常处于一种自我消耗、内耗的状态中&#xff0c;令我难以振作起来去改变现状。因此&#xff0c;“拒绝内耗&#xff0c;提升表达力&#xff0c;努力提升自我”成为了我必须完成的小目标。 在高中和大学的的时候&#xff0c;有一段时间&#xff0c;我曾经…

【牛客网】迷宫问题与年终奖

目录 一、编程题 1.迷宫问题 2.年终奖 二、选择题 1、将N条长度均为M的有序链表进行合并&#xff0c;合并以后的链表也保持有序&#xff0c;时间复杂度为()? 2、大小为MAX的循环队列中&#xff0c;f为当前对头元素位置&#xff0c;r为当前队尾元素位置(最后一个元素的位…

Ansys Zemax | 设计抬头显示器时要使用哪些工具 – 第一部分

本文演示了如何使用OpticStudio工具设计分析抬头显示器(HUD)性能&#xff0c;即全视场像差(FFA)和NSC矢高图。(联系我们获取文章附件) 初始结构 HUD简介 以下为HUD的示意图。液晶显示器作为光源发光&#xff0c;光线被HUD的两个反射镜反射&#xff0c;然后通过风挡玻璃反射&am…

零死角玩转stm32中级篇3-SPI总线

一.基础知识 1.什么是SPI SPI&#xff08;Serial Peripheral Interface&#xff0c;串行外设接口&#xff09;是一种同步的串行通信协议&#xff0c;它被用于在微控制器、存储器芯片、传感器和其他外围设备之间传输数据。SPI通常由四个线组成&#xff1a;时钟线&#xff08;SC…

对git的简单总结

Git的基本使用 配置用户名和邮箱常见的操作查看仓库的状态远端仓库整体流程分支本地分支命令远端分支命令 这几天在做毕业设计&#xff0c;需要用到git&#xff0c;所以简单总结一下git的基本使用。 配置用户名和邮箱 git config --global user.name "Your Name" g…