tensorflow入门(三)tensorflow下神经网络参数的设置

news2025/1/14 2:29:52

参考   Tensorflow入门 - 云+社区 - 腾讯云

神经网络中的参数是神经网络实现分类或回归问题中重要的部分。在tensorflow中,变量(tf.Variable)的作用就是保存和更新神经网络中的参数的。在tensorflow中,变量(tf.Variable)的作用就是保存和更新神经网络的参数。和其他编程语言类似,tensorflow中的变量也需要指定初始值。因为在神经网络中,给参数赋予随机初始值最为常见,所以一般也使用随机数给tensorflow中的变量初始化。下面一段代码给出了一种在tensorflow中声明一个2*3矩阵变量的方法:

weights = tf.Variable(tf.random_normal([2, 3], stddev = 2))

这段代码调用了tensorflow变量的声明函数tf.Variable。在变量声明函数中给出了初始化这个变量的方法。tensorflow中变量的初始值可以设置成随机数、常数或者是通过其他变量的初始值计算得到。在上面的样例中,tf.random_normal([2 ,3], stddev=2)会产生一个2x3的矩阵,矩阵中的元素是均值为0,标准差为2的随机数。tf.random_normal函数可以通过参数mean来指定平均值,在没有指定时默认为0。通过满足正太分布的随机数来初始化神经网络中的参数是一个非常有用的方法。除了正太分布的随机数,tensorflow还提供了一些其他的随机数发生器,下表列出了tensorflow目前支持的所有随机数发生器。
 

函数名称随机数分布主要参数
tf.random_normal正太分布平均值、标准差、取值类型
tf.truncated_normal正太分布,但如果随机出来的值超过两个标准差,那么这个数将会被重新随机平均值、标准差、取值类型
tf.random_uniform均匀分布最小取值、最大取值、取值类型
tf.random_gammaGamma分布形状参数alpha、尺度参数beta、取值类型
 

tensorflow也支持通过常量来初始化一个变量,下表给出了tensorflow中常用的变量声明方法。

函数名称功能样例
tf.zeros产生全0的数组tf.zeros([2, 3], int32) -> [[0,0,0],[0,0,0]]
 
tf.ones产生全1的数组tf.ones([2, 3], int32) -> [[1,1,1],[1,1,1]]
 
tf.fill产生一个全部为给定数字的数组tf.fill([2, 3], 9) -> [[9, 9, 9],[9, 9, 9]]
 
tf.constant产生一个给定值的常数

tf.constant([1, 2, 3]) -> [1,2,3]

在神经网络中,偏置值(bias)通常会使用常数来设置初始值。以下代码给出了一个样例:

biases = tf.Variable(tf.zeros([3]))

以上代码将会生成一个初始值全部为0且长度为3的变量。除了使用随机数或常数,tensorflow也支持通过其他变量的初始值来初始化新的变量。以下代码给出了具体的方法。

w2 = tf.Variable(weights.initialized_value())
w3 = tf.Variable(weights.initialized_value() * 2.0)

以上代码中,w2的初始值被设置成了与weights变量相同。w3的初始值则是weights初始值的两倍。在tensorflow中,一个变量的值在被使用之前,这个变量的初始化过程需要被明确地调用。以下样例介绍了如何通过变量实现神经网络的参数并实现前向传播过程。

import tensorflow as tf

# 声明w1、w2两个变量。这里还通过seed参数设定了随机种子。
# 这样可以保证每次运行得到的结果是一样的。
w1 = tf.Variable(tf.random_normal((2,3), stddev = 1, seed = 1 ))
w2 = tf.Variable(tf.random_normal((3,1), stddev = 1, seed = 1 ))

# 暂时将输入的随机向量定义为一个常量。注意这里x是一个1*2的矩阵。
x = tf.constant([0.7, 0.9])

# 描述前向传播算法获得神经网络的输出
a = tf.matmul(x, w1)
y = tf.matmul(a, w2)


sess = tf.Session()

# 这里不能直接通过sess.run(y)来获取y的取值
# 因为w1和w2都还没有运行初始化过程。以下两行分别初始化了w1和w2两个变量。
sess.run(w1.initializer) #初始化w1。
sess.run(w2.initializer) #初始化w2。

#输出[[3.95757794]]
print(sess.run(y))
sess.close()

以上程序实现了神经网络的前向传播过程。从这段代码可以看出,当声明了变量w1、w2之后,可以通过w1和w2来定义神经网络的前向传播过程并得到中间结果a和最后答案y。

在tensorflow程序的第二步回声明(session),并通过会话计算结果。在上面的样例中,当会话定义完成之后就可以真正运行定义好的计算了。但在计算y之前,需要将所有用到的变量初始化。也就是说,虽然在变量定义时给出了变量初始化的方法,但这个方法并没有被真正运行。但在计算y之前,需要通过运行w1.initializer和w2.initializer来给变量赋值。虽然直接调用每个变量的初始化过程是一个可行的方案,但是当变量数目增多,或者变量之间存在依赖关系时,耽搁调用的方案就比较麻烦了。为了解决这个问题,tensorflow提供了一种更加便捷的方式来完成变量初始化过程。以下程序展示了通过tf.global_variables_initializer函数实现初始化所有变量的过程。

init_op = tf.global_variables_initializer()
sess.run(init_op)

通过tf.global_variables_initializer函数,就不需要将变量一个一个初始化了。这个函数也会自动处理变量之间的依赖关系。

变量和张量的关系:

tensorflow的核心概念是张量(tensor),所有的数据都是通过张量的形式来组织的,那么变量和张量之间的关系时什么呢?在tensorflow中,变量的声明函数tf.Variable是一个运算。这个运算的输出结果是一个张量,这个张量也就是变量,所以变量是一种特殊的张量。

下面进一步介绍tf.Variable操作在tensorflow中底层是如何实现的。下图给出了神经网络前向传播样例程序的tensorflow计算图的一个部分,这个部分显示了和变量w1相关的操作。

           

上图中黑色的椭圆代表了变量w1,可以看到w1是一个Variable运算。在这张图的下方可以看到w1通过一个read操作将值直接提供给了一个乘法运算,这个乘法运算就是tf.matmul(x, w1)。初始化变量w1的操作是通过Assign操作完成的。从上图可以看到Assign这个节点的输入为随机数生成函数的输出,而且输出赋给了变量w1。这样就完成了变量初始化的过程。

所有变量都被自动地加入到GraphKeys.VARIBALES这个集合中。通过tf.global_variable()函数可以拿到当前计算图上所有的变量。拿到计算图上所有的变量有助于持久化这个计算图的运行状态。当构建机器学习模型时,比如神经网络,可以通过变量声明函数中的trainable参数来区分需要优化的参数(比如神经网络中的参数)和其他参数(比如迭代的轮数)。如果声明变量时参数trainable为True,那么这个变量将会被自动加入到GraphKeys.TRAINABLE_VARIABLES集合。tensorflow中提供的神经网络优化算法会将GraphKeys.TRAINABLE_VARIABLES集合中的变量作为默认的优化对象。

类似张量,维度(shape)和类型(type)也是变量最重要的两个属性。和大部分程序语言类似,变量的类型是不可改变的。一个变量在构建之后,它的类型就不能再改变了。比如在上面给出的前向传播样例中,w1的类型为random_normal结果的默认类型为tf.float32,那么它将不能被赋予其他类型的值。以下代码将会报出类型不匹配的错误。

w1 = tf.Variable(tf.random_normal([2, 3], stddev=1), name= "w1")
w2 = tf.Variable(tf.random_normal([2, 3], dtype=tf.float64, stddev = 1), name= "w2")

w1.assign(w2)


'''
程序将报错:
TypeError:Input 'value' of 'Assign' Op has type float64 that does not match type float32 of argument 'ref'
'''

维度是变量另一个重要的属性。和类型不大一样的是,维度在程序运行中是有可能改变的,但是需要通过设置参数validate_shape=False,下面给出了一段示范代码。

w1 = tf.Variable(tf.random_normal([2 ,3], stddev=1), name="w1")
w2 = tf.Variable(tf.random_normal([2 ,2], stddev=1), name="w2")

# 下面这句话会报维度不匹配的错误:
# ValueError: Dimension 1 in both shapes must be equal, but are 3 and 2
# for 'Assign_1' (op: 'Assign') with input shapes: [2, 3], [2, 2].
tf.assign(w1 ,w2)
#这句话可以被成功执行
tf.assign(w1, w2, validate_shape=False)

虽然tensorflow支持更改变量的维度,但是这种用法在实践中比较罕见。

 

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/79231.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

Educational Codeforces Round 123 (Rated for Div. 2) D. Cross Coloring

Problem - D - Codeforces 翻译: 有一张纸,可以用大小为𝑛𝑚:𝑛行和𝑚列的单元格表示。所有的细胞最初都是白色的。 𝑞操作已应用到工作表。他们的𝑖-th可以描述如下: &#x1d4…

前端工程化项目的思考

这是一篇个人使用前端工程开发项目的思考,希望可以帮助到你。完全是一篇综合概念应该是很多东西,我也不清楚会有多少字,估计会对刚刚开始的人看起来比较迷,但也是没有办法的事情 1.前端脚本语言开发的作者我想应该也想不到js会发展…

Spark

1 Spark作业提交流程 2 Spark提交作业参数 1)在提交任务时的几个重要参数 executor-cores —— 每个executor使用的内核数,默认为1,官方建议2-5个 num-executors —— 启动executors的数量,默认为2 executor-memory —— executor…

【OpenCV学习】第9课:形态学操作的应用-提取水平线丶垂直线

仅自学做笔记用,后续有错误会更改 理论 图像在进行形态学操作的时候, 可以通过自定义的结构元素实现结构元素对输入图像的一些对象敏感丶对另外一些对象不敏感, 这样就会让敏感的对象改变而不敏感的对象保留输出。 通过使用两个最基本的形态学操作 - 膨…

华为云服务-运维篇-弹性负载均衡

文章目录一、什么是负载均衡二、我们为什么需要负载均衡1、生活中需要它的类似场景2、生活场景中协调者(负载均衡)作用3、协调者(负载均衡)引入后的变化三、华为云平台-如何做负载均衡弹性负载均衡-ELB四、总结一、什么是负载均衡 负载均衡构建在原有网…

【数据挖掘】薪酬分段对应工作经验/学历画柱状图【招聘网站的职位招聘数据预处理】

文章目录一.需求背景1.1 需求分析二.数据处理(对给定职位,汇总薪酬分段对应工作经验要求数据,画柱状图;)2.1 事前准备2,1 处理开始三.数据处理(对给定职位,汇总薪酬分段对应学历要求数据,画柱状图;)四.附源…

吉林大学 超星慕课 高级语言程序设计 实验08 结构化程序设计(2022级)

本人能力有限,发出只为帮助有需要的人。 建议同学们自己写完后再进行讨论。 其中的代码均没能在oj上进行测试,因此可能有误,请谅解。 除此以外部分题目设计深度优先搜索,因此可以分别用递归和堆栈实现,堆栈方法为了…

JavaScript进阶教程——异步编程、封装Ajax

异步编程 什么是同步与异步: 同步:一件事没做完,只能等待,完成之后再去做另一件事 异步: 两件事可以同时进行 前端开发中最常见的两种异步情况: ajax: 向后台请求数据计时器: setInterval se…

Python学习基础笔记四十一——sys模块

sys模块是与Python解释器交互的一个接口。 sys.argv 命令行参数List,第一个元素是程序本身路径 sys.exit(n) 退出程序,正常退出时exit(0),错误退出sys.exit(1) sys.version 获取Python解释程序的版本信息 sys.path 返…

ARM Cortex M3处理器概述

Cortex-M3概述 2004年ARM发布作为新型Corex处理器内核系列首款的Cortex-M3处理器。 STM32系列基于专为高性能、低成本、低功耗的嵌入式应用专门设计的ARM Cortex-M内核。 STM32命名规则 STMF103xx系统结构 1.使用高性能的ARM Cortex-M3 32位RISC内核 2.工作频率为72MHZ 3.内…

shell脚本监控文件夹文件实现自动上传数据到hive表

sh createtb.sh “tablename;field1,field2,field3,field4,field5,field6,field7;partition1,partition2” 数据库名:observation (脚本里写死了) 表名:tablename 指定名:field1,field2,field3,field4,field5,field6,f…

分别使用Alpine、Docker制作jdk镜像

目录 制作 jdk 1.0 镜像 ——Docker 1.创建文件夹上传jdk的安装包,和在同级目录下编写Dockerfile文件 2.编写 Dockerfile 文件 3.执行Dockerfile文件,初次依赖镜像的时候会下载相应镜像 优化制作jdk镜像(缩小内存大小)——使用alpine …

【致敬世界杯】球迷(我)和足球的故事

目录 一、第一次接触足球 二、回味无穷的2018世界杯 三、致敬世界杯 3.1 源代码 3.2 思路 3.3 关于图片 一、第一次接触足球 踢足球是一项优秀的运动,它可以锻炼身体,增强团队合作精神,并为人们带来快乐和满足感。回忆起小学时候第一次…

OpenCV和RTSP的综合研究

一、RTSP是什么?用来干什么? RTSP(Real Time Streaming Protocol),RFC2326,实时流传输协议,是TCP/IP协议体系中的一个应用层协议,由哥伦比亚大学、网景和RealNetworks公司提交的IET…

四旋翼无人机学习第14节--PCB Editor简单绘制封装-1

文章目录1 前言1.1 网络获取1.2 封装软件生成1.3 立创商城封装转化1 前言 在之前的博客中,我们绘制了封装所需的焊盘,有了焊盘我们就可以绘制封装啦。当然封装的获取有很多途径,下面我来总结一下。 1.1 网络获取 (有需要的可以下载哦&…

华为eNSP模拟器配置MSTP多实例生成树

传统的stp、rstp有其必然的缺陷 1.统一局域网内所有的vlan共享一个生成树,无法在vlan间实现数据流量的负载均衡。 2.链路利用率低,被阻塞的冗余链路不承载任何流量,造成了带宽的浪费,还可能造成部分vlan报文无法转发。MSTP在它们…

计算机毕业设计springboot+vue基本微信小程序的学习资料共享小程序

项目介绍 前台为用户使用的,包括下面一些功能: ① 资料发布:用户可以将想要共享的资料发布到小程序,供他人购买。 ②搜索 :分为按名称搜索和分类搜索,用户可选择其中一种方式,检索自己所需要的资料。 ③ 查看资料详情:用户可以…

学委必备小工具——筛选未提交人数【python小工具】

问题描述 作为一个学委,通常的任务就是收取班级作业,然后向老师报告当前未交人员的名单 JS版本:实现以一个表格数据查询另一个表格【JS】 之前我已经尝试通过用JS实现了,本质上差别其实也不是很大,只是对于JS来说&…

Java基础之《netty(11)—netty模型》

一、简单说明 1、工作原理示意图 netty主要基于主从Reactors多线程模型做了一定的改进,其中主从Reactor多线程模型有多个Reactor。 2、说明 (1)BossGroup线程维护selector,只关注Accept事件。 (2)当接收到…

[附源码]Node.js计算机毕业设计出版社样书申请管理系统Express

项目运行 环境配置: Node.js最新版 Vscode Mysql5.7 HBuilderXNavicat11Vue。 项目技术: Express框架 Node.js Vue 等等组成,B/S模式 Vscode管理前后端分离等等。 环境需要 1.运行环境:最好是Nodejs最新版,我…