TensorFlow2实战-系列教程1:回归问题预测

news2025/1/16 13:45:43

🧡💛💚TensorFlow2实战-系列教程 总目录

有任何问题欢迎在下面留言
本篇文章的代码运行界面均在Jupyter Notebook中进行
本篇文章配套的代码资源已经上传

1、环境测试

import tensorflow as tf
import numpy as np
tf.__version__

打印结果

‘2.10.0’

x1 =[[1,9],[3,6]]
x2 = tf.constant(x1)
print(x1)
print(x2)

打印结果:

[[1, 9], [3, 6]]
tf.Tensor( [[1 9] [3 6]], shape=(2, 2), dtype=int32)

2、导包读数据

import numpy as np
import pandas as pd 
import matplotlib.pyplot as plt
import tensorflow as tf
from tensorflow.keras import layers
import tensorflow.keras
import warnings
warnings.filterwarnings("ignore")
%matplotlib inline

features = pd.read_csv('temps.csv')
features.head()

打印结果:

在这里插入图片描述

  • year,moth,day,week分别表示的具体的时间
  • temp_2:前天的最高温度值
  • temp_1:昨天的最高温度值
  • average:在历史中,每年这一天的平均最高温度值
  • actual:这就是我们的标签值了,当天的真实最高温度
  • friend:这一列可能是凑热闹的,你的朋友猜测的可能值,咱们不管它就好了

星期是个文本特性,用onehot转换一下:

features = pd.get_dummies(features)
features.head(5)

在这里插入图片描述

3、标签制作与数据预处理

# 标签
labels = np.array(features['actual'])

# 在特征中去掉标签
features= features.drop('actual', axis = 1)

# 名字单独保存一下,以备后患
feature_list = list(features.columns)

# 转换成合适的格式
features = np.array(features)

打印结果:

(348, 14)

from sklearn import preprocessing
input_features = preprocessing.StandardScaler().fit_transform(features)
input_features[0]

打印结果:

array([ 0. , -1.5678393 , -1.65682171, -1.48452388, -1.49443549, -1.3470703 , -1.98891668, 2.44131112, -0.40482045, -0.40961596, -0.40482045, -0.40482045, -0.41913682, -0.40482045])

4、 基于Keras构建网络模型

常用参数:

  • activation:激活函数的选择,一般常用relu
  • kernel_initializer,bias_initializer:权重与偏置参数的初始化方法
  • kernel_regularizer,bias_regularizer:要不要加入正则化
  • inputs:输入,可以自己指定,也可以让网络自动选 units:神经元个数

按顺序构造网络模型:

model = tf.keras.Sequential()
model.add(layers.Dense(16))
model.add(layers.Dense(32))
model.add(layers.Dense(1))
  1. 创建一个执行序列
  2. 添加全连接层,16个神经元
  3. 添加全连接层,32个神经元
  4. 添加全连接层,1个神经元,作为最后的输出

定好优化器和损失函数,然后训练:

model.compile(optimizer=tf.keras.optimizers.SGD(0.001), loss='mean_squared_error')
model.fit(input_features, labels, validation_split=0.25, epochs=10, batch_size=64)
  1. 指定SGD为优化器学习率为0.001,MSE为损失函数
  2. 指定数据和标签然后训练,25%为验证集,10个epochs

打印结果:

Epoch 1/10 5/5 - 0s 33ms/step - loss: 4267.9907 val_loss: 3133.0610
Epoch 2/10 5/5 0s 4ms/step - loss: 1925.8059 - val_loss: 3318.1531
Epoch 3/10 5/5 - 0s 3ms/step - loss: 181.2731 val_loss: 2728.9922
Epoch 4/10 5/5 0s 3ms/step - loss: 104.3410 - val_loss: 2093.8855
Epoch 5/10 5/5 - 0s 3ms/step - loss: 77.6116 - val_loss: 1377.6144
Epoch 6/10 5/5 0s 3ms/step - loss: 73.3877 - val_loss: 1163.6123
Epoch 7/10 5/5 0s 3ms/step - loss: 60.4262 val_loss: 867.4617
Epoch 8/10 5/5 0s 3ms/step - loss: 73.3110 - val_loss: 654.7820
Epoch 9/10 5/5 0s 3ms/step - loss: 36.6109 val_loss: 581.9786
Epoch 10/10 5/5 0s 3ms/step - loss: 56.6764 - val_loss: 383.0244
<keras.callbacks.History at 0x22634a22760>

从打印结果来看,训练集的损失和验证集的损失差距比较大,可能出现过拟合的现象

输入数据:

input_features.shape

打印结果:

(348, 14)

查看网络结构:

model.summary()

打印结果:

Model: “sequential”


Layer (type) Output Shape Param #

dense (Dense) multiple 240


dense_1 (Dense) multiple 544


dense_2 (Dense) multiple 33

Total params: 817
Trainable params: 817
Non-trainable params: 0


5、改初始化方法

model = tf.keras.Sequential()
model.add(layers.Dense(16,kernel_initializer='random_normal'))
model.add(layers.Dense(32,kernel_initializer='random_normal'))
model.add(layers.Dense(1,kernel_initializer='random_normal'))
model.compile(optimizer=tf.keras.optimizers.SGD(0.001), loss='mean_squared_error')
model.fit(input_features, labels, validation_split=0.25, epochs=100, batch_size=64)

部分打印结果:

Epoch 99/100 261/261 0s 42us/sample - loss: 27.9759 - val_loss: 41.2864
Epoch 100/100 261/261 0s 42us/sample - loss: 44.5327 - val_loss: 48.2574

很显然差距消失了

6、加入正则化惩罚项

model = tf.keras.Sequential()
model.add(layers.Dense(16,kernel_initializer='random_normal',kernel_regularizer=tf.keras.regularizers.l2(0.03)))
model.add(layers.Dense(32,kernel_initializer='random_normal',kernel_regularizer=tf.keras.regularizers.l2(0.03)))
model.add(layers.Dense(1,kernel_initializer='random_normal',kernel_regularizer=tf.keras.regularizers.l2(0.03)))
model.compile(optimizer=tf.keras.optimizers.SGD(0.001), loss='mean_squared_error')
model.fit(input_features, labels, validation_split=0.25, epochs=100, batch_size=64)

部分打印结果:

Epoch 99/100 261/261 0s 42us/sample - loss: 26.2268 - val_loss: 20.5562
Epoch 100/100 261/261 0s 42us/sample - loss: 24.3962 - val_loss: 21.1083

很显然结果更好了

7、展示测试结果

# 转换日期格式
dates = [str(int(year)) + '-' + str(int(month)) + '-' + str(int(day)) for year, month, day in zip(years, months, days)]
dates = [datetime.datetime.strptime(date, '%Y-%m-%d') for date in dates]

# 创建一个表格来存日期和其对应的标签数值
true_data = pd.DataFrame(data = {'date': dates, 'actual': labels})

# 同理,再创建一个来存日期和其对应的模型预测值
months = features[:, feature_list.index('month')]
days = features[:, feature_list.index('day')]
years = features[:, feature_list.index('year')]

test_dates = [str(int(year)) + '-' + str(int(month)) + '-' + str(int(day)) for year, month, day in zip(years, months, days)]

test_dates = [datetime.datetime.strptime(date, '%Y-%m-%d') for date in test_dates]

predictions_data = pd.DataFrame(data = {'date': test_dates, 'prediction': predict.reshape(-1)}) 

# 真实值
plt.plot(true_data['date'], true_data['actual'], 'b-', label = 'actual')

# 预测值
plt.plot(predictions_data['date'], predictions_data['prediction'], 'ro', label = 'prediction')
plt.xticks(rotation = '60'); 
plt.legend()

# 图名
plt.xlabel('Date'); plt.ylabel('Maximum Temperature (F)'); plt.title('Actual and Predicted Values');

在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1413847.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

【开源】基于JAVA的房屋出售出租系统

目录 一、摘要1.1 项目介绍1.2 项目录屏 二、功能模块2.1 房屋销售模块2.2 房屋出租模块2.3 预定意向模块2.4 交易订单模块 三、系统展示四、核心代码4.1 查询房屋求租单4.2 查询卖家的房屋求购单4.3 出租意向预定4.4 出租单支付4.5 查询买家房屋销售交易单 五、免责说明 一、摘…

前端怎么监听手机键盘是否弹起

摘要&#xff1a; 开发移动端中&#xff0c;经常会遇到一些交互需要通过判断手机键盘是否被唤起来做的&#xff0c;说到判断手机键盘弹起和收起&#xff0c;应该都知道&#xff0c;安卓和ios判断手机键盘是否弹起的写法是有所不同的&#xff0c;下面讨论总结一下两端的区别以及…

专业120+总分400+海南大学838信号与系统考研高分经验海大电子信息与通信

今年专业838信号与系统120&#xff0c;总分400&#xff0c;顺利上岸海南大学&#xff0c;这一年的复习起起伏伏&#xff0c;但是最后还是坚持下来的&#xff0c;吃过的苦都是值得&#xff0c;总结一下自己的复习经历&#xff0c;希望对大家复习有帮助。首先我想先强调一下专业课…

嵌入式学习第十一天

1.数组和指针的关系: 1.一维数组和指针的关系: int a[5] {1, 2, 3, 4, 5}; int *p NULL; p &a[0]; p a; 数组的数组名a是指向数组第一个元素的一个指针常量 a &a[0] a 的类型可以理解为 int * 有两种情况除…

《动手学深度学习(PyTorch版)》笔记4.4

注&#xff1a;书中对代码的讲解并不详细&#xff0c;本文对很多细节做了详细注释。另外&#xff0c;书上的源代码是在Jupyter Notebook上运行的&#xff0c;较为分散&#xff0c;本文将代码集中起来&#xff0c;并加以完善&#xff0c;全部用vscode在python 3.9.18下测试通过。…

SpringBoot自定义全局异常处理器

文章目录 一、介绍二、实现1. 定义全局异常处理器2. 自定义异常类 三、使用四、疑问 一、介绍 Springboot框架提供两个注解帮助我们十分方便实现全局异常处理器以及自定义异常。 ControllerAdvice 或 RestControllerAdvice&#xff08;推荐&#xff09;ExceptionHandler 二、…

软件设计师——计算机网络(四)

&#x1f4d1;前言 本文主要是【计算机网络】——软件设计师——计算机网络的文章&#xff0c;如果有什么需要改进的地方还请大佬指出⛺️ &#x1f3ac;作者简介&#xff1a;大家好&#xff0c;我是听风与他&#x1f947; ☁️博客首页&#xff1a;CSDN主页听风与他 &#x1…

架构整洁之道-设计原则

4 设计原则 通常来说&#xff0c;要想构建一个好的软件系统&#xff0c;应该从写整洁的代码开始做起。这就是SOLID设计原则所要解决的问题。 SOLID原则的主要作用就是告诉我们如何将数据和函数组织成为类&#xff0c;以及如何将这些类链接起来成为程序。请注意&#xff0c;这里…

chroot: failed to run command ‘/bin/bash’: No such file or directory

1. 问题描述及原因分析 在busybox的环境下&#xff0c;执行 cd rootfs chroot .报错如下&#xff1a; chroot: failed to run command ‘/bin/bash’: No such file or directory根据报错应该rootfs文件系统中缺少/bin/bash&#xff0c;进入查看确实默认是sh&#xff0c;换成…

vertica10.0.0单点安装_ubuntu18.04

ubuntu的软件包格式为deb&#xff0c;而rpm格式的包归属于红帽子Red Hat。 由于项目一直用的vertica-9.3.1-4.x86_64.RHEL6.rpm&#xff0c;未进行其他版本适配&#xff0c;而官网又下载不到vertica-9.3.1-4.x86_64.deb&#xff0c;尝试通过alian命令将rpm转成deb&#xff0c;但…

【GitHub项目推荐--30 天学会XXX】【转载】

30 天学会 React 这个项目是《30 天 React 挑战》&#xff0c;是在 30 天内学习 React 的分步指南。它需要你学习 React 之前具备 HTML、CSS 和 JavaScript 知识储备。 除了 30 天学会 React&#xff0c;开发者还发布过 30 天学会 JavaScript 等项目。 开源地址&#xff1a;…

解读BEVFormer,新一代CV工作的基石

文章出处 BEVFormer这篇文章很有划时代的意义&#xff0c;改变了许多视觉领域工作的pipeline[2203.17270] BEVFormer: Learning Birds-Eye-View Representation from Multi-Camera Images via Spatiotemporal Transformers (arxiv.org)https://arxiv.org/abs/2203.17270 BEV …

数论Leetcode204. 计数质数、Leetcode858. 镜面反射、Leetcode952. 按公因数计算最大组件大小

Leetcode204. 计数质数 题目 给定整数 n &#xff0c;返回 所有小于非负整数 n 的质数的数量 。 代码 class Solution:def countPrimes(self, n: int) -> int:if n < 2:return 0prime_arr [1 for _ in range(n)]prime_arr[0], prime_arr[1] 0, 0ls list()for i in…

链表--102. 二叉树的层序遍历/medium 理解度C

102. 二叉树的层序遍历 1、题目2、题目分析3、复杂度最优解代码示例4、适用场景 1、题目 给你二叉树的根节点 root &#xff0c;返回其节点值的 层序遍历 。 &#xff08;即逐层地&#xff0c;从左到右访问所有节点&#xff09;。 示例 1&#xff1a; 输入&#xff1a;root […

Django开发_20_form表单前后端关联(2)

根据上一篇文章的代码,进一步了解掌握GET,POST的运行机制 一、实例代码 views.py: def show_reverse(request):if request.method "GET":return redirect(reverse("work4:fill"))if request.method "POST":hobby request.POST.get("h…

Android Studio离线开发环境搭建

Android Studio离线开发环境搭建 1.下载离线和解压包2.创建工程3.创建虚拟机tips 1.下载离线和解压包 下载地址 百度网盘&#xff1a;https://pan.baidu.com/s/1XBPESFOB79EMBqOhFTX7eQ?pwdx2ek 天翼网盘&#xff1a;https://cloud.189.cn/web/share?code6BJZf2uUFJ3a&#…

Apache SeaTunnel 数据集成插件开发最新经验总结!

在Apache SeaTunnel的最新插件开发中&#xff0c;connector-v2 maxcompute 连接器实现了基于CatalogTable SaveMode的新版本。 本文主要给大家分享了源端的关键改动包括弃用了过时的方法&#xff0c;改为通过CatalogTable实现数据传递。汇端则增加了对multi-table sink和save…

HTML+JavaScript-04

JavaScript中的循环 for语句 一个for循环会一直执行&#xff0c;直到循环条件为false for(let i0; i<array.length-1; i){//当遍历完数组后结束循环console.log(array[i] "<br/>");//循环语句 }do...while语句 do...while 语句一直重复直到指定的条件求…

深入理解Redis:如何设置缓存数据的过期时间及其背后的机制

目录 Redis 给缓存数据设置过期时间 Redis是如何判断数据是否过期的呢&#xff1f; 过期的数据的删除策略 Redis 内存淘汰机制 Redis 给缓存数据设置过期时间 一般情况下&#xff0c;我们设置保存的缓存数据的时候都会设置一个过期时间。为什么呢&#xff1f; 因为内存是有…

shell编程-7

shell学习第7天 sed的学习1.sed是什么2.sed有两个空间pattern hold3.sed的语法4. sed里单引号和双引号的区别:5.sed的查找方式6.sed的命令sed的标签用法sed的a命令:追加sed的i命令:根据行号插入sed的c命令:整行替换sed的r命令sed的s命令:替换sed的d命令:删除sed中的&符号 7…