使用TensorFlow训练深度学习模型实战(下)

news2024/11/29 10:38:36

大家好,本文接TensorFlow训练深度学习模型的上半部分继续进行讲述,下面将介绍有关定义深度学习模型、训练模型和评估模型的内容。

定义深度学习模型

数据准备完成后,下一步是使用TensorFlow搭建神经网络模型,搭建模型有两个选项:

可以使用各种层,包括Dense、Conv2D和LSTM,从头开始搭建模型。这些层定义了模型的架构及数据流经过它的方式,可基于TensorFlow Hub提供的预训练模型搭建模型。这些模型已经在大型数据集上进行了训练,并可以在特定数据集上进行微调,以达到在较短的训练时间内达到较高的准确度。

可以根据TensorFlow Hub中的预训练模型来建立模型。这些模型已经在大型数据集上进行了训练,并且可以在你的特定数据集上进行微调,以达到较少的训练时间,达到较高的准确性。

  • 从头开始定义深度学习模型

TensorFlow中的tf.keras.Sequential函数允许我们逐层定义神经网络模型,我们可以选择各种层,如Dense、Conv2D和LSTM,来搭建定制的模型架构。以下是示例: 

# 定义模型架构
model = tf.keras.Sequential([
    tf.keras.layers.Conv2D(32, 3, activation='relu', input_shape=(28, 28, 1)),
    tf.keras.layers.MaxPooling2D(),
    tf.keras.layers.Flatten(),
    tf.keras.layers.Dense(128, activation='relu'),
    tf.keras.layers.Dropout(0.5),
    tf.keras.layers.Dense(10)
])

在这个示例中,我们定义了一个模型,包含以下六个层(4个隐藏层):

  1. Conv2D层,具有32个过滤器,3x3的内核大小和ReLU激活。此层以形状为(28,28,1)的输入图像作为输入。

  2. MaxPooling2D层,具有默认的2x2池大小。此层对从上一层获得的特征映射进行下采样。

  3. Flatten层,将2D特征映射展平为1D向量。

  4. Dense层,具有128个神经元和ReLU激活。此层对展平的特征映射执行完全连接操作。

  5. Dropout层,在训练期间随机丢弃50%的连接以防止过拟合。

  6. Dense层,具有十个神经元,无激活函数。此层表示模型的输出层,神经元的数量对应于分类任务中的类别数目。

这个模型遵循典型的卷积神经网络架构,包括多个卷积层和池化层,以及一个或多个全连接层。

  • 从预训练模型定义深度学习模型 

利用TensorFlow Hub提供的预训练模型可能是一个不错的选择,因为它们已经在大量的数据集上进行了训练,可以帮助在减少训练时间的同时实现高准确度。在实现任何这些模型之前,让我们先了解一些TensorFlow Hub提供的常见预训练模型。

  1. VGG:The Visual Geometry Group(VGG)模型是由牛津大学开发的。这些模型广泛用于图像分类任务,并在各种基准数据集上取得了最先进的结果。

  2. ResNet:The Residual Network(ResNet)模型是由微软研究院开发的。这些模型具有独特的架构,可以训练非常深的神经网络(高达1000层)。

  3. Inception:Inception模型是由Google开发的。这些模型具有独特的架构,使用不同尺度的多个并行卷积,Inception模型广泛用于目标检测和图像分类任务。

  4. MobileNet:MobileNet模型是由Google开发的。这些模型具有针对移动设备和嵌入式设备进行优化的独特架构,MobileNet模型广泛用于移动设备上的图像分类和目标检测任务。

可以通过向预训练模型添加额外层并在特定数据集上训练模型来应用迁移学习。与从头开始训练模型相比,这种技术可以节省大量时间和计算资源。但是,在选择预训练模型并将数据集转换为该格式以确保兼容之前,了解预训练模型所需的输入格式非常重要。

在这个示例中,MobileNet模型被作为基本模型使用。在使用基本模型之前,检查模型所需的格式非常重要, 在本示例中,格式为(224,224,3)。然而,MNIST数据集是一个灰度图像,大小为(28,28,1),其中单个值表示像素的亮度。图像大小也比所需的格式要小得多。因此,需要重新调整数据集。以下是调整大小的主要思路:

使用image.resize函数将图像调整为所需的大小。该函数使用双线性插值来保留原始图像中的信息,同时将其调整为新大小。因此,此步骤可以将原始形状(28,28,1)调整为(224,224,1)的形状。

使用image.grayscale_to_rgb函数将图像转换为新的RGB图像,通过将单个灰度通道复制到新的RGB图像的所有三个通道中,从而将原始形状(224,224,1)调整为(224,224,3)的形状。

# 调整输入图像的大小为224x224,并将其转换为三通道的RGB图像
X_train = tf.image.grayscale_to_rgb(tf.image.resize(X_train, [224, 224]))
X_test = tf.image.grayscale_to_rgb(tf.image.resize(X_test, [224, 224]))

 现在让我们基于MobileNet模型定义我们的模型:

# 加载MobileNet模型,不包括顶层
base_model = MobileNet(include_top=False, input_shape=(224, 224, 3))

# 添加一个全局平均池化层和一个全连接输出层
x = base_model.output
x = GlobalAveragePooling2D()(x)
x = Dropout(0.5)(x)
x = Dense(10, activation='softmax')(x)

# 将基础模型和新层结合起来,创建完整的模型
model = tf.keras.models.Model(inputs=base_model.input, outputs=x)

# 冻结基础模型中的各层
for layer in base_model.layers:
    layer.trainable = False

在上面的示例中,我们定义了一个模型,如下所示:

  1. 使用MobileNet()定义基本模型

  2. GlobalAveragePooling2D层,使用基本模型的最后一个卷积层的输出,计算每个特征映射的平均值,从而得到一个固定长度的向量,总结了特征映射中的空间信息。

  3. Dropout层,在训练期间随机丢弃50%的连接以防止过拟合。

  4. Dense层,使用十个单元的完全连接层和softmax激活。它接收来自上一层的输出并生成覆盖十个可能类别的概率分布。

编译和训练模型 

在创建模型之后,必须通过指定在训练期间使用的损失函数、优化器和指标来编译它。以下是一个编译模型的示例代码:

# 编译该模型
model.compile(optimizer='adam',
              loss='sparse_categorical_crossentropy',
              metrics=['accuracy'])

由于这是一个多分类问题,因此此示例代码使用了稀疏交叉熵损失函数,我们使用的是Adam优化器和准确率指标。

在训练模型之后可以在测试集上评估它,以查看它在未见过的数据上的表现如何,以下是一个评估模型的示例代码:

# 在测试数据上评估该模型
test_loss, test_acc = model.evaluate(X_test, y_test)
print('Test loss: ', test_loss)
print('Test Acc: ', test_acc)

 在此示例代码中,我们在测试集上评估模型,并输出测试损失和准确率。

进行预测

一旦训练和评估了模型,就可以使用它来预测新数据。以下是一个进行预测的示例代码:

# 对新数据进行预测
y_pred = model.predict(X_test)
y_pred_labels = np.argmax(y_pred, axis=1)
print(y_pred_labels)

在此示例代码中,我们在模型上使用predict()方法对整个测试集进行预测。

如果我们想要预测单个图像并返回预测标签与真实标签,那么就需要对Keras模型的predict()方法进行更改。因为Keras模型的predict()方法期望输入数据形式为一批图像,而我们想要传递单个图像给predict()方法,所以需要将其重新调整为批次大小为1。

def predict_and_compare(model, X_test, y_test, index):
    # 从X_test中获取给定索引的例子
    example = X_test[index]

    # 将例子重塑为预期的输入形状
    example = np.reshape(example, (1, 28, 28, 1))

    # 预测这个例子的标签
    y_pred = model.predict(example)

    # 将预测的概率转换为类别标签
    y_pred_label = np.argmax(y_pred, axis=1)[0]

    # 使用索引从y_test获取真实标签
    y_test_array = y_test.values
    # Get the label for the first example in the test set 
    y_true = y_test_array[index]

    # 输出预测的和真实的标签
    print("Predicted label:", y_pred_label)
    print("True label:", y_true)
    
    # 返回预测的和真实的标签
    return y_pred_label, y_true
  
# 预测并比较测试集中第一个例子的标签
y_pred_label, y_true = predict_and_compare(model, X_test, y_test, 0)

在上面的示例中,我们通过添加一个额外的维度来代表批次大小,从而将输入图像从(28,28,1)调整为(1,28,28,1)。这样,我们就可以传递单个图像给predict()方法,并获得该图像的预测结果。当我们调用上面的函数时,可以自定义要预测的图像:

 这就是在TensorFlow中实现深度学习的步骤。当然,这只是一个基本示例。你可以搭建具有更多层、不同类型的层和不同超参数的更复杂的模型,以便在数据集上获得更好的性能。

综上,本文我们演示了如何对数据进行预处理、搭建和训练模型、在单独的测试集上评估其性能以及使用简单的卷积神经网络(CNN)进行图像分类的预测,通过学习可以获得如何在TensorFlow中构建深度学习模型以及如何将这些概念应用于真实世界数据集的理解。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/792953.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

年轻小伙爆肝ARST

关于 ARTS 的释义 —— 每周完成一个 ARTS: ⭐️● Algorithm: 每周至少做一个 LeetCode 的算法题 ⭐️● Review: 阅读并点评至少一篇技术文章 ⭐️● Tips: 学习至少一个技术技巧 ⭐️● Share: 分享一篇有观点和思考的技术文章 希望通过此次活动能聚集一波热爱技…

GitLab 删除项目

1.点击头像 2.点击Profile 3.选择要删除的项目点进去 4.settings-general-Advances-expand 5.然后在弹出框中输入你要删除的项目名称即可

【WEB开发】Java获取高德POI(关键词搜索法)实现数据展示

前言 该篇文章是关键词搜索法获取高德poi,但鉴于无法突破200条记录的上限,所以采用了本方法进行区/县循环检索。开始之前我们首先需要明白一些常识 poi是兴趣点,它本身除了经纬度,还记录了一些信息,如名称、地址、联…

mirror功能

实现方式 mirror逻辑的工作阶段: ngx在log phase之后(在ngx_http_free_request处调用)已完成向client端返回response,在log phase之后完成close connection(短链接),在该阶段处理mirror逻辑不…

【Redis】高级篇: 一篇文章讲清楚Redis的单线程和多线程

目录 面试题 Redis到底是多线程还是单线程? 简单回答 详解 Redis的“单线程” Redis为什么选择单线程? 后来Redis为什么又逐渐加入了多线程特性? Redis为什么快? 回答 IO多路复用 Unix网络编程的5种IO模型 主线程和IO…

温湿度传感器的工作原理及应用领域你了解多少呢?

传感器是一种将物理量转换为电信号的装置,用于检测温度、湿度、压力、光强、震动等物理量。它能够将检测到的物理量转换为电信号,并输送到计算机、单片机等设备进行分析和处理。生产生活中,不同的场所和环境对温湿度有着特殊要求,…

看完即会,抓取微信小程序数据包教程

在给VIP学员答疑的时候,有很多小伙伴问到能不能抓取到微信小程序数据呢?答案当然是肯定的,通过Fiddler或者Charles这些主流的抓包工具都可以抓得到,在IOS平台抓取微信小程序和https请求都是一样的设置,接下来给大家通过…

【代码随想录day20】验证二叉搜索树

题目 给你一个二叉树的根节点 root ,判断其是否是一个有效的二叉搜索树。 有效 二叉搜索树定义如下: 节点的左子树只包含 小于 当前节点的数。 节点的右子树只包含 大于 当前节点的数。 所有左子树和右子树自身必须也是二叉搜索树。 思路 最开始想简单…

Linux 学习记录58(ARM篇)

Linux 学习记录58(ARM篇) 本文目录 Linux 学习记录58(ARM篇)一、GIC相关寄存器1. 系统框图2. 中断号对应关系 二、GICD寄存器1. GICD_CTLR2. GICD_ISENABLERx3. GICD_IPRIORITYRx4. GICD_ITARGETSRx5. GICD_ICPENDRx 三、GICC寄存器1. GICC_PMR2. GICC_CTLR3. GICC_IAR4. GICC_…

JAVA面试总结-Redis篇章(五)——持久化

Java面试总结-Redis篇章(五)——持久化 1.RDBRDB全称Redis Database Backup file (Redis数据备份文件),也被叫做Redis数据快照。简单来说就是把内存中的所有数据都记录到磁盘中。当Redis实例故障重启后,从磁盘读取快照文件&#x…

持续贡献开源力量,棱镜七彩加入openKylin

近日,棱镜七彩签署 openKylin 社区 CLA(Contributor License Agreement 贡献者许可协议),正式加入openKylin 开源社区。 棱镜七彩成立于2016年,是一家专注于开源安全、软件供应链安全的创新型科技企业。自成立以来&…

短视频账号矩阵系统源码开发部署路径

一、短视频批量剪辑的开发逻辑算法 1.视频剪辑之开发算法 自己研发视频剪辑是指通过对视频素材进行剪切、调整、合并等操作,利用后台计算机算法,进行抽帧抽组抽序进行排列以达到对视频内容进行修改和优化的目的。自己研发的视频剪辑工具可以通过后台码…

系统集成项目管理工程师挣值分析笔记大全

系统集成项目管理工程师挣值分析笔记大全 挣值分析是一种项目管理技术,用于量化和监控项目绩效。它通过比较计划值(PV)、实际成本(AC)和挣值(EV)三个参数来评估项目的进展情况和成本绩效。 挣值…

系统架构设计师-软件架构设计(3)

目录 一、软件架构风格(其它分类) 1、闭环控制结构(过程控制) 2、C2风格 3、MDA(模型驱动架构 Model Driven Architecture) 4、特定领域软件架构(DSSA) 4.1 DSSA基本活动及产出物…

《CUDA C++ Programming Guide》第一章 CUDA介绍

第一章 CUDA介绍 1.1 使用GPUs的好处 在相同的价格和功耗范围内,图形处理器 GPU 比 CPU 提供了更高的指令吞吐量和内存带宽, 许多应用程序利用这些更高的功能在 GPU 上比在 CPU 上运行得更快。相比较其他计算设备,如 FPGA,也是非常节能的&a…

协议和模型

1 规则 1.1 通信基础知识 不同网络的规模、形状和功能都存在很大差异。它们可以复杂到通过互联网来连接设备,也可以简单到直接将两台计算机用一根电缆连接,或者是介于这两种之间。然而,只是完成终端设备之间的有线或无线物理连接并不足以实…

MATLAB基础知识回顾

目录 1.帮助命令 2.数据类型 3.元胞数组和结构体 4.矩阵操作 4.1 矩阵的定义与构造 4.2 矩阵的四则运算 4.3 矩阵的下标 5.程序结构 5.1 for循环结构 5.2 分支结构 7.基本绘图操作 7.1.二维平面绘图 6.2 三维立体绘图 7.图形的保存与导出 8.补充 语句后⾯加;的作…

Kaggle图表内容识别大赛TOP方案汇总

赛题名称:Benetech - Making Graphs Accessible 赛题链接:https://www.kaggle.com/competitions/benetech-making-graphs-accessible 赛题背景 数以百万计的学生有学习、身体或视力障碍,导致人们无法阅读传统印刷品。这些学生无法访问科学…

人机合一Linux

未来云系统成为主流,维护电脑这种充满时代特性的技术,完全不重要了。 无论是学习还是工作,电脑都是IT人必不可少的重要武器,一台好电脑除了自身配置要经得起考验,后期主人对它的维护也是决定它寿命的重要因素&#xff…

【01trie】CF1851F

Problem - F - Codeforces 题意: 思路: 首先最大异或对可以用01trie解决 trie树入门_lamentropetion的博客-CSDN博客 ai xor x 和 aj xor x 都必须为1 因此可以转换为 ai 和 aj 最小异或对问题 改一下01trie的板子即可 主要修改部分在query上 p…