【深度学习基础(3)】初识神经网络之深度学习hello world

news2024/11/18 8:36:42

文章目录

  • 一. 训练Keras中的MNIST数据集
  • 二. 工作流程
    • 1. 构建神经网络
    • 2. 准备图像数据
    • 3. 训练模型
    • 4. 利用模型进行预测
    • 5. (新数据上)评估模型精度

本节将首先给出一个神经网络示例,引出如下概念。了解完本节后,可以对神经网络在代码上的实现有一个整体的了解。

本节相关概念:

  • 样本
  • 标签
  • 层(layer)
  • 数据蒸馏
  • 密集连接
  • 10路softmax分类层
  • 编译(compilation)步骤的3个参数
  • 损失值、精度
  • 过拟合

我们来看一个神经网络的具体实例:使用Python的Keras库来学习手写数字分类。

在这个例子中,我们要解决的问题是,将手写数字的灰度图像(28像素×28像素)划分到10个类别中(从0到9)。我们将使用MNIST数据集。你可以将“解决”MNIST问题看作深度学习的“Hello World”,用来验证你的算法正在按预期运行。下图给出了MNIST数据集的一些样本。

在这里插入图片描述

说明

在机器学习中,分类问题中的某个类别叫作类(class),数据点叫作样本(sample)与某个样本对应的类叫作标签(label)(即描述:样本属于哪个类别)。

 

你不需要现在就尝试在计算机上运行这个例子。之后的文章会具体分析。

 

一. 训练Keras中的MNIST数据集

from tensorflow.keras.datasets import mnist 

(train_images, train_labels), (test_images, test_labels) = mnist.load_data()

train_images和train_labels组成了训练集,模型将从这些数据中进行学习。然后,我们在测试集(包括test_images和test_labels)上对模型进行测试。

图像被编码为NumPy数组,而标签是一个数字数组,取值范围是0~9。图像和标签一一对应。

 
看一下训练数据:

>>> train_images.shape 
 (60000, 28, 28) 
>>> len(train_labels) 
60000 
>>> train_labels 
array([5, 0, 4, ..., 5, 6, 8], dtype=uint8)

 

再来看一下测试数据:

>>> test_images.shape
(10000, 28, 28) 
>>> len(test_labels) 
10000 
>>> test_labels 
array([7, 2, 1, ..., 4, 5, 6], dtype=uint8)

 

二. 工作流程

我们的工作流程如下:
首先,将训练数据(train_images和train_labels)输入神经网络;
然后,神经网络学习将图像和标签关联在一起;
最后,神经网络对test_images进行预测,我们来验证这些预测与test_labels中的标签是否匹配。
 

1. 构建神经网络

下面我们来构建神经网络,如下:

from tensorflow import keras 
from tensorflow.keras import layers 

model = keras.Sequential(
						 [ layers.Dense(512, activation="relu"), 
						 layers.Dense(10, activation="softmax") ])

 

神经网络的核心组件是层(layer)

具体来说,层从输入数据中提取表示。大多数深度学习工作涉及将简单的层链接起来,从而实现渐进式的数据蒸馏(data distillation)。深度学习模型就像是处理数据的筛子,包含一系列越来越精细的数据过滤器(也就是层)。
 

本例中的模型包含2个Dense层,它们都是密集连接(也叫全连接)的神经层。

第2层是一个10路softmax分类层,它将返回一个由10个概率值(总和为1)组成的数组。每个概率值表示当前数字图像属于10个数字类别中某一个的概率。
 

在训练模型之前,我们还需要指定编译(compilation)步骤的3个参数

  • 优化器(optimizer):模型基于训练数据来自我更新的机制,其目的是提高模型性能。
  • 损失函数(loss function):模型如何衡量在训练数据上的性能,从而引导自己朝着正确的方向前进。
  • 在训练和测试过程中需要监控的指标(metric):本例只关心精度(accuracy),即正确分类的图像所占比例。后面两章会详细介绍损失函数和优化器的确切用途。

如下代码展示了编译步骤。


model.compile(
			  optimizer="rmsprop", 
			  loss="sparse_categorical_crossentropy", 
			  metrics=["accuracy"]
			  )

 

2. 准备图像数据

在开始训练之前,我们先对数据进行预处理,将其变换为模型要求的形状,并缩放到所有值都在[0, 1]区间。前面提到过,训练图像保存在一个uint8类型的数组中,其形状为(60000, 28, 28),取值区间为[0,255]。我们将把它变换为一个float32数组,其形状为(60000, 28 *28),取值范围是[0, 1]。

下面准备图像数据,如代码所示。

train_images = train_images.reshape((60000, 28 * 28)) 
train_images = train_images.astype("float32") / 255 

test_images = test_images.reshape((10000, 28 * 28)) 
test_images = test_images.astype("float32") / 255

 

3. 训练模型

在Keras中,通过调用模型的fit方法调用数据,训练模型。

>>> model.fit(train_images, train_labels, epochs=5, batch_size=128) 
Epoch 1/5 
60000/60000 [===========================] - 5s - loss: 0.2524 - acc: 0.9273 Epoch 2/5 
51328/60000 [=====================>.....] - ETA: 1s - loss: 0.1035 - acc: 0.9692

训练过程中显示了两个数字:一个是模型在训练数据上的损失值(loss),另一个是模型在训练数据上的精度(acc)。我们很快就在训练数据上达到了0.989(98.9%)的精度。

现在我们得到了一个训练好的模型,可以利用它来预测新数字图像的类别概率(如下代码)。这些新数字图像不属于训练数据,比如可以是测试集中的数据。

 

4. 利用模型进行预测

>>> test_digits = test_images[0:10] 
>>> predictions = model.predict(test_digits) 
>>> predictions[0] 
array([1.0726176e-10, 1.6918376e-10, 6.1314843e-08, 8.4106023e-06, 2.9967067e-11, 3.0331331e-09, 8.3651971e-14, 9.9999106e-01, 2.6657624e-08, 3.8127661e-07], dtype=float32)

如上代码我们对11个test_images图片进行预测,是什么数字,我们拿到第一个图片预测的概率数组,其中索引为7时,概率最大(0.99999106,几乎等于1),所以根据我们的模型,这个数字一定是7。

>>> predictions[0].argmax() 
7 
>>> predictions[0][7] 
0.99999106

这里我们检查测试标签是否与之一致:

>>> test_labels[0] 

7

平均而言,我们的模型对这种前所未见的数字图像进行分类的效果如何?我们来计算在整个测试集上的平均精度,如下代码所示。

 

5. (新数据上)评估模型精度

>>> test_loss, test_acc = model.evaluate(test_images, test_labels) 
>>> print(f"test_acc: {test_acc}") 
test_acc: 0.9785

测试精度约为97.8%,比训练精度(98.9%)低不少。训练精度和测试精度之间的这种差距是过拟合(overfit)造成的。

过拟合是指机器学习模型在新数据上的性能往往比在训练数据上要差。

第一个例子到这里就结束了。你刚刚看到了如何用不到15行Python代码构建和训练一个神经网络,对手写数字进行分类。

 

之后的文章我们将详细描述每一个步骤的原理,并且将学到张量(输入模型的数据存储对象)、张量运算(层的组成要素)与梯度下降(可以让模型从训练示例中进行学习)。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1642632.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

智慧文旅展现文化新风貌,科技助力旅行品质升级:借助智慧技术,文旅产业焕发新生机,为旅行者带来更高品质的文化体验之旅

一、引言 在数字化、智能化的浪潮下,文旅产业正迎来前所未有的发展机遇。智慧文旅作为文旅产业与信息技术深度融合的产物,不仅为旅行者带来了全新的文化体验,也为文旅产业注入了新的活力。本文旨在探讨智慧文旅如何借助智慧技术展现文化新风…

R语言中,查看经安装的包,查看已经加载的包,查看特定包是否已经安装,安装包,更新包,卸载包

创建于:2024.5.4 R语言中,查看经安装的包,查看已经加载的包,查看特定包是否已经安装,安装包,更新包,卸载包 文章目录 1. 查看经安装的包2. 查看已经加载的包3. 查看特定包是否已经安装4. 安装包…

将要上市的自动驾驶新书《自动驾驶系统开发》中摘录各章片段 1

以下摘录一些章节片段: 1. 概论 自动驾驶系统的认知中有一些模糊的地方,比如自动驾驶系统如何定义的问题,自动驾驶的研发为什么会有那么多的子模块,怎么才算自动驾驶落地等等。本章想先给读者一个概括介绍,了解自动驾…

18 内核开发-内核重点数据结构学习

课程简介: Linux内核开发入门是一门旨在帮助学习者从最基本的知识开始学习Linux内核开发的入门课程。该课程旨在为对Linux内核开发感兴趣的初学者提供一个扎实的基础,让他们能够理解和参与到Linux内核的开发过程中。 课程特点: 1. 入门级别&…

Java高阶私房菜:JVM性能优化案例及讲解

目录 核心思想 优化思考方向 压测环境准备 堆大小配置调优 调优前 调优后 分析结论 垃圾收集器配置调优 调优前 调优后 分析结论 JVM性能优化是一项复杂且耗时的工作,该环节没办法一蹴而就,它需要耐心雕琢,逐步优化至理想状态。“…

【Gateway远程开发】0.5GB of free space is necessary to run the IDE.

【Gateway远程开发】0.5GB of free space is necessary to run the IDE. 报错 0.5GB of free space is necessary to run the IDE. Make sure that there’s enough space in following paths: /root/.cache/JetBrains /root/.config/JetBrains 原因 下面两个路径的空间不…

WPF之绑定验证(错误模板使用)

1,前言: 默认情况下,WPF XAML 中使用的绑定并未开启绑定验证,这样导致用户在UI上对绑定的属性进行赋值时即使因不符合规范内部已抛出异常(此情况仅限WPF中的数据绑定操作),也被程序默认忽略&…

Linux设置脚本任意位置执行

记得备份 !!!!!!!!!!!!!! 修改文件之后记得用 source 文件名 刷新 注意:刷新文件之后在当前窗口…

02.zabbix配置web界面

zabbix配置web界面 访问搭建好的地址: http://192.168.111.66/zabbix 检查配置都是正常,下一步 对应的信息,我设置的密码是:123456,下一步即可; 给服务器随意设置一个名字,下一步 检查数据…

022、Python+fastapi,第一个Python项目走向第22步:ubuntu 24.04 docker 安装mysql8集群、redis集群(三)

这次来安装mysql8了,以前安装不是docker安装,这个我也是第一次,人人都有第一次嚒 前言 前面的redis安装还是花了点时间的,主要是网上教程,各有各的好,大家千万别取其长处,个人觉得这个环境影响…

一、RocketMQ基本概述与部署

RocketMQ基本概述与安装 一、概述1.MQ概述1.1 用途1.2 常见MQ产品1.3 MQ常用的协议 2.RocketMQ概述2.1 发展历程 二、相关概念1.基本概念1.1 消息(Message)1.2 主题(Topic)1.3 标签(Tag)1.4 队列&#xff0…

stamps做sbas-insar,时序沉降图怎么画?

🏆本文收录于「Bug调优」专栏,主要记录项目实战过程中的Bug之前因后果及提供真实有效的解决方案,希望能够助你一臂之力,帮你早日登顶实现财富自由🚀;同时,欢迎大家关注&&收藏&&…

【计算机网络】计算机网络的定义和分类

一.定义 计算机网络并没有一个精确和统一的定义,在计算机网络发展的不同阶段,人们对计算机网络给出了不同的定义,这些定义反映了当时计算机网络技术的发展水平。 例如计算机网络早期的一个最简单定义:计算机网络是一些互连的、自…

短视频素材去哪里找免费?短视频素材从哪儿下载?

在这个数字内容为王的时代,视频已经成为沟通信息和吸引观众的强大工具。无论是在市场营销、教育还是娱乐领域,高质量的视频素材都是制作引人注目内容的关键。以下列出的网站提供多样的视频素材,帮助您增强视觉叙述,并在竞争激烈的…

查找算法与排序算法

查找算法 二分查找 (要求熟练) // C// 二分查找法(递归实现) int binarySearch(int *nums, int target, int left, int right) // left代表左边界,right代表右边界 {if (left > right) return -1; // 如果左边大于右边,那么…

Docker部署nginx并且实现https访问

实验环境: 在已有的docker环境和nginx镜像的基础上进行操作 1、生成私钥 (1)openssl genrsa -out key.pem 2048 生成证书签名请求 (CSR) 并自签证书: (2)openssl req -new -x509 -key key.pem -out cert.pem -day…

Vitis HLS 学习笔记--HLS流水线基本用法

目录 1. 简介 2. 示例 2.1 对内层循环打拍 2.2 对外层循环打拍 2.3 优化数组访问后打拍 3. 总结 1. 简介 本文介绍pipeline的基本用法。pipeline是一种用于提高硬件设计性能的技术。本文介绍了pipeline在累加计算函数中的应用。通过优化内外层循环和数组访问&#xff0c…

C#中.net8WebApi加密解密

尤其在公网之中,数据的安全及其的重要,除过我们使用jwt之外,还可以对传送的数据进行加密,就算别人使用抓包工具,抓到数据,一时半会儿也解密不了数据,当然,加密也影响了效率&#xff…

【Linux】awk命令学习

最近用的比较多,学习总结一下。 文档地址:https://www.gnu.org/software/gawk/manual/gawk.html 一、awk介绍二、语句结构1.条件控制语句1)if2)for3)while4)break&continue&next&exit 2.比较运…

线性数据结构-手写链表-LinkList

为什么需要手写实现数据结构? 其实技术的本身就是基础的积累和搭建的过程,基础扎实 地基平稳 万丈高楼才会久战不衰,做技术能一通百,百通千就不怕有再难得技术了。 一:链表的分类 主要有单向,双向和循环链表…