【Python机器学习】交互特征与多项式特征

news2024/10/7 18:25:24

对于线性模型来说,想要丰富特征,还有一种方法是添加原始数据的交互特征和多项式特征。这种特征工程通常用于统计建模,但也经常用于实际的机器学习应用中。

交互特征

上一篇的例子里,线性模型对wave数据集的的每个箱子都学到一个常数值,但线性模型不仅可以学习偏移,还可以学习斜率。想要向分箱数据上的线性模型添加斜率,一种方法是重新假如原始特征,这样就会得到11维的数据:

from sklearn.linear_model import LinearRegression
from sklearn.tree import DecisionTreeRegressor
import numpy as np
import matplotlib.pyplot as plt
from sklearn.preprocessing import OneHotEncoder

plt.rcParams['font.sans-serif'] = ['SimHei']
plt.rcParams['axes.unicode_minus'] = False

X,y=mglearn.datasets.make_wave(n_samples=100)
line=np.linspace(-3,3,1000,endpoint=False).reshape(-1,1)
bins=np.linspace(-3,3,11)
which_bin=np.digitize(X,bins=bins)
#使用OneHotEncoder进行变换
encoder=OneHotEncoder(sparse=False)
encoder.fit(which_bin)
#transform创建one-hot编码
X_binned=encoder.transform(which_bin)
line_binned=encoder.transform(np.digitize(line,bins=bins))
X_combined=np.hstack([X,X_binned])
print(X_combined.shape)

reg=LinearRegression().fit(X_combined,y)

line_combined=np.hstack([line,line_binned])
plt.plot(line,reg.predict(line_combined),label='线性回归()')

for bin in bins:
    plt.plot([bin,bin],[-3,3],':',c='k')

plt.legend(loc='best')
plt.xlabel('输入特征')
plt.ylabel('回归 输出')
plt.plot(X[:,0],y,'o',c='k')

plt.show()

在这个例子中,模型在每个箱子都学到一个偏移,还学到一个斜率。学到的斜率是向下的,而且每个箱子都相同,也就是只有一个x轴特征,只有一个斜率。因为斜率在所有箱子都是相同的,所以它似乎不是很有用,我们更希望每个箱子都有一个不同的斜率。

为了实现这一点,我们可以添加交互特征或乘积特征,用来表示数据点所在的箱子以及数据点在x轴的位置。这个特征就是箱子指示符与原始特征的乘积。创建数据集:

X_product=np.hstack([X_binned,X*X_binned])
print(X_product.shape)

这个数据集现在有20个特征:数据点所在箱子的指示符与原始特征和箱子指示符的乘积。

我们可以将乘积特征看作每个箱子x轴特征的单独副本。它在箱子内等于原始特征,在其位置等于0:

reg=LinearRegression().fit(X_product,y)
line_combined=np.hstack([line_binned,line*line_binned])
plt.plot(line,reg.predict(line_combined),label='线性回归()')

for bin in bins:
    plt.plot([bin,bin],[-3,3],':',c='k')

plt.legend(loc='best')
plt.xlabel('输入特征')
plt.ylabel('回归 输出')
plt.plot(X[:,0],y,'o',c='k')

plt.show()

可以看到,现在这个模型中,每个箱子都有自己的偏移和斜率。

多项式特征

使用分箱是扩展连续特征的一种方法。另一种方法是使用原始特征的多项式。对于给定特征x,我们可以考虑x**2、x**3、x**4等等,这在preprocessing模块的PolynomialFeatures中实现:

from sklearn.preprocessing import PolynomialFeatures


#包含直到x**10的多项式
ploy=PolynomialFeatures(degree=10,include_bias=False)
ploy.fit(X)
X_ploy=ploy.transform(X)
print(X_ploy.shape)

多项式的次数为10,因此生成了10个特征。

我们比较X_ploy和X的元素:

print('Entries of X:\n{}'.format(X[:5]))
print('Entries of X_ploy:\n{}'.format(X_ploy[:5]))

我们还可以通过调用get_feature_names_out方法来获取特征的语义,来给出每个特征的指数:

print('Polynomial feature names:{}'.format(ploy.get_feature_names_out()))

可以看到,X_ploy的第一列与X完全对应,而其他列则是第一列的幂。

将多项式特征与线性回归模型一起使用,可以得到经典的多项式回归模型:

reg=LinearRegression().fit(X_ploy,y)
line_ploy=ploy.transform(line)
plt.plot(line,reg.predict(line_ploy),label='线性回归(多项式回归)')

plt.plot(X[:,0],y,'o',c='k')
plt.legend(loc='best')
plt.xlabel('输入特征')
plt.ylabel('回归 输出')

plt.show()

可以看到,多项式特征在这个一维数据上得到了非常平滑的拟合,但高次多项式在边界或数据很少的区域可能会有极端的表现。

作为对比,下面是原始数据上学到的核SVM模型,没有做任何变换:

from sklearn.svm import SVR

for gamma in [1,10]:
    svr=SVR(gamma=gamma).fit(X,y)
    plt.plot(line,svr.predict(line),label='SVR gamma={}'.format(gamma))


plt.plot(X[:,0],y,'o',c='k')
plt.legend(loc='best')
plt.xlabel('输入特征')
plt.ylabel('回归 输出')

plt.show()

使用更复杂的模型(也就是核SVM),我们能够学习到一个与多项式回归的复杂度类似的预测结果,且不需要进行显式的特征变换。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1868176.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

基于稀疏矩阵方法的剪枝压缩模型方案总结

1.简介 1.1目的 在过去的一段时间里,对基于剪枝的模型压缩的算法进行了一系列的实现和实验,特别有引入的稀疏矩阵的方法实现了对模型大小的压缩,以及在部分环节中实现了模型前向算法的加速效果,但是总体上模型加速效果不理想。所…

从零到一打造自己的大模型:模型训练

前言 最近看了很多大模型,也使用了很多大模型。对于大模型理论似乎很了解,但是好像又缺点什么,思来想去决定自己动手实现一个 toy 级别的模型,在实践中加深对大语言模型的理解。 在这个系列的文章中,我将通过亲手实践…

【面试题】Spring面试题

目录 Spring Framework 中有多少个模块,它们分别是什么?Spring框架的设计目标、设计理念?核心是什么?Spring框架中都用到了哪些设计模式?Spring的核心机制是什么?什么是Spring IOC容器?什么是依…

竞赛选题 python区块链实现 - proof of work工作量证明共识算法

文章目录 0 前言1 区块链基础1.1 比特币内部结构1.2 实现的区块链数据结构1.3 注意点1.4 区块链的核心-工作量证明算法1.4.1 拜占庭将军问题1.4.2 解决办法1.4.3 代码实现 2 快速实现一个区块链2.1 什么是区块链2.2 一个完整的快包含什么2.3 什么是挖矿2.4 工作量证明算法&…

鸿蒙面试心得

自疫情过后,java和web前端都进入了冰河时代。年龄、薪资、学历都成了找工作路上躲不开的门槛。 年龄太大pass 薪资要高了pass 学历大专pass 好多好多pass 找工作的路上明明阳关普照,却有一种凄凄惨惨戚戚说不清道不明的“优雅”意境。 如何破局&am…

修复:cannot execute binary file --- ppc64le 系统架构

前言: 修复node_exporter,引用pprof包,对源码编译后在 Linux 系统下执行程序运行时,发生了报错,报错信息:cannot execute binary file: Exec format error。 开始以为编译有问题,检查发现;该l…

正规的外盘期货开户指南避坑!

一:最正规最靠谱的外盘期货开户方式。那就是直开香港账户,需要基本证件、护照、境外卡等。 如果你满足以上条件,可以直接在香港外盘期货公司的营业部或线上官网开户。 优点:安全正规,银期转账。 缺点:保…

Java - 程序员面试笔记记录 实现 - Part1

社招又来学习 Java 啦,这次选了何昊老师的程序员面试笔记作为主要资料,记录一下一些学习过程。 1.1 Java 程序初始化 Java 程序初始化遵循规则:静态变量优于动态变量;父类优于子类;成员变量的定义顺序; …

1. jenkins持续集成交付

jenkins持续集成交付 一、jenkins介绍二、jenkins的安装部署1、下载jenkins2、安装jenkins3、修改插件下载地址4、初始化jenkins 一、jenkins介绍 持续集成交付, CI/CD 偏开发、项目编译、部署、更新 二、jenkins的安装部署 1、下载jenkins [rootjenkins ~]# wge…

LLM 推理:Nvidia TensorRT-LLM 与 Triton Inference Server

随着LLM越来越热门,LLM的推理服务也得到越来越多的关注与探索。在推理框架方面,tensorrt-llm是非常主流的开源框架,在Nvidia GPU上提供了多种优化,加速大语言模型的推理。但是,tensorrt-llm仅是一个推理框架&#xff0…

算法设计与分析--分布式系统作业及答案

分布式系统 作业参考答案2.1 分析在同步和异步模型下,convergecast 算法的时间复杂性。2.2 G 里一结点从 pr 可达当且仅当它曾设置过自己的 parent 变量。2.3 证明 Alg2.3 构造一棵以 Pr 为根的 DFS 树。2.4 证明 Alg2.3 的时间复杂度为 O(m)。2.5 修改 Alg2.3 获得…

限域传质分离膜兼具高渗透性、高选择性特点 未来应用前景广阔

限域传质分离膜兼具高渗透性、高选择性特点 未来应用前景广阔 分离膜是一种具有选择性透过功能的薄层材料。限域传质分离膜是基于限域传质机制的分离膜,兼具高渗透性、高选择性的特点。限域传质是流体分子通过与其运动自由程相当传质空间的过程,流体分子…

网络安全 DVWA通关指南 Cross Site Request Forgery (CSRF)

DVWA Cross Site Request Forgery (CSRF) 文章目录 DVWA Cross Site Request Forgery (CSRF)DVWA Low 级别 CSRFDVWA Medium 级别 CSRFDVWA High 级别 CSRFDVWA Impossible 级别 CSRF CSRF是跨站请求伪造攻击,由客户端发起,是由于没有在执行关键操作时&a…

推荐一个shp修复工具

我们在《如何解决ArcGIS中数据显示乱码问题》一文中,为你分享过打开shp文件的乱码问题。 现在再为你分享一个shp文件的修复工具,你可以在文末查看该工具的领取方式。 shp文件修复工具 Shapefile(简称SHP)是Esri推出的一种广泛使…

新能源行业知识体系-------蒙西电网需求侧响应

新能源行业知识体系-------主目录-----持续更新(进不去说明我没写完):https://blog.csdn.net/grd_java/article/details/139946830 目录 一、背景介绍二、需求响应电能量收益介绍三、超额回收需求响应减免收益介绍四、参与需求侧响应五、蒙西电力现货特点六、交易中…

好消息!终于解决了!Coze工作流错误中断问题终于得到解决!

文章目录 📖 介绍 📖🏡 演示环境 🏡📒 解决方案 📒📝 常见的工作流中断问题📝 好消息来了!⚓️ 相关链接 ⚓️📖 介绍 📖 大家是否曾经遇到过这样的问题:在Coze平台辛辛苦苦设计的一个工作流,尤其是流程非常复杂和长的情况下,只要中间一个环节出错,整…

红海云签约联东集团,引领产业园区领军企业人力资源数字化新范式

北京联东投资(集团)有限公司(以下简称“联东集团”)是集产业园区运营、模板钢结构和投资业务为一体的集团化公司。联东集团独创了产业聚合U模式,致力于打造产业集聚平台,服务于实体企业成长和地区经济发展。…

[SD必备知识18]修图扩图AI神器:ComfyUI+Krita加速修手抽卡,告别低效抽卡还原光滑细腻双手,写真无需隐藏手势

🌹大家好!我是安琪!感谢大家的支持与鼓励。 krita-ai-diffusion简介 在AIGC图像生成领域的迅猛发展下,当前的AI绘图工具如Midjourney、Stable Diffusion都能够近乎完美的生成逼真富有艺术视觉效果的图像质量。然而,针…

基于大语言模型的多意图增强搜索

随着人工智能技术的蓬勃发展,大语言模型(LLM)如Claude等在多个领域展现出了卓越的能力。如何利用这些模型的语义分析能力,优化传统业务系统中的搜索性能是个很好的研究方向。 在传统业务系统中,数据匹配和检索常常面临…

SpringMVC 请求参数接收

目录 请求 传递单个参数 基本类型参数传递 未传递参数 传递参数类型不匹配 传递多个参数 传递对象 后端参数重命名 传递数组 传递集合 传递JSON数据 JSON是什么 JSON的优点 传递JSON对象 获取URL中的参数 文件上传 在浏览器与程序进行交互时,主要分为…