机器学习入门实例-加州房价预测-3(选择与训练模型+调参)

news2024/12/24 3:26:58

选择与训练模型

使用线性回归

	from sklearn.linear_model import LinearRegression
    from sklearn.metrics import mean_squared_error

    lin_reg = LinearRegression()
    lin_reg.fit(housing_prepared, housing_labels)
    housing_predictions = lin_reg.predict(housing_prepared)
    lin_mse = mean_squared_error(housing_labels, housing_predictions)
    lin_rmse = np.sqrt(lin_mse)
    # 太高的表示Underfit
    print(lin_rmse) # 68627.87390018745

在这里插入图片描述

使用决策树

    from sklearn.metrics import mean_squared_error
    from sklearn.tree import DecisionTreeRegressor

    tree_reg = DecisionTreeRegressor()
    tree_reg.fit(housing_prepared, housing_labels)
    housing_predictions = tree_reg.predict(housing_prepared)
    tree_mse = mean_squared_error(housing_labels, housing_predictions)
    tree_rmse = np.sqrt(tree_mse)
    # 太低的表示overfit
    print(tree_rmse) # 0.0

k-fold cross-validation

fold指不同的数据子集。k-fold cross-validation就是随机产生k个fold,每次选一个fold来评估效果,其他k-1个fold用来训练。这样会得到k个评估分数。
由于scikit-learn的cross-validation用的是utility function(越大越好)而非cost function(越小越好),所以这里选用了负数的mse。

def display_scores(scores):
    print("Scores: ", scores)
    print("Mean: ", scores.mean())
    print("Standard deviation: ", scores.std())

...
	from sklearn.tree import DecisionTreeRegressor
	from sklearn.linear_model import LinearRegression
	from sklearn.model_selection import cross_val_score
	
    tree_reg = DecisionTreeRegressor()
    scores = cross_val_score(tree_reg, housing_prepared, housing_labels,
                             scoring="neg_mean_squared_error", cv=10)
    tree_rmse_scores = np.sqrt(-scores)
    print(display_scores(tree_rmse_scores))

    scores2 = cross_val_score(LinearRegression(), housing_prepared, housing_labels,
                              scoring="neg_mean_squared_error", cv=10)
    lin_rmse_scores = np.sqrt(-scores2)
    print(display_scores(lin_rmse_scores))

    from sklearn.ensemble import RandomForestRegressor
    forest_reg = RandomForestRegressor()
    scores3 = cross_val_score(forest_reg, housing_prepared, housing_labels,
                              scoring="neg_mean_squared_error", cv=10)
    forest_rmse_scores = np.sqrt(-scores3)
    print(display_scores(forest_rmse_scores))

Scores:  [72747.0548719  69975.38915452 68277.88913022 71825.88436612
 69330.68736727 76568.19953125 72764.63740705 72219.04604848
 69031.97373315 70374.45146712]
Mean:  71311.52130770871
Standard deviation:  2321.1947983407713

Scores:  [71762.76364394 64114.99166359 67771.17124356 68635.19072082
 66846.14089488 72528.03725385 73997.08050233 68802.33629334
 66443.28836884 70139.79923956]
Mean:  69104.07998247063
Standard deviation:  2880.3282098180657

Scores:  [51627.53671267 49369.27783646 46517.4011464  51846.1271283
 47431.17962375 51357.65695338 52565.55993554 49896.31387898
 48580.76091798 53639.76973388]
Mean:  50283.15838673507
Standard deviation:  2192.6847023341056

可以看到,随机森林的效果好一点点。

保存模型可以用joblib包:

    import joblib
    joblib.dump(forest_reg, "forest_regression.pkl")

使用时:

import joblib
my_model_loaded = joblib.load("forest_regression.pkl")
y_pred = my_model_loaded.predict(X_test)

Fine-tune(微调)

Grid Search

    from sklearn.ensemble import RandomForestRegressor
    from sklearn.model_selection import GridSearchCV
    forest_reg = RandomForestRegressor()
    param_grid = [
        {'n_estimators': [3, 10, 30, 50], 'max_features': [2, 4, 6, 8, None]},
        {'bootstrap': [False], 'n_estimators': [3, 10, 30], 'max_features': [2, 3, 4, 8]}
    ]
    grid_search = GridSearchCV(forest_reg, param_grid, cv=5,
                               scoring="neg_mean_squared_error",
                               return_train_score=True)
    grid_search.fit(housing_prepared, housing_labels)
    print(grid_search.best_params_)
    print(grid_search.best_estimator_)
    print(np.sqrt(-grid_search.best_score_))
    # 输出训练结果
    cvres = grid_search.cv_results_
    for mean_score, params in zip(cvres["mean_test_score"], cvres["params"]):
        print(np.sqrt(-mean_score), params)


{'bootstrap': False, 'max_features': 4, 'n_estimators': 30}
RandomForestRegressor(bootstrap=False, max_features=4, n_estimators=30)
49184.012820612865
63296.60697816219 {'max_features': 2, 'n_estimators': 3}
55506.98954834032 {'max_features': 2, 'n_estimators': 10}
52588.29130799578 {'max_features': 2, 'n_estimators': 30}
52209.351970524105 {'max_features': 2, 'n_estimators': 50}
61124.66160352609 {'max_features': 4, 'n_estimators': 3}
52392.16874785449 {'max_features': 4, 'n_estimators': 10}
50273.44009993824 {'max_features': 4, 'n_estimators': 30}
50073.39902683726 {'max_features': 4, 'n_estimators': 50}
58727.56972424551 {'max_features': 6, 'n_estimators': 3}
51782.31672170441 {'max_features': 6, 'n_estimators': 10}
49980.86986567165 {'max_features': 6, 'n_estimators': 30}
49621.90325596429 {'max_features': 6, 'n_estimators': 50}
59108.772714035185 {'max_features': 8, 'n_estimators': 3}
52103.26690162754 {'max_features': 8, 'n_estimators': 10}
50132.82702279907 {'max_features': 8, 'n_estimators': 30}
49781.109586154096 {'max_features': 8, 'n_estimators': 50}
59531.39144640905 {'max_features': None, 'n_estimators': 3}
53194.75190489223 {'max_features': None, 'n_estimators': 10}
51329.73377108047 {'max_features': None, 'n_estimators': 30}
50826.59885373242 {'max_features': None, 'n_estimators': 50}
61849.029391012125 {'bootstrap': False, 'max_features': 2, 'n_estimators': 3}
54126.80686429009 {'bootstrap': False, 'max_features': 2, 'n_estimators': 10}
51848.68218742629 {'bootstrap': False, 'max_features': 2, 'n_estimators': 30}
60098.910109841345 {'bootstrap': False, 'max_features': 3, 'n_estimators': 3}
52334.13150446245 {'bootstrap': False, 'max_features': 3, 'n_estimators': 10}
50341.87869180483 {'bootstrap': False, 'max_features': 3, 'n_estimators': 30}
57650.86470796738 {'bootstrap': False, 'max_features': 4, 'n_estimators': 3}
51682.89592908345 {'bootstrap': False, 'max_features': 4, 'n_estimators': 10}
49184.012820612865 {'bootstrap': False, 'max_features': 4, 'n_estimators': 30}
57308.84929909555 {'bootstrap': False, 'max_features': 8, 'n_estimators': 3}
51285.68101404841 {'bootstrap': False, 'max_features': 8, 'n_estimators': 10}
49469.879650849194 {'bootstrap': False, 'max_features': 8, 'n_estimators': 30}

解释:param_grid表示要测试两组参数,第一组是n_estimators和max_features的组合,所以共有4 x 5 = 20种;第二组是 3 x 4 = 12种,所以一共测试了20 + 12 = 32种参数设置方式。每一种训练5次。训练结束后可以通过best_params_获取参数组合。

  • n_estimators:随机森林中树的数量,默认为 100
  • max_features:选择特征的方式,可以是整数、浮点数、字符串或 None。默认为 “auto”,即使用所有特征。如果是整数,则表示每次分裂只考虑部分特征;如果是浮点数,则表示每次分裂只考虑部分特征的比例;如果是字符串 “sqrt” 或 “log2”,则表示每次分裂只考虑特征数的平方根或对数;如果是 None,则表示每次分裂都使用所有特征
  • bootstrap 参数是指是否使用自助法(bootstrap)从原始样本中进行有放回的抽样来生成新的训练集。默认值为 True,即使用自助法。如果将其设为 False,则表示不使用自助法,即不进行有放回的抽样。

gs也可以用自定义的参数。

from sklearn.pipeline import Pipeline
from sklearn.ensemble import RandomForestRegressor
from sklearn.model_selection import GridSearchCV

# 定义自定义的 Transformer
class MyTransformer(BaseEstimator, TransformerMixin):
    def __init__(self, param1=1, param2=1):
        self.param1 = param1
        self.param2 = param2

    def fit(self, X, y=None):
        # ...
        return self

    def transform(self, X, y=None):
        # ...
        return X_transformed

# 定义 Pipeline 和 param_grid
pipeline = Pipeline([
    ('my_transformer', MyTransformer()),
    ('rf', RandomForestRegressor())
])

param_grid = {
    'my_transformer__param1': [1, 2, 3],
    'my_transformer__param2': [0.1, 0.2, 0.3],
    'rf__n_estimators': [10, 50, 100],
    'rf__max_depth': [None, 5, 10],
}
# 进行网格搜索
grid_search = GridSearchCV(pipeline, param_grid=param_grid, cv=5, n_jobs=-1)
grid_search.fit(X, y)

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/439403.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

直播预告 | TDengine Apache SeaTunnel 联合应用最佳实践

TDengine 自诞生之日起,除产品层面的技术创新和实力提升外,也在大力完善自身产品生态,以此进一步满足用户的业务需求、提升使用体验。 近日,TDengine 与 Apache SeaTunnel 展开集成合作,双方将于 4 月 18 日 19:00 联…

二、Java 并发编程(5)

本章概要 线程上下文切换 线程上下文切换的流程导致线程上下文切换的原因 Java中的阻塞队列 阻塞队列的主要操作Java中阻塞队列的实现 2.7 线程上下文切换 CPU 利用时间片轮询来为每个任务都服务一定的时间,然后把当前任务的状态保存下来,继续服务下…

IO多路转接—select,poll,epoll

目录 select 函数介绍 select基本工作流程 select的优缺点及适用场景 poll poll的优缺点 epoll epoll的相关系统调用 epoll_create epoll_ctl epoll_wait epoll工作原理 epoll服务器编写 epoll的优点 epoll工作方式 select 函数介绍 系统提供select函数来实现多路复…

Spring核心设计思想

目录 前言: Spring是什么 什么是IoC 传统开发思想 IoC开发思想 Spring IoC 什么是DI 小结: 前言: 官网中提出:Spring makes programming Java quicker, easier, and safer for everybody. Spring’s focus on speed, simp…

YOLOv7+单目测距(python)

YOLOv7单目测距(python) 1. 相关配置2. 测距原理3. 相机标定3.1:标定方法13.2:标定方法2 4. 相机测距4.1 测距添加4.2 主代码 5. 实验效果 相关链接 1. YOLOV5 单目测距(python) 2. YOLOV5 双目测距&…

基于springboot的招聘信息管理系统源码数据库论文

目 录 1 绪 论 1.1 课题背景与意义 1.2 系统实现的功能 1.3 课题研究现状 2系统相关技术 2.1 Java语言介绍 2.2 B/S架构 2.3 MySQL 数据库介绍 2.4 MySQL环境配置 2.5 SpringBoot框架 3系统需求分析 3.1系统功能 3.2可行性研究 3.2.1 经济可行性 …

力扣sql中等篇练习(六)

力扣sql中等篇练习(六) 1 购买了产品A和产品B却没有购买产品C的顾客 1.1 题目内容 1.1.1 基本题目信息 1.1.2 示例输入输出 1.2 示例sql语句 # 先求出既有的,然后再去筛选掉没有的 # 去重用不了内连接 SELECT t1.customer_id,c.customer_name FROM ( SELECT distinct cust…

《Spring MVC》 第二章 第一个程序

前言 Spring MVC 是 Spring 框架提供的一款基于 MVC 模式的轻量级 Web 开发框架。 Spring MVC 本质是对 Servlet 的进一步封装,其最核心的组件是DispatcherServlet,它是 Spring MVC 的前端控制器,主要负责对请求和响应的统一地处理和分发。C…

C++ auto 内联函数 指针空值

本博客基于 上一篇博客的 序章,主要对 C 当中对C语言的缺陷 做的优化处理。 上一篇博客:C 命名空间 输入输出 缺省参数 引用 函数重载_chihiro1122的博客-CSDN博客 auto关键字 auto作为一个新的类型指示符来指示编译器,auto声明的变量必须由…

uni-app使用时遇到的坑

一.uni-app开发规范 1.微信小程序request请求需要https 小程序端: 在本地运行时,可以使用http 但是预览或者上传时,使用http无法请求 APP端: 一般APP可以使用http访问 高版本的APP可能需要用https访问 二. uni-app项目 配置App升…

Java语言请求示例,电商商品详情接口,接口封装

Java具有大部分编程语言所共有的一些特征,被特意设计用于互联网的分布式环境。Java具有类似于C语言的形式和感觉,但它要比C语言更易于使用,而且在编程时彻底采用了一种以对象为导向的方式。 使用Java编写的应用程序,既可以在一台…

如何更好的进行数据管理?10 条建议给到你

这个时代数据量的快速增长和数据复杂性的大幅度提高,让企业迫切的寻找更加智能的方式管理数据,从而有效提高 IT 效率。 管理数据库不是单一的目标,而是多个目标并行,如数据存储优化、效率、性能、安全。只有管理好数据从创建到删除…

newman结合jenkins实现自动化测试

一、背景 为了更好的保障产品质量和提升工作效率,使用自动化技术来执行测试用例。 二、技术实现 三、工具安装 3.1 安装newman npm install -g newman查看newman版本安装是否成功,打开命令行,输入newman -v,出现 版本信息即安…

浅述 国产仪器 6362D光谱分析仪

6362D光谱分析仪(简称:光谱仪)是一款高分辨、大动态高速高性能光谱分析仪,适用于600~1700nm光谱范围的DWDM、光放大器等光系统测试; LED、FP-LD、DFB-LD、光收发器等光有源器件测试;光纤、光纤光…

C语言基础应用(五)循环结构

引言 如果要求123…100,你会怎么求解呢? 如果按照常规代码 int main() {int sum 0;sum 1;sum 2;sum 3;...sum 100;printf("The value of sum is %d\n",sum);return 0; }就会特别麻烦,并且代码过于冗长。下面将引入循环的概念…

硬件知识的基础学习

GPIO、继电器、三极管、PWM、MOS管 的 输入与输出。 本人没有系统的学习过专业的硬件知识,只有在实践过程中向前辈简单的学习,若有问题,还请大佬指正。 目录 一、GPIO 1.1 输入与输出的区别 1.2 输入 1.2.1 电流流向和电阻区分上拉输入…

动力节点老杜Vue笔记——Vue程序初体验

目录 一、Vue程序初体验 1.1 下载并安装vue.js 1.2 第一个Vue程序 1.3 Vue的data配置项 1.4 Vue的template配置项 一、Vue程序初体验 可以先不去了解Vue框架的发展历史、Vue框架有什么特点、Vue是谁开发的,对我们编写Vue程序起不到太大的作用,…

计算机网络 实验六

⭐计网实验专栏,欢迎订阅与关注! ★观前提示:本篇内容为计算机网络实验。内容可能会不符合每个人实验的要求,因此以下内容建议仅做思路参考。 一、实验目的 掌握以太网帧的格式及各字段的含义掌握IP包的组成格式及各字段的含义掌…

java中HashMap的使用

HashMap 键值对关系,值可以重复,可以实现多对一,可以查找重复元素 记录: 做算法遇到好多次了,就总结一下大概用法。 例如今天遇到的这个题: 寻找出现一次的数,那就使用哈希表来存储&#xf…

X射线吸收光谱知识点

1) 什么是XAS XAS是X-ray Absorbtion Spectra的缩写,全称为X射线吸收光谱。X射线透过样品后,其强度发生衰减且其衰减程度与材料结构、组成有关。这种研究透射强度I与入射X射线强度Io之间的关系,称为X射线吸收光谱;由于其透射光强与元素、原子…