深入探讨梯度下降:优化机器学习的关键步骤(二)

news2025/1/20 3:52:32

文章目录

  • 🍀引言
  • 🍀eta参数的调节
  • 🍀sklearn中的梯度下降

🍀引言

承接上篇,这篇主要有两个重点,一个是eta参数的调解;一个是在sklearn中实现梯度下降

在梯度下降算法中,学习率(通常用符号η表示,也称为步长或学习速率)的选择非常重要,因为它直接影响了算法的性能和收敛速度。学习率控制了每次迭代中模型参数更新的幅度。以下是学习率(η)的重要性:

  • 收敛速度:学习率决定了模型在每次迭代中移动多远。如果学习率过大,模型可能会在参数空间中来回摇摆,导致不稳定的收敛或甚至发散。如果学习率过小,模型将收敛得很慢,需要更多的迭代次数才能达到最优解。因此,选择合适的学习率可以加速收敛速度。

  • 稳定性:过大的学习率可能会导致梯度下降算法不稳定,甚至无法收敛。过小的学习率可以使算法更加稳定,但可能需要更多的迭代次数才能达到最优解。因此,合适的学习率可以在稳定性和收敛速度之间取得平衡。

  • 避免局部最小值:选择不同的学习率可能会导致模型陷入不同的局部最小值。通过尝试不同的学习率,您可以更有可能找到全局最小值,而不是被困在局部最小值中。

  • 调优:学习率通常需要调优。您可以尝试不同的学习率值,并监视损失函数的收敛情况。通常,您可以使用学习率衰减策略,逐渐降低学习率以改善收敛性能。

  • 批量大小:学习率的选择也与批量大小有关。通常,小批量梯度下降(Mini-batch Gradient Descent)使用比大批量梯度下降更大的学习率,因为小批量可以提供更稳定的梯度估计。

总之,学习率是梯度下降算法中的关键超参数之一,它需要仔细选择和调整,以在训练过程中实现最佳性能和收敛性。不同的问题和数据集可能需要不同的学习率,因此在实践中,通常需要进行实验和调优来找到最佳的学习率值。


🍀eta参数的调节

在上代码前我们需要知道,如果eta的值过小会造成什么样的结果

在这里插入图片描述
反之如果过大呢

在这里插入图片描述
可见,eta过大过小都会影响效率,所以一个合适的eta对于寻找最优有着至关重要的作用


在上篇的学习中我们已经初步完成的代码,这篇我们将其封装一下
首先需要定义两个函数,一个用来返回thera的历史列表,一个则将其绘制出来

def gradient_descent(eta,initial_theta,epsilon = 1e-8):
    theta = initial_theta
    theta_history = [initial_theta]
    def dj(theta): 
        return 2*(theta-2.5) #  传入theta,求theta点对应的导数
    def j(theta):
        return (theta-2.5)**2-1  #  传入theta,获得目标函数的对应值
    while True:
        gradient = dj(theta)
        last_theta = theta
        theta = theta-gradient*eta 
        theta_history.append(theta)
        if np.abs(j(theta)-j(last_theta))<epsilon:
            break
    return theta_history

def plot_gradient(theta_history):
    plt.plot(plt_x,plt_y)
    plt.plot(theta_history,[(i-2.5)**2-1 for i in theta_history],color='r',marker='+')
    plt.show()

其实就是上篇代码的整合罢了
之后我们需要进行简单的调参了,这里我们分别采用0.10.010.9,这三个参数进行调节

eta = 0.1
theta =0.0
plot_gradient(gradient_descent(eta,theta))
len(theta_history)

运行结果如下
在这里插入图片描述

eta = 0.01
theta =0.0
plot_gradient(gradient_descent(eta,theta))
len(theta_history)

运行结果如下
在这里插入图片描述

eta = 0.9
theta =0.0
plot_gradient(gradient_descent(eta,theta))
len(theta_history)

运行结果如下
在这里插入图片描述
这三张图与之前的提示很像吧,可见调参的重要性
如果我们将eta改为1.0呢,那么会发生什么

eta = 1.0
theta =0.0
plot_gradient(gradient_descent(eta,theta))
len(theta_history)

运行结果如下
在这里插入图片描述
那改为1.1呢

eta = 1.1
theta =0.0
plot_gradient(gradient_descent(eta,theta))
len(theta_history)

运行结果如下
在这里插入图片描述
我们从图可以清楚的看到,当eta为1.1的时候是嗷嗷增大的,这种情况我们需要采用异常处理来限制一下,避免报错,处理的方式是限制循环的最大值,且可以在expect中设置inf(正无穷)

def gradient_descent(eta,initial_theta,n_iters=1e3,epsilon = 1e-8):
    theta = initial_theta
    theta_history = [initial_theta]
    i_iter = 1
    def dj(theta):  
        try:
            return 2*(theta-2.5) #  传入theta,求theta点对应的导数
        except:
            return float('inf')
    def j(theta):
        return (theta-2.5)**2-1  #  传入theta,获得目标函数的对应值
    while i_iter<=n_iters:
        gradient = dj(theta)
        last_theta = theta
        theta = theta-gradient*eta 
        theta_history.append(theta)
        if np.abs(j(theta)-j(last_theta))<epsilon:
            break
        i_iter+=1
    return theta_history

def plot_gradient(theta_history):
    plt.plot(plt_x,plt_y)
    plt.plot(theta_history,[(i-2.5)**2-1 for i in theta_history],color='r',marker='+')
    plt.show()

注意:inf表示正无穷大


🍀sklearn中的梯度下降

这里我们还是以波士顿房价为例子
首先导入需要的库

from sklearn.datasets import load_boston
from sklearn.linear_model import SGDRegressor

之后取一部分的数据

boston = load_boston()
X = boston.data
y = boston.target
X = X[y<50]
y = y[y<50]

然后进行数据归一化

from sklearn.preprocessing import StandardScaler
from sklearn.model_selection import train_test_split
X_train,X_test,y_train,y_test=train_test_split(X,y)
std = StandardScaler()
std.fit(X_train)
X_train_std=std.transform(X_train)
X_test_std=std.transform(X_test)
sgd_reg = SGDRegressor()
sgd_reg.fit(X_train_std,y_train)

最后取得score

sgd_reg.score(X_test_std,y_test)

运行结果如下
在这里插入图片描述


请添加图片描述

挑战与创造都是很痛苦的,但是很充实。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/966936.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

I2C与I3C的对比

I2C与I3C的对比 电气特性 I2C 1.半双工 2.串行数据线(SDA)和串行时钟线(SCL) 3.数据线漏极开路&#xff0c;即I2C接口接上拉电阻 4.I2C总线运行速度&#xff1a;**标准模式100kbit/s&#xff0c;快速模式400kbit/s&#xff0c;快速模式plus 1Mbit/s&#xff0c;**高速模式…

TopSAP天融信 LINUX客户端 CentOS版安装

TopSAP天融信 LINUX客户端 CentOS版安装 下载客户端安装运行 下载客户端 项目需要用到CentOS环境下的天融信客户端&#xff0c;可以下载LINUX版 下载地址 https://app.topsec.com.cn/ X86_64(或AMD64)架构客户端deb包&#xff1a;V3.5.2.36.2 MD5 :EC032529A8D3A645B7368F28E…

Darshan日志分析

标头 darshan-parser 输出的开头显示了有关作业的总体信息的摘要。还可以使用–perf、–file或–total命令行选项生成其他作业级别摘要信息。 darshan log version&#xff1a;Darshan 日志文件的内部版本号。compression method&#xff1a;压缩方法。exe&#xff1a;生成日志…

中文命名实体识别

本文通过people_daily_ner数据集&#xff0c;介绍两段式训练过程&#xff0c;第一阶段是训练下游任务模型&#xff0c;第二阶段是联合训练下游任务模型和预训练模型&#xff0c;来实现中文命名实体识别任务。 一.任务和数据集介绍 1.命名实体识别任务 NER&#xff08;Named En…

开发总结:webpack

webpack官网webpack | webpack 中文文档 | webpack 中文网 一、什么是webpack webpack 可以看做是模块打包机&#xff0c;它所做的事情是&#xff1a;分析你的项目结构&#xff0c;找到JavaScript 模块以及其它的一些浏览器不能直接运行的拓展语言&#xff08;Scss&#xff0…

(位运算) 剑指 Offer 15. 二进制中1的个数 ——【Leetcode每日一题】

❓ 剑指 Offer 15. 二进制中1的个数 难度&#xff1a;简单 编写一个函数&#xff0c;输入是一个无符号整数&#xff08;以二进制串的形式&#xff09;&#xff0c;返回其二进制表达式中数字位数为 ‘1’ 的个数&#xff08;也被称为 汉明重量).&#xff09;。 提示&#xff…

非科班菜鸡算法学习记录 | 代码随想录算法训练营第53天|| 1143.最长公共子序列 1035.不相交的线 53. 最大子序和 动态规划

1143. 最长公共子序列 知识点&#xff1a;动规 状态&#xff1a;不会 思路&#xff1a; 用dpij表示两个串中到i-1和j-1个字符结束的最长公共子序列长度&#xff08;不用特殊初始化&#xff09; class Solution { public:int longestCommonSubsequence(string text1, string …

操作系统的发展和分类

注意&#xff1a;每个阶段的主要优点都是解决了上个阶段的缺点 1.手工操作阶段 概括&#xff1a;一个用户在一段时间内独占全机&#xff0c;导致资源利用率极低&#xff0c;用户输入指令给机器&#xff0c;然后机器运行响应给用户。 2.批处理阶段 2.1单道批处理系统 优点&…

leetcode622-设计循环队列

本题重点&#xff1a; 1. 选择合适的数据结构 2. 针对选择的数据结构判断“空”和“满” 这两点是不分先后次序的&#xff0c;在思考时应该被综合起来。事实上&#xff0c;无论我们选择链表还是数组&#xff0c;最终都能实现题中描述的“循环队列”的功能&#xff0c;只不过…

数学建模-点评笔记 9月3日

1.摘要&#xff1a;关键方法和结论&#xff08;精炼的语言&#xff09;要说明&#xff0c;方法的合理性和意义也可以说明。 评委先通过摘要筛选&#xff08;第一轮&#xff09; 2.时间序列找异常值除了3西格玛还有针对时间序列更合适寻找的方法 3.模型的优缺点要写的详细一点…

编写一个这样的程序,满足五日均线,十日均线,二十日均线,六十天六日均线调头向上的选股代码

编写一个这样的程序&#xff0c;满足五日均线&#xff0c;十日均线&#xff0c;二十日均线&#xff0c;六十天六日均线调头向上的选股代码 以下是一个用C语言编写的程序&#xff0c;可以读取股票数据并筛选出满足条件的股票。程序使用了一个假设的股票数据文件格式&#xff0c…

将帅要避免五个方面的弱点:蛮干、怕死、好名、冲动、溺爱民众

将帅要避免五个方面的弱点&#xff1a;蛮干、怕死、好名、冲动、溺爱民众 【安志强趣讲《孙子兵法》第28讲】 【原文】 是故屈诸侯者以害&#xff0c;役诸侯者以业&#xff0c;趋诸侯者以利。 【注释】 趋&#xff1a;归附、依附。 【趣讲白话】 所以&#xff0c;用祸患威逼诸侯…

IDM2024Internet Download Manager下载器最新版本

IDM&#xff08;Internet Download Manager&#xff09;下载器主窗口的左侧是下载类别的分类&#xff0c;提供了分类功能来组织和管理文件。如果不需要它&#xff0c;可以删除“分类”窗口&#xff0c;并且在下载文件时不选择任何分类。 每个下载类别都有一个名称&#xff0c;…

ARM编程模型-常用指令集

一、ARM指令集 ARM是RISC架构&#xff0c;所有的指令长度都是32位&#xff0c;并且大多数指令都在一个单周期内执行。主要特点&#xff1a;指令是条件执行的&#xff0c;内存访问使用Load/store架构。 二、Thumb 指令集 Thumb是一个16位的指令集&#xff0c;是ARM指令集的功能…

Go实现LogCollect:海量日志收集系统【下篇——开发LogTransfer】

Go实现LogAgent&#xff1a;海量日志收集系统【下篇】 0 前置文章 Go实现LogAgent&#xff1a;海量日志收集系统【上篇——LogAgent实现】 前面的章节我们已经完成了日志收集&#xff08;LogAgent&#xff09;&#xff0c;接下来我们需要将日志写入到kafka中&#xff0c;然后…

后端SpringBoot+前端Vue前后端分离的项目(一)

前言&#xff1a;后端使用SpringBoot框架&#xff0c;前端使用Vue框架&#xff0c;做一个前后端分离的小项目&#xff0c;需求&#xff1a;实现一个表格&#xff0c;具备新增、删除、修改的功能。 一、数据库表的设计 设计了一个merchandise表&#xff0c;id是编号&#xff0c…

基于Matlab利用IRM和RRTstar实现无人机路径规划(附上源码+数据+说明+报告+PPT)

无人机路径规划是无人机应用领域中的关键问题之一。本文提出了一种基于IRM&#xff08;Informed RRTstar Method&#xff09;和RRTstar&#xff08;Rapidly-exploring Random Tree star&#xff09;算法的无人机路径规划方法&#xff0c;并使用Matlab进行实现。该方法通过结合I…

【LeetCode-中等题】994. 腐烂的橘子

文章目录 题目方法一&#xff1a;bfs层序遍历 题目 该题值推荐用bfs&#xff0c;因为是一层一层的感染&#xff0c;而不是一条线走到底的那种&#xff0c;所以深度优先搜索不适合 方法一&#xff1a;bfs层序遍历 广度优先搜索&#xff0c;就是从起点出发&#xff0c;每次都尝…

无涯教程-JavaScript - VARP函数

VARP函数取代了Excel 2010中的VAR.P函数。 描述 该函数根据整个总体计算方差。 语法 VARP (number1,[number2],...)争论 Argument描述Required/OptionalNumber1The first number argument corresponding to a population.RequiredNumber2...Number arguments 2 to 255 cor…

Go实现LogCollect:海量日志收集系统【上篇——LogAgent实现】

Go实现LogCollect&#xff1a;海量日志收集系统【上篇——LogAgent实现】 下篇&#xff1a;Go实现LogCollect&#xff1a;海量日志收集系统【下篇——开发LogTransfer】 项目架构图&#xff1a; 0 项目背景与方案选择 背景 当公司发展的越来越大&#xff0c;业务越来越复杂…