政安晨的机器学习笔记——跟着演练快速理解TensorFlow(适合新手入门)

news2024/11/15 18:14:19

准备工作        

本笔记是假设您已经安装了Windows系统或Ubuntu系统的Anaconda(或 Miniconda)、Jupyter Notebook、TensorFLow,稍微了解Python语言,并可以进行一点点操作的基础上进行的。

        如果您还不具备这个条件,去看我的政安晨笔记里关于准备工作的文章:

基于Anaconda安装TensorFlow并尝试一个神经网络小实例icon-default.png?t=N7T8https://blog.csdn.net/snowdenkeke/article/details/135841281示例讲解机器学习工具Jupyter Notebook入门(超级详细)icon-default.png?t=N7T8https://blog.csdn.net/snowdenkeke/article/details/135880886实例讲解深度学习工具PyTorch在Ubuntu系统上的安装入门(基于Miniconda)(非常详细)icon-default.png?t=N7T8https://blog.csdn.net/snowdenkeke/article/details/135887509

基于Ubuntu系统的Miniconda安装Jupyter Notebookicon-default.png?t=N7T8https://blog.csdn.net/snowdenkeke/article/details/135919533

基于Ubuntu系统的Miniconda安装TensorFlow并使用Jupyter Notebook在多个Conda虚拟环境下管理测试icon-default.png?t=N7T8https://blog.csdn.net/snowdenkeke/article/details/135905122

当您准备好了之后,让咱们开始接下来有趣的旅程。

一、走一个机器学习里的“Hello World”

打开Jupyter Notebook,新建一个笔记文件:tf-hellloworld。

虽然这个程序我在我另外的文章中示例过了,但它比较典型,可以着重使用一下

将下述代码复制到笔记本的cell单元格中:

import tensorflow as tf
mnist = tf.keras.datasets.mnist

(x_train, y_train),(x_test, y_test) = mnist.load_data()
x_train, x_test = x_train / 255.0, x_test / 255.0

model = tf.keras.models.Sequential([
  tf.keras.layers.Flatten(input_shape=(28, 28)),
  tf.keras.layers.Dense(128, activation='relu'),
  tf.keras.layers.Dropout(0.2),
  tf.keras.layers.Dense(10, activation='softmax')
])

model.compile(optimizer='adam',
              loss='sparse_categorical_crossentropy',
              metrics=['accuracy'])

model.fit(x_train, y_train, epochs=5)
model.evaluate(x_test, y_test)

如下图所示:

完成训练后,您也完成了这个Hello World,同时也证明您电脑中的TensorFlow是可用的,咱们就可以往下继续了。

二、再做一个简单的机器学习示例并说明一下

咱们新建一个文件(ExampleofKeras):

在笔记的单元格中输入如下代码:

import tensorflow as tf

mnist = tf.keras.datasets.mnist

(x_train, y_train), (x_test, y_test) = mnist.load_data()
x_train, x_test = x_train / 255.0, x_test / 255.0

model = tf.keras.models.Sequential([
  tf.keras.layers.Flatten(input_shape=(28, 28)),
  tf.keras.layers.Dense(128, activation='relu'),
  tf.keras.layers.Dropout(0.2),
  tf.keras.layers.Dense(10)
])

predictions = model(x_train[:1]).numpy()
predictions

tf.nn.softmax(predictions).numpy()

loss_fn = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True)

loss_fn(y_train[:1], predictions).numpy()

model.compile(optimizer='adam',
              loss=loss_fn,
              metrics=['accuracy'])

model.fit(x_train, y_train, epochs=5)

model.evaluate(x_test,  y_test, verbose=2)

probability_model = tf.keras.Sequential([
  model,
  tf.keras.layers.Softmax()
])

probability_model(x_test[:5])

保存并执行一下:

上面这个程序做了如下这么几件事:

  1. 加载一个预构建的数据集。
  2. 构建对图像进行分类的神经网络机器学习模型。
  3. 训练此神经网络。
  4. 评估模型的准确率。

接下来分步骤分析一下:

第一步:设置 TensorFlow

首先将 TensorFlow 导入到您的程序:

import tensorflow as tf

第二步: 加载数据集

加载并准备MNIST数据集。将样本数据从整数转换为浮点数:

mnist = tf.keras.datasets.mnist

(x_train, y_train), (x_test, y_test) = mnist.load_data()
x_train, x_test = x_train / 255.0, x_test / 255.0

第三步:构建机器学习模型

通过堆叠层来构建tf.keras.Sequential模型。

model = tf.keras.models.Sequential([
  tf.keras.layers.Flatten(input_shape=(28, 28)),
  tf.keras.layers.Dense(128, activation='relu'),
  tf.keras.layers.Dropout(0.2),
  tf.keras.layers.Dense(10)
])

对于每个样本,模型都会返回一个包含logits或log-odds分数的向量,每个类一个。

predictions = model(x_train[:1]).numpy()
predictions

logits:

分类模型生成的原始(非规范化)预测向量,通常会将其传递给规范化函数。如果模型解决的是多类分类问题,对数通常会成为 softmax 函数的输入。然后,softmax 函数会生成一个(归一化)概率向量,每个可能的类别都有一个值。

log-odds:

某些事件发生的几率的对数。

tf.nn.softmax函数将这些 logits 转换为每个类的概率

tf.nn.softmax(predictions).numpy()

虽然还可以将tf.nn.softmax烘焙到网络最后一层的激活函数中。虽然这可以使模型输出更易解释,但不建议使用这种方式,因为在使用 softmax 输出时不可能为所有模型提供精确且数值稳定的损失计算。

使用SparseCategoricalCrossentropy为训练定义损失函数,它会接受 logits 向量和 True 索引,并为每个样本返回一个标量损失。

loss_fn = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True)

此损失等于 true 类的负对数概率:如果模型确定类正确,则损失为零。

这个未经训练的模型给出的概率接近随机(每个类为 1/10),因此初始损失应该接近 -tf.math.log(1/10) ~= 2.3

loss_fn(y_train[:1], predictions).numpy()

在开始训练之前,使用 Keras model.compile 配置和编译模型。将optimizer类设置为 adam,将 loss 设置为您之前定义的 loss_fn 函数,并通过将 metrics 参数设置为 accuracy 来指定要为模型评估的指标。

model.compile(optimizer='adam',
              loss=loss_fn,
              metrics=['accuracy'])

第四步:训练并评估模型

使用model.fit方法调整您的模型参数并最小化损失

model.fit(x_train, y_train, epochs=5)

 方法通常在 "Validation-set" 或 "Test-set" 上检查模型性能。

model.evaluate(x_test,  y_test, verbose=2)

Validation-set 说明

针对训练好的模型进行初始评估的数据集子集。通常情况下,在根据测试集评估模型之前,要根据验证集多次评估训练有素的模型。
传统上,数据集中的示例分为以下三个不同的子集:
训练集
验证集
测试集
理想情况下,数据集中的每个示例应只属于前面的一个子集。例如,一个例子不应同时属于训练集和验证集。

Test-set 说明

数据集的子集,用于测试训练有素的模型。
传统上,我们将数据集中的示例分为以下三个不同的子集:
训练集
验证集
测试集
数据集中的每个示例只能属于前面的一个子集。例如,一个例子不应同时属于训练集和测试集。
训练集和验证集都与模型的训练密切相关。由于测试集仅与训练间接相关,因此测试损失是一个比训练损失或验证损失更少偏差、更高质量的指标。

现在,这个照片分类器的准确度已经达到 98%。

如果您想让模型返回概率,可以封装经过训练的模型,并将 softmax 附加到该模型:

probability_model = tf.keras.Sequential([
  model,
  tf.keras.layers.Softmax()
])
probability_model(x_test[:5])

结论

恭喜小伙伴!您已经利用Keras API 基于预构建数据集成功训练了一个机器学习模型。这是全新的开始,是您在机器学习领域深入探索的第一步。这是全新的领域,充满挑战,但同时,也充满着机会。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1424036.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

Android进阶之路 - ViewPager2 比 ViewPager 强在哪?

我记得前年(2022)面试的时候有被问到 ViewPager 和 ViewPager2 有什么区别?当时因为之前工作一直在开发售货机相关的项目,使用的技术要求并不高,所以一直没去了解过 ViewPager2~ 去年的时候正好有相关的功能需求&#…

安卓线性布局LinearLayout

<?xml version"1.0" encoding"utf-8"?> <LinearLayout xmlns:android"http://schemas.android.com/apk/res/android"xmlns:tools"http://schemas.android.com/tools"android:layout_width"match_parent"android:…

关于谷歌新版调试用具(Chrome Dev Tool ),网络选项(chrome-network)默认开启下拉模式的设置

今天在使用谷歌浏览器进行调试的时候&#xff0c;打开调试工具网络选项发现过滤不同模式的选项卡不见了&#xff0c;转而变成一个下拉式选项&#xff0c;如下图 这样一来使得切换不同类型查看的时候变得非常不方便&#xff0c;然后网上查了一下发现这个功能谷歌在很早版本就已…

如果我要访问一个网址,那么在网络中会有哪些过程

访问一个网址是我们日常网络使用中非常常见的操作&#xff0c;背后涉及到一系列精密而复杂的步骤。这个过程包括DNS解析、建立TCP连接、发起HTTP请求、服务器处理请求、服务器响应、浏览器渲染等环节。在这篇文章中&#xff0c;我们将深入探讨这些步骤&#xff0c;并解释它们在…

Windows Server 2025 Azure Arc 介绍

Azure Arc 是一个扩展 Azure 平台的桥梁&#xff0c;可帮助你构建可灵活地跨数据中心、边缘和多云环境运行的应用程序和服务。使用一致的开发、操作和安全模型来开发云原生应用程序。 Azure Arc 可在新的和现有的硬件、虚拟化和 Kubernetes 平台、物联网设备和集成系统上运行。…

web应用课——(第四讲:中期项目——拳皇)

代码AC Git地址&#xff1a;拳皇——AC Git链接

DataX介绍

一、介绍 DataX 是一个异构数据源离线同步工具&#xff0c;致力于实现包括关系型数据库(MySQL、Oracle等)、HDFS、Hive、ODPS、HBase、FTP等各种异构数据源之间稳定高效的数据同步功能。 github地址 详细文档 操作手册 支持数据框架如下&#xff1a; 架构 Reader&#xff1…

LLM 推理优化探微 (1) :Transformer 解码器的推理过程详解

编者按&#xff1a;随着 LLM 赋能越来越多需要实时决策和响应的应用场景&#xff0c;以及用户体验不佳、成本过高、资源受限等问题的出现&#xff0c;大模型高效推理已成为一个重要的研究课题。为此&#xff0c;Baihai IDP 推出 Pierre Lienhart 的系列文章&#xff0c;从多个维…

2024年小年是哪一天?小年习俗记到手机便签

随着春节的临近&#xff0c;我们即将迎来一个重要的传统节日——“小年”。那么2024年小年是哪一天呢&#xff1f;关于2024年小年的具体日期&#xff0c;地域不同&#xff0c;节日时间有所不同。在北方&#xff0c;小年通常是在腊月二十三&#xff0c;即2月2日&#xff1b;而在…

locust--python实现的分布式性能测试工具

1.locust特点&#xff1a; 1.1 支持Python编写测试用例方案&#xff1b; 1.2 使用requests发送http请求&#xff1b; 1.3 使用协程实现&#xff0c;高并发时消耗更低&#xff1b; 1.4 使用Flask提供 Web UI&#xff1b; 1.5 有第三方插件支持扩展&#xff1b; 2.创建locust 性能…

【MySQL】学习并使用聚合函数和DQL进行分组查询

&#x1f308;个人主页: Aileen_0v0 &#x1f525;热门专栏: 华为鸿蒙系统学习|计算机网络|数据结构与算法 ​&#x1f4ab;个人格言:“没有罗马,那就自己创造罗马~” #mermaid-svg-t8K8tl6eNwqdFmcD {font-family:"trebuchet ms",verdana,arial,sans-serif;font-siz…

canvas自定义扩展方法:文字自动换行

查看专栏目录 canvas实例应用100专栏&#xff0c;提供canvas的基础知识&#xff0c;高级动画&#xff0c;相关应用扩展等信息。canvas作为html的一部分&#xff0c;是图像图标地图可视化的一个重要的基础&#xff0c;学好了canvas&#xff0c;在其他的一些应用上将会起到非常重…

15. 三数之和(力扣LeetCode)

文章目录 15. 三数之和题目描述双指针去重逻辑的思考a的去重b与c的去重 15. 三数之和 题目描述 给你一个整数数组 nums &#xff0c;判断是否存在三元组 [nums[i], nums[j], nums[k]] 满足 i ! j、i ! k 且 j ! k &#xff0c;同时还满足 nums[i] nums[j] nums[k] 0 。请 …

hbuilderx uniapp运行到真机控制台显示手机端调试基座版本号1.0.0,调用uni.share提示打包时未添加share模块

记录一个困扰了几天的一个蠢问题&#xff0c;发现真相的我又气又笑。 由于刚开始接触uniapp 移动端开发&#xff0c;有个需求需要使用uni.share API&#xff0c;但是我运行项目老提示打包时没配置share模块 我确实没在manifest内配置。网上搜了一些资料&#xff0c;但是我看官…

MySQL判断两个时间段是否重合

前提 新增的数据不能和数据库的时间有重合部分。 如图&#xff0c;4种重合情况和2种不重合情况。 时间段 a&#xff0c;b 数据库字段 start_time&#xff0c;end_time 第一种写法 列举每一种重合的情况&#xff1a; SELECT * FROM table WHERE(start_time > a and en…

大数据开发之离线数仓项目(用户行为采集平台)(可面试使用)

第 1 章&#xff1a;数据仓库概念 数据仓库&#xff0c;是为企业指定决策&#xff0c;提供数据支持的&#xff0c;可以帮助企业&#xff0c;改进业务流程、提高产品质量等。 数据仓库的输入数据通常包括&#xff1a;业务数据、用户行为数据和爬虫数据等。 业务数据&#xff1a…

写静态页面——粘性定位练习

0、效果&#xff1a; 1、HTML代码&#xff1a;为了简洁采用内部样式 <!DOCTYPE html> <html lang"en"> <head><meta charset"UTF-8"><meta name"viewport" content"widthdevice-width, initial-scale1.0"…

企业网络基础架构监控工具

IT 基础架构已成为提供基本业务服务的基石&#xff0c;无论是内部管理操作还是为客户托管的应用程序服务&#xff0c;监控 IT 基础设施至关重要&#xff0c;并且已经建立起来&#xff0c;SMB IT 基础架构需要简单的网络监控工具来监控性能和报告问题。通常&#xff0c;几个 IT …

【HTML】自定义属性(data)

自定义属性 data: 的用法&#xff08;如何设置,如何获取) &#xff0c;有何优势&#xff1f; data-* 的值的获取和设置&#xff0c;2种方法: 传统方法 getAttribute() 获取 data- 属性值; setAttribute() 设置 data- 属性值getAttribute() 获取 data- 属性值; setAttribute()…

强大的虚拟机Parallels Desktop 19 mac中文激活

Parallels Desktop是一款功能全面、易于使用的虚拟机软件&#xff0c;它为用户提供了在Mac电脑上同时运行多个操作系统的便利。 软件下载&#xff1a;Parallels Desktop 19 mac中文激活版下载 Parallels Desktop 19 mac具有快速启动和关闭虚拟机的能力&#xff0c;让用户能够迅…