梯度和反向传播

news2024/9/19 11:09:01

一.梯度

在机器学习的时候都了解过了,梯度是一个向量,导数+变化最快的方向

损失函数:

通过梯度使损失降到最

用y=wx+b举例也就是使用梯度来更新w的值,w=w-学习率*梯度。大于零就减小,反之增大

二.反向传播

就比如搭积木,反向传播就是将一个个积木搭建成一个形状的过程

在函数求导中我们都知道对于复合函数求导要用到的链式法则。

比如J(a,b,c)=3(a+bc),令v=bc,u=a+v,J(a,b,c)=3u,要对这个函数求导

将这个过程画成图就是计算图了:

而从b,c到v一直到结果的过程就是向前的,而反向传播就是反过来的

1.向前计算

对于tensor创建的张量,里面有一个属性requires_grad,如果设置为True那么它将记录这个张量的所有的操作保存在grad_fn当中

1.1计算过程:

import torch
# 创建一个张量
x=torch.ones(2,2,requires_grad=True)
# 对这个张量计算
y=x+5
print(y)
# 输出里面就会有一个属性grad_fn来保存这个张量x的计算过程

那么通过这个grad_fn保存的操作过程就可以用来组成上面的那种计算图

补充(有时我们不需要这个记录操作的时候我们可以将不需要的那个操作封装到with torch.no_grad()中。)

with torch.no_grad():
    z=x*3+2    # 此时创建的张量z中的requires_grad=False

反向传播

比如我们要计算y=(x+1)^2

import torch

x=torch.ones(1,1,requires_grad=True)

z=x+1

y=z*z

out=y

对于正向计算后得到的out而言,就可以使用backward来进行反向传播,计算梯度

out.backward()
# 计算结果就是这个函数的导数2

注意(在输出为一个标量的情况下可以直接使用backward方法,但是在输出不为标量的情况下要向backward中传入参数)

那么对于损失来说,loss.backward()计算出导数也就是梯度,并且保存到grad中来更新梯度(比如梯度下降的时候不可能只计算一次梯度,肯定是要更新的)

做一个简单的例子:

比如我们现在的模型是y=wx+b,其中w和b均为参数。然后我们使用y=3x+1来构造x和y的数据,然后通过模型来确定w和b的值,在和y=3x+1来比较误差

import torch
import numpy as np
from matplotlib import pyplot as plt

# 1准备数据
x=torch.rand([100])
y=3*x+1

# 先随机生成一个w和b
w=torch.rand(1,requires_grad=True)
b=torch.rand(1,requires_grad=True)

# 定义一个对损失反向传播得到梯度的函数
def loss_fn(y,y_predict):
    loss=(y_predict-y).pow(2).mean()
    for i in [w,b]:
        # 每次反向传播前把梯度设置为0
        if i.grad is not None:
            i.grad.data.zero_()
    loss.backward()
    return loss.data

# 定义一个函数计算w和b的下降后的值,learning_rate为学习率
def optimize(learning_rate):
    w.data-=learning_rate*w.grad.data
    b.data-=learning_rate*b.grad.data

for i in range(3000):  #下降3000次,次数越多越精确
    # 使用当前的w和b值来计算出预测值
    y_predict=x*w+b
    # 计算误差
    loss=loss_fn(y,y_predict)
    
    # 更新w和b
    optimize(0.01)

# 计算最终的预测值
predict=x*w+b

# 计算最后的误差
loss=predict-y
    
    











本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1983409.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

【源码+文档+调试讲解】古风生活体验交流网站

摘 要 二十一世纪我们的社会进入了信息时代,信息管理系统的建立,大大提高了人们信息化水平。传统的管理方式对时间、地点的限制太多,而在线管理系统刚好能满足这些需求,在线管理系统突破了传统管理方式的局限性。于是本文针对这一…

24/8/5算法笔记 逻辑回归sigmoid

今日是代码对sigmoid函数的实现和运用 #linear_model线性回归 #名字虽然叫逻辑回归,作用于分类 #分类:类别 #回归:预测 from sklearn.linear_model import LogisticRegression 实现函数 import numpy as np import matplotlib.pyplot as pl…

Linux笔记-3()

目录 一、Linuⅸ实操篇-定时任务调度 二、Linuⅸ实操篇-Linuⅸ磁盘分区、挂载 三、Linux实操篇-网络配置 一、Linuⅸ实操篇-定时任务调度 1 crond任务调度---crontab进行定时任务的设置1.1 概述任务调度:是指系统在某个时间执行的特定的命令或程序。任务调度分类…

【python】OpenCV—Image Colorization

文章目录 1、CIELAB 色彩空间2、作色问题定义3、Caffe 模型4、代码实现——Image5、代码实现——Video6、参考 1、CIELAB 色彩空间 Lab颜色空间,也称为Lab色彩空间或CIELAB色彩空间,是一种基于人类视觉感知特性的颜色模型。它是在1931年国际照明委员会&…

渗透SQL注入

首先打开php: Less-1: 打开浏览器输入网址,进入靶场: 输入?id1查询: 使用order by查询数据表的列数: http://127.0.0.1/sqllab/less-1/?id1 order by 4 -- ​ http://127.0.0.1/sqllab/less-1/?id1 order by 3 -- 由此可得表…

基于paddleocr实现验证码识别——训练数据

一、项目介绍 验证码(CAPTCHA)用于区分用户是人类还是计算机程序(如机器人)。这是为了防止各种形式的自动化攻击和滥用。以下是需要验证码识别的几个主要原因: 1. 防止恶意破解密码 攻击者可能会使用自动化程序进行…

数据结构----------贪心算法

什么是贪心算法? 贪心算法(Greedy Algorithm)是一种在问题求解过程中,每一步都采取当前状态下最优(即最有利)的选择,从而希望导致最终的全局最优解的算法策略。 贪心算法的核心思想是做选择时&…

【深度学习】DeepSpeed,ZeRO 数据并行的三个阶段是什么?

文章目录 ZeRO实验实验设置DeepSpeed ZeRO Stage-2 实验性能比较进一步优化DeepSpeed ZeRO Stage-3 和 CPU 卸载结论ZeRO ZeRO(Zero Redundancy Optimizer)是一种用于分布式训练的大规模深度学习模型的优化技术。它通过分片模型状态(参数、梯度和优化器状态)来消除数据并行…

Flink异步IO 调用算法总是超时

记录一次使用Flink 异步调用IO 总是超时的bug 注&#xff1a;博主使用的版本就是&#xff1a;<flink.version>1.16.1</flink.version> 起因&#xff1a; 因公司业务需要&#xff0c;使用Flink对数据进行流式处理&#xff0c;具体处理流程就是&#xff0c;从kafka…

PageRank算法与TextRank算法

PageRank PageRank 是一种用于计算网页重要性的算法&#xff0c;其核心思想源自随机浏览模型。这个模型假设一个网络中的用户通过随机点击链接在网页之间跳转&#xff0c;并根据网页的链接结构计算每个网页的重要性。 假设三个网页按以下方式连接&#xff0c;计算每个网页的PR值…

【零基础实战】基于物联网的人工淡水湖养殖系统设计

文章目录 一、前言1.1 项目介绍1.1.1 开发背景1.1.2 项目实现的功能1.1.3 项目硬件模块组成1.1.4 ESP8266工作模式配置 1.2 系统设计方案1.2.1 关键技术与创新点1.2.2 功能需求分析1.2.3 现有技术与市场分析1.2.4 硬件架构设计1.2.5 软件架构设计1.2.6 上位机开发思路 1.3 系统…

Robot Operating System——深度解析单线程执行器(SingleThreadedExecutor)执行逻辑

大纲 创建SingleThreadedExecutor新增Nodeadd_nodetrigger_entity_recollectcollect_entities 自旋等待get_next_executablewait_for_workget_next_ready_executableTimerSubscriptionServiceClientWaitableAnyExecutable execute_any_executable 参考资料 在ROS2中&#xff0c…

ARM知识点二

一、指令 指令的生成过程 指令执行过程示例 if (a 0) {x 0; } else {x x 3; } //翻译为 cmp r0,#0 MOVEQ R1,#0 ADDGT R1,R1,#3指令获取&#xff1a;从Flash中读取 CMP R0, #0&#xff0c;控制器开始执行。 指令解码&#xff1a;解码器解析 CMP 指令&#xff0c;ALU比较R…

DAMA学习笔记(十)-数据仓库与商务智能

1.引言 数据仓库&#xff08;Data Warehouse&#xff0c;DW&#xff09;的概念始于20世纪80年代。该技术赋能组织将不同来源的数据整合到公共的数据模型中去&#xff0c;整合后的数据能为业务运营提供洞察&#xff0c;为企业决策支持和创造组织价值开辟新的可能性。与商务智能&…

浅谈线程组插件之jp@gc - Ultimate Thread Group

浅谈线程组插件之jpgc - Ultimate Thread Group jpgc - Ultimate Thread Group是JMeter的一个强大且灵活的扩展插件&#xff0c;由JMeter Plugins Project提供。它为性能测试提供了超越JMeter原生线程组的更精细的控制能力&#xff0c;允许用户根据复杂的场景设计自定义负载模…

【TFT电容屏】

TFT电容屏基础知识补课 前言一、入门知识1.1 引脚介绍1.1.1 显示部分片选指令选择写指令读操作复位并行数据接口 1.1.2 背光电源背光电源 1.1.3 触摸IIC接口外部中断接口复位NC 1.2 驱动介绍1.3 FSMC介绍 总结 前言 跟着阳桃电子的学习⇨逐个细讲触摸屏接口定义–STM32单片机…

科普文:JUC系列之ForkJoinPool源码解读ForkJoinWorkerThread

科普文&#xff1a;JUC系列之ForkJoinPool基本使用及原理解读-CSDN博客 科普文&#xff1a;JUC系列之ForkJoinPool源码解读概叙-CSDN博客 科普文&#xff1a;JUC系列之ForkJoinPool源码解读WorkQueue-CSDN博客 科普文&#xff1a;JUC系列之ForkJoinPool源码解读ForkJoinTask…

复现sql注入漏洞

Less-1 字符型注入 页面如下&#xff1a; 我们先输入“?id1”看看结果&#xff1a; 页面显示错误信息中显示提交到sql中的“1”在通过sql语句构造后形成“1" LIMIT 0, 1”&#xff0c;其中多了一个“”&#xff0c;那么&#xff0c;我们的任务就是——逃脱出单引号的控制…

petalinux安装成功后登录Linux出现密码账号不正确

安装完Linux系统后发现登陆开发板上的Linux系统登陆一直错误&#xff0c;但你输入的账号和密码确确实实是“root”&#xff0c;但仍然一直在重复登陆。 这个时候就会怀疑自己是不是把密码改了&#xff0c;导致错误&#xff0c;然后又重新创建petalinux工程。 其实这个时候不需…

2024年第二季度HDD出货量和容量分析

概述 根据Trendfocus, Inc.发布的《SDAS: HDD Information Service CQ2 24 Quarterly Update – Executive Summary》报告&#xff0c;2024年第二季度硬盘驱动器(HDD)出货量和容量均出现了显著增长。总体来看&#xff0c;HDD出货量较上一季度增长2%&#xff0c;达到3028万块&a…