TensorFlow笔记之多元线性回归

news2025/1/6 19:53:21

文章目录

  • 前言
  • 一、数据处理
  • 二、TensorFlow1.x
    • 1.定义模型
    • 2.训练模型
    • 3.结果可视化
    • 4.模型预测
    • 5.TensorBoard可视化
  • 三、TensorFlow2.x
    • 1.定义模型
    • 2.训练模型
    • 3.结果可视化
    • 4.模型预测
  • 总结


前言

记录使用TensorFlow1.x和TensorFlow2.x完成多元线性回归的过程。


一、数据处理

在此使用波士顿房价数据集,包含506个样本,输入为12个房屋信息特征,输出为房价。
使用pandas库读取csv文件,对数据进行归一化以消除不同维度量级上的差异,进行训练集与测试集的划分以评估训练结果。

import pandas as pd
import numpy as np
from sklearn.model_selection import train_test_split
#数据读取
df = pd.read_csv(r'database/boston.csv', header=0)
x_data = np.array(df.values)[:,:12]
y_data = np.array(df.values)[:,12]
#数据归一化
for i in range(12):
    x_data[:,i] = x_data[:,i] / (np.max(x_data[:,i]) - np.min(x_data[:,i]))
#数据集划分
x_train, x_test, y_train, y_test = train_test_split(x_data, y_data, test_size=0.3)
#维度变换
x_train = x_train.reshape((-1, 12))
x_test = x_test.reshape((-1, 12))
y_train = y_train.reshape((-1, 1))
y_test = y_test.reshape((-1, 1))

二、TensorFlow1.x

1.定义模型

import tensorflow.compat.v1 as tf
import matplotlib.pyplot as plt
from sklearn.utils import shuffle
tf.disable_eager_execution()
#使用命名空间对节点打包
with tf.name_scope('Model'):
    #创建变量
    x = tf.placeholder(tf.float32, [None, 12], name='X')
    y = tf.placeholder(tf.float32, [None, 1], name='Y')
    w = tf.Variable(tf.random.normal((12,1)), name='w')
    b = tf.Variable(tf.random.normal((1, 1)), name='b')
    def model(x, w, b):
        return tf.matmul(x, w) + b
    #预测节点
    pred = model(x, w, b)

2.训练模型

使用小批量梯度下降进行训练,每轮过后打乱训练集顺序。

#训练参数
train_epoch = 100
learning_rate = 0.01
batch_size = 100
batch_num = (x_train.shape[0] // batch_size)
#损失函数
step = 0
display_step = 5
loss_list_test = []
loss_list_train = []
with tf.name_scope('LossFunction'):
    loss_function = tf.reduce_mean(tf.square(y - pred))
#定义优化器
optimizer = tf.train.GradientDescentOptimizer(learning_rate)\
    .minimize(loss_function)
#变量初始化
sess = tf.Session()
init = tf.global_variables_initializer()
sess.run(init)

迭代训练

for epoch in range(train_epoch):
    for batch in range(batch_num):
        xi = x_train[batch * batch_size:(batch + 1) * batch_size]
        yi = y_train[batch * batch_size:(batch + 1) * batch_size]
        sess.run(optimizer, feed_dict={x:xi, y:yi})
        step = step + 1
        if step % display_step == 0:
            loss_train = sess.run(loss_function,\
                                  feed_dict={x:x_train, y:y_train})
            loss_test= sess.run(loss_function,\
                                feed_dict={x:x_test, y:y_test})
            loss_list_train.append(loss_train)
            loss_list_test.append(loss_test)
            #print('w=', sess.run(w), '\n', 'b=', sess.run(b))
	#打乱训练集顺序
    x_train, y_train = shuffle(x_train, y_train)

3.结果可视化

plt.plot(loss_list_train, 'b-')
plt.plot(loss_list_test, 'r-')
print('train_epoch=', train_epoch)
print('learning_rate', learning_rate)
print('batch_size=', batch_size)

测试集损失函数比训练集下降得更快

增大训练轮数,减少批次样本量,损失函数进一步下降

将测试集比例设置为0.99,只用5个样本作为训练集,经过数轮训练后,训练集损失函数趋于0,测试集损失函数上升,产生了过拟合。

4.模型预测

从测试集随机抽取样本进行预测

i = np.random.randint(0, 50)
print('第%i个样本:' % i)
print('预测值:', sess.run(pred, feed_dict={x:x_test[i].reshape((1, 12))}))
print('实际值:', y_test[i])
sess.close()

差距有大有小

5.TensorBoard可视化

在变量初始化后加入:

#设置存储目录
tf.reset_default_graph()
log_dir = 'G://log'
#记录损失值
sum_loss_op = tf.summary.scalar('loss', loss_function)
#合并写入
merged = tf.summary.merge_all()
#文件写入器
write = tf.summary.FileWriter(log_dir, sess.graph)

将损失值加入摘要:

loss_train, sum_loss = sess.run([loss_function, sum_loss_op],\
                                  feed_dict={x:x_train, y:y_train})
write.add_summary(sum_loss, epoch)

关闭:

write.close()

在Anaconda Prompt中进入日志目录,运行TensorBoard,访问网址。
可以看到损失值和计算图

三、TensorFlow2.x

1.定义模型

import tensorflow as tf
import matplotlib.pyplot as plt

def model(x, w, b):
    return tf.matmul(x, w) + b

def loss_function(x, y, w, b):
    pred = model(x, w, b)
    loss = tf.reduce_mean(tf.square(y - pred))
    return loss

def grad(x, y, w, b):
    with tf.GradientTape() as tape:
        loss = loss_function(x, y, w, b)
    return tape.gradient(loss, [w, b])

w = tf.Variable(tf.random.normal((12,1)), dtype=tf.float32)
b = tf.Variable(tf.random.normal((1,1)), dtype=tf.float32)

矩阵乘法需要转化为Tensor

x_train = tf.cast(x_train, tf.float32)
x_test = tf.cast(x_test, tf.float32)
y_train = tf.cast(y_train, tf.float32)
y_test = tf.cast(y_test, tf.float32)

2.训练模型

#训练参数
train_epoch = 100
learning_rate = 0.01
batch_size = 100
batch_num = x_train.shape[0] // batch_size
step = 0
display_step = 5
loss_list_train = []
loss_list_test = []
#创建优化器
optimizer = tf.keras.optimizers.SGD(learning_rate)

迭代训练

for epoch in range(train_epoch):
    for batch in range(batch_num):
        xi = x_train[batch*batch_size : (batch+1)*batch_size]
        yi = y_train[batch*batch_size : (batch+1)*batch_size]
        grads = grad(xi, yi, w, b)
        #应用梯度
        optimizer.apply_gradients(zip(grads, [w,b]))
        step = step + 1
        if step % display_step == 0:
            loss_list_train.append(loss_function(x_train, y_train, w, b))
            loss_list_test.append(loss_function(x_test, y_test, w, b))
    #使用tf.random.shuffle打乱Tensor类型的数据集
    train_data = tf.concat([x_train, y_train], axis=1)
    train_data = tf.random.shuffle(train_data)
    x_train = train_data[:, :12]
    y_train = tf.reshape(train_data[:, 12], (-1,1))

3.结果可视化

plt.plot(loss_list_train, 'b-')
plt.plot(loss_list_test, 'r-')
print('train_epoch=', train_epoch)
print('learning_rate', learning_rate)
print('batch_size=', batch_size) 

4.模型预测

i = np.random.randint(0, 50)
pred = model(tf.reshape(x_test[i], (1,12)), w, b).numpy()
print('第%i个样本:' % i)
print('预测值:', pred)
print('实际值:', y_test[i].numpy()) 


总结

TensorFlow1.x与TensorFlow2.x部分语法存在差异,使用TensorBoard进行展示时需要将之前的log文件删除;进行矩阵乘法时,需要注意数据类型与维度;由于数据集较小,采用的是线性回归,模型的准确程度有待优化。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/126396.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

dll修复工具哪个比较好?好的修复工具怎么选择

最近有小伙伴咨询小编,问dll修复工具的选择,因为他的电脑经常出现dll缺失,一缺失就打开不了各种软件程序,非常的让他烦恼,所以今天小编就来给大家详细的说说dll修复工具哪个比较好?要怎么去选择。 一.什么…

36 氪发布《研发项目管理软件应用指南》,ONES 入选典型厂商案例

近日,36氪企服点评发布了《研发项目管理软件应用指南》(下称「指南」)。36氪企服点评致力于帮助每个需求企业服务的人做出正确的决策,携手每个企服行业者为大众提供更高的价值与服务。在该指南中,36氪企服点评综合了海…

大数据系列——ClickHouse表引擎与分布式查询

目录 一、ClickHouse的表引擎 1、MergeTree的创建方式与存储结构 2、ReplacingMergeTree 二、数据分片与分布式查询 三、Clickhouse-ETL常见业务使用 一、ClickHouse的表引擎 表引擎体系,包括合并树、外部存储、内存、文件、接口和其他6大类20多种表引擎。而在…

全流量回溯分析为您解决应用性能问题(一)

前言 信息中心老师反应,用户反馈办公系统有访问慢的情况,需要通过流量分析系统来了解系统的运行情况,此报告专门针对系统的性能数据做了分析。 信息中心已部署NetInside流量分析系统,使用流量分析系统提供实时和历史原始流量&am…

【网络安全篇】浅谈web应用程序的安全风险

🏆今日学习目标: 🍀浅谈web应用程序的安全风险 ✅创作者:贤鱼 ⏰预计时间:25分钟 🎉个人主页:贤鱼的个人主页 🔥专栏系列:网络安全 🍁贤鱼的个人社区&#xf…

使用FastJson进行驼峰下划线相互转换写法及误区

PropertyNamingStrategy 有四种序列化方式。 CamelCase策略,Java对象属性:personId,序列化后属性:persionId – 实际只改了首字母 大写变小写 PascalCase策略,Java对象属性:personId,序列化后属…

说透IO多路复用模型

在说IO多路复用模型之前,我们先来大致了解下Linux文件系统。在Linux系统中,不论是你的鼠标,键盘,还是打印机,甚至于连接到本机的socket client端,都是以文件描述符的形式存在于系统中,诸如此类&…

springboot项目打war包 部署到Tomcat

1、SpringBoot项目Pom文件修改 <!-- 打war包配置 --><packaging>war</packaging><!-- 打war包配置 --><plugin><groupId>org.apache.maven.plugins</groupId><artifactId>maven-war-plugin</artifactId><version>…

英美TOP名校对IB的申请要求汇总

英美TOP名校对IB的申请要求汇总 英国大学剑桥大学 IB要求 40-42分&#xff08;满分45&#xff09;&#xff0c;HL要求为776分。 学校可能要求申请者的某些科目成绩为7&#xff0c;视不同专业和学院而定。 对任何要求数学的专业&#xff0c;申请者需选Analysis and Approaches&a…

Google SEM和谷歌SEO的区别

很多人对Google SEM和Google SEO概念很模糊。米贸搜整理如下。看图: Google SEM和SEO的关系 在上图中&#xff0c; 最上面的部分属于Google SEM&#xff0c;即Google Ads广告推广&#xff0c;是一种按效果付费的广告&#xff1b; 底层属于Google SEO&#xff0c;也就是Googl…

前端基础_配置IIS服务器

配置IIS服务器 在应用程序完全离线之前&#xff0c;还需要正确地提供清单文件。清单文件必须有扩展名.manifest和正确的mime-type。 如果使用Apache这样的通用Web服务器&#xff0c;需要找到在AppServ/Apache2.2/conf文件夹中的mine.types文件并向其添加“text/cache-manifes…

React学习02-React面向组件编程

React 开发者工具 推荐使用Chrome或Edge浏览器&#xff0c;安装React Developer Tools&#xff08;Facebook出品&#xff09;。 安装完成后&#xff0c;访问使用React编写的页面时&#xff0c;图标会高亮&#xff08;开发环境为红色有debug标识&#xff0c;生产环境为蓝色&…

如何高效阅读一篇论文

如何阅读一篇论文&#xff08;做好阅读笔记&#xff09;阅读步骤第一遍第二遍第三遍上哪里找论文paperswithcodeconnectedpaperslabml.ai 深度学习论文实现labml.ai 热门研究论文阅读步骤 第一遍 第一次通过的目的是大致了解论文。 阅读作者姓名、标题、摘要、简介、小节标题…

create first django

django-admin startproject first 1. 运行第一个django.py文件 python manage.py runserver 2. 建立第一个app python manage.py startapp first_app 修改settings.py&#xff0c;将first_app加入到下面中 然后修改views.py 然后修改urls.py配置导入view文件 前面是一个正则表达…

一文速学-Pandas处理时间序列数据-时间/日期操作详解

前言 关于Pandas处理时间序列数据我已经有写过两篇处理文章了&#xff1a; 一文速学-Pandas中DataFrame转换为时间格式数据与处理 一文速学-Pandas处理时间序列数据操作详解 日常处理一些数据和业务上需求&#xff0c;其实还是十分常用到时序数据的&#xff0c;一些处理方…

堆排序,建初始堆以及优先队列(priority_queue)

1.堆&#xff1a; 如果有一个关键码的集合K {k0&#xff0c;k1&#xff0c; k2&#xff0c;…&#xff0c;kn-1}&#xff0c;把它的所有元素按完全二叉树的顺序存储方式存储在一个一维数组中&#xff0c;并满足&#xff1a;Ki < K2i1 且 Ki<K2i2 &#xff0c;则称为小堆…

Docker部署jenkins配置公私钥拉取代码

容器内配置公私钥 先进入部署Jenkinns中的容器&#xff0c;在docker容器内生成公私钥 ssh-keygen -t rsajenkins 配置私钥信息 在Dashbord->凭据->系统->全局凭据中新增一个凭据 将公钥配置在gitlab 正常这么配制就可以了&#xff0c;但是在jenkins上发现使用ssh…

如何快速掌握代币经济学

如何研究加密世界里的Token? 先看一组数据&#xff1a;截至2022年&#xff0c;市面上大约有6000种加密货币(或者更多&#xff09;。这对投资者来说当然是一个很大的机会。然而&#xff0c;在2021年&#xff0c;投资者在Crypto项目遇到欺诈&#xff0c;损失的金额120亿美元。因…

2022年河北沃克金属制品有限公司助力河北石家庄电子商务资源对接暨电商直播选品大会圆满落幕!

会议主题&#xff1a;聚合电商直播优势资源 赋能产业发展消费增长 主题活动&#xff1a;2022河北•石家庄电子商务资源对接暨电商直播选品大会 承办日期&#xff1a;2022年12月26日至2022年12月27日 主办单位&#xff1a;石家庄市商务局 指导单位&#xff1a;河北省商务厅 …

基于K8s的DevOps平台实践(一)

文章目录前言1. DevOps介绍&#x1f351; 瀑布式流程&#x1f351; 敏捷开发&#x1f351; DevOps2. Jenkins初体验&#x1f351; K8s环境中部署jenkins&#x1f351; 安装汉化插件3. Jenkins基本使用演示&#x1f351; 演示目标&#x1f351; 演示准备&#x1f351; 演示过程4…