在本篇文章中,我们将实现经典的"Get to the Point"模型,该模型最初发表于 Get to the Point: Summarization with Pointer-Generator Networks。这是当时最著名的摘要生成模型之一,至今仍有很多人使用其Pointer-Generator架构作为他们模型的一部分。
1. 模型简介
"Get to the Point" 模型结合了Pointer机制和覆盖率机制,解决了生成式摘要系统中的词汇外词(OOV)问题以及重复生成问题。
该模型的核心思想是使用Pointer-Generator机制将生成器(Generator)和指针(Pointer)结合,从而在预测新词时可以选择从词汇表中生成词语,或者直接复制文章中的词。覆盖率机制则通过引入覆盖向量,帮助模型避免生成重复内容。
模型总体架构:
2. 模型参数与配置
import torch
import torch.nn.functional as F
# 配置参数
beam_size = 4
emb_dim = 128
batch_size = 16
hidden_dim = 256
max_enc_steps = 400
max_dec_steps = 100
vocab_size = 50000
lr = 0.15
cov_loss_wt = 1.0
pointer_gen = True
is_coverage = True
max_grad_norm = 2.0
3. 数据准备
我们使用CNN/DailyMail数据集,首先需要进行数据下载和处理。处理好的数据文件。
数据预处理流程(跳过分词部分)
# 已经过Tokenize的文件,可以直接跳过分词部分
train_data_path = "data/finished_files/chunked/train_*"
eval_data_path = "data/finished_files/val.bin"
decode_data_path = "data/finished_files/test.bin"
vocab_path = "data/finished_files/vocab"
4. 模型结构
4.1 基本模块
我们为所有模块定义一个基础类BasicModule
,其初始化方法采用了截断正态分布来初始化参数。
class BasicModule(nn.Module):
def __init__(self, init='uniform'):
super(BasicModule, self).__init__()
self.init = init
def init_params(self):
for param in self.parameters():
if param.requires_grad and len(param.shape) > 0:
stddev = 1 / math.sqrt(param.shape[0])
if self.init == 'uniform':
torch.nn.init.uniform_(param, a=-0.05, b=0.05)
elif self.init == 'normal':
torch.nn.init.normal_(param, std=stddev)
elif self.init == 'truncated_n