深入探讨梯度下降:优化机器学习的关键步骤(三)

news2024/11/27 2:26:49

文章目录

  • 🍀引言
  • 🍀随机、批量梯度下降的差异
  • 🍀随机梯度下降的实现
  • 🍀随机梯度下降的调试

🍀引言

随机梯度下降是一种优化方法,主要作用是提高迭代速度,避免陷入庞大计算量的泥沼。在每次更新时,随机梯度下降只使用一个样本中的一个例子来近似所有的样本,来调整参数,虽然不是全局最优解,但很多时候是可接受的。

前两篇主要介绍了一下批量梯度下降,本节前部分主要介绍一下随机梯度下降


🍀随机、批量梯度下降的差异

随机梯度下降和批量梯度下降都是常用的优化方法,它们在处理大规模数据集时都有自己的优点和缺点。以下是它们的不同点:

  • 相同点:
    两种方法都用于优化目标函数,通过迭代地更新参数来最小化目标函数。在每一步迭代中,它们都会根据当前参数的梯度来更新参数。

  • 不同点:
    (1)样本的使用方式:在随机梯度下降中,每次迭代只使用**一个样本**来计算梯度;而在批量梯度下降中,每次迭代会使用整个数据集来计算梯度。因此,随机梯度下降在处理大规模数据集时更高效,因为它不需要加载整个数据集到内存中。

    (2)收敛速度:由于随机梯度下降每次只使用一个样本来计算梯度,因此它的收敛速度通常比批量梯度下降更快。但是,随机梯度下降的收敛可能更加波动,因为每次迭代的样本可能不同。

    (3)准确度:批量梯度下降的准确度通常比随机梯度下降更高。因为批量梯度下降会使用整个数据集来计算梯度,因此它的更新更精确。但是,在处理大规模数据集时,批量梯度下降可能会遇到内存不足的问题。

这里可以通过下列图来进行简单的说明
请添加图片描述

上面这种图是批量梯度下降的主要公式,前两篇文章已经介绍了
请添加图片描述
上面的这张图指的就是随机梯度下降的主要公式了,我们可以看到求个符号消失了

🍀随机梯度下降的实现

导入必要的库

import numpy as np

选取100000个数据作为测试数据

m = 100000
x = np.random.random(size=m)
y = x*3+4+np.random.normal(size=m)  # 后面的添加的噪音

注意:后面加了一个噪音目的是使得原有的数据添加一些随机性,省的太假了~
之后我们需要编写两个函数,前一个函数主要是用来计算样本的梯度,后一个函数主要包括计算学习率以及循环判断

def sgd(X_b,y,initial_theta,n_iters,epsilon=1e-8):
    def learning_rate(i_iter):
        t0=5
        t1 = 50
        return t0/(i_iter+t1)
    theta = initial_theta
    i_iter = 1
    while i_iter<=n_iters:
        index=np.random.randint(0,len(X_b))
        x_i = X_b[index]
        y_i = y[index]
        gradient = dj_sgd(theta,x_i,y_i)
        theta = theta-gradient*learning_rate(i_iter)
        i_iter+=1
    return theta

注意:在学习率的计算采用模拟退火思想,目的是为了控制参数的变化来影响行为,从而达到更好的优化效果。
请添加图片描述
之后我们需要使用numpy库中的hstack函数在x左侧添加一列

X_b = np.hstack([np.ones((len(x),1)),x])  # 左测增加一列

在添加前,我们需要将x转成矩阵

x = x.reshape(-1,1)

运行结果如下
在这里插入图片描述
之后我们需要设置initial_theta初始值

initial_theta = np.zeros(X_b.shape[1])

前提的准备做完就可以验证了

%%time 
sgd(X_b,y,initial_theta,n_iters=m//4)

运行结果如下
在这里插入图片描述
返回的值,分别近似截距和系数


我们可以将代码再优化一下

def sgd(X_b, y, initial_theta, n_iters, epsilon=1e-8):
    def learning_rate(i_iter):
        t0 = 5
        t1 = 50
        return t0 / (i_iter + t1)

    theta = initial_theta  # 初始化模型参数
    m = len(X_b)  # 样本数量

    for cur_iter in range(n_iters):  # 迭代n_iters次,每轮迭代看一遍整个样本
        random_indexs = np.random.permutation(m)  # 随机打乱样本的顺序,用于随机梯度下降
        X_random = X_b[random_indexs]  # 打乱后的特征数据
        y_random = y[random_indexs]  # 打乱后的标签数据

        for i in range(m):  # 遍历每个样本
            # 使用学习率learning_rate(cur_iter*m+i)来更新模型参数theta,通过梯度dj_sgd计算
            theta = theta - learning_rate(cur_iter * m + i) * dj_sgd(theta, X_random[i], y_random[i])

    return theta  # 返回优化后的模型参数

这个函数使用了随机梯度下降算法来更新模型参数,通过不断地随机选择一个样本进行参数更新,逐渐优化模型以适应训练数据。学习率随着迭代次数变化,初始较大然后逐渐减小,以有利于收敛到最优解。


🍀随机梯度下降的调试

首先还是做前期的准备

import numpy as np
X = np.random.random(size=(1000,10))
X_b = np.hstack([np.ones((len(X),1)),X])
true_theta = np.arange(1,12,dtype='float') # 这里代表有11个特征值(10个系数,1个截距)
y = X_b.dot(true_theta) + np.random.normal(size=len(X))

之后我们分别才有两种方法进行调试
首先是dj_math

这个函数用于计算线性回归中的成本函数(通常是均方误差)相对于参数 theta 的梯度,采用了矢量化的方法。这是数学公式:

在这里插入图片描述

  • X_b 是包含偏置项的特征矩阵(通常是原始特征矩阵的一列加上全部为 1 的列)。
  • y 是目标向量。
  • theta 是待更新的参数向量。
  • m 是训练样本的数量。
def dj_math(theta,X_b,y):
    return X_b.T.dot(X_b.dot(theta)-y)*2./len(X_b)

其次是dj_debug

这个函数使用数值逼近方法来计算成本函数相对于参数的梯度。它通过轻微地扰动每个参数 theta[i] 并测量成本函数 j 的变化来估计梯度。这是数学公式:

在这里插入图片描述

  • theta 是参数向量。
  • X_b 是包含偏置项的特征矩阵。
  • y 是目标向量。
  • i 是被扰动的参数的索引。
  • epsilon 是用于扰动的小值。
def dj_debug(theta,X_b,y):
    res=np.empty(len(theta))
    epsilon = 0.01
    for i in range(len(theta)):
        theta1 = theta.copy()
        theta2 = theta.copy()
        theta1[i] +=epsilon
        theta2[i] -=epsilon
        res[i] = (j(theta1,X_b,y)-j(theta2,X_b,y))/(2*epsilon)
    return res

这种数值逼近通常用于调试和验证梯度计算的正确性,特别是在梯度下降等基于梯度的优化算法中,有助于优化参数 theta 的训练过程

完整代码如下

def j(theta,X_b,y):
    try:
        return np.sum((X_b.dot(theta)-y)**2)/len(X_b)
    except:
        return float('inf')

def dj_math(theta,X_b,y):
    return X_b.T.dot(X_b.dot(theta)-y)*2./len(X_b)

def dj_debug(theta,X_b,y):
    res=np.empty(len(theta))
    epsilon = 0.01
    for i in range(len(theta)):
        theta1 = theta.copy()
        theta2 = theta.copy()
        theta1[i] +=epsilon
        theta2[i] -=epsilon
        res[i] = (j(theta1,X_b,y)-j(theta2,X_b,y))/(2*epsilon)
    return res


def gradient_descent(dj,X_b,y,eta,initial_theta,n_iters=1e4,epsilon=1e-8):
    theta = initial_theta
    i_iter = 1
    while i_iter<n_iters:
        last_theta = theta
        theta =theta- eta*dj(theta,X_b,y)
        if abs(j(theta,X_b,y)-j(last_theta,X_b,y))<epsilon:
            break
        i_iter+=1
    return theta

可以分别进行测试一下,显然前者更快一点
在这里插入图片描述

请添加图片描述

挑战与创造都是很痛苦的,但是很充实。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/992960.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

【图神经网络 01】

图的基本构成&#xff1a; V&#xff1a;Vertex (or node) attributes E&#xff1a;Edge (or link) attributes and directions U&#xff1a;Global (or master node) attributes 图的邻接矩阵&#xff1a;文本数据也可以表示图的形式&#xff0c;邻接矩阵表示的连接关系。 以…

计算机竞赛 基于深度学的图像修复 图像补全

1 前言 &#x1f525; 优质竞赛项目系列&#xff0c;今天要分享的是 基于深度学的图像修复 图像补全 该项目较为新颖&#xff0c;适合作为竞赛课题方向&#xff0c;学长非常推荐&#xff01; &#x1f9ff; 更多资料, 项目分享&#xff1a; https://gitee.com/dancheng-se…

牛客网项目-第一章-笔记

牛客网项目-第一章 环境配置 java maven idea Spring Intializr 搜索jar包的网站&#xff1a;https://mvnrepository.com/ https://start.spring.io/ 缺少的aop包&#xff0c;手动在pom.xml中加入依赖 <dependency><groupId>org.springframework.boot</gro…

OpenRoads Designer导入文本格式水平路线、路线纵断面

ORD可以用以文本文件进行定义水平几何路线及纵断面几何路线直接导入来完成几何路线的定义&#xff1a; 水平路线 平面几何路线示例 “平面几何路线.txt”文件内容&#xff1a; BP-1,54376.169,1816.914 BP,54376.101,1817.912 JD01,54358.369,2081.452,0 JD02,54810.789,477…

linux 基础命令 cd /xxx 和 cd xxx 的区别

cd 命令&#xff1a;用于改变当前工作目录的命令&#xff0c;作用&#xff1a;切换当前目录至其它目录 用 cd 命令去 home目录&#xff1a; cd home/ 用cd 命令 去tony 目录下 cd ../ 返回上级目录 cd ../ tony / 返回上级目录进入和hom 同级的tony 目录 这里要讲 linux …

树(一)树和二叉树的基本概念

文章目录 一、树1、什么是树2、树的相关概念3、树的表示 二、二叉树1、二叉树的概念2、二叉树的几种情况3、特殊二叉树4、二叉树的性质5、二叉树的存储结构 一、树 1、什么是树 树是一种非线性的数据结构&#xff0c;它是由n&#xff08;n>0&#xff09;个有限结点组成一个…

C++学习——vector类的使用

目录 vector类的介绍&#xff1a; vector类的构造函数: operator operator [ ] begin & end size & resize capacity & reserve push_back & pop_back insert & erase vector类的介绍&#xff1a; vector是C标准模板库中的部分内容&#xff0c;中文偶尔…

【Python】OpenCV立体相机配准与三角化代码实现

下面的介绍了使用python和OpenCV对两个相机进行标定、配准,同时实现人体关键点三角化的过程 import cv2 as cv import glob import numpy as np import matplotlib.pyplot as pltdef calibrate_camera(images_folder):images_names = glob.glob(images_folder

css画一条渐变的虚线

效果展示 原理&#xff1a;给元素设置一个渐变的背景色&#xff0c;画一条白色的虚线盖住背景&#xff0c;就达到了渐变虚线的效果 代码&#xff1a; <div class"pending-line"></div>.pending-line{width: 101px;border-top: 2px dashed #fff; // do…

C++算法 —— 动态规划(3)多状态

文章目录 1、动规思路简介2、按摩师3、打家劫舍Ⅱ4、删除并获得点数5、粉刷房子6、买卖股票的最佳时机含冷冻期7、买卖股票的最佳时机含手续费8、买卖股票的最佳时机Ⅲ9、买卖股票的最佳时间Ⅳ 每一种算法都最好看完第一篇再去找要看的博客&#xff0c;因为这样会帮你梳理好思路…

正式支持 NVIDIA A100,吞吐量提高 10 倍的Milvus Cloud2.3 使用指南

Milvus 2.3 正式支持 NVIDIA A100! 作为为数不多的支持 GPU 的向量数据库产品,Milvus 2.3 在吞吐量和低延迟方面都带来了显著的变化,尤其是与此前的 CPU 版本相比,不仅吞吐量提高了 10 倍,还能将延迟控制在极低的水准。 不过,正如我前面提到的,鲜有向量数据库支持 GPU,…

必须收藏 | 如何完全卸载ArcGIS

好多小伙伴在卸载ArcGIS过程都遇到了卸载不彻底无法重新安装新版本&#xff0c;卸载残留的注册表找不到等一系列问题&#xff0c;今天小编为大家整理了几个如何完全卸载ArcGIS的方法&#xff0c;希望能够帮到大家&#xff01; #1快捷版 1、开始>控制面板>添加删除程序&…

MR源码解析和join案例

MR源码解析 new Job(): 读取本地文件, xml配置job.start(): 启动线程job的run():线程方法 runTasks(): 传入对应的接口&#xff0c;启动map或者reduceMapTask类的run(): 设置map阶段的参数&#xff0c;初始化任务&#xff0c;创建上下文对象 创建读取器LineRecordReader判断是…

【计算机网络】HTTPS

文章目录 1. HTTPS的概念2. 加密常见的加密方式对称加密非对称加密 3. HTTPS的工作过程的探究方案1 —— 只使用对称加密方案2 —— 只使用 非对称加密方案3 —— 双方都是用非对称加密方案4 —— 非对称加密对称加密中间人攻击引入证书CA认证理解数据签名 方案5 —— 非对称加…

【Redis】1、NoSQL之Redis的配置及优化

关系数据库与非关系数据库 关系型数据库 关系型数据库是一个结构化的数据库&#xff0c;创建在关系模型&#xff08;二维表格模型&#xff09;基础上&#xff0c;一般面向于记录。 SQL 语句&#xff08;标准数据查询语言&#xff09;就是一种基于关系型数据库的语言&a…

WebGL 绘制矩形

上一节绘制了圆点&#xff0c;调用的绘制方法如下&#xff1a;gl.drawArrays(gl.POINTS, 0, 1); 第一个参数明显是个枚举类型&#xff0c;肯定还有其他值&#xff0c;如下所示&#xff1a; POINTS 可视的点LINES 单独线段LINE_STRIP 线条LINE_LOOP 闭合线条TRIANGLES 单独三…

【Redis7】--1.概述、安装和配置

文章目录 1.Redis概述1.1Redis是什么1.2Redis与MySQL的关系1.3Redis功能1.4Redis优势 2.Redis的安装和配置 1.Redis概述 1.1Redis是什么 Redis全称 远程字典服务器&#xff08;Remote Dictionary Server&#xff09;&#xff0c;它是完全开源的&#xff0c;使用ANSIC语言编写…

算法-26. 删除有序数组中的重复项-⭐

给你一个 升序排列 的数组 nums &#xff0c;请你 原地 删除重复出现的元素&#xff0c;使每个元素 只出现一次 &#xff0c;返回删除后数组的新长度。元素的 相对顺序 应该保持 一致 。然后返回 nums 中唯一元素的个数。 考虑 nums 的唯一元素的数量为 k &#xff0c;你需要做…

【数据分析】Python:处理缺失值的常见方法

在数据分析和机器学习中&#xff0c;缺失值是一种常见的现象。在实际数据集中&#xff0c;某些变量的某些条目可能没有可用的值。处理缺失值是一个重要的数据预处理步骤。在本文中&#xff0c;我们将介绍如何在 Pandas 中处理缺失值。 我们将探讨以下内容&#xff1a; 什么是缺…

php将数组中的最后一个元素放到第一个

array_unshift($firstStepResult, array_pop($firstStepResult)); 转换之后