Python中随机梯度下降法

news2024/12/23 13:09:51

随机梯度下降法

批量梯度下降使用全部的训练样本来计算梯度,并更新模型参数,因此它的每一次迭代计算量较大,但对于凸优化问题,可以保证每次迭代都朝着全局最优解的方向前进,收敛速度较快,最终收敛到的结果也比较稳定。

随机梯度下降则每次迭代仅使用一个样本来计算梯度,并更新模型参数,因此每次迭代的计算量较小,但收敛速度较慢,最终收敛结果也不够稳定,可能会陷入局部最优解。

在实际应用中,批量梯度下降通常用于训练数据量较小、维度较高的情况,而随机梯度下降通常用于训练数据量较大、维度较低的情况。同时也可以采用一种介于两者之间的小批量梯度下降(Mini-Batch Gradient Descent),即每次迭代使用一定数量的随机样本来计算梯度,并更新模型参数,这种方法在训练大规模数据集时也比较实用。
我们用一个函数举例子:
在这里插入图片描述
这里需要求它的梯度:
在这里插入图片描述

随机梯度下降法的代码实现:

import numpy as np
from mpl_toolkits.mplot3d import Axes3D
import matplotlib.pyplot as plt
np.random.seed(42)  # 设置随机种子,保证每次运行结果相同
x = np.random.randn(100, 3)
y = x.dot(np.array([4, 5, 6])) + np.random.randn(100) * 0.1
def loss_function(w, x, y):
    return 0.5 * np.mean((np.dot(x, w) - y) ** 2)

def gradient_function(w, x, y):
    return np.dot(x.T, np.dot(x, w) - y) / len(y)
def SGD(x, y, w_init, alpha, max_iter):
    w = w_init
    for i in range(max_iter):
        rand_idx = np.random.randint(len(y))
        x_i = x[rand_idx, :].reshape(1, -1)
        y_i = y[rand_idx]
        grad_i = gradient_function(w, x_i, y_i)
        w = w - alpha * grad_i
    return w
fig = plt.figure()
ax = Axes3D(fig)
W0 = np.arange(0, 10, 0.1)
W1 = np.arange(0, 10, 0.1)
W0, W1 = np.meshgrid(W0, W1)
W2 = np.array([SGD(x, y, np.array([w0, w1, 0]), 0.01, 1000)[2] for w0, w1 in zip(np.ravel(W0), np.ravel(W1))])
W2 = W2.reshape(W0.shape)
ax.plot_surface(W0, W1, W2, cmap='coolwarm')
ax.set_xlabel('w0')
ax.set_ylabel('w1')
ax.set_zlabel('loss')
plt.show()



在这里插入图片描述
这个图像表示了随机梯度下降算法在二元线性回归问题中对于不同初始权重( x x x y y y)的收敛情况。横轴和纵轴分别表示 x x x y y y的值,而竖轴表示算法收敛的迭代次数。不同颜色的曲面表示不同的收敛路径。可以看到,算法在不同的初始权重下都能收敛到大致相同的最优权重,这也验证了随机梯度下降算法的鲁棒性和适用性。此外,可以看到不同初始权重下的收敛速度和路径都不同,这也说明了随机梯度下降算法的随机性。
可以更换max_iter的值,以下是迭代10000此的结果:
在这里插入图片描述
迭代100000的结果:
在这里插入图片描述

在这段代码中的数学含义代表如下:

w_init:表示模型参数的初始值。这里是一个一维数组,包含三个元素,分别对应截距、x1系数和x2系数。

x:表示输入特征矩阵,每行包含两个特征x1和x2。

y:表示输出标签,是一个一维数组。

alpha:表示学习率,是一个超参数,决定了每一次迭代参数的更新幅度。

max_iter:表示最大迭代次数,也是一个超参数。

rand_idx = np.random.randint(len(y)):在每一次迭代中,随机选择一个样本。

x_i = x[rand_idx, :].reshape(1, -1):根据随机选择的样本索引,选取相应的特征向量,并将其变为一个1行2列的矩阵。

y_i = y[rand_idx]:根据随机选择的样本索引,选取相应的输出标签。

grad_i = gradient_function(w, x_i, y_i):计算该样本的梯度。

w = w - alpha * grad_i:更新模型参数。

return w:返回最终权重。

x = np.insert(x, 0, 1, axis=1):将截距添加到输入特征矩阵x的第一列。

np.dot(x, w):计算模型的预测值。

np.dot(x, w) - y:计算模型预测值与真实标签的误差。

np.dot(x.T, np.dot(x, w) - y):计算误差的梯度。

len(y):计算样本的数量。

return np.dot(x.T, np.dot(x, w) - y) / len(y):返回梯度。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/448668.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

synchronized原理:

vm中每个对象都会有一个监视器Monitor,监视器和对象一起创建、销毁。监视器相当于一个用来监视这些线程进入的特殊房间,其义务是保证(同一时间)只有一个线程可以访问被保护的临界区代码块。每一个锁都对应一个monitor对象&#xf…

如何训练自己的大型语言模型

如何使用 Databricks、Hugging Face 和 MosaicML 训练大型语言模型 (LLM) 介绍 大型语言模型,如 OpenAI 的 GPT-4 或谷歌的 PaLM,已经席卷了人工智能世界。然而,大多数公司目前没有能力训练这些模型,并且完全依赖少数大型科技公司…

LaoCat带你认识容器与镜像之Docker网络

近期比较忙,心思也比较乱,难得今天休息,闲来无事,借机更新一下系列 ~ 系列目录 LaoCat带你认识容器与镜像(一) LaoCat带你认识容器与镜像(二【一章】) LaoCat带你认识容器与镜像&…

软件安全性与隐私保护的最佳实践

在当今数字化时代,随着软件使用的普及和信息技术的发展,软件安全性和隐私保护越来越成为了IT领域关注的热点问题。在此,本文将探讨软件安全性和隐私保护的最佳实践,以帮助大家更好地保护自己的信息安全。 一、软件安全性最佳实践…

数据结构与算法八 优先队列

一 优先队列 普通的队列是一种先进先出的数据结构,元素在队列尾追加,而从队列头删除。在某些情况下,我们可能需要找出队列中的最大值或者最小值,例如使用一个队列保存计算机的任务,一般情况下计算机的任务都是有优先级…

C语言-malloc、free、memset、realloc、strcpy

malloc()开辟指定内存空间 函数原型 void *malloc(size_t size) C 库函数 void *malloc(size_t size) 分配所需的内存空间,并返回一个指向它的指针。 free 释放内存空间 free C 库函数 void free(void *ptr) 释放之前调用 calloc、malloc 或 realloc 所分配的…

紧跟时代潮流,如用ChatGPT速成自媒体达人

每一个选题是否成为爆款和热门,这个就占了80%,为什么?因为我看到你的标题,我可以点进去啊,不管内容如何,至少让人眼前一亮,有点进去的欲望,至少浏览量会很大,这就成功了一…

【LeetCode: 1043. 分隔数组以得到最大和 | 暴力递归=>记忆化搜索=>动态规划 | 线性dp 区间dp】

🚀 算法题 🚀 🌲 算法刷题专栏 | 面试必备算法 | 面试高频算法 🍀 🌲 越难的东西,越要努力坚持,因为它具有很高的价值,算法就是这样✨ 🌲 作者简介:硕风和炜,…

使用ETL工具Sqoop,将MySQL数据库db03中的10张表的表结构和数据导入(同步)到大数据平台的Hive中

在MySQL中,创建一个用户,用户名为sqoop03,密码为:123456 启动MySQL:support-files/mysql.server start 进入MySQL:mysql -u root -p 创建用户sqoop03:grant all on *.* to sqoop03% identifi…

5.5 高斯型求积公式简历

学习目标: 我会按照以下步骤学习高斯求积公式简介: 理解积分的概念:学习什么是积分以及积分的几何和物理意义,如面积、质量、电荷等概念。 掌握基本的积分技巧:掌握基本的积分公式和技巧,如换元法、分部积…

重要通知!报表控件FastReport VCL将停止支持旧的 Delphi 版本

FastReport 是功能齐全的报表控件,可以帮助开发者可以快速并高效地为.NET,VCL,COM,ActiveX应用程序添加报表支持,由于其独特的编程原则,现在已经成为了Delphi平台最优秀的报表控件,支持将编程开…

视频批量剪辑:如何给视频添加上下黑边并压缩视频容量。

视频太多了,要如何进行给视频添加上下黑边并压缩视频容量?今天就由小编来教教大家要如何进行操作,感兴趣的小伙伴们可以来看看。 首先,我们要进入视频剪辑高手主页面,并在上方板块栏里选择“批量剪辑视频”板块&#…

PX4无人机调参

文章目录 前言一、滤波参数二、PID参数自动调参手动调参角速率环姿态环 前言 PX4 1.13.2 日志分析软件:flight review https://logs.px4.io/ 一、滤波参数 调参时可以用自稳模式飞行 在调滤波器参数之前,可以先大致调一下PID的参数,角度率…

4-log打印

1.相关文件 2.示例 #include <stdbool.h> #include <stdint.h> #include <stdio.h> #include "nrf.h" #include "nrf_delay.h" #include "app_error.h" #include "nrf_log.h" #include "nrf_log_ctrl.h" …

WPF教程(九)--数据绑定(2)--绑定模式

一、绑定模式 绑定模式以及模式的使用效果。 示例如下是根据ListBox中的选中项&#xff0c;去改变TextBlock的背景色。将 TextBlock 的背景色绑定到在 ListBox 中选择的颜色。在下面的代码中针对TextBlock的 Background 属性使用绑定语法绑定从 ListBox 中选择的值。代码如下。…

typeScript的安装及基础使用示例

4.1.安装typescript npm 包&#xff1a; npm install -g typescript 2.查看安装好的版本检验&#xff1a; tsc -v 3.编译一个typescript 文件&#xff1a;tsc hello.ts 4.运行一个ts文件&#xff1a; 首先安装ts-node &#xff0c;ts-node需要在全局去安装。这里要用 npm…

【LeetCode】剑指 Offer 64. 求1+2+…+n p307 -- Java Version

题目链接&#xff1a;https://leetcode.cn/problems/qiu-12n-lcof/ 1. 题目介绍&#xff08;64. 求12…n&#xff09; 求 12...n &#xff0c;要求不能使用乘除法、for、while、if、else、switch、case等关键字及条件判断语句&#xff08;A?B:C&#xff09;。 【测试用例】&a…

android aidl

本文只是记录个人学习aidl的实现&#xff0c;如需学习请参考下面两篇教程 官方文档介绍Android 接口定义语言 (AIDL) | Android 开发者 | Android Developers 本文参考文档Android进阶——AIDL详解_android aidl_Yawn__的博客-CSDN博客 AIDL定义&#xff1a;Android 接口…

实验五 视图与完整性约束

实验五 视图与完整性约束 目录 实验五 视图与完整性约束选择题sql评测题1、SQl视图&#xff1a;建立视图CJ\_STUDENT题目代码题解 2、SQL视图&#xff1a;建立视图AVG\_CJ题目代码 3、SQL视图&#xff1a;建立视图IS\_STUDENT题目代码题解 4、SQL视图&#xff1a;根据视图CJ\_S…

pcle接口详解用途说明

PCIE (peripheral component interconnect express) 中文名&#xff1a;高速串行计算机扩展总线标准&#xff0c;它原来的名称为“3GIO”&#xff0c;由英特尔在2001年提出。 PCIE 有 12345代 和x1/x4/x8/x16插槽 1、PCIE x1/x4/x8/x16插槽模式&#xff0c;的区别和用处 pcel …