零基础学人工智能:TensorFlow 入门例子

news2024/10/1 23:44:22

识别手写图片

因为这个例子是 TensorFlow 官方的例子,不会说的太详细,会加入了一点个人的理解,因为TensorFlow提供了各种工具和库,帮助开发人员构建和训练基于神经网络的模型。TensorFlow 中最重要的概念是张量(Tensor),它代表了多维数组或矩阵,因此 TensorFlow 支持各种不同类型的计算,如线性回归、逻辑回归、卷积神经网络、循环神经网络等。所以帮我们极大减少了对数学与算法基础的要求。

准备数据

这里用来识别的手写图片大致是这样的,为了降低复杂度,每个图片是 28*28 大小。

 

但是直接丢图片给我们的模型,模型是不认识的,所以必须要对图片进行一些处理。

如果了解线性代数,大概知道图片的每个像素点其实可以表示为一个二维的矩阵,对图片做各种变换,比如翻转啊什么的就是对这个矩阵进行运算,于是我们的手写图片大概可以看成是这样的:

 

这个矩阵展开成一个向量,长度是 28*28=784。我们还需要另一个东西用来告诉模型我们认为这个图片是几,也就是给图片打个 label。这个 label 也不是随便打的,这里用一个类似有 10 个元素的数组,其中只有一个是 1,其它都是 0,哪位为 1 表示对应的图片是几,例如表示数字 8 的标签值就是 ([0,0,0,0,0,0,0,0,1,0])。

这些就是单张图片的数据处理,实际上为了高效的训练模型,会把图片数据和 label 数据分别打包到一起,也就是 MNIST 数据集了。

MNIST数据集

MNIST 数据集是一个入门级的计算机视觉数据集,官网是Yann LeCun's website。 我们不需要手动去下载这个数据集, 1.0 的 TensorFlow 会自动下载。

这个训练数据集有 55000 个图片数据,用张量的方式组织的,形状如 [55000,784],如下图:

 

还记得为啥是 784 吗,因为 28*28 的图片。
label 也是如此,[55000,10]:

 

这个数据集里面除了有训练用的数据之外,还有 10000 个测试模型准确度的数据。

整个数据集大概是这样的:

 

现在数据有了,来看下我们的模型。

Softmax 回归模型

Softmax 中文名叫归一化指数函数,这个模型可以用来给不同的对象分配概率。比如判断

 

的时候可能认为有 80% 是 9,有 5% 认为可能是 8,因为上面都有个圈。

我们对图片像素值进行加权求和。比如某个像素具有很强的证据说明这个图不是 1,则这个像素相应的权值为负数,相反,如果这个像素特别有利,则权值为正数。

如下图,红色区域代表负数权值,蓝色代表正数权值。

 

同时,还有一个偏置量(bias) 用来减小一些无关的干扰量。

Softmax 回归模型的原理大概就是这样,更多的推导过程,可以查阅一下官方文档,有比较详细的内容。

说了那么久,终于可以上代码了。

训练模型

具体引入的一些包这里就不一一列出来,主要是两个,一个是 tensorflow 本身,另一个是官方例子里面用来输入数据用的方法。

from tensorflow.examples.tutorials.mnist import input_data
import tensorflow as tf

之后就可以建立我们的模型。

# Create the model
  x = tf.placeholder(tf.float32, [None, 784])
  W = tf.Variable(tf.zeros([784, 10]))
  b = tf.Variable(tf.zeros([10]))
  y = tf.matmul(x, W) + b

这里的代码都是类似占位符,要填了数据才有用。

  • x 是从图片数据文件里面读来的,理解为一个常量,一个输入值,因为是 28*28 的图片,所以这里是 784;
  • W 代表权重,因为有 784 个点,然后有 10 个数字的权重,所以是 [784, 10],模型运算过程中会不断调整这个值,可以理解为一个变量;
  • b 代表偏置量,每个数字的偏置量都不同,所以这里是 10,模型运算过程中也会不断调整这个值,也是一个变量;
  • y 基于前面的数据矩阵乘积计算。

tf.zeros 表示初始化为 0。

我们会需要一个东西来接受正确的输入,也就是放训练时准确的 label。

 # Define loss and optimizer
  y_ = tf.placeholder(tf.float32, [None, 10])

我们会用一个叫交叉熵的东西来衡量我们的预测的「惊讶」程度。

关于交叉熵,举个例子,我们平常写代码的时候,一按编译,一切顺利,程序跑起来了,我们就没那么「惊讶」,因为我们的代码是那么的优秀;而如果一按编译,整个就 Crash 了,我们很「惊讶」,一脸蒙逼的想,这怎么可能。

交叉熵感性的认识就是表达这个的,当输出的值和我们的期望是一致的时候,我们就「惊讶」值就比较低,当输出值不是我们期望的时候,「惊讶」值就比较高。

  cross_entropy = tf.reduce_mean(
      tf.nn.softmax_cross_entropy_with_logits(labels=y_, logits=y))

这里就用了 TensorFlow 实现的 softmax 模型来计算交叉熵。
交叉熵,就是我们想要尽量优化的值,让结果符合预期,不要让我们太「惊讶」。

TensorFlow 会自动使用反向传播算法(backpropagation algorithm) 来有效的确定变量是如何影响你想最小化的交叉熵。然后 TensorFlow 会用你选择的优化算法来不断地修改变量以降低交叉熵。

train_step = tf.train.GradientDescentOptimizer(0.5).minimize(cross_entropy)

这里用了梯度下降算法(gradient descent algorithm)来优化交叉熵,这里是以 0.5 的速度来一点点的优化交叉熵。

之后就是初始化变量,以及启动 Session

sess = tf.InteractiveSession() tf.global_variables_initializer().run()

启动之后,开始训练!

 
# Train for _ in range(1000): batch_xs, batch_ys = mnist.train.next_batch(100) sess.run(train_step, feed_dict={x: batch_xs, y_: batch_ys})

这里训练 1000 次,每次随机找 100 个数据来练习,这里的 feed_dict={x: batch_xs, y_: batch_ys},就是我们前面那设置的两个留着占位的输入值。

到这里基本训练就完成了。

评估模型

训练完之后,我们来评估一下模型的准确度。

  # Test trained model
  correct_prediction = tf.equal(tf.argmax(y, 1), tf.argmax(y_, 1))
  accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))
  print(sess.run(accuracy, feed_dict={x: mnist.test.images,
                                      y_: mnist.test.labels}))

tf.argmax 给出某个tensor对象在某一维上的其数据最大值所在的索引值。因为我们的 label 只有一个 1,所以 tf.argmax(_y, 1) 就是 label 的索引,也就是表示图片是几。把计算值和预测值 equal 一下就可以得出模型算的是否准确。
下面的 accuracy 计算的是一个整体的精确度。

这里填入的数据不是训练数据,是测试数据和测试 label。

最终结果,我的是 0.9151,91.51% 的准确度。官方说这个不太好,如果用一些更好的模型,比如多层卷积网络等,这个识别率可以到 99% 以上。

官方的例子到这里就结束了,虽然说识别了几万张图片,但是我一张像样的图片都没看到,于是我决定想办法拿这个模型真正找几个图片测试一下。

用模型测试

看下上面的例子,重点就是放测试数据进去这里,如果我们要拿图片测,需要先把图片变成相应格式的数据。

sess.run(accuracy, feed_dict={x: mnist.test.images, 
                            y_: mnist.test.labels})

看下这里 mnist 是从 tensorflow.examples.tutorials.mnist 中的 input_data 的 read_data_sets 方法中来的。

from tensorflow.examples.tutorials.mnist import input_data 
mnist = input_data.read_data_sets(FLAGS.data_dir, one_hot=True)

Python 就是好,有啥不懂看下源码。源码的在线地址在这里

打开找 read_data_sets 方法,发现:

from tensorflow.contrib.learn.python.learn.datasets.mnist import read_data_sets

在这个文件里面。

def read_data_sets(train_dir,
                   fake_data=False,
                   one_hot=False,
                   dtype=dtypes.float32,
                   reshape=True,
                   validation_size=5000):
                   ...
                   ...
                   ...
    train = DataSet(train_images, train_labels, dtype=dtype, reshape=reshape)
    validation = DataSet(validation_images,
                        validation_labels,
                        dtype=dtype,
                        reshape=reshape)
    test = DataSet(test_images, test_labels, dtype=dtype, reshape=reshape)

    return base.Datasets(train=train, validation=validation, test=test)

可以看到,最后返回的都是是一个对象,而我们用的 feeddict={x: mnist.test.images, y: mnist.test.labels} 就是从这来的,是一个 DataSet 对象。这个对象也在这个文件里面。

class DataSet(object):

  def __init__(self,
               images,
               labels,
               fake_data=False,
               one_hot=False,
               dtype=dtypes.float32,
               reshape=True):
    """Construct a DataSet.
    one_hot arg is used only if fake_data is true.  `dtype` can be either
    `uint8` to leave the input as `[0, 255]`, or `float32` to rescale into
    `[0, 1]`.
    """
    ...
    ...

这个对象很长,我就只挑重点了,主要看构造方法。一定要传入的有 images 和 labels。其实这里已经比较明朗了,我们只要把单张图片弄成 mnist 格式,分别传入到这个 DataSet 里面,就可以得到我们要的数据。

网上查了下还真有,代码地址,对应的文章:www.jianshu.com/p/419557758…,文章讲的有点不清楚,需要针对 TensorFlow 1.0 版本以及实际目录情况做点修改。

直接上我修改后的代码:

from PIL import Image
from numpy import *

def GetImage(filelist):
    width=28
    height=28
    value=zeros([1,width,height,1])
    value[0,0,0,0]=-1
    label=zeros([1,10])
    label[0,0]=-1

    for filename in filelist:
        img=array(Image.open(filename).convert("L"))
        width,height=shape(img);
        index=0
        tmp_value=zeros([1,width,height,1])
        for i in range(width):
            for j in range(height):
                tmp_value[0,i,j,0]=img[i,j]
                index+=1

        if(value[0,0,0,0]==-1):
            value=tmp_value
        else:
            value=concatenate((value,tmp_value))

        tmp_label=zeros([1,10])
        index=int(filename.strip().split('/')[2][0])
        print "input:",index
        tmp_label[0,index]=1
        if(label[0,0]==-1):
            label=tmp_label
        else:
            label=concatenate((label,tmp_label))

    return array(value),array(label)

这里读取图片依赖 PIL 这个库,由于 PIL 比较少维护了,可以用它的一个分支 Pillow 来代替。另外依赖 numpy 这个科学计算库,没装的要装一下。

这里就是把图片读取,并按 mnist 格式化,label 是取图片文件名的第一个字,所以图片要用数字开头命名。

如果懒得 PS 画或者手写的画,可以把测试数据集的数据给转回图片,实测成功,参考这篇文章:如何用python解析mnist图片

新建一个文件夹叫 test_num,里面图片如下,这里命名就是 label 值,可以看到 label 和图片是对应的:

 

开始测试:

  print("Start Test Images")

  dir_name = "./test_num"
  files = glob2.glob(dir_name + "/*.png")
  cnt = len(files)
  for i in range(cnt):
    print(files[i])
    test_img, test_label = GetImage([files[i]])

    testDataSet = DataSet(test_img, test_label, dtype=tf.float32)

    res = accuracy.eval({x: testDataSet.images, y_: testDataSet.labels})

    print("output: ",  res)
    print("----------"

这里用了 glob2 这个库来遍历以及过滤文件,需要安装,常规的遍历会把 Mac 上的 .DS_Store 文件也会遍历进去。

 

可以看到我们打的 label 和模型算出来的是相符的。

然后我们可以打乱文件名,把 9 说成 8,把 0 也说成 8:

 

可以看到,我们的 label 和模型算出来的是不相符的。

 

恭喜,到着你就完成了一次简单的人工智能之旅。

总结

从这个例子中我们可以大致知道 TensorFlow 的运行模式:

 

例子中是每次都要走一遍训练流程,实际上是可以用 tf.train.Saver() 来保存训练好的模型的。这个入门例子完成之后能对 TensorFlow 有个感性认识。

TensorFlow 没有那么神秘,没有我们想的那么复杂,也没有我们想的那么简单,并且还有很多数学知识要补充呢。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1332706.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

【论文解读】3D视觉标定的显式文本解耦和密集对齐(CVPR 2023)

来源:投稿 作者:橡皮 编辑:学姐 论文链接:https://arxiv.org/abs/2209.14941 开源代码:https://github.com/yanmin-wu/EDA 图1所示。文本解耦,密集对齐的3D视觉标定。文本中的不同颜色对应不同的解耦分量。…

Pinely Round 3 (Div. 1 + Div. 2) E. Multiple Lamps(思维题 暴力 补写法)

题目 n(n<2e5)盏灯&#xff0c;编号1到n&#xff0c;一开始全是熄灭的 n个开关&#xff0c;第i个控制着所有i的倍数&#xff0c;按下i的时候&#xff0c;i、2i、...这些灯的状态会被翻转&#xff0c; m(m<2e5)个限制&#xff0c;第j条限制形如uj,vj&#xff0c;表示uj…

智能优化算法应用:基于向量加权平均算法3D无线传感器网络(WSN)覆盖优化 - 附代码

智能优化算法应用&#xff1a;基于向量加权平均算法3D无线传感器网络(WSN)覆盖优化 - 附代码 文章目录 智能优化算法应用&#xff1a;基于向量加权平均算法3D无线传感器网络(WSN)覆盖优化 - 附代码1.无线传感网络节点模型2.覆盖数学模型及分析3.向量加权平均算法4.实验参数设定…

springboot整合JPA 多表关联 :一对多 多对多

补充一下自定义SQL 这是连表查询&#xff0c;可以任意查出字符&#xff0c;用Map接收 Testvoid test3() {JPAQueryFactory jpaQueryFactory new JPAQueryFactory(em);QStudent student QStudent.student;QMessage message QMessage.message;//constructor(StuMesDto.class, …

【pynput】鼠标行为追踪并模拟

文章目录 前言基本思路安装依赖包实时鼠标捕获捕获鼠标位置捕获鼠标事件记录点击内容 效果图 利用本文内容从事的任何犯法行为和开发与本人无关&#xff0c;请理性利用技术服务大家&#xff0c;创建美好和谐的社会&#xff0c;让人们生活从繁琐中变得更加具有创造性&#xff01…

7. 结构型模式 - 代理模式

亦称&#xff1a; Proxy 意图 代理模式是一种结构型设计模式&#xff0c; 让你能够提供对象的替代品或其占位符。 代理控制着对于原对象的访问&#xff0c; 并允许在将请求提交给对象前后进行一些处理。 问题 为什么要控制对于某个对象的访问呢&#xff1f; 举个例子&#xff…

Oracle研学-查询

学自B站黑马程序员 1.单表查询 //查询水表编号为 30408 的业主记录 select * from T_OWNERS where watermeter30408 //查询业主名称包含“刘”的业主记录 select * from t_owners where name like %刘% //查询业主名称包含“刘”的并且门牌号包含 5 的业主记录 select * from…

Redis-运维

转自 极客时间 Redis 亚风 原文视频&#xff1a;https://u.geekbang.org/lesson/535?article681062 Redis 同步 Redis主从数据同步,主从第⼀次同步是全量同步 replicaof 主机 端口 #当前这个机器做Master的备份master如何判断slave是不是第⼀次来同步数据&#xff1a; Repl…

linux循环调度执行

9.2 循环调度执行 9.2.1 简介 cron的概念和crontab是不可分割的。 ​ crontab是一个命令&#xff0c;常见于Unix和Linux的操作系统之中用于设置周期性被执行的指令。 ​ 该命令从标准输入设备读取指令&#xff0c;并将其存放于“crontab”文件中&#xff0c;以供之后读取和执…

基于AT89C51单片机的8位密码锁仿真与实物制作

点击链接获取Keil源码与Project Backups仿真图&#xff1a; https://download.csdn.net/download/qq_64505944/88657969?spm1001.2014.3001.5503 源码获取 C 源码仿真图毕业设计实物制作步骤01 摘要 在日常的生活和工作中, 住宅与部门的安全防范、单位的文件档案、财务报表…

在使用 npm install的时候提示node-sass command faile 解决方案

在使用npm install的时候错误提示node-sass 相关的。错误信息如下图&#xff1a; 解决方法&#xff08;PS&#xff1a;凯哥的不适用&#xff09; 出现这种问题基本是由于node版本与sass版本不匹配导致的 方案1&#xff1a;卸载node&#xff0c;安装对应版本 方案2&#xff1…

关键字:protected关键字

在 Java 中&#xff0c;protected 是一个访问修饰符&#xff0c;用于修饰类成员&#xff08;成员变量、成员方法和构造方法&#xff09;。当一个类成员被声明为 protected 时&#xff0c;它可以在同一包中的其他类以及子类中被访问。 以下是 protected 关键字的解析&#xff1a…

5. 创建型模式 - 单例模式

亦称&#xff1a; 单件模式、Singleton 意图 单例模式是一种创建型设计模式&#xff0c; 让你能够保证一个类只有一个实例&#xff0c; 并提供一个访问该实例的全局节点。 问题 单例模式同时解决了两个问题&#xff0c; 所以违反了单一职责原则&#xff1a; 保证一个类只有一…

隧道裂缝检测_2【C++PCL】

作者:迅卓科技 简介:本人从事过多项点云项目,并且负责的项目均已得到好评! 公众号:迅卓科技,一个可以让您可以学习点云的好地方 1.前言 我们团队注重每一个细节,确保代码的可读性、可维护性和可扩展性达到最高标准。我们严格遵循行业最佳实践,采用模块化和面向对象的设…

k8s集群部署成功后某个节点突然出现notready状态的问题原因分析和解决办法

文章目录 1、问题描述2、查看node03的日志3、错误原因分析4、解决办法 1、问题描述 k8s集群配置为 一主三个节点&#xff1b;刚开始运行一直正常&#xff1b;某天突然node03主机状态变为notready&#xff0c;问题如下&#xff1a; 在master节点使用&#xff1a; #master节点…

Vue3学习(后端开发)

目录 一、安装Node.js 二、创建Vue3工程 三、用VSCode打开 四、源代码目录src 五、入门案例——手写src 六、测试案例 七、ref和reactive的区别 一、安装Node.js 下载20.10.0 LTS版本 https://nodejs.org/en 使用node命令检验安装是否成功 node 二、创建Vue3工程 在…

OLED显示原理7T1C基础分析(PWM与DC调光)

文章目录 一、7T1C设计要点分析1、先回顾一下上篇 发光过程三个阶段---复位、补偿、发光2、设计关键点一&#xff1a;复位、补偿、发光三阶段 控制信号严格分离3、基本亮度控制策略---DC调光 && PWM调光4、PWM调光频率 之 低频PWM/高频PWM---EM信号的控制细节5、功耗优…

蓝桥小课堂-平方和【算法赛】

问题描述 蓝桥小课堂开课啦&#xff01; 平方和公式是一种用于计算连续整数的平方和的数学公式。它可以帮助我们快速求解从 1 到 n 的整数的平方和&#xff0c;其中 n 是一个正整数。 平方和公式的表达式如下&#xff1a; 这个公式可以简化计算过程&#xff0c;避免逐个计算…

HarmonyOs4.0基础(一)

目录 一、HarmonyOs系统定义 1.1系统的技术特性(三大特征) 1.1.1、硬件互助、资源共享 1.1.2、一次开发、多端部署(面向开发者) 1.1.3、统一OS&#xff0c;弹性部署(支持多种API&#xff1a;ArkTs、JS、C/C、Java) 1.2、系统的技术架构 二、Harmony OS项目搭建 2.1、(D…

Elasticsearch的分片平衡问题解决

2023年11月份在某电商系统生产中的Elasticsearch&#xff08;以下简称ES&#xff09;集群突然&#xff0c;出现了大量慢查询告警&#xff0c;导致请求堆积。经过几天的排查发现了ES节点主分片和副本分片分布存在不均匀的问题。当然了暂未有定论是由于分片不均衡导致了性能下降&…