pytorch(二)梯度下降算法

news2024/12/24 2:42:56

文章目录

    • 优化问题
    • 梯度下降
    • 随机梯度下降

在线性模型训练的时候,一开始并不知道w的最优值是什么,可以使用一个随机值来作为w的初始值,使用一定的算法来对w进行更新

优化问题

寻找使得目标函数最优的权重组合的问题就是优化问题

梯度下降

通俗的讲,梯度下降就是使得梯度往下降的方向,也就是负方向走。一般来说,梯度往正方向走,表示梯度大于0,,表示函数是往递增方向走,而这里需要的是找最低点,最低点一定是在往下走,所以这里的梯度要取负号。梯度下降更新权重的公式如下(注意是减),α表示学习率:
在这里插入图片描述
梯度下降算法属于贪心算法的一种,它的权重的更新,每一次都是朝着梯度下降最快的方向进行更新,当梯度为0的时候,算法收敛,权重不再更新。梯度下降可能得到的是一个局部最优解(非凸函数)。

在深度学习中,尽管梯度下降算法会陷入局部最优,但是在深度学习中梯度下降算法依旧广泛应用:在之前大家认为,深度学习的目标函数会出现很多的局部最优解,但是实际上,其损失函数并没有很多的局部最优解。但是深度学习的损失函数会存在很多的鞍点(也就某一点上梯度为0,从一个切面上看是最小值,从另一个切面看是最大值的点,如下图 ),导致权重无法继续迭代,可以使用动量法来解决鞍点问题。

在这里插入图片描述
在这里插入图片描述
在这里插入图片描述
代码实现:

  • 要求:模拟梯度下降算法,计算在x_data、y_data数据集下,y=w*x模型找到合适的w的值。

  • 和第二课不同的是,第一课的w是我们认为设定的,通过一个for循环使得w迭代,这一次需要的是通过模型找到适合的w

import matplotlib.pyplot as plt

x_data=[1.0,2.0,3.0]
y_data=[2.0,4.0,6.0]

w=1.0
# 求预测值
def forward(x):
    return x*w

# 损失函数
def cost(xs,ys):
    costs=0
    # 用zip打包成元祖,并返回元祖组成的列表
    for x,y in zip(xs,ys):
        y_pred=forward(x)
        costs+=(y_pred-y)**2       
    return costs/len(xs)

# 计算梯度
def gradient(xs,ys):
    grad=0
    for x,y in zip(xs,ys):
        grad+=2*x*(x*w-y)
    return grad/len(xs)

cost_list=[]
epoch_list=[]
print('predict before training',forward(4))

for epoch in range(200):
    cost_val=cost(x_data,y_data)
    grad_val=gradient(x_data,y_data)
    w-=0.01*grad_val
    
    epoch_list.append(epoch)
    cost_list.append(cost_val)
    
    print('epoch:',epoch,'w=',w,'loss=',cost_val)

print('predict after training:',forward(4))

# 画图
plt.plot(epoch_list,cost_list)
plt.xlabel('epoch')
plt.ylabel('loss')
plt.show()

部分结果截图
在这里插入图片描述

随机梯度下降

在上面的梯度下降中,求损失的时候用的是全部数据的平均损失作为更新的依据。而随机梯度下降是在全部的数据中随机选择一个作为更新的依据。使用随机梯度下降可以有效的避开鞍点

在这里插入图片描述

import matplotlib.pyplot as plt

x_data=[1.0,2.0,3.0]
y_data=[2.0,4.0,6.0]

w=1.0
# 求预测值
def forward(x):
    return x*w

# 损失函数
def cost(x,y):
    y_pred=forward(x)
    return (y_pred-y)**2

# 计算梯度
def gradient(x,y):
    return 2*x*(x*w-y)

cost_list=[]
epoch_list=[]
print('predict before training',forward(4))

for epoch in range(200):
    for x,y in zip(x_data,y_data):
        grad_val=gradient(x,y)
        print('\tgrad:',x,y,w)
        w-=0.01*grad_val
        print('\tgrad:',x,y,w,'\n')
        cost_val=cost(x,y)
    
    epoch_list.append(epoch)
    cost_list.append(cost_val)
    print('epoch:',epoch,'w=',w,'loss=',cost_val)

print('predict after training:',forward(4))

# 画图
plt.plot(epoch_list,cost_list)
plt.xlabel('epoch')
plt.ylabel('loss')
plt.show()

结果截图:
在这里插入图片描述

在第一个梯度下降中,样本x与样本x+1之间没有时序关系,我们计算的是他们的总的损失,这些运行时可以并行运行的。但是在第二个随机梯度下降中,我们是先计算了x再计算的x+1,数据之间存在先后的关系,有依赖关系,不能用并行计算。所以梯度下降可以有效提高运算的效率,而随机梯度下降可以获得一个优异的w。把以上两种方法折中,就产生了小批量随机梯度下降,取了一种性能与时间复杂度之间的折中。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1407610.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

Nginx问题分析

问题再现 分析问题: 就是通过http://182.44.16.68:8077/web-ui/static/js/chunk-libs.82635094.js 地址访问,找不到对应的js文件 首先确认文件在服务器的位置 发现这个目录下确实有这个js文件,那问题就在于http://182.44.16.68:8077/web-ui…

225.用队列实现栈(附带源码)

目录 一、思路 二、源码 一、思路 所以,创建两个队列 入栈,那个不空入那个 出栈,移动不空的队列的前n-1个到空队列,出队列第n个 很简单 总的来说,就是 下面直接手撕代码: 二、源码 typedef int QDa…

Unity_使用Image和脚本生成虚线段

生成如图样式的虚线段 原理:使用Image做一条线段,这个方法的原理就是给固定的片元长度,对Image进行分割,把片元添加到一个列表中,然后循环对列表中的偶数位进行隐藏,也可以调整线段的宽度 缺陷&#xff1…

Pandas.Series.idxmin() 最小值索引 详解 含代码 含测试数据集 随Pandas版本持续更新

关于Pandas版本: 本文基于 pandas2.2.0 编写。 关于本文内容更新: 随着pandas的stable版本更迭,本文持续更新,不断完善补充。 传送门: Pandas API参考目录 传送门: Pandas 版本更新及新特性 传送门&…

Java Web(二)--HTML

基本介绍 官网文档地址: HTML 教程 HTML(HyperText Mark-up Language)即超文本标签语言;HTML 文本是由 HTML 标签组成的文本,可以包括文字、图形、动画、声音、表格、链接等;HTML 的结构包括头部(Head&…

如何通过内网穿透+代理共享网络

去年写了一篇博客:如何通过代理共享网络,在这篇文章探索了怎么在同一个局域网内共享代理服务。不过,它的实用性也比较缺乏,要求必须处于同一个局域网之下,大多数时候,我们可能很难有这样的环境。所以&#…

【GitHub项目推荐--12 年历史的 PDF 工具开源了】【转载】

最近在整理 PDF 的时候,有一些需求普通的 PDF 编辑器没办法满足,比如 PDF 批量合并、编辑等。 于是,我就去 GitHub 上看一看有没有现成的轮子,发现了这个 PDF 神器「PDF 补丁丁」,让人惊讶的是这个 PDF 神器有 12 年的…

基于SpringBoot Vue美食网站系统

大家好✌!我是Dwzun。很高兴你能来阅读我,我会陆续更新Java后端、前端、数据库、项目案例等相关知识点总结,还为大家分享优质的实战项目,本人在Java项目开发领域有多年的经验,陆续会更新更多优质的Java实战项目&#x…

mysql生成最近24小时整点/最近30天/最近12个月时间临时表

文章目录 生成最近24小时整点生成最近30天生成最近12个月 生成最近24小时整点 SELECT-- 每向下推1行, i比上次减去1b.*, i.*,DATE_FORMAT( DATE_SUB( NOW(), INTERVAL ( -( i : i - 1 ) ) HOUR ), %Y-%m-%d %H:00 ) AS time FROM-- 目的是生成12行数据( SELECTa FROM( SELECT…

可直接将视频转文字的工具,速到快到离谱!

如何将视频转换成文字,推荐大家使用视频提取文案小助手,三秒一键搞定,真的快到离谱​! 不少草根博主在做短视频的时候,就有很多人给大家支招让大家先模仿后超越的模式,激起一众爱好短视频的草根博主成为短…

Scrapy爬虫在新闻数据提取中的应用

Scrapy是一个强大的爬虫框架,广泛用于从网站上提取结构化数据。下面这段代码是Scrapy爬虫的一个例子,用于从新闻网站上提取和分组新闻数据。 使用场景 在新闻分析和内容聚合的场景中,收集和组织新闻数据是常见需求。例如,如果我…

【小黑嵌入式系统第十六课】PSoC 5LP第三个实验——μC/OS-III 综合实验

上一课: 【小黑嵌入式系统第十五课】μC/OS-III程序设计基础(四)——消息队列(工作方式&数据通信&生产者消费者模型)、动态内存管理、定时器管理 前些天发现了一个巨牛的人工智能学习网站,通俗易懂…

纯注解开发bean

注解开发定义bean:Controller:用于表现层bean定义;Service:用于业务层bean定义;Repository:用于数据层bean定义。 我们先来完成数据层和业务逻辑层的注解 数据层: package org.example.dao.impl;import or…

智慧博物馆信息化系统建设(3)

博物馆智能电子导览系统 IPAD智能化定制服务 系统采用的IPAD。使用者通过智能IPAD终端上的三维立体导图,可以在参观的同时,随时读取展馆平面地图以及展品相关信息,然后选择相关服务。简单操作便可获得文字、图片、声音以及视频资料展现给使用者。 游客通过该智能IPAD终端…

构建中国人自己的私人GPT—与文档对话

先看效果 他可以从上传的文件中提取内容作为答案。上传文件摄取速度 摄取速度取决于您正在摄取的文档数量以及每个文档的大小。为了加快摄取速度,您可以在配置中更改摄取模式。 存在以下摄取模式: simple:历史行为,一次按顺序摄…

03 SpringBoot实战 -微头条之首页门户模块(跳转某页面自动展示所有信息+根据hid查询文章全文并用乐观锁修改阅读量)

1.1 自动展示所有信息 需求描述: 进入新闻首页portal/findAllType, 自动返回所有栏目名称和id 接口描述 url地址:portal/findAllTypes 请求方式:get 请求参数:无 响应数据: 成功 {"code":"200","mes…

RubbleDB: CPU-Efficient Replication with NVMe-oF——论文泛读

ATC 2023 Paper 论文阅读笔记整理 问题 由于需要执行昂贵的后台压缩操作,CPU 往往是持久键值存储的性能瓶颈。在日志结构合并树(LSM树),标准的基于磁盘的键值存储设计[2,4,8,22,41],压缩可以在生产工作负载中消耗高达…

基于FPGA的OFDM基带发射机的设计与实现

文章目录 前言一、OFDM描述二、本系统的实现参照 1.IEEE 802.11a协议主要参数2.不同调制方式与速率 3. IFFT映射关系4. IEEE 802.11a物理层规范5. PPDU帧格式三、设计与实现 1.扰码2.卷积编码与删余3.数据交织4.符号调制5.导频插入6.IFFT变换 7.循环前缀&加窗8.训练序列生成…

快速上手的AI工具-文心一言绘本创作

前言 大家好晚上好,现在AI技术的发展,它已经渗透到我们生活的各个层面。对于普通人来说,理解并有效利用AI技术不仅能增强个人竞争力,还能在日常生活中带来便利。无论是提高工作效率,还是优化日常任务,AI工具…

RKE快速搭建离线k8s集群并用rancher管理界面

转载说明:如果您喜欢这篇文章并打算转载它,请私信作者取得授权。感谢您喜爱本文,请文明转载,谢谢。 本文记录使用RKE快速搭建一套k8s集群过程,使用的rancher老版本2.5.7(当前最新版为2.7)。适用…