干货!深入学习必学的模型微调

news2025/1/11 14:32:09

学习目标

  • 知道微调的原理
  • 能够利用微调模型来完成图像的分类任务

1.微调

如何在只有6万张图像的MNIST训练数据集上训练模型。学术界当下使用最广泛的大规模图像数据集ImageNet,它有超过1,000万的图像和1,000类的物体。然而,我们平常接触到数据集的规模通常在这两者之间。假设我们想从图像中识别出不同种类的椅子,然后将购买链接推荐给用户。一种可能的方法是先找出100种常见的椅子,为每种椅子拍摄1,000张不同角度的图像,然后在收集到的图像数据集上训练一个分类模型。另外一种解决办法是应用迁移学习(transfer learning),将从源数据集学到的知识迁移到目标数据集上。例如,虽然ImageNet数据集的图像大多跟椅子无关,但在该数据集上训练的模型可以抽取较通用的图像特征,从而能够帮助识别边缘、纹理、形状和物体组成等。这些类似的特征对于识别椅子也可能同样有效。

微调由以下4步构成。

  1. 在源数据集(如ImageNet数据集)上预训练一个神经网络模型,即源模型。
  2. 创建一个新的神经网络模型,即目标模型。它复制了源模型上除了输出层外的所有模型设计及其参数。我们假设这些模型参数包含了源数据集上学习到的知识,且这些知识同样适用于目标数据集。我们还假设源模型的输出层跟源数据集的标签紧密相关,因此在目标模型中不予采用。
  3. 为目标模型添加一个输出大小为目标数据集类别个数的输出层,并随机初始化该层的模型参数。
  4. 在目标数据集(如椅子数据集)上训练目标模型。我们将从头训练输出层,而其余层的参数都是基于源模型的参数微调得到的。

 

当目标数据集远小于源数据集时,微调有助于提升模型的泛化能力。

2.热狗识别

接下来我们来实践一个具体的例子:热狗识别。将基于一个小数据集对在ImageNet数据集上训练好的ResNet模型进行微调。该小数据集含有数千张热狗或者其他事物的图像。我们将使用微调得到的模型来识别一张图像中是否包含热狗。

首先,导入实验所需的工具包。

import tensorflow as tf
import numpy as np

2.1 获取数据集

我们首先将数据集放在路径hotdog/data之下:

 

 

每个类别文件夹里面是图像文件。

上一节中我们介绍了ImageDataGenerator进行图像增强,我们可以通过以下方法读取图像文件,该方法以文件夹路径为参数,生成经过图像增强后的结果,并产生batch数据:

flow_from_directory(self, directory,
                            target_size=(256, 256), color_mode='rgb',
                            classes=None, class_mode='categorical',
                            batch_size=32, shuffle=True, seed=None,
                            save_to_dir=None)

主要参数:

  • directory: 目标文件夹路径,对于每一个类对应一个子文件夹,该子文件夹中任何JPG、PNG、BNP、PPM的图片都可以读取。
  • target_size: 默认为(256, 256),图像将被resize成该尺寸。
  • batch_size: batch数据的大小,默认32。
  • shuffle: 是否打乱数据,默认为True。

我们创建两个
tf.keras.preprocessing.image.ImageDataGenerator实例来分别读取训练数据集和测试数据集中的所有图像文件。将训练集图片全部处理为高和宽均为224像素的输入。此外,我们对RGB(红、绿、蓝)三个颜色通道的数值做标准化。

# 获取数据集
import pathlib
train_dir = 'transferdata/train'
test_dir = 'transferdata/test'
# 获取训练集数据
train_dir = pathlib.Path(train_dir)
train_count = len(list(train_dir.glob('*/*.jpg')))
# 获取测试集数据
test_dir = pathlib.Path(test_dir)
test_count = len(list(test_dir.glob('*/*.jpg')))
# 创建imageDataGenerator进行图像处理
image_generator = tf.keras.preprocessing.image.ImageDataGenerator(rescale=1./255)
# 设置参数
BATCH_SIZE = 32
IMG_HEIGHT = 224
IMG_WIDTH = 224
# 获取训练数据
train_data_gen = image_generator.flow_from_directory(directory=str(train_dir),
                                                    batch_size=BATCH_SIZE,
                                                    target_size=(IMG_HEIGHT, IMG_WIDTH),
                                                    shuffle=True)
# 获取测试数据
test_data_gen = image_generator.flow_from_directory(directory=str(test_dir),
                                                    batch_size=BATCH_SIZE,
                                                    target_size=(IMG_HEIGHT, IMG_WIDTH),
                                                    shuffle=True)

下面我们随机取1个batch的图片然后绘制出来。

import matplotlib.pyplot as plt
# 显示图像
def show_batch(image_batch, label_batch):
    plt.figure(figsize=(10,10))
    for n in range(15):
        ax = plt.subplot(5,5,n+1)
        plt.imshow(image_batch[n])
        plt.axis('off')
# 随机选择一个batch的图像        
image_batch, label_batch = next(train_data_gen)
# 图像显示
show_batch(image_batch, label_batch)

 

2.2 模型构建与训练

我们使用在ImageNet数据集上预训练的ResNet-50作为源模型。这里指定weights='imagenet'来自动下载并加载预训练的模型参数。在第一次使用时需要联网下载模型参数。

Keras应用程序(keras.applications)是具有预先训练权值的固定架构,该类封装了很多重量级的网络架构,如下图所示:

 

实现时实例化模型架构:

tf.keras.applications.ResNet50(
    include_top=True, weights='imagenet', input_tensor=None, input_shape=None,
    pooling=None, classes=1000, **kwargs
)

主要参数:

  • include_top: 是否包括顶层的全连接层。
  • weights: None 代表随机初始化, 'imagenet' 代表加载在 ImageNet 上预训练的权值。
  • input_shape: 可选,输入尺寸元组,仅当 include_top=False 时有效,否则输入形状必须是 (224, 224, 3)(channels_last 格式)或 (3, 224, 224)(channels_first 格式)。它必须为 3 个输入通道,且宽高必须不小于 32,比如 (200, 200, 3) 是一个合法的输入尺寸。

在该案例中我们使用resNet50预训练模型构建模型:

# 加载预训练模型
ResNet50 = tf.keras.applications.ResNet50(weights='imagenet', input_shape=(224,224,3))
# 设置所有层不可训练
for layer in ResNet50.layers:
    layer.trainable = False
# 设置模型
net = tf.keras.models.Sequential()
# 预训练模型
net.add(ResNet50)
# 展开
net.add(tf.keras.layers.Flatten())
# 二分类的全连接层
net.add(tf.keras.layers.Dense(2, activation='softmax'))

接下来我们使用之前定义好的ImageGenerator将训练集图片送入ResNet50进行训练。

# 模型编译:指定优化器,损失函数和评价指标
net.compile(optimizer='adam',
            loss='categorical_crossentropy',
            metrics=['accuracy'])
# 模型训练:指定数据,每一个epoch中只运行10个迭代,指定验证数据集
history = net.fit(
                    train_data_gen,
                    steps_per_epoch=10,
                    epochs=3,
                    validation_data=test_data_gen,
                    validation_steps=10
                    )
Epoch 1/3
10/10 [==============================] - 28s 3s/step - loss: 0.6931 - accuracy: 0.5031 - val_loss: 0.6930 - val_accuracy: 0.5094
Epoch 2/3
10/10 [==============================] - 29s 3s/step - loss: 0.6932 - accuracy: 0.5094 - val_loss: 0.6935 - val_accuracy: 0.4812
Epoch 3/3
10/10 [==============================] - 31s 3s/step - loss: 0.6935 - accuracy: 0.4844 - val_loss: 0.6933 - val_accuracy: 0.4875

总结

  • 微调是目标模型复制了源模型上除了输出层外的所有模型设计及其参数,并基于目标数据集微调这些参数。而目标模型的输出层需要从头训练。
  • 利用tf.keras中的application实现迁移学习

人工智能自学站:

人工智能开发_人工智能工程师_AI人工智能

人工智能之python编程零基础入门

4天快速入门Python数据挖掘
简单快速入门Python机器学习
AI深度学习自然语言处理NLP零基础入门

AI-OpenCV图像处理10小时零基础入门

3天带你玩转Python深度学习

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/107842.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

浅析JWT

Cookie-session 我们都知道JWT一般用于用户登录等需要记住的操作,在谈论JWT之前就不得不谈谈以前的cookie-session登录了 。因为http协议是一种无状态协议,即每次服务端接收到客户端的请求时,都是一个全新的请求,服务器并不知道客…

【从零开始学微服务】08.引入微服务架构的时机

大家好,欢迎来到万猫学社,跟我一起学,你也能成为微服务专家。 在了解引入微服务架构的时机之前,架构设计时一般需要遵循的三个原则。 架构设计三个原则 架构设计一般需要遵循以下三个原则: 合适原则:合适…

NeurIPS'22 | APG:面向CTR预估的自适应参数生成网络

丨目录: 摘要 背景 Method 实验 结语▐ 摘要目前基于深度学习的CTR预估模型(即 Deep CTR Models)被广泛的应用于各个应用中。传统的 Deep CTR Models 的学习模式是相对静态的,即所有的样本共享相同的网络参数。然而,由…

IntelliJ IDEA中我最爱的10个快捷操作

1. psvm/main快速生成 main() 方法 在日常开发中,我们经常需要写main()方法,这时候您也可以使用main或者psvm命令快速地帮助我们创建出main()方法。 2.sout快速生成println()方法 打印输出一些内容到控制台也是频率很高的一个行为,我们可以…

Pytest断言

🔴pytest 允许使用标准的python assert 用于验证Python测试中的期望和值。所以并不像unittest的那么丰富。但是我们可以重写。 ❞小例子--介绍 import pytestclass Testnew:def test_num(self):assert 1 "1"def test_dic(self):assert {"QA":…

MySql索引下推知识分享

作者:刘邓忠 Mysql 是大家最常用的数据库,下面为大家带来 mysql 索引下推知识点的分享,以便巩固 mysql 基础知识,如有错误,还请各位大佬们指正。 1 什么是索引下推 索引下推 (Index Condition Pushdown,…

技术分享 | 测试的本质是什么?

本文将分别浅谈不同阶段的业务、不同端的业务、不同类型的业务的测试差异,再抽离其中的测试目标/本质。仅为笔者个人观点,欢迎批评指正。 一、不同阶段业务对测试的需求不同 不同阶段业务对测试的需求不同。这点几乎经历过的人员都心有戚戚焉。 从0到1的…

盘点导致Spring事务失效的4个场景

1,非运行时异常导致事务无法回滚 我们知道,Spring是通过AOP的方式来实现事务的,而在处理事务的过程中,Spring只有捕获到RuntimeException或者Error的时候才会触发回滚操作,如果我们在代码中抛出的是非运行时异常&…

Web前端学习之虚拟DOM如何进化为真实DOM

Vue和React的Render函数中都涉及到了Virtual DOM的概念,Virtual DOM也是性能优化上的重要一环,同时突破了直接操作真实DOM的瓶颈,本文带着以下几个问题来阐述Virtual DOM。 1.为什么要操作虚拟 DOM? 2.什么是虚拟 DOM? 3.手把手教你实现…

Word内容解析之图表数据获取

最近遇到一个问题,Word里有个从Excel直接复制进去的图,但那个Excel已经找不到了,无法通过编辑数据获取到表格的数据。这个其实可以用getdata等软件获取,或者鼠标点在表上的点就可以显示数据,再把数据录下来&#xff0c…

更加灵活、经济、高效的训练 — 新一代搜推广稀疏大模型训练范式GBA

作者:苏文博、张远行 近日,阿里巴巴在国际顶级机器学习会议NeurIPS 2022上发表了新的自研训练模式 Gloabl Batch gradients Aggregation(GBA,论文链接:https://arxiv.org/abs/2205.11048),由阿里…

模拟电子技术(七)波形的发生和信号的转换

(七)波形的发生和信号的转换正弦波振荡电路RC正弦波振荡电路LC正弦波振荡电路正弦波振荡例题电压比较器单限比较器过零比较器一般单限比较器滞回比较器窗口比较器电压比较器例题非正弦波发生电路矩形波发生电路三角波发生电路锯齿波发生电路信号转换电路…

Visual Studio 调试无法启动调试,拒绝访问

方法一 win更新了不兼容 ,卸载更新。 1、单击开始菜单,选择【设置】如下图; 2、然后再进入【更新和安全】选项,如下图; 3、查看已安装更新历史记录,如下图红圈 4、这个页面详细列出了最新的更新&#xf…

绿盟SecXOps安全智能分析技术白皮书 安全分析模型核心服务部署

安全分析模型核心服务部署 ModelOps 对所有的人工智能 模型(图形模型、语言模型、基于规则的模型)以及决策模型的整个生命周期 进行管理,确保对生产中的所有模型进行独立验证和问责,其核心功能涵盖了模型存储、模型测试、模型回滚…

28. 如何使用 SAP OData 服务向 ABAP 服务器上传文件

文章目录 1. 创建对应的自定义数据库表和 ABAP DDIC 结构2. 完成 SEGW 事物码里模型的增强3. 完成必要的 ABAP 编码本教程到目前为止开发的 OData 图书管理服务,可以在 ABAP 系统里对图书数据进行增删改查。 本步骤我们继续介绍如何通过 SAP OData 服务,实现向 ABAP 系统上传…

0.96寸OLED显示屏介绍

OLED显示屏简介 OLED,即有机发光二极管(Organic Light Emitting Diode)。OLED 由于同时具备自发光,不需背光源、对比度高、厚度薄、视角广、反应速度快、可用于挠曲性面板、使用温度范围广、构造及制程较简单等优异之特性&#x…

【C语言】常见字符函数和字符串函数

1.1strlen size_t strlen(const char* str); 字符串已经\0作为结束标志,strlen函数返回的是在字符串中\0前面出现的字符个数(不包含\0)。 参数指向的字符串必须以\0结束。 注意函数的返回值为size_t,是无符号整形(…

五、Vector底层源码详解

文章目录特点底层源码分析有参构造器public Vector(int initialCapacity, int capacityIncrement)有参构造器public Vector(int initialCapacity)有参构造器public Vector(Collection<? extends E> c)无参构造器public Vector()扩容机制特点 底层是elementDate数组线程…

自学Python真的可以吗?

自学当然可以学成功python了&#xff0c;但是前提是你需要认真去学&#xff0c;而不是三天打渔两天晒网的&#xff0c;因为python初学很容易&#xff0c;稍微过几天忘记也很容易&#xff0c;所以一定要坚持学习&#xff0c;并且通过平时多加练习来熟练掌握各个知识点。 一、学…

非零基础自学Golang 第15章 Go命令行工具 15.4 注释文档(doc)

非零基础自学Golang 文章目录非零基础自学Golang第15章 Go命令行工具15.4 注释文档(doc)第15章 Go命令行工具 15.4 注释文档(doc) Go语言文档工具go doc和go fmt一样&#xff0c;也是对godoc的简单封装。 我们通常使用go doc查看指定包的文档。 例如我们查看函数fmt.Println…