【TensorFlow2 之012】TF2.0 中的 TF 迁移学习

news2025/1/10 2:55:26

#012 TensorFlow 2.0 中的 TF 迁移学习

一、说明

        在这篇文章中,我们将展示如何在不从头开始构建计算机视觉模型的情况下构建它。迁移学习背后的想法是,在大型数据集上训练的神经网络可以将其知识应用于以前从未见过的数据集。也就是说,为什么它被称为迁移学习;我们将现有模型的学习转移到新的数据集中。

教程概述:

  1. 介绍
  2. 使用内置的 TensorFlow 模型进行迁移学习
  3. 使用 TensorFlow Hub 进行迁移学习

二、为什么迁移学习简介

之        我们已经探讨了如何使用数据增强来提高模型性能。现在的问题是,“如果我们没有足够的数据来从头开始训练我们的网络怎么办?

对        此的解决方案是使用迁移学习方法。一篇更具理论意义的帖子已经发布在我们的博客上。如果需要,请查看它以刷新一些想法。我们可以使用迁移学习将知识从一些预先训练好的开源网络转移到我们自己的简历问题中。

计        算机视觉研究社区在互联网上发布了许多数据集,如Imagenet或MS Coco或Pascal数据集。许多计算机视觉研究人员已经在这些数据集上训练了他们的算法。有时,此培训需要数周时间,并且可能需要许多 GPU。事实上,其他人已经完成了这项任务并经历了痛苦的高性能研究过程,这意味着我们经常可以下载开源权重。

有        很多网络已经过训练。例如,Imagenet 数据集,它由 1000 个不同的类和超过 14 万张图像组成。因此,网络可能有一个 softmax 单元,它输出一千个可能的类之一。我们可以做的是摆脱softmax层并创建我们自己的输出单元来表示例如猫或

        由于我们使用下载的权重,我们将只训练与我们的输出层关联的参数,在我们的例子中,这将是一个 sigmoid 输出层。

三、使用内置的 TensorFlow 模型进行迁移学习

        让我们首先为训练准备数据集。我们将使用 wget.download 命令下载数据集。之后,我们需要解压缩它并合并训练和测试部分的路径。

import os
import wget
import zipfile
wget.download("https://storage.googleapis.com/mledu-datasets/cats_and_dogs_filtered.zip")

with zipfile.ZipFile("cats_and_dogs_filtered.zip","r") as zip_ref:
    zip_ref.extractall()

base_dir = 'cats_and_dogs_filtered'

train_dir = os.path.join(base_dir, 'train')
validation_dir = os.path.join(base_dir, 'validation')

        现在,让我们导入所有必需的库并构建模型。我们将使用一个名为MobileNetV2的预训练网络,该网络在ImageNet数据集上进行训练。在这里,我们希望使用除顶部分类图层之外的所有图层,因此我们不会将它们包含在我们的网络中。

        obileNetV2 体系结构概述:https://ai.googleblog.com/2018/04/mobilenetv2-next-generation-of-on.html

这        个模型和其他预训练模型已经在TensorFlow中可用。

import numpy as np
import tensorflow as tf
import matplotlib.pyplot as plt
import matplotlib.image as mpimg

from tensorflow.keras import Model
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Dense, GlobalAveragePooling2D
from tensorflow.keras.applications import MobileNetV2
from tensorflow.keras.preprocessing.image import ImageDataGenerator

base_model = MobileNetV2(input_shape=(224, 224, 3),
                                      include_top=False,
                                      weights='imagenet')

base_model.summary()


输出:
Model: "mobilenetv2_1.00_224"
__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
==================================================================================================
input_1 (InputLayer)            [(None, 224, 224, 3) 0                                            
__________________________________________________________________________________________________
Conv1_pad (ZeroPadding2D)       (None, 225, 225, 3)  0           input_1[0][0]                    
__________________________________________________________________________________________________
Conv1 (Conv2D)                  (None, 112, 112, 32) 864         Conv1_pad[0][0]                  
__________________________________________________________________________________________________
bn_Conv1 (BatchNormalization)   (None, 112, 112, 32) 128         Conv1[0][0]                      
__________________________________________________________________________________________________
Conv1_relu (ReLU)               (None, 112, 112, 32) 0           bn_Conv1[0][0]                   
__________________________________________________________________________________________________
expanded_conv_depthwise (Depthw (None, 112, 112, 32) 288         Conv1_relu[0][0]  

        果最后一层输出不同数量的类,那么我们需要有自己的输出单元来输出以下类:。有几种方法可以做到这一点:

  • 取最后几层的权重,将它们用作初始化并进行梯度下降。这样,我们将重新训练网络的一部分。
  • 删除最后几层的权重,使用我们自己的新隐藏单元和我们自己的最终 sigmoid(或 softmax)输出。通过这种方式,我们可以更改输出的数量。

        因此,这两种方法中的任何一种都值得尝试。

        现在让我们冻结预训练层并添加一个新层,称为 GlobalAveragePooling2D,之后是具有 sigmoid 激活函数的 Dense 层。

现        

base_model.trainable = False

model = Sequential([base_model,
                    GlobalAveragePooling2D(),
                    Dense(1, activation='sigmoid')])
model.summary()

    

Model: "sequential"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
mobilenetv2_1.00_224 (Model) (None, 7, 7, 1280)        2257984   
_________________________________________________________________
global_average_pooling2d (Gl (None, 1280)              0         
_________________________________________________________________
dense (Dense)                (None, 1)                 1281      
=================================================================
Total params: 2,259,265
Trainable params: 1,281
Non-trainable params: 2,257,984
_________________________________________________________________

          现在是训练步骤的时候了。我们将使用图像数据生成器来迭代图像。现在没有必要为大量时期训练网络,因为我们已经预先训练了一部分网络。

model.compile(loss='binary_crossentropy',
              optimizer='adam',
              metrics=['accuracy'])

train_datagen = ImageDataGenerator(rescale=1./255)
val_datagen = ImageDataGenerator(rescale=1./255)

train_generator = train_datagen.flow_from_directory(
        train_dir,
        target_size=(224, 224),
        batch_size=32,
        class_mode='binary')

validation_generator = val_datagen.flow_from_directory(
        validation_dir,
        target_size=(224, 224),
        batch_size=32,
        class_mode='binary')

history = model.fit(
      train_generator,
      epochs=6,
      validation_data=validation_generator,
      verbose=2)

让        我们看看结果。

accuracy = history.history['accuracy']
val_accuracy = history.history['val_accuracy']

loss = history.history['loss']
val_loss = history.history['val_loss']

epochs = range(len(accuracy))

plt.plot(epochs, accuracy, label="Training")
plt.plot(epochs, val_accuracy, label="Validation")
plt.legend()
plt.title('Training and validation accuracy')
plt.figure()

plt.plot(epochs, loss, label="Training")
plt.plot(epochs, val_loss, label="Validation")
plt.legend()
plt.title('Training and validation loss')

3.Text(0.5, 1.0, 'Training and validation loss')

四、 使用 TensorFlow Hub 进行迁移学习

        访问预训练模型的另一种方法是TensorFlow Hub。TensorFlow Hub是一个库,用于发布,发现和使用机器学习模型的可重用部分。您可以在此处找到更多预训练模型。

        我们将冻结图层并添加一个用于分类的新图层。全连接网络的输入称为瓶颈要素。它们表示网络中最后一个卷积层的激活图。

我们为给定数量的 epoch 训练模型。

        在这里,我们也可以使用 TensorBoard,但这次让我们保持简单。

        最后,让我们可视化一些预测。在此数据集中,猫标记为 0,狗标记为 1。我们将用蓝色显示正确的预测,用红色显示假。

预测

五、总结

        正如我们在上面看到的,使用迁移学习可以帮助我们在短时间内取得非常好的结果。使用数据增强,可以进一步增强结果。

         在下一篇文章中,我们将展示如何创建网络并将其转换以在移动设备上使用。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1084048.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

omnipathr官网教程 mistr

github python版本 omnipath tutorials Issue #17 saezlab/omnipath (github.com) R版本 saezlab/OmnipathR: R client for the OmniPath web service (github.com)https://github.com/saezlab/OmnipathR GitHub - saezlab/OmnipathR: R client for the OmniPath web s…

Prettier插件使用

一、前言 由于之前使用的Beautify格式化插件已经没有在维护了,所以这里再分享一个Formatter插件-Prettier。 二、插件安装 首先在扩展(ctrlShiftX)中搜索关键词Prettier,点击安装。 三、插件使用配置 首先在VSCode编辑器中依次打开菜单文件-首选项-…

Oauth2.0单点登录的解决方案 安当加密

上海安当技术有限公司的ASP身份认证系统提供针对Oauth2.0单点登录的解决方案。该解决方案可以帮助客户实现以下目标: 统一的用户管理:Oauth2.0单点登录可以提供一个统一的用户管理平台,使得用户只需要在一个平台上进行注册和身份认证&#x…

基于Java使用SpringBoot+Vue框架实现的前后端分离的美食分享平台

✌全网粉丝20W,csdn特邀作者、博客专家、CSDN新星计划导师、java领域优质创作者,博客之星、掘金/华为云/阿里云/InfoQ等平台优质作者、专注于Java技术领域和毕业项目实战✌ 🍅文末获取项目下载方式🍅 一、项目背景介绍: 在当今社会&#xff0…

打造类ChatGPT服务,本地部署大语言模型(LLM),如何远程访问?

ChatGPT的成功,让越来越多的人开始关注大语言模型(LLM)。如果拥有了属于自己的大语言模型,就可以对其进行一些专属优化。例如:打造属于自己的AI助理,或是满足企业自身的业务及信息安全需求。 所以&#xff…

天猫商品评论数据接口,天猫商品评论API接口,天猫API接口

天猫商品评论内容数据接口的步骤,但是可以提供淘宝商品评论内容数据接口的步骤: 授权获得淘宝开放平台API所需的权限。获取AppKey和AppSecret等认证信息。发送HTTP请求,获取所需评价信息。对获取到的评价信息进行处理和解析。结果处理&#…

行业顶流|AI数字人直播,低成本、高回报的新趋势

去年,元宇宙的概念炒得特别火,落地却寥寥无几,今年的AI数字人是否也只是一时炒作的概念呢? 在这个信息时代,科技的发展总是伴随着各种概念的冒起。元宇宙作为其中之一,的确在一瞬间扑面而来,引…

《机器学习》第5章 神经网络

文章目录 5.1 神经元模型5.2 感知机与多层网络5.3 误差逆传播算法5.4 全局最小与局部最小5.5 其他常见神经网络RBF网络ART网络SOM网络级联相关网络Elman网络Boltzmann机 5.6 深度学习 5.1 神经元模型 神经网络是由具有适应性的简单单元组成的广泛并行互连的网络,它…

数字化转型,河北吉力宝—传统行业的自我救赎新标杆

近年来,国家出台了各项政策支持企业数字化转型,“十四五”计划更是将建设数字经济作为重要发展目标,中国人工智能产业进入爆发式增长阶段,市场潜力巨大。随着数字化时代的到来,加快发展数字经济成为把握新一轮科技革命和产业变革新机遇的战略选择。 健康卫生事件后…

React如何优化减少组件间的重新Render

目前写了不少React的项目,发现React有些特点更灵活和注重细节,很多东西需要有一定的内功才能掌握好;比如在项目中常常遇到的组件重复渲染,有时候组件重复渲染如果内容是纯文本,不打印日志就不容易发现重复渲染了&#…

Maven - 5 分钟快速通关

目录 一、Maven 1.1、 基础语法 1.2、聚合 1.3、继承 1.4、自定义属性 一、Maven 1.1、 基础语法 <?xml version"1.0" encoding"UTF-8"?> <project xmlns"http://maven.apache.org/POM/4.0.0"xmlns:xsi"http://www.w3.org/…

RN(React Native)的应用程序在雷电模拟器可以运行,安卓真机运行失败问题解决记录

yarn react-native build-android打包的apk在真机安卓运行提示&#xff1a; Unable to load script . Make sure you re either running Metro ( run npx react - native start ) or that your bundle index . android . bundle is packaged correctly for release . jn…

基于JavaWeb+SpringBoot+Vue超市管理系统的设计和实现

基于JavaWebSpringBootVue超市管理系统的设计和实现 源码传送入口前言主要技术系统设计功能截图Lun文目录订阅经典源码专栏Java项目精品实战案例《500套》 源码获取 源码传送入口 前言 摘 要 科技进步的飞速发展引起人们日常生活的巨大变化&#xff0c;电子信息技术的飞速发…

适合学生写作业的台灯有哪些?高品质学生读写台灯推荐

不得不说如今我国青少年儿童的近视率还是非常高的&#xff0c;据国家卫健委疾控局数据&#xff0c;我国儿童青少年总体近视率为52.7%&#xff0c;其中6岁儿童为14.3%&#xff0c;小学生为35.6%&#xff0c;初中生为71.1%&#xff0c;高中生为80.5%&#xff0c;造成近视的原因不…

PLC编程速成(二)

目录 操作符 什么是操作符&#xff1f; 变量表&#xff08;数据类型&#xff09; 常用的类型&#xff1a; 变量表图 设置复位指令 如何重复双线圈与解决复双线圈问题&#xff1f; 解决复双线圈 ​编辑 重复双线圈 置复位指令&#xff08;有置位就存在复位&#xff09;…

声量暴涨130%,小红书「待爆」赛道创作指南

近年来&#xff0c;小红书影视板块展现出了旺盛的生命力。热门赛道逼近饱和的当下&#xff0c;内容如何不断推陈出新&#xff0c;成为营销困局。 本期&#xff0c;千瓜将锁定蓄势待发的影视板块&#xff0c;梳理“影视”内容打造方式&#xff0c;助力品牌开疆扩土&#xff0c;抢…

【图像误差测量】测量 2 张图像之间的差异,并测量图像质量(Matlab代码实现)

&#x1f4a5;&#x1f4a5;&#x1f49e;&#x1f49e;欢迎来到本博客❤️❤️&#x1f4a5;&#x1f4a5; &#x1f3c6;博主优势&#xff1a;&#x1f31e;&#x1f31e;&#x1f31e;博客内容尽量做到思维缜密&#xff0c;逻辑清晰&#xff0c;为了方便读者。 ⛳️座右铭&a…

离散数学 学习 之 递推方程和生成函数

递推方程 注意这里的特征根一定不是相等 特解的话一般要去设出基本的形式 这是0 次多项式 生成函数

Kepler.gl笔记:地图交互

1 双图 点第一个图标&#xff0c;进入双图 双图可以选择各自显示哪些layer 2 2D图转3D图 点击第二个图标 鼠标拖拽是控制位置 ctrl鼠标拖拽是旋转 3 显示图例

10_博客管理系统

1 项目展示 Express框架可以开发各种不同类型的项目&#xff0c;博客管理系统&#xff08;Blog Management System&#xff09;就是一个比较典型的项目。许多热爱分享技术的程序员都在建立自己的博客&#xff0c;用来发表一些技术文章。 主要完成用户登录、用户管理、文章管理…