AI算法15-弹性网络回归算法Elastic Net Regression | ENR

news2024/9/25 23:19:52

弹性网络回归算法简介

在机器学习领域中,弹性网络(Elastic Net)是一种结合了L1范数(套索回归)和L2范数(岭回归)的正则化方法。它综合了两者的优点,既可以实现特征选择,又可以处理多重共线性。弹性网络在实际应用中具有广泛的用途,因此,在这篇文章中我们将探讨弹性网络正则化的公式、应用场景、优势以及如何调节超参数等方面。一般线性Elastic Net模型的目标函数

目标函数的第一行与传统线性回归模型完全相同,即我们希望得到相应的自变量系数β,以此最小化实际因变量y与预测应变量βx之间的误差平方和。 而线性Elastic Net与线性回归的不同之处就在于有无第二行的这个约束,线性Elastic Net希望得到的自变量系数是在由t控制的一个范围内。 这一约束也是Elastic Net模型能进行复杂度调整,LASSO回归能进行变量筛选和复杂度调整的原因。 我们可以通过下面的这张图来解释这个道理:

先看左图,假设一个二维模型对应的系数是β1和β2,然后是最小化误差平方和的点,即用传统线性回归得到的自变量系数。 但我们想让这个系数点必须落在蓝色的正方形内,所以就有了一系列围绕的同心椭圆,其中最先与蓝色正方形接触的点,就是符合约束同时最小化误差平方和的点。 这个点就是同一个问题LASSO回归得到的自变量系数。 因为约束是一个正方形,所以除非相切,正方形与同心椭圆的接触点往往在正方形顶点上。而顶点又落在坐标轴上,这就意味着符合约束的自变量系数有一个值是0。 所以这里传统线性回归得到的是β1和β2都起作用的模型,而LASSO回归得到的是只有β2有作用的模型,这就是LASSO回归能筛选变量的原因。

而正方形的大小就决定了复杂度调整的程度。假设这个正方形极小,近似于一个点,那么LASSO回归得到的就是一个只有常量(intercept)而其他自变量系数都为0的模型,这是模型简化的极端情况。 由此我们可以明白,控制复杂度调整程度的λ值与约束大小t是呈反比的,即λ值越大对参数较多的线性模型的惩罚力度就越大,越容易得到一个简单的模型。

另外,我们之前提到的参数α就决定了这个约束的形状。刚才提到LASSO回归(α=1)的约束是一个正方形,所以更容易让约束后的系数点落在顶点上,从而起到变量筛选或者说降维的目的。 而Ridge回归(α=0)的约束是一个圆,与同心椭圆的相切点会在圆上的任何位置,所以Ridge回归并没有变量筛选的功能。 相应的,当几个自变量高度相关时,LASSO回归会倾向于选出其中的任意一个加入到筛选后的模型中,而Ridge回归则会把这一组自变量都挑选出来。 至于一般的Elastic Net模型(0<α<1),其约束的形状介于正方形与圆形之间,所以其特点就是在任意选出一个自变量或者一组自变量之间权衡。

弹性网络回归算法模型介绍

弹性网络回归算法的代价函数结合了Lasso回归和岭回归的正则化方法,通过两个参数 λ 和 ρ 来控制惩罚项的大小。

同样是求使得代价函数最小时 w 的大小:

可以看到,当 ρ = 0 时,其代价函数就等同于岭回归的代价函数,当 ρ = 1 时,其代价函数就等同于 Lasso 回归的代价函数。与 Lasso 回归一样代价函数中有绝对值存在,不是处处可导的,所以就没办法通过直接求导的方式来直接得到 w 的解析解,不过还是可以用坐标下降法2(coordinate descent)来求解 w。

弹性网络回归算法步骤

坐标下降法

坐标下降法的求解方法与 Lasso 回归所用到的步骤一样,唯一的区别只是代价函数不一样。

具体步骤:

  1. 初始化权重系数 w,例如初始化为零向量。
  2. 遍历所有权重系数,依次将其中一个权重系数当作变量,其他权重系数固定为上一次计算的结果当作常量,求出当前条件下只有一个权重系数变量的情况下的最优解。在第 k 次迭代时,更新权重系数的方法如下:

  3. 步骤2为一次完整迭代,当所有权重系数的变化不大或者到达最大迭代次数时,结束迭代。

弹性网络回归算法代码实现

使用 Python 实现弹性网络回归算法(坐标下降法):

def elasticNet(X, y, lambdas=0.1, rhos=0.5, max_iter=1000, tol=1e-4):
    """
    弹性网络回归,使用坐标下降法(coordinate descent)
    args:
        X - 训练数据集
        y - 目标标签值
        lambdas - 惩罚项系数
        rhos - 混合参数,取值范围[0,1]
        max_iter - 最大迭代次数
        tol - 变化量容忍值
    return:
        w - 权重系数
    """
    # 初始化 w 为零向量
    w = np.zeros(X.shape[1])
    for it in range(max_iter):
        done = True
        # 遍历所有自变量
        for i in range(0, len(w)):
            # 记录上一轮系数
            weight = W[i]
            # 求出当前条件下的最佳系数
            w[i] = down(X, y, w, i, lambdas, rhos)
            # 当其中一个系数变化量未到达其容忍值,继续循环
            if (np.abs(weight - w[i]) > tol):
                done = False
        # 所有系数都变化不大时,结束循环
        if (done):
            break
    return w


def down(X, y, w, index, lambdas=0.1, rhos=0.5):
    """
    cost(w) = (x1 * w1 + x2 * w2 + ... - y)^2 / 2n + ... + λ * ρ * (|w1| + |w2| + ...) + [λ * (1 - ρ) / 2] * (w1^2 + w2^2 + ...)
    假设 w1 是变量,这时其他的值均为常数,带入上式后,其代价函数是关于 w1 的一元二次函数,可以写成下式:
    cost(w1) = (a * w1 + b)^2 / 2n + ... + λρ|w1| + [λ(1 - ρ)/2] * w1^2 + c (a,b,c,λ 均为常数)
    => 展开后
    cost(w1) = [aa / 2n + λ(1 - ρ)/2] * w1^2 + (ab / n) * w1 + λρ|w1| + c (aa,ab,c,λ 均为常数)
    """
    # 展开后的二次项的系数之和
    aa = 0
    # 展开后的一次项的系数之和
    ab = 0
    for i in range(X.shape[0]):
        # 括号内一次项的系数
        a = X[i][index]
        # 括号内常数项的系数
        b = X[i][:].dot(w) - a * w[index] - y[i]
        # 可以很容易的得到展开后的二次项的系数为括号内一次项的系数平方的和
        aa = aa + a * a
        # 可以很容易的得到展开后的一次项的系数为括号内一次项的系数乘以括号内常数项的和
        ab = ab + a * b
    # 由于是一元二次函数,当导数为零是,函数值最小值,只需要关注二次项系数、一次项系数和 λ
    return det(aa, ab, X.shape[0], lambdas, rhos)


def det(aa, ab, n, lambdas=0.1, rhos=0.5):
    """
    通过代价函数的导数求 w,当 w = 0 时,不可导
    det(w) = [aa / n + λ(1 - ρ)] * w + ab / n + λρ = 0 (w > 0)
    => w = - (ab / n + λρ) / [aa / n  + λ(1 - ρ)]

    det(w) = [aa / n + λ(1 - ρ)] * w + ab / n  - λρ = 0 (w < 0)
    => w = - (ab / n - λρ) / [aa / n  + λ(1 - ρ)]

    det(w) = NaN (w = 0)
    => w = 0
    """
    w = -(ab / n + lambdas * rhos) / (aa / n + lambdas * (1 - rhos))
    if w < 0:
        w = -(ab / n - lambdas * rhos) / (aa / n + lambdas * (1 - rhos))
        if w > 0:
            w = 0
    return w

弹性网络回归算法动画演示

下面动图展示了不同的 ρ 对弹性网络回归的影响,当 ρ 逐渐增大时,L1正则项占据主导地位,代价函数越接近Lasso回归,当 ρ 逐渐减小时,L2正则项占据主导地位,代价函数越接近岭回归。

下面动图展示了Lasso回归与弹性网络回归对比,虚线表示Lasso回归的十个特征,实线表示弹性网络回归的十个特征,每一个颜色表示一个自变量的权重系数(训练数据来源于sklearn diabetes datasets)

Lasso回归与弹性网络回归对比

弹性网络回归算法优势和劣势

  1. 优势:弹性网络综合了套索回归和岭回归的优点,既能实现特征选择,又能处理多重共线性;对噪声具有一定的鲁棒性;适用于高维数据和具有相关特征的情况。
  2. 劣势:需要调节两个超参数,增加了模型的复杂度;在特征高度相关的情况下,可能无法有效区分特征的重要性。

弹性网络回归算法的应用场景

  1. 处理高维数据:当数据集具有大量特征时,弹性网络可以帮助筛选出最重要的特征,避免过拟合问题和提高模型泛化能力。
  2. 处理共线性:弹性网络能够处理特征之间存在较强相关性的情况,通过综合考虑L1和L2正则化的效果,可以更好地稳定模型参数估计。
  3. 噪声较多的情况:弹性网络对噪声具有一定的鲁棒性,可以减小噪声对模型的影响,提高模型的预测准确性。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1923869.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

从“卷模型”到“卷应用”:AI时代的价值重塑与个性化智能探索

&#x1f308;所属专栏&#xff1a;【其它】✨作者主页&#xff1a; Mr.Zwq✔️个人简介&#xff1a;一个正在努力学技术的Python领域创作者&#xff0c;擅长爬虫&#xff0c;逆向&#xff0c;全栈方向&#xff0c;专注基础和实战分享&#xff0c;欢迎咨询&#xff01; 您的点…

Java基础之集合

集合和数组的类比 数组: 长度固定可以存基本数据类型和引用数据类型 集合: 长度可变只能存引用数据类型存储基本数据类型要把他转化为对应的包装类 ArrayList集合 ArrayList成员方法 添加元素 删除元素 索引删除 查询 遍历数组

mqtt.fx连接阿里云

本文主要是记述一下如何使用mqtt.fx连接在阿里云上创建好的MQTT服务。 1 根据MQTT填写对应端口即可 找到设备信息&#xff0c;里面有MQTT连接参数 2 使用物模型通信Topic&#xff0c;注意这里的post说设备上报&#xff0c;那也就是意味着云端订阅post&#xff1b;set则意味着设…

向量索引【草稿】

用「向量」化数据表示「概念」。 向量表达:概念上更为接近的点在空间中更为聚集,而概念上更为不同的点,则距离更远。 向量数学表达:以坐标原点为起点,这些坐标点重点。 在语言上应用–词向量。 一个训练恰当的词向量集合,将和指代的事物之间的向量集合十分接近。有利于自…

记一次 .NET某上位视觉程序 离奇崩溃分析

一&#xff1a;背景 1. 讲故事 前段时间有位朋友找到我&#xff0c;说他们有一个崩溃的dump让我帮忙看下怎么回事&#xff0c;确实有太多的人在网上找各种故障分析最后联系到了我&#xff0c;还好我一直都是免费分析&#xff0c;不收取任何费用&#xff0c;造福社区。 话不多…

IDEA启动Web项目总是提示端口占用

IDEA启动Web项目总是提示端口占用 一、前言 1.场景 IDEA启动Web项目总是提示端口占用&#xff1a; 确实是端口被占用&#xff0c;比如&#xff1a;没有正常关闭 Springboot 项目导致Springboot 项目换任何端口都提示端口占用&#xff0c;而且找不到占用端口的程序 2.环境 …

Qt中https的使用,报错TLS initialization failed和不能打开ssl.lib问题解决

前言 在现代应用程序中&#xff0c;安全地传输数据变得越来越重要。Qt提供了一套完整的网络API来支持HTTP和HTTPS通信。然而&#xff0c;在实际开发过程中&#xff0c;开发者可能会遇到SSL相关的错误&#xff0c;例如“TLS initialization failed”&#xff0c;cantt open ssl…

要注意!Google账号提示活动异常就要注意了,很可能下一步就是真•停用

很多朋友&#xff0c;在主动或被动登录谷歌账号时&#xff0c;被提醒账号活动异常&#xff0c;要验证手机号才能进一步使用谷歌账号&#xff0c;这是什么原因呢&#xff1f;如果不及时验证会出现什么状况呢&#xff0c;该如何解决这个问题呢&#xff1f;如果验证提示手机无法用…

一篇文章教你掌握——Pytorch深度学习实践基础

一篇文章教你掌握——Pytorch深度学习实践 1. Overview 概述1.1 Rule-based systems 基于规则的系统1.2 Classic machine learning 经典机器学习1.3 Representation learning 表征学习1.4 Brief history of neural networks 神经网络简史 2. 配置环境2.1 安装Anaconda2.2 创建虚…

[stm32f407]定时器使用

1.定时器定时串口打印 main.c #include "stm32f4xx.h" // Device header #include "serial.h" #include "delay.h" #include "tim.h"extern uint16_t count;int main(void) {Serial_Init();TIM_Init();printf(&quo…

通过AIGC赋能创意设计发展

随着人工智能技术的飞速发展&#xff0c;AIGC&#xff08;Artificial Intelligence Generated Content&#xff09;正逐渐成为创意设计领域的新引擎。AIGC通过智能算法与大数据的深度融合&#xff0c;不仅为设计师们提供了前所未有的创意灵感&#xff0c;还在设计方案优化等方面…

云计算数据中心(一)

目录 一、云数据中心的特征二、云数据中心网络部署&#xff08;一&#xff09;改进型树结构&#xff08;二&#xff09;递归层次结构&#xff08;三&#xff09;光交换网络&#xff08;四&#xff09;无线数据中心网络&#xff08;五&#xff09;软件定义网络 一、云数据中心的…

光明乳业首推公益主题数字资产,用爱助力青少年健康成长

作为一直秉持“温暖如一”的企业价值观的百年乳企&#xff0c;光明乳业始终关注青少年儿童的健康成长&#xff0c;积极投身公益活动&#xff0c;用爱心和行动温暖他们的世界。 今年六月&#xff0c;适逢儿童节与全国爱眼日&#xff0c;光明乳业先后参与“童心筑爱 品牌赋能”公…

神经网络 | Transformer 基本原理

目录 1 为什么使用 Transformer&#xff1f;2 Attention 注意力机制2.1 什么是 Q、K、V 矩阵&#xff1f;2.2 Attention Value 计算流程2.3 Self-Attention 自注意力机制2.3 Multi-Head Attention 多头注意力机制 3 Transformer 模型架构3.1 Positional Encoding 位置编…

晏子春秋-读书笔记二

“橘生淮南则为橘&#xff0c;生于淮北则为枳&#xff0c;叶徒相似&#xff0c;其实味不同。所以然者何&#xff1f;水土异也。今民生长于齐不盗&#xff0c;入楚则盗&#xff0c;得无楚之水土使民善盗耶&#xff1f;” 这段话的大意是说&#xff0c;橘树生长在淮河以南就是甜美…

【触摸屏】【地震知识宣传系统】功能模块:视频 + 知识问答

项目背景 鉴于地震知识的普及对于提升公众防灾减灾意识的重要性&#xff0c;客户希望开发一套互动性强、易于理解的地震学习系统&#xff0c;面向公众、学生及专业人员进行地震知识教育与应急技能培训。 产品功能 系统风格&#xff1a;严谨的设计风格和准确的信息呈现&#…

PointCloudLib ISS关键点提取 C++版本

测试效果 算法简介 PCL(Point Cloud Library)中的内部形状描述子(ISS,Intrinsic Shape Signatures)关键点提取是一种在3D点云中提取显著几何特征点的方法。这种方法非常适用于需要高质量点云配准、对象识别和分类等任务。以下是对PCL内部形状描述子(ISS)关键点提取的详…

企业国产操作系统选型适配实施方案

【摘要】企业在推动国产化进程时&#xff0c;需选择一款主流、稳定且安全的服务器操作系统作为其系统软件。在产品投入实际生产环境前&#xff0c;对上游软硬件的适配情况有深入了解至关重要。本文将重点介绍银河麒麟高级服务器操作系统V10&#xff08;以下简称麒麟V10&#xf…

深度学习中的FLOPs补充

学习了博主的介绍&#xff08;深度学习中的FLOPs介绍及计算(注意区分FLOPS)-CSDN博客&#xff09;后&#xff0c;对我不理解的内容做了一点补充。 链接放到下边啦 https://blog.csdn.net/qq_41834400/article/details/120283103 FLOPs&#xff1a;注意s小写&#xff0c;是floa…