最小二乘法的原理及实现

news2025/1/11 14:21:20

1.最小二乘法的原理及实现

笔记来源于《白话机器学习的数学》

1.1 最小二乘法的原理

预测一个变量 x x x与一个变量 y y y的关系
例如:广告费 x x x与点击量 y y y
用直线拟合数据

1.2 最小二乘法的实现

广告费x和点击量y,找到一条直线表达式,输入广告费来预测点击量
训练数据如下:

import numpy as np
import matplotlib.pyplot as plt

# 读入训练数据
train = np.loadtxt('click.csv', delimiter=',', dtype='int', skiprows=1)
train_x = train[:,0]
train_y = train[:,1]

数据预处理步骤之一:对训练数据进行标准化 / 归一化,目的使得参数收敛会更快
计算出数据中所有x的均值 μ \mu μ和标准差 σ \sigma σ,每个数值x按照下列式子进行标准化,数据y也进行类似标准化

# 标准化
mu = train_x.mean()
sigma = train_x.std()
def standardize(x):
    return (x - mu) / sigma

train_z = standardize(train_x)
# 展示标准化后的数据
plt.plot(train_z, train_y, 'o')
plt.show()

横轴范围变为了[-2,2]

假设函数 f θ ( x ) f_{\theta}(x) fθ(x)、目标函数 E ( θ ) E(\theta) E(θ)

# 参数初始化
theta0 = np.random.rand()
theta1 = np.random.rand()
# 预测函数
def f(x):
    return theta0 + theta1 * x
# 目标函数
def E(x, y):
    return 0.5 * np.sum((y - f(x)) ** 2)
# 学习率
ETA = 1e-3
# 初始化误差的差值,作为循环结束依据
diff = 1
# 初始化更新次数
count = 0

梯度下降法对目标函数进行优化
参数更新表达式

# 直到误差的差值小于 0.01 为止,重复参数更新
error = E(train_z, train_y)
while diff > 1e-2:
    # 更新结果保存到临时变量
    tmp_theta0 = theta0 - ETA * np.sum((f(train_z) - train_y))
    tmp_theta1 = theta1 - ETA * np.sum((f(train_z) - train_y) * train_z)

    # 更新参数
    theta0 = tmp_theta0
    theta1 = tmp_theta1

    # 计算与上一次误差的差值
    current_error = E(train_z, train_y)
    diff = error - current_error
    error = current_error

    # 输出日志
    count += 1
    log = '第 {} 次 : theta0 = {:.3f}, theta1 = {:.3f}, 差值 = {:.4f}'
    print(log.format(count, theta0, theta1, diff))

# 绘图确认
x = np.linspace(-3, 3, 100)
plt.plot(train_z, train_y, 'o')
plt.plot(x, f(x))
plt.show()

对目标函数的优化方法:
最速下降法(梯度下降法/Gradient Descent) 对所有训练数据都重复进行计算
梯度下降法的问题是训练数据越多,循环次数越多,计算时间越长,选用随机数作为初始值,每次初始值都会变,进而导致陷入局部最优解的问题

在GD算法中,每次的梯度都是从所有样本中累计获取的,这种情况最容易导致梯度方向过于稳定一致,且更新次数过少,容易陷入局部最优–摘自:防止梯度下降陷入局部最优的三种方法


随机梯度下降法(Stochastic GD)随机选择一个训练数据,并使用它来更新参数
k是被随机选中的数据索引
θ j : = θ j − η ( f θ ( x ( k ) ) − y ( k ) ) x j ( k ) \theta_j:=\theta_j-\eta\big(f_{\boldsymbol{\theta}}(\boldsymbol{x}^{(k)})-y^{(k)}\big)x_j^{(k)} θj:=θjη(fθ(x(k))y(k))xj(k)
注意这里没有求和,因为只使用了一个训练数据
为什么随机梯度下降法不容易陷入局部最优?

stochastic GD是GD的另一种极端更新方式,其每次都只使用一个样本进行参数更新,这样更新次数大大增加,更新参数时使用的又是选择数据时的梯度,每次梯度方向不同,也就不容易陷入局部最优。–摘自:防止梯度下降陷入局部最优的三种方法

小批量梯度下降法(Mini-Batch GD)随机选择m个训练数据来更新参数
θ j : = θ j − η ∑ k ∈ K ( f θ ( x ( k ) ) − y ( k ) ) x j ( k ) \theta_j:=\theta_j-\eta\sum_{k\in K}\big(f_{\boldsymbol{\theta}}(\boldsymbol{x}^{(k)})-y^{(k)}\big)x_j^{(k)} θj:=θjηkK(fθ(x(k))y(k))xj(k)
∑ k ∈ K \sum_{k\in K} kK代表将集合K中的所有元素相加

Mini-Batch GD便是两种极端的折中,即每次更新使用一小批样本进行参数更新。Mini-Batch GD是目前最常用的优化算法,严格意义上Mini-Batch GD也叫做stochastic GD,所以很多深度学习框架上都叫做SGD。–摘自:防止梯度下降陷入局部最优的三种方法

动量(Momentum)

动量也是GD中常用的方式之一,SGD的更新方式虽然有效,但每次只依赖于当前批样本的梯度方向,这样的梯度方向依然很可能很随机。动量就是用来减少随机,增加稳定性。其思想是模仿物理学的动量方式,每次更新前加入部分上一次的梯度量,这样整个梯度方向就不容易过于随机。一些常见情况时,如上次梯度过大,导致进入局部最小点时,下一次更新能很容易借助上次的大梯度跳出局部最小点。–摘自:防止梯度下降陷入局部最优的三种方法

无论使用哪种优化方法对目标函数进行优化,我们都必须考虑学习率 η \eta η设置为合适的值很重要,这个问题比较难,可以通过反复尝试来找到合适的值

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/688001.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

基于matlab多运动目标跟踪监测算法实现(附源码)

一、前言 此示例演示如何对来自固定摄像机的视频中的移动对象执行自动检测和基于运动的跟踪。 二、介绍 移动物体检测和基于运动的跟踪是许多计算机视觉应用的重要组成部分,包括活动识别、交通监控和汽车安全。基于运动的对象跟踪问题可以分为两部分: 检…

【KitBash3D Cargo插件】向UE中直接导入免费模型

步骤 1. 进入KitBash3D官网,点击右上角按钮来下载Cargo 2. 下载好后是个压缩包,需要进行解压 3. 解压后运行安装程序 4. 我就安装到默认的安装路径 5. 安装好后打开软件,注册账号(如果点击创建账户按钮没反应就去KitBash3D官网注…

VS依赖注入(DI)构造函数自动生成局部私有变量

前言 依赖注入(DI)在开发中既是常见的也是必需的技术。它帮助我们优化了代码结构,使得应用更加灵活、易于扩展,同时也降低了各个模块之间的耦合度,更容易进行单元测试,提高了编码效率和质量。我们经常会先定义局部变量&#xff0…

OpenCL编程指南-6.2程序对象

创建和构建程序 要创建程序对象,可以传入OpenCL C源代码文本,或者利用程序二进制码来创建。由OpenCL C源代码创建程序对象是开发人员创建程序对象的一般做法。OpenCL C程序的源代码放在一个外部文件中(例如,就像我们的示例代码中…

【网络知识面试】初识协议栈和套接字及连接阶段的三次握手

接上一篇:【网络面试必问】浏览器如何委托协议栈完成消息的收发 1. 协议栈 一直对操作系统系统的内核协议栈理解的模模糊糊,借着这一篇博客做一下简单梳理。 我觉得最直白的理解,内核协议栈就是操作系统中的一个网络控制软件,就是…

【git】git常用指令(项目一般使用流程示例)

文章目录 创建开发环境clone到本地查看分支创建自己的开发分支切换到开发分支 开发完成上传到仓库判断目前本地仓库的状态新内容提交到暂存区新内容更新到本地仓库新内容推到远端仓库dev1.0并入主分支1.切换到主分支2.合并3.推主分支上远端仓库 回退版本主分支更新了&#xff0…

软件产品登记测试为何如此重要?

软件产品登记测试为何如此重要? 软件产品登记测试报告,是对客户的软件产品进行功能性的检测和验证,确保这些功能都得以实现并能正常运行,可作为国家高新、增值税退税、双软评估、首套台软件的检测证明材料。 软件登记测试是“双软…

three.js中聚光灯及其属性介绍

一、聚光灯及其属性介绍 Three.js中的聚光灯(SpotLight)是一种用于在场景中创建聚焦光照的光源类型。它有以下属性: color:聚光灯的颜色。 intensity:聚光灯的强度。 distance:聚光灯的有效距离。 angl…

知识管理工具:在信息时代下的组织智慧管理

随着信息时代的到来,企业面临着前所未有的信息爆炸和快速变化的挑战。如何高效地管理和利用这些信息已经成为了企业生存和发展的关键。在这种背景下,知识管理工具应运而生,为企业提供了优秀的解决方案。 知识管理工具的定义与特点 知识管理的…

DAMA数据治理CDGA/CDGP认证考试备考经验分享

一,关于DAMA中国和CDGA/CDGP考试 国际数据管理协会(DAMA国际)是一个全球性的专业组织,由数据管理和相关的专业人士组成,非营利性机构,厂商中立。协会自1980年成立以来,一直致力于数据管理和数字…

gralylog介绍与安装

介绍 Graylog是一个开源的日志管理和分析平台,用于收集、存储、分析和可视化大量日志数据。它提供了一个集中化的解决方案,可以帮助组织有效地处理分散在各种系统和应用程序中的日志信息。 以下是Graylog的主要特点和功能: 日志收集&#x…

【AI工具】-MockingBird-语音合成语音克隆

简介 MockingBird: 英文翻译:反舌鸟,也可能来自《杀死一只知更鸟》(英语:To Kill a Mockingbird),台译“梅冈城故事”,中国大陆译“杀死一只知更鸟”,直译应为“杀死一…

【Python】python进阶篇之数据库操作

数据库操作 pip3安装mysql依赖 pip3 list|grep mysqlpip3 install mysql-connector-python #指定版本 pip3 install mysql-connector-python版本号 #升降版本 pip3 install --upgrade mysql-connector-python版本号原生SQL操作 操作mysql可以使用pymsql或mysql-connector-py…

基于html+css的图展示138

准备项目 项目开发工具 Visual Studio Code 1.44.2 版本: 1.44.2 提交: ff915844119ce9485abfe8aa9076ec76b5300ddd 日期: 2020-04-16T16:36:23.138Z Electron: 7.1.11 Chrome: 78.0.3904.130 Node.js: 12.8.1 V8: 7.8.279.23-electron.0 OS: Windows_NT x64 10.0.19044 项目…

第三章 决策树

文章目录 第三章 决策树3.1基本流程3.2划分选择3.2.1信息增益3.2.2增益率3.2.3基尼指数 3.3剪枝处理3.3.1预剪枝3.3.2后剪枝 3.4连续与缺失值3.4.1连续值处理3.4.2缺失值处理 3.5多变量决策树3.7实验 第三章 决策树 3.1基本流程 决策过程: 基本算法: …

灵雀云获Gartner® 首份《DevOps平台魔力象限报告》“荣誉提及”

随着平台工程理念的崛起,企业使用的独立的DevOps工具链逐渐向更先进、更便捷的DevOps平台演进。Gartner发布了首份DevOps平台魔力象限报告(Gartner Magic Quadrant for DevOps Platforms)。在这个备受关注的报告中,中国云原生厂商…

大势智慧软硬件技术答疑第五期

1.控制点误差表达到多少就可以? 答:水平和高程误差在0.01左右就可以,图示精度是满足的。 2.三维影像有颜色,为什么生成的是二维影像是黑色的? 答:使用dasviewer的工具-输出正射图再试试。 3.最新模方对ps版…

JMeter中常见的四种参数化实现方式是什么?

1 参数化释义 什么是参数化?从字面上去理解的话,就是事先准备好数据(广义上来说,可以是具体的数据值,也可以是数据生成规则),而非在脚本中写死,脚本执行时从准备好的数据中取值。 参…

浅析移动警务App中的技术痛点与挑战

移动警务是指警务机关利用移动通信技术和移动设备,实现警务信息化、智能化和移动化的一种工作模式。通过移动警务,警务人员可以随时随地进行警务工作,提高警务反应速度和效率。 移动警务通常包括以下方面的内容: 移动巡逻&#x…

【ArcGIS Pro二次开发】(43):线闭合

当我们需要将多段线【polyline】转为面【polygon】的时候,必须保证线是闭合的,不然是无法生成面的,如下图: 如果cad线段,可以在属性里将闭合选项设置为是,实现线的闭合: 但如果是在ArcGIS Pro里…