sklearn之线性回归——以上证红利指数为例

news2025/1/23 17:25:38

文章目录

    • 线性回归
      • 概念
      • 使用sklearn实现上证中立指数预测
        • 内置数据集的加载与处理
      • 外部数据集的加载和处理
        • 数据内容
        • 数据加载和处理
      • 开始预测
        • 分割数据集
        • 导入线性回归模型
        • 查看线性回归模型的系数
        • 绘制预测结果
        • 预测效果评估
      • 最终代码

线性回归

线性回归(Linear Regression)模型是最简单的线性模型之一,很具代表性

概念

我们在高中时代其实就学过使用最小二乘法进行线性回归分析

这实际上是统计学部分的内容,会有大量的自变量,或者说解释变量,还有就是对应的因变量,也就是输出结果,回归分析就是找出他们对应的关系,并且使用某个模型描述出来,这样一来给出新的变量,就能利用模型实现预测

这也就是我们一开始介绍机器学习说明的过程,给出输入和输出,找到一个模型T能够很好的拟合这些数据,从而使用T就能预测结果了

从几何层面,回归就是找到具有代表性的直线、曲线、甚至是面,来进行拟合

回归的种类有很多,一元和多元,那么一元其实就是线性回归。我们这里先讨论线性回归,而且我们假设因变量和自变量之间是满足线性关系的,也就是 y = w 0 + w 1 x y=w_0+w_1x y=w0+w1x

这里的 w 0 w_0 w0 w 1 w_1 w1我们称之为回归系数,我们需要拟合的,求出来的就是这两个权值,一个经典的示意图是这样的

image.png

这里的每一个点就是实际的数据,红色的线是我们拟合出来的,很容易可以看得到,有些点离线近,有些点离线远,我们使用残差(Residual)来描述这里的远和近,也就是误差,简单说就是从点向x轴做垂线与拟合线相交的点的距离就是残差, ϵ = ∣ y ^ i − y i ∣ \epsilon=|\hat y_i-y_i| ϵ=y^iyi

这里的小帽子表示的是预测数据,就是不准的意思,没啥难理解的

那么我们的目标就变成了,要求一条拟合的线,让所有的误差最小

这里的思想就是使用最小二乘法(Ordinary Least Squares,OLS)了,就是要让残差的平方和最小即可,那我们的损失函数就可以变成这样了 H = ∑ i = 1 m ( y ^ i − y i ) 2 = ∑ i = 1 m ( y i − w 1 x i − w 0 ) 2 H=\sum_{i=1}^{m}(\hat y_i-y_i)^2=\sum_{i=1}^{m}(y_i-w_1x_i-w_0)^2 H=i=1m(y^iyi)2=i=1m(yiw1xiw0)2

以上就是求解这两个参数,也就是求一个二元函数 H ( w 0 , w 1 ) H(w_0,w_1) H(w0,w1)的最小值,然后取出对应的 w 0 w_0 w0 w 1 w_1 w1即可

事实上我们也可以利用优化算法(随机梯度下降法、牛顿迭代法)来快速逼近最优参数

使用sklearn实现上证中立指数预测

内置数据集的加载与处理

以导入波士顿房价数据集为例

form sklearn.datasets import load_boston

这里的boston可以换成别的数据集

名称数据集
load_boston波士顿房价
load_breast_cancer乳腺癌
load_iris鸢尾花
load_diabetes糖尿病
load_linnerud体能训练
load_wine红酒品类

然后对应的就是数据处理的部分了

boston = load_boston() # boston是一个字典对象,我们可以使用key方法查看他对应的属性值

在取出来字典之后,我们就可以进行数据预处理和分析了

外部数据集的加载和处理

我们首先需要收集数据,这里我们直接在官网可以下载上证红利指数

数据内容

image.png

这里我把所有的非数值类型的数据全部删除了,这里是五年的数据

那么这里的特征值有,开盘,最高,最低,收盘,涨跌,涨跌幅,成交量,成交金额

数据加载和处理

下载之后我们获取到的就是一份表格文件了,下载可能是xlsx格式的,可以另存为csv格式的,方便处理

我这里使用pandas进行读取和预处理工作

import pandas as pd
file_path = './000015perf.csv'

data = pd.read_csv(file_path)

这里使用read_csv直接读取的data是DataFrame类型的数据了

如果我们使用的是内置数据,就要通过pd.DataFrame(boston.data)来转换成DataFrame类型

之后我们可以给他加上标签,数据清洗等操作

# 删除含有缺失值的行
data_clean = data.dropna()

我们直接把有缺失的情况给扔掉

开始预测

分割数据集

正如我们前面所说,我们至少要把整个数据集分割成两部分,训练集和测试集,为了保证数据分割的随机性和专业性,sklearn提供了专门的分割函数,train_test_split

我们直接读取的内容是预测的结果y,我们称之为标签数据,和特征值x,我们称之为特征数据

这个专门的分割函数是要求特征数据和标签数据必须是分开的,我们可以使用pandas的drop方法去除

X = data_clean.drop(columns=['日期Date', '涨跌Change', '涨跌幅(%)Change(%)'])
y = data_clean['涨跌Change']

这里我们删除了日期,因为没啥太大作用,还有可能影响预测结果的涨跌和涨跌幅度,并且把涨跌作为预测的对象

有一个细节是特征数据一般用大写的X,特征值一般用小写的y

接下来就是进行训练集和测试集的分割

# 分割训练集和测试集
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.25, random_state=20)

test_size是表示测试集合所占的比例,random_state表示随机的状态

这里的随机状态其实是相反的意思,就是为了保证某些数据是固定的,因为一旦全随机可能会导致预测不够准确,从而无法调参,而固定下一部分数据作为训练集和测试集的话是提供了一定的稳定性,这种稳定性也方便了调参的进行

导入线性回归模型

在数据分割完成之后,我们就可以导入线性回归模型,训练数据并且进行模型预测了

from sklearn.linear_model import LinearRegression

# 创建线性回归模型
model = LinearRegression()

# 训练模型
model.fit(X_train, y_train)

# 使用模型进行预测
y_pred = model.predict(X_test)

在sklearn中,训练模型的方法统称为fit

回归分析属于监督学习,所以fit提供两个参数,前者是特征数据,后者是标签数据

查看线性回归模型的系数

线性回归的核心目的就是找到关键的参数,我们可以通过print直接输出查看每个特征的权值

print("w0 = ", model.intercept_)
print("W = ", model.coef_)

image.png

我们之前所有的特征值去除影响之后共计6个特征,所以至少有6个权值,再加上w0是截距,应该是7个权值

对于这些权值我们也可以做出一些解释,例如第一个权值对应的是开盘价格,那其实说明开盘价格越高跌的概率就越大,第四个是收盘价格,那其实也很好说明问题了

出现这样直白的结果其实是由于我们特征数据类型收集的不够多,或者是不够具备我们想要研究的特征数据

绘制预测结果

我们可以使用matplotlib来进行预测涨跌和实际涨跌的对比

plt.figure(figsize=(10, 6))
sns.regplot(x=y_test.values, y=y_pred, scatter_kws={'color': 'blue'}, line_kws={'color': 'red', 'linewidth': 2})
plt.xlabel('Actual Stock Change')
plt.ylabel('Predicted Stock Change')
plt.title('Seaborn Regression Plot of Actual vs Predicted Stock Change')
plt.grid(True)
plt.show()

image.png

从结果上看基本都集中在直线附近,还是比较准确的

预测效果评估

由于回归分析的目标值是连续的,所以我们不能用准确率来评估,而应该比较预测值和实际值的差值评估,其中均方根误差(root-mean-square error、RMSE)是最常见的评估标准之一

R M S E = ∑ i = 1 n ( P r e d i c t i − A c t u a l i ) 2 n RMSE=\sqrt\frac{\sum_{i=1}^{n}(Predict_i-Actual_i)^2}{n} RMSE=ni=1n(PredictiActuali)2

还有一个是R方分数,表示预测数据和实际数据的相关性,范围是从0到1,越大表示相关性越好

$ R^2 = 1 - \frac{SS_{res}}{SS_{tot}} $
其中:

  • ( S S r e s SS_{res} SSres ) 是残差平方和(Sum of Squares of the Residuals),它衡量了模型预测值与实际值之间的差异。
  • ( S t o t S_{tot} Stot ) 是总平方和(Total Sum of Squares),它衡量了实际值与平均值之间的差异。
    更具体地说,这些平方和的计算方式如下:
    S S r e s = ∑ ( y i − y ^ i ) 2 SS_{res} = \sum (y_i - \hat{y}_i)^2 SSres=(yiy^i)2
    S S t o t = ∑ ( y i − y ˉ ) 2 SS_{tot} = \sum (y_i - \bar{y})^2 SStot=(yiyˉ)2
from sklearn.metrics import mean_squared_error, r2_score

print(f'Mean Squared Error (MSE): {mse}')
print(f'R^2 Score: {r2}')

image.png

当然如果我们想查看线性回归输出的预测涨跌和实际涨跌的对比情况,也可以很容易的实现

df = pd.DataFrame({'实际涨跌':y_test, '预测涨跌':y_pred})
print(df)

image.png

最终代码

import pandas as pd
from sklearn.model_selection import train_test_split
from sklearn.linear_model import LinearRegression
from sklearn.metrics import mean_squared_error, r2_score
import matplotlib.pyplot as plt
import seaborn as sns

file_path = './000015perf.csv'

data = pd.read_csv(file_path)

# 删除含有缺失值的行
data_clean = data.dropna()

# 删除日期列,因为它对预测可能没有直接作用
X = data_clean.drop(columns=['日期Date', '涨跌Change', '涨跌幅(%)Change(%)'])
y = data_clean['涨跌Change']

# 分割训练集和测试集
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.25, random_state=20)

# 创建线性回归模型
model = LinearRegression()

# 训练模型
model.fit(X_train, y_train)

# 使用模型进行预测
y_pred = model.predict(X_test)

# 计算预测的均方误差和R^2分数
mse = mean_squared_error(y_test, y_pred)
r2 = r2_score(y_test, y_pred)

plt.figure(figsize=(10, 6))
sns.regplot(x=y_test.values, y=y_pred, scatter_kws={'color': 'blue'}, line_kws={'color': 'red', 'linewidth': 2})
plt.xlabel('Actual Stock Change')
plt.ylabel('Predicted Stock Change')
plt.title('Seaborn Regression Plot of Actual vs Predicted Stock Change')
plt.grid(True)
plt.show()

print("w0 = ", model.intercept_)
print("W = ", model.coef_)


print(f'Mean Squared Error (MSE): {mse}')
print(f'R^2 Score: {r2}')

df = pd.DataFrame({'实际涨跌':y_test, '预测涨跌':y_pred})
print(df)

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1665902.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

Linux·基本指令

从本节开始将新开一个关于Linux操作系统的板块,其实Linux也没什么太神秘的,就是一个操作系统(OS)嘛,跟Windows操作系统是一个概念,只不过Windows中的大部分操作都是用光标点击来进行人机交互,但是Linux是通过输入命令行…

AIGC、LLM 加持下的地图特征笔记内容生产系统架构设计

文章目录 背景构建自动化内容生产平台系统架构设计架构详细设计流程介绍笔记来源笔记抓取干预 笔记 AIGC 赋能笔记 Rule 改写笔记特征库构建 附录Bash Cron 定时任务Golang 与 Pyhon AIGC 实践 小结 背景 在大模型的浪潮下,ChatGPT、Sora、Gemini、文言一心 等新技…

LoRaWAN入门

1.文档资料 飞书云文档 (feishu.cn) G43室内LoRaWAN网关 - doc.alinkwise.com > LoRaWAN网关(基站) > G4x > G43室内LoRaWAN网关 2.简介 LoRa: 远距离无线电(long rang radio), 它最大特点就是在同样的功耗条件下比其他无线方式…

《构建合同中台系统:实现合同管理的集成化与智能化》

随着企业数字化转型的深入推进,合同管理作为企业日常运营的重要组成部分,也在不断演进与升级。传统的合同管理方式已经无法满足企业对于效率、合规性和智能化的需求,因此,构建合同中台系统成为了当下企业迫切需要解决的问题。 **1…

Vue中进行粘贴板粘贴数据(图片、文字等)

在页面中如果需要进行粘贴数据,那么就要读取系统粘贴板clipboard,通过此Api来进行粘贴板数据的操作。 目录: 一.封装相关函数1.示例代码:2.代码解释: 二.页面中进行粘贴1.代码示例:2.代码解释: 三.运行结果…

C数据结构:队列

目录 队列是什么? 队列的实现 队列的数据结构 队列的初始化 队列的插入 队列的删除 获取队列队头元素 获取队列队尾元素 获取队列元素个数 检查队列是否为空 队列的销毁 队列的使用 完整代码 队列是什么? 队列也是顺序表中的一种 队列和栈…

Python-VBA函数之旅-staticmethod函数

目录 一、staticmethod函数的常见应用场景 二、staticmethod函数使用注意事项 三、如何用好staticmethod函数? 1、staticmethod函数: 1-1、Python: 1-2、VBA: 2、推荐阅读: 个人主页: https://blog…

Linux(Ubuntu24.04) 安装 MinIO

本文所使用的 Ubuntu 系统版本是 Ubuntu 24.04 ! # 1、下载 MinIO wget https://dl.min.io/server/minio/release/linux-amd64/minio# 2、添加可执行权限 chmod x minio# 3、导出环境变量,用于设置账号密码,我设置的账号和密码都是 minioadmin export MI…

【数据结构练习题】Map与Set——1.只出过一次的数字2.复制带随机指针的链表3.宝石与石头4.坏键盘打字

♥♥♥♥♥个人主页♥♥♥♥♥ ♥♥♥♥♥数据结构练习题总结专栏♥♥♥♥♥ ♥♥♥♥♥【数据结构练习题】堆——top-k问题♥♥♥♥♥ 文章目录 1.只出过一次的数字1.1问题描述1.2思路分析1.3绘图分析1.4代码实现2.复制带随机指针的链表2.1问题描述2.2思路分析2.3绘图分析2.4代…

Android解放双手的利器之ViewBinding

文章目录 1. 背景2. ViewBinding是什么3. 开启ViewBinding功能4. 生成绑定类5. 使用ViewBinding5.1Activity 中使用5.2 Fragment 中使用5.3 ViewHolder 中使用 6. ViewBinding的优点7. 与 dataBinding 对比 1. 背景 写代码最繁琐的是什么?重复的机械操作。我们刚接…

【AI+老照片焕新】母亲节用AI把时间的印记变成暖心礼物

想念是一张泛黄的照片,藏在抽屉里的笑容,总是那么亲切。今天是母亲节,是不是想给妈妈来点不一样的惊喜?用AI技术,把那些老照片瞬间焕新,让妈妈的青春记忆重放光华! 想象一下,妈妈年…

vue3vue3vue3vue3vue3vue3vue3vue3vue3vue3vue3vue3

纯vue3的语法 一.创建(基于vite) 1.在指定目录下运行 npm create vuelatest 项目名称:英文小写下划线数字回车表示确定是、否 左右切换路由、pina、单元测试、端到端的测试、开启eslint控制代码质量 先选择no,学的时候自己手动…

4---自动化构建代码(逻辑梳理,轻松理解)

一、需求引出: 在使用编译器编译代码时,无论我们在一个项目中写了多少个文件(包括头文件、源文件),我们都可以一键完成编译,编译器会自动处理各个文件之间的包含,调用关系。但是在Linux中,我们在一个目录下…

Docker in Docker(DinD)原理与实战

🐇明明跟你说过:个人主页 🏅个人专栏:《Docker幻想曲:从零开始,征服容器宇宙》 🏅 🔖行路有良友,便是天堂🔖 目录 一、引言 1、Docker简介 2、Docker …

Kubernetes学习-深入Pod篇(一) 创建Pod,Pod配置文件详解

🏷️个人主页:牵着猫散步的鼠鼠 🏷️系列专栏:Kubernetes渐进式学习-专栏 🏷️个人学习笔记,若有缺误,欢迎评论区指正 1.前言 我们在前面的文章讲解了Kubernetes的核心概念和服务部署&#x…

08.3.grafana自定义图形

grafana自定义图形 找插件里面的zabbix 点击update 数据源—zabbix数据源,添加zabbix数据源 选择zabbix类型 我这里配置的是本地,所以URL直接localhost 这里配置zabbix登录账号密码Admin/zabbix 然后点击保存并测试,会直接显示版本 导入模板&…

JavaSE——集合框架一(1/7)-集合体系概述(集合体系结构,Collection集合体系)、Collection的常用方法(介绍,实例演示,代码)

目录 集合体系概述 集合体系结构 Collection集合体系 Collection的常用方法 介绍 实例演示 完整代码 集合体系概述 集合体系结构 集合是一种容器,用来装数据的,类似于数组,但集合的大小可变,开发中也非常常用。 为了满足…

ACM 的代码编码示例

写在最前面的 实践的顺序, 应该是先将基础的 数据结构题目类型给实现。 然后再开始尝试 实现对应类型的算法题目,如回溯算法, 贪心算法, 动态规划, 图论; 基础的数据结构, 推荐卡尔的&#xff…

【C++】vector的底层原理讲解及其实现

目录 一、认识vector底层结构 二、初始化vector的函数 构造函数拷贝构造赋值构造initializer_list构造迭代器区间构造 三、迭代器 四、数据的访问 五、容量相关的函数 六、关于数据的增删查改操作 一、认识vector底层结构 STL库中实现vector其实是用三个指针来完成的&#x…

PY32F403系列单片机,32位M4内核MCU,主频最高144MHZ

PY32F403系列单片机是基于Arm Cortex-M4核的32位通用微控制器产品。内置的FPU和DSP功能支持浮点运算和全部DSP指令。通过平衡成本,性能,功耗来获得更好的用户体验。 PY32F403单片机典型工作频率可达144MHZ,内置高速存储器,丰富的…