《Keras深度学习:入门、实战与进阶》之回归问题实例:波士顿房价预测

news2024/9/20 17:54:39

本文摘自《Keras深度学习:入门、实战与进阶》。
在这里插入图片描述

本节将要预测20世纪70年代中期波士顿郊区房屋价格的中位数。这个数据是1978年统计收集的,数据集中的每一行数据都是对波士顿周边或城镇房价的描述,包含以下14个特征和506条数据。
 CRIM:城镇人均犯罪率。
 ZN:住宅用地所占比例。
 INDUS:城镇中非住宅用地所占比例。
 CHAS:虚拟变量,用于回归分析。
 NOX:环保指数。
 RM:每栋住宅的房间数。
 AGE:1940年以前建成的自住单位的比例。
 DIS:距离5个波士顿的就业中心的加权距离。
 RAD:距离高速公路的便利指数。
 TAX:每一万美元的不动产税率。
 PTRATIO:城镇中教师和学生的比例。
 B:城镇中黑人的比例。
 LSTAT:地区中有多少房东属于低收入人群。
 MEDV:自住房屋房价中位数。
通过Keras API把波士顿房屋价格数据集导入到R中,并查看测试集和训练集的样本数量。

> library(keras)
> # 1. 导入数据
> boston_housing <- dataset_boston_housing()
> c(train_data, train_labels) %<-% boston_housing$train
> c(test_data, test_labels) %<-% boston_housing$test
> cat('训练样本数量:',length(train_labels),'\n',
+     '测试样本数量:',length(test_labels))
训练样本数量: 404 
测试样本数量: 102

这个数据集比MNIST数据集小很多:它一共有506个样本,分为404个训练样本和102个测试样本。
在对数据集做缺失值插补和特征预处理之前,让我们观察训练集中各列数据的描述统计分析。skimr软件包提供了一个很好的解决方案,可以显示每列的关键描述统计信息。skim()函数会生成包含每一列的描述统计的数据框,并包含一个直方图,可以直观查看数值变量的数据分布情况。

> # 添加列名称
> library(tibble)
> column_names <- c('CRIM', 'ZN', 'INDUS', 'CHAS', 'NOX', 'RM', 'AGE', 
+                   'DIS', 'RAD', 'TAX', 'PTRATIO', 'B', 'LSTAT')
> train_df <- as_tibble(train_data)
> colnames(train_df) <- column_names
> # 对数据进行描述统计分析
> if(!require(skimr)) install.packages("skimr")
> skimmed <- skim(train_df)
> skimmed

在这里插入图片描述
train_data一共有404行13列,13列均为数值变量,最后一部分是各数值变量的描述统计分析。n_missing是统计各变量的样本缺失数量,此数据各变量均为0,说明无数据缺失;complete_rate是数据完整度,此数据各变量均为1;mean、sd、p0、p25、p50、p75、p100依次为均值、标准差、最小值、第一四分位数、中位数、第三四分位数、最大值统计指标,可见各变量间存在尺度不一致情况,需要在建模前进行数据标准化处理;最后一列是数据分布的直方图可视化展示。
从描述统计分析结果可知,输入变量中的各列数据范围差异比较大。在建模前,需先对数据集进行标准化处理。 此案例使用scale()函数进行Z-Score标准化,处理后训练集中各列数据符合标准正态分布,即均值为0,标准差为1。

> # 对train_data进行标准化
> train_data <- scale(train_data)

接着使用对训练集标准化后得到的各列均值和标准差对测试集数据进行数据处理。

> col_means_train <- attr(train_data, "scaled:center") 
> col_stddevs_train <- attr(train_data, "scaled:scale")
> test_data <- scale(test_data, center = col_means_train, scale = col_stddevs_train)

最后,将标准化后的因变量和自变量进行合并,形成包含标签的训练及测试数据集。

> all_train_data=cbind(train_data,train_labels)
> all_test_data=cbind(test_data,test_labels)
> all_train_data=as.data.frame(all_train_data)
> all_test_data=as.data.frame(all_test_data)
> colnames(all_train_data) <- c(column_names,'MEDV')
> colnames(all_test_data) <- c(column_names,'MEDV')

经过上一小结的数据预处理,训练集和测试集已经达到深度学习建模要求。本小节将利用全连接神经网络进行模型构建、模型训练及模型预测等工作。
让我们建立一个序贯模型,该模型具有两个全连接的隐藏层,神经元数量均为64,采用ReLU激活函数。因为这是一个回归问题,不需要将预测结果进行分类转换,所以输出层不设置激活函数,直接输出数值。
模型定义完成后,需要对模型进行编译,编译模型是为了使模型能够有效地使用Keras封装的数值计算。Keras可以根据后端自动选择最佳方式来训练模型,并进行预测。编译时,必须指定训练模型时所需的一些属性。训练一个神经网络模型,意味着找到最好的权重集来对这个问题作出预测。
编译模型时,必须指定用于评估一组权重的损失函数(loss)、用于搜索网络不同权重的优化器(optimizer),以及希望在模型训练期间收集和报告的可选指标。在这个例子中,采用Adam优化器,均方误差(MSE)作为损失函数。同时采用平均绝对误差(MAE)来评估模型的性能,值越小代表模型的性能越好。

> # 构建模型函数
> build_model <- function() {
+   
+   model <- keras_model_sequential() %>%
+     layer_dense(units = 64, activation = "relu",
+                 input_shape = dim(train_data)[2]) %>%
+     layer_dense(units = 64, activation = "relu") %>%
+     layer_dense(units = 1)
+   
+   model %>% compile(
+     loss = "mse",
+     optimizer = optimizer_rmsprop(),
+     metrics = list("mean_absolute_error")
+   )
+   
+   model
+ }

模型编译完成后,就可以用于计算了。在使用模型预测新数据前,需要先对模型进行训练。模型通过调用fit方法来实现。训练过程将采用epochs参数,对数据集进行固定次数的迭代,因此必须指定epochs参数大小进行模型训练,还需要设置在执行神经网络中的权重更新的每个批次中所用实例的个数(batch_size,默认为32)。
在这个例子中,将运行一个较小的epochs参数150,batch_size使用默认的32,且我们设置一个回调函数,如果经过20次训练周期后验证集的损失函数没有明显改善,将自动停止训练。

> # 设置回调函数的停止条件
> early_stop <- callback_early_stopping(monitor = "val_loss", patience = 20)
> # 训练模型
> mlp_model <- build_model()
> history <- mlp_model %>% fit(
+   train_data,
+   train_labels,
+   epochs = 150, 
+   validation_split = 0.2,
+   verbose = 0,
+   callbacks = early_stop
+ )
> plot(history)

在这里插入图片描述
模型在经过50多次训练周期后停止了训练。通过以下命令可以查看模型的训练周期。

> cat('模型训练周期的次数为:','\n',length(history$metrics$val_loss))
模型训练周期的次数为: 
 55

利用callback_early_stopping回调函数后,训练模型在55次训练周期后停止了训练。因为模型的验证集的损失函数值在最后20次的训练周期均没有再改善,所以停止训练。
因为回调函数参数min_delta默认值为0,所以当出现训练周期为35时的验证集的损失函数值均大于后面20次训练周期的验证集的损失函数值,模型停止训练。以下代码查看当训练周期为35时的验证集损失函数值,并计算最后20次训练周期与其的差值。

> cat('epoch为35时的验证集损失函数值:','\n',
+     history$metrics$val_loss[35])
epoch为35时的验证集损失函数值: 
 13.57665
> diff <- history$metrics$val_loss[36:55] - history$metrics$val_loss[35]
> round(diff,1)
 [1] 0.1 0.4 0.0 1.0 0.1 0.7 0.6 1.0 0.8 1.4 0.9 0.0 0.7 0.1 0.6 0.4 1.8 2.2 0.7 1.1

因为回调函数参数restore_best_weights默认值为FALSE, 则模型将会使用在训练的最后一步获得的权重值。
最后,利用训练好的模型对测试样本的房价进行预测,并计算与实际值的平均绝对误差值(MAE)。

> # 对测试样本进行预测
> test_predictions <- mlp_model %>% 
+   predict(test_data)
> # 查看平均绝对误差
> mae <- mean(abs(all_test_data$MEDV-test_predictions))
> paste0('测试集上的平均绝对误差: $',
+        sprintf("%.f", mae * 1000))
[1] "测试集上的平均绝对误差: $3043"

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/197065.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

verilog图像算法实现和仿真(代码与实践)

【声明:版权所有,欢迎转载,请勿用于商业用途。 联系信箱:feixiaoxing @163.com】 这里的代码指的是verilog代码,而不是之前的python代码。因为verilog处理的是数据,所以之前我们也谈到过,如果需要用verilog处理图像数据,需要先用python把图像变成文本文件,等到…

菜鸟的进阶--手写一个微型Spring

前言想干嘛深入了解spring原理&#xff0c;特别是IOC容器是如何实现的&#xff1f;AOP是如何实现的&#xff1f;手写一个spring迷你版框架&#xff0c;实现容器和AOP机制。我为什么想这么做spring是整个java体系中最重要的框架&#xff0c;它整合第三方技术&#xff0c;将所有的…

交联剂134272-64-3,Maleimide-NH2 HCl,2-马来酰亚胺乙胺盐酸盐

【中文名称】N-(2-氨乙基)马来酰亚胺盐酸盐&#xff0c;2-马来酰亚胺乙胺盐酸盐【英文名称】 MAL-NH2 HCl&#xff0c;Maleimide-NH2 HCl&#xff0c;MAL NH2 HCl&#xff0c;Maleimide-amine HCl&#xff0c;MAL-amine HCl&#xff0c;N-(2-AMinoethyl)MaleiMide Hydrochlorid…

5年老测试员,面试被刷,别人说他不懂自动化测试.....

圈内认识的朋友最近跳槽了&#xff0c;之前在一家小公司干了5年测试&#xff0c;本来以为很容易跳一个高待遇的工作&#xff0c;结果却比想象的难&#xff0c;因为他不会自动化测试… 最近也看了很多人的简历&#xff0c;写的都是3年工作经验&#xff0c;但面试中&#xff0c…

对数据库几个范式的理解

数据库关系理论 这部分主要是几个概念很抽象&#xff0c;大家开始学可能学不明白。最近在准备复试&#xff0c;复习了一下相关的内容&#xff0c;顺便做一下总结。 先说几个名词&#xff1a; 候选码&#xff1a;能够唯一确定一个元组的属性集合称为候选码。注意是集合&#…

每日学术速递2.3

CV - 计算机视觉 | ML - 机器学习 | RL - 强化学习 | NLP 自然语言处理 Subjects: cs.Cv、cs.LG 1.Compositional Prompt Tuning with Motion Cues for Open-vocabulary Video Relation Detection(ICLR 2023) 标题&#xff1a;通过基于错误的隐性神经表征的上下文修剪实现高…

Java基础学习笔记(十五)—— 集合(3)

集合1 HashMap 类1.1 HashMap 类概述1.2 HashMap 案例2 TreeMap 类2.1 TreeMap 类概述2.2 TreeMap 案例3 Properties集合3.1 Properties集合概述3.2 Properties基本使用3.3 Properties特有方法3.4 Properties和IO流相结合的方法4 可变参数与不可变集合4.1 可变参数4.2 不可变集…

2023.1.26

0、任务 今明两天任务&#xff0c;回答以下问题&#xff1a; 1、网络传输延迟有哪些&#xff1f;如何区分传输延迟和排队延迟&#xff1f; 2、如何理解路由器存储转发的过程&#xff1f; 3、拥塞是什么&#xff0c;为什么会发生拥塞&#xff0c;发生拥塞的表现是什么&#xff…

网络资源下载方式:http/https、ftp/sftp、BT种子、磁力下载、ed2k下载等的区别

文章目录参考资料序言中心化下载http/https下载ftp/sftp下载http与ftp下载方式的不同中心化下载的缺点中心化下载BT种子下载磁力下载ed2k下载推荐的下载器IDM下载器安装步骤IDM如何下载种子文件参考资料 一文读懂Bt种子、磁力链接、直链、p2p这些下载的区别 常说的BT下载、磁力…

【数据结构基础】图 - 基础和Overview

图(Graph)是由顶点和连接顶点的边构成的离散结构。在计算机科学中&#xff0c;图是最灵活的数据结构之一&#xff0c;很多问题都可以使用图模型进行建模求解。例如: 生态环境中不同物种的相互竞争、人与人之间的社交与关系网络、化学上用图区分结构不同但分子式相同的同分异构体…

情人节该送女友什么?分享四款适合送女生的数码好物

情人节快到了&#xff0c;对于有伴侣的人来说&#xff0c;这是一个浪漫的日子。在这个浪漫的日子&#xff0c;一些生活仪式感是必不可少的。最近看到不少人问&#xff0c;适合女生的数码好物有哪些&#xff1f;下面&#xff0c;我来给大家推荐几款适合送女生的数码好物&#xf…

动态规划DP与记忆化搜索DFS 题单刷题(c++实现+AC代码)

文章目录数字三角形滑雪挖地雷最大食物链计数采药疯狂的采药5倍经验值过河卒洛谷动态规划入门题单&#xff1a; 提单传送门 数字三角形 观察下面的数字金字塔。写一个程序来查找从最高点到底部任意处结束的路径&#xff0c;使路径经过数字的和最大。每一步可以走到左下方的点也…

“深度学习”学习日记。卷积神经网络--卷积层

2023.2.3 CNN中出现一些新的概念&#xff1a;填充、步幅 等&#xff0c;此外各层中传递的数据是有形状的&#xff0c;与之前的全连接层神经网络完全不同&#xff1b; 一、全连接层存在的问题&#xff1a; 全连接层神经网络使用了Affine层&#xff0c;在相邻的神经元全部连接…

php7.3.4 pdo方式连接sqlserver 设置方法

我这边用的php是7.3.4版本的&#xff0c;大家设置的时候看一下。一、首先要开启php的sqlsrv扩展1.下载SQLSRV58.EXE,我的php版本是7.3.4https://docs.microsoft.com/en-us/sql/connect/php/release-notes-php-sql-driver?viewsql-server-2017#previous-releases拷贝到浏览器打…

内网渗透(二)之基础知识-工作组介绍

系列文章 内网渗透(一)之基础知识-内网渗透介绍和概述 注&#xff1a;阅读本编文章前&#xff0c;请先阅读系列文章&#xff0c;以免造成看不懂的情况&#xff01;&#xff01; 工作组介绍 1、工作组的介绍 在一个大型单位里,可能有成百上千台计算机互相连接组成局域网,它…

Rancher 部署 MongoDB

文章目录前置部署创建 Headless开始部署测试前置 背景&#xff1a;在 K8S 集群用 bitnami 部署 MongoDB 有一定的学习成本&#xff0c;有兴趣可以参考 k8s 部署 mongodb 三种模式&#xff0c;且部署后发现 MongoDB 会随着时间推移占用越来越多的内存&#xff0c;暂没找到原有&…

计算机如何在本地硬盘安装WinPE系统

环境&#xff1a; 联想E14 Win 10专业版 U盘魔术师V6 30G硬盘分区 双硬盘&#xff1a;128G固&#xff0b;1T机 DiskGenius UltraISO 问题描述&#xff1a; 如何在本地硬盘安装WinPE系统 解决方案&#xff1a; 一、使用软件制作硬盘PE系统 1.机械磁盘先分区分一个30G分区 …

Java 中的Type类型及其实现【学习记录】

概述 在JDK1.5之前只有原始类型&#xff0c;此时所有的原始类型都通过字节码文件类Class进行抽象。Class类的一个具体对象就代表一个指定的原始类型。 JDK1.5加入了泛型类&#xff0c;扩充了数据类型&#xff0c;从只有原始类型基础上扩充了参数化类型、类型变量类型、通配符…

OpenStack使用Skyline Dashboard面板替换默认Horizon面板

书接上回 OpenStack Yoga安装使用kolla-ansible 忘记提示了。如果截止发稿今天&#xff0c;使用最新zed版本&#xff0c;在最后一步部署阶段会报错&#xff0c;好像是rabbitMQ重启失败。所以建议使用最新版再退一个版本 官方文档 skyline-apiserver/README-zh_CN.md at maste…

一文入门图像分类

文章目录一、卷积网络1.1 卷积的参数量1.2 卷积的计算量1.3 降低模型参数量和计算量的方法1.3.1 GoogLeNet 使用不同大小的卷积核1.3.2 ResNet 使用11卷积压缩通道数1.3.3 可分离卷积二、Transformer2.1 注意力机制 Attention Mechanism2.2 多头注意力 Multi-head (Self-)Atten…