Kaggle Feedback Prize 3比赛总结:如何高效使用hidden states输出(2)

news2024/9/20 18:44:28

比赛链接:https://www.kaggle.com/competitions/feedback-prize-english-language-learning

在Kaggle Feedback Prize 3比赛总结:如何高效使用hidden states输出(2)中介绍了针对last layer hidden state的各种pooling的方法。

在利用Transformer类的预训练模型进行下游任务的微调过程中,最后一层的输出不一定是输入文本的最佳表示。如下图所示,Transformer encoder的每一层的hidden state都是对输入文本的上下文表征。对于预训练的语言模型,包括Transformer,输入文本的最有效的上下文表征往往发生在中间层,而顶层则专门用于专一任务的建模。因此,仅仅利用最后一层输出可能会限制预训练的表示的力量。本文将介绍几种常见的layer pooling的办法。
在这里插入图片描述

Concatenate Pooling

Concatenate pooling是将不同层的输出堆叠到一起。在这里我将介绍concatenate最后4层的输出。如下图的最后一行表示。例,如果模型的hidden state的大小为768,那么每一层的输出大小为[batch size, sequence length, 768]。如果将最后四层的输出concatenate到一起,那么输出为[batch size, sequence length, 768 * 4].

在这里插入图片描述

代码实现如下:

with torch.no_grad():
    outputs = model(features['input_ids'], features['attention_mask'])
all_hidden_states = torch.stack(outputs[2])

concatenate_pooling = torch.cat(
    (all_hidden_states[-1], all_hidden_states[-2], all_hidden_states[-3], all_hidden_states[-4]),-1
)
concatenate_pooling = concatenate_pooling[:, 0]

logits = nn.Linear(config.hidden_size*4, 1)(concatenate_pooling) # regression head

print(f'Hidden States Output Shape: {all_hidden_states.detach().numpy().shape}')
print(f'Concatenate Pooling Output Shape: {concatenate_pooling.detach().numpy().shape}')
print(f'Logits Shape: {logits.detach().numpy().shape}')

输出为:
Hidden States Output Shape: (13, 16, 256, 768)
Concatenate Pooling Output Shape: (16, 3072)
Logits Shape: (16, 1)

注:这里的例子是使用最后四层的 [CLS] token。

Weighted Layer Pooling

这个方法是较为常用的方法。这个方法是利用最后几层的hidden state,同时定义一个可学习的参数来决定每一层的hidden state占得比重。代码实现如下:

class WeightedLayerPooling(nn.Module):  
    def __init__(self, last_layers = CFG.weighted_layer_pooling_num, layer_weights = None):  
        super(WeightedLayerPooling, self).__init__()  
        self.last_layers = last_layers  
        self.layer_weights = layer_weights if layer_weights is not None \  
        else nn.Parameter(torch.tensor([ 1/ last_layers] * last_layers, dtype = torch.float))  
  
    def forward(self, features):  
        all_layer_embedding = torch.stack(features)  
        all_layer_embedding = all_layer_embedding[-self.last_layers:, :, :, :] 
        weight_factor = self.layer_weights.unsqueeze(-1).unsqueeze(-1).unsqueeze(-1).expand(all_layer_embedding.size())
        print(f'weight_factor: {self.layer_weights}')  
        weighted_average = (weight_factor*all_layer_embedding).sum(dim=0) / self.layer_weights.sum()  
        return weighted_average

with torch.no_grad():
    outputs = model(features['input_ids'], features['attention_mask'])

hidden_states = outputs.hidden_states
Last_layer_num = 4
WeightedLayerPooling = WeightedLayerPooling(last_layers=Last_layer_num, layer_weights=None)
weighted_layers_pooling_features = WeightedLayerPooling(hidden_states)

print(f'Hidden States Output Shape: {hidden_states.detach().numpy().shape}')
print(f'Weighted Pooling Output Shape: {weighted_layers_pooling_features.detach().numpy().shape}')

输出为:
Hidden States Output Shape: (13, 16, 256, 768)
Weighted Pooling Output Shape: (16, 256, 768)

Attention Pooling

Attention机制能够学习每一层的 [CLS] token的贡献程度。因此利用attention机制能够更好的利用不同层的[CLS]特征。这里介绍一种 dot-product attention 机制。这个过程可以被表示为:

o = W h T S o f t m a x ( q ∗ h C L S T ) h C L S o=W^T_hSoftmax(q*h^T_{CLS})h_{CLS} o=WhTSoftmax(qhCLST)hCLS
其中 W h T W^T_h WhT q q q 是可学习参数。

代码实现如下:

with torch.no_grad():
    outputs = model(features['input_ids'], features['attention_mask'])
all_hidden_states = torch.stack(outputs[2])

class AttentionPooling(nn.Module):
    def __init__(self, num_layers, hidden_size, hiddendim_fc):
        super(AttentionPooling, self).__init__()
        self.num_hidden_layers = num_layers
        self.hidden_size = hidden_size
        self.hiddendim_fc = hiddendim_fc
        self.dropout = nn.Dropout(0.1)

        q_t = np.random.normal(loc=0.0, scale=0.1, size=(1, self.hidden_size))
        self.q = nn.Parameter(torch.from_numpy(q_t)).float()
        w_ht = np.random.normal(loc=0.0, scale=0.1, size=(self.hidden_size, self.hiddendim_fc))
        self.w_h = nn.Parameter(torch.from_numpy(w_ht)).float()

    def forward(self, all_hidden_states):
        hidden_states = torch.stack([all_hidden_states[layer_i][:, 0].squeeze()
                                     for layer_i in range(1, self.num_hidden_layers+1)], dim=-1)
        hidden_states = hidden_states.view(-1, self.num_hidden_layers, self.hidden_size)
        out = self.attention(hidden_states)
        out = self.dropout(out)
        return out

    def attention(self, h):
        v = torch.matmul(self.q, h.transpose(-2, -1)).squeeze(1)
        v = F.softmax(v, -1)
        v_temp = torch.matmul(v.unsqueeze(1), h).transpose(-2, -1)
        v = torch.matmul(self.w_h.transpose(1, 0), v_temp).squeeze(2)
        return v

hiddendim_fc = 128
pooler = AttentionPooling(config.num_hidden_layers, config.hidden_size, hiddendim_fc)
attention_pooling_embeddings = pooler(all_hidden_states)
logits = nn.Linear(hiddendim_fc, 1)(attention_pooling_embeddings) # regression head

print(f'Hidden States Output Shape: {all_hidden_states.detach().numpy().shape}')
print(f'Attention Pooling Output Shape: {attention_pooling_embeddings.detach().numpy().shape}')

输出为:
Hidden States Output Shape: (13, 16, 256, 768)
Attention Pooling Output Shape: (16, 128)

总结

不同的layer pooling method 对于不同的任务有不一样的性能表现,因此对于每一个任务还是需要进行探索的。但在这一次比赛中,最后4层的WeightLayerPooling,是相对而言改进模型表现最有效的方法。同时,值得注意的是,也可以将不同的pooling方法得到的特征最后再做一次concatenate 这样会让模型学到各种不同pooling方法的feature,也有可能在一定程度上改善模型的性能。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/60541.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

Vue学习:Hello小案例

使用Vue的目的:构建用户界面(需要使用容器 摆放这个界面的内容) favicon.ico:1 GET http://127.0.0.1:5500/favicon.ico 404 (Not Found) 没有页签图标 在者服务器中 http://127.0.0.1:5500没有/favicon.ico 强制刷新网页:s…

3大经典分布式存储算法

文章目录1、背景2、算法2.1 分布存储之哈希取余算法2.2 分布式存储之一致性哈希算法2.3 分布式存储之哈希槽算法1、背景 一个经典的面试题目:1~2亿条数据需要缓存,请问如何设计这个方案? 回答:单台单机肯定不可能&…

Musical Christmas Lights——一个圣诞树灯光✨随音乐节奏改变的前端开源项目

文章目录前言视频介绍项目截图项目地址项目源码以上就是本篇文章的全部内容,将你编写好的项目分享给你的朋友们或者那个TA吧!制作不易,求个三连!❤️ 💬 ⭐️前言 今天博主在刷短视频时😐,朋友推…

VMware 虚拟机系统 与 win10 共享文件夹问题的解决

环境描述 本地:Win10 64位 VMware Workstation Pro 16 虚拟机,安装的 ubuntu 20.04 文件夹共享 win10 与 虚拟机的 ubuntu 共享文件夹,之前低版本的 VMware ,安装 VMware Tools,并且 win10 端设置好工作目录后&…

秒级使网站变灰,不改代码不上线,如何做到?

注意:文本不是讲如何将网站置灰的那个技术点,那个技术点之前汶川地震的时候说过。 本文不讲如何实现技术,而是讲如何在第一时间知道消息后,更快速的实现这个置灰需求的上线。 实现需求不是乐趣,指挥别人去实现需求才…

广域网技术——SR-MPLS隧道保护技术

目录 TI-LFA FRR保护技术 LFA FRR R-LFA FRR TI-LFA FRR Anycast FRR技术 Host-Standby技术 VPN FRR技术 SR-MPLS防微环技术 场景一 SR本地正切防微环 场景二 SR本地回切防微环 场景三 SR远端正切防微环 场景四 SR远端回切防微环 TI-LFA和防微环的对比 TI-LFA FRR…

41. set()函数:将可迭代对象转换为可变集合

41. set()函数:将可迭代对象转换为可变集合 文章目录41. set()函数:将可迭代对象转换为可变集合1. set( )函数的作用2. set( )函数的语法3. set函数创建空集合4. set函数的参数只能是可迭代对象4.1 将字符串转换为集合4.2 set( )函数的参数不能为整数4.3…

MIT 6.S081 Operating System Lecture8 (非常随意的笔记)

系列文章目录 文章目录系列文章目录Page FaultCOPY ON WRITEPage Fault eager allocation 通常,因为应用程序无法非常准确地估计自己要增加的内存有多少,所以通常申请的内存会比真实要使用的内存要多。 在XV6中,sbrk的实现默认是eager alloc…

基于粒子群算法优化的lssvm回归预测-附代码

基于粒子群算法优化的lssvm回归预测 - 附代码 文章目录基于粒子群算法优化的lssvm回归预测 - 附代码1.数据集2.lssvm模型3.基于粒子群算法优化的LSSVM4.测试结果5.Matlab代码摘要:为了提高最小二乘支持向量机(lssvm)的回归预测准确率&#xf…

【C++】stack/queue/list

文章目录注意事项1 emplace 与 push 的区别一、stack(栈)(先进后出、【头部插入、删除】、不许遍历)1 基本概念(栈是自顶向下(top在下),堆是向上)2 stack 常用接口(构造函数、赋值操…

[蓝牙 Mesh Zephyr]-[005]-Key

[蓝牙 Mesh & Zephyr]-[005]-Key 1. Keys Mesh Profile specification 定义了 2 种key:application keys (AppKey)和 network keys(NetKey)。AppKeys 用于保护 upper transport layer 的通信安全,Net…

如何手动添加NLTK data

一、问题描述 Python的自然语言处理库NLTK在安装之后需要下载一些data文件才能使用。官方比较推荐的方式是直接运行下载data的代码: import nltk nltk.download(punkt) 但是实际操作之后发现由于网络原因无法下载成功。 除了运行代码之外,官方还推荐…

分布式队列celery学习

说明:本文内容来自《python自动化运维快速入门》学习 一、介绍 Celery是由纯Python编写的,但协议可以用任何语言实现。目前,已有Ruby实现的RCelery、Node.js实现的node-celery及一个PHP客户端,语言互通也可以通过using webhooks…

[附源码]JAVA毕业设计客户台账管理(系统+LW)

[附源码]JAVA毕业设计客户台账管理(系统LW) 目运行 环境项配置: Jdk1.8 Tomcat8.5 Mysql HBuilderX(Webstorm也行) Eclispe(IntelliJ IDEA,Eclispe,MyEclispe,Sts都支持)。 项目技术&…

Activiti7工作流(二)

流程定义相关 流程定义查询 查询流程相关信息,包含流程定义,流程部署,流程定义版本 Test public void testDefinitionQuery(){//创建ProcessEngine对象ProcessEngine processEngine ProcessEngines.getDefaultProcessEngine();//获取仓库…

自动识别验证码实现系统自动登录(可扩展实现无人自动化操作,如领取各个平台的优惠券),不依赖第三方可以支持离线识别处理,附源码可直接运行

自动识别验证码实现系统自动登录(可扩展实现无人自动化操作,如领取各个平台的优惠券),不依赖第三方可以支持离线识别处理,附源码可直接运行。 实现过程: 1、只要是图片验证码都支持识别; 2、通过百度API实现验证码识别;(依赖第三方,且需要连接互联网,内网不可用,实…

7-FreeRTOS软件定时器

1- 简介 1.1 软件定时器简述 软件定时器就是允许函数设置一定的等待时间,然后执行。定时器执行的函数被称为定时器的回调函数。定时器从启动到执行回调函数之间的时间称为定时器的周期。定时器的回调函数在定时器的时间到达时执行。 软件定时器要先创建才能使用。…

实战Docker未授权访问提权

1、fofa关键字 port“2375” && body“page not found” 2、docker -H tcp://ip:port 可查看到当前所有的实例 3、docker -H tcp://ip:port pull alpine 4、docker -H tcp://ip:port run -it --privileged alpine bin/sh 5、fdisk -l 查看其分区结构 6、创建一个…

Java安全-CC1

CC1 这里用的是组长的链子和yso好像不太一样&#xff0c;不过大体上都是差不多的。后半条的链子都是一样的&#xff0c;而且这条更短更易理解。yso的CC1过段时间再看一下。 环境 Maven依赖&#xff1a; <dependencies><dependency><groupId>commons-colle…

十四、使用 Vue Router 开发单页应用(3)

本章概要 命名路由命名视图编程式导航传递 prop 到路由组件HTML 5 history 模式 14.5 命名路由 有时通过一个名称来标识路由会更方便&#xff0c;特别是在链接到路由&#xff0c;或者执行导航时。可以在创建 Router 实例时&#xff0c;在routes 选项中为路由设置名称。 修改…