微调预训练的 NLP 模型

news2025/1/18 9:04:54

动动发财的小手,点个赞吧!

针对任何领域微调预训练 NLP 模型的分步指南

简介

在当今世界,预训练 NLP 模型的可用性极大地简化了使用深度学习技术对文本数据的解释。然而,虽然这些模型在一般任务中表现出色,但它们往往缺乏对特定领域的适应性。本综合指南[1]旨在引导您完成微调预训练 NLP 模型的过程,以提高特定领域的性能。

动机

尽管 BERT 和通用句子编码器 (USE) 等预训练 NLP 模型可以有效捕获语言的复杂性,但由于训练数据集的范围不同,它们在特定领域应用中的性能可能会受到限制。当分析特定领域内的关系时,这种限制变得明显。

例如,在处理就业数据时,我们希望模型能够识别“数据科学家”和“机器学习工程师”角色之间的更接近,或者“Python”和“TensorFlow”之间更强的关联。不幸的是,通用模型常常忽略这些微妙的关系。

下表展示了从基本多语言 USE 模型获得的相似性的差异:

alt

为了解决这个问题,我们可以使用高质量的、特定领域的数据集来微调预训练的模型。这一适应过程显着增强了模型的性能和精度,充分释放了 NLP 模型的潜力。

在处理大型预训练 NLP 模型时,建议首先部署基本模型,并仅在其性能无法满足当前特定问题时才考虑进行微调。

本教程重点介绍使用易于访问的开源数据微调通用句子编码器 (USE) 模型。

可以通过监督学习和强化学习等各种策略来微调 ML 模型。在本教程中,我们将专注于一次(几次)学习方法与用于微调过程的暹罗架构相结合。

理论框架

可以通过监督学习和强化学习等各种策略来微调 ML 模型。在本教程中,我们将专注于一次(几次)学习方法与用于微调过程的暹罗架构相结合。

方法

在本教程中,我们使用暹罗神经网络,它是一种特定类型的人工神经网络。该网络利用共享权重,同时处理两个不同的输入向量来计算可比较的输出向量。受一次性学习的启发,这种方法已被证明在捕获语义相似性方面特别有效,尽管它可能需要更长的训练时间并且缺乏概率输出。

连体神经网络创建了一个“嵌入空间”,其中相关概念紧密定位,使模型能够更好地辨别语义关系。

alt
  • 双分支和共享权重:该架构由两个相同的分支组成,每个分支都包含一个具有共享权重的嵌入层。这些双分支同时处理两个输入,无论是相似的还是不相似的。
  • 相似性和转换:使用预先训练的 NLP 模型将输入转换为向量嵌入。然后该架构计算向量之间的相似度。相似度得分(范围在 -1 到 1 之间)量化两个向量之间的角距离,作为它们语义相似度的度量。
  • 对比损失和学习:模型的学习以“对比损失”为指导,即预期输出(训练数据的相似度得分)与计算出的相似度之间的差异。这种损失指导模型权重的调整,以最大限度地减少损失并提高学习嵌入的质量。

数据概览

为了使用此方法对预训练的 NLP 模型进行微调,训练数据应由文本字符串对组成,并附有它们之间的相似度分数。

训练数据遵循如下所示的格式:

alt

在本教程中,我们使用源自 ESCO 分类数据集的数据集,该数据集已转换为基于不同数据元素之间的关系生成相似性分数。

准备训练数据是微调过程中的关键步骤。假设您有权访问所需的数据以及将其转换为指定格式的方法。由于本文的重点是演示微调过程,因此我们将省略如何使用 ESCO 数据集生成数据的详细信息。

ESCO 数据集可供开发人员自由使用,作为各种应用程序的基础,这些应用程序提供自动完成、建议系统、职位搜索算法和职位匹配算法等服务。本教程中使用的数据集已被转换并作为示例提供,允许不受限制地用于任何目的。

让我们首先检查训练数据:

import pandas as pd

# Read the CSV file into a pandas DataFrame
data = pd.read_csv("./data/training_data.csv")

# Print head
data.head()
alt

起点:基线模型

首先,我们建立多语言通用句子编码器作为我们的基线模型。在进行微调过程之前,必须设置此基线。

在本教程中,我们将使用 STS 基准和相似性可视化示例作为指标来评估通过微调过程实现的更改和改进。

STS 基准数据集由英语句子对组成,每个句子对都与相似度得分相关联。在模型训练过程中,我们评估模型在此基准集上的性能。每次训练运行的持久分数是数据集中预测相似性分数和实际相似性分数之间的皮尔逊相关性。

这些分数确保当模型根据我们特定于上下文的训练数据进行微调时,它保持一定程度的通用性。

# Loads the Universal Sentence Encoder Multilingual module from TensorFlow Hub.
base_model_url = "https://tfhub.dev/google/universal-sentence-encoder-multilingual/3"
base_model = tf.keras.Sequential([
    hub.KerasLayer(base_model_url,
                   input_shape=[],
                   dtype=tf.string,
                   trainable=False)
])

# Defines a list of test sentences. These sentences represent various job titles.
test_text = ['Data Scientist''Data Analyst''Data Engineer',
             'Nurse Practitioner''Registered Nurse''Medical Assistant',
             'Social Media Manager''Marketing Strategist''Product Marketing Manager']

# Creates embeddings for the sentences in the test_text list. 
# The np.array() function is used to convert the result into a numpy array.
# The .tolist() function is used to convert the numpy array into a list, which might be easier to work with.
vectors = np.array(base_model.predict(test_text)).tolist()

# Calls the plot_similarity function to create a similarity plot.
plot_similarity(test_text, vectors, 90"base model")

# Computes STS benchmark score for the base model
pearsonr = sts_benchmark(base_model)
print("STS Benachmark: " + str(pearsonr))
alt

微调模型

下一步涉及使用基线模型构建暹罗模型架构,并使用我们的特定领域数据对其进行微调。

# Load the pre-trained word embedding model
embedding_layer = hub.load(base_model_url)

# Create a Keras layer from the loaded embedding model
shared_embedding_layer = hub.KerasLayer(embedding_layer, trainable=True)

# Define the inputs to the model
left_input = keras.Input(shape=(), dtype=tf.string)
right_input = keras.Input(shape=(), dtype=tf.string)

# Pass the inputs through the shared embedding layer
embedding_left_output = shared_embedding_layer(left_input)
embedding_right_output = shared_embedding_layer(right_input)

# Compute the cosine similarity between the embedding vectors
cosine_similarity = tf.keras.layers.Dot(axes=-1, normalize=True)(
    [embedding_left_output, embedding_right_output]
)

# Convert the cosine similarity to angular distance
pi = tf.constant(math.pi, dtype=tf.float32)
clip_cosine_similarities = tf.clip_by_value(
    cosine_similarity, -0.999990.99999
)
acos_distance = 1.0 - (tf.acos(clip_cosine_similarities) / pi)

# Package the model
encoder = tf.keras.Model([left_input, right_input], acos_distance)

# Compile the model
encoder.compile(
    optimizer=tf.keras.optimizers.Adam(
        learning_rate=0.00001,
        beta_1=0.9,
        beta_2=0.9999,
        epsilon=0.0000001,
        amsgrad=False,
        clipnorm=1.0,
        name="Adam",
    ),
    loss=tf.keras.losses.MeanSquaredError(
        reduction=keras.losses.Reduction.AUTO, name="mean_squared_error"
    ),
    metrics=[
        tf.keras.metrics.MeanAbsoluteError(),
        tf.keras.metrics.MeanAbsolutePercentageError(),
    ],
)

# Print the model summary
encoder.summary()
alt
  • Fit model
# Define early stopping callback
early_stop = keras.callbacks.EarlyStopping(
    monitor="loss", patience=3, min_delta=0.001
)

# Define TensorBoard callback
logdir = os.path.join(".""logs/fit/" + datetime.now().strftime("%Y%m%d-%H%M%S"))
tensorboard_callback = keras.callbacks.TensorBoard(log_dir=logdir)

# Model Input
left_inputs, right_inputs, similarity = process_model_input(data)

# Train the encoder model
history = encoder.fit(
    [left_inputs, right_inputs],
    similarity,
    batch_size=8,
    epochs=20,
    validation_split=0.2,
    callbacks=[early_stop, tensorboard_callback],
)

# Define model input
inputs = keras.Input(shape=[], dtype=tf.string)

# Pass the input through the embedding layer
embedding = hub.KerasLayer(embedding_layer)(inputs)

# Create the tuned model
tuned_model = keras.Model(inputs=inputs, outputs=embedding)

评估结果

现在我们有了微调后的模型,让我们重新评估它并将结果与基本模型的结果进行比较。

# Creates embeddings for the sentences in the test_text list. 
# The np.array() function is used to convert the result into a numpy array.
# The .tolist() function is used to convert the numpy array into a list, which might be easier to work with.
vectors = np.array(tuned_model.predict(test_text)).tolist()

# Calls the plot_similarity function to create a similarity plot.
plot_similarity(test_text, vectors, 90"tuned model")

# Computes STS benchmark score for the tuned model
pearsonr = sts_benchmark(tuned_model)
print("STS Benachmark: " + str(pearsonr))
alt

基于在相对较小的数据集上对模型进行微调,STS 基准分数与基线模型的分数相当,表明调整后的模型仍然具有普适性。然而,相似性可视化显示相似标题之间的相似性得分增强,而不同标题的相似性得分降低。

总结

微调预训练的 NLP 模型以进行领域适应是一种强大的技术,可以提高其在特定上下文中的性能和精度。通过利用高质量的、特定领域的数据集和暹罗神经网络,我们可以增强模型捕获语义相似性的能力。

本教程以通用句子编码器 (USE) 模型为例,提供了微调过程的分步指南。我们探索了理论框架、数据准备、基线模型评估和实际微调过程。结果证明了微调在增强域内相似性得分方面的有效性。

通过遵循此方法并将其适应您的特定领域,您可以释放预训练 NLP 模型的全部潜力,并在自然语言处理任务中取得更好的结果

Reference

[1]

Source: https://towardsdatascience.com/domain-adaption-fine-tune-pre-trained-nlp-models-a06659ca6668

本文由 mdnice 多平台发布

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/738793.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

vue进阶----路由

目录 前端路由的概念与原理 什么是路由 SPA 与前端路由 前端路由 前端路由的工作方式 实现简易的前端路由 vue-router 的基本用法 vue-router vue-router 安装和配置的步骤 声明路由的匹配规则 vue-router 的常见用法 1、路由重定向 2、嵌套路由 3、动态路由匹配 …

Stable Diffusion高阶技能(1)-掌握这些,你也能绘出惊艳画作

开篇 初踏入AI作画的世界,你可能会对如何制造出惊艳的艺术作品而困惑。作为一个前沿技术的探索者,我在这一篇文章中,会和你一同揭秘如何用正确的提示词操控AI的“透视”,将最美的画面展现在你眼前。 技能一、提高图片质量的高阶手法 在数量众多的元素中,我们如何做出最…

Vue组件库Element-常见组件-Form表单

Form表单 Form 表单&#xff1a;由输入框、选择器、单选框、多选框等控件组成&#xff0c;用以收集、检验、提交数据 具体关键代码如下&#xff1a; <template><div><el-row><!-- button 按钮 --><el-button>默认按钮</el-button><e…

DDPM 知识点

Generative Modeling by Estimating Gradients of the Data Distribution | Yang Song Score Matching 系列 (一) Non-normalized 模型估計 | 棒棒生

基于单片机智能饮水机加热系统的设计与实现

功能介绍 以51单片机作为主控系统&#xff1b;LCD1602液晶显示当前水温&#xff0c;定时提醒&#xff0c;水量变化DS18B20检测当前水体温度&#xff1b;水位传感器检测当前水位&#xff1b;继电器驱动加热片进行水温加热&#xff1b;定时提醒喝水&#xff0c;蜂鸣器报警&#x…

一键报警终端怎么样

一键报警终端是一种便携式设备&#xff0c;用于紧急情况下的一键求救。通过一键报警终端&#xff0c;用户可以发送紧急求助信号给预设的联系人或报警中心&#xff0c;以便及时获得救援。一键报警终端的主要功能和特点如下&#xff1a;1. 便携式设计&#xff1a;一键报警终端通常…

【Android studio】学号及姓名的输入保存页面

一、设计需求 设计一个页面有两个编辑框&#xff0c;分别输入学号和姓名。有两个按钮&#xff0c;一个是修改按钮&#xff0c;当按下修改按钮&#xff0c;编辑框可以进行编辑&#xff1b;一个是保存按钮&#xff0c;当按下保存按钮&#xff0c;使编辑框显示当前的内容并且编辑…

在线性能分析工具Arthas基于Springboot安装配置使用和Arthas Tunnel安装配置使用

概要 Arthas 是一款线上监控诊断产品&#xff0c;通过全局视角实时查看应用 load、内存、gc、线程的状态信息&#xff0c;并能在不修改应用代码的情况下&#xff0c;对业务问题进行诊断&#xff0c;包括查看方法调用的出入参、异常&#xff0c;监测方法执行耗时&#xff0c;类加…

Python Web开发入门教程(非常详细)

Python是一种非常流行的编程语言&#xff0c;被广泛应用于数据科学、Web开发、人工智能、机器学习等领域。Python语言易学易用&#xff0c;是许多初学者进入编程世界的入门选择。然而&#xff0c;学习Python并不是一件简单的事情&#xff0c;尤其是对于初学者而言。在本文中&am…

深度学习——优化器Optimizer

代码以及详细注释&#xff1a; import torch import torch.utils.data as Data import torch.nn.functional as F import matplotlib.pyplot as plt# torch.manual_seed(1) # reproducible """超参数 """ # 学习率 LR 0.01 # 批大小 BATCH_…

API测试之Postman使用完全指南

前言 Postman是一个可扩展的API开发和测试协同平台工具&#xff0c;可以快速集成到CI/CD管道中。旨在简化测试和开发中的API工作流。 Postman 工具有 Chrome 扩展和独立客户端&#xff0c;推荐安装独立客户端。 Postman 有个 workspace 的概念&#xff0c;workspace 分 pers…

16、Python读取气象数据的正确姿势

文章目录 一、气象数据格式&#xff08;常用&#xff09;二、单个文件读取1. 常规格式2. CSV格式3. NetCDF格式4. GRIB格式 一、气象数据格式&#xff08;常用&#xff09; 常规格式&#xff08;Plain Text&#xff09;&#xff1a;气象数据可以使用纯文本格式进行存储&#xf…

漏洞复现 || 某友文件上传

免责声明 技术文章仅供参考&#xff0c;任何个人和组织使用网络应当遵守宪法法律&#xff0c;遵守公共秩序&#xff0c;尊重社会公德&#xff0c;不得利用网络从事危害国家安全、荣誉和利益&#xff0c;未经授权请勿利用文章中的技术资料对任何计算机系统进行入侵操作。利用此…

HarmonyOS学习路之开发篇—流转(跨端迁移 一)

跨端迁移开发 场景介绍 开发者在应用FA中通过调用流转任务管理服务、分布式任务调度的接口&#xff0c;实现跨端迁移。 1. 设备A上的应用FA向流转任务管理服务注册一个流转回调&#xff1a; Alt1-系统推荐流转&#xff1a;系统感知周边有可用设备后&#xff0c;主动为用户提…

网络版本的计算器

文章目录 1. TCP协议通讯流程2. 应用层2.1 再谈 "协议" 3. 网络版计算器3.1 服务器提供服务3.1.1 提取有效载荷3.1.2 服务器的反序列化3.1.3 计算服务3.1.4 服务器的序列化3.1.5 添加序列化后的长度 3.2 客户端发送请求3.2.1 填充客户端请求3.2.2 客户端进行序列化3.…

为什么我挖不倒sql注入啊!

为什么我挖不倒sql注入啊&#xff01; 背景一句话讲原理小白速挖注入 背景 不知道是不是初学安全的小伙伴都和我一样&#xff0c;刚开始学的时候&#xff0c;诶挺简单啊&#xff01;我咋这么聪明一学就会&#xff0c;靶场轻轻松松过关&#xff0c;到了实战根本挖不出来&#x…

【C++】float / double 与 0 值比较

【C】float / double 与 0 值比较 文章目录 【C】float / double 与 0 值比较1. 概述不同1.1 - float 与 double 实际存储1.2 - C 语言与 C 中不同 2. 比较方法2.1 - C 风格比较2.2 - 使用 limits 函数 3. 参考链接 References 1. 概述不同 当然使用普通的比较没有问题&#xf…

项目管理中,WBS与项目计划有什么区别?

为了成功完成项目并控制成本&#xff0c;我们有必要采取科学的项目管理方法。实现这一目标的工具是项目计划和工作分解结构&#xff08;WBS&#xff09;。 WBS 与项目计划是项目管理中必不可少的工具&#xff0c;但两者有不同的用途。WBS精确描述了项目工作和可交付成果&#…

前端vue入门(纯代码)26_多级路由

如果耐不住寂寞&#xff0c;你就看不到繁华。 【24.Vue Router--多级路由】 [可以去官网看看Vue Router文档](嵌套路由 | Vue Router (vuejs.org)) 在实际开发中&#xff0c;我们不单单会使用到一层路由&#xff0c;有时候会涉及到两层或两层以上的路由&#xff0c;多级路由…

带清除按钮的输入框

// index.html <!DOCTYPE html> <html> <head><meta charset"utf-8"><meta name"viewport" content"widthdevice-width, initial-scale1, maximum-scale1"><title>测试 - layui</title><link rel&…