机器学习——多元线性回归升维

news2024/10/1 5:28:29

机器学习升维

    • 升维
    • 使用sklearn库实现特征升维
    • 实现天猫年度销量预测
    • 实现中国人寿保险预测

升维

定义:将原始的数据表示从低维空间映射到高维空间。在线性回归中,升维通常是通过引入额外的特征来实现的,目的是为了更好地捕捉数据的复杂性,特别是当数据之间的关系是非线性的时候。

目的:解决欠拟合问题,提高模型的准确率。为解决因对预测结果考虑因素比较少,而无法准确计算出模型参数问题。

常用方法:将已知维度进行自乘(或相乘)来构建新的维度。

本文主要记录的是线性回归中遇到数据呈现非线性特征时,该如何处理!

切记:对训练集特征升维后也要对测试集、验证集特征数据进行升维操作

数据准备如下:

在这里插入图片描述

如果对其直接进行线性回归,则拟合后的模型如下:

在这里插入图片描述

从上述两图可知,对于具有非线性特征的图像,不对其使用特使的处理,则无法对其产生比较好的模型拟合。

上述图像生成代码:

# 导包
import numpy as np
import matplotlib.pyplot as plt
from sklearn.linear_model import LinearRegression
# 创建数据
X = np.linspace(-1,11,100)
y = (X - 5)**2 + 3*X + 12 + np.random.randn(100)
X = X.reshape(-1,1)
# display(X.shape,y.shape)
plt.scatter(X,y)

# 不升维直接用线性回归解决
model = LinearRegression()
model.fit(X,y)
X_test = np.linspace(-2,12,300).reshape(-1,1)
y_test = model.predict(X_test)
plt.scatter(X,y)
plt.plot(X_test,y_test,color = 'red')

为了使得可以对具有非线性特征的数据进行处理,生成一个较好的模型,可是实现预测的任务,于是便有了升维操作,下举例升维和不升维的区别:

不升维:二维数据x1, x2若不对其进行升维操作,则其拟合的多元线性回归公式为:
y = w 1 ∗ x 1 + w 2 ∗ x 2 + w 0 y = w_1*x_1 + w_2*x_2 + w_0 y=w1x1+w2x2+w0

升维:若对二维数据x1,x2进行升维操作,则其可有5个维度(以自乘为例):x1、x2、x12,x22、x1*x2,在加上一个偏置项w0,一共有六个参数,则其拟合后的多元线性回归公式为:
y = w 0 + w 1 ∗ x 1 + w 2 ∗ x 2 + w 3 ∗ x 1 2 + w 4 ∗ x 2 2 + w 5 ∗ x 1 ∗ x 2 y= w_0+w_1*x_1+w_2*x_2+w_3*x_1^2+w_4*x_2^2+w_5*x_1*x_2 y=w0+w1x1+w2x2+w3x12+w4x22+w5x1x2

若这样,则由原本的一维线性方程转换成了二维函数(最直观的表现),则原本的数据集则可以拟合成下图所示的模型:

在这里插入图片描述

上图生成代码如下:

# 导包
import numpy as np
import matplotlib.pyplot as plt
from sklearn.linear_model import LinearRegression
X = np.linspace(-1,11,100)
y = (X - 5)**2 + 3*X + 12 + np.random.randn(100)
X = X.reshape(-1,1)

# 升维,可以解决多项式的问题,直观表现为可以让直线进行拐弯
np.set_printoptions(suppress=True)
X2 = np.concatenate([X,X**2], axis= 1)
# 注:只需要对特征进行升维,不需要对目标值进行升维

# 生成测试数据
X_test = np.linspace(-2,12,300).reshape(-1,1) 
model2 = LinearRegression()
model2.fit(X2,y)
X_test2 = np.concatenate([X_test,X_test**2],axis=1)
y_test2 = model2.predict(X_test2)
print('所求的w是\n',model2.coef_)
print('所求的截距b是\n',model2.intercept_)

# 绘制图像的时候要用没升维的数据进行绘制
plt.scatter(X,y,color='green')
plt.plot(X_test,y_test2,color = 'red')

使用sklearn库实现特征升维

在sklearn中具有很多封装好的工具,可以直接调用。

from sklearn.preprocessing import PolynomialFeatures # (多项式)升维的python库

使用方法:

# 特征和特征之间相乘
poly = PolynomialFeatures(interaction_only=True)
A = [[3,2]]
poly.fit_transform(A)
# 生成结果:array([[1., 3., 2., 6.]])

#特征之间乘法,自己和自己自乘(在上述情况下加上自己的乘法)
poly = PolynomialFeatures(interaction_only=False)
A = [[3,2,5]]
poly.fit_transform(A)
# 生成结果:array([[ 1.,  3.,  2.,  5.,  9.,  6., 15.,  4., 10., 25.]])

# 可以通过degree来提高升维的大小
poly = PolynomialFeatures(degree=4,interaction_only=False)# 特征和特征之间相乘
A = [[3,2,5]]
poly.fit_transform(A)
# 生成结果:
# array([[  1.,   3.,   2.,   5.,   9.,   6.,  15.,   4.,  10.,  25.,  27.,
#         18.,  45.,  12.,  30.,  75.,   8.,  20.,  50., 125.,  81.,  54.,
#        135.,  36.,  90., 225.,  24.,  60., 150., 375.,  16.,  40., 100.,
#        250., 625.]])

实现天猫年度销量预测

实现代码:

import numpy as np
import matplotlib.pyplot as plt
from sklearn.preprocessing import PolynomialFeatures,StandardScaler
from sklearn.linear_model import LinearRegression,SGDRegressor

# 创建数据
X = np.arange(2009,2020).reshape(-1,1) - 2008
y = np.array([0.5,9.36,52,191,350,571,912,1207,1682,2135,2684])
plt.scatter(X,y)
# 创建测试数据
X_test = np.linspace(2009,2020,100).reshape(-1,1) - 2008

# 数据升维
ploy = PolynomialFeatures(degree=2, interaction_only=False)
X2 = ploy.fit_transform(X)
X_test2 = ploy.fit_transform(X_test)

# 模型创建LinearRegression
model = LinearRegression(fit_intercept=False)
model.fit(X2,y)
y_pred = model.predict(X_test2)
print('参数w为:',model.coef_)
print('参数b为:',model.intercept_)

plt.scatter(X,y,color='green')
plt.plot(X_test,y_pred,color='red')
# 使用SGD进行梯度下降,必须要归一化,否则效果会非常不好
# 创建测试数据
X_test = np.linspace(2009,2019,100).reshape(-1,1) - 2008

# 数据升维
ploy = PolynomialFeatures(degree=2, interaction_only=False)
X2 = ploy.fit_transform(X)
X_test2 = ploy.fit_transform(X_test)

#对数据进行归一化操作
standard = StandardScaler()
X2_norm = standard.fit_transform(X2)
X_test2_norm = standard.fit_transform(X_test2)

# 模型创建SGDRegression
model = SGDRegressor(eta0=0.3, max_iter=5000)
model.fit(X2_norm,y)
y_pred = model.predict(X_test2_norm)
print('参数w为:',model.coef_)
print('参数b为:',model.intercept_)

plt.scatter(X,y,color='green')
plt.plot(X_test,y_pred,color='red')

这里需要说明一下情况,如果第二段代码不进行归一化,则呈现的是下图:

在这里插入图片描述

如果进行了归一化,则产生的和法一LinearRegession是一样的图形(基本相同):

在这里插入图片描述

这是什么原因?

  • 线性回归(Linear Regression)和随机梯度下降(SGD)在处理特征尺度不同的问题上有一些不同之处,导致线性回归相对于特征尺度的敏感性较低。
  • SGD的更新规则涉及学习率(η)和梯度。如果不同特征的尺度相差很大,梯度的大小也会受到这种尺度差异的影响。因此在引入高次项或其他非线性特征,需要注意特征的尺度,避免数值上的不稳定性。
  • SGD中的正则化项通常依赖于权重的大小。通过归一化,可以使得正则化项对所有特征的影响更加平衡。

实现中国人寿保险预测

import pandas as pd
import seaborn as sns
import numpy as np
from sklearn.linear_model import LinearRegression,ElasticNet
from sklearn.metrics import mean_squared_error,mean_squared_log_error
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import PolynomialFeatures

# 读取数据
data_renshou = pd.read_excel('your_path/中国人寿.xlsx')
# 可以通过下式生成图像,查看那些数据是好数据那些是不好的数据——好特征:差别大,容易区分
#sns.kdeplot(data=data_renshou, x="charges",hue="sex",shade=True)
#sns.kdeplot(data=data_renshou, x="charges",hue="smoker",shade=True)
#sns.kdeplot(data=data_renshou, x="charges",hue="region",shade=True)
#sns.kdeplot(data=data_renshou, x="charges",hue="children",shade=True)

# 特征工程,对数据进行处理
data_renshou = data_renshou.drop(['region','sex'],axis = 1)	# 删除不好的特征

# 体重指数,离散化转换,体重两种情况:标准,fat
def conver(df,bmi):
    df['bmi'] = 'fat' if df['bmi'] >= bmi else 'standard'
    return df
data_renshou = data_renshou.apply(conver, axis=1,args=(30,))

# 特征提取,离散转数值型数据
data_renshou = pd.get_dummies(data_renshou)
data_renshou.head()

#特征和目标值提取
# 训练数据
x = data_renshou.drop('charges', axis=1)
# 目标值
y = data_renshou['charges']

# 划分数据
X_train,X_test,y_train,y_test = train_test_split(x,y,test_size=0.2)

# 特征升维(导致了他下面的参数biandu)
poly = PolynomialFeatures(degree=2, include_bias=False)
X_train_poly = poly.fit_transform(X_train)
X_test_poly = poly.fit_transform(X_test)
# 模型训练与评估
np.set_printoptions(suppress=True)
model = LinearRegression()
model.fit(X_train_poly,y_train)
print('测试数据得分:',model.score(X_train_poly,y_train))
print('预测数据得分:',model.score(X_test_poly,y_test))
print('测试数据均方误差:',np.sqrt(mean_squared_error(y_test,model.predict(X_test_poly))))
print('训练数据均方误差:',np.sqrt(mean_squared_error(y_train,model.predict(X_train_poly))))
print('测试数据对数误差:',np.sqrt(mean_squared_log_error(y_test,model.predict(X_test_poly))))
print('训练数据对数误差:',np.sqrt(mean_squared_log_error(y_train,model.predict(X_train_poly))))
print('获得的参数为:',model.coef_.round(2),model.intercept_.round(2))

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1261400.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

MacOS 系统 Flutter开发Android 环境配置

上节我们已经把 开发工具准备齐全,并可以进行Flutter的web开发,本节将做安卓开发环境进行详细说明 接上节这里先说下,系统环境 MacOS14 (Sonoma) 芯片 Apple M3 执行命令:flutter doctor 提示如下&#…

Shell脚本:Linux Shell脚本学习指南(第三部分Shell高级)一

第三部分:Shell高级(一) 这一章讲解 Shell 脚本编程的进阶内容,主要涉及重定向、文件描述符、管道和过滤器、子 Shell、信号等。 本章会使用到一些底层的编程知识,有C语言和 C 编程经验的程序员阅读起来将会更加轻松。…

《微信小程序开发从入门到实战》学习三十三

第四章 云开发 本章云开发技术的功能与使用,包括以下几点: 1.学习使用云开发控制台 2.学习云开发JSON数据库功能 3.学习云开文件存储功能 4.学习云函数功能 5.使用云开发技术实现投票小程序的服务端功能 投票小程序大部分已经实现。需要实现&#…

人工智能-优化算法之凸集

凸性 凸性(convexity)在优化算法的设计中起到至关重要的作用, 这主要是由于在这种情况下对算法进行分析和测试要容易。 换言之,如果算法在凸性条件设定下的效果很差, 那通常我们很难在其他条件下看到好的结果。 此外&…

二叉堆与优先队列

二叉堆与优先队列 1、什么是二叉堆 1.1、初识二叉堆 什么是二叉堆? 二叉堆本质上是一种完全二叉树,它分为两个类型。 最大堆(也叫大顶堆):任意节点的值都大于或等于它的左右孩子节点的值,并且最大的值位…

SHAP(一):具有 Shapley 值的可解释 AI 简介

SHAP(一):具有 Shapley 值的可解释 AI 简介 这是用 Shapley 值解释机器学习模型的介绍。 沙普利值是合作博弈论中广泛使用的方法,具有理想的特性。 本教程旨在帮助您深入了解如何计算和解释基于 Shapley 的机器学习模型解释。 我…

NX二次开发UF_CURVE_create_arc_point_tangent_radius 函数介绍

文章作者:里海 来源网站:https://blog.csdn.net/WangPaiFeiXingYuan UF_CURVE_create_arc_point_tangent_radius Defined in: uf_curve.h int UF_CURVE_create_arc_point_tangent_radius(tag_t point, tag_t tangent_object, double radius, UF_CURVE_…

初识前后端数据交互(新手篇)

一个软件项目的开发必然是离不开前端和后端的协作,对于刚入行的新手前端或者新手后端来说,很有必要了解一下对方是在做什么,以及提供给自己什么样的帮助,为什么需要对方共同协作才能完成整个软件项目的开发呢?希望这篇…

Scrapy框架内置管道之图片视频和文件(一篇文章齐全)

1、Scrapy框架初识(点击前往查阅) 2、Scrapy框架持久化存储(点击前往查阅) 3、Scrapy框架内置管道 4、Scrapy框架中间件(点击前往查阅) Scrapy 是一个开源的、基于Python的爬虫框架,它提供了…

2015年五一杯数学建模B题空气污染问题研究解题全过程文档及程序

2015年五一杯数学建模 B题 空气污染问题研究 原题再现 近十年来,我国 GDP 持续快速增长,但经济增长模式相对传统落后,对生态平衡和自然环境造成一定的破坏,空气污染的弊病日益突出,特别是日益加重的雾霾天气已经干扰…

从0开始学习JavaScript--JavaScript对象继承深度解析

JavaScript中的对象继承是构建灵活、可维护代码的关键部分。本文将深入讨论JavaScript中不同的继承方式,包括原型链继承、构造函数继承、组合继承等,并通过丰富的示例代码展示它们的应用和差异。通过详细解释,大家可以更全面地了解如何在Java…

Shopee如何入驻?如何防封?

Shopee作为东南亚领航电商平台,面向东南亚蓝海市场,近年来随着东南亚市场蒸蒸日上,虾皮也吸引了大批量的跨境商家入驻。那么接下来就给想要入驻的虾皮小白一个详细的安全入驻教程。 一、商家如何入驻 虾皮与LAZADA最大的区别就是商家即卖家&…

RT-DETR改进 | 2023 | InnerEIoU、InnerSIoU、InnerWIoU、InnerDIoU等二十余种损失函数

论文地址:官方Inner-IoU论文地址点击即可跳转 官方代码地址:官方代码地址-官方只放出了两种结合方式CIoU、SIoU 本位改进地址: 文末提供完整代码块-包括InnerEIoU、InnerCIoU、InnerDIoU等七种结合方式和其AlphaIoU变种结合起来可以达到二十…

15、矩阵键盘密码锁

矩阵键盘密码锁 main.c #include <REGX52.H> #include "Delay.h" #include "LCD1602.h" #include "MatrixKey.h"//初始化变量 unsigned char KeyNum; unsigned int Password,Count;void main() {//LCD屏幕初始化显示Password:LCD_Init();…

kafka的详细安装部署

简介&#xff1a; Kafka是一个分布式流处理平台&#xff0c;主要用于处理高吞吐量的实时数据流。Kafka最初由LinkedIn公司开发&#xff0c;现在由Apache Software Foundation维护和开发。 Kafka的核心是一个分布式发布-订阅消息系统&#xff0c;它可以处理大量的消息流&#…

matplotlib,DLL load failed: 找不到指定的模块

问题&#xff1a;import matplotlib mportError: DLL load failed: 找不到指定的模块 &#xff08;2023年11月28日&#xff09; 解决方法&#xff1a;具体是matplotlib版本不匹配&#xff0c;而且在线pip install numpy时因为在线下载numpy库中缺少DLL。 应该下载带有mkl的num…

利用ogr2ogr从PostGIS中导出/导入Tab/Dxf/Geojson等格式数据

ogr2ogr Demo Command 先查看下当前gdal支持的全部格式&#xff0c;部分gdal版本可能不支持PostGIS。 如出现PostgreSQL表名支持。 #全部支持的格式 ogrinfo --formats | sort #AVCBin -vector- (rov): Arc/Info Binary Coverage #AVCE00 -vector- (rov): Arc/Info E00 (ASC…

居家适老化设计第三十三条---卫生间之暖风

居家适老化是指为了满足老年人居住需求而进行的住房改造&#xff0c;以提供更加安全、舒适、便利的居住环境。在居家适老化中&#xff0c;暖风系统是一个重要的考虑因素。暖风系统可以提供温暖舒适的室内温度&#xff0c;对老年人来说尤为重要。老年人常常身体机能下降&#xf…

PHPExcel 导出Excel报错:PHPExcel_IOFactory::load()

背景 近期在做 excel文件数据导出时&#xff0c;遇到如下报错&#xff1a; iconv(): Detected an illegal character in input string场景&#xff1a;计划任务后台&#xff0c;分步导出 大数据 excel文件发现在加载文件时&#xff0c;会有报错 报错信息 如下&#xff1a; {&q…

Elasticsearch初识--CentOS7安装ES及Kibana

文章目录 一&#xff0e;前言二&#xff0e;介绍1.Elasticsearch2.Kibana 三&#xff0e;ES安装1.下载安装包2.解压、配置2.1 解压2.2 配置 3.启动3.1增加用户3.2启动 4.解决资源分配太少问题5.启动成功 四&#xff0e;Kibana安装1.下载安装包2.解压、配置2.1 解压2.2 配置2.2 …