机器学习——弹性网估计

news2024/11/25 14:41:22

机器学习——弹性网估计

文章目录

  • 机器学习——弹性网估计
    • @[toc]
    • 1 模型介绍
    • 2 模型设定
    • 3 弹性网估计

1 模型介绍

弹性网估计属于惩罚回归,常见的惩罚回归包括岭回归(ridge)、套索回归(lasso)和弹性网(elasticnet)回归等。

岭回归用于缓解高维数据可能的多重共线性问题:当样本容量 n n n小于特征变量数 p p p时,采用普通最小二乘法会出现多重共线问题,无法识别特征变量对标签的影响。如果 n n n不是远大于 p p p,即使能识别各特征变量的影响系数,但外推能力较差。岭回归在损失函数基础上加上待估参数的L2范数,通过最小化使各变量系数向原点收缩。给定收缩参数,能识别出某些特征变量对标签的影响可能并不稳健。

套索回归在损失函数基础上加上待估参数的L1范数,这使得损失函数变得不可微,但某些变量的影响系数可能刚好等于0,使得损失函数最小化。这使套索回归具备筛选变量的功能。

弹性网估计是岭回归和套索回归的混合,尽管lasso可以筛选变量,但对于具有高度相关的多个变量,lasso会任意进行筛选,导致经济解释不足。由于ridge基本不会出现‘稀疏解’,将ridge与lasso结合,即L1范数和L2范数均融入损失函数中形成弹性网估计。


2 模型设定

给定多元线性回归模型
y = X β + ε \mathbf{y}=\mathbf{X} \boldsymbol{\beta}+\boldsymbol{\varepsilon} y=Xβ+ε
其中标签或响应变量为 y ≡ ( y 1 y 2 ⋯ y n ) ′ \mathbf{y} \equiv\left(y_1 y_2 \cdots y_n\right)^{\prime} y(y1y2yn) X \mathbf{X} X为变量向量, β \boldsymbol{\beta} β是参数向量, ε \boldsymbol{\varepsilon} ε是残差向量。为估计参数向量 β \boldsymbol{\beta} β,使用普通最小二乘法OLS得
min ⁡ β L ( β ) = ( y − X β ) ′ ( y − X β ) ⏟ S S R \min _{\boldsymbol{\beta}} L(\boldsymbol{\beta})=\underbrace{(\mathbf{y}-\mathbf{X} \boldsymbol{\beta})^{\prime}(\mathbf{y}-\mathbf{X} \boldsymbol{\beta})}_{S S R} βminL(β)=SSR (yXβ)(yXβ)
求解得
β ^ O L S ≡ ( X ′ X ) − 1 X ′ y \hat{\boldsymbol{\beta}}_{O L S} \equiv\left(\mathbf{X}^{\prime} \mathbf{X}\right)^{-1} \mathbf{X}^{\prime} \mathbf{y} β^OLS(XX)1Xy
这里必须假设 ( X ′ X ) − 1 \left(\mathbf{X}^{\prime} \mathbf{X}\right)^{-1} (XX)1存在。对于高维数据,OLS适应性下降,因此考虑对损失函数 L ( β ) L(\boldsymbol{\beta}) L(β)加以改进,加入 β \boldsymbol{\beta} β的L2范数
min ⁡ β L ( β ) = ( y − X β ) ′ ( y − X β ) ⏟ S S R + λ ∥ β ∥ 2 2 ⏟ penalty  \min _{\boldsymbol{\beta}} L(\boldsymbol{\beta})=\underbrace{(\mathbf{y}-\mathbf{X} \boldsymbol{\beta})^{\prime}(\mathbf{y}-\mathbf{X} \boldsymbol{\beta})}_{S S R}+\underbrace{\lambda\|\boldsymbol{\beta}\|_2^2}_{\text {penalty }} βminL(β)=SSR (yXβ)(yXβ)+penalty  λβ22
其中 λ \lambda λ称为惩罚系数、调节系数或收缩因子;惩罚系数大,参数向量向原点越靠近。计入L2范数后,损失函数不仅要考虑预测误差平方和最小,也要兼顾参数向量的平方和大小:如果参数过大,那么损失函数不一定最小。上述 β \boldsymbol{\beta} β的估计量即为 β r i d g e \boldsymbol{\beta}_{ridge} βridge。同理,也可以加入 β \boldsymbol{\beta} β的L1范数
min ⁡ β L ( β ) = ( y − X β ) ′ ( y − X β ) ⏟ S S R + λ ∥ β ∥ 1 ⏟ penalty  \min _{\boldsymbol{\beta}} L(\boldsymbol{\beta})=\underbrace{(\mathbf{y}-\mathbf{X} \boldsymbol{\beta})^{\prime}(\mathbf{y}-\mathbf{X} \boldsymbol{\beta})}_{S S R}+\underbrace{\lambda\|\boldsymbol{\beta}\|_1}_{\text {penalty }} βminL(β)=SSR (yXβ)(yXβ)+penalty  λβ1
上述参数估计量称为 β l a s s o \boldsymbol{\beta}_{lasso} βlasso。将两种范数相结合
min ⁡ β ( y − X β ) ′ ( y − X β ) + λ [ α ∥ β ∥ 1 + ( 1 − α ) ∥ β ∥ 2 2 ] \min _{\boldsymbol{\beta}}(\mathbf{y}-\mathbf{X} \boldsymbol{\beta})^{\prime}(\mathbf{y}-\mathbf{X} \boldsymbol{\beta})+\lambda\left[\alpha\|\boldsymbol{\beta}\|_1+(1-\alpha)\|\boldsymbol{\beta}\|_2^2\right] βmin(yXβ)(yXβ)+λ[αβ1+(1α)β22]
其中 λ \lambda λ为惩罚系数, α \alpha α为混合系数,即L1范数惩罚所占比例, 1 − α 1-\alpha 1α为L2惩罚所占比例。 α ∈ ( 0 , 1 ) \alpha \in(0,1) α(0,1)时,上述参数向量估计量即为弹性网估计量 β e l a s t i c n e t \boldsymbol{\beta}_{elasticnet} βelasticnet α = 0 \alpha = 0 α=0退化为ridge回归; α = 1 \alpha = 1 α=1退化为lasso回归。参数 λ \lambda λ α \alpha α通过交叉验证最小化均方误差获得。


3 弹性网估计

下面是Python代码

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from sklearn.model_selection import train_test_split
from sklearn.model_selection import KFold
from sklearn.preprocessing import StandardScaler
from sklearn.linear_model import RidgeCV
from sklearn.linear_model import ElasticNet
from sklearn.linear_model import ElasticNetCV
from sklearn.linear_model import enet_path
# 使用boston数据集
boston = pd.read_csv('boston.csv')
# 特征变量提取
X_raw = boston.iloc[:, :-1]
# 响应变量
y = boston.iloc[:, -1]
# 特征变量标准化
scaler = StandardScaler()
X = scaler.fit_transform(X_raw)
# 模型估计,惩罚系数任意(0,1)
# l1_ratio = 0表示岭回归,l1_ratio = l为套索回归;alpha惩罚系数
model = ElasticNet(alpha=0.1, l1_ratio=0.5)
model.fit(X, y)
print('模型得分:\n',model.score(X, y))
result = pd.DataFrame({'变量':X_raw.columns,'系数':model.coef_})
print(f'最优回归系数:\n',result)
# 收缩路径
alphas, coefs, _ = enet_path(X, y, eps=1e-4, l1_ratio = 0.5)
ax = plt.gca()
ax.plot(alphas, coefs.T)
ax.set_xscale('log')
plt.xlabel(r'$\alpha$ (log scale)')
plt.ylabel('Coefficients')
plt.title('Elastic Net Cofficient Path (l1_ratio = 0.5)')
plt.axhline(0, linestyle='--', linewidth=1, color='k')
plt.legend(X_raw.columns)
plt.grid()
plt.show()

在这里插入图片描述

使用十折交叉验证惩罚参数和混合参数

#alpha = 0.001-1;lambda =  0.001-1
alphas = np.logspace(-3, 0, 100)
kfold = KFold(n_splits=10, shuffle=True, random_state=1)
model = ElasticNetCV(cv=kfold, alphas=alphas, l1_ratio=np.logspace(-3, 0, 100))
model.fit(X, y)
print('最优惩罚系数:\n',model.alpha_)
print('最优混合系数:\n',model.l1_ratio_)
# 最优惩罚系数:0.02848035868435802
#最优混合系数:1.0

混合系数 α = 1 \alpha = 1 α=1,此时选择模型lasso回归,最优惩罚参数为 λ = 0.028 \lambda = 0.028 λ=0.028

X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.3, random_state=2344)
fit = Lasso(alpha=model.alpha_)
fit.fit(X_train, y_train)
score = model.score(X_train, y_train)
print('模型得分:',score)
result = pd.DataFrame({'变量':X_raw.columns,'系数':model.coef_})
print(f'回归系数:\n',result)
# 模型得分: 0.7237822155914061
# 回归系数:
#           变量        系数
# 0      CRIM -0.846146
# 1        ZN  0.965785
# 2     INDUS -0.000000
# 3      CHAS  0.680701
# 4       NOX -1.886944
# 5        RM  2.713469
# 6       AGE -0.000000
# 7       DIS -2.935723
# 8       RAD  2.203538
# 9       TAX -1.658672
# 10  PTRATIO -2.011514

-END-

陈强,《机器学习及Python应用》高等教育出版社, 2021年3月

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/513868.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

aws exam

Route 53 Route 53 是AWS的一个服务,它的主要功能如下,下面会一一介绍每个功能 Domain registration(域名注册)DNS management(DNS管理)Health check(健康检查)Routing polices&am…

K8S系列之NetworkPolicy

什么是NetworkPolicy IP 地址或端口层面(OSI 第 3 层或第 4 层)控制网络流量, 则你可以考虑为集群中特定应用使用 Kubernetes 网络策略(NetworkPolicy)。 NetworkPolicy 是一种以应用为中心的结构,允许你设…

浅析EasyCVR视频汇聚技术在城市智慧文旅数智平台中的应用意义

一、背景分析 根据文化和旅游部4月21日公布的2023年一季度国内旅游数据情况抽样调查统计结果显示,2023年一季度,国内旅游总人次12.16亿,比上年同期增加3.86亿,同比增长46.5%。其中,城镇居民国内旅游人次9.44亿&#x…

消息和消息队列、以及作用场景(一)

“消息”是在两台计算机间传送的数据单位。消息可以非常简单,例如只包含文本字符串;也可以更复杂,可能包含嵌入对象。 “消息队列”是在消息的传输过程中保存消息的容器。 目前的消息队列有很多,例如:Kafka、RabbitMQ…

包装三年经验拿21K,试用期没过完就被裁了....

最近翻了一些网站的招聘信息,把一线大厂和大型互联网公司看了个遍,发现市场还是挺火热的,虽说铜三铁四,但是软件测试岗位并没有削减多少,建议大家有空还是多关注和多投简历,不要闭门造车,错过好…

Redis命令详解

Redis是一个高性能的内存键值数据库,它支持多种数据结构,包括字符串、哈希、列表、集合、有序集合等。Redis通过提供一组命令来实现对数据的操作,这些命令可以通过Redis客户端发送给Redis服务器,从而对数据库进行操作。 Redis的一…

阿里云刘伟光:2 万字解读金融级云原生

作者:刘伟光,阿里云智能新金融&互联网行业总裁、中国金融四十人论坛常务理事,毕业于清华大学电子工程系 01 前言 2015年云原生理念提出的时候,彼时全球金融百年发展形成的信息化到数字化的背后,金融级的技术服务…

好用工具第1期:手机电脑同屏QtScrcpy

QtScrcpy 可以通过 USB / 网络连接Android设备,并进行显示和控制。无需root权限。 同时支持 GNU/Linux ,Windows 和 MacOS 三大主流桌面平台。 QtScrcpy 是一个开源项目, 项目地址是: https://github.com/barry-ran/QtScrcpy 它专注于: 精致 (仅显示设…

Java 责任链模式详解

责任链模式(Chain of Responsibility Pattern)是一种行为型设计模式,它用于将请求的发送者和接收者解耦,使得多个对象都有机会处理这个请求。在责任链模式中,有一个请求处理链条,每个处理请求的对象都是一个…

mysql数据库基础知识,mysql数据库简介(一看就懂,一学就会)

目录 一、MySQL学习路线二、MySQL常见操作1、查看所有数据库show databases。2、MySQL 创建数据库3、删除数据库4、选择数据库use databasename5、查看该数据库下所有表show tables6、创建数据库表7、删除数据库 三、增删改查1、插入数据2、查询数据3、where子句4、更新语句5、…

微前端应用(qiankun+umi+antd)

1.微前端介绍以应用选型 1.1什么是微前端? 微前端是一种前端架构模式,它将前端应用程序拆分成多个小型的、独立开发、独立部署的子应用,然后将这些子应用组合成一个大型的、复杂的前端应用。每个子应用都有自己的技术栈、独立的代码库、独立的开发、测…

Linux快捷命令

目录 一、快捷排序——sort 常用选项: 示例 二、快捷去重——uniq 常用选项: 示例: ​编辑 ​编辑 ​编辑 三、快捷替换——tr 用于windows的编写的脚本格式转换为Linux格 方法一: 方法二: 四、快速裁…

JAVA double精度丢失问题

double类型精度丢失问题: 0.1*0.1使用计算器计算是0.01,代码里却是0.010000000000000002 public class HelloWorld {public static void main(String []args) {double number1 0.1;double number2 0.1;double result number1 * number2 ;System.o…

CSP-S 2022 提高级 第一轮 阅读程序(1) 第16-21题

【题目】 CSP-S 2022 提高级 第一轮 阅读程序&#xff08;1&#xff09; 第16-21题 01 #include <iostream> 02 #include <string> 03 #include <vector> 04 05 using namespace std; 06 07 int f(const string &s, const string &t) 08 { …

关于cartographer建立正确关系树的理解

正确的TF关系map----odom----base_link----laser base_link是固定在机器人本体上的坐标系&#xff0c;通常选择飞控 其中map–odom 的链接是由cartographer中lua文件配置完成的 map_frame "map", tracking_frame "base_link", published_frame "b…

Ubuntu 20.04 安装 mysql8 并配置远程访问

文章目录 一、使用 apt-get 安装 mysql 服务二、初始化 mysql 数据库管理员用户密码三、配置远程访问 一、使用 apt-get 安装 mysql 服务 # 更新软件源 apt-get install update# 安装mysql服务 apt-get install mysql-server# 使用mysqladmin工具查看mysql版本 mysqladmin --v…

一文解析Linux进程的睡眠和唤醒

Linux进程的睡眠和唤醒 在Linux中&#xff0c;仅等待CPU时间的进程称为就绪进程&#xff0c;它们被放置在一个运行队列中&#xff0c;一个就绪进程的状 态标志位为 TASK_RUNNING。一旦一个运行中的进程时间片用完&#xff0c; Linux 内核的调度器会剥夺这个进程对CPU的控制权&…

燃气巡检二维码

对燃气公司的输气管道和阀井等设施的巡检工作的管理目标是能降低成本、提高工作效率以及管理水平。但用纸质记录的方式进行燃气设备巡检有以下缺点&#xff1a; 1、难保证巡检真实性 无法客观、方便地掌握巡检人员巡检的到位情况&#xff0c;因而无法有效地保证巡检工作人员按计…

软件兼容性测试如何进行?怎么选择靠谱的软件检测公司?

软件兼容性测试是一项非常重要的工作&#xff0c;能够确保在不同的操作系统、设备、浏览器以及其他软件环境下&#xff0c;软件应用都能够正常运行。 一、软件兼容性测试如何进行? 确定测试的环境&#xff0c;包括操作系统、设备、浏览器等&#xff0c;并建立测试用例和测试…

Maven必要知识

参考笔记&#xff1a; https://www.wolai.com/arAiYJYCr6Kkfi2kZ8HxE8 1. Maven 概述 1.1 什么是 Maven Maven 是 Apache 软件基金会组织维护的一款专门为 Java 项目提供构建和依赖管理支持的工具。 Maven 作为依赖管理工具 jar 包的管理jar 包的来源jar 包之间的依赖关系…