Python批量梯度下降法的举例

news2024/10/5 14:08:11

梯度下降法

梯度下降法是一种常用的优化算法,用于求解目标函数的最小值。其基本思想是,通过不断地朝着函数梯度下降的方向更新参数,直到找到函数的最小值。

具体来说,假设我们有一个可导的目标函数 f ( x ) f(x) f(x),我们要求它的最小值。首先,我们随机初始化一个参数向量 x 0 x_0 x0,然后计算该点处的梯度 g ( x 0 ) = ∇ f ( x 0 ) g(x_0) = \nabla f(x_0) g(x0)=f(x0)。接着,我们沿着梯度的负方向更新参数,即 x 1 = x 0 − η g ( x 0 ) x_{1} = x_{0} - \eta g(x_0) x1=x0ηg(x0),其中 η \eta η 是学习率,它控制了每一步更新的幅度。然后,我们继续计算 x 1 x_1 x1 处的梯度,重复上述更新过程,直到找到目标函数的最小值。

梯度下降法有两种常用的变体:批量梯度下降法和随机梯度下降法。批量梯度下降法在每次更新参数时都要计算全部样本的梯度,因此它的计算开销比较大,但是更新方向比较稳定,收敛速度比较慢。随机梯度下降法在每次更新参数时只考虑一个样本的梯度,因此它的计算开销比较小,但是更新方向比较不稳定,收敛速度比较快。

梯度下降法的数学公式推导如下:

这里我们以批量梯度下降法为例:

import numpy as np
import matplotlib.pyplot as plt

# 生成样本数据
np.random.seed(0) 
X = np.random.rand(50, 1)
# 生成目标值
y = 2 * X + np.random.randn(50, 1) * 0.1

# 定义损失函数
def loss_function(X, y, w):
    m = len(y)
    J = 1 / (2 * m) * np.sum((X.dot(w) - y) ** 2)
    return J

# 定义梯度下降函数
def gradient_descent(X, y, w, alpha, num_iters):
    m = len(y)
    J_history = np.zeros((num_iters, 1))
    for i in range(num_iters):
        w = w - alpha / m * X.T.dot(X.dot(w) - y)
        J_history[i] = loss_function(X, y, w)
        print("Iteration {}, w = {}, loss = {}".format(i, w.ravel(), J_history[i, 0]))
    return w, J_history

# 初始化参数
w = np.zeros((2, 1))
alpha = 0.1
num_iters = 10000000

# 添加一列偏置项
X_b = np.c_[np.ones((len(X), 1)), X]

# 运行梯度下降算法
w, J_history = gradient_descent(X_b, y, w, alpha, num_iters)

# 输出最终的参数值和损失函数值
print("Final parameters: w = {}, loss = {}".format(w.ravel(), J_history[-1, 0]))

# 绘制样本数据散点图
plt.scatter(X, y, alpha=0.5)

# 生成拟合直线的点坐标
x_line = np.array([[0], [1]])
y_line = x_line * w[1, 0] + w[0, 0]

# 绘制拟合直线
plt.plot(x_line, y_line, color='r')

# 显示图像
plt.show()

在Jupyter里面运行之后我们发现输出如下:

Output exceeds the size limit. Open the full output data in a text editorIteration 0, w = [0.10586793 0.07155122], loss = 0.5557491892500293
Iteration 1, w = [0.19729987 0.13480597], loss = 0.44015806935473656
Iteration 2, w = [0.27618572 0.19084247], loss = 0.3525877150650007
Iteration 3, w = [0.34416842 0.24059807], loss = 0.28619553470776754
Iteration 4, w = [0.40267618 0.28488764], loss = 0.23581044813963176
Iteration 5, w = [0.45295053 0.32441962], loss = 0.19752456266928253
Iteration 6, w = [0.49607077 0.35980988], loss = 0.16838459622747565
Iteration 7, w = [0.53297511 0.39159387], loss = 0.1461586810761424
Iteration 8, w = [0.56447915 0.42023707], loss = 0.12916013375789542
Iteration 9, w = [0.59129188 0.44614419], loss = 0.11611427531530936
Iteration 10, w = [0.61402962 0.46966706], loss = 0.10605778526555067
Iteration 11, w = [0.63322814 0.49111157], loss = 0.09826264183760337
Iteration 12, w = [0.64935317 0.51074369], loss = 0.09217864242878422
Iteration 13, w = [0.66280956 0.52879465], loss = 0.08738996542064487
Iteration 14, w = [0.67394923 0.54546548], loss = 0.08358234326756732
Iteration 15, w = [0.6830781  0.56093099], loss = 0.0805182546884386
Iteration 16, w = [0.69046209 0.57534318], loss = 0.07801817701869833
Iteration 17, w = [0.69633236 0.5888342 ], loss = 0.07594641831943069
Iteration 18, w = [0.70088983 0.60151897], loss = 0.07420041047978741
Iteration 19, w = [0.70430916 0.61349743], loss = 0.07270261784568362
Iteration 20, w = [0.70674215 0.62485649], loss = 0.07139442244222281
Iteration 21, w = [0.70832077 0.63567171], loss = 0.07023150293863561
Iteration 22, w = [0.70915971 0.64600884], loss = 0.06918034245759412
Iteration 23, w = [0.70935865 0.65592505], loss = 0.0682155894697291
Iteration 24, w = [0.70900424 0.66547007], loss = 0.06731806337787473
...
Iteration 999997, w = [-7.21008413e-04  1.96927329e+00], loss = 0.004277637843402933
Iteration 999998, w = [-7.21008413e-04  1.96927329e+00], loss = 0.004277637843402933
Iteration 999999, w = [-7.21008413e-04  1.96927329e+00], loss = 0.004277637843402933
Final parameters: w = [-7.21008413e-04  1.96927329e+00], loss = 0.004277637843402933

可以看到我们这里是迭代了100万次,它的loss已经下降到了千分之四的水平
得益于较低的loss率,我们可以看到线性回归的图像表现比较良好。

在这里插入图片描述
给定训练集 ( x ( 1 ) , y ( 1 ) ) , ( x ( 2 ) , y ( 2 ) ) , ⋯   , ( x ( m ) , y ( m ) ) {(x^{(1)}, y^{(1)}), (x^{(2)}, y^{(2)}), \cdots, (x^{(m)}, y^{(m)})} (x(1),y(1)),(x(2),y(2)),,(x(m),y(m)),其中 x ( i ) ∈ R n + 1 x^{(i)} \in \mathbb{R}^{n+1} x(i)Rn+1 y ( i ) ∈ R y^{(i)} \in \mathbb{R} y(i)R i = 1 , 2 , ⋯   , m i = 1, 2, \cdots, m i=1,2,,m,假设 y ( i ) y^{(i)} y(i) x ( i ) x^{(i)} x(i) 满足如下关系:

在这里插入图片描述

其中 w ∈ R n + 1 w \in \mathbb{R}^{n+1} wRn+1 是待求解的参数, ϵ ( i ) \epsilon^{(i)} ϵ(i) 是噪声项。我们的目标是找到一个 w w w 使得训练集上的损失函数最小:

在这里插入图片描述
其中 h w ( x ) = w T x h_w(x) = w^Tx hw(x)=wTx 是预测函数, m m m 是训练集的大小。
使用梯度批量下降法求解 w w w,更新公式为:
在这里插入图片描述
其中 α \alpha α 是学习率, m m m 是批量大小。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/447006.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

项目五:使用路由器构建园区网

使用路由器构建园区网 1、新建拓扑2、配置交换机与主机3、配置路由交换机并进行通信4、通信测试5、配置路由器并进行通信测试1、配置路由器R-12、配置路由器R-2、R-33、通信测试 1、新建拓扑 依次添加四台主机,两台交换机,型号为S3700。两台路由交换机&…

体制内干部职务职级及领导干部排序对应关系大全

请点击↑关注、收藏,本博客免费为你获取精彩知识分享!有惊喜哟!! 一、公务员级别对应关系 (一)综合管理公务员职务与职级 1、职务分为10级,包括:正国职、副国职、正部职、副部职、正…

【WSN定位】基于加权双曲线的Dvhop定位算法【Matlab代码#16】

文章目录 1. 原始Dvhop定位2. 基于双曲线的Dvhop定位3. 对原始模型加权4. 部分代码5. 结果展示6. 资源获取7. 参考文献 1. 原始Dvhop定位 可参考Dvhop定位算法 2. 基于双曲线的Dvhop定位 双曲线定位算法是一种通过将待定位节点定位在以锚节点为焦点、两锚节点之间距离为焦距…

字符集与字符编码(ASCII、GBK、UNICODE)

1 常见编码 1.1 单字节编码:ASCII ASCII使用1个字节(8个bit)来记录一组常用字符,见下表: 例如其中字母a的二进制位:1100 001 97,那么a在计算机中就可以用1100001来保存。 注意上表中其实只…

【02-Java Web先导课】-Tomcat服务器的下载与安装

文章目录 前言一、Tomcat服务器(apache-tomcat-8.5.28)的 下载1、下载地址 二、Tomcat服务器的安装1、Tomcat目录结构2、Tomcat的启动与停止4、Tomcat启动成功后的测试 免责声明: 前言 Tomcat主要实现了Java EE中的Servlet、JSP规范&#xf…

【Python爬虫项目实战三】Ddddocr识别Ocr过开放猫验证码(接Authorization认证更新)

目录 🍇前言🍍验证码识别的几个方法🥥百度AI开放平台🥥Ddddocr🦑分析验证码位数🦑获取验证码接口🦑算法识别匹配🦑请求登陆接口 🍋总结: 🍇前言 …

Doris(13):数据模型

在 Doris 中,数据以表(Table)的形式进行逻辑上的描述。一张表包括行(Row)和列(Column)。Row即用户的一行数据。Column 用于描述一行数据中不同的字段。 Column可以分为两大类:Key&a…

Java双亲委派和类加载器

Java双亲委派和类加载器 Java类生命周期主要内容类加载器的分类Bootstrap ClassLoader非Bootstrap ClassLoaderExtension ClassLoaderApplication ClassLoaderUser ClassLoader 类加载的命名空间问题提出双亲委派机制问题解答 破坏双亲委派破坏双亲委派-第一次破坏双亲委派-第二…

【MySQL】GROUP BY分组子句与联合查询的使用详解

目录 前篇都在这里喔~ MySQL的增删改查 MySQL数据库约束和聚合函数的使用 1.GROUP BY子句 练习表如下: 1.查询不包含董事长的平均工资 2.按照角色分组计算平均工资 3.过滤掉平均工资大于一万的角色 4.♥过滤数据♥ 2.联合查询 以下列表作为依据 1.内连接 …

(十二)rk3568 NPU 中部署自己训练的模型,(1)使用yolov5训练自己的数据集-环境搭建部分

rk3568中带有0.8T算力的NPU,可以完成一些轻量级的图像识别任务。 本文向零基础人员介绍从windows中搭建训练环境,模型训练、模型转换到rknn模型部署到电路板上全部过程。 rk3568npu支持caffe、darknet、onnx、pytorch、tensorflow等多种框架。 本人使用…

springboot+vue企业人事人力资源管理系统java公司员工出差考勤办公OA系统

“简易云”是这个系统的名字 (6)系统管理:主要下拉分为角色管理、菜单管理; 角色管理:此页面可对角色进行增删改查操作,可修改不同角色的权限; 菜单管理:此页面可配置系统可展示的菜…

linux学习记录 和文件系统相关的命令

记录过程,会有错误,硬链接与软链接哪里可能没有说清楚 文件,目录操作命令 pwd 获取当前处于哪个目录当中,返回的是绝对路径 [rootlocalhost home]# pwd /homecd cd 相对/绝对路径 切换目录的,change directory .代表当前目录 …代表上一级…

【C++学习】类和对象--对象特性(1)

构造函数和析构函数 对象的初始化和清理是两个非常重要的安全问题 一个对象或变量没有初始状态,对其使用后果是未知的 使用完一个对象或变量,没有及时清理,也会造成一定的安全问题 C利用构造函数和析构函数解决上述问题,这两个函数…

排序 Comparable 和 Comparator 区别所在

在 Java 中,Comparable 和Comparator 都是用来元素排序的,但是本质不用。我们从几点开始分析。 1.字面含义 Comparable 中文翻译是”比较“,以 able 结尾 说明它具有某种能力。 Comparator 中文翻译是”比较器“,以 or 结尾 表明…

【C++ 二十】STL:遍历、查找、排序、拷贝和替换、算术生成、集合算法

STL:遍历、查找、排序、拷贝和替换、算术生成、集合算法 文章目录 STL:遍历、查找、排序、拷贝和替换、算术生成、集合算法前言1 常用遍历算法1.1 for_each1.2 transform 2 常用查找算法2.1 find2.2 find_if2.3 adjacent_find2.4 binary_search2.5 count…

室内人员定位系统源码,采用java语言+UWB定位技术开发

运用UWB定位技术开发的人员定位系统源码 文末获取联系 本套系统运用UWB定位技术开发的高精度人员定位系统,通过独特的射频处理,配合先进的位置算法,可以有效计算复杂环境下的人员与物品的活动信息。 系统提供位置实时显示、历史轨迹回放、人…

循序渐进,学会用pyecharts绘制瀑布图

循序渐进,学会用pyecharts绘制瀑布图 瀑布图简介 瀑布图(Waterfall Plot)是由麦肯锡顾问公司所独创的图表类型,因为形似瀑布流水而称之为瀑布图。 瀑布图采用绝对值与相对值结合的方式,适用于表达多个特定数值之间的数量变化关系。当用户想…

本地Nacos设置脚本命令启动

一、起因: 每次启动都要找到位置写一遍命令费劲。 1、可设置开机启动 2、可设置脚本自动 二、配置脚本: 1、这是我nacos的位置 用bat命令启动一个cmd命令行,然后在里面执行两天命令。 ①命令一:打开指定路径 ②命令二&#xf…

Java图书借阅管理系统详细设计和实现

基于JavaSpringHtml的图书借阅管理系统详细设计和实现 博主介绍:5年java开发经验,专注Java开发、定制、远程、文档编写指导等,csdn特邀作者、专注于Java技术领域 作者主页 超级帅帅吴 Java毕设项目精品实战案例《500套》 欢迎点赞 收藏 ⭐留言 文末获取源…

cmd连接本地mysql数据库和远程服务器mysql数据库

1.在cmd窗口里连接本地的mysql数据库 打开运行窗口,输入cmd,确定 windowsr 或在左下角windows图标处鼠标右键,点击运行按钮打开运行窗口 格式: mysql -u用户名 -p密码 mysql -uroot -p123456 成功进入mysql 2. 在cmd窗口里连接远…