Scikit-LearnTensorFlow机器学习实用指南(三):一个完整的机器学习项目【下】

news2024/9/26 20:16:17

机器学习实用指南(三):一个完整的机器学习项目【下】

作者:LeonG
本文参考自:《Hands-On Machine Learning with Scikit-Learn & TensorFlow 机器学习实用指南》,感谢中文AI社区ApacheCN提供翻译。

本文全部代码和数据集保存在我的github-----LeonG的github

1.回顾

在上一节,我们从网络上获取了数据:housing

然后将数据分为训练集strat_train_set和测试集strat_test_set

将训练标签也就是房价单独分离housing_labels

最后分析了训练集的一些规律,针对这个数据集制作了一个数据整理工具full_pipeline,将训练集strat_train_set转为housing_prepared

经过这些步骤,我们的训练模型只需要输入训练数据housing_prepared和训练标签housing_labels,就可以得到训练好的模型了。

2.训练模型

接下来我们要尝试几种机器学习的算法模型:线性回归模型、决策树模型、随机森林模型。

提示:这些算法模型的具体原理和细节在以后的章节会详细解析,本章只是简单的使用,不用担心不看懂。

2.1线性回归模型

我们先来训练一个线性回归模型,借助sklearn中的LinearRegression类来实现:

from sklearn.linear_model import LinearRegression
lin_reg = LinearRegression()
#输入训练数据进行训练
lin_reg.fit(housing_prepared,housing_labels)

只需要这样简单的三行操作就能训练完一个线性回归模型。

现在使用训练集中的前五行来验证:

#取前五行数据
some_data = housing.iloc[:5]
some_labels = housing_labels.iloc[:5]
#对这些数据进行预测(代入到训练好的模型中计算出预测房价)
some_data_prepared = full_pipeline.transform(some_data)
#模型的预测值
print("Predictions:\t", lin_reg.predict(some_data_prepared))
#数据集的标签值
print("Labels:\t\t", list(some_labels))
Predictions: [210644.60 317768.80 210956.43 59218.98 189747.55]
Labels:      [286600.0, 340600.0, 196900.0, 46300.0, 254500.0]

可以看出数据之间差距还是比较大的,我们计算一下这个回归模型的RMSE

RMSE是均方根误差:
\sqrt{\sum_{i=1}^{n}\frac{1}{n}{(f(x_i)-y_i)}^{2}}
均方根误差的意义大概可以理解为:预测值和实际值的平均差距

当然,我们也不需要手动写公式,直接让sklearn来帮我们算:

from sklearn.metrics import mean_squared_error
#计算出训练集的所有预测值
housing_predictions = lin_reg.predict(housing_prepared)
#计算线性模型的预测值和实际值的均方误差
lin_mse = mean_squared_error(housing_labels, housing_predictions)
#均方根误差=均方误差开方
lin_rmse = np.sqrt(lin_mse)
lin_rmse
68628.19819848923

计算得出线性回归模型的均方根误差很大,这样肯定不行。

结论:对于该数据集来说,线性回归模型是一个欠拟合模型。

2.2决策树模型

修复欠拟合的主要方法是选择一个更强大的模型,接下来试试决策树模型。

决策树模型可以发现数据中复杂的非线性关系。借助sklearn中的DecisionTreeRegressor类来实现:

from sklearn.tree import DecisionTreeRegressor
tree_reg = DecisionTreeRegressor()
#输入训练数据进行训练
tree_reg.fit(housing_prepared, housing_labels)

这次不取前五行测试了,直接计算这个模型的RMSE:

#计算出训练集的所有预测值
housing_predictions = tree_reg.predict(housing_prepared)
#计算决策树模型的预测值和实际值的均方误差
tree_mse = mean_squared_error(housing_labels, housing_predictions)
#均方根误差=均方误差开方
tree_rmse = np.sqrt(tree_mse)
tree_rmse
0.0

咦,没有误差?这个模型是绝对完美的吗?不对,这是因为模型严重的拟合数据,任何一条训练数据都能得到对应的训练标签。

结论:对于该数据集来说,决策树模型是一个过拟合模型。

3.交叉验证

如何验证模型的真实水平呢?在确定模型之前,我们都不要碰测试集,所以需要用训练集的部分数据来做训练,接下来使用交叉验证法。

3.1K折交叉验证法

K折交叉验证法:将数据集分为K份,称为折,每次用其中一个折作为测试集来计算误差,经过K次计算后求出一组长度为K的误差值,这组误差的平均值就是交叉验证得出的误差值。

我们将K设置为10:

借助sklearn可以很简单的实现验证:

from sklearn.model_selection import cross_val_score
#总共有五个参数,第一个是模型,2和3是数据,scoring指定了计算方式,cv是K值
scores = cross_val_score(tree_reg, housing_prepared, housing_labels,
                         scoring="neg_mean_squared_error", cv=10)
#score是效用函数计算得出的,实际上和均方误差相反,所以要加上负号
tree_rmse_scores = np.sqrt(-scores)

设置一个输出函数来查看具体情况:

def display_scores(scores):
    #均方误差,一共十个
    print("Scores:", scores) 
    #平均的均方误差
    print("Mean:", scores.mean())
    #均方误差的标准差
    print("Standard deviation:", scores.std())
display_scores(tree_rmse_scores)
Scores: [67649.82 67698.67 71079.28 69445.09 71808.23 
         73827.59 71111.46 71243.31 75630.03 70498.20]
Mean: 70999.17217565424
Standard deviation: 2344.261017051602

可以看出,决策树模型并没有那么好用,甚至比线性回归模型还糟糕。

交叉验证不仅可以得到模型性能的评估,还能测量评估的准确性(标准差)。决策树的误差大概是71000,波动幅度±2300。

3.2随机森林模型

上述两个模型误差都很大,现在使用随机森林模型。

这个模型的名字很有意思,如果将决策树看做是一棵树的话,随机森林就是随机组合一些属性来训练许多决策树,在其他多个模型之上建立模型成为集成学习。借助sklearn中的RandomForestRegressor类来实现:

from sklearn.ensemble import RandomForestRegressor

forest_reg = RandomForestRegressor()
#输入训练数据进行训练
forest_reg.fit(housing_prepared,housing_labels)
#计算交叉验证误差
scores = cross_val_score(forest_reg, housing_prepared, housing_labels,
                         scoring="neg_mean_squared_error", cv=10)
#score是效用函数计算得出的,实际上和均方误差相反,所以要加上负号
forest_rmse_scores = np.sqrt(-scores)
display_scores(forest_rmse_scores)
Scores: [51066.82 50166.73 52755.29 55534.63 51963.35 
         54194.03 52341.03 50770.49 54823.32 52582.53]
Mean: 52619.825856487405
Standard deviation: 1679.5101421709217

看起来效果比上面两个模型都要好,实际上我们应该多测试几个模型,比如不同核心的支持向量机、神经网络等等,目标是列出可以使用模型的列表。做完之后就是对模型的微调了。

4.模型微调

调整什么呢,调整超参数,超参数是什么,超参数就是不能通过学习来自动调整的参数。比如学习率,神经网络的层数等等。

在机器学习中,超参数是在开始学习过程之前设置值的参数,而不是通过训练得到的参数数据。通常情况下,需要对超参数进行优化,给学习机选择一组最优超参数,以提高学习的性能和效果。

超参数的一些示例:

  • 树的数量或树的深度
  • 矩阵分解中潜在因素的数量
  • 学习率(多种模式)
  • 深层神经网络隐藏层数
  • k均值聚类中的簇数

4.1网格搜索

网格搜索的意思很简单,为几个超参数设定一个范围取值,逐一搜索最佳组合的方式就是网格搜索。

我们借助SKlearn中的GridSearchCV来自动完成搜索工作。

你所需要做的是告诉 GridSearchCV 要试验有哪些超参数,要试验什么值, GridSearchCV 就能用交叉验证试验所有可能超参数值的组合。例如,下面的代码搜索了随机森林模型超参数值的最佳组合:

from sklearn.model_selection import GridSearchCV
param_grid = [
#字典1:尝试3×4=12种组合
{'n_estimators': [3, 10, 30], 'max_features': [2, 4, 6, 8]},
#字典2:尝试1×2×3=6种组合
{'bootstrap': [False], 'n_estimators': [3, 10], 'max_features': [2, 3, 4]},
]
forest_reg = RandomForestRegressor()
#定义一个网格搜索,采用5折交叉验证法,判断标准是均方误差
grid_search = GridSearchCV(forest_reg, param_grid, cv=5,scoring='neg_mean_squared_error')
#使用这些组合来训练随机森林
grid_search.fit(housing_prepared, housing_labels)
#输出最佳超参数组合
grid_search.best_params_
#输出最佳模型
grid_search.best_estimator_
{'max_features': 8, 'n_estimators': 30}
RandomForestRegressor(bootstrap=True, criterion='mse', max_depth=None,
                      max_features=8, max_leaf_nodes=None,
                      min_impurity_decrease=0.0, min_impurity_split=None,
                      min_samples_leaf=1, min_samples_split=2,
                      min_weight_fraction_leaf=0.0, n_estimators=30,
                      n_jobs=None, oob_score=False, random_state=None,
                      verbose=0, warm_start=False)

两个搜索组合,第一个组合有12种情况,第二个组合有6种情况,一共18种情况,因为是五折交叉验证,所以,一共要进行18×5=90轮训练!完成后,就能返回超参数的最佳组合best_params_最佳模型best_estimator_

我们还可以查看网格搜索中每一个属性组合的得分情况:

cvres = grid_search.cv_results_
#输出超参数组合对应的得分情况
for mean_score, params in zip(cvres["mean_test_score"], cvres["params"]):
    print(np.sqrt(-mean_score), params)
65092.02010981559 {'max_features': 2, 'n_estimators': 3}
55359.4496463199  {'max_features': 2, 'n_estimators': 10}
52556.54068544674 {'max_features': 2, 'n_estimators': 30}
59605.3135515846  {'max_features': 4, 'n_estimators': 3}
53164.81127753654 {'max_features': 4, 'n_estimators': 10}
50676.9400303366  {'max_features': 4, 'n_estimators': 30}
59477.454981964   {'max_features': 6, 'n_estimators': 3}
52493.72219849825 {'max_features': 6, 'n_estimators': 10}
50129.58037046355 {'max_features': 6, 'n_estimators': 30}
58975.88486428257 {'max_features': 8, 'n_estimators': 3}
51999.19103564533 {'max_features': 8, 'n_estimators': 10}
49948.7230892116  {'max_features': 8, 'n_estimators': 30}
61188.21752051931 {'bootstrap': False, 'max_features': 2, 'n_estimators': 3}
54352.52617644768 {'bootstrap': False, 'max_features': 2, 'n_estimators': 10}
60598.83975833867 {'bootstrap': False, 'max_features': 3, 'n_estimators': 3}
52874.9435481527  {'bootstrap': False, 'max_features': 3, 'n_estimators': 10}
59428.43938347719 {'bootstrap': False, 'max_features': 4, 'n_estimators': 3}
52007.58232013096 {'bootstrap': False, 'max_features': 4, 'n_estimators': 10}

可以看出{'max_features': 8, 'n_estimators': 30}这个超参数组合得分最高,我们成功的使用网格搜索法调整了超参数值,将误差值从52619降低到49948

4.2其他方法

网格搜索看起来就是穷举法,穷举的方式在组合数少的情况下还能用,组合多的话最好使用随机搜索RandomizedSearchCV,虽然不会尝试所有的组合,但是能抽取更多完全不同的组合情况,还能方便的设定搜索次数,控制计算量。

还有一种方法是集成法,将几个不同的最佳模型组合起来使用,这个方法在后面的章节深入讲解。

5.测试模型

现在,我们可以测试一下调整好的模型了。

测试集也要进行处理,类比本章第一节的操作:

  1. 将测试数据分为两个测试数据X_test和测试标签y_test

  2. 将测试数据进行数据整理将X_test转为X_test_prepared,注意,这里使用tranform函数而不是fit_transform函数。

最后对测试集进行预测,计算预测值和实际值的均方根误差就能得到最终的误差效果。

#获得最佳的模型
final_model = grid_search.best_estimator_
#分割测试数据和测试标签
X_test = strat_test_set.drop("median_house_value", axis=1)
y_test = strat_test_set["median_house_value"].copy()
#将测试数据进行整理(使用transform函数)
X_test_prepared = full_pipeline.transform(X_test)
#计算预测值
final_predictions = final_model.predict(X_test_prepared)
#计算预测值和测试标签的均方误差
final_mse = mean_squared_error(y_test, final_predictions)
#计算最终的均方根误差
final_rmse = np.sqrt(final_mse)
48154.525254070046

得出该模型的均方根误差为48154。

至此,我们机器学习项目的开发阶段就算告一段落了。

最后,就是项目的预上线了,我们需要向万达集团展示具体实施方案,然后给自己倒上一杯卡布奇诺。

希望这一章能告诉你机器学习项目是什么样的,你能用学到的工具训练一个好系统。

你已经看到,大部分的工作是数据准备步骤、搭建监测工具、建立人为评估的流水线和自动化定期模型训练。

当然,最好能了解整个过程、熟悉三或四种算法,而不是在探索高级算法上浪费全部时间,导致在全局上的时间不够。 因此,如果你还没做,现在最好拿起台电脑,选择一个感兴趣的数据集,将整个流程从头到尾完成一遍。

讲实话,能坚持学到这里的朋友,真的很优秀。学完了这一章,你的机器学习之路已经成功了一半。

接下来我们会对机器学习的各种算法进行具体的学习和实践。


欢迎来我的博客留言讨论,我的博客主页:LeonG的博客

本文参考自:《Hands-On Machine Learning with Scikit-Learn & TensorFlow机器学习实用指南》,感谢中文AI社区ApacheCN提供翻译。

版权声明:本文为博主原创文章,遵循 CC 4.0 BY-SA 版权协议,转载请附上原文出处链接和本声明。

最后编辑于:2024-08-25 10:33:47


喜欢的朋友记得点赞、收藏、关注哦!!!

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2167880.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

TypeError: load() missing 1 required positional argument: ‘Loader‘

标题TypeError: load() missing 1 required positional argument: ‘Loader’ 源码: 处理后: 顺利通过,由于yaml版本导致的问题

Alertmanager 路由匹配

Alertmanager主要负责对Prometheus产生的告警进行统一处理,因此在Alertmanager配置中一般会包含以下几个主要部分: 全局配置(global):用于定义一些全局的公共参数,如全局的SMTP配置,Slack配置等…

day-61 外观数列

思路 每次对字符串进行遍历即可,用一个Integer统计相邻的相同字符个数,如果当前字符与后面邻接的字符相同,num;如果不同,则将""nums.charAt(j)拼接到字符串中 解题过程 当n1时,可以直接返回,不为…

【机器学习导引】ch3-线性模型

线性回归 梯度 在数学中,对于函数 f ( x 1 , … , x m ) f(x_1, \ldots, x_m) f(x1​,…,xm​) 在点 a ( a 1 , … , a m ) a (a_1, \ldots, a_m) a(a1​,…,am​) 处的梯度被定义为: ∇ f ( a ) ( ∂ f ∂ x 1 ( a ) , … , ∂ f ∂ x m ( a ) )…

排序题目:对角线遍历 II

文章目录 题目标题和出处难度题目描述要求示例数据范围 解法思路和算法代码复杂度分析 题目 标题和出处 标题:对角线遍历 II 出处:1424. 对角线遍历 II 难度 6 级 题目描述 要求 给定一个二维整数数组 nums \texttt{nums} nums,将 …

vue3.0 + element plus 全局自定义指令:select滚动分页

需求:项目里面下拉框数据较多 ,一次性请求数据,体验差,效果就是滚动进行分页。 看到这个需求的时候,我第一反应就是封装成自定义指令,这样回头用的时候,直接调用就可以了。 第一步 第二步&…

eHR软件的价格一般是多少?

在人力资源数字化转型的大潮中,越来越多的企业开始关注eHR(电子人力资源管理)软件的采购问题。eHR软件价格并不是一个简单的数字,而是受多种因素影响,具有较大波动性。那么,eHR软件的价格一般是多少呢&…

侧边菜单的展开和折叠

通过按钮控制侧边栏的展开和折叠通过窗口宽度的变化控制侧边栏的展开和折叠&#xff08;小于768px,自动折叠&#xff09; 通过按钮控制展开 通过按钮控制折叠 切换到手机模式自动折叠 环境准备&#xff1a;Vue3Element-UI Plus <script setup> import {onMounted, r…

基于SpringBoot + Vue的Gucci进销存系统

文章目录 前言一、详细操作演示视频二、具体实现截图三、技术栈1.前端-Vue.js2.后端-SpringBoot3.数据库-MySQL4.系统架构-B/S 四、系统测试1.系统测试概述2.系统功能测试3.系统测试结论 五、项目代码参考六、数据库代码参考七、项目论文示例结语 前言 &#x1f49b;博主介绍&a…

001. OBS (obs-studio)

1. 下载 https://obsproject.com/download windows c 插件下载 https://obsproject.com/visual-studio-2022-runtimes 2. 操作步骤 https://renwen.shnu.edu.cn/_s40/9a/2c/c28309a760364/page.psp https://zhuanlan.zhihu.com/p/597231652

【Java 问题】基础——IO

接上文 IO 42.Java 中 IO 流分为几种?Java IO体系中的装饰器模式抽象组件&#xff08;Component&#xff09;具体组件&#xff08;Concrete Component&#xff09;抽象装饰器&#xff08;Decorator&#xff09;具体装饰器&#xff08;Concrete Decorator&#xff09;使用装饰器…

喜讯 | 宝兰德「应用服务器软件 V9.5」荣获“2024年度优秀软件产品”殊荣

近日&#xff0c;中国软件行业协会公布了“2024年度推广优秀软件产品”名单。经过专家委员会的评议及最终审核&#xff0c;宝兰德凭借领先的技术能力和丰富的经验积累&#xff0c;中间件核心产品「应用服务器软件 V9.5」获评“2024年度优秀软件产品”。 本次评选活动由中国软件…

基于SpringBoot的在线考试系统设计与实现

1.1 研究背景 21世纪&#xff0c;我国早在上世纪就已普及互联网信息&#xff0c;互联网对人们生活中带来了无限的便利。像大部分的企事业单位都有自己的系统&#xff0c;由从今传统的管理模式向互联网发展&#xff0c;如今开发自己的系统是理所当然的。那么开发在线考试系统意…

vscode【实用插件】Project Manager 项目管理

安装 在 vscode 插件市场的搜索 Project Manager点 安装 安装成功后&#xff0c;vscode 左侧栏会出现 使用 将项目添加到项目列表中 用 vscode 打开项目&#xff0c;点保存即可 将项目移出项目列表 切换项目 单击项目列表中的项目&#xff0c;即可切换到目标项目 新窗口打开…

道一云·七巧和金蝶云星空单据接口对接

道一云七巧和金蝶云星空单据接口对接 对接系统金蝶云星空 金蝶K/3Cloud结合当今先进管理理论和数十万家国内客户最佳应用实践&#xff0c;面向事业部制、多地点、多工厂等运营协同与管控型企业及集团公司&#xff0c;提供一个通用的ERP服务平台。K/3Cloud支持的协同应用包括但不…

淘宝霸屏必备工具:淘宝商品评论电商API接口

淘宝商品评论电商API接口是指用于获取淘宝商品评论信息的一种接口&#xff0c;通过该接口可以获取淘宝网上商品的评价内容、评价等级、评价数量等信息。通过了解并使用该接口&#xff0c;能够帮助电商了解消费者对商品的评价情况&#xff0c;做好商品的推广和销售工作。 接口使…

电脑提速秘籍:6款不可不知的Windows实用软件

6款Windows系统上不可或缺的高效工具&#xff0c;每一款都是小巧而强大的存在&#xff0c;让你的电脑使用更加流畅&#xff01; 1.unlocker 当你遇到那些顽固的文件&#xff0c;需要管理员权限或者重启电脑才能删除时&#xff0c;这款只有1.02MB的轻量级工具可以帮你轻松解决问…

《黑神话悟空》战斗流派与技能加点指南及录屏技巧

在深入探索《黑神话悟空》的战斗艺术之前&#xff0c;让我们先来了解一些基本的战斗流派和技能加点策略&#xff0c;这将为你的西游之旅增添无限可能。不仅如此&#xff0c;我们还将介绍一款实用的录屏工具&#xff0c;让你能够轻松记录并分享你的冒险经历。现在&#xff0c;就…

水下生物检测系统源码分享

水下生物检测检测系统源码分享 [一条龙教学YOLOV8标注好的数据集一键训练_70全套改进创新点发刊_Web前端展示] 1.研究背景与意义 项目参考AAAI Association for the Advancement of Artificial Intelligence 项目来源AACV Association for the Advancement of Computer Vis…

SAP B1 流程实操 - 营销单据采购部分(上)

背景 在 SAP B1 中&#xff0c;除开【销售】外超常用的模块就是【采购】&#xff0c;企业可能不涉及生产和库存&#xff08;贸易公司&#xff09;&#xff0c;甚至不涉及采购&#xff08;服务业&#xff09;&#xff0c;但是一定会有基本的 销售。本文中我们讲解 销售 模块的基…