【Python深度学习(第二版)(3)】初识神经网络之深度学习hello world

news2024/11/28 22:33:37

文章目录

  • 一. 训练Keras中的MNIST数据集
  • 二. 工作流程
    • 1. 构建神经网络
    • 2. 准备图像数据
    • 3. 训练模型
    • 4. 利用模型进行预测
    • 5. (新数据上)评估模型精度

本节将首先给出一个神经网络示例,引出如下概念。了解完本节后,可以对神经网络在代码上的实现有一个整体的了解。

本节相关概念:

  • 样本
  • 标签
  • 层(layer)
  • 数据蒸馏
  • 密集连接
  • 10路softmax分类层
  • 编译(compilation)步骤的3个参数
  • 损失值、精度
  • 过拟合

我们来看一个神经网络的具体实例:使用Python的Keras库来学习手写数字分类。

在这个例子中,我们要解决的问题是,将手写数字的灰度图像(28像素×28像素)划分到10个类别中(从0到9)。我们将使用MNIST数据集。你可以将“解决”MNIST问题看作深度学习的“Hello World”,用来验证你的算法正在按预期运行。下图给出了MNIST数据集的一些样本。

在这里插入图片描述

说明

在机器学习中,分类问题中的某个类别叫作类(class),数据点叫作样本(sample)与某个样本对应的类叫作标签(label)(即描述:样本属于哪个类别)。

 

你不需要现在就尝试在计算机上运行这个例子。之后的文章会具体分析。

 

一. 训练Keras中的MNIST数据集

from tensorflow.keras.datasets import mnist 

(train_images, train_labels), (test_images, test_labels) = mnist.load_data()

train_images和train_labels组成了训练集,模型将从这些数据中进行学习。然后,我们在测试集(包括test_images和test_labels)上对模型进行测试。

图像被编码为NumPy数组,而标签是一个数字数组,取值范围是0~9。图像和标签一一对应。

 
看一下训练数据:

>>> train_images.shape 
 (60000, 28, 28) 
>>> len(train_labels) 
60000 
>>> train_labels 
array([5, 0, 4, ..., 5, 6, 8], dtype=uint8)

 

再来看一下测试数据:

>>> test_images.shape
(10000, 28, 28) 
>>> len(test_labels) 
10000 
>>> test_labels 
array([7, 2, 1, ..., 4, 5, 6], dtype=uint8)

 

二. 工作流程

我们的工作流程如下:
首先,将训练数据(train_images和train_labels)输入神经网络;
然后,神经网络学习将图像和标签关联在一起;
最后,神经网络对test_images进行预测,我们来验证这些预测与test_labels中的标签是否匹配。
 


具体的代码我们可以在 deep-learning-with-python-notebooks 中直接运行。


1. 构建神经网络

下面我们来构建神经网络,如下:

from tensorflow import keras 
from tensorflow.keras import layers 

model = keras.Sequential(
						 [ layers.Dense(512, activation="relu"), 
						 layers.Dense(10, activation="softmax") ])

 

神经网络的核心组件是层(layer)

具体来说,层从输入数据中提取表示。大多数深度学习工作涉及将简单的层链接起来,从而实现渐进式的数据蒸馏(data distillation)。深度学习模型就像是处理数据的筛子,包含一系列越来越精细的数据过滤器(也就是层)。
 

本例中的模型包含2个Dense层,它们都是密集连接(也叫全连接)的神经层。

第2层是一个10路softmax分类层,它将返回一个由10个概率值(总和为1)组成的数组。每个概率值表示当前数字图像属于10个数字类别中某一个的概率。
 

在训练模型之前,我们还需要指定编译(compilation)步骤的3个参数

  • 优化器(optimizer):模型基于训练数据来自我更新的机制,其目的是提高模型性能。
  • 损失函数(loss function):模型如何衡量在训练数据上的性能,从而引导自己朝着正确的方向前进。
  • 在训练和测试过程中需要监控的指标(metric):本例只关心精度(accuracy),即正确分类的图像所占比例。后面两章会详细介绍损失函数和优化器的确切用途。

如下代码展示了编译步骤。


model.compile(
			  optimizer="rmsprop", 
			  loss="sparse_categorical_crossentropy", 
			  metrics=["accuracy"]
			  )

 

2. 准备图像数据

在开始训练之前,我们先对数据进行预处理,将其变换为模型要求的形状,并缩放到所有值都在[0, 1]区间。前面提到过,训练图像保存在一个uint8类型的数组中,其形状为(60000, 28, 28),取值区间为[0,255]。我们将把它变换为一个float32数组,其形状为(60000, 28 *28),取值范围是[0, 1]。

下面准备图像数据,如代码所示。

train_images = train_images.reshape((60000, 28 * 28)) 
train_images = train_images.astype("float32") / 255 

test_images = test_images.reshape((10000, 28 * 28)) 
test_images = test_images.astype("float32") / 255

 

3. 训练模型

在Keras中,通过调用模型的fit方法调用数据,训练模型。

>>> model.fit(train_images, train_labels, epochs=5, batch_size=128) 
Epoch 1/5 
60000/60000 [===========================] - 5s - loss: 0.2524 - acc: 0.9273 Epoch 2/5 
51328/60000 [=====================>.....] - ETA: 1s - loss: 0.1035 - acc: 0.9692

训练过程中显示了两个数字:一个是模型在训练数据上的损失值(loss),另一个是模型在训练数据上的精度(acc)。我们很快就在训练数据上达到了0.989(98.9%)的精度。

现在我们得到了一个训练好的模型,可以利用它来预测新数字图像的类别概率(如下代码)。这些新数字图像不属于训练数据,比如可以是测试集中的数据。

 

4. 利用模型进行预测

>>> test_digits = test_images[0:10] 
>>> predictions = model.predict(test_digits) 
>>> predictions[0] 
array([1.0726176e-10, 1.6918376e-10, 6.1314843e-08, 8.4106023e-06, 2.9967067e-11, 3.0331331e-09, 8.3651971e-14, 9.9999106e-01, 2.6657624e-08, 3.8127661e-07], dtype=float32)

如上代码我们对11个test_images图片进行预测,是什么数字,我们拿到第一个图片预测的概率数组,其中索引为7时,概率最大(0.99999106,几乎等于1),所以根据我们的模型,这个数字一定是7。

>>> predictions[0].argmax() 
7 
>>> predictions[0][7] 
0.99999106

这里我们检查测试标签是否与之一致:

>>> test_labels[0] 

7

平均而言,我们的模型对这种前所未见的数字图像进行分类的效果如何?我们来计算在整个测试集上的平均精度,如下代码所示。

 

5. (新数据上)评估模型精度

>>> test_loss, test_acc = model.evaluate(test_images, test_labels) 
>>> print(f"test_acc: {test_acc}") 
test_acc: 0.9785

测试精度约为97.8%,比训练精度(98.9%)低不少。训练精度和测试精度之间的这种差距是过拟合(overfit)造成的。

过拟合是指机器学习模型在新数据上的性能往往比在训练数据上要差。

第一个例子到这里就结束了。你刚刚看到了如何用不到15行Python代码构建和训练一个神经网络,对手写数字进行分类。

 

之后的文章我们将详细描述每一个步骤的原理,并且将学到张量(输入模型的数据存储对象)、张量运算(层的组成要素)与梯度下降(可以让模型从训练示例中进行学习)。
 

参考:
《Python深度学习(第二版)》–弗朗索瓦·肖莱
https://www.redhat.com/zh/topics/digital-transformation/what-is-deep-learning

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1659090.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

springcloud整合网关(springcloud-gateway)

pom引入依赖 <dependency><groupId>org.springframework.cloud</groupId><artifactId>spring-cloud-starter-gateway</artifactId></dependency><!-- 服务注册 --><dependency><groupId>com.alibaba.cloud</groupId&…

ALV 排序、汇总

目录 前言 实战 汇总 分类汇总 排序 分类汇总分隔方式&#xff08;仅适用于LIST ALV&#xff09; 完整代码&#xff1a; 前言 在SAP ABAP ALV中&#xff0c;排序和汇总是两个关键特性&#xff0c;用于组织和分析数据显示。 排序 排序功能允许用户根据一个或多个…

Redis是单线程吗?为什么6.0之后引入了多线程?

Redis是单线程吗&#xff1f;为什么6.0之后引入了多线程&#xff1f; Redis 是单线程吗&#xff1f;Redis 单线程模式是怎样的&#xff1f;Redis 采用单线程为什么还这么快&#xff1f;Redis 6.0 之前为什么使用单线程&#xff1f;Redis 6.0 之后为什么引入了多线程&#xff1f…

从Apache HttpClient类库,说一说springboot应用程序中的AutoConfiguration的封装

一、背景 在使用httpclient框架请求http接口的时候&#xff0c;我们往往会需要自定义配置httpclient&#xff0c;而非直接使用。 <dependency><groupId>org.apache.httpcomponents</groupId><artifactId>httpclient</artifactId><version>…

Language2Pose: Natural Language Grounded Pose Forecasting # 论文阅读

URL https://arxiv.org/pdf/1907.01108 TD;DR 19 年 7 月 cmu 的文章&#xff0c;提出一种基于 natural language 生成 3D 动作序列的方法。通过一个简单的 CNN 模型应该就可以实现 Model & Method 首先定义一下任务&#xff1a; 输入&#xff1a;用户的自然语言&…

链表第4/9题--翻转链表--双指针法

LeetCode206&#xff1a;给你单链表的头节点 head &#xff0c;请你反转链表&#xff0c;并返回反转后的链表。 示例 1&#xff1a; 输入&#xff1a;head [1,2,3,4,5] 输出&#xff1a;[5,4,3,2,1]示例 2&#xff1a; 输入&#xff1a;head [1,2] 输出&#xff1a;[2,1]示例…

科沃斯梦碎“扫地茅”,钱东奇跌落“风口”

昔日“扫地茅“不香了&#xff0c;科沃斯跌落神坛。 4月27日&#xff0c;科沃斯发布2023年报显示&#xff1a;2023年&#xff0c;科沃斯的营收为155.02亿元&#xff0c;同比增加1.16%&#xff1b;同期&#xff0c;净利为6.10亿元&#xff0c;同比减少63.96%。科沃斯的经营业绩…

HR招聘面试测评,如何判断候选人的创新能力?

创新能力代表着一个人的未来发展潜力&#xff0c;创新能力突出的人&#xff0c;未来的上限就可能更高。而对于一个公司而言&#xff0c;一个具有创新能力的员工&#xff0c;会给公司带来新方案&#xff0c;新思路&#xff0c;对公司的长远发展拥有着十分积极的作用。 而在挑选…

【荣耀笔试题汇总】2024-05-09-荣耀春招笔试题-三语言题解(CPP/Python/Java)

&#x1f36d; 大家好这里是清隆学长 &#xff0c;一枚热爱算法的程序员 ✨ 本系列打算持续跟新荣耀近期的春秋招笔试题汇总&#xff5e; &#x1f4bb; ACM银牌&#x1f948;| 多次AK大厂笔试 &#xff5c; 编程一对一辅导 &#x1f44f; 感谢大家的订阅➕ 和 喜欢&#x1f49…

解析Spring中的循环依赖问题:初探三级缓存

什么是循环依赖&#xff1f; 这个情况很简单&#xff0c;即A对象依赖B对象&#xff0c;同时B对象也依赖A对象&#xff0c;让我们来简单看一下。 // A依赖了B class A{public B b; }// B依赖了A class B{public A a; }这种循环依赖可能会引发问题吗&#xff1f; 在没有考虑Sp…

信息系统项目管理师0097:价值交付系统(6项目管理概论—6.4价值驱动的项目管理知识体系—6.4.6价值交付系统)

点击查看专栏目录 文章目录 6.4.6价值交付系统1.创造价值2.价值交付组件3.信息流6.4.6价值交付系统 价值交付系统描述了项目如何在系统内运作,为组织及其干系人创造价值。价值交付系统包括项目如何创造价值、价值交付组件和信息流。 1.创造价值 项目存在于组织中,包括政府机构…

kkkkkkkkkkkk564

欢迎关注博主 Mindtechnist 或加入【Linux C/C/Python社区】一起探讨和分享Linux C/C/Python/Shell编程、机器人技术、机器学习、机器视觉、嵌入式AI相关领域的知识和技术。 人工智能与机器学习 &#x1f4dd;人工智能相关概念☞什么是人工智能、机器学习、深度学习☞人工智能发…

【LeetCode:LCR 077. 排序链表 + 链表】

&#x1f680; 算法题 &#x1f680; &#x1f332; 算法刷题专栏 | 面试必备算法 | 面试高频算法 &#x1f340; &#x1f332; 越难的东西,越要努力坚持&#xff0c;因为它具有很高的价值&#xff0c;算法就是这样✨ &#x1f332; 作者简介&#xff1a;硕风和炜&#xff0c;…

发表博客之:gemm/threadblock/threadblock_swizzle.h 文件夹讲解,cutlass深入讲解

文章目录 [发表博客之&#xff1a;gemm/threadblock/threadblock_swizzle.h 文件夹讲解&#xff0c;cutlass深入讲解](https://cyj666.blog.csdn.net/article/details/138514145)先来看一下最简单的struct GemmIdentityThreadblockSwizzle结构体 发表博客之&#xff1a;gemm/th…

STM32G030C8T6:EEPROM读写实验(I2C通信)

本专栏记录STM32开发各个功能的详细过程&#xff0c;方便自己后续查看&#xff0c;当然也供正在入门STM32单片机的兄弟们参考&#xff1b; 本小节的目标是&#xff0c;系统主频64 MHZ,采用高速外部晶振&#xff0c;实现PB11,PB10 引脚模拟I2C 时序&#xff0c;对M24C08 的EEPRO…

就业班 第三阶段(zabbix) 2401--5.9 day1 普通集zabbix 5.0部署 nginx部署+agent部署

文章目录 环境一、zabbix 5.0 部署1、安装yum源2、安装相关软件3、数据库安装和配置mariaDB数据库mysql57数据库 安装mysql万能卸载mysql代码&#xff1a;启动mysql并初始化4、数据表导入5、修改配置&#xff0c;启动服务6、配置 web GUI7、浏览器访问注意数据加密的选项不要勾…

基于Springboot的滴答拍摄影

基于SpringbootVue的滴答拍摄影设计与实现 开发语言&#xff1a;Java数据库&#xff1a;MySQL技术&#xff1a;SpringbootMybatis工具&#xff1a;IDEA、Maven、Navicat 系统展示 用户登录 首页 摄影作品 摄影服务 摄影论坛 后台登录 后台首页 用户管理 摄影师管理 摄影作…

谷歌继续将生成式人工智能融入网络安全

谷歌正在将多个威胁情报流与 Gemini 生成人工智能模型相结合&#xff0c;以创建新的云服务。 Google 威胁情报服务旨在帮助安全团队快速准确地整理大量数据&#xff0c;以便更好地保护组织免受网络攻击。 本周在旧金山举行的 RSA 会议上推出的 Google 威胁情报服务吸收了 Mand…

旅游组团奖励标准,申报条件!利川市旅游组团奖励办法

利川市旅游组团奖励有哪些&#xff1f;关于利川市旅游组团奖励标准&#xff0c;申报条件整理如下&#xff1a; 第一条根据《湖北省人民政府办公厅印发关于更好服务市场主体推动经济稳健发展若干政策措施的通知》&#xff08;鄂政办发〔2022〕54号&#xff09;、《恩施州人民政府…

力扣2105---给植物浇水II(Java、模拟、双指针)

题目描述&#xff1a; Alice 和 Bob 打算给花园里的 n 株植物浇水。植物排成一行&#xff0c;从左到右进行标记&#xff0c;编号从 0 到 n - 1 。其中&#xff0c;第 i 株植物的位置是 x i 。 每一株植物都需要浇特定量的水。Alice 和 Bob 每人有一个水罐&#xff0c;最初是…