【论文精读2】用于多文档摘要生成的层次Transformer方法

news2024/11/18 7:51:18

前言

论文分享 来自2019ACL的多文档摘要生成方法论文,作者来自英国爱丁堡大学,引用数310
Hierarchical Transformers for Multi-Document Summarization 代码地址hiersumm

多文档摘要抽取的难点在于没有合适的数据集,同时过长的文档文本也导致现在硬件水平无法支撑模型的训练,Generating Wikipedia by summarizing long sequences.提出了WikiSum数据集,以维基百科的第一段内容作为摘要,以标题+引用的文章或根据标题在网页搜索的前10篇文章(除去wikipedia本文)作为相同主题的文章集(document cluster)输入,提出WikiSum的作者采用二步结构,先抽取式选择一部分重要的段落,然后将这些重要的段落拼接成一个段落后,采用生成式模型进行生成

作者指出,这种将所有重要段落拼接成一个段落的方式忽略了段落之间的层次信息,作者的创新点在于设计了学习文章间层次结构的transformer,并发现对输入文章进行排序能够增强模型效果

模型

首先作者构建了一个段落排序模型,采用基于LSTM的回归模型,以标题和段落作为输入,以段落与目标摘要的ROUGE-2分数作为输出进行拟合,得到一个可以根据标题和段落生成分数的预测模型,在测试时,段落会根据回归模型生成的分数进行排序,最后筛选出 L ′ L' L个段落作为最终摘要模型的输入,这里其实是摘要式方法的一种变种,可以有很多选择,整体模型的过程如下图所示,将排序后的 L ′ L' L个段落输入到encoder-decoder中最终输出生成式摘要
在这里插入图片描述
编码层: 编码层将字转化成向量,并加入位置信息,位置信息为三角函数编码
在这里插入图片描述
其中由于输入的位置包括字的位置信息和段落的位置信息,因此作者将两个三角函数位置信息通过拼接的方式,拼接到每一个字向量后,i为段落位置,j为字位置,拼接后得到位置编码 p e i j pe_{ij} peij, w i j w_{ij} wij为字向量, x i j 0 x_{ij}^0 xij0代表第0层,第i个段落第j个字符的向量
在这里插入图片描述
局部Transformer层: 和普通的Transformer一样,由多头注意力层和前向连接层构成
在这里插入图片描述
全局Transformer层: 全局Transformer层通过self-attention让每一个段落收集其他段落的信息,得到一个获取上下文信息的全局向量,其由多头池化(Multi head pooling),段落间注意力机制(inter paragraph attention mechanism)和前向连接层构成
多头池化: 是论文的创新点之一,其计算每一个字的分布,根据不同字的权重编码整个句子,是将序列字编码成单一向量的好方法,公式如下
在这里插入图片描述
由局部Transformer编码的向量维度为 [ b a t c h , s e q l e n , d i m ] [batch, seqlen, dim] [batch,seqlen,dim],通过两个全连接层分别映射到注意力值 a a a,维度为 [ b a t c h , n h e a d , s e q l e n , 1 ] [batch,n_{head},seqlen,1] [batch,nhead,seqlen,1]和值 b b b,维度为 [ b a t c h , n h e a d , s e q l e n , , d h e a d ] [batch,n_{head},seqlen,,d_{head}] [batch,nhead,seqlen,,dhead],注意力分数通过对注意力值进行softmax得到
在这里插入图片描述
将注意力分数和字向量进行相乘后,在序列长度维度进行相加和归一化后,得到每一个段落不同head的表征,多头池化代码如下,输入为 [ b a t c h , s e q l e n , d i m ] [batch, seqlen, dim] [batch,seqlen,dim],输出为 [ b a t c h , h e a d , d i m ] [batch,head,dim] [batch,head,dim]

class MultiHeadedPooling(nn.Module):
    def __init__(self, head_count, model_dim, dropout=0.1, use_final_linear=True):
        assert model_dim % head_count == 0
        self.dim_per_head = model_dim // head_count
        self.model_dim = model_dim
        super(MultiHeadedPooling, self).__init__()
        self.head_count = head_count
        self.linear_keys = nn.Linear(model_dim,
                                     head_count)
        self.linear_values = nn.Linear(model_dim,
                                       head_count * self.dim_per_head)
        self.softmax = nn.Softmax(dim=-1)
        self.dropout = nn.Dropout(dropout)
        if (use_final_linear):
            self.final_linear = nn.Linear(model_dim, model_dim)
        self.use_final_linear = use_final_linear

    def forward(self, key, value, mask=None):
        batch_size = key.size(0)
        dim_per_head = self.dim_per_head
        head_count = self.head_count

        def shape(x, dim=dim_per_head):
            """  projection """
            return x.view(batch_size, -1, head_count, dim) \
                .transpose(1, 2)

        def unshape(x, dim=dim_per_head):
            """  compute context """
            return x.transpose(1, 2).contiguous() \
                .view(batch_size, -1, head_count * dim)

        scores = self.linear_keys(key)
        value = self.linear_values(value)

        scores = shape(scores, 1).squeeze(-1)
        value = shape(value)
        # key_len = key.size(2)
        # query_len = query.size(2)
        #
        # scores = torch.matmul(query, key.transpose(2, 3))

        if mask is not None:
            mask = mask.unsqueeze(1).expand_as(scores)
            scores = scores.masked_fill(mask, -1e18)

        # 3) Apply attention dropout and compute context vectors.
        attn = self.softmax(scores)
        drop_attn = self.dropout(attn)
        context = torch.sum((drop_attn.unsqueeze(-1) * value), -2)
        if (self.use_final_linear):
            context = unshape(context).squeeze(1)
            output = self.final_linear(context)
            return output
        else:
            return context

不输出head维度,直接输出单一的段落向量方法,输入为 [ b a t c h , s e q l e n , d i m ] [batch, seqlen, dim] [batch,seqlen,dim],输出为 [ b a t c h , d i m ] [batch,dim] [batch,dim]

class MultiHeadPoolingLayer( nn.Module ):
    def __init__( self, embed_dim, num_heads  ):
        super().__init__()
        self.num_heads = num_heads
        self.dim_per_head = int( embed_dim/num_heads )
        self.ln_attention_score = nn.Linear( embed_dim, num_heads )
        self.ln_value = nn.Linear( embed_dim,  num_heads * self.dim_per_head )
        self.ln_out = nn.Linear( num_heads * self.dim_per_head , embed_dim )
    def forward(self, input_embedding , mask=None):
        a = self.ln_attention_score( input_embedding )
        v = self.ln_value( input_embedding )
        
        a = a.view( a.size(0), a.size(1), self.num_heads, 1 ).transpose(1,2)
        v = v.view( v.size(0), v.size(1),  self.num_heads, self.dim_per_head  ).transpose(1,2)
        a = a.transpose(2,3)
        if mask is not None:
            a = a.masked_fill( mask.unsqueeze(1).unsqueeze(1) , -1e9 ) 
        a = F.softmax(a , dim = -1 )

        new_v = a.matmul(v)
        new_v = new_v.transpose( 1,2 ).contiguous()
        new_v = new_v.view( new_v.size(0), new_v.size(1) ,-1 ).squeeze(1)
        new_v = self.ln_out( new_v )
        return new_v

段落间注意力: 段落间注意力是采用自注意力机制(scale dot product)对每一个段落的head表征进行学习,目的是让每一个段落学习到其他段落的关联信息,最终得到段落i,注意力头z的上下文表征 c o n t e x t i z context_i^z contextiz在这里插入图片描述
前向连接层: 前向连接层的输入为一个段落多个head的拼接, 最终输出和全局Transformer层输入一样维度的输出
在这里插入图片描述在这里插入图片描述
段落间注意力的注意力分数其实是段落间的关系系数,因此可以采用已学习好的图表征来代替,例如句法关系图(Lexical Relation Graph)和(近似对画图)Approximate Discourse Graph,有兴趣的读者可以看论文附录,替代方式如下
在这里插入图片描述

实验

在这里插入图片描述
作者比较了不同数量段落的情况下,tf-idf cosine similarity和使用回归模型排序选择方式下,抽取段落和真实摘要之间的ROUGE-L 召回值,可以发现回归模型能够选择更好的段落
在这里插入图片描述
作者比较了层次Transformer模型与Flat Transformer和Transformer Decoder with Memory Compression attention之间的性能,HT相比T-DMCA在更少的字符数下能够有些微的提升,同时,在1600长度文本训练的模型,在3000字输入的情况下能够获得更好的结果,近1个点的ROUGE提升
在这里插入图片描述
作者还做了消融实验,表明位置编码,多头池化和全局Transformer都是有用的

结论

论文中采用多头池化的方式获得多头上下文表征,并将不同段落的多头上下文表征进行相互学习

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/826729.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

剑指 Offer 55 - II. ! 平衡二叉树

剑指 Offer 55 - II. 平衡二叉树 输入一棵二叉树的根节点,判断该树是不是平衡二叉树。如果某二叉树中任意节点的左右子树的深度相差不超过1,那么它就是一棵平衡二叉树。 来自力扣K神的解法1,真的是太巧妙了! 方法recur检查以nod…

什么是自动化测试框架?自动化测试框架有哪些?

一、自动化测试 1、为什么要做自动化测试? 自动化测试就是把以人为驱动的测试行为转化为机器执行的一种过程,即模拟手工测试的步骤,通过执行测试脚本自动地测试软件自动化测试就是程序(脚本)测试程序,使用…

LeNet卷积神经网络-笔记

LeNet卷积神经网络-笔记 手写分析LeNet网三卷积运算和两池化加两全连接层计算分析 基于paddle飞桨框架构建测试代码 #输出结果为: #[validation] accuracy/loss: 0.9530/0.1516 #这里准确率为95.3% #通过运行结果可以看出,LeNet在手写数字识别MNIST验证…

如何开启一个java微服务工程

安装idea IDEA常用配置和插件(包括导入导出) https://blog.csdn.net/qq_38586496/article/details/109382560安装配置maven 导入source创建项目 修改项目编码utf-8 File->Settings->Editor->File Encodings 修改项目的jdk maven import引入…

【C++】类和对象——拷贝构造函数、运算符重载、日期类实现、const成员、取地址操作符重载

目录 拷贝构造函数运算符重载日期类实现const成员取地址及const取地址操作符重载 拷贝构造函数 拷贝构造函数:只有单个形参,该形参是对本类类型对象的引用(一般常用const修饰),在用已存在的类类型对象创建新对象时由编译器自动调用。 拷贝构…

SOLIDWORKS 钣金零件怎么画?

一、SOLIDWORKS 钣金功能介绍 SOLIDWORKS 是一款广泛应用于机械设计领域的 CAD 软件,其钣金功能可以帮助用户快速创建钣金件的 3D 模型。钣金折弯是一种常见的加工方式,可以将平面材料通过弯曲变形成为所需形状。 二、如何使用 SOLIDWORKS 钣金功能 步骤…

shell清理redis模糊匹配的多个key

#!/bin/bash# 定义Redis服务器地址和端口 REDIS_HOST"localhost" REDIS_PORT6380# 获取匹配键的数量 function get_matching_keys() {local key_pattern"$1"redis-cli -h $REDIS_HOST -p $REDIS_PORT -n 0 KEYS "$key_pattern" }# 删除匹配的键 …

一文带你详细了解Open API设计规范

写在前面: OpenAPI 规范(OAS)定义了一个标准的、语言无关的 RESTful API 接口规范,它可以同时允许开发人员和操作系统查看并理解某个服务的功能,而无需访问源代码,文档或网络流量检查(既方便人…

Atlas200DK A2联网实战

文章目录 1.Atlas原始网络信息2. 开发板联网2.1 使用Type-c 连接开发板2.2 修改本地网络适配器2.3 修改开发板网络信息2.4 测试外网连接 1.Atlas原始网络信息 Type-C 网口 ETH0 网口 ETH1 网口 2. 开发板联网 2.1 使用Type-c 连接开发板 使用xshell 等ssh终端登录开发板&…

【C++从0到王者】第十五站:list源码分析及手把手教你写一个list

文章目录 一、list源码分析1.分析构造函数2.分析尾插等 二、手把手教你写一个list1.结点声明2.list类的成员变量3.list类的默认构造函数4.list类的尾插5.结点的默认构造函数6.list类的迭代器7.设计const迭代器8.list的insert、erase等接口9.size10.list的clear11.list的析构函数…

【java安全】CommonsBeanUtils1

文章目录 【java安全】CommonsBeanUtils1前言Apache Commons BeanutilsBeanComparator如何调用BeanComparator#compare()方法?构造POC完整POC 调用链 【java安全】CommonsBeanUtils1 前言 在之前我们学习了java.util.PriorityQueue,它是java中的一个优…

2.2 身份鉴别与访问控制

数据参考:CISP官方 目录 身份鉴别基础基于实体所知的鉴别基于实体所有的鉴别基于实体特征的鉴别访问控制基础访问控制模型 一、身份鉴别基础 1、身份鉴别的概念 标识 实体身份的一种计算机表达每个实体与计算机内部的一个身份表达绑定信息系统在执行操作时&a…

3、详解桶排序及排序内容总结

堆 满二叉树可以用一个数组中从0开始的连续一段来记录 i i i位置左孩子: 2 ∗ i + 1 2*i+1 2∗i+1,右孩子: 2 ∗ i + 2 2*i+2 2∗i+2,父: ( i − 1 ) / 2 (i-1)/2 (i−1)/2 大根堆 每一棵子树的根为最大值 小根堆 每一棵子树的根为最小值 建大根堆 不断地根据公…

配置HDFS单机版,打造数据存储的强大解决方案

目录 简介:步骤:安装java下载安装hadoop配置hadoop-env.sh配置 core-site.xml配置hdfs-site.xml初始化hdfs文件系统启动hdfs服务验证hdfs 结论: 简介: Hadoop分布式文件系统(HDFS)是Hadoop生态系统中的一个…

【硬件设计】模拟电子基础二--放大电路

模拟电子基础二--放大电路 一、基本放大电路1.1 初始电路1.2 静态工作点1.3 分压偏置电路 二、负反馈放大电路三、直流稳压电路 前言:本章为知识的简单复习,适合于硬件设计学习前的知识回顾,不适合运用于考试。 一、基本放大电路 1.1 初始电…

数学建模-爬虫入门

Python快速入门 简单易懂Python入门 爬虫流程 获取网页内容:HTTP请求解析网页内容:Requst库、HTML结果、Beautiful Soup库储存和分析数据 什么是HTTP请求和响应 如何用Python Requests发送请求 下载pip macos系统下载:pip3 install req…

VactorCast自动化单元测试

VectorCAST软件自动化测试方案 VectorCAST软件自动化测试方案 博客园 软件测试面临的问题 有一句格言是这样说的,“如果没有事先做好准备,就意味着做好了 失败的准备。”如果把这个隐喻应用在软件测试方面,就可以这样说“没有测试到&#xf…

Tomcat虚拟主机

Tomcat虚拟主机 部署 [rootlocalhost webapps]# cd ../conf [rootlocalhost conf]# pwd /usr/local/tomcat/conf [rootlocalhost conf]# vim server.xml #增加虚拟主机配置&#xff0c;添加以下&#xff1a; <Host name"www.a.com" appBase"webapps"u…

react-redux的理解与使用

一、react-redux作用 和redux和flux功能一样都是管理各个组件的状态&#xff0c;是redux的升级版。 二、为什么要用reac-redux&#xff1f; 那么我们既然有了redux&#xff0c;为什么还要用react-redux呢&#xff1f;原因如下&#xff1a; 1&#xff0c;解决了每个组件用数…

怎么才能远程控制笔记本电脑?

为什么选择AnyViewer远程控制软件&#xff1f; 为什么AnyViewer是远程控制笔记本电脑软件的首选&#xff1f;以下是选择AnyViewer成为笔记本电脑远程控制软件的主要因素。 跨平台能力 AnyViewer作为一款跨平台远程控制软件&#xff0c;不仅可以用于从一台Windows电…