线性回归算法(含示例代码)

news2024/11/16 3:38:10

1 知识点讲解

1.1 线性回归

线性回归是一种常见的机器学习算法,用于预测连续型变量。该算法的目标是建立一个线性模型,根据输入的自变量来预测一个连续型的因变量

在线性回归中,我们假设因变量(也称为响应变量)与自变量之间存在线性关系。这意味着我们可以使用一条直线来拟合数据,并使用这条直线来进行预测。线性回归可以用于单个自变量或多个自变量的情况。

线性回归的目标是通过最小化预测值与实际值之间的差异来确定最佳拟合直线的参数。这个差异通常被称为残差,我们使用最小二乘法来计算残差和拟合直线的参数。最小二乘法可以使得残差的平方和最小,从而得到最佳的拟合直线。

最小二乘法是一种常用的线性回归算法,用于确定最佳拟合直线的参数。它的目标是通过最小化预测值与实际值之间的差异来确定最佳拟合直线的参数。

设有 n n n 个样本,每个样本都有一个自变量 x i x_i xi 和一个因变量 y i y_i yi。我们假设因变量 y i y_i yi 与自变量 x i x_i xi 之间存在线性关系,即

y i = β 0 + β 1 x i + ϵ i y_i = \beta_0 + \beta_1 x_i + \epsilon_i yi=β0+β1xi+ϵi
其中 β 0 \beta_0 β0 β 1 \beta_1 β1 是我们要求解的拟合直线的参数, ϵ i \epsilon_i ϵi 是模型中的误差项。我们的目标是找到最佳的 β 0 \beta_0 β0 β 1 \beta_1 β1,使得预测值与实际值之间的差异最小。

最小二乘法的基本思想是最小化残差平方和,即

∑ i = 1 n ( y i − y i ^ ) 2 \sum_{i=1}^{n} (y_i - \hat{y_i})^2 i=1n(yiyi^)2
其中 y i ^ \hat{y_i} yi^ 表示用拟合直线预测的值,可以表示为

y i ^ = β 0 + β 1 x i \hat{y_i} = \beta_0 + \beta_1 x_i yi^=β0+β1xi
我们要求解的是使得残差平方和最小的 β 0 \beta_0 β0 β 1 \beta_1 β1。为了求解这个问题,我们需要对残差平方和进行求导,并令导数等于零。这样就可以得最小化残差平方和的解,即

β 1 = ∑ i = 1 n ( x i − x ˉ ) ( y i − y ˉ ) ∑ i = 1 n ( x i − x ˉ ) 2 \beta_1 = \frac{\sum_{i=1}^{n} (x_i - \bar{x})(y_i - \bar{y})}{\sum_{i=1}^{n} (x_i - \bar{x})^2} β1=i=1n(xixˉ)2i=1n(xixˉ)(yiyˉ)

β 0 = y ˉ − β 1 x ˉ \beta_0 = \bar{y} - \beta_1 \bar{x} β0=yˉβ1xˉ

其中, x ˉ \bar{x} xˉ y ˉ \bar{y} yˉ 分别为自变量和因变量的平均值。

这些公式可以用矩阵形式表示为:

[ β 0 β 1 ] = ( X T X ) − 1 X T Y \begin{bmatrix} \beta_0 \\ \beta_1 \end{bmatrix} = (X^TX)^{-1}X^TY [β0β1]=(XTX)1XTY

其中,X 是一个nX2的矩阵,第一列为 1,第二列为自变量 $;Y 是 nX1 的矩阵,表示因变量 y。

1.2 多元线性回归及评估

1.2.1 多元线性回归模型

多元线性回归是一种用于分析多个自变量和一个因变量之间关系的线性回归模型。它可以用于预测因变量的值,给定多个自变量的值。多元线性回归的模型可以表示为:
y = β 0 + β 1 x 1 + β 2 x 2 + ⋯ + β n x n + ϵ y = \beta_0 + \beta_1x_1 + \beta_2x_2 + \cdots + \beta_nx_n + \epsilon y=β0+β1x1+β2x2++βnxn+ϵ
其中, y y y 是因变量, x 1 , x 2 , ⋯   , x n x_1, x_2, \cdots, x_n x1,x2,,xn n n n 个自变量, β 0 , β 1 , β 2 , ⋯   , β n \beta_0, \beta_1, \beta_2, \cdots, \beta_n β0,β1,β2,,βn 是模型的系数, ϵ \epsilon ϵ 是误差项。

多元线性回归的目标是找到最佳的系数 β 0 , β 1 , β 2 , ⋯   , β n \beta_0, \beta_1, \beta_2, \cdots, \beta_n β0,β1,β2,,βn,使得模型预测值与真实值之间的平方误差最小。这个过程通常使用最小二乘法进行求解。

1.2.2 评估

涉及总离差平方和的分解为回归平方和+残差平方和,这部分知识可根据需要查阅,

  • 拟合优度R^2(判定系数、决定系数、样本可决系数)

R 2 = 1 − ∑ i = 1 n ( y i − y i ^ ) 2 ∑ i = 1 n ( y i − y ˉ ) 2 R^2 = 1 - \frac{\sum_{i=1}^{n}(y_i-\hat{y_i})^2}{\sum_{i=1}^{n}(y_i-\bar{y})^2} R2=1i=1n(yiyˉ)2i=1n(yiyi^)2

范围:0到1(闭区间)

越接近1,说明样本回归线对样本值的拟合优度约好,X对Y的解释能力越强。

Adj.R-square:Adj为adjust的缩写,代表调整之后的决定系数。对模型复杂度进行了惩罚,以更准确地评估模型的预测能力。
A d j . R 2 = 1 − ( 1 − R 2 ) ( n − 1 ) n − p − 1 Adj. R^2 = 1 - \frac{(1-R^2)(n-1)}{n-p-1} Adj.R2=1np1(1R2)(n1)
其中, n n n 是样本数量, p p p 是自变量数量。Adj. R 2 R^2 R2 具有以下特点:

  1. 当模型中的自变量数量增加时,Adj. R 2 R^2 R2 的值会减小,因为增加自变量会增加模型的复杂度,需要对其进行惩罚。

  2. 当模型中的自变量数量为1时,Adj. R 2 R^2 R2 R 2 R^2 R2 的值相等。

  3. R 2 R^2 R2 一样,Adj. R 2 R^2 R2 的取值范围为 [ 0 , 1 ] [0,1] [0,1],其值越接近1,表示模型对数据的拟合越好。

Adj. R 2 R^2 R2 是一种常用的模型评估指标,在多元线性回归中有广泛的应用。它可以帮助我们更加准确地评估模型的预测能力,避免过度拟合和欠拟合等问题。

  • 显著性检验

这部分涉及到数理统计这门课程,头大!快去了解一下吧!

2 简单示例及函数介绍

2.1 sklearn.linear_model&np.array().reshape()

sklearn.linear_model 是 Scikit-learn 模块中的一个子模块,用于实现各种线性模型。该模块中包含许多用于线性回归、岭回归、Lasso回归等任务的类和函数,是Scikit-learn中最基础、最常用的模块之一。
在这里插入图片描述
这里我们使用第一个最小二乘法线性回归模型,下方看个示例。

import matplotlib.pyplot as plt
import numpy as np
from sklearn.linear_model import LinearRegression
# 数据集
X = np.array([2, 4, 6, 8, 10]).reshape(-1,1)
Y = [5, 9, 11, 14, 18]
# 模型
model = LinearRegression()
model.fit(X,Y)
# 预测
y_pred=model.predict(X)
# 方程参数
print("系数:",model.coef_)
print("截距:",model.intercept_)
# 绘图
plt.scatter(X,Y)
plt.plot(X,y_pred)
plt.show()
plt.close()

在这里插入图片描述
可以发现,上面第四行使用了numpy中的reshape函数,解释下。

在这个代码中,使用 reshape() 函数是为了将一维的自变量 x 转换为二维数组。这是因为 Scikit-learn 中的线性回归模型要求自变量是一个二维数组,其中第一维表示样本数,第二维表示特征数。在一元线性回归的情况下,样本数为自变量的长度,特征数为1。

因此,我们需要将自变量 x 转换为一个列向量,即将原数组的行数不变,将列数变为1。这可以通过 reshape(-1, 1) 来实现,其中 -1 表示自动计算行数。同理,看下方示例:

import numpy as np

arr = np.array([1, 2, 3, 4, 5])
print(arr.shape)      # 输出 (5, )
arr_reshaped = arr.reshape(-1, 1)
print(arr_reshaped.shape)  # 输出 (5, 1)

arr = np.array([1, 2, 3, 4, 5])
print(arr.shape)      # 输出 (5, )
arr_reshaped = arr.reshape(1, -1)
print(arr_reshaped.shape)  # 输出 (1, 5)

2.2 statemodels多元线性回归

import numpy as np
import statsmodels.api as sm
# 创建一个随机的多元线性回归数据集
np.random.seed(0)
n = 100
p = 3
X = np.random.randn(n, p)
y = X.dot(np.array([1, 2, 3])) + np.random.randn(n)
# 向 X 中添加常数项
X = sm.add_constant(X)
# 创建一个多元线性回归模型
model = sm.OLS(y, X)
# 拟合模型并输出结果
results = model.fit()
print(results.summary())

在这里插入图片描述

3 实践案例

案例描述:

Walmart是全球最大的连锁超市之一,在全球范围内拥有数千家门店。为了更好地管理门店和制定营销策略,Walmart希望能够预测未来每个门店的销售额,以便做出更好的管理和决策。
数据集:

本实验使用的数据集是Kaggle上的Walmart Dataset (Retail)数据集,包括Walmart公司在2010年2月到2012年10月期间45个店铺的历史销售数据。该数据集包含了多个特征,包括店铺编号、日期、每个店铺每周销售额、是否假期、当日温度、地区燃油价格、消费价格指数(CPI)和失业率等信息。

import pandas as pd
import numpy as np
import seaborn as sns
import matplotlib.pyplot as plt
from sklearn.preprocessing import StandardScaler
from sklearn.model_selection import train_test_split
from sklearn.linear_model import LinearRegression
from sklearn.metrics import r2_score
# 读取数据
df = pd.read_csv('data.csv', parse_dates = True, infer_datetime_format = True)
# 整理数据
df.Date=pd.to_datetime(df.Date)
df['weekday'] = df.Date.dt.weekday
df['month'] = df.Date.dt.month
df['year'] = df.Date.dt.year
df.drop(['Date','weekday','year'], axis=1, inplace=True)
target = 'Weekly_Sales'
# 相关性分析和特征选择
sns.pairplot(df, x_vars=['Holiday_Flag', 'Temperature', 'Fuel_Price', 'CPI', 'Unemployment','month'],
            y_vars=['Weekly_Sales'], height=4, aspect=0.7, kind='reg')
plt.show()

在这里插入图片描述

df1=df.copy(deep=True)
df1.drop(['Store'], axis=1, inplace=True)
corr_matrix = df1.corr()
sns.heatmap(corr_matrix, annot=True, cmap='Reds')
plt.show()

在这里插入图片描述

# 发现Fuel_Price和Weekly_Sales的相关系数小于0.01,认为该特征和销售量关系不大
df.drop(['Fuel_Price'], axis=1, inplace=True)
# 分类数据处理
df = pd.get_dummies(df, columns=[ 'Store','month'])
# 提取特征和标签
X = df.drop([target],axis=1)
y = df[target]
# 划分训练集和测试集
X_train,X_test,y_train,y_test = train_test_split(X,y,random_state=100,test_size=0.2)
# 数据标准化
scaler = StandardScaler()
X_train = scaler.fit_transform(X_train)
X_test = scaler.transform(X_test)
# 训练多元线性回归模型
model = LinearRegression()
model.fit(X_train,y_train)
# 测试集预测
y_pred = model.predict(X_test)
# 评估模型预测效果
print("系数:",model.coef_)
print("截距:",model.intercept_)
print('R平方值(R^2):', r2_score(y_test, y_pred))

参考链接:
机器学习 | 使用statsmodels和sklearn进行回归分析_sklearn statsmodels_育种数据分析之放飞自我的博客-CSDN博客

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/642667.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

近80%企业首选——亚马逊云科技为中国企业出海保驾护航

随着全球数字化进程的不断加速,中国出海“大航海时代”已然到来。从#万企组团出国抢订单#到#苏州赴日包机抢单20亿元#,中国企业对海外市场的优势已经一步步建立了起来。 从卖小商品、卖鞋的“世界工厂”,到现在产业升级后的卖汽车、卖服务、…

抖音seo矩阵系统源码|需求文档编译说明(一)

抖音seo矩阵系统文章目录技术囊括 ①产品原型 ②需求文档 ③产品流程图 ④部署方式说明 ⑤完整源码 ⑥源码编译方式说明 ⑦三方框架和SDK使用情况说明和代码位置 ⑧平台操作文档 ⑨程序架构文档 短视频矩阵系统源码开发锦囊囊括前言一、短视频账号矩阵系统开发者必备能力语言&…

招标投标管理微信小程序解决方案

招投标管理微信小程序是一种基于微信公众平台构建的在线招投标管理平台,适用于各类招投标项目管理,通过小程序内的功能实现投标、查看、评估和管理等各项业务。下面我们来了解一下招投标管理微信小程序的具体功能和应用情况。 招投标管理微信小程序的功能…

App 启动速度优化

前言​​​​​​​ APP打开的一瞬间速度快慢;就好比人的第一印象,快速的打开一个应用往往给人很舒服的体验。app经常性卡顿启动速度很慢,这无疑是对用户的流失。 启动方式介绍 APP启动的方式分为3种:冷启动、热启动、温启动。…

28.vite

目录 1 一些概念 1.1 单页面应用程序SPA 1.2 vite 2 初始化vite项目 3 项目中的文件 1 一些概念 1.1 单页面应用程序SPA 单页面应用程序是只有一个页面的前端,切换页面通过前端路由来切换 特点如下 实现了前后端分离,后端仅出接口&#…

Flink TableAPI window and watermarket

序言 本次主要是弄清楚.批流统一 的处理方式,因为它是使用SQL来操作批流计算的.所以它怎么设置算子并行度?如何设置窗口?如何处理流式数据?等等 有很多疑问. 我还是觉得直接使用流计算的API更好.流批一体API最终也是转换成流式计算,最主要的是使用sql来设置算子或者窗口,并…

python合并多个excel,每个excel中有相同的列,按指定列名将数据列合并到一起。以统计学生多个作业提交情况为例。

一、实现目标 有多个excel文件,每个excel文件是一次学生作业的提交情况,最终统计出所有学生所有作业的提交情况。具体格式和内容如下: excel1.xlsx excel2.xlsx excel3.xlsx: 最后统计出所有学生提交的所有作业的情况: 二、实现思路

C# 自动备份文件

目录 文件目录如下 APBackUpFiles app.config OracleHelper LocalFileMethods LogFile packages.config ReadFile 如何发布 在工作的时候,遇到了需要定时对服务器的文件进行备份的需求,原因是 AP(服务器)上的空间不够了&a…

遗传算法解决TSP旅行商问题(numpy、pandas)

努力是为了不平庸~ 学习的最大理由是想摆脱平庸,早一天就多一份人生的精彩;迟一天就多一天平庸的困扰。 目录 一、引言 原理: 问题: 二、思路步骤 三、代码编写步骤 A、代码各步骤的方法、目的及意义 1. 导入所需的库&…

测试人,你凭什么脱颖而出?

我们在软件测试面试时,可能经常会碰到HR这样问“与其他竞争者相比,你认为自己的优势在哪里?” 看似简单,但仔细深思可能心理陡然冰冰凉,因为自己难以有信心比他人突出(除了腰间盘),看…

DBA 抓包神器 tshark 测评

想窥探神秘的网络世界的奥秘,tshark 助你一臂之力! 作者:赵黎明 爱可生 MySQL DBA 团队成员,熟悉 Oracle、MySQL 等数据库,擅长数据库性能问题诊断、事务与锁问题的分析等,负责处理客户 MySQL 及我司自研 D…

chatgpt赋能python:Python火了原因分析

Python火了原因分析 Python语言是近年来最热门的编程语言之一,有很多原因可以解释它的成功。本文将介绍三个最重要的原因,以及如何利用这些原因来提高您的Python编程技能。 Python具有易学性和流行的库 Python的设计使它非常容易学习,尤其…

【机器学习】十大算法之一 “朴素贝叶斯”

作者主页:爱笑的男孩。的博客_CSDN博客-深度学习,活动,python领域博主爱笑的男孩。擅长深度学习,活动,python,等方面的知识,爱笑的男孩。关注算法,python,计算机视觉,图像处理,深度学习,pytorch,神经网络,opencv领域.https://blog.csdn.net/Code_and516?typeblog个…

Visual Studio Code Arduino资源占用和效率对比

Visual Studio Code&Arduino资源占用和效率对比 系统资源占用:编译效率: 这段时间在玩ESP32,闲来无事对比了一下Visual Studio Code后面简称VS和Arduino的效率和资源占用,只是大致的对比,没有斤斤计较。 配置为&am…

springboot集成swagger

文章目录 swagger概述swagger常用注解ApiImplicitParam swagger的集成方式集成swagger2.9集成swagger2.10集成swagger3 swagger概述 swagger是当下比较流行的实时接口文文档生成工具。接口文档是当前前后端分离项目中必不可少的工具,在前后端开发之前,后…

Mysql数据库初体验

Mysql数据库初体验 一、数据库的基本概念1.数据(Data)2.表3.数据库4.数据库管理系统(DBMS)5.数据库系统 二、数据库系统发展史1.第一代数据库2.第二代数据库3.第三代数据库 三、当今主流数据库介绍四、数据库分类1.关系数据库2.关系型 SQL 数…

前端教程:Canvas怎样创建画布和绘制图形?

HTML5提供了一种全新的画布功能,即通过Canvas来让用户在网页中绘制图形、文字、图片等。Canvas表示画布,现实生活中的画布是用来作画的,HTML5中的Canvas与之类似,我们可以称它为“网页中的画布”。默认情况下,Canvas是…

【MySQL高级篇笔记-锁(下) 】

此笔记为尚硅谷MySQL高级篇部分内容 目录 一、概述 二、MySQL并发事务访问相同记录 1、读-读情况 2、写-写情况 3、读-写或写-读情况 4、并发问题的解决方案 三、锁的不同角度分类 1、从数据操作的类型划分:读锁、写锁 2、从数据操作的粒度划分&#xf…

攻防渗透第四章(谷歌语法)

一、常用谷歌黑客语法 制定网站的URL site: 包含特定字符的URL inurl: 网页标题中包含特定字符 intitle: 正文中指定字符 intext: 指定类型文件 filetype 开发语言判断 site:163.com filetype:php site:163.com filetype:jsp site:163.com filetype:asp site:163.com filetype…

工具篇--4 消息中间件-RabbitMq 模型介绍

1 介绍: RabbitMQ 是一个开源的消息中间件,它实现了 AMQP(高级消息队列协议)标准,并且支持多种语言和操作系统,包括 Java、Python、Ruby、PHP、.NET、MacOS、Windows、Linux 等等。RabbitMQ 提供了可靠的消息传递机制…