深度学习基础知识-tf.keras实例: 加州房价预测

news2025/1/11 6:57:02

参考书籍:《Hands-On Machine Learning with Scikit-Learn, Keras, and TensorFlow, 2nd Edition (Aurelien Geron [Géron, Aurélien])》

代码有修改,已测通。


简单顺序结构

这次得数据集比之前得简单,只包含数字型特征,没有ocean_proximity,也没有缺失值。

如果 sklearn.datasets.fetch_california_housing 报错 urllib.error.HTTPError: HTTP Error 403: Forbidden,那么下载文件cal_housing_py3.pkz放到 sklearn.datasets.get_data_home()下,这里是 C:\Users\用户名\scikit_learn_data
参考:https://blog.csdn.net/qq_44644355/article/details/107054585

from sklearn.datasets import fetch_california_housing, get_data_home
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler

def load_housing_data():
    housing = fetch_california_housing()
    # 默认划分是3:1,aka 75%train, 25%test
    X_train_full, X_test, y_train_full, y_test = train_test_split(
        housing.data, housing.target)
    X_train, X_valid, y_train, y_valid = train_test_split(
        X_train_full, y_train_full)
    scaler = StandardScaler()
    # fit_transform会计算数据的均值和方差,transform不会
    # 而且前者一般只在训练集进行,后者则是在训练集和测试集都可以用
    X_train = scaler.fit_transform(X_train)
    X_valid = scaler.transform(X_valid)
    X_test = scaler.transform(X_test)
    return X_train, X_valid, X_test, y_train, y_valid, y_test

#print(get_data_home())

X_train, X_valid, X_test, y_train, y_valid, y_test = load_housing_data()

model = keras.models.Sequential([
    keras.layers.Dense(30, activation="relu", input_shape=X_train.shape[1:]),
    keras.layers.Dense(1)
])
model.compile(loss="mean_squared_error", optimizer="sgd")
history = model.fit(X_train, y_train, epochs=30, validation_data=(X_valid, y_valid))
mse_test = model.evaluate(X_test, y_test)
print(mse_test)
X_new = X_test[:3]
y_pred = model.predict(X_new)
print(y_pred)

复杂结构

单输入

在这里插入图片描述

input_ = keras.layers.Input(shape=X_train.shape[1:])
hidden1 = keras.layers.Dense(30, activation="relu")(input_)
hidden2 = keras.layers.Dense(30, activation="relu")(hidden1)
concat = keras.layers.concatenate(inputs=[input_, hidden2])
output = keras.layers.Dense(1)(concat)
model = keras.Model(inputs=[input_], outputs=[output])
#model.compile(loss="mse", optimizer=keras.optimizers.SGD(lr=1e-3))

多输入

在这里插入图片描述

input_A = keras.layers.Input(shape=[5], name="wide_input")
input_B = keras.layers.Input(shape=[6], name="deep_input")
hidden1 = keras.layers.Dense(30, activation="relu")(input_B)
hidden2 = keras.layers.Dense(30, activation="relu")(hidden1)
concat = keras.layers.concatenate([input_A, hidden2])
output = keras.layers.Dense(1, name="output")(concat)
model = keras.Model(inputs=[input_A, input_B], outputs=[output])
model.compile(loss="mse", optimizer=keras.optimizers.SGD(lr=1e-3))

# 划分数据。0-4是输入wide_input的,2-最后是输入deep_input的
X_train_A, X_train_B = X_train[:, :5], X_train[:, 2:]
X_valid_A, X_valid_B = X_valid[:, :5], X_valid[:, 2:]
X_test_A, X_test_B = X_test[:, :5], X_test[:, 2:]
X_new_A, X_new_B = X_test_A[:3], X_test_B[:3]
# 训练
history = model.fit((X_train_A, X_train_B), y_train, epochs=30,
                    validation_data=((X_valid_A, X_valid_B), y_valid))
mse_test = model.evaluate((X_test_A, X_test_B), y_test)
print(mse_test)
y_pred = model.predict((X_new_A, X_new_B))
print(y_pred)

如果需要增加一个输出,如下图所示,可以这样改:
在这里插入图片描述

# 其他保持不变
aux_output = keras.layers.Dense(1, name="aux_output")(hidden2)
model = keras.Model(inputs=[input_A, input_B], outputs=[output, aux_output])
model.compile(loss=["mse","mse"], loss_weights=[0.9, 0.1], optimizer=keras.optimizers.SGD(lr=1e-3))
# 假设aux_output预测的也是同样的东西
history = model.fit([X_train_A, X_train_B], [y_train, y_train], epochs=30,
                    validation_data=([X_valid_A, X_valid_B], [y_valid, y_valid]))
# 此时有总loss和每个输出的loss
total_loss, main_loss, aux_loss = model.evaluate([X_test_A, X_test_B], [y_test, y_test])
print((total_loss, main_loss, aux_loss))
# 预测结果也会有多个
y_pred_main, y_pred_aux = model.predict([X_new_A, X_new_B])
print((y_pred_main, y_pred_aux))

如果需要动态调整网络,比如在某些情况下需要进入循环或者分支,那可以写一个新的类,如下面所示。这样的好处是网络组织更加灵活,而且summary()只能打印层的列表,不能传递层之间的连接方式;缺点是不能clone或者保存(不能用hdf5格式,只能save_weights和load_weights勉强保存一下),有时候也可能出错。

class WideAndDeepModel(keras.Model):
    def __init__(self, units=30, activation="relu", **kwargs):
        super().__init__(**kwargs) # handles standard args (e.g., name)
        self.hidden1 = keras.layers.Dense(units, activation=activation)
        self.hidden2 = keras.layers.Dense(units, activation=activation)
        self.main_output = keras.layers.Dense(1)
        self.aux_output = keras.layers.Dense(1)

    # 这里可以写loop if什么的
    def call(self, inputs):
        input_A, input_B = inputs
        hidden1 = self.hidden1(input_B)
        hidden2 = self.hidden2(hidden1)
        concat = keras.layers.concatenate([input_A, hidden2])
        main_output = self.main_output(concat)
        aux_output = self.aux_output(hidden2)
        return main_output, aux_output

model = WideAndDeepModel()

保存模型可以用pickle或者joblib的dump,也可以直接:

model.save("xxx.h5")

这里使用HDF5格式保存架构、超参数和每层的参数,也会保存optimizer。等再载入到时候可以用:

model = keras.models.load_model("xxx.h5")

有时候需要早点停止,可以加Callbacks。比如ModelCheckpoint就保存了某些时间点模型的checkpoints。默认是每个epoch结束时。

# build and compile the model
checkpoint_cb = keras.callbacks.ModelCheckpoint("my_keras_model.h5", save_best_only=True)
history = model.fit(X_train, y_train, epochs=10,
                    alidation_data=(X_valid, y_valid), callbacks=[checkpoint_cb])
model = keras.models.load_model("my_keras_model.h5") # roll back to best model

此时model只保存了最好的模型。

另外,也可以使用EarlyStopping,指如果在验证集上,若干个(patience参数)epoch没有进步了,就停止训练。也可以同时使用checkpoints和earlystopping。

early_stopping_cb = keras.callbacks.EarlyStopping(patience=10, restore_best_weights=True)

history = model.fit(X_train, y_train, epochs=100, validation_data=(X_valid, y_valid),
	callbacks=[checkpoint_cb, early_stopping_cb])

也可以写自定义的callbacks。可以选的时间点有:on_train_begin(), on_train_end(), on_epoch_begin(), on_epoch_end(), on_batch_begin(), and on_batch_end()。还可以在test和predict种插入callbacks,前者是evaluate()调用的,后者是predict()调用的。

class PrintValTrainRatioCallback(keras.callbacks.Callback):
    def on_epoch_end(self, epoch, logs):
        print("\nval/train: {:.2f}".format(logs["val_loss"] / logs["loss"]))

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/616322.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

leetcode98. 验证二叉搜索树(java)

验证二叉搜索树 leetcode98. 验证二叉搜索树题目描述 递归法解题思路代码演示 中序遍历解法解题思路代码演示 二叉树专题 leetcode98. 验证二叉搜索树 leetcode 98.验证二叉搜索树 来源:力扣(LeetCode) 链接:https://leetcode.cn/…

Linux开发中的辅助工具

文章目录 前言一、add2line二、strip三、ar四、nm五、objdump六、size七、strings总结 前言 本篇文章我们来介绍一些Linux开发中的辅助工具,有了这些辅助工具将会让我们的开发变的更加轻松。 一、add2line addr2line是一个GNU调试工具,用于将程序计数…

priority_queue的模拟实现和仿函数

priority_queue模拟 首先查看源代码,源代码就在queue剩下的部分中 push_heap是STL库中的堆算法,STL库中包装有支持堆的算法,在algorithm.h中: 只要不断用堆的形式插入数据,就会形成堆。 priority_queue模拟——初版 pr…

自定义组件中,使用onLoad,onShow生命周期失效问题

的解决方法 自定义组件中,使用onLoad,onShow生命周期失效问题 自定义组件中,使用onLoad,onShow生命周期失效问题 官方文档可查阅到: 页面生命周期仅在page中的vue页面有效,而单独封装的组件中【页面周期无效】,但是Vu…

Pytorch入门(四)使用VGG16网络训练CIFAR10数据集

本文使用PytorchVGG16官方CIFAR10数据集完成图像分类。识别效果如下: 文章目录 一、VGG16 神经网络结构二、VGG16 模型训练三、预测CIFAR10中的是个类别 一、VGG16 神经网络结构 VGG,又叫VGG-16,顾名思义就是有16层,包括13个卷…

地震勘探基础(十)之地震速度关系

地震速度 地震勘探中引入了多种速度的概念,如下图所示。 层速度、平均速度和均方根速度之间的关系 层速度指的是某一套地层垂向上,由于地质条件相对稳定,地层顶底厚度比上地震波的传播时间为层速度,用 v n v_n vn​ 表示。 如下…

一文看懂软件架构4+1视图

目录 一、概述 二、各视图详解 1. 场景视图 2. 逻辑视图 3. 开发视图 4. 处理视图 5. 物理视图 葵花宝典:一看就懂的理解方式 一、概述 41视图包括: 场景视图(也叫用例视图):黑盒视图。从外部视角&#xff…

chatgpt赋能python:Python如何分段数据的平均数

Python如何分段数据的平均数 Python是一门极其流行的编程语言,广泛应用于数据分析与科学计算领域。在数据分析中,计算各个数据段的平均数是一项常见的任务。本文将介绍如何使用Python分段计算数据的平均数,以及如何优化这一过程以使速度更快…

Linux中的lrzsz

一、介绍 lrzsz是一款在Linux里可代替ftp上传和下载的程序,也就是一款软件。它是开发者常用的一款工具,这个工具用于windows机器和远端的Linux机器通过XShell传输文件。 二、lrzsz的安装 在安装之前,我们可以使用下述命令先查看yum仓库中是否存在我们要安装的软件: yum…

CentOS7使用Docker快速安装Davinci

环境信息 操作系统:CentOS7Docker : 23.0.6 (已配置阿里云镜像加速) 安装步骤 安装docker-compose-plugin 官方的例子使用的是docker-compose,但是由于yum能够安装的最新斑斑是1.x,而且官方的docker-compose要求最低版本为2.2以…

首个区块链技术领域国家标准出台 ,中创助力打造区块链技术和应用创新高地

区块链作为数字中国的重要技术底座,正在深刻改变着我国社会生产方式。何谓区块链,对大众来说,也许尚陌生,殊不知,这一产业已稳稳起跑在我国高质量发展的“赛道”上。 近日,《区块链和分布式记账技术参考架…

【JavaScript】超全基础万字大总结

目录 一、初识 JavaScript 1.1 JavaScript 是什么? 1.2 发展历史 1.3 JavaScript 和 HTML 和 CSS 之间的关系 1.4 JavaScript 运行过程 1.5 JavaScript 的组成 二、前置知识 2.1 第一个程序 2.2 JavaScript 的书写形式 2.3 输入输出 三、语法概览 3.1 变…

Linux(CentOS 7) 安装 Mysql8 、Java 以及 mycat2 详细流程

目录 一、Mysql8 安装 1.下载mysql8 2. 解压Mysql 压缩包 3.重名命mysql 文件 4.创建data文件夹 储存文件 5.创建用户组以及用户 6.授权用户 将mysql文件夹的所有者和所有组都改为mysql 7.mysql初始化进入bin目录执行mysqld文件进行初始化 8.编辑my.cnf 9.添加mysqld…

有哪些虚拟化和容器化工具推荐? - 易智编译EaseEditing

以下是几个常用的虚拟化和容器化工具推荐: VMware vSphere: VMware vSphere 是一套完整的虚拟化平台,包括虚拟化服务器、虚拟化存储和虚拟化网络。 它提供了高性能的虚拟机管理和资源调度功能,适用于企业级的虚拟化部署。 Docke…

IT知识百科:什么是跨站脚本(XSS)攻击?

跨站脚本(Cross-Site Scripting,XSS)是一种常见的网络安全漏洞,攻击者利用该漏洞在受害者的网页中插入恶意脚本,从而能够获取用户的敏感信息、劫持会话或进行其他恶意活动。本文将详细介绍跨站脚本攻击的原理、类型、常…

vue props传值层级多,子级孙子级怎么修改传参

vue props传值层级多了,子级孙子级怎么修改传参 1.出现背景2.怎么在孙组件里改变传过来的值呢2.1这样改是不行的2.2可行的方法2.2.1 引用对象只改变单属性2.2.2 provide和inject 1.出现背景 本来自己写页面的话是直接全部写在一个vue文件里,一个vue文件…

【solidworks】此文档 templates\gba0.drwdot 使用字体长仿宋体,而该字体不可用

一、问题背景 在SolidWorks中绘制工程图纸时,新建一个图纸,打开后就弹出字体错误 此文档 templates\gba0.drwdot 使用字体长仿宋体,而该字体不可用。 二、解决办法 点击选择新的字体,拖到最下面选择汉仪长仿宋体。 上面之所…

41 管理虚拟机可维护性-虚拟机NMI Watchdog

文章目录 41 管理虚拟机可维护性-虚拟机NMI Watchdog41.1 概述41.2 注意事项41.3 操作步骤 41 管理虚拟机可维护性-虚拟机NMI Watchdog 41.1 概述 NMI Watchdog是一种用来检测Linux出现hardlockup(硬死锁)的机制。通过产生NMI不可屏蔽中断,…

win10+tf2.x+cuda+cudnn踩坑记录( Loaded cuDNN version 8400)

项目场景: 项目用到了tensorflow2.x: 想要用GPU跑算法win10系统下需要安装cuda和cudnn配置带有tenserflow-gpu的环境 问题描述 jyputer运行错误提示:Loaded cuDNN version 8400 Could not locate zlibwapi.dll. Please make sure it is in…

智安网络|保护企业网络空间资产安全的重要性

在数字化时代,企业网络空间资产的安全和保护变得越来越重要,并且拥有安全性能优越、系统完整的企业网络系统,是企业发展的必要条件。但想要实现网络空间安全首先需要关注网络漏洞问题。 保护企业网络空间资产的重要性 网络空间资产安全是企…