常用函数
①tf.summary.scalar
用于汇总标量数据,共有四个参数,格式如下:
tf.summary.scalar(tags,values,collections = None,name = None)
例如:tf.summary.scalar('test',test)
以标量的形式显示变量test的变化。该函数一般用于表示损失值、准确率的变化情况。
②tf.summary.histogram
用于显示直方图信息,共有四个参数,格式如下:
tf.summary.histogram(tags,values,collections = None,name = None)
例如:tf.summary.histogram('histogram',w)
以直方图的形式记录变量w的变化。该函数一般用于表示训练过程中变量的分布情况。
③tf.summary.image
用于输出带有图像的序列化数据,共有五个参数,格式如下:
tf.summary.image(tag,tensor,max_images = 3,collections = None,name = None)
其中第二个参数的类型为tensor,即一个多维数组,第三个参数用于设置输出图像的数量。
④tf.summary.audio
用于展示训练过程中记录的音频,共有六个参数,格式如下:
tf.summary.audio(name,tensor,sample_rate,max_outputs = 3,collections = None,family = None)
其中第二个参数的类型为tensor,即一个多维数组,第三个参数用于设置输出音频的数量。
⑤tf.summary.distribution
用于展示分布图,一般用于显示学习参数如weights、bias的分布。
⑥tf.summary.text
用于将文本类型的数据转换为tensor写入summary中,代码示例如下:
text = "test_test_test"
summary_op = tf.summary.text('text',tf.convert_to_tensor(text))
⑦tf.summary.merge_all
该函数可以将所有汇总数据全部保存到磁盘,以便tensorboard显示。如果没有特殊要求,一般用这一句代码就可以把训练时的各种信息都展示出来。格式如下:
tf.summary.merge_all(key = 'summaries')
需要注意的是,在使用tf.summary.merge_all()之前,你需要确保已经定义了至少一个tf.summary操作,例如tf.summary.scalar()、tf.summary.histogram()等。
⑧tf.summary.merge
常和tf.get_collection()函数配合使用,可以选择将需要保存的汇总数据保存到磁盘,以便tensorboard显示。格式如下:
tf.summary.merge(inputs,collections = None,name = None)
⑨tf.summary.FileWriter
用于指定一个文件用来保存图。格式如下:
tf.summary.FileWritter(path,sess.graph)
使用时,可以调用add_summary()方法将训练过程的数据保存在filewriter指定的文件中。
代码示例如下:
# -*- coding: utf-8 -*-
"""
Created on Mon Sep 25 20:07:18 2023
@author: ASUS
"""
import tensorflow.compat.v1 as tf
import numpy as np
import matplotlib.pyplot as plt
import os
tf.compat.v1.disable_eager_execution()#这个函数用于禁用 TensorFlow 2 中的即时执行模式,以便能够使用 TensorFlow 1.x 的计算图执行方式。
#1.准备数据
train_X = np.linspace(-1, 1,100)#train_X 是一个从 -1 到 1 的等间距数组,用作输入特征。
train_Y = 5 * train_X + np.random.randn(*train_X.shape) * 0.7#train_Y 是根据 train_X 生成的目标值,在真实值的基础上加上了一些噪声。
#2.搭建模型
#通过占位符定义
X = tf.placeholder("float")#X 和 Y 是 TensorFlow 的占位符(Placeholder),用于在执行时提供输入和标签数据。
Y = tf.placeholder("float")
#定义学习参数的变量
W = tf.Variable(tf.compat.v1.random_normal([1]),name="weight")#W 和 b 是学习参数的变量,可以被模型训练调整。
b = tf.Variable(tf.zeros([1]),name="bias")
#定义运算
z = tf.multiply(X,W) + b#z 是通过将输入特征 X 与权重 W 相乘并加上偏差 b 得到的预测值。
#定义损失函数
cost = tf.reduce_mean(tf.square(Y - z))#cost 是损失函数,计算预测值与真实值之间的平方差的平均值。
#定义学习率
learning_rate = 0.01#learning_rate 是学习率,用来控制优化算法在每次迭代中更新参数的步长。
#设置优化函数
optimizer = tf.train.GradientDescentOptimizer(learning_rate).minimize(cost)#optimizer 是梯度下降优化器,用于最小化损失函数。
#3.迭代训练
#初始化所有变量
init = tf.global_variables_initializer()
#定义迭代参数
training_epochs = 20#training_epochs 是迭代训练的轮数。
display_step = 2#display_step 是控制训练过程中打印输出的步长。
#定义保存路径
savedir = "log4/"
#启动Session
with tf.Session() as sess:#with tf.Session() as sess: 创建一个会话,在该会话中执行计算图操作。
sess.run(init)#sess.run(init) 运行初始化操作,初始化所有变量。
tf.summary.scalar("loss", cost)
#合并所有的summary
merged_summary_op = tf.summary.merge_all()
#创建summary_write用于写文件
summary_writer = tf.summary.FileWriter(os.path.join(savedir,'summary_log'),sess.graph)
for epoch in range(training_epochs):
for(x,y) in zip(train_X,train_Y):
sess.run(optimizer,feed_dict={X:x,Y:y})#sess.run(optimizer,feed_dict={X:x,Y:y}) 执行一次优化器操作,将当前的输入特征 x 和标签值 y 传入模型。
summary_str = sess.run(merged_summary_op,feed_dict = {X:x,Y:y})
summary_writer.add_summary(summary_str,epoch)
if epoch % display_step == 0:#每隔 display_step 轮迭代打印一次损失值和当前的参数值。
loss=sess.run(cost,feed_dict={X:train_X,Y:train_Y})
#测试模型
print("Epoch:",epoch+1,"cost=",loss,"W=",sess.run(W),"b=",sess.run(b))
print("Finished!")
#使用 matplotlib 库绘制训练数据点和拟合直线。
plt.plot(train_X,train_Y,'ro',label='Original data')#绘制原始数据点。
plt.plot(train_X,sess.run(W)*train_X+sess.run(b),'--',label='Fittedline')#绘制拟合的直线。
plt.legend()#添加图例。
plt.show()#显示图形。
#4.利用模型
print("x=0.2,z=",sess.run(z,feed_dict={X:0.2}))#使用训练好的模型,传入输入特征 0.2 来计算预测值 z。
运行代码后,可以看到生成的文件
常用类
①Class FileWriter
该类提供了一种在给定目录下创建事件文件并向事件文件中添加汇总数据和事件信息的机制。它采用异步更新文件内容的方式。因此,训练中的程序可以在训练循环中直接调用methods将数据添加到文件,而不用减速训练。
②Class FileWriterCache
该类用于缓存filewriter,且每个目录拥有一个。
③Class SummaryWriter
该类用于将生成的序列化数据写入指定的文件。
除此之外,还有Event类、Summary类、SummaryDescription类等