30天从入门到精通TensorFlow1.x 第二天,变量 tf.Variable()

news2024/11/28 23:54:29

文章目录

  • 一,接前一天
    • (1).内容前先弄清楚 `sess.run()` 函数
      • a. 该函数干嘛的
      • b. 该函数有哪些参数
      • c. 该函数的使用
    • (2).由库函数创建张量
    • (3).由库函数创建张量
  • 二、变量`tf.Variable()`
    • (1).啥是变量
    • (2).什么情况下会用该变量函数
    • (3).通过变量来创建张量
  • 三、变量 `tf.Variable()` 与 `tf.placeholder()` 的区别
    • 1.初始化
    • 2. 数据类型和形状
    • 3. 可训练性
    • 4. 更新方式
    • 总结:
  • 四、`tf.get_variable()`获取变量

一,接前一天

(1).内容前先弄清楚 sess.run() 函数

a. 该函数干嘛的

sess.run() 函数是 TensorFlow 中最核心的执行函数之一,用于在会话中执行计算图上的操作张量,并返回相应的结果。

看到了:该函数是用来创建一个会话对象。在这个会话中,我们可以执行模型中定义好的操作张量

b. 该函数有哪些参数

sess.run()
#来看下源代码
run(self, feed_dict=None, session=None)
'''
fetches(必选)也就是:需要执行的操作或张量,可以是单个张量、操作,也可以是由它们组成的列表、元组或字典等形式。
feed_dict(可选):用于传递数据的占位符字典。该参数默认为空字典,如果图中存在占位符,则 feed_dict 必须提供相应的数值,否则会抛出异常。

session=None:指定我们的回话对象,在进行初始化的时候,这里就要显式的说明
例如:
# 创建一个会话
test = tf.Session()
tf.global_variables_initializer().run(session=test)

'''

下面分别对这两个参数进行详细解释:

  1. fetches
    fetches 参数用于指定需要在会话中执行的操作或者张量,可以传入单个操作或者张量,也可以传入一个列表元组字典等形式。fetches 参数有以下几种常见的情况:
    –传入单个张量或操作,返回该张量或操作的计算结果。
    传入多个张量或操作的列表或元组,返回所有张量或操作的计算结果。
    –传入字典,键为字符串,值为张量或操作,返回字典中所有张量或操作的计算结果。
    例如,我们可以通过 sess.run() 函数来获取两个张量的值:
import tensorflow as tf
a = tf.constant(2)
b = tf.constant(3)
c = a + b
with tf.Session() as sess:
    print(sess.run([a, b]))
  1. feed_dict
    feed_dict 参数用于传递占位符的值,以便在会话中执行计算。在 TensorFlow 中,占位符是一种特殊的张量,它没有具体的数值,但在使用时必须提供相应的数据。feed_dict 参数可以是一个字典,其中占位符张量需要填充的数值
    例如,我们可以通过 feed_dict 参数来设置占位符变量的值:
import tensorflow as tf
x = tf.placeholder(tf.float32, shape=[None])
y = 2 * x
with tf.Session() as sess:
    result = sess.run(y, feed_dict={x: [1, 2, 3]})
    print(result)

c. 该函数的使用

我们来使用一下:

import tensorflow as tf

#创建占位符
x = tf.placeholder(tf.float32, shape=[None])
y = 2 * x

with tf.Session() as sess: 
#创建一个会话对象,赋值给sess
    result = sess.run(y, feed_dict={x: [1, 2, 3]})
    #在该会话中执行定义好的操作 和 张量
    print(result)

(2).由库函数创建张量

test3 = tf.zeros((2,3))
print('tensor test3:',test3)
print('run test3:',test.run(test3))

在这里插入图片描述

(3).由库函数创建张量

二、变量tf.Variable()

(1).啥是变量

在 TensorFlow 1.x 版本中,tf.Variable() 函数用于创建一个可训练的张量,也就是变量。变量在模型的训练过程中会被反复更新,用于存储更新模型参数与常规张量不同的是,变量创建时需要初始化并且可以持久化到磁盘上。通过变量,可以方便地定义和管理模型参数,从而加快模型训练的速度和效果。

(2).什么情况下会用该变量函数

神经网络模型中,我们通常需要定义一些可训练的参数,例如:权重偏置项等。这些参数会在模型的训练过程中被反复更新以优化模型性能。使用 tf.Variable() 函数可以方便地创建这些参数。

举个例子,假设我们要实现一个简单的线性回归模型 y = wx + b,其中 w 和 b 是待学习的参数。我们可以使用 tf.Variable() 函数来创建这两个变量:

import tensorflow as tf

# create w and b init 0.0
w = tf.Variable(0.0, name='weight')
b = tf.Variable(0.0, name='bias')

# create input and out
x = tf.placeholder(dtype=tf.float32, shape=[None])
out = tf.placeholder(dtype=tf.float32, shape=[None])

# create loss and opt
y = w * x + b
loss = tf.reduce_mean(tf.square(y - out))
optimizer = tf.train.GradientDescentOptimizer(learning_rate=0.005)
train_op = optimizer.minimize(loss)

# train model
with tf.Session() as sess:
    tf.global_variables_initializer().run(session=sess)
    for i in range(1000):
        _, loss_val, w_val, b_val = sess.run(
            [train_op, loss, w, b],
            feed_dict={x: [1, 23, 4, 5, 7, 5, 7], out: [3, 5, 7, 9, 11, 13, 15]}
        ) #注意 输入的数据 形状要一致,避免输出与预测值得形状不一致问题
        if i % 100 == 0:
            print('Step {}: loss = {}, w = {}, b = {}'.format(i, loss_val, w_val, b_val))

在这里插入图片描述

(3).通过变量来创建张量

import tensorflow as tf
import numpy as np
from pprint import pprint

#TODO 创建一个会话
test = tf.Session()
w = tf.Variable([5.0],tf.float32)
b = tf.Variable([6.0],tf.float32)
x = tf.placeholder(tf.float32)

y = w * x + b

print('w:',w)
print('b:',b)
print('x:',x)
print('y:',y)
tf.global_variables_initializer().run(session = test)
print('run y:',test.run(y,feed_dict={x:[1,2,3,4]}))

在这里插入图片描述
使用tf.Variable()创建变量时,需要注意以下几点:

  1. 变量的初始值需要在创建时指定,可以是常量,也可以是随机数等
  2. 变量的类型需要与初始值相匹配,比如:创建一个整型变量时,初始值必须为整数。
  3. 在使用变量前,需要通过调用tf.global_variables_initializer()初始化所有变量。
  4. 变量在计算图中有自己的作用域,需要注意命名空间的规范,便于管理和调试。
  5. 变量的生命周期需要谨慎管理,避免内存泄漏或者过早释放。

三、变量 tf.Variable()tf.placeholder() 的区别

在 TensorFlow 中,tf.Variable() 和 tf.placeholder() 都是非常重要的变量类型,它们有以下几个不同点:

1.初始化

tf.Variable() 变量在创建时需要进行初始化,并且可以持久化到磁盘上。我们通常使用 tf.global_variables_initializer() 函数来初始化所有变量。例如:

import tensorflow as tf

# 创建变量 w 和 b
w = tf.Variable(0.0, name='weight')
b = tf.Variable(0.0, name='bias')

with tf.Session() as sess:
	# 初始化所有变量
    sess.run(tf.global_variables_initializer())

tf.placeholder() 变量没有具体的值只是一个占位符,因此也不需要进行初始化。它通常用于定义模型的输入和输出等信息。

2. 数据类型和形状

tf.Variable()tf.placeholder() 变量都可以指定数据类型和形状。在使用 tf.Variable() 创建变量时,我们需要明确指定变量的形状数据类型。例如:

import tensorflow as tf

# 创建一个形状为 [2, 3],数据类型为 float32 的变量
x = tf.Variable(tf.zeros([2, 3]), dtype=tf.float32)

而在使用 tf.placeholder() 时,我们可以通过 shape 参数来指定占位符的形状,数据类型则可以通过传入的数值来自动推断。例如:

import tensorflow as tf

# 定义一个形状为 [None, 3] 的占位符,数据类型自动推断
x = tf.placeholder(dtype=tf.float32, shape=[None, 3])

'''
自动推断的意思是:当我们使用占位符创建了指定数据类型,在给他传参的时候,无需一致,该函数会自动识别为 我们指定的数据类型
'''

3. 可训练性

tf.Variable() 变量是可训练的,它们在模型的训练过程中会被反复更新优化模型性能。而 tf.placeholder() 变量没有训练参数只是一个占位符,用于定义模型输入输出等信息。

4. 更新方式

对于 tf.Variable() 变量,我们可以使用 assignassign_add 等方法来更新变量的值。例如:

import tensorflow as tf

# 创建变量 w 和 b,并初始化为0.0
w = tf.Variable(0.0, name='weight')
b = tf.Variable(0.0, name='bias')

# 使用 assign 方法更新 w 的值
sess.run(w.assign(2.0))

在这里插入图片描述

而对于 tf.placeholder() 变量,其值不能直接更新,需要通过传递新的数值重新计算相应的张量操作

总结:

  1. 在 TensorFlow 中,tf.placeholder() 是一个占位符变量,它不是具体的数值只是一个形状和数据类型都已经确定的张量。在计算图中使用 tf.placeholder() 可以定义模型的输入和输出等信息,然后通过 feed_dict 参数传递具体的数值.
    例如:
import tensorflow as tf

# 创建一个形状为 [None, 3] 的占位符,数据类型为 float32
x = tf.placeholder(dtype=tf.float32, shape=[None, 3])

# 定义一个操作 y,将 x 向量与常量向量相加
y = x + tf.constant([1.0, 2.0, 3.0])

with tf.Session() as sess:
    # 将 [4, 5, 6] 作为 x 的值,计算 y 的结果
    result = sess.run(y, feed_dict={x: [[1, 2, 3], [4, 5, 6]]})
    print(result)
  1. tf.Variable() 则是一个具体的数值,它可以被更新持久化到磁盘上。在使用 tf.Variable() 创建变量时,我们需要指定变量的初始值数据类型
import tensorflow as tf

# 创建一个训练变量 w,初始值为 0.0
w = tf.Variable(0.0, dtype=tf.float32)

# 定义一个操作 loss,将 w 的平方与常量 2 相加
loss = tf.square(w) + 2

with tf.Session() as sess:
    # 初始化所有变量
    sess.run(tf.global_variables_initializer())

    # 计算 loss 的结果
    result = sess.run(loss)
    print(result)
  1. 在使用 tf.Variable() 变量时,需要给它传递一个初始值,并在计算前进行初始化才能保证正常计算。

注意:在 TensorFlow 中,使用 tf.Variable() 创建变量时需要为其指定一个初始值。如果在创建变量时给定了初始值,那么在计算前进行初始化之后,变量的值就会被更新。
那么为什么既然会被初始化更新掉还要给与初始值呢?

在 TensorFlow 中,给变量指定一个初始值的目的是为了在计算图中建立变量节点,并确定它的数据类型、形状和初值等属性。这样做有以下几个好处

原因二:确定变量的数据类型、形状和初值等属性,使得我们能够更方便地使用变量进行计算。

原因三:在计算图中明确地标记变量节点,使得我们能够更方便地对其进行操作和管理。

原因四:变量的初值可以作为一种默认值,在初始化时如果没有手动赋值,就会自动使用默认值进行初始化。

原因五:在实际的工作中,我们通常会需要手动给变量赋初值,并在计算前对变量进行初始化。这样做主要是为了保证变量的初值和计算结果符合预期。

四、tf.get_variable()获取变量

  1. tf.get_variable()的好处:
    一般来说:如果我们定义的变量名称在之前已经定义过,再次定义的时候那么TensorFlow就会报错。若果此时我们使用tf.get_variable()函数来替代tf.variable()函数,就会避免这种情况。使用tf.get_variable()后,如果变量已经定义过,该函数就会直接返回变量,若果变量之前未被定义,则该函数就会从新定义。
    例如:
import tensorflow as tf

# 定义一个变量 x
x = tf.Variable(tf.random_normal([10]), name='x')

# 使用 x 的名称来创建一个新变量 y,并共享 x 的值
y = tf.get_variable(name='x', shape=[10], dtype=tf.float32, initializer=tf.constant_initializer(0.0))

with tf.Session() as sess:
    # 初始化所有变量
    sess.run(tf.global_variables_initializer())

    # 输出变量 x 和 y 的数值
    print('x:', sess.run(x))
    print('y:', sess.run(y))

注意:由于我们使用了 tf.get_variable() 方法,因此在创建变量时需要指定变量的名称形状数据类型以及初始化器等参数

另外:
如果要通过 tf.get_variable() 方法继承已经定义的变量,那么必须要在创建计算图时指定 reuse=True 参数,以告知 TensorFlow 允许共享变量。如果没有指定 reuse=True,那么 TensorFlow 将会默认禁止变量共享,从而导致错误。

在分布式TensorFlow中, tf.get_variable() 获取得到全局变量,若要得到局部变量,则使用:tf.get_local_variable()

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/597624.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

git在一台电脑上配置不同的仓库(多个gitee仓库、或者不同的github仓库)

前言 在开发过程中,我们工作的时候难免会使用到两个不同的仓库,但是正常来说一电脑默认一个参考,直接会用全局命令 git config --global user.name "yourName"但是这样只能配置一个仓库了,本文介绍在一台电脑上配置多…

跟着LearnOpenGL学习5--纹理

文章目录 一、前言二、纹理映射三、纹理环绕方式四、纹理过滤五、多级渐远纹理六、加载与创建纹理七、生成纹理八、应用纹理九、完整代码9.1、工程结构9.2、纹理图片9.3、stb_image.cpp9.4、顶点着色器9.5、片段着色器9.6、main.cpp 十、纹理颜色与顶点颜色混合十一、纹理单元 …

Python系列模块之pymysql操作MySQL 数据库

目录 一、安装pymysql 二、连接数据库 三、数据库操作 3.1 查询 3.2 更新 3.3 使用循环批量更新 Python 系列文章学习记录: Python系列之Windows环境安装配置_开着拖拉机回家的博客-CSDN博客 Python系列之变量和运算符_开着拖拉机回家的博客-CSDN博客 Pyt…

JavaWeb13(ajax01)

目录 一. 什么是ajax 二.为什么需要AJAX? 三. 同步和异步的区别. 四.基于jQuery实现AJAX语法 4.1 语法1-$.ajax(url,[settings]). 4.1 语法2-$.get/post(url, [data], [callback], [type]). 五 .案例 5.1 ajax实现登录 ①html代码 用户登录 用户名: 密码…

Linux安装VNC(Linux桌面版远程)

一、Linux安装VNC服务 适用于CentOS7 #检查系统没有装GUI界面 rpm -qa | grep gnomeyum -y groupinstall "X Window System" yum -y groupinstall "GNOME Desktop"#安装VNC yum install tigervnc tigervnc-server -y检查开机运行级别 systemctl get-defa…

Shell脚本攻略:数组

目录 一、理论 1.数组概述 2.定义数组 3.数组打印 4.数组的数据类型及处理 5.数组赋值 6.数组遍历 7.数组切片 8.数组替换 9.删除数组 10.追加数组中的元素 11.数组排序算法 二、实验 1.实验一 2.实验二 3.实验三 一、理论 1.数组概述 数组是Shell的一种特殊变…

MySQL InnoDB缓存池

缓存池的作用 缓存表数据与索引数据,把磁盘上的数据加载到缓冲池,避免每次访问都进行磁盘IO,起到加速访问的作用。 为什么不把所有数据放到缓冲池中 凡事都具备两面性,抛开数据易失性不说,访问快速的反面是存储容量…

MongoDB 基本概念

MongoDB 部署模型 在生产环境中,MongoDB 经常会部署成一个三节点的复制集,或者一个分片集群。 我们先来看左边,当 MongoDB 部署为一个复制集时,应用程序通过驱动,直接请求复制集中的主节点,完成读写操作。另…

前端学习(DAY51)面试1

组件中的 data 为什么是一个函数? 如果 data 是对象的话,当数据改动时就会影响到所有的实例,可能会造成一些数据的冲突。 HTTP http:以安全为目标的http通道,HTTPs是以安全为目标的https通道(使用SSL进…

Linux--ServerProgramming--(3)详解高性能服务器程序框架

1. 服务器框架详解 1.1 服务器模型 1.1.1 C/S 模型 此模型很简单,就是服务器和客户端。 此模型 非常适合资源相对集中的场合。 缺点:因为服务器是通信的中心,当访问量过大时,可能所有的客户都将得到很慢的响应。此缺点可由 P2P…

利用PHP导出MySQL数据表结构和SQL文件

目录 一、获取数据库所有的数据表 方法一:TP5 方法二:原生PHP 二、导出指定数据表的数据结构 三、 导出SQL文件 四、生成SQL语句 五、完整代码 前端 后端 语言:PHP 数据库:MySQL 功能:分为四部分,① 查出数…

智大数据比赛的总结

强国杯个人赛一定要报 hive 和hadoop基础环境配置 开启单节点集群环境 (0 / 10 分) 本次使用环境为单节点集群,对应主机名为hadoop000,使用工具连接对应主机并进行相关操作。 环境中已经安装java、Hadoop、Hive、Mysql并配置对应环境变量,安装路径为/root/software/,对应…

通过python采集关键字搜索1688工厂数据接口,1688工厂数据接口,1688API接口

1688是一个行业网站,主要提供中小型批发和生产商的信息,是中国供应商向全球采购商展示其产品的平台。在1688上,可以找到许多工厂和制造商的信息,包括公司名称、地址、联系人、联系方式、主要产品等。 采集1688工厂数据可以帮助采…

MySQL数据库 2.启动与停止

目录 ​编辑 🤔 启动与停止: 🙂1.WIN加R调用windows命令行,输入:services.msc 🙂2.可以在cmd(管理员模式)中输入以下指令: 🤔 启动MySQ后的操作步骤&…

linux安装tomcat8

1.tomcat8下载 https://tomcat.apache.org/download-80.cgi 2.tomcat8安装 (1)将tomcat jar上传到usr/local目录 (2)解压tomcat压缩包 [rootiZ2ze7vthdl3oh0n0hzlu7Z local]# tar -zxvf apache-tomcat-8.5.58.tar.gz&#x…

开发小程序过程中的兼容难题,应当何去何从?

如今小程序开发已经成为了互联网行业发展的主流,而小程序开发过程中的兼容难题也让许多开发者感到头疼。那么小程序开发过程中兼容问题究竟有哪些,该如何解决?下面我们就针对这个问题展开一下分析。 什么是小程序? 小程序是一种无…

为什么魂斗罗只有 128KB 却可以实现那么长的剧情

经常看到有同学在抱怨现在的游戏、APP占用非常大的空间,基本都是 10G 起步。 这让我想到初中时玩过的一款游戏魂斗罗,为什么它只有 128KB 却可以实现那么长的剧情呢?这篇文章将会给大家讲讲这里面的奥秘~ 正文 现代程序员 A 和 1980 年代游戏…

小程序安装Vant Weapp详细步骤,下载和npm安装版

小程序安装Vant Weapp详细步骤 使用npm下载1、新建项目并初始化项目2、下载Vant Weapp3、修改 app.json4、构建 npm 包5、引入组件 下载方式1. npm下载或者下载[官方示例](https://github.com/youzan/vant-weapp)2. 把里面的dist文件夹复制出来,放到项目的根目录&am…

MKS SimpleFOC ESP32 例程7 双电机电流控制

Makerbase ESP32 FOC 例程7 双电机电流控制 第一部分 硬件介绍 1.1 硬件清单 序号品名数量1ESP32 FOC V1.0 主板12ARDUINO UNO主板23MKS SF2804电机1412V电源适配器15USB 线1 注意:YT2804是改装的云台无刷电机,带有AS5600编码器,可实现360连续运转。…

恒流间歇滴定法(GITT)测试锂离子电池的实验流程

恒流间歇滴定法(GITT)测试锂离子电池的实验流程 锂电池作为现代电子设备中最常用的电源之一,其性能和安全性对于设备的正常运行至关重要。恒电流间歇滴定法是一种常用的测试方法,用于评估锂电池的容量、循环寿命和内阻等关键参数。…