深入浅出机器学习中的梯度下降算法

news2024/12/27 2:28:20

大家好,在机器学习中,梯度下降算法(Gradient Descent)是一个重要的概念。它是一种优化算法,用于最小化目标函数,通常是损失函数。梯度下降可以帮助找到一个模型最优的参数,使得模型的预测更加准确,本文将介绍梯度下降算法的原理、公式以及在Python中实现这一算法。

1. 梯度下降算法的理论基础

在数学中,梯度是一个向量,表示函数在某一点的变化率和方向。在多维空间中,梯度指向函数上升最快的方向。

图片

可以通过梯度来找到函数的最小值或最大值,对于损失函数关注的是最小值。

梯度下降的核心思想是通过不断调整参数,沿着损失函数的梯度方向移动,从而逐步逼近最小值。具体步骤如下:

(1) 初始化参数:随机选择参数的初始值。

(2) 计算梯度:计算损失函数对每个参数的梯度。

(3) 更新参数:根据梯度信息调整参数,更新规则为:

其中:\theta是要优化的参数;\alpha是学习率(step size),决定每次更新的幅度;\triangledown J(\theta )是损失函数关于参数的梯度。

(4) 重复步骤:重复计算梯度和更新参数,直到收敛(即损失函数的变化非常小)。

假设我们有一个简单的线性回归问题,目标是最小化均方误差(MSE)损失函数: 

其中\hat{Y}_{i} = \theta _{0} + \theta _{1}X_{i}是模型的预测值。为了使用梯度下降,我们需要计算损失函数关于参数的梯度: 

通过求导,可以得到梯度表达式,并利用它来更新参数。

2. Python 实现梯度下降算法

接下来将通过一个简单的线性回归示例来实现梯度下降算法,以下是实现代码:

import numpy as np
import matplotlib.pyplot as plt

生成一些随机数据来模拟房屋面积与房价之间的线性关系:

# 生成数据
np.random.seed(0)

# 生成自变量 X(房屋面积),范围从50到200平方米
X = 50 + 150 * np.random.rand(100)  # 生成从50到200的100个点

# 生成因变量 Y(房价),假设房价与房屋面积的关系
Y = 300000 + 2000 * X + np.random.randn(100) * 20000  # 线性关系加上噪声,价格范围在30万到50万之间

# 绘制生成的散点图
plt.scatter(X, Y, color='blue', alpha=0.5)
plt.title('房屋面积与房价的关系')
plt.xlabel('房屋面积 (平方米)')
plt.ylabel('房价 (人民币)')
plt.grid()
plt.show()

图片

实现梯度下降算法的核心部分:

# 将数据标准化,帮助梯度下降更快收敛
X = (X - np.mean(X)) / np.std(X)
Y = (Y - np.mean(Y)) / np.std(Y)

# 梯度下降参数
alpha = 0.01  # 学习率
num_iterations = 1000  # 迭代次数
m = len(Y)  # 样本数量

# 初始化参数
theta_0 = 0  # 截距
theta_1 = 0  # 斜率

# 存储损失值
losses = []

# 梯度下降算法实现
for i in range(num_iterations):
    # 计算预测值
    Y_pred = theta_0 + theta_1 * X
    
    # 计算损失函数 (MSE)
    loss = (1/m) * np.sum((Y - Y_pred) ** 2)
    losses.append(loss)
    
    # 计算梯度
    gradient_0 = -(2/m) * np.sum(Y - Y_pred)  # 截距的梯度
    gradient_1 = -(2/m) * np.sum((Y - Y_pred) * X)  # 斜率的梯度
    
    # 更新参数
    theta_0 -= alpha * gradient_0
    theta_1 -= alpha * gradient_1

print(f'截距 (θ0): {theta_0:.4f}, 斜率 (θ1): {theta_1:.4f}')

截距 (θ0): 0.0000, 斜率 (θ1): 0.9743

通过绘制损失函数随迭代次数变化的曲线,观察梯度下降的收敛过程。

# 绘制损失函数变化曲线
plt.figure()
plt.plot(range(num_iterations), losses, color='blue')
plt.title('损失函数随迭代次数的变化')
plt.xlabel('迭代次数')
plt.ylabel('损失值 (MSE)')
plt.grid()
plt.show()

最后,我们可以将训练好的回归线可视化,以观察模型的效果。​​​​​​​​​​​​​​

# 可视化回归线
plt.figure()
plt.scatter(X, Y, color='blue', alpha=0.5)
plt.plot(X, theta_0 + theta_1 * X, color='red', linewidth=2)
plt.title('梯度下降后的线性回归拟合')
plt.xlabel('房屋面积 (标准化)')
plt.ylabel('房价 (标准化)')
plt.grid()

plt.tight_layout()  # 调整子图间距
plt.show()

图片

梯度下降算法在许多机器学习算法中得到了广泛应用,比如线性回归、逻辑回归、神经网络等,可以用于分类问题,通过优化对数损失函数,也可以用于深度学习,反向传播算法依赖于梯度下降来更新权重。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2251660.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

PotPlayer 最新版本支持使用 Whisper 自动识别语音生成字幕

PotPlayer 最新版本支持使用 Whisper 自动识别语音生成字幕 设置使用下载地址 设置 使用 下载地址 https://www.videohelp.com/software/PotPlayer

【0x0001】HCI_Set_Event_Mask详解

目录 一、命令概述 二、命令格式 三、命令参数说明 四、返回参数说明 五、命令执行流程 5.1. 主机准备阶段 5.2. 命令发送阶段 5.3. 控制器接收与处理阶段 5.4. 事件过滤与反馈阶段 5.5. 主机处理(主机端) 5.6. 示例代码 六、命令应用场景 …

可解释机器学习 | Python实现LGBM-SHAP可解释机器学习

机器学习 | Python实现GBDT梯度提升树模型设计 目录 机器学习 | Python实现GBDT梯度提升树模型设计基本介绍模型使用参考资料基本介绍 LightGBM(Light Gradient Boosting Machine)是一种基于决策树的梯度提升框架,是一种高效的机器学习模型。SHAP(SHapley Additive exPlan…

【Rust在WASM中实现pdf文件的生成】

Rust在WASM中实现pdf文件的生成 概念和依赖问题描述分步实现最后 概念和依赖 . WASM WebAssembly(简称WASM)是一个虚拟指令集体系架构(virtual ISA),旨在为C/C等语言编写的程序提供一种高效的二进制格式,使…

Java开发每日一课:Java开发岗位要求

找过工作的朋友应该知道,现在招聘Java开发工程师的公司特别多。那么Java开发这个岗位具体是做什么的?又有什么招聘要求呢? 我毕业的时候也面试过很多公司,当时对Java开发这个岗位的要求不甚了解,因为懂Java语法知识就能…

Spring Boot日志总结

文章目录 1.我们的日志2.日志的作用3.使用日志对象打印日志4.日志框架介绍5.深入理解门面模式(外观模式)6.日志格式的说明7.日志级别7.1日志级别分类7.2配置文件添加日志级别 8.日志持久化9.日志文件的拆分9.1官方文档9.2IDEA演示文件分割 10.日志格式的配置11.更简单的日志输入…

2025年Java面试八股文大全

很多人会问Java面试八股文有必要背吗? 我的回答是:很有必要。你可以讨厌这模式,但你一定要去背,因为不背你就进不了大厂。 国内的互联网面试,恐怕是现存的、最接近科举考试的制度。 而且,我国的八股文确…

DPDK用户态协议栈-Tcp Posix API 1

和udp一样&#xff0c;我们需要实现和系统调用一样的接口来实现我们的tcp server。先来看看我们之前写的unix_tcp使用了哪些接口&#xff0c;这边我加上两个系统调用&#xff0c;分别是接收数据和发送数据。 #include <stdio.h> #include <arpa/inet.h> #include …

记一次搞校园网的经历

接教室的校园网&#xff0c;到另一个屋子玩电脑&#xff0c;隔墙想放大一下AP的信号&#xff0c;发现死活不行 这是现状 由于校园网认证的存在&#xff0c;无法用桥接&#xff0c;桥接需要路由器有IP&#xff0c;而这个IP无法用未刷机的路由器来打开校园网页面认证 解决 将一…

RTC 实时时钟实验

利用 ALIENTEK 2.8 寸 TFTLCD 模块来显示日期和时间&#xff0c;实现一个简单的时钟。 STM32F1 RTC 时钟简介 STM32 的实时时钟&#xff08; RTC &#xff09;是一个独立的定时器。 STM32 的 RTC 模块拥有一组连续计数 的计数器&#xff0c;在相应软件配置下&#xf…

接口性能优化宝典:解决性能瓶颈的策略与实践

目录 一、直面索引 &#xff08;一&#xff09;索引优化的常见场景 &#xff08;二&#xff09;如何检查索引的使用情况 &#xff08;三&#xff09;如何避免索引失效 &#xff08;四&#xff09;强制选择索引 二、提升 SQL 执行效率 &#xff08;一&#xff09;避免不必…

2021陇剑杯-内存取证

内存分析&#xff08;问1&#xff09; 网管小王制作了一个虚拟机文件&#xff0c;让您来分析后作答&#xff1a; 虚拟机的密码是_____________。&#xff08;密码中为flag{xxxx}&#xff0c;含有空格&#xff0c;提交时不要去掉&#xff09;。 mimikatz一把梭了 flag{W31C0M3…

Ubuntu 安装 MariaDB

安装 MariaDB具体步骤 1、更新软件包索引&#xff1a; sudo apt update2、安装 MariaDB 服务器&#xff1a; sudo apt install mariadb-server3、启动 MariaDB 服务&#xff08;如果未自动启动&#xff09;&#xff1a; sudo systemctl start mariadb4、设置 MariaDB 开机启…

深度学习Python基础(2)

二 数据处理 一般来说PyTorch中深度学习训练的流程是这样的&#xff1a; 1. 创建Dateset 2. Dataset传递给DataLoader 3. DataLoader迭代产生训练数据提供给模型 对应的一般都会有这三部分代码 # 创建Dateset(可以自定义) dataset face_dataset # Dataset部分自定义过的…

Linux下的三种 IO 复用

目录 一、Select 1、函数 API 2、使用限制 3、使用 Demo 二、Poll 三、epoll 0、 实现原理 1、函数 API 2、简单代码模板 3、LT/ET 使用过程 &#xff08;1&#xff09;LT 水平触发 &#xff08;2&#xff09;ET边沿触发 4、使用 Demo 四、参考链接 一、Select 在…

Windows常用DOS指令(附案例)

文章目录 1.dir 查看当前目录2.cd 进入指定目录3.md 创建指定目录4.cd> 创建指定文件5.rd 删除指定空目录6.del 删除指定文件7.copy 复制文件8.xcopy 批量复制9.ren 改名10.type 在命令行空窗口打开文件11.cls 清空DOS命令窗口12.chkdsk 检查磁盘使用情况13.time 显示和设置…

【Linux】匿名管道通信场景——进程池

&#x1f525; 个人主页&#xff1a;大耳朵土土垚 &#x1f525; 所属专栏&#xff1a;Linux系统编程 这里将会不定期更新有关Linux的内容&#xff0c;欢迎大家点赞&#xff0c;收藏&#xff0c;评论&#x1f973;&#x1f973;&#x1f389;&#x1f389;&#x1f389; 文章目…

C#基础之集合讲解

文章目录 1 集合1.1 数组1.1.1 简介1.1.2 声明使用1.1.2.1 声明 & 初始化1.1.2.2 赋值给数组1.1.2.3 访问数组元素 1.1.3 多维数组1.1.3.1 声明1.1.3.2 初始化二维数组1.1.3.3 访问二维数组元素 1.1.4 交错数组1.1.5 传递数组给函数1.1.6 Array1.1.6.1 简介1.1.6.2 属性1.1…

Azure DevOps Server:使用甘特图Gantt展示需求进度

自从Azure DevOps Server取消与Project Server的集成后&#xff0c;许多用户都在关注如何使用甘特图来展示项目进度。 在Azure DevOps Server开放扩展Extension功能后&#xff0c;许多开发者或专业开发团队做了很多甘特图Gantt相关的开发工作&#xff0c;使用比较多的是(GANTT …

数据湖的概念(包含数据中台、数据湖、数据仓库、数据集市的区别)--了解数据湖,这一篇就够了

文章目录 一、数据湖概念1、企业对数据的困扰2、什么是数据湖3、数据中台、数据湖、数据仓库、数据集市的区别 网上看了好多有关数据湖的帖子&#xff0c;还有数据中台、数据湖、数据仓库、数据集市的区别的帖子&#xff0c;发现帖子写的都很多&#xff0c;而且专业名词很多&am…