last_hidden_state vs pooler_output的区别

news2024/10/5 15:31:33

一、问题来源:

from transformers import AutoTokenizer, AutoModel
import torch
# Load model from HuggingFace Hub
MODEL_NAME_PATH = 'xxxx/model/bge-large-zh'
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME_PATH)
model = AutoModel.from_pretrained(MODEL_NAME_PATH)

模型结构如下:

BertModel(
  (embeddings): BertEmbeddings(
    (word_embeddings): Embedding(21128, 1024, padding_idx=0)
    (position_embeddings): Embedding(512, 1024)
    (token_type_embeddings): Embedding(2, 1024)
    (LayerNorm): LayerNorm((1024,), eps=1e-12, elementwise_affine=True)
    (dropout): Dropout(p=0.1, inplace=False)
  )
  (encoder): BertEncoder(
    (layer): ModuleList(
      (0-23): 24 x BertLayer(
        (attention): BertAttention(
          (self): BertSelfAttention(
            (query): Linear(in_features=1024, out_features=1024, bias=True)
            (key): Linear(in_features=1024, out_features=1024, bias=True)
            (value): Linear(in_features=1024, out_features=1024, bias=True)
            (dropout): Dropout(p=0.1, inplace=False)
          )
          (output): BertSelfOutput(
            (dense): Linear(in_features=1024, out_features=1024, bias=True)
            (LayerNorm): LayerNorm((1024,), eps=1e-12, elementwise_affine=True)
            (dropout): Dropout(p=0.1, inplace=False)
          )
        )
        (intermediate): BertIntermediate(
          (dense): Linear(in_features=1024, out_features=4096, bias=True)
          (intermediate_act_fn): GELUActivation()
        )
        (output): BertOutput(
          (dense): Linear(in_features=4096, out_features=1024, bias=True)
          (LayerNorm): LayerNorm((1024,), eps=1e-12, elementwise_affine=True)
          (dropout): Dropout(p=0.1, inplace=False)
        )
      )
    )
  )
  (pooler): BertPooler(
    (dense): Linear(in_features=1024, out_features=1024, bias=True)
    (activation): Tanh()
  )
)

Q1、cls的值和pooler的值是一样的吗?
Q2、最后的pooler层和hidden层是什么关系?

二、实验证明:

Q1、cls的值和pooler的值是一样的吗?

# Sentences we want sentence embeddings for
sentences = ["开心", "快乐", "难过", "天气", "今天会有大大的台风吗?"]
# Tokenize sentences
encoded_input = tokenizer(sentences, padding=True, truncation=True, return_tensors='pt', max_length=200)
# for retrieval task, add an instruction to query
# encoded_input = tokenizer([instruction + q for q in queries], padding=True, truncation=True, return_tensors='pt')
# Compute token embeddings
with torch.no_grad():
    model_output = model(**encoded_input)
    # Perform pooling. In this case, cls pooling.
    sentence_embeddings = model_output[0][:, 0]
# normalize embeddings
sentence_embeddings = torch.nn.functional.normalize(sentence_embeddings, p=2, dim=1)

print(‘cls:’, model_output[0][:, 0, :])

cls: tensor([[ 0.3269, -0.6412, -0.2382,  ...,  0.0255, -0.1801, -0.3025],
        [ 0.1351, -0.5155, -0.1700,  ...,  0.1093, -0.3750, -0.1323],
        [ 0.2752, -0.1703, -0.2730,  ...,  0.0376, -0.0339, -0.3541],
        [ 0.1346, -0.0378, -0.5070,  ...,  0.0078,  0.0472, -0.1815],
        [-0.4051,  0.1123, -0.3873,  ...,  0.3585,  0.4913,  0.3192]])

print(‘pooler:’, model_output[1])

pooler: tensor([[ 0.3888, -0.2329, -0.1749,  ...,  0.1678,  0.3938, -0.3191],
        [ 0.3949, -0.2882, -0.0945,  ...,  0.1802,  0.2705, -0.1891],
        [ 0.4765, -0.1235, -0.2330,  ...,  0.3005,  0.3487, -0.1290],
        [ 0.3851, -0.1853, -0.3189,  ...,  0.2757,  0.3601, -0.3220],
        [ 0.3008, -0.3742, -0.4550,  ...,  0.4318,  0.2130, -0.1575]])

cls的值和pooler的值不一样

Q2、最后的pooler层和hidden层是什么关系?

理论层面:

transformers.models.bert.modeling_bert.BertModel.forward方法中这么一行代码:

sequence_output = encoder_outputs[0]
pooled_output = self.pooler(sequence_output) if self.pooler is not None else None

pooler的定义:

self.pooler = BertPooler(config) if add_pooling_layer else None

BertPooler的定义:

class BertPooler(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.dense = nn.Linear(config.hidden_size, config.hidden_size)
        self.activation = nn.Tanh()

    def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
        # We "pool" the model by simply taking the hidden state corresponding
        # to the first token.
        first_token_tensor = hidden_states[:, 0]
        pooled_output = self.dense(first_token_tensor)
        pooled_output = self.activation(pooled_output)
        return pooled_output

从上面的源码可以看出,pooler_output 就是[CLS]embedding又经历了一次全连接层的输出

数据层面:
model.pooler(model_output[0])
tensor([[ 0.3888, -0.2329, -0.1749,  ...,  0.1678,  0.3938, -0.3191],
        [ 0.3949, -0.2882, -0.0945,  ...,  0.1802,  0.2705, -0.1891],
        [ 0.4765, -0.1235, -0.2330,  ...,  0.3005,  0.3487, -0.1290],
        [ 0.3851, -0.1853, -0.3189,  ...,  0.2757,  0.3601, -0.3220],
        [ 0.3008, -0.3742, -0.4550,  ...,  0.4318,  0.2130, -0.1575]],
       grad_fn=<TanhBackward0>)

在这里插入图片描述
pooler_output 就是[CLS]embedding又经历了一次全连接层的输出

三、结论:

pooler就是将[CLS]这个token再过一下全连接层+Tanh激活函数,作为该句子的特征向量

四、Bert的Pooler_output的由来

我们知道,BERT的训练包含两个任务:MLM和NSP任务(Next Sentence Prediction)。 对这两个任务不熟悉的朋友可以参考:BERT源码实现与解读(Pytorch) 和 【论文阅读】BERT 两篇文章。

其中MLM就是挖空,然后让bert预测这个空是什么。做该任务是使用token embedding进行预测。

而Next Sentence Prediction就是预测bert接受的两句话是否为一对。例如:窗前明月光,疑是地上霜 为 True,窗前明月光,李白打开窗为False。

所以,NSP任务需要句子的语义信息来预测,但是我们看下源码是怎么做的。

class BertForNextSentencePrediction(BertPreTrainedModel):
	
    def __init__(self, config):
        super().__init__(config)

        self.bert = BertModel(config)
        self.cls = BertOnlyNSPHead(config)	# 这个就是一个 nn.Linear(config.hidden_size, 2)
		...
	
	def forward(...):
		...
		outputs = self.bert(...)
		pooled_output = outputs[1] # 取pooler_output
		seq_relationship_scores = self.cls(pooled_output)	# 使用pooler_ouput送给后续的全连接层进行预测
		...

从上面的源码可以看出,在NSP任务训练时,并不是直接使用[CLS]token的embedding作为句子特征传给后续分类头的,而是使用的是pooler_output。个人原因可能是因为直接使用[CLS]的embedding效果不够好。
但在MLM任务时,是直接使用的是last_hidden_state,有兴趣可以看一下

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/857831.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

MongoDB 备份与恢复

1.1 MongoDB的常用命令 mongoexport / mongoimport mongodump / mongorestore 有以上两组命令在备份与恢复中进行使用。 1.1.1 导出工具mongoexport Mongodb中的mongoexport工具可以把一个collection导出成JSON格式或CSV格式的文件。可以通过参数指定导出的数据项&#xff0c…

公检系统升级:校对软件提升司法办案水平

公检系统升级中引入校对软件可以显著提升司法办案水平&#xff0c;为司法工作提供更高效和准确的支持。以下是校对软件在提升司法办案水平方面的作用&#xff1a; 1.提高文书准确性&#xff1a;校对软件可以自动检测和修正法律文书中的语法、拼写和标点等错误。它可以捕捉到人眼…

nginx自定义负载均衡及根据cpu运行自定义负载均衡

1.nginx如何自定义负载均衡 在Nginx中&#xff0c;可以通过配置文件自定义负载均衡策略。具体步骤如下&#xff1a; 首先&#xff0c;在Nginx配置文件中定义一个upstream模块&#xff0c;并设置负载均衡策略和后端服务器列表&#xff0c;例如&#xff1a; upstream myapp {ser…

数字孪生轨道交通,地铁视频孪生三维可视化管控平台

为促进数字孪生城市领域高质量发展&#xff0c;延续《数字孪生城市应用案例汇编&#xff08;2022年&#xff09;》已有研究成果&#xff0c;宣传推广一批创新性强、具有示范效应的优秀案例&#xff0c;为各部委及地方政府推动数字孪生城市建设提供有力支撑&#xff0c;中国信息…

使用 Visual Studio Code 调试 CMake 脚本

之前被引入到 Visual Studio 中的 CMake 调试器&#xff0c;现已在 Visual Studio Code 中可用。 也就是说&#xff0c;现在你可以通过在 VS Code 中安装 CMake 工具扩展&#xff0c;来调试你的 CMakeLists.txt 脚本了。是不是很棒? 背景知识 Visual C 开发团队和 CMake 的维…

【数据库】P0 创建数据库环境 MySQL + DataGrip

创建数据库环境 下载安装 MySQL下载安装 DataGrip 下载安装 MySQL Windows版本_MySQL 下载地址&#xff1a; https://dev.mysql.com/downloads/mysql/ 下载后依照默认顺序安装即可&#xff0c;本博文将讲述简约安装步骤&#xff1b; 如需详细安装步骤可见&#xff1a;https:/…

SharePoint 管理

SharePoint平台使您能够以在线方式和本地方式轻松地管理和协调业务数据。因为其灵活性和易使用性&#xff0c;公司可以快速采用SharePoint来管理其业务数据。但是&#xff0c;SharePoint管理员在管理及审核SharePoint服务器时&#xff0c;内容的高级别协作和动态性质会导致问题…

Is a directory: ‘outs//.ipynb_checkpoints‘

提示out/文件夹的.ipynp_chechpoints是一个文件夹&#xff0c;但是打开文件夹却没有看到&#xff0c;可以得知他是一个隐藏文件夹&#xff0c;进入outs/文件夹&#xff0c;使用 ls -a可以看到所有文件 果然出现这个文件夹&#xff0c;但是我们这个outs/文件夹存放的是图片&am…

Django实现音乐网站 ⑼

使用Python Django框架制作一个音乐网站&#xff0c; 本篇主要是后台对专辑、首页轮播图原有功能的基础上进行部分功能实现和显示优化。 目录 专辑功能优化 新增编辑 专辑语种改为下拉选项 添加单曲优化显示 新增单曲多选 更新歌手专辑数、专辑单曲数 获取歌手专辑数 保…

实践指南-前端性能提升 270% | 京东云技术团队

一、背景 当我们疲于开发一个接一个的需求时&#xff0c;很容易忘记去关注网站的性能&#xff0c;到了某一个节点&#xff0c;猛地发现&#xff0c;随着越来越多代码的堆积&#xff0c;网站变得越来越慢。 本文就是从这样的一个背景出发&#xff0c;着手优化网站的前端性能&a…

java输出GB2312、GBK、GB18030、UTF-8所有的字符

相关文章 【转】彻底搞明白 GB2312、GBK 、GB18030和UTF-8 1.ASICII、GB2312、GBK、GB18030 以及 UTF8 的关系 2.编写代码 引入依赖&#xff1a;hutool工具类 <dependency><groupId>cn.hutool</groupId><artifactId>hutool-all</artifactId>&…

企业内部FAQ常见问题展示分享的价值

企业内部FAQ&#xff08;常见问题&#xff09;展示分享是一种将常见问题和解决方案以问答形式呈现给员工的方式。这种方式可以帮助企业提高工作效率、提供一致的解决方案、提升员工满意度和减少重复工作。 企业内部FAQ常见问题展示分享的价值&#xff1a; 1. 提高工作效率 企…

恒盛策略:内盘是买入还是卖出?

内盘&#xff0c;又称成交明细&#xff0c;是指交易所实时发布的每一笔成交价格、交易时刻、成交量等具体数据。关于股民来说&#xff0c;每天重视内盘数据已成为必修课。可是&#xff0c;当看到内盘数据时&#xff0c;很多人都会困惑&#xff0c;到底内盘是买入还是卖出呢&…

PCI 简易通讯控制器有黄色感叹号

一、问题描述 设备管理器中&#xff0c;其他设备中显示 “PCI 简易通讯控制器”驱动未安装&#xff0c;显示黄色感叹号&#xff1a; 二、原因分析 右键该驱动&#xff0c;查看属性ID&#xff0c;显示为&#xff1a; PCI \ VEN_8086&#xff06;DEV_1C3A&#xff06;SUBSYS…

学习51单片机怎么开始?

学习的过程不总是先打好基础&#xff0c;然后再盖上层建筑&#xff0c;尤其是实践性的、工程性很强的东西。如果你一定要先全面打好基础&#xff0c;再学习单片机&#xff0c;我觉得你一定学不好&#xff0c;因为你的基础永远打不好&#xff0c;因为基础太庞大了&#xff0c;基…

3.多线程(进阶)

文章目录 1.常见的锁策略1.1乐观锁 vs 悲观锁1.2互斥锁 vs 读写锁1.3重量级锁 vs 轻量级锁1.4自旋锁vs 挂起等待锁&#xff08;Spin Lock&#xff09;1.5公平锁 vs 非公平锁1.6可重入锁 vs 不可重入锁1.7相关面试题 2.CAS2.1什么是 CAS2.2CAS 是怎么实现的2.3CAS 有哪些应用2.3…

这50幅画让你看清世界真相,犀利深刻,值得一读!

让你看清这个世界的真相&#xff01; 01 自弃者扶不起 自强者打不倒 02 人人都活在假象里 03 宁可有病再治&#xff0c;也不愿意未雨绸缪 04 一个人成熟的表现 是具备了太极思维 05 最大的监狱是人的思维监狱 06 认知太浅&#xff0c;放弃学习 这就是焦虑和绝望的根本原因 0…

智能型静电消除器的优势有哪些?

智能型静电消除器是一种使用先进技术和智能控制系统来消除静电问题的设备。静电是由于电荷不平衡而引起的现象&#xff0c;常见于工业生产、医疗设备、办公环境等场合。静电的存在可能导致电子设备故障、火灾、等问题。 智能型静电消除器与传统静电消除器相比&#xff0c;具有以…

Java gradle创建聚合项目(父子级项目)

放个简单的demo给大家下载吧。 参考大神gitee地址&#xff1a;demo-gradle-multi: Gradle构建SpringBoot多模块项目示例 (gitee.com)

近期学习练习

练习&#xff1a; 或&#xff1a; //sqrt是数学库函数&#xff0c;开平方 //头文件为math.h 或