【python】基于岭回归算法对学生成绩进行预测

news2024/11/20 22:39:03

前言

在数据分析和机器学习领域,回归分析是一种预测连续数值的监督学习技术。当数据特征与目标变量之间存在线性关系时,线性回归模型尤其有用。然而,当特征数量多于样本数量,或者特征之间存在多重共线性时,普通最小二乘法可能不是最佳选择。这时,岭回归(Ridge Regression)作为一种改进的线性回归方法,通过引入正则化项来防止模型过拟合,从而提高模型的泛化能力。

正文

数据加载与预处理

在本例中,我们使用pandas库加载了一个名为data.csv的数据集。数据集被分为特征集X和目标变量y。为了简化问题,我们只取前两列作为特征,并假设第三列是目标变量。

data = pd.read_csv('data.csv')
X = data.iloc[:, :2]  # 取前两列作为特征
y = data.iloc[:, 2]  # 取第三列作为目标变量

接下来,我们使用train_test_split函数将数据集分为训练集和测试集,其中测试集占20%。这样做的目的是为了在模型训练完成后,能够在未见过的数据上评估模型性能。

X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)

在进行模型训练之前,对特征进行标准化是很重要的。这可以通过StandardScaler实现,它将数据缩放到均值为0,标准差为1。

scaler = StandardScaler()
X_train_scaled = scaler.fit_transform(X_train)
X_test_scaled = scaler.transform(X_test)

模型选择与超参数优化

岭回归是一种通过引入L2正则化项来防止模型过拟合的线性回归方法。正则化项的强度由超参数alpha控制。为了找到最佳的alpha值,我们使用GridSearchCV进行超参数优化。

alpha_candidates = [1e-15, 1e-10, 1e-5, 1e-2, 1, 5, 10, 20]
grid_search = GridSearchCV(estimator=ridge, param_grid={'alpha': alpha_candidates}, cv=5, scoring='neg_mean_squared_error')
grid_search.fit(X_train_scaled, y_train)

GridSearchCV通过交叉验证的方式在给定的参数网格中寻找最佳的参数组合。我们选择了5折交叉验证,并使用负均方误差作为评分指标,因为GridSearchCV默认寻找评分指标的最大值,而均方误差越小越好。

模型训练与评估

在找到最佳的alpha值后,我们使用这个值来训练最终的岭回归模型,并在测试集上进行预测。

best_alpha = grid_search.best_params_['alpha']
ridge_best = Ridge(alpha=best_alpha)
ridge_best.fit(X_train_scaled, y_train)
y_pred = ridge_best.predict(X_test_scaled)

为了评估模型性能,我们计算了均方误差(MSE),这是一个常用的回归评估指标。

mse = mean_squared_error(y_test, y_pred)
print(f'Mean Squared Error with best alpha: {mse}')

结果可视化

最后,我们通过绘制实际值与预测值的散点图来可视化模型的预测效果。理想情况下,预测值应该与实际值完全一致,即所有点都落在对角线上。

plt.scatter(y_test, y_pred, alpha=0.5)
plt.xlabel('Actual Values')
plt.ylabel('Predicted Values')
plt.title('Ridge Regression Prediction')
plt.plot(lims, lims, 'k--', alpha=0.75, zorder=0)
plt.grid(True)
plt.show()

通过散点图,我们可以直观地看到模型的预测效果。如果大多数点都集中在对角线附近,那么模型的预测效果就比较好。
在这里插入图片描述

总结

本文介绍了如何使用岭回归模型对数据集进行分析,并展示了如何通过超参数优化来提高模型性能。其中使用了GridSearchCV来寻找最佳的alpha值,并使用均方误差作为评估指标。最后,我们通过可视化手段直观地展示了模型的预测效果。岭回归作为一种有效的正则化方法,在处理特征数量多或存在多重共线性的数据集时,能够提高模型的泛化能力。

整体代码

import pandas as pd
import numpy as np
from sklearn.model_se`在这里插入代码片`lection import train_test_split, GridSearchCV
from sklearn.linear_model import Ridge
from sklearn.preprocessing import StandardScaler
from sklearn.metrics import mean_squared_error
import matplotlib.pyplot as plt

# 1. 数据加载
data = pd.read_csv('data.csv')
X = data.iloc[:, :2]  # 取前两列作为特征
y = data.iloc[:, 2]  # 取第三列作为目标变量

# 2. 数据预处理
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)
scaler = StandardScaler()
X_train_scaled = scaler.fit_transform(X_train)
X_test_scaled = scaler.transform(X_test)

# 3. 使用GridSearchCV来优化alpha值
# 定义alpha值的候选范围
alpha_candidates = [1e-15, 1e-10, 1e-5, 1e-2, 1, 5, 10, 20]

# 创建岭回归模型
ridge = Ridge()

# 创建GridSearchCV对象
grid_search = GridSearchCV(estimator=ridge, param_grid={'alpha': alpha_candidates}, cv=5,
                           scoring='neg_mean_squared_error')

# 执行网格搜索
grid_search.fit(X_train_scaled, y_train)

# 获取最佳alpha值
best_alpha = grid_search.best_params_['alpha']
print(f"Best alpha: {best_alpha}")

# 使用最佳alpha值训练模型
ridge_best = Ridge(alpha=best_alpha)
ridge_best.fit(X_train_scaled, y_train)

# 进行预测
y_pred = ridge_best.predict(X_test_scaled)

# 评估模型
mse = mean_squared_error(y_test, y_pred)
print(f'Mean Squared Error with best alpha: {mse}')

# 注意:这里使用的是负均方误差作为评分指标,因为GridSearchCV默认寻找最大值,而均方误差越小越好,所以取负值。

# 4. 可视化预测结果
plt.scatter(y_test, y_pred, alpha=0.5)  # 绘制实际值与预测值的散点图
plt.xlabel('Actual Values')
plt.ylabel('Predicted Values')
plt.title('Ridge Regression Prediction')

# 绘制理想情况的对角线
lims = [
    np.min([y_test.min(), y_pred.min()]),  # x轴最小值
    np.max([y_test.max(), y_pred.max()]),  # x轴最大值
]
plt.plot(lims, lims, 'k--', alpha=0.75, zorder=0)
plt.xlim(lims)
plt.ylim(lims)

# 显示图形
plt.grid(True)
plt.show()

66

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1656960.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

jsp 实验16 MVC 表白墙

源代码以及执行结果截图&#xff1a; ExpressWish_Bean.java package web; import java.util.HashMap; import java.util.ArrayList; import java.util.Iterator; public class ExpressWish_Bean { public HashMap<String,ExpressWish> wishList; ArrayList&…

【AI知识】Stable diffusion常用提示词分享

模型&#xff08;Model&#xff09; majicmixRealistic_v7 majicmixRealistic&#xff08;麦橘写实&#xff09;是融合了多种展现日常生活人物形象的写实风格模型&#xff0c;人物的外观更加接近现实生活&#xff0c;对于光影、皮肤、人物动态均有较好的表现&#xff0c;非常…

Java | Leetcode Java题解之第70题爬楼梯

题目&#xff1a; 题解&#xff1a; public class Solution {public int climbStairs(int n) {double sqrt5 Math.sqrt(5);double fibn Math.pow((1 sqrt5) / 2, n 1) - Math.pow((1 - sqrt5) / 2, n 1);return (int) Math.round(fibn / sqrt5);} }

版本控制工具之Git的基础使用教程

Git Git是一个分布式版本控制系统&#xff0c;由Linux之父Linus Torvalds 开发。它既可以用来管理和追踪计算机文件的变化&#xff0c;也是开发者协作编写代码的工具。 本文将介绍 Git 的基础原理、用法、操作等内容。 一、基础概念 1.1 版本控制系统 版本控制系统&#x…

如何区分APP页面是H5还是原生页面?

刚刚接触手机测试的同学&#xff0c;或多或少都有过这样的疑问&#xff1a;APP页面哪些是H5页面&#xff1f;哪些是原生页面?单凭肉眼&#xff0c;简直太难区分了&#xff01;我总结了6个小技巧&#xff0c;希望能帮大家答疑解惑。 1、看断网的情况 断开网络&#xff0c;显示…

推荐 3 个 yyds 的开源项目!

本期推荐开源项目目录&#xff1a; 1. AI 搜索引擎 2. 大模型聊天框架 3. 模仿抖音的移动端短视频 01 AI 搜索引擎 Perplexica 是一个开源的、由 AI 驱动的搜索引擎。它深入互联网寻找答案&#xff0c;不仅搜索网络&#xff0c;还理解您的问题。 Perplexica 受到 Perplexity AI…

今天来聊聊Numpy!

numpy&#xff1f;what&#xff5e;什么是numpy&#xff1f; 小编先暂且不提。 ​ 大家先暂且看看这句话&#xff0c;“你给我翻 译翻译&#xff0c;什么他妈的是他妈的惊喜&#xff1f; 这还用翻译&#xff0c;都说了… 惊喜嘛……”。 惊喜这段出自《让子…

Could not find the Qt platform plugin “dxcb“ in ““、 重点:是dxcb

这个重点就在于是dxcb不是xcb&#xff0c;让我一顿好找。 https://bugs.launchpad.net/ubuntu/source/deepin-qt5dxcb-plugin/bug/1826629 这篇文章描述了应该是deepin系统的一个问题&#xff0c;应该已经修复了不知道为什么我还会遇到。 不过知道是dxcb后直接去qtcreater里的系…

ICode国际青少年编程竞赛- Python-2级训练场-坐标与列表遍历

ICode国际青少年编程竞赛- Python-2级训练场-坐标与列表遍历 1、 for i in range(5):Flyer[i].step(Dev.x -Flyer[i].x) Dev.step(Item.y - Dev.y)2、 for i in range(7):Flyer[i].step(Dev.y - Flyer[i].y) Dev.step(Item[2].x - Dev.x)3、 for i in range(5):Flyer[i].…

基于若依框架搭建网站的开发日志(一):若依框架搭建、启动、部署

RuoYi&#xff08;基于SpringBoot开发的轻量级Java快速开发框架&#xff09; 链接&#xff1a;开源地址 若依是一款开源的基于VueSpringCloud的微服务后台管理系统&#xff08;也有SpringBoot版本&#xff09;&#xff0c;集成了用户管理、权限管理、定时任务、前端表单生成等…

MYSQL8.0.20安装教程

一&#xff1a;下载mysql MySQL :: Download MySQL Installer (Archived Versions) 二&#xff1a;选中server only&#xff0c;点击next 三&#xff1a;点击server 选项&#xff0c;点击Execute 弹窗点击安装 四&#xff1a;安装项为绿色后&#xff0c;点击next 五&#xf…

大数据中的HDFS读写流程(namenode,datanode)

HDFS读写流程 读取流程 1、客户端请求上传文件 2、namenode检查是否存在&#xff0c;可以上传&#xff0c; 3、客户端请求第一个block块上传到datanode 4、namenode返回3个datanode节点&#xff0c;d1,d2,d3 5、客户端请求dn1调用数据&#xff0c;d1收到请求会继续调用d2&#…

使用海外云手机为亚马逊店铺引流

在全球经济一体化的背景下&#xff0c;出海企业与B2B外贸企业愈发重视海外市场的深耕&#xff0c;以扩大市场份额。本文旨在探讨海外云手机在助力亚马逊店铺提升引流效果方面的独特作用与优势。 海外云手机&#xff0c;一种基于云端技术的虚拟手机&#xff0c;能够在单一硬件上…

经典分类网络LeNet5和VGG16项目:实现CIFAR10分类

CIFAR10分类 v1&#xff1a;LeNet5&#xff1a;2cnn3fc 可视化结果 精确率 损失 最佳 v2&#xff1a;LeNet5&#xff1a;3cnn2fc 可视化结果 精确率 损失 最佳 v3&#xff1a;LeNet5&#xff1a;2cnnbnres3fc 可视化结果 精确率 损失 最佳 v4&#xff1a;VG…

web API设计笔记

Hello , 我是小恒。今晚就讲讲我在开发维护API后的经验分享&#xff0c;当然我知识有限&#xff0c;暂时也不会写实际操作。GitHub项目仓库有一堆还在前期开发&#xff0c;我的时间很多时间投在了开源上。 推荐书籍 我认为一个好的 API 设计是面向用户的&#xff0c;充分隐藏底…

2024华为数通HCIP-datacom最新题库(变题版)

请注意&#xff0c;华为HCIP-Datacom考试831已变题 请注意&#xff0c;华为HCIP-Datacom考试831已变题 请注意&#xff0c;华为HCIP-Datacom考试831已变题 近期打算考HCIP的朋友注意了&#xff0c;如果你准备去考试&#xff0c;还是用的之前的题库&#xff0c;切记暂缓。 H1…

IDEA切换分支

1、选择要切换分支的module 2、右键&#xff0c;选择git 3、再点击branches 4、可以看到当前module的本地分支&#xff08;local Branches&#xff09;及远程分支&#xff08;Remote Branches&#xff09;列表。点击你要切换到的分支,Checkout即可。

01WPS部分编写实现QT

1、新建项目 -创建wps类 -继承QMainWindow 2、菜单栏设置 3、开始实现操作 设置程序图标&#xff1a; pro文件中添加 RC_ICONS images/wps.ico //后面这个是文件地址哈1、字体选择大小设置 void MainWindow::initMainWindow() {// 初始化字号列表项QFontDatabase fontdb;…

Sarcasm detection论文解析 |# 利用情感语义增强型多层次记忆网络进行讽刺检测

论文地址 论文地址&#xff1a;https://www.sciencedirect.com/science/article/abs/pii/S0925231220304689?via%3Dihub#/ 论文首页 笔记框架 利用情感语义增强型多层次记忆网络进行讽刺检测 &#x1f4c5;出版年份:2020 &#x1f4d6;出版期刊:Neurocomputing &#x1f4c8;影…

预测市场?预测股票?如何让预测有更高的准确率?

我们发现在足球赛中&#xff0c;只要知道一个简单的讯息&#xff08;主队过去的获胜机率超过一半&#xff09;&#xff0c;预测力就会明显好过随便乱猜。如果再加上第二个简单的讯息&#xff08;胜负纪录较佳的队伍会略占优势&#xff09;&#xff0c;可以再进一步提升预测力。…