政安晨:【Keras机器学习实践要点】(十二)—— 迁移学习和微调

news2024/11/29 2:50:59

目录

设置

介绍

冻结层:了解可训练属性

可训练属性的递归设置

典型的迁移学习工作流程

微调

关于compile()和trainable的重要说明

BatchNormalization层的重要注意事项


政安晨的个人主页政安晨

欢迎 👍点赞✍评论⭐收藏

收录专栏: TensorFlow与Keras机器学习实战

希望政安晨的博客能够对您有所裨益,如有不足之处,欢迎在评论区提出指正!

本文是Keras 迁移学习和微调的完全指南文章。

设置

import numpy as np
import keras
from keras import layers
import tensorflow_datasets as tfds
import matplotlib.pyplot as plt

Keras是一种用于构建深度学习模型的高级神经网络API。迁移学习是指在一个任务上训练好的模型或特征提取器用于另一个相关任务上。

Keras迁移学习是利用预训练模型的特征提取能力来加速模型训练。预训练模型通常是在大规模数据集上训练的,并且已经提取出了一些有用的特征。迁移学习可以通过利用这些特征来降低新任务的数据需求和训练时间。

Keras提供了一些流行的预训练模型,如VGG16、ResNet50和InceptionV3等。这些模型可以直接在Keras中加载,并且可以通过设置参数来冻结一部分或全部层,以便在新任务上进行微调。

迁移学习的步骤包括加载预训练模型、修改模型结构(根据新任务的要求)并选择要冻结的层、添加新的全连接层或输出层、在新数据集上进行训练和微调模型。在训练过程中,可以选择冻结一些层,只训练新添加的层,以避免破坏原始特征提取能力。

Keras迁移学习的好处包括:

  1. 加速模型训练因为可以利用预训练模型的特征提取能力。
  2. 避免从头开始训练模型减少数据需求和计算资源。
  3. 可以应用在小数据集上而不需要大规模数据集。

总而言之,Keras迁移学习是一种利用预训练模型的特征提取能力来加速深度学习模型训练的方法。它可以帮助我们在新任务上更快、更有效地构建和训练模型。

介绍

迁移学习包括利用在一个问题上学习到的特征,在一个新的、类似的问题上加以利用。

例如,从一个识别浣熊的模型中学习到的特征可能有助于启动一个用于识别塔努基鱼的模型。

迁移学习通常用于数据集数据太少,无法从头开始训练完整模型的任务。

在深度学习中,迁移学习最常见的表现形式是以下工作流程:
 

× 从先前训练好的模型中提取图层。
× 冻结它们,以避免在未来的训练中破坏它们所包含的任何信息。
× 在冻结层上添加一些新的、可训练的层。它们将学会在新数据集上把旧特征转化为预测结果。
× 在数据集上训练新层。

最后一个可选步骤是微调,包括解冻上述获得的整个模型(或部分模型),并以极低的学习率在新数据上对其进行重新训练。通过逐步调整预训练特征以适应新数据,这有可能实现有意义的改进。

首先,我们将详细介绍 Keras 可训练 API,它是大多数迁移学习和微调工作流程的基础。

然后,我们将通过在 ImageNet 数据集上预训练模型,并在 Kaggle "猫与狗 "分类数据集上重新训练模型来演示典型的工作流程。

冻结层:了解可训练属性

层和模型有三种权重属性:

weights 是层的所有权重变量的列表。
trainable_weights(可训练权重)是指在训练过程中为了最小化损失而更新(通过梯度下降)的权重列表。
non_trainable_weights(不可训练权重)是不需要训练的权重列表。通常情况下,模型会在前向传递过程中更新这些权重。

例如密集层有 2 个可训练的权重(内核和偏置)

layer = keras.layers.Dense(3)
layer.build((None, 4))  # Create the weights

print("weights:", len(layer.weights))
print("trainable_weights:", len(layer.trainable_weights))
print("non_trainable_weights:", len(layer.non_trainable_weights))

结果如下:

weights: 2
trainable_weights: 2
non_trainable_weights: 0

一般来说,所有权重都是可训练权重。唯一具有不可训练权重的内置层是批归一化层。它在训练过程中使用不可训练权重来跟踪输入的均值和方差。以后咱们也会了解如何在自己的自定义层中使用不可训练权重。

示例:将可训练设置为假

layer = keras.layers.Dense(3)
layer.build((None, 4))  # Create the weights
layer.trainable = False  # Freeze the layer

print("weights:", len(layer.weights))
print("trainable_weights:", len(layer.trainable_weights))
print("non_trainable_weights:", len(layer.non_trainable_weights))

结果如下:

weights: 2
trainable_weights: 0
non_trainable_weights: 2

当可训练权重变为不可训练权重时,其值在训练过程中不再更新。

# Make a model with 2 layers
layer1 = keras.layers.Dense(3, activation="relu")
layer2 = keras.layers.Dense(3, activation="sigmoid")
model = keras.Sequential([keras.Input(shape=(3,)), layer1, layer2])

# Freeze the first layer
layer1.trainable = False

# Keep a copy of the weights of layer1 for later reference
initial_layer1_weights_values = layer1.get_weights()

# Train the model
model.compile(optimizer="adam", loss="mse")
model.fit(np.random.random((2, 3)), np.random.random((2, 3)))

# Check that the weights of layer1 have not changed during training
final_layer1_weights_values = layer1.get_weights()
np.testing.assert_allclose(
    initial_layer1_weights_values[0], final_layer1_weights_values[0]
)
np.testing.assert_allclose(
    initial_layer1_weights_values[1], final_layer1_weights_values[1]
)

演绎如下:

 1/1 ━━━━━━━━━━━━━━━━━━━━ 1s 766ms/step - loss: 0.0615

请勿将 layer.trainable 属性与 layer.__call__() 中的参数 training 混淆(后者控制层是以推理模式还是训练模式运行前向传递)。

可训练属性的递归设置

如果在模型上或任何有子图层的图层上设置 trainable = False,所有子图层也将变得不可训练。

示例:

inner_model = keras.Sequential(
    [
        keras.Input(shape=(3,)),
        keras.layers.Dense(3, activation="relu"),
        keras.layers.Dense(3, activation="relu"),
    ]
)

model = keras.Sequential(
    [
        keras.Input(shape=(3,)),
        inner_model,
        keras.layers.Dense(3, activation="sigmoid"),
    ]
)

model.trainable = False  # Freeze the outer model

assert inner_model.trainable == False  # All layers in `model` are now frozen
assert inner_model.layers[0].trainable == False  # `trainable` is propagated recursively

典型的迁移学习工作流程

这就引出了如何在 Keras 中实现典型的迁移学习工作流程:

实例化基础模型并加载预训练权重。
通过设置可训练 = 假,冻结基础模型中的所有层。
在基础模型中一个(或多个)层的输出基础上创建一个新模型。
在新数据集上训练新模型。

请注意,也可以采用另一种更轻便的工作流程:

实例化一个基础模型,并将预先训练好的权重载入其中。
通过它运行新数据集,并记录基础模型中一个(或多个)层的输出。这就是所谓的特征提取。
将该输出作为一个新的、更小的模型的输入数据。

第二个工作流程的主要优势在于,你只需在数据上运行一次基础模型,而不是每个历元训练一次。因此速度更快,成本更低。
 

不过,第二种工作流程的一个问题是,它无法在训练过程中动态修改新模型的输入数据,而这在进行数据扩增时是必需的。当新数据集的数据太少,无法从头开始训练一个完整的模型时,迁移学习通常会被用于这种任务,在这种情况下,数据扩增就显得非常重要。

因此,在下文中,我们将重点介绍第一种工作流程。

下面是 Keras 中的第一个工作流程:


首先,实例化一个带有预训练权重的基础模型。

base_model = keras.applications.Xception(
    weights='imagenet',  # Load weights pre-trained on ImageNet.
    input_shape=(150, 150, 3),
    include_top=False)  # Do not include the ImageNet classifier at the top.

然后,冻结基本模型。

base_model.trainable = False

在最开始处创建一个新模型。

inputs = keras.Input(shape=(150, 150, 3))
# We make sure that the base_model is running in inference mode here,
# by passing `training=False`. This is important for fine-tuning, as you will
# learn in a few paragraphs.
x = base_model(inputs, training=False)
# Convert features of shape `base_model.output_shape[1:]` to vectors
x = keras.layers.GlobalAveragePooling2D()(x)
# A Dense classifier with a single unit (binary classification)
outputs = keras.layers.Dense(1)(x)
model = keras.Model(inputs, outputs)

在新数据上训练新数据。

model.compile(optimizer=keras.optimizers.Adam(),
              loss=keras.losses.BinaryCrossentropy(from_logits=True),
              metrics=[keras.metrics.BinaryAccuracy()])
model.fit(new_dataset, epochs=20, callbacks=..., validation_data=...)

微调

一旦模型在新数据上收敛,你可以尝试解冻所有或部分基础模型,并使用非常低的学习率对整个模型进行端到端的重新训练。

这是一个可选的最后一步,可能会给你带来渐进的改进。但也有可能很快出现过拟合的情况,要记住这一点。

在模型的冻结层收敛之后才进行此步骤非常关键。如果将随机初始化的可训练层与保存预训练特征的可训练层混合,随机初始化的层将在训练过程中导致非常大的梯度更新,破坏预训练特征。

此阶段使用非常低的学习率也非常关键,因为你正在训练一个比第一轮训练中要大得多的模型,而且通常使用的数据集非常小。因此,如果应用大的权重更新,很容易出现过拟合的情况。在这里,你只想以增量的方式重新适应预训练的权重。

这就是如何实施整个基础模型的微调过程:

# Unfreeze the base model
base_model.trainable = True

# It's important to recompile your model after you make any changes
# to the `trainable` attribute of any inner layer, so that your changes
# are take into account
model.compile(optimizer=keras.optimizers.Adam(1e-5),  # Very low learning rate
              loss=keras.losses.BinaryCrossentropy(from_logits=True),
              metrics=[keras.metrics.BinaryAccuracy()])

# Train end-to-end. Be careful to stop before you overfit!
model.fit(new_dataset, epochs=10, callbacks=..., validation_data=...)

关于compile()和trainable的重要说明

在模型上调用compile()旨在“冻结”该模型的行为。这意味着在模型编译时可训练属性的值应保持不变,直到再次调用compile()。因此,如果您更改任何可训练值,请确保再次调用compile()以使您的更改生效。

BatchNormalization层的重要注意事项

许多图像模型都包含BatchNormalization层。在各个方面,该层都是一个特例。以下是一些需要记住的事情。

1. BatchNormalization包含2个不可训练的权重,在训练过程中这些权重会更新。这些变量用于跟踪输入的均值和方差。

2. 当你设置bn_layer.trainable = False时,BatchNormalization层将以推理模式运行,并且不会更新其均值和方差统计数据。这与一般情况下的其他层不同,因为权重的可训练性和推理/训练模式是两个不相关的概念。但是,在BatchNormalization层的情况下,这两者是相互关联的。

3. 当你解冻包含BatchNormalization层的模型以进行微调时,应该在调用基本模型时将BatchNormalization层保持在推理模式中,即通过传递training=False来实现。否则,对不可训练权重的更新将突然破坏模型所学到的内容。

后面的文章中,咱们将开展一个端到端示例,在那里面,您将看到这种模式的实际应用。


本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1559929.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

基于Uni-app的体育场馆预约系统的设计与实现

文章目录 基于Uni-app的体育场馆预约系统的设计与实现1、前言介绍2、开发技术简介3、系统功能图3、功能实现4、库表设计5、关键代码6、源码获取7、 🎉写在最后 基于Uni-app的体育场馆预约系统的设计与实现 1、前言介绍 伴随着信息技术与互联网技术的不断发展&#…

轻量应用服务器16核32G28M腾讯云租用优惠价格4224元15个月

腾讯云16核32G服务器租用价格4224元15个月,买一年送3个月,配置为:轻量16核32G28M、380GB SSD盘、6000GB月流量、28M带宽,腾讯云优惠活动 yunfuwuqiba.com/go/txy 活动链接打开如下图: 腾讯云16核32G服务器租用价格 腾讯…

Nginx的反向代理

Nginx的反向代理 location ^~ /aaa {proxy_pass http://192.168.15.78/; } 1. 跨域 2.Nginx 代理服务器缓存 3.Nginx 负载均衡 4. 动静分离 Nginx的跨域 跨源资源共享 (CORS) 是一种机制,它使用额外的 HTTP 标头让用户代理获得访问来自不同来域的服务器上选定资…

怎么快速上手虚拟化(容器)技术——以 Docker 为例

Docker 整体介绍 Docker 是一种使用 Go 语言开发的容器工具。所谓容器,实际上是一种虚拟化技术,用于为应用提供虚拟化的运行环境,相较于虚拟机具有轻量级、低延迟的特性。 下面是对上述介绍的说明: 应用程序运行需要一定的依赖…

qtcreator的信号槽链接

在ui文件中简单创建一个信号槽连接并保存可以在ui_mainwindow.h下 class Ui_MainWindow 类 void setupUi(QMainWindow *MainWindow)函数 找到对应代码 QObject::connect(pushButton, SIGNAL(clicked()), MainWindow, SLOT(close())); 下拉,由于 class MainWind…

@Transactional使用细节

版权声明 本文原创作者:谷哥的小弟作者博客地址:http://blog.csdn.net/lfdfhl 动态代理回顾 Spring的声明式事务管理是建立在 AOP 的基础之上的。Spring AOP是通过动态代理实现的。如果代理对象实现了接口,则使用JDK的动态代理;…

SpringBoot整合knife4J 3.0.3

Knife4j的前身是swagger-bootstrap-ui,前身swagger-bootstrap-ui是一个纯swagger-ui的ui皮肤项目。项目正式更名为knife4j,取名knife4j是希望她能像一把匕首一样小巧,轻量,并且功能强悍,更名也是希望把她做成一个为Swagger接口文档服务的通用性解决方案,不仅仅只是专注于前端Ui…

【IP组播】PIM-SM的RP、RPF校验

目录 一:PIM-SM的RP 原理概述 实验目的 实验内容 实验拓扑 1.基本配置 2.配置IGP 3.配置PIM-SM和静态RP 4.配置动态RP 5.配置Anycast RP 二: RPF校验 原理概述 实验目的 实验内容 实验拓扑 1.基本配置 2.配置IGP 3.配置PIM-DM 4.RPF校…

LeetCode_876(链表的中间结点)

//双指针//时间复杂度O(n) 空间复杂度O(1)public ListNode middleNode(ListNode head) {ListNode slowhead,fast head;while (fast!null && fast.next!null){slow slow.next;fast fast.next.next;}return slow;} 1->2->3->4->5->null 快指针移动两个…

如何创建一个TCP多人聊天室?

一、什么是TCP? TCP(Transmission Control Protocol)是一种可靠的 面向连接的协议 ,可以保证数据在传输过程中不会丢失、重复或乱序。 利用TCP实现简单聊天程序,需要客户端和服务器端之间建立TCP连接,并通…

一条SQL在MySQL中的执行过程

图解: 第⼀步:连接器 过程 1. 建⽴连接:与客户端进⾏ TCP 三次握⼿建⽴连接; 2. 校验密码:校验客户端的⽤户名和密码,如果⽤户名或密码不对,则会报错;3. 权限判断&#xff1a…

HCIP【GRE VPN配置】

目录 实验要求: 实验配置思路: 实验配置过程: 一、按照图式配置所有设备的IP地址 (1)首先配置每个接口的IP地址 (2)配置静态路由使公网可通 二、在公网的基础上创建GRE VPN隧道&#xff0…

【yy讲解PostCSS是如何安装和使用】

🎥博主:程序员不想YY啊 💫CSDN优质创作者,CSDN实力新星,CSDN博客专家 🤗点赞🎈收藏⭐再看💫养成习惯 ✨希望本文对您有所裨益,如有不足之处,欢迎在评论区提出…

保研线性代数机器学习基础复习2

1.什么是群(Group)? 对于一个集合 G 以及集合上的操作 ,如果G G-> G,那么称(G,)为一个群,并且满足如下性质: 封闭性:结合性:中性…

Linux多进程通信(1)——无名管道及有名管道使用例程

管道是半双工通信,如果需要 双向通信,则需要建立两个管道, 无名管道:只能父子进程间通信,且是非永久性管道通信结构,当它访问的进程全部终止时,管道也随之被撤销 有名管道:进程间不需…

【核弹级软安全事件】XZ Utils库中发现秘密后门,影响主要Linux发行版,软件供应链安全大事件

Red Hat 发布了一份“紧急安全警报”,警告称两款流行的数据压缩库XZ Utils(先前称为LZMA Utils)的两个版本已被植入恶意代码后门,这些代码旨在允许未授权的远程访问。 此次软件供应链攻击被追踪为CVE-2024-3094,其CVS…

从大厂裸辞半年,我靠它成功赚到了第一桶金,如果你失业了,建议这样做,不然时间太久了就完了

程序员接私活和创业是许多技术从业者关注的话题。下面我将介绍一些程序员接私活和创业的渠道和建议: 接私活的渠道: 自媒体平台: 可以利用社交媒体、个人博客、技术社区等平台展示自己的作品和技能,吸引潜在客户。自由工作平台&…

C#(winform) 调用MATLAB函数

测试环境 VisualStudio2022 / .NET Framework 4.7.2 Matlab2021b 参考:C# Matlab 相互调用 Matlab 1、编写Matlab函数 可以没有任何参数单纯定义matlab处理的函数,输出的数据都存在TXT中用以后期读取数据 function [result,m,n] TEST(list) % 计算…

数据分析之Tebleau 的度量名称和度量值

度量名称 包含所有的维度 度量值 包含所有的度量 度量名称包含上面所有的维度,度量值包含上面所有的度量 当同时创建两个或两个以上度量或维度时,会自动创建度量名称和度量值 拖入省份为行(这会是还没有值的) 可以直接将销售金额拖到数值这里 或者将销售…

Kafka 学习之:基于 flask 框架通过具体案例详解生产消费者模型,这一篇文章就够了

文章目录 案例信息介绍后端异步处理请求和后端同步处理请求同步方式异步方式 环境文件目录配置.envrequirements.txt 完整代码ext.pyapp.pykafka_create_user.py 运行方式本地安装 kafka运行 app.py使用 postman 测试建立 http 长连接,等待后端处理结果发送 RAW DAT…