tensorflow入门(四)如何用tensorflow训练神经网络

news2024/12/24 0:24:27

参考    如何用tensorflow训练神经网络 - 云+社区 - 腾讯云

在使用神经网络解决实际的分类或回归问题时需要设置好参数取值。下面介绍使用监督学习的方式来合理地设置参数取值,同时也将给出tensorflow程序来完成这个过程。设置神经网络参数的过程就是神经网络的训练过程。只有经过有效训练的神经网络模型才可以真正地解决分类或者回归问题。

用监督学习的方式设置神经网络参数需要有一个标注好的训练数据集。监督学习最重要的思想是,在一直答案的标注数据集上,模拟给出预测结果要尽量逼近真实的答案。通过调整神经网络中地参数对训练数据进行拟合,可以使得模块对未知的样本提供预测的能力。在神经网络优化算法中,最常用的方法是反向传播算法(backpropagation)。反向传播算法的具体工作原理如下图

                     

从上图可以看出,反向传播算法实现了一个迭代过程。在每次迭代的开始,首先需要选取一部分训练数据,这一小部分数据叫做一个batch。然后,这个batch的样例通过前向传播算法得到神经网络模型的预测结果。因为训练数据都是有正确答案标注的,所以可以计算出当前神经网络模型的预测答案与真实答案之间的差距。最后,基于预测值和真实值之间的差距,反向传播算法会相应更新神经网络参数的取值,使得在这个batch上的神经网络模型的预测结果和真实结果更加接近。

通过tensorflow实现反向传播算法的第一步是使用tensorflow表达一个batch的数据。例如使用常量来表达过一个样例:

x = tf.constant([0.7, 0.9])

但如果每轮迭代中选取的数据都要通过常量来表示,那么tensorflow的计算图将会太大。因为每生成一个常量,tensorflow都会在计算图中增加一个节点。一般来说,一个神经网络的训练过程会需要几百万甚至几亿轮的迭代,这样计算图就会非常大,而且利用率很低。

为了避免这个问题,tensorflow提供了placeholder机制用于提供输入数据。

placeholder相当于定义了一个位置,这个位置中的数据在程序运行时再指定。

这样在程序中就不需要生成大量常数来提供输入数据,而只需要将数据通过placeholder传入tensorflow计算图。

在placeholder定义时,这个位置上的数据类型是需要指定的。

和其他张量一样,placeholder的类型也是不可以改变的。

placeholder中的数据的维度信息可以根据提供的数据推导得出,所以不一定要给出。

下面给出了通过placeholder实现前向传播算法的代码:

import tensorflow as tf

w1 = tf.Variable(tf.random_normal([2, ,3], stddev = 1, seed = 1))
w2 = tf.Variable(tf.random_normal([3, ,1], stddev = 1, seed = 1))

# 定义placeholder作为存放数据的地方。这里维度也不一定要定义。
# 但如果维度是确定的,那么给出维度可以降低出错的概率。
x = tf.placeholder(tf.float32, shape=(1,2), name = "input")
a = tf.matmul(x, w1)
y = tf.matmul(a, w2)

sess = tf.Session()
init_op = tf.global_variables_initializer()
sess.run(init_op)

# 下面一行将报错: InvalidArgumentError: You must feed a value for placeholder
# tensor 'input_1' with dtype float and shape [1, 2]
print(sess.run(y))

print(sess.run(y , feed_dict = {x: {[[0.7, 0.9]]}))

在这段程序中替换了原来通过常量定义的输入x。在新的程序中计算前向传播结果时,需要提供一个feed_dict来指定x的取值。feed_dict是一个字典(map),在字典中需要给出每个用到的placeholder的取值。如果某个需要的placeholder没有被指定取值,那么程序在运行时将会报错。

以上程序只计算了一个样例的前向传播结果,但如上图所示,在训练神经网络时需要每次提供一个batch的训练样例。对于这样的需求,placeholder也可以很好的支持。在以上程序中,如果将输入的1*2矩阵改为n*2的矩阵,那么就可以得到n个样例的前向传播结果了。其中n*2的矩阵的每一行为一个样例数据。这样前向传播的结果为n*1的矩阵,这个矩阵的每一行就代表了一个样例的前向传播结果。以下代码给出了一个示例。

x = tf.placeholder(tf.float32, shape=(3, 2, name = "input"))
... # 中间部分和上面的样例程序一样

# 因为x在定义是制定了n为3,所以在运行前向传播过程时需要提供3个样例数据
print(sess.run(y, feed_dict={x: [[0.7, 0.9],[0.1, 0.4],[0.5, 0.8]]}))


'''
输出结果为:
[[3.95757794]
[1.15376544]
[3.16749191]]

以上样例展示了一次性计算多个样例的前向传播结果。在运行时,需要将三个样例[0.7, 0.9]、[0.1, 0.4]和[0.5, 0.8]组成一个3*2的矩阵传入placeholder。计算得到的结果为3*1的矩阵。其中第一行3.95757794为样例[0.7, 0.9]的前向传播结果:1.15376544为样例[0.1, 0.4]的前向传播结果;3.16749191为样例[0.5, 0.8]的前向传播结果。

在得到一个batch的前向传播结果以后,需要定义一个损失函数来刻画当前的预测值和真实答案之间的差距。然后通过反向传播算法来调整神经网络参数的取值是的差距可以被缩小。下面代码定义了一个简单的损失函数,并通过tensorflow定义了反向传播算法。

# 使用sigmoid函数将y转换为0~1之间的数值。转换后y代表预测是正样本的概率,1-y代表
# 预测是负样本的概率
y = tf.sigmoid(y)


# 定义损失函数来刻画预测值与真实值得差距
cross_entropy = - tf.reduce_mean(
      y_ * tf.log(tf.clip_by_value(y, 1e-10, 1.0))
      + (1 - y_)*tf.log(tf.clip_by_value(1-y, 1e-10, 1.0)))


# 定义学习率
learning_rate = 0.001


# 定义反向传播算法来优化神经网络中的参数
train_step = \
    tf.train.AdaOptimizer(learning_rate).minimize(cross_entropy)

在以上代码中,cross_entropy定义了真实值和预测值之间的交叉熵(cross_entropy),这是一个分类问题中一个常用的损失函数。第二行train_step定义了反向传播的优化方法。目前,tensorflow支持10种不同的优化器,常用的优化器有三种:

tf.train.GradientDescentOptimizer

tf.train.AdamOptimizer

tf.train.MomentumOptimizer

在定义了反向传播算法之后,通过运行sess.run(train_step)就可以对所有在GraphKeys.TRAINBLE_VARIABLES集合中的变量进行优化,使得在当前batch下损失函数最小。

下面给出了一个完整的程序来训练神经网络解决二分类问题。

import tensorflow as tf

# NumPy是一个科学计算的工具包,这里通过NumPy工具包生成模拟数据集。
from numpy.random import RandomState

# 定义训练数据batch的大小。
batch_size = 8

# 定义神经网络的参数。
w1 = tf.Variable(tf.random_normal([2, 3], stddev=1, seed=1))
w2 = tf.Variable(tf.random_normal([3, 1], stddev=1, seed=1))

# 在shape的一个维度上使用None可以方便使用不同的batch大小。在训练时需要把数据分
# 成比较小的batch,但是在测试时,可以一次性使用全部的数据。当数据集比较小时这样
# 比较方便测试,但数据集比较大时,将大量数据放入一个batch可能会导致内存溢出。

x  = tf.placeholder(tf.float32, shape=[None, 2], name='x-input')
y_ = tf.placeholder(tf.float32, shape=[None, 1], name='y-input')

# 定义神经网络前向传播的过程。
a = tf.matmul(x, w1)
y = tf.matmul(a, w2)



# 定义损失函数和反向传播的算法.
y = tf.sigmoid(y)

cross_entropy = -tf.reduce_mean(
    y_ * tf.log(tf.clip_by_value(y, 1e-10, 1.0))
    + (1 - y_) * tf.log(tf.clip_by_value(1 - y, 1e-10, 1.0)))

train_step = tf.train.AdamOptimizer(0.001).minimize(cross_entropy)



# 通过随机数生成一个模拟数据集。
rdm = RandomState(1)
dataset_size = 128
X = rdm.rand(dataset_size, 2)
# 定义规则来给出样本的标签,在这里所有x1+x2<1的样例都被认为是正样本,
# 而其他为负样本
# 在这里使用0来表示负样本,1来表示正样本。大部分解决分类问题的神经网络都会采用
# 0和1的表示方法。
Y = [[int(x1 + x2 < 1)] for (x1, x2) in X]

with tf.Session() as sess:
    init_op = tf.global_variables_initializer()
    sess.run(init_op)

    # 设定训练的轮数。
    STEPS = 5000
    for i in range(STEPS):
       # 每次选取batch_size个样本进行训练。
       start = (i * batch_size) % dataset_size
       end = min(start + batch_size, dataset_size)

       # 通过选取的样本训练神经网络并更新参数。
       sess.run(train_step,
             feed_dict={x: X[start:end], y_: Y[start:end]})

       if i % 1000 == 0:
          # 每隔一段时间计算在所有数据上的交叉熵并输出。
          total_cross_entropy = sess.run(
            cross_entropy, feed_dict={x: X, y_: Y})
          print("After %d training step(s), cross entropy on all data is %g" %
           (i, total_cross_entropy))
    print(sess.run(w1))
    print(sess.run(w2))

以上程序实现了训练神经网络的全部过程。从这段程序可以总结出训练神经网络的过程可以分为三个步骤:

  1. 定义神经网络的结构和前向传播的结果
  2. 定义损失函数以及选择反向传播优化的算法
  3. 生成会话(tf.Session)并且在训练数据上反复运行反向传播优化算法

无论神经网络的结构如何变化,这三个步骤是不变的。

 

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/78312.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

基于JDBC的MySQL数据库编程

✨博客主页: 荣 ✨系列专栏: MySQL ✨一句短话: 难在坚持,贵在坚持,成在坚持! 文章目录一. JDBC概述二. JDBC前置工作1. 准备好MySQL驱动包2. 创建项目三. JDBC的使用步骤1. 创建数据源DataSourece2. 连接数据库3. 构造并执行sql语句4. 释放资源5. sql语句不要写死(以插入为例)…

C++入门基础07:函数定义与声明、函数传参(传值、传地址、传引用)、函数重载

C入门基础07&#xff1a;函数定义与声明、函数传参&#xff08;传值、传地址、传引用&#xff09;、函数重载 1、函数定义与声明 函数是一起执行一个任务的一组语句。每个程序&#xff08;C/C&#xff09;都有一个主函数 main() &#xff0c; 所有简单的程序都可以定义其他额…

1563_AURIX_TC275_EVR的控制寄存器

全部学习汇总&#xff1a; GreyZhang/g_TC275: happy hacking for TC275! (github.com) 1. 连续的写入必须确保是解锁状态&#xff0c;否则的话可能会导致所有的总线阻塞。 2. 多核系统中&#xff0c;尽量写入之后再读取一下确认写入的状态。 这里是过压以及欠压的几个监控结果…

if、for、while结构的用法

分支与循环的流程控制一、分支流程控制1. if分支语句结构1). 单分支2). 双分支3). 三元运算符&#xff08;双分支的简化&#xff09;4). 多分支二. 循环流程控制1. while循环语句语法结构1.while循环用法2.while 的死循环3. while else的用法2. for循环语句语法结构1. for和ran…

嵌入式开发学习之--串口通讯(上)

提示&#xff1a;本篇开始学习各种通讯方式&#xff0c;重中之重。 文章目录前言一、 串口通讯协议简介1.1 物理层1.2 协议层1.2.1 基本组成。1.2.2 波特率1.2.3 起始和终止位1.2.4 有效数据1.2.5 数据校验二、USART结构体详解总结前言 作为一个嵌入式的开发者&#xff0c;解析…

网页木马挂马的实现与防范

一、网页挂马是什么 不少用户都碰到过这样的现象&#xff1a;打开一个网站&#xff0c;结果页面还没显示&#xff0c;杀毒软件就开始报警&#xff0c;提示检测到木马病毒。有经验的朋友会知道这是网页恶意代码&#xff0c;这就是典型的网页挂马现象。那么是什么原因导致了这种…

数据库概论之MySQL表的增删改查 - 进阶版本1

MySQL表的增删改查 - 进阶1、数据库约束1.1 约束类型1.2 NULL约束1.3 UNIQUE约束1.4 DEFAULT约束1.5 PRIMARY约束1.6 FOREIGN KEY外键约束1.6.1 语法1.6.2 工作原理2、表的设计2.1 一对一2.2 一对多2.3 多对多大家好&#xff0c;已经好久没更新了 , 学校的学业有点忙 , 没有额外…

[论文解析] Denoising Diffusion Probabilistic Models

文章目录OverviewsWhat problem is addressed in the paper?What is the key to the solution?What is the main contribution?Contents扩散概率模型背景算法实验结论Overviews What problem is addressed in the paper? We present high quality image synthesis result…

【Java面试】说说类加载机制(流程)

文章目录加载流程装载(Load)链接(Link)验证(Verify)准备(Prepare)解析(Resolve)初始化(Initialize)卸载(Unload)加载流程 类的加载流程如下&#xff1a; 转载(Load)->链接(Link)->初始化(Initialize)->使用(Use)->卸载(Unload) 其中链接又包含验证(Verify)&#x…

rabbitMQ延时队列——TTL和DLX

一. 场景&#xff1a;“订单下单成功后&#xff0c;15分钟未支付自动取消” 1.传统处理超时订单 采取定时任务轮训数据库订单&#xff0c;并且批量处理。其弊端也是显而易见的&#xff1b;对服务器、数据库性会有很大的要求&#xff0c; 并且当处理大量订单起来会很力不从…

flask前后端项目--实例-前端部分:-4-vue-Element Plus

flask前后端项目--实例-前端部分&#xff1a;-4-vue-Element Plus组件添加事项 一、实验测试步骤 1.Element Plus添加 1.先备份App.VUE&#xff0c;然后修改app.vue的内容&#xff0c;数据来源资Element Plus的表格table 2. 数据来源资Element Plus的表格table 3. 运行服务&…

023_SSS_Neural 3D Video Synthesis from Multi-view Video(CVPR2022)

Neural 3D Video Synthesis from Multi-view Video(CVPR2022) 本文提出了一种新的3D视频生成方法&#xff0c;这种方法能够以紧凑但富有表现力的表示形式表示动态真实世界场景的多视图视频记录&#xff0c;从而实现高质量的视图合成和运动插值。 1. Introduction 本文的主要…

百度地图 ( 一 ) 显示地图

1.百度地图 百度地图开放平台 https://lbsyun.baidu.com/ 使用百度地图时导入JavaScript包 <script type"text/javascript" src"http://api.map.baidu.com/api?v2.0&ak您的密钥"></script>1.1.如何申请 ak 密钥 在 开发平台 找 控制…

ChatGPT注册流程

1.访问官网点击 Sign up https://chat.openai.com/auth/login 2.输入你的邮箱 3.点击Continue下一步: 4.输入密码继续下一步&#xff1a; 5.然后你的邮箱会受到一封邮件&#xff08;如果没收到请检查垃圾邮箱&#xff09;&#xff1a; 6.点击验证邮箱按钮&#xff0c;会跳到…

MySQL 5.7中文乱码与远程链接问题

MySQL 5.7中文乱码与远程链接问题1. MySQL 5.7中文乱码2. 远程链接问题3. 不区分表大小写4. 超过最大连接数5. 时区问题5. GROUP BY 问题配置集合重启MySQL1. MySQL 5.7中文乱码 当我们直接在数据库里面输入中文时&#xff0c;保存后出现&#xff1a; Incorrect string value&…

LeetCode 第 244 场周赛题解

前言 这是 2021-06-06 的一场 LeetCode 周赛&#xff0c;本场周赛的题目相较而以往而言比较简单&#xff0c;基本上想到点上就可以做出来&#xff0c;主要涉及到矩阵的旋转、贪心、滑动窗口、前缀和、二分查找等知识点。 第 244 场周赛链接&#xff1a;https://leetcode-cn.c…

SpringBoot+Vue实现前后端分离的学校快递站点管理系统

文末获取源码 开发语言&#xff1a;Java 使用框架&#xff1a;spring boot 前端技术&#xff1a;JavaScript、Vue.js 、css3 开发工具&#xff1a;IDEA/MyEclipse/Eclipse、Visual Studio Code 数据库&#xff1a;MySQL 5.7/8.0 数据库管理工具&#xff1a;phpstudy/Navicat JD…

机器人开发--雷达lidar

机器人开发--雷达lidar1 介绍2 分类2.1 整体分类2.2 机械式&#xff08;三角&ToF&#xff09;三角测距激光雷达ToF测距激光雷达3 机械式单线ToF激光雷达3.1 扫描原理3.2 不同材料反射率3.3 核心参数参考1 介绍 激光雷达&#xff08;英文&#xff1a;Laser Radar &#xff…

2023最新SSM计算机毕业设计选题大全(附源码+LW)之java制造类企业erp23725

面对老师五花八门的设计要求&#xff0c;首先自己要明确好自己的题目方向&#xff0c;并且与老师多多沟通&#xff0c;用什么编程语言&#xff0c;使用到什么数据库&#xff0c;确定好了&#xff0c;在开始着手毕业设计。 1&#xff1a;选择课题的第一选择就是尽量选择指导老师…

第十四周周报

学习目标&#xff1a; 一、论文“Vector Quantized Diffusion Model for Text-to-Image Synthesis”的Code 二、猫狗识别、人脸识别模型 学习内容&#xff1a; Code 学习时间&#xff1a; 12.4-12.9 学习产出&#xff1a; 一、论文Code 正向过程&#xff1a; 先通过Tam…