Pytorch梯度下降算法(Gradient Descent)

news2024/11/18 17:44:28

intro

其实对于我们将要学的梯度最小函数,目的就是先得到loss损失最小的值,然后根据这个最小的值去得到w。

初始点在initial guess这个位置,我们希望找到最小的权重点global cost minimum,我们到底是让这个点左移寻找还是右移寻找呢?

此时我们就需要使用到梯度定义。 

在加上一个\Deltax后,如果这个导数值变为负的,说明我接下来函数图像呈现下降的趋势,那根据我们上述所说的寻找一个阶梯最小函数,就是要使函数往小的方向进行。所以我们希望函数图像下降的话,我们就取导数为负的方向。

因此我们可以每下降一部分可以更新一下权重。用来保留。公式中的a为学习率,学习率乘上这个导数表示我们每次大概在函数上走多远,通常来说学习率要小一点,但是也不能太小。

以上我们的梯度下降的算法,其实就是我们的贪心算法,它可能找不到全局最优解,但是可以找到局部最优解,大家可以思考一下为什么会出现这样的情况。大家可以看看下面的这个图,结合上面的介绍。

因为到局部最优点后,没有办法找到这个导数为负数的这样的情况了。

还存在一种点,鞍点。这个也可能无法到达全局最优点。鞍点就是出现了一个导数为0的线。

因为公式中 w=w-ag  g为导数,此时导数为0了,w一直不变。

左边的式子是上一节中预测值y^减去真实值y(此时我们预测的y使用x*w)。此时求出的就是w权重值,对w进行求导,最终结合我们此次的内容,得到最后跟新出的w局部最优权重。

代码实现

x_data=[1.0,2.0,3.0]
y_data=[2.0,4.0,6.0]
w=1.0

def forward(x):
    return x*w

def cost(xs,ys):
    cost=0
    for x,y in zip(xs,ys):
        y_pred=forward(x)
        cost+=(y_pred-y)**2
    return cost/len(xs)

def gradient(xs,ys):
    grad=0
    for x,y in zip(xs,ys):
        grad+=2*x*(x*w-y)
    return grad/len(xs)

print('Predict (before training)',4,forward(4))
for epoch in range(100):
    cost_val=cost(x_data,y_data)
    grad_val=gradient(x_data,y_data)
    w-=0.1*grad_val
    print('Epoch:',epoch,'w=',w,'loss',cost_val)
print('Predict (after train)',4,forward(4))

 随机梯度下降

在深度学习中,使用梯度下降还是比较少的,通常我们使用的是随机梯度下降(Stochsstic gradient descent)。

我们可以看出,我们梯度下降使用的是使用整个损失的平均损失作为梯度下降的依据,但是随机梯度下降变成了单个样本的损失函数来进行更新。使用这个随机梯度时,由于每一个点都会存在噪声,那即使我们陷入了鞍点,噪声也会推动我们向前运动。

在更新的过程中会跨出鞍点,往后面进行运动。

代码修改

#原先的cost 现在变成loss 现在不用求均值了
def loss(x,y):
    y_pred=forward(x)
    return (y_pred-y)*2

#原先的gradient也不需要求均值了
def gradient(x,y):
    return 2*x*(x*w-y)

for epoch in range(100):
    for x,y in zip (x_data,y_data):
        grad=gradient(x,y)
        w=w-0.1*grad_val #此时的w也不用去累积计算了
        print('\tgrad:',x,y,grad)
        l=loss(x,y)
    print('progress:',epoch,'w=',w,'loss:',l)

大家可以看上述代码和梯度下降算法的差距,几乎都是将在梯度下降算法中的每一个求均值的部分都修改成求一次。

但是出现一个问题:

对于一个梯度下降算法,其实我们在使用模型计算时不管是对x1还是(x+1)求解f(x)时,两者是没有依赖关系的,这些运算可以并行。但是使用随机梯度下降时,每次都会更新,两者存在依赖关系。所以说梯度下降算法的时间复杂度是优于随机梯度下降算法的。

因此我们会对其进行折中考虑,Batch(批量的随机梯度下降)。就是将随机梯度下降分的很散的数据,给他一个批量处理(若干个一组)。在深度学习默认使用的随机梯度下降就是sdd这个算法,就是采用batch这个方法。 

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1687441.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

Linux第三十九章

🐶博主主页:ᰔᩚ. 一怀明月ꦿ ❤️‍🔥专栏系列:线性代数,C初学者入门训练,题解C,C的使用文章,「初学」C,linux 🔥座右铭:“不要等到什么都没有了…

【全开源】JAVA同城搬家系统源码小程序APP源码

JAVA同城搬家系统源码 特色功能: 强大的数据处理能力:JAVA提供了丰富的数据结构和算法,以及强大的并发处理能力,使得系统能够快速地处理大量的货物信息、司机信息、订单信息等,满足大规模物流的需求。智能路径规划&a…

【Redis】String的介绍与应用详解

大家好,我是白晨,一个不是很能熬夜,但是也想日更的人。如果喜欢这篇文章,点个赞👍,关注一下👀白晨吧!你的支持就是我最大的动力!💪💪&#x1f4aa…

设置 sticky 不生效?会不会是你还是没懂 sticky?

官方描述 基本上可以看懂的就会知道。sticky 是相对于存在滚动条的内容的,啥意思? 就是不论你被谁包着,你只会往上找有 overflow 属性的盒子进行定位,包括:overflow:hidden; overflow:scroll; overflow:auto; overflo…

一键批量提取TXT文档前N行,高效处理海量文本数据,省时省力新方案!

大量的文本信息充斥着我们的工作与生活。无论是研究资料、项目文档还是市场报告,TXT文本文档都是我们获取和整理信息的重要来源。然而,面对成百上千个TXT文档,如何快速提取所需的关键信息,提高工作效率,成为了许多人头…

EI稳定检索--人文社科类会议(ICBAR 2024)

【ACM独立出版】第四届大数据、人工智能与风险管理国际学术会议 (ICBAR 2024) 2024 4th International Conference on Big Data, Artificial Intelligence and Risk Management 【高录用•快检索,ACM独立出版-稳定快速EI检索 | 往届均已完成EI, Scopus检索】 【见…

运行vue2项目基本过程

目录 步骤1 步骤2 步骤3 补充: 解决方法: node-scss安装失败解决办法 步骤1 安装npm 步骤2 切换淘宝镜像 #最新地址 淘宝 NPM 镜像站喊你切换新域名啦! npm config set registry https://registry.npmmirror.com 步骤3 安装vue-cli npm install…

分布式中traceId链接服务间的日志

使用技术: 网关:SpringCloudGateway RPC调用:Feign 一:在网关入口处设置header:key-traceId,value-UUID import com.kw.framework.common.croe.constant.CommonConstant; import com.kw.framework.gateway…

机器学习高斯贝叶斯算法实战:判断肿瘤是良性还是恶性

概述 我们使用威斯康星乳腺肿瘤数据集,来构建一个机器学习模型,用来判断患者的肿瘤是良性还是恶性。 数据分析 威斯康星乳腺肿瘤数据集,包括569个病例的数据样本,每个样本具有30个特征值。 样本分为两类:恶性Malig…

SHA1获取

这里写目录标题 JDK获取uniapp开发Dcould获取 JDK获取 一、下载jdk 链接: http://www.oracle.com/ 二、安装直接下一步下一步 三、配置环境变量 先新增变量JAVA_HOME变量值为C:\devUtils\jdk (jdk安装路径位置)再配置Path(%JAVA_HOME%\bin) 四、创建SHA1安全证书 win r输入cmd…

常见应用流量特征分析

目录 1.sqlmap 1.常规GET请求 2.通过--os-shell写入shell 3.post请求 2.蚁剑 编码加密后 3.冰蝎 冰蝎_v4.1 冰蝎3.2.1 4.菜刀 5.哥斯拉 1.sqlmap 1.常规GET请求 使用的是sqli-labs的less7 (1)User-Agent由很明显的sqlmap的标志,展…

如何快速增加外链?

要快速增加外链并不难,相信各位都知道,难的是快速增加外链且没有风险,所以这时候GNB外链的重要性就出现了,这是一种自然的外链,何谓自然的外链,在谷歌的体系当中,自然外链指的就是其他网站资源给…

[Spring Boot]baomidou 多数据源

文章目录 简述本文涉及代码已开源 项目配置pom引入baomidouyml增加dynamic配置启动类增加注解配置结束 业务调用注解DS()TransactionalDSTransactional自定义数据源注解MySQL2 测试调用查询接口单数据源事务测试多数据源事务如果依然使用Transactional会怎样?测试正…

不同类型的区块链钱包有什么特点和适用场景?

区块链钱包是用于存储和管理加密货币的重要工具,市面上有许多不同类型的区块链钱包可供选择。以下是几种主要类型的区块链钱包及其特点和适用场景。 1.软件钱包: 特点:软件钱包是最常见的一种区块链钱包,通常作为软件应用程序提供…

docker不删除容器更改其挂载目录

场景:docker搭建的jenkins通常需要配置很多开发环境,当要更换挂载目录,每次都需要删除容器重新运行,不在挂载目录的环境通常不会保留。 先给一个参考博客docker不删除容器,修改容器挂载或其他_jenkins 修改容器挂载do…

第17讲:C语言内存函数

目录 1.memcpy使用和模拟实现2.memmove使用和模拟实现3.memset函数的使用4.memcmp函数的使用 1.memcpy使用和模拟实现 void * memcpy (void * destination, const void * source, size_t num);• 函数memcpy从source的位置开始向后复制num个字节的数据到destination指向的内存…

分析电脑上处理器的性能报告

这张图片给出了一份详细的第11代Intel(R) Core(TM) i7-1165G7 2.80GHz处理器的性能报告。 CPU型号:11th Gen Intel(R) Core(TM) i7-1165G7(这是一个低功耗的移动处理器,常用于轻薄型笔记本电脑) 基准速度:2.80 GHz&…

C语言-信号

信号 一、信号是什么东西 信号是事件发生时通知进程的一种机制,有时也称之为软件中断。 信号的到来会打断了程序执行的正常流程。 大多数情况下,无法预测信号到达的精确时间。 一个(具有合适权限的)进程能够向另一进程发送信…

python查找内容在文件中的第几行(利用了滑动窗口)

def find_multiline_content(file_path, multiline_content):with open(file_path, r) as file:# 文件内容file_lines file.readlines()# 待检测内容multiline_lines multiline_content.strip().split(\n)# 待检测内容总行数num_multiline_lines len(multiline_lines)matchi…

Postgresql源码(130)ExecInterpExpr转换为IR的流程

相关 《Postgresql源码(127)投影ExecProject的表达式执行分析》 《Postgresql源码(128)深入分析JIT中的函数内联llvm_inline》 《Postgresql源码(129)JIT函数中如何使用PG的类型llvmjit_types》 表达式计算…