【机器学习】线性回归算法:原理、公式推导、损失函数、似然函数、梯度下降

news2025/1/22 12:14:06

1. 概念简述

        线性回归是通过一个或多个自变量与因变量之间进行建模的回归分析,其特点为一个或多个称为回归系数的模型参数的线性组合。如下图所示,样本点为历史数据,回归曲线要能最贴切的模拟样本点的趋势,将误差降到最小


2. 线性回归方程

        线形回归方程,就是有 n 个特征,然后每个特征 Xi 都有相应的系数 Wi ,并且在所有特征值为0的情况下,目标值有一个默认值 W0 ,因此:

线性回归方程为

h(w)=w_{0} + w_{1}*x_{1}+w_{2}*x_{2}+...+w_{n}*x_{n}

整合后的公式为:

h(w)=\sum_{i}^{n}w_{i}*x_{i} = \theta ^{T}*x


3. 损失函数

        损失函数是一个贯穿整个机器学习的一个重要概念,大部分机器学习算法都有误差,我们需要通过显性的公式来描述这个误差,并将这个误差优化到最小值。假设现在真实的值y预测的值h 。

损失函数公式为:

J(\theta )=\frac{1}{2}*\sum_{i}^{n}( y^{(i)} - \theta ^{T}*x^{(i)} )^{2}

也就是所有误差和的平方。损失函数值越小,说明误差越小,这个损失函数也称最小二乘法


4. 损失函数推导过程

4.1 公式转换

首先我们有一个线性回归方程h(\theta)=\theta_{0} + \theta_{1}*x_{1}+\theta_{2}*x_{2}+...+\theta_{n}*x_{n} 

为了方便计算计算,我们将线性回归方程转换成两个矩阵相乘的形式,将原式的 \theta _{0} 后面乘一个 x_{0}

此时的 x0=1,因此将线性回归方程转变成 h(\theta)=\sum_{i}^{n}\theta_{i}*x_{i},其中 \theta _{i} 和 x_{i} 可以写成矩阵

h(\theta)=\theta_{0} + \theta_{1}*x_{1}+...+\theta_{n}*x_{n} = \left [ \theta _{0} \; \theta _{1}\; \theta _{2}\; ... \right ]*\begin{bmatrix} x _{0}\\ x _{1}\\ x _{2}\\ ...\\ \end{bmatrix}=\sum_{i}^{n}\theta_{i}*x_{i} = \theta ^{T}*x

4.2 误差公式

以上求得的只是一个预测的值,而不是真实的值,他们之间肯定会存在误差,因此会有以下公式:

y_{i} = \theta _{i}*x_{i}+\epsilon_{i}

我们需要找出真实值 y_{i} 与预测值 \theta _{i}*x_{i} 之间的最小误差 \epsilon_{i} ,使预测值和真实值的差距最小。将这个公式转换成寻找不同的 \theta _{i} 使误差达到最小。

4.3 转化为 \theta 求解

由于 \epsilon_{i} 既存在正数也存在负数,所以可以简单的把这个数据集,看作是一个服从均值 \theta ,方差\sigma ^{2} 的正态分布。

所以 \epsilon_{i} 出现的概率满足概率密度函数

p(\epsilon _{i} ) = \frac{1}{\sigma\sqrt{2\pi }} exp\tfrac{-(\epsilon _{i})^{2}}{2\sigma ^{2}}

把 \epsilon_{i} =y_{i}- \theta _{i}*x_{i} 代入到以上的高斯分布函数(即正态分布)中,变成以下式子: 

p(\epsilon _{i} ) = \frac{1}{\sigma\sqrt{2\pi }} exp\tfrac{-(y_{i}- \theta _{i}*x_{i})^{2}}{2\sigma ^{2}}

到此,我们将对误差 \epsilon _{i} 的求解转换成对 \theta_{i} 的求解了。

在求解这个公式时,我们要得到的是误差 \epsilon _{i} 最小,也就是求概率 p(\epsilon _{i}) 最大的。因为误差 \epsilon _{i} 满足正态分布,因此在正太曲线中央高峰部的概率 p(\epsilon _{i}) 是最大的,此时标准差\sigma为0误差是最小的。

尽管在生活中标准差肯定是不为0的,没关系,我们只需要去找到误差值出现的概率最大的点。现在,问题就变成了怎么去找误差出现概率最大的点,只要找到,那我们就能求出\theta _{i}

4.4 似然函数求 \theta

似然函数的主要作用是,在已经知道变量 x 的情况下,调整 \theta,使概率 y 的值最大。

似然函数理解:

以抛硬币为例,正常情况硬币出现正反面的概率都是0.5,假设你在不确定这枚硬币的材质、重量分布的情况下,需要判断其是否真的是均匀分布。在这里我们假设这枚硬币有 \theta 的概率会正面朝上,有 1-\theta 的概率会反面朝上

为了获得 \theta 的值,将硬币抛10次,H为正面,T为反面,得到一个正反序列 x = HHTTHTHHHH,此次实验满足二项分布,这个序列出现的概率\theta \theta (1-\theta )(1-\theta ) \theta(1-\theta ) \theta \theta \theta \theta= \theta^{7}(1-\theta )^{3},我们根据一次简单的二项分布实验,得到了一个关于 \theta 的函数,这实际上是一个似然函数,根据不同的 \theta 值绘制一条曲线,曲线就是\theta的似然函数,y轴是这一现象出现的概率。

从图中可见,当 \theta 等于 0.7 时,该序列出现的概率是最大的,因此我们确定该硬币正面朝上的概率是0.7。

因此,回到正题,我们要求的是误差出现概率 p(\epsilon _{i}) 的最大值,那就做很多次实验,对误差出现概率累乘,得出似然函数,带入不同的 \theta ,\theta是多少时,出现的概率是最大的,即可确定\theta的值。

综上,我们得出求 \theta 的似然函数为:

L( \theta ) = \prod_{i}^{m} \frac{1}{\sigma\sqrt{2\pi }} exp\tfrac{-(y_{i}- \theta _{i}*x_{i})^{2}}{2\sigma ^{2}}

4.5 对数似然

由于上述的累乘的方法不太方便我们去求解 \theta,我们可以转换成对数似然,将以上公式放到对数中,然后就可以转换成一个加法运算。取对数以后会改变结果值,但不会改变结果的大小顺序。我们只关心\theta等于什么的时候,似然函数有最大值,不用管最大值是多少,即,不是求极值而是求极值点。注:此处log的底数为e。

对数似然公式如下:

\log (L( \theta )) =\log \prod_{i}^{m} \frac{1}{\sigma\sqrt{2\pi }} exp\tfrac{-(y_{i}- \theta _{i}*x_{i})^{2}}{2\sigma ^{2}} = \sum_{i}^{n}\log \frac{1}{\sigma\sqrt{2\pi }} exp\tfrac{-(y_{i}- \theta _{i}*x_{i})^{2}}{2\sigma ^{2}}

对以上公式化简得:

\log (L( \theta )) =n*\log \frac{1}{\sigma\sqrt{2\pi }} - \frac{1}{2\sigma ^{2}}\sum_{i}^{n} (y_{i}- \theta _{i}*x_{i})^{2}

4.6 损失函数

我们需要把上面那个式子求得最大值,然后再获取最大值时的 \theta 值。 而上式中 n*\log \frac{1}{\sigma\sqrt{2\pi }} 是一个常数项,所以我们只需要把减号后面那个式子变得最小就可以了,而减号后面那个部分,可以把常数项 \frac{1}{\sigma ^{2}} 去掉,因此我们得到最终的损失函数如下,现在只需要求损失函数的最小值。

J (\theta ) = \frac{1}{2}\sum_{i}^{n} (y_{i}- \theta _{i}*x_{i})^{2}

注:保留 \frac{1}{2} 是为了后期求偏导数。

损失函数越小,说明预测值越接近真实值,这个损失函数也叫最小二乘法。


5. 梯度下降

损失函数中 xiyi 都是给定的值,能调整的只有 \theta,如果随机的调整,数据量很大,会花费很长时间,每次调整都不清楚我调整的是高了还是低了。我们需要根据指定的路径去调节,每次调节一个,范围就减少一点,有目标有计划去调节。梯度下降相当于是去找到一条路径,让我们去调整\theta

梯度下降的通俗理解就是,把对以上损失函数最小值的求解,比喻成梯子,然后不断地下降,直到找到最低的值。

5.1 批量梯度下降(BGD)

批量梯度下降,是在每次求解过程中,把所有数据都进行考察,因此损失函数因该要在原来的损失函数的基础之上加上一个m:数据量,来求平均值

J (\theta ) = \frac{1}{2m}\sum_{i}^{m} (y_{i}- \theta _{i}*x_{i})^{2}

因为现在针对所有的数据做了一次损失函数的求解,比如我现在对100万条数据都做了损失函数的求解,数据量结果太大,除以数据量100万,求损失函数的平均值。

然后,我们需要去求一个点的方向,也就是去求它的斜率。对这个点求导数,就是它的斜率,因此我们只需要求出 J(\theta ) 的导数,就知道它要往哪个方向下降了。它的方向先对所有分支方向求导再找出它们的合方向。

J(\theta ) 的导数为:

\frac{\partial J (\theta)}{\partial \theta _{j}} = -\frac{1}{m}\sum_{i}^{m} (y^{j}- h_{\theta} (x^{i}))x_{j}^{i}

由于导数的方向是上升的,现在我们需要梯度下降,因此在上式前面加一个负号,就得到了下降方向,而下降是在当前点的基础上下降的。

批量梯度下降法下降后的点为:

\theta_{j}{'} = \theta_{j}+\alpha \frac{1}{m}\sum_{i}^{m} (y^{j}- h_{\theta} (x^{i}))x_{j}^{i}

新点是在原点的基础上往下走一点点,斜率表示梯度下降的方向,\alpha 表示要下降多少。由于不同点的斜率是不一样的,以此循环,找到最低点。

批量梯度下降的特点:每次向下走一点点都需要将所有的点拿来运算,如果数据量大非常耗时间。


5.2 随机梯度下降(SGD)

随机梯度下降是通过每个样本来迭代更新一次。对比批量梯度下降,迭代一次需要用到所有的样本,一次迭代不可能最优,如果迭代10次就需要遍历整个样本10次。SGD每次取一个点来计算下降方向。但是,随机梯度下降的噪音比批量梯度下降要多,使得随机梯度下降并不是每次迭代都向着整体最优化方向

随机梯度下降法下降后的点为:

\theta_{j}{'} = \theta_{j}+\alpha (y^{j}- h_{\theta} (x^{i}))x_{j}^{i}

每次随机一个点计算,不需要把所有点拿来求平均值,梯度下降路径弯弯曲曲趋势不太好。


5.3 mini-batch 小批量梯度下降(MBGO)

我们从上面两个梯度下降方法中可以看出,他们各自有优缺点。小批量梯度下降法在这两种方法中取得了一个折衷,算法的训练过程比较快,而且也要保证最终参数训练的准确率。

假设现在有10万条数据,MBGO一次性拿几百几千条数据来计算,能保证大体方向上还是下降的。

小批量梯度下降法下降后的点为:

\theta_{j}{'} = \theta_{j}+\alpha \frac{1}{n}\sum_{i}^{n} (y^{j}- h_{\theta} (x^{i}))x_{j}^{i}

\alpha 用来表示学习速率,即每次下降多少。已经求出斜率了,但是往下走多少合适呢,\alpha值需要去调节,太大的话下降方向会偏离整体方向,太小会导致学习效率很慢。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1214892.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

第二证券:产业资本真金白银传递市场信心

本年以来,A股商场继续颤抖,但工业本钱纷繁行为,拿出大笔真金白银掀起增持回购潮。Wind数据闪现,到11月15日记者发稿,本年以来已有逾千家公司发布了股票回购预案,拟回购金额上限估计超1200亿元。同期&#x…

《洛谷深入浅出基础篇》P1551亲戚——集合——并查集P1551亲戚

上链接:P1551 亲戚 - 洛谷 | 计算机科学教育新生态 (luogu.com.cn)https://www.luogu.com.cn/problem/P1551 上题干: 题目背景 若某个家族人员过于庞大,要判断两个是否是亲戚,确实还很不容易,现在给出某个亲戚关系图…

LINMP搭建wordpress-数据库不分离

目录 一、nginx部署 1.安装nginx前的系统依赖环境检查 2.下载nginx源代码包 3.解压缩源码包 4.创建普通的nginx用户 5.开始编译安装nginx服务 6.创建一个软连接以供集中管理 7.配置nginx环境变量 二、mysql 1.创建普通mysql用户 2.下载mysql二进制代码包 3.创建mys…

windows的远程桌面服务RDS存在弱加密证书的漏洞处理

背景 漏洞扫描检测windows服务器的远程桌面服务使用了弱加密的ssl证书 思路 按照报告描述,试图使用强加密的新证书更换默认证书 解决 生成证书 通过openssl1.1.1生成(linux自带openssl,windows安装的是openssl1.1.1w)&#x…

基于Vue+SpringBoot的农村物流配送系统 开源项目

目录 一、摘要1.1 项目介绍1.2 项目录屏 二、功能模块2.1 系统登录、注册界面2.2 系统功能2.2.1 快递信息管理:2.2.2 位置信息管理:2.2.3 配送人员分配:2.2.4 路线规划:2.2.5 个人中心:2.2.6 退换快递处理:…

智能电网短路故障接地故障模拟柜

智能电网短路故障接地故障模拟柜是一种用于模拟智能电网中短路故障和接地故障的设备,它可以模拟电网中的各种故障情况,帮助电力工程师进行故障诊断和维修。智能电网中的短路故障是指电路中出现了异常的电流路径,导致电流过大,可能…

利用ffmpeg实现rtmp和rtsp推流

环境说明 windows11 : ffmpeg VLC Linux Unbuntu20.04 : SRS MediaMTX 可选:GStreamer win11下载ffmpeg和ffplay ffmpeg官网 添加环境变量:添加ffmpeg/bin所在的路径。 D:\ffmpeg\ffmpeg-master-latest-win64-lgpl-shared\bin win11查看本机电脑的设备…

基于单片机的水位检测系统仿真设计

**单片机设计介绍, 基于单片机的水位检测系统仿真设计 文章目录 一 概要二、功能设计设计思路 三、 软件设计原理图 五、 程序六、 文章目录 一 概要 基于单片机的水位检测系统仿真系统是一种用于模拟水位检测系统的工作过程,以验证设计方案的可行性和优…

赢麻了……腾讯1面核心9问,小伙伴过了提42W offer

说在前面 在40岁老架构师尼恩的(50)读者社群中,经常有小伙伴,需要面试腾讯、美团、京东、阿里、 百度、头条等大厂。 下面是一个小伙伴成功拿到通过了腾讯面试,并且最终拿到offer,一毕业就年薪42W&#x…

UniApp中的数据存储与获取指南

目录 介绍 数据存储方案 1. 本地存储 2. 数据库存储 3. 网络存储 实战演练 1. 本地存储实例 2. 数据库存储实例 3. 网络存储实例 注意事项与最佳实践 结语 介绍 在移动应用开发中,数据的存储和获取是至关重要的一部分。UniApp作为一款跨平台应用开发框架…

C++入门(2)—函数重载、引用

目录 一、函数重载 1、参数类型不同 2、参数个数不同 3、参数顺序不同 4、 链接中如何区分函数重载 二、引用 1、规则 2、特征 3、使用场景 做参数 做返回值 4、常引用 5、传值、传引用效率比较 6、引用和指针的区别 接上一小节C入门(1)—命名空间、缺省参数 一…

Nutz框架如何自定义SQL?

Nutz框架基本的简单sql已经封装了,但是一些叫为复杂的sql需要手动去写,那如何实现像Mybatis那样通过配置文件编写呢?如有不明白详见官方文档:自定义 SQL - Nutzhttps://nutzam.com/core/dao/customized_sql.html#ndoc-4 一 新建…

Navicat for mysql 无法连接到虚拟机的linux系统下的mysql

原创/朱季谦 最近在linux Centos7版本的虚拟机上安装了一个MySql数据库,发现本地可以正常ping通虚拟机,但Navicat则无法正常连接到虚拟机里的MySql数据库,经过一番琢磨,发现解决这个问题的方式,很简单,总共…

CTFhub-RCE-过滤cat

查看当前目录:输入:127.0.0.1|ls 127.0.0.1|cat flag_42211411527984.php 无输出内容 使用单引号绕过 127.0.0.1|cat flag_42211411527984.php|base 64 使用双引号绕过 127.0.0.1|c""at flag_42211411527984.php|base64 使用特殊变量绕过 127.0.0.…

2016Outlook显示正在启动无法进入Outlook

2016Outlook显示正在启动无法进入Outlook 故障现象: 因上次非正常关闭,导致Outlook启动时,一直处于启动界面,无法进入主界面正常工作 故障截图: 故障原因: 数据文件异常导致 解决方案: 1、关…

asp.net core mvc 之 依赖注入

一、视图中使用依赖注入 1、core目录下添加 LogHelperService.cs 类 public class LogHelperService{public void Add(){}public string Read(){return "日志读取";}} 2、Startup.cs 文件中 注入依赖注入 3、Views目录中 _ViewImports.cshtml 添加引用 4、视图使用…

HTML5+CSS3小实例:悬停放大图片的旅游画廊

实例:悬停放大图片的旅游画廊 技术栈:HTML+CSS 效果: 源码: 【HTML】 <!DOCTYPE html> <html><head><meta http-equiv="content-type" content="text/html; charset=utf-8"><meta name="viewport" content=&…

【linux】nmon 工具使用

nmon 介绍 nmon是奈杰尔的性能监视器的缩写&#xff0c;适用于POWER、x86、x86_64、Mainframe和现在的ARM&#xff08;Raspberry Pi&#xff09;上的Linux。同样适用于nmon for AIX的工具&#xff08;与IBM的AIX一起提供&#xff09;。njmon与之类似&#xff0c;但将数据保存为…

波束形成中的主瓣宽度

阵列信号处理相关基础知识及主瓣宽度 导向矢量阵列方向图确知波束形成普通波束形成主瓣宽度确知波束形成主瓣宽度普通波束形成主瓣宽度 在讨论主瓣宽度之前&#xff0c;首先得了解导向矢量、波束形成、阵列方向图的概念&#xff0c;这些是阵列信号处理中最基础的知识。 导向矢量…