从预训练的BERT中提取Embedding

news2025/1/11 12:44:00

文章目录

    • 背景
    • 前置准备
    • 思路
    • 利用Transformer 库实现

背景

假设要执行一项情感分析任务,样本数据如下
在这里插入图片描述
可以看到几个句子及其对应的标签,其中1表示正面情绪,0表示负面情绪。我们可以利用给定的数据集训练一个分类器,对句子所表达的情感进行分类。

前置准备

# 安装modelscope包
pip install modelscope
# 下载 bert-base-uncased 模型
modelscope download --model AI-ModelScope/bert-base-uncased

思路

  1. 分词:以第一句为例,我们使用WordPiece对句子进行分词,并得到标记(单词),如下所示。

    tokens = [I, love, Paris]

  2. 添加标记:在开头添加[CLS]标记,在结尾添加[SEP]标记,如下所示。

    tokens = [ [CLS], I, love, Paris, [SEP] ]

  3. 填充:为了保持所有标记的长度一致,我们将数据集中的所有句子的标记长度设为7。句子I loveParis的标记长度是5,为了使其长度为7,需要添加两个标记来填充,即[PAD]。因此,新标记如下所示。

    tokens = [ [CLS], I, love, Paris, [SEP], [PAD], [PAD] ]

    添加两个[PAD]标记后,标记的长度达到所要求的7。

  4. 注意力掩码:下一步,要让模型理解[PAD]标记只是为了匹配标记的长度,而不是实际标记的一部分。为了做到这一点,我们需要引入一个注意力掩码。我们将所有位置的注意力掩码值设置为1,将[PAD]标记的位置设置为0,如下所示。

    attention_mask = [ 1, 1, 1, 1, 1, 0, 0]

  5. 映射到token id:然后,将所有的标记映射到一个唯一的标记ID。假设映射的标记ID如下所示。

    token_ids = [101, 1045, 2293, 3000, 102, 0, 0]

    ID 101表示标记[CLS],1045表示标记I,2293表示标记love,以此类推。

    现在,我们把token_ids和attention_mask一起输入预训练的BERT模型,并获得每个标记的特征向量(嵌入)。通过代码,我们可以进一步理解以上步骤。下图显示的标记+单词而不是id,但实际传入的是id
    在这里插入图片描述

以上,可以得到每个单词的Embedding,整个句子的Embedding是 R [ C L S ] R_{[CLS]} R[CLS]

利用Transformer 库实现

from transformers import BertModel, BertTokenizer
import torch
# 下载并加载预训练的模型
model = BertModel.from_pretrained('bert-base-uncased')
# 下载并加载用于预训练模型的词元分析器。
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')

# 下面,让我们看看如何对输入进行预处理。
# 0. 对输入进行预处理假设输入句如下所示。
sentence = 'I love Paris'
# 1. 分词
tokens = tokenizer.tokenize(sentence)
print(tokens) # ['i', 'love', 'paris']

# 2. 添加标记
tokens = ['[CLS]'] + tokens + ['[SEP]']
print(tokens) # ['[CLS]', 'i', 'love', 'paris', '[SEP]']

# 3. 填充
tokens = tokens + ['[PAD]'] + ['[PAD]']
print(tokens) #['[CLS]', 'i', 'love', 'paris', '[SEP]', '[PAD]', '[PAD]' ]

# 4. 注意力掩码
attention_mask = [1 if i!= '[PAD]' else 0 for i in tokens]
print(attention_mask) # [1, 1, 1, 1, 1, 0, 0]

# 5. 将所有标记转换为它们的标记ID
token_ids = tokenizer.convert_tokens_to_ids(tokens)
print(token_ids) # [101, 1045, 2293, 3000, 102, 0, 0]

# 6. 将token_ids和attention_mask转换为张量
token_ids = torch.tensor(token_ids).unsqueeze(0)
attention_mask = torch.tensor(attention_mask).unsqueeze(0)


# 7. 将token_ids和atten-tion_mask送入模型,并得到嵌入向量。
# 需要注意,model返回的输出是一个有两个值的元组。第1个值hidden_rep表示隐藏状态的特征,它包括从顶层编码器(编码器12)获得的所有标记的特征。第2个值cls_head表示[CLS]标记的特征。
hidden_rep, cls_head = model(token_ids, attention_mask = attention_mask)
print(hidden_rep.shape) # torch.Size([1, 7, 768])

'''
数组[1, 7, 768]表示[batch_size, se-quence_length, hidden_size],也就是说,批量大小设为1,序列长度等于标记长度,即7。因为有7个标记,所以序列长度为7。隐藏层的大小等于特征向量(嵌入向量)的大小,在BERT-base模型中,其为768。
* hidden_rep[0][0]给出了第1个标记[CLS]的特征。   
* hidden_rep[0][1]给出了第2个标记I的特征。   
* hidden_rep[0][2]给出了第3个标记love的特征
'''

print(cls_head.shape) # torch.Size([1, 768])

'''
大小[1, 768]表示[batch_size, hid-den_size]。我们知道cls_head持有句子的总特征,所以,可以用cls_head作为句子I love Paris的整句特征。
'''


以上获得的是从顶层编码器(编码器12)获得的特征,如果要获取所有编码器的特征,需要修改以下两个地方。

# 下载并加载预训练的模型时,设置output_hidden_states = True
model = BertModel.from_pretrained('bert-base-uncased', output_hidden_states = True)
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')

# 调用模型时,产生的是三元组
last_hidden_state, pooler_output, hidden_states = model(token_ids, attention_mask = attention_mask)

'''
* last_hidden_state,它仅有从最后的编码器(编码器12)中获得的所有标记的特征
* pooler_output表示来自最后的编码器的[CLS]标记的特征,它被一个线性激活函数和tanh激活函数进一步处理。
* hidden_states包含从所有编码器层获得的所有标记的特征。它是一个包含13个值的元组,含有所有编码器层(隐藏层)的特征,即从输入嵌入层h到最后的编码器层h。
'''

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2274905.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

从CentOS到龙蜥:企业级Linux迁移实践记录(系统安装)

引言: 随着CentOS项目宣布停止维护CentOS 8并转向CentOS Stream,许多企业和组织面临着寻找可靠替代方案的挑战。在这个背景下,龙蜥操作系统(OpenAnolis)作为一个稳定、高性能且完全兼容的企业级Linux发行版&#xff0…

车联网安全--TLS握手过程详解

目录 1. TLS协议概述 2. 为什么要握手 2.1 Hello 2.2 协商 2.3 同意 3.总共握了几次手? 1. TLS协议概述 车内各ECU间基于CAN的安全通讯--SecOC,想必现目前多数通信工程师们都已经搞的差不多了(不要再问FvM了);…

iOS实际开发中使用Alamofire实现多文件上传(以个人相册为例)

引言 在移动应用中,图片上传是一个常见的功能,尤其是在个人中心或社交平台场景中,用户经常需要上传图片到服务器,用以展示个人风采或记录美好瞬间。然而,实现多图片上传的过程中,如何设计高效的上传逻辑并…

基于phpstudy快速搭建本地php环境(Windows)

好好生活,别睡太晚,别爱太满,别想太多。 2025.1.07 声明 仅作为个人学习使用,仅供参考 对于CTF-Web手而言,本地PHP环境必不可少,但对于新手来说从下载PHP安装包到配置PHP环境是个非常繁琐的事情&#xff0…

ffmpeg 编译遇到的坑

makeinfo: error parsing ./doc/t2h.pm: Undefined subroutine &Texinfo::Config::set_from_init_file called at ./doc/t2h.pm line 24. 编译选项添加: --disable-htmlpages

Git:merge合并、冲突解决、强行回退的终极解决方案

首先还是得避免冲突的发生,无法避免时再去解决冲突,避免冲突方法: 时常做pull、fatch操作,不要让自己本地仓库落后太多版本;在分支操作,如切换分支、合并分支、拉取分支前,及时清理Change&#…

国内外网络安全政策动态(2024年12月)

▶︎ 1.2项网络安全国家标准获批发布 2024年12月6日,根据2024年11月28日国家市场监督管理总局、国家标准化管理委员会发布的中华人民共和国国家标准公告(2024年第29号),全国网络安全标准化技术委员会归口的2项网络安全国家标准正…

新兴的开源 AI Agent 智能体全景技术栈

新兴的开源 AI Agent 智能体全景技术栈 LLMs:开源大模型嵌入模型:开源嵌入模型模型的访问和部署:Ollama数据存储和检索:PostgreSQL, pgvector 和 pgai后端:FastAPI前端:NextJS缺失的一环:评估和…

通过一个含多个包且引用外部jar包的项目实例感受Maven的便利性

目录 1 引言2 手工构建3 基于Maven的构建4 总结 1 引言 最近在阅读一本Java Web的书籍1时,手工实现书上的一个含多个Packages的例子,手工进行编译、运行,最终实现了效果。但感觉到整个构建过程非常繁琐,不仅要手写各个源文件的编…

信息科技伦理与道德3:智能决策

1 概述 1.1 发展历史 1950s-1980s:人工智能的诞生与早期发展热潮 1950年:图灵发表了一篇划时代的论文,并提出了著名的“图灵测试”;1956年:达特茅斯会议首次提出“人工智能”概念;1956年-20世纪70年代&a…

Sql 创建用户

Sql server 创建用户 Sql server 创建用户SQL MI 创建用户修改其他用户密码 Sql server 创建用户 在对应的数据库执行,该用户得到该库的所有权限 test.database.chinacloudapi.cn DB–01 DB–02 创建服务器登录用户 CREATE LOGIN test WITH PASSWORD zDgXI7rsafkak…

【再谈设计模式】观察者模式~对象间依赖关系的信使

一、引言 在软件工程、软件开发的世界里,设计模式如同建筑蓝图中的经典结构,帮助开发者构建更加灵活、可维护和可扩展的软件系统。观察者模式就是其中一种极为重要的行为型设计模式,它在处理对象间的一对多关系时展现出独特的魅力。 二、定义…

如何设计一个注册中心?以Zookeeper为例

这是小卷对分布式系统架构学习的第8篇文章,在写第2篇文章已经讲过服务发现了,现在就从组件工作原理入手,讲讲注册中心 以下是面试题: 某团面试官:你来说说怎么设计一个注册中心? 我:注册中心嘛&…

【Unity3D】导出Android项目以及Java混淆

Android Studio 下载文件归档 | Android Developers Android--混淆配置(比较详细的混淆规则)_android 混淆规则-CSDN博客 Unity版本:2019.4.0f1 Gradle版本:5.6.4(或5.1.1) Gradle Plugin版本&#xff…

2024 China Collegiate Programming Contest (CCPC) Zhengzhou Onsite 基础题题解

今天先发布基础题的题解,明天再发布铜牌题和银牌题的题解 L. Z-order Curve 思路:这题目说了,上面那一行,只有在偶数位才有可能存在1,那么一定存在这样的数,0 ,1,100, 10000,那么反之,我们的数…

【FlutterDart】tolyui_feedback组件例子效果(23 /100)

上效果图 有12种位置展示效果;很能满足大部分需要 代码如下: import package:flutter/material.dart; import package:tolyui_feedback/tolyui_feedback.dart;class TolyTooltipDemo extends StatelessWidget {const TolyTooltipDemo({super.key});ove…

服务器攻击方式有哪几种?

随着互联网的快速发展,网络攻击事件频发,已泛滥成互联网行业的重病,受到了各个行业的关注与重视,因为它对网络安全乃至国家安全都形成了严重的威胁。面对复杂多样的网络攻击,想要有效防御就必须了解网络攻击的相关内容…

Mermaid 使用教程之流程图 - 从入门到精通

本文由 Mermaid中文文档 整理而来,并且它同时提供了一个Mermaid在线编辑器。 Mermaid 流程图 - 基本语法​ 流程图由节点(几何形状)和边(箭头或线)组成。Mermaid代码定义了如何创建节点和边,并适应不同的…

Flink系统知识讲解之:如何识别反压的源头

Flink系统知识之:如何识别反压的源头 什么是反压 Ufuk Celebi 在一篇古老但仍然准确的文章中对此做了很好的解释。如果您不熟悉这个概念,强烈推荐您阅读这篇文章。如果想更深入、更低层次地了解该主题以及 Flink 网络协议栈的工作原理,这里有…

网络-ping包分析

-a:使 ping 在收到响应时发出声音(适用于某些操作系统)。-b:允许向广播地址发送 ping。-c count:指定发送的 ping 请求的数量。例如,ping -c 5 google.com 只发送 5 个请求。-i interval:指定两…