tensorflow基础

news2025/1/15 7:05:03

tensorflow基础

  • (一)编程模型
    • (1)编程模型中的运行机制
    • (2)编写hello world程序
    • (3)使用注入机制进行代码编写
    • (4)保存和载入模型的方法介绍
      • (4.1)保存模型
      • (4.2)载入模型
    • (5)检查点(Checkpoint)
    • (6)模型操作常用函数总结

(一)编程模型

TensorFlow的命名来源于本身的运行原理。Tensor(张量)意味着N维数组,Flow(流)意味着基于数据流图的计算。TensorFlow是张量从图像的一端流动到另一端的计算过程,这也是TensorFlow的编程模型。

(1)编程模型中的运行机制

TensorFlow的运行机制属于“定义”与“运行”相分离。从操作层面可以抽象成两种:模型构建和模型运行。
模型构建中的概念
在这里插入图片描述
表中定义的内容都是在一个叫做“图”的容器中完成的。关于“图”,有以下几点需要理解
● 一个“图”代表一个计算任务。
● 在模型运行的环节中,“图”会在会话(session)里被启动。
● session将图的OP分发到如CPU或GPU之类的设备上,同时提供执行OP的方法。这些方法执行后,将产生的tensor返回。在Python语言中,返回的tensor是numpy ndarray对象;在C和C++语言中,返回的tensor是TensorFlow::Tensor实例。
在这里插入图片描述

(2)编写hello world程序

import tensorflow as tf
#1.定义一个常量
hello=tf.constant("hello")
# 2.定义session
sess=tf.Session()
# 3.使用sess进行运行
print(sess.run(hello))
sess.close()

"""
使用with 进行改进
"""
with tf.Session() as sess:
    print(sess.run(hello))

(3)使用注入机制进行代码编写

使用注入机制,将具体的实参注入到相应的placeholder中。feed只在调用它的方法内有效,方法结束后feed就会消失。

a=tf.placeholder(tf.float32)
b=tf.placeholder(tf.float32)
z=tf.multiply(a,b)
with tf.Session() as sess:
    print(sess.run(z,feed_dict={a:3,b:5}))

使用tf.placeholder为这些操作创建占位符,然后使用feed_dict把具体的值放到占位符里。

(4)保存和载入模型的方法介绍

(4.1)保存模型

首先需要建立一个saver,然后在session中通过saver的save即可将模型保存起来。代码如下:

    #之前是各种构建模型graph的操作(矩阵相乘,sigmoid等)
    saver = tf.train.Saver()                                 #生成saver
    with tf.Session() as sess:
        sess.run(tf.global_variables_initializer())     #先对模型初始化
        #然后将数据丢入模型进行训练blablabla
        #训练完以后,使用saver.save来保存
        saver.save(sess, "save_path/file_name")
                                            #file_name如果不存在,会自动创建

(4.2)载入模型

将模型保存好以后,载入也比较方便。在session中通过调用saver的restore()函数,会从指定的路径找到模型文件,并覆盖到相关参数中。代码如下:

    saver = tf.train.Saver()
    with tf.Session() as sess:
        #参数可以进行初始化,也可不进行初始化。即使初始化了,初始化的值也会被restore的
          值给覆盖
        sess.run(tf.global_variables_initializer())
        saver.restore(sess, "save_path/file_name")
                                #会将已经保存的变量值resotre到变量中。

(5)检查点(Checkpoint)

保存模型并不限于在训练之后,在训练之中也需要保存,因为TensorFlow训练模型时难免会出现中断的情况。我们自然希望能够将辛苦得到的中间参数保留下来,否则下次又要重新开始。这种在训练中保存模型,习惯上称之为保存检查点。

import tensorflow as tf
import numpy as np
import matplotlib.pyplot as plt

#定义生成loss可视化的函数
plotdata = { "batchsize":[], "loss":[] }
def moving_average(a, w=10):
    if len(a) < w:
      return a[:]
    return [val if idx < w else sum(a[(idx-w):idx])/w for idx, val in
    enumerate(a)]
#生成模拟数据
train_X = np.linspace(-1, 1, 100)
train_Y = 2*train_X + np.random.randn(*train_X.shape)*0.3
                                            # y=2x,但是加入了噪声
#图形显示
plt.plot(train_X, train_Y, 'ro', label='Original data')
plt.legend()
plt.show()
tf.reset_default_graph()
# 创建模型
# 占位符
X = tf.placeholder("float")
Y = tf.placeholder("float")
# 模型参数
W = tf.Variable(tf.random_normal([1]), name="weight")
b = tf.Variable(tf.zeros([1]), name="bias")
# 前向结构
z = tf.multiply(X, W)+ b
#反向优化
cost =tf.reduce_mean( tf.square(Y - z))
learning_rate = 0.01
optimizer = tf.train.GradientDescentOptimizer(learning_rate).minimize(cost)                                             #梯度下降
# 初始化所有变量
init = tf.global_variables_initializer()
# 定义学习参数
training_epochs = 20
display_step = 2
"""
tf.train.Saver(max_to_keep=1)代码创建saver时传入的参数max_to_keep=1代表:
在迭代过程中只保存一个文件。这样,在循环训练过程中,新生成的模型就会覆盖以前的模型。
"""
saver = tf.train.Saver(max_to_keep=1)       # 生成saver
savedir = "log/"
# 启动图
with tf.Session() as sess:
    sess.run(init)
    # 向模型中输入数据
    for epoch in range(training_epochs):
      for (x, y) in zip(train_X, train_Y):
          sess.run(optimizer, feed_dict={X: x, Y: y})
      #显示训练中的详细信息
      if epoch % display_step == 0:
            loss = sess.run(cost, feed_dict={X: train_X, Y:train_Y})
            print ("Epoch:", epoch+1, "cost=", loss, "W=", sess.run(W), "b=",sess.run(b))
    if not (loss == "NA" ):
      plotdata["batchsize"].append(epoch)
      plotdata["loss"].append(loss)
    saver.save(sess, savedir+"linermodel.cpkt", global_step=epoch)
print (" Finished! ")
print ("cost=", sess.run(cost, feed_dict={X: train_X, Y: train_Y}),
"W=", sess.run(W), "b=", sess.run(b))
#显示模型
plt.plot(train_X, train_Y, 'ro', label='Original data')
plt.plot(train_X, sess.run(W)* train_X + sess.run(b), label='Fitted Wline')
plt.legend()
plt.show()
plotdata["avgloss"] = moving_average(plotdata["loss"])
plt.figure(1)
plt.subplot(211)
plt.plot(plotdata["batchsize"], plotdata["avgloss"], 'b--')
plt.xlabel('Minibatch number')
plt.ylabel('Loss')
plt.title('Minibatch run vs. Training loss')
plt.show()

#重启一个session ,载入检查点
load_epoch=18
with tf.Session() as sess2:
    sess2.run(tf.global_variables_initializer())
    saver.restore(sess2, savedir+"linermodel.cpkt-" + str(load_epoch))
    print ("x=0.2, z=", sess2.run(z, feed_dict={X: 0.2}))

使用MonitoredTrainingSession函数来自动管理检查点文件

    import tensorflow as tf
    tf.reset_default_graph()
    global_step = tf train.get_or_create_global_step()
    step = tf.assign_add(global_step, 1)
    #设置检查点路径为log/checkpoints
    with  tf.train.MonitoredTrainingSession(checkpoint_dir='log/checkpoints',
    save_checkpoint_secs  = 2) as sess:
        print(sess.run([global_step]))
        while not sess.should_stop():        #启用死循环,当sess不结束时就不停止
          i = sess.run( step)
          print( i)

注意:(1)如果不设置save_checkpoint_secs参数,默认的保存时间间隔为10分钟。这种按照时间保存的模式更适用于使用大型数据集来训练复杂模型的情况。(2)使用该方法时,必须要定义global_step变量,否则会报错误。

(6)模型操作常用函数总结

在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/468014.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

代码随想录|day58|单调栈part01 ● 739. 每日温度 ● 496.下一个更大元素 I

739. 每日温度 链接&#xff1a;代码随想录 今天正式开始单调栈&#xff0c;这是单调栈一篇扫盲题目&#xff0c;也是经典题。 大家可以读题&#xff0c;思考暴力的解法&#xff0c;然后在看单调栈的解法。 就能感受出单调栈的巧妙 fvfdvsf fdfd ddf fdd fd fsd 496.下一个更…

轻量级服务器nginix:如何实现Spring项目的负载均衡

这里写目录标题 一 生成war包并给数据库导入数据1.1生成war包1.2 向数据库中导入数据 二 启动Tomcat三 配置负载均衡并启动Nginx1.cent121这台虚拟机上2.检测两个tomcat的运行状态3.配置nginx4.启动4.1 nginx报错4.2 成功启动项目 四 命令总结 一 生成war包并给数据库导入数据 …

docker和k8s基础介绍

一 Docker介绍 1.1 docker是什么 Docker 是一个开源项目&#xff0c; 诞生于 2013 年初&#xff0c;最初是 dotCloud 公司内部的一个业余项目。它基于 Google 公司推出的 Go 语言实现。 项目后来加入了 Linux 基金会&#xff0c;遵从了 Apache 2.0协议&#xff0c; 项目代码在…

SQL——索引

&#x1f4a1; 索引 在关系型数据库中&#xff0c;索引是一种单独的、物理上的对数据库表中的一列或多列的值进行排序的一种存储结构&#xff0c;他是某个表中的一列或着若干列值的集合和相应的指向表中物理标识这些值的数据页的逻辑指针清单&#xff08;类似于图书目录&#x…

RF技术设计的机械数码一体化防盗锁

机械数码一体化防盗锁在传统锁具的基础上增加了一个受控的弹子&#xff0c;只能通过设置过的合法钥匙开启&#xff0c;并且增加了防盗报警功能。本文介绍了基于PIC单片机、RF技术和无线数据传输技术的机械数码一体化防盗锁的设计。 引言 机械锁和数码锁是我们日常生活中最常见的…

CDGA|数据监管越来越严,数据治理发展何去何从?

尽管数据监管越来越严格&#xff0c;但仍然存在许多机会。事实上&#xff0c;数据监管的加强可能会促进金融科技行业更好地运用数据&#xff0c;激发金融科技行业更多的创新和合作,创造更多的价值和机会。 推动金融机构重视数据安全和隐私保护 促使他们采取更严格的安全措施&a…

Ueditor 富文本编辑器 插入 m3u8 和 mp4 视频(PHP)

当前环境&#xff1a;PHP、Ueditor的版本是1.4.3 新的需求是&#xff0c;需要在Ueditor 富文本编辑器中&#xff0c;插入视频播放&#xff0c;并且视频格式有MP4&#xff0c;也有M3U8。 百度编辑器默认的是embed&#xff0c;需要修改下配置。 ueditor.all.js和 ueditor.confi…

C++vector的动态扩容,为何是1.5倍或者是2倍

1. vector如何进行扩容 当插入元素时&#xff0c;如果size capacity&#xff0c;触发扩容机制。 拷贝元素释放旧空间 2. 如何避免扩容导致效率低 在插入前&#xff0c;预估好vector的容量&#xff0c;通常使用reserve。如果没有reserve&#xff0c;边插边扩容&#xff0c;…

ROS学习7:ROS机器人导航仿真

【Autolabor初级教程】ROS机器人入门 1. 概述 在 ROS 中机器人导航 (Navigation) 由多个功能包组合实现&#xff0c;ROS 中又称之为导航功能包集&#xff0c;关于导航模块&#xff0c;官方介绍如下 一个二维导航堆栈&#xff0c;它接收来自里程计、传感器流和目标姿态的信息&a…

汽车智能化「出海」红利

在高阶智能座舱中&#xff0c;车载导航产品作为与用户体验息息相关的模块之一&#xff0c;同样也进入了升级迭代周期。 基于高精度地图渲染、高精度定位算法、AR等技术的车道级导航、AR导航等产品快速上车&#xff0c;但同时随着人机交互多模发展以及3D沉浸式用户体验需求趋势下…

DataX-在Windows上实现postgresql同步数据到mysql

场景 DataX-阿里开源离线同步工具在Windows上实现Sqlserver到Mysql全量同步和增量同步: DataX-阿里开源离线同步工具在Windows上实现Sqlserver到Mysql全量同步和增量同步_霸道流氓气质的博客-CSDN博客 在上面实现sqlserver到mysql的数据同步之后&#xff0c;如果要实现postg…

FPGA/Verilog HDL/AC620零基础入门学习——8*8同步FIFO实验

实验要求 该项目主要实现一个深度为8、位宽为8bit的同步FIFO存储单元。模块功能应包括读控制、写控制、同时读写控制、FIFO满状态、FIFO空状态等逻辑部分。 该项目由一个功能模块和一个testbench组成。其中功能模块的端口信号如下表所示。 提示&#xff1a; &#xff08;1&a…

Sqoop 从入门到精通

Sqoop Sqoop 架构解析 概述 Sqoop是Hadoop和关系数据库服务器之间传送数据的一种工具。它是用来从关系数据库如&#xff1a;MySQL&#xff0c;Oracle到Hadoop的HDFS&#xff0c;并从Hadoop的文件系统导出数据到关系数据库。 传统的应用管理系统&#xff0c;也就是与关系型数…

【微服务笔记24】微服务组件之Sleuth + Zipkin实现服务调用链路追踪功能

这篇文章&#xff0c;主要介绍微服务组件之Sleuth Zipkin实现服务调用链路追踪功能。 目录 一、Sleuth链路追踪 1.1、什么是Sleuth 1.2、Sleuth专业术语 &#xff08;1&#xff09;Span &#xff08;2&#xff09;Trace &#xff08;3&#xff09;工作原理 1.3、Sleuth…

月薪15K必会技术,如何从0到1学习性能测试,5个操作安排的明明白白

目录 【开幕】武林秘籍惊现江湖 【第一幕】该不该预测一个初始值&#xff1f; 【第二幕】从单线程开始 【第三幕】用命令行形式跑性能测试&#xff0c;然后观察机器性能。 【第四幕】控制吞吐&#xff01;控制吞吐&#xff01;控制吞吐&#xff01; 【第五幕】武林秘籍重…

10 dubbo源码学习_线程池

1. 线程模型&线程池介绍1.1 线程池1.2 线程模型 2. 线程池源码分析2.1 FixedThreadPool2.2 CachedThreadPool2.3 LimitedThreadPool 3. 线程模型源码3.1 AllDispatcher3.2 DirectDispatcher3.3 MessageOnlyDispatcher3.4 ExecutionDispatcher3.5 ConnectionOrderedDispatch…

Visual Studio C# WinForm开发入门(6):TreeView 控件使用

TreeView控件用树显示节点层次。 例如&#xff1a;顶级目录是根(C:)&#xff0c;C盘下的每个子目录都是子节点&#xff0c;而每个子目录又都有自己的子节点 TreeView属性和方法&#xff1a; 属性说明CheckBoxes表示节点旁边是否出现复选框ImageList指定一个包含节点图标的Imag…

Spring Cloud Gateway 服务网关的部署与使用详细介绍

为什么需要服务网关 传统的单体架构中只需要开放一个服务给客户端调用&#xff0c;但是微服务架构中是将一个系统拆分成多个微服务&#xff0c;如果没有网关&#xff0c;客户端只能在本地记录每个微服务的调用地址&#xff0c;当需要调用的微服务数量很多时&#xff0c;它需要…

【音视频第20天】wireshark+tcpdump

tcpdump抓 wireshark分析 目录 tcpdumpwireshark tcpdump tcpdump参数详解 网上一搜一大堆。最全的不是用tcpdump -h而是man tcpdump来查询手册。 tcpdump -i eth0 -p udp -xx -Xs 0 -w /root/test2.cap -i 针对eth0网卡的&#xff0c;ifconfig是查看有几个网卡 -i eth0 表示…

海睿思分享 | 终于有人把指标体系和标签体系说清楚了

当前&#xff0c;随着企业数字化转型如火如荼地开展&#xff0c;在企业经营管理数字化的数据建设过程中&#xff0c;经常会遇到指标和标签的使用场景。 指标体系到底是什么&#xff1f;标签体系又是什么&#xff1f;这些疑问导致在数据分析过程中效率低下、科学性不高&#xf…