深度学习--复制机制:CopyNet 模型在序列到序列模型中的应用以及代码实现

news2024/9/19 15:06:16

CopyNet 是一种特别设计的序列到序列(Seq2Seq)模型,旨在更好地处理那些在输出序列中需要直接复制输入序列中的部分或全部内容的任务。它在机器翻译、摘要生成、文本复述等任务中有广泛的应用,尤其是在输入和输出有显著重叠的场景。

以下是 CopyNet 模型在序列到序列学习中的具体应用:

1. 任务背景

在许多自然语言处理任务中,生成的输出序列常常需要复制输入序列中的一些词语或短语。例如,在机器翻译中,专有名词通常直接从源语言复制到目标语言中;在摘要生成中,一些关键短语可能需要直接引用。

2. CopyNet 模型的结构

CopyNet 的结构可以看作是对传统的 Seq2Seq 模型的一种扩展,主要包括以下几个部分:

  • 编码器(Encoder):与传统的 Seq2Seq 模型类似,CopyNet 使用编码器将输入序列编码为一系列隐藏状态。

  • 解码器(Decoder):解码器在生成每一个输出词时,既考虑生成新词的概率分布,也考虑直接复制输入序列中某个词的概率分布。

  • 生成机制(Generation Mode):生成新词时,解码器基于传统的 Seq2Seq 方法,使用词汇表中的词语来生成输出。

  • 复制机制(Copy Mode):在复制模式下,模型基于注意力机制(Attention Mechanism)从输入序列中选择一个词,直接将其复制到输出序列中。

3. 训练过程

在训练过程中,CopyNet 的目标函数是生成和复制的加权和。模型通过最大化生成词的概率和复制词的概率的和来学习。

4. 应用场景

  • 文本摘要:在生成摘要时,某些句子片段或重要信息可以直接从原文中复制,以确保摘要的准确性。
  • 机器翻译:在处理专有名词或具有特定结构的句子时,可以直接将输入中的一些词语复制到翻译结果中。
  • 问答生成:当生成回答时,问题中的一些关键术语可能需要直接复制到答案中。

5. 实现细节

在实现方面,CopyNet 可以基于现有的 Seq2Seq 框架(如 Transformer、LSTM 等)进行扩展。典型的步骤包括:

  • 使用标准的编码器-解码器架构。
  • 在解码阶段引入两个分支,一个用于词汇表生成,一个用于输入序列的复制。
  • 将这两者结合,通过注意力机制选择是否生成或复制。

6. 优势

  • 增强生成能力:CopyNet 能够平衡生成新词和直接复制输入词,使得模型在处理需要高度忠实于输入内容的任务时更加有效。
  • 提高输出质量:尤其在需要保留输入中关键信息的任务中,CopyNet 生成的输出更具准确性和自然性。

7. 代码实现

实现 CopyNet 模型的 Python 代码需要使用深度学习框架(如 TensorFlow 或 PyTorch)。下面以 PyTorch 为例,展示如何实现一个简单的 CopyNet 模型。

1). 导入必要的库

import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np

2). 编码器(Encoder)

class Encoder(nn.Module):
    def __init__(self, input_dim, emb_dim, hid_dim, n_layers, dropout):
        super().__init__()
        
        self.embedding = nn.Embedding(input_dim, emb_dim)
        self.rnn = nn.GRU(emb_dim, hid_dim, n_layers, dropout=dropout)
        self.dropout = nn.Dropout(dropout)
        
    def forward(self, src):
        embedded = self.dropout(self.embedding(src))
        outputs, hidden = self.rnn(embedded)
        return outputs, hidden

3). 解码器(Decoder)

class Decoder(nn.Module):
    def __init__(self, output_dim, emb_dim, hid_dim, n_layers, dropout):
        super().__init__()
        
        self.output_dim = output_dim
        self.embedding = nn.Embedding(output_dim, emb_dim)
        self.rnn = nn.GRU(emb_dim, hid_dim, n_layers, dropout=dropout)
        self.fc_out = nn.Linear(hid_dim, output_dim)
        self.dropout = nn.Dropout(dropout)
        
    def forward(self, input, hidden):
        input = input.unsqueeze(0)
        embedded = self.dropout(self.embedding(input))
        output, hidden = self.rnn(embedded, hidden)
        prediction = self.fc_out(output.squeeze(0))
        return prediction, hidden

4). CopyNet 模型

class CopyNet(nn.Module):
    def __init__(self, encoder, decoder, src_pad_idx, trg_pad_idx, device):
        super().__init__()
        
        self.encoder = encoder
        self.decoder = decoder
        self.src_pad_idx = src_pad_idx
        self.trg_pad_idx = trg_pad_idx
        self.device = device
        
    def create_mask(self, src):
        mask = (src != self.src_pad_idx).permute(1, 0)
        return mask
    
    def forward(self, src, trg, teacher_forcing_ratio = 0.5):
        trg_len = trg.shape[0]
        batch_size = trg.shape[1]
        trg_vocab_size = self.decoder.output_dim
        
        outputs = torch.zeros(trg_len, batch_size, trg_vocab_size).to(self.device)
        
        encoder_outputs, hidden = self.encoder(src)
        
        input = trg[0, :]
        
        for t in range(1, trg_len):
            output, hidden = self.decoder(input, hidden)
            
            # Copy mechanism
            copy_prob = F.softmax(output, dim=1)
            gen_prob = F.softmax(self.decoder.fc_out(hidden[-1]), dim=1)
            final_prob = torch.log(gen_prob + copy_prob)
            
            outputs[t] = final_prob
            
            teacher_force = np.random.random() < teacher_forcing_ratio
            
            top1 = output.argmax(1)
            
            input = trg[t] if teacher_force else top1
        
        return outputs

5). 训练函数

def train(model, iterator, optimizer, criterion, clip):
    model.train()
    
    epoch_loss = 0
    
    for i, batch in enumerate(iterator):
        src = batch.src
        trg = batch.trg
        
        optimizer.zero_grad()
        
        output = model(src, trg)
        
        output_dim = output.shape[-1]
        
        output = output[1:].view(-1, output_dim)
        trg = trg[1:].view(-1)
        
        loss = criterion(output, trg)
        
        loss.backward()
        
        torch.nn.utils.clip_grad_norm_(model.parameters(), clip)
        
        optimizer.step()
        
        epoch_loss += loss.item()
        
    return epoch_loss / len(iterator)

6). 评估函数

def evaluate(model, iterator, criterion):
    model.eval()
    
    epoch_loss = 0
    
    with torch.no_grad():
        for i, batch in enumerate(iterator):
            src = batch.src
            trg = batch.trg
            
            output = model(src, trg, 0)
            
            output_dim = output.shape[-1]
            
            output = output[1:].view(-1, output_dim)
            trg = trg[1:].view(-1)
            
            loss = criterion(output, trg)
            
            epoch_loss += loss.item()
        
    return epoch_loss / len(iterator)

7). 模型实例化和训练

INPUT_DIM = len(SRC.vocab)
OUTPUT_DIM = len(TRG.vocab)
ENC_EMB_DIM = 256
DEC_EMB_DIM = 256
HID_DIM = 512
N_LAYERS = 2
ENC_DROPOUT = 0.5
DEC_DROPOUT = 0.5
SRC_PAD_IDX = SRC.vocab.stoi[SRC.pad_token]
TRG_PAD_IDX = TRG.vocab.stoi[TRG.pad_token]

enc = Encoder(INPUT_DIM, ENC_EMB_DIM, HID_DIM, N_LAYERS, ENC_DROPOUT)
dec = Decoder(OUTPUT_DIM, DEC_EMB_DIM, HID_DIM, N_LAYERS, DEC_DROPOUT)

model = CopyNet(enc, dec, SRC_PAD_IDX, TRG_PAD_IDX, device).to(device)

optimizer = torch.optim.Adam(model.parameters())

TRG_PAD_IDX = TRG.vocab.stoi[TRG.pad_token]
criterion = nn.CrossEntropyLoss(ignore_index = TRG_PAD_IDX)

N_EPOCHS = 10
CLIP = 1

for epoch in range(N_EPOCHS):
    train_loss = train(model, train_iterator, optimizer, criterion, CLIP)
    valid_loss = evaluate(model, valid_iterator, criterion)
    
    print(f'Epoch: {epoch+1:02}')
    print(f'\tTrain Loss: {train_loss:.3f} | Train PPL: {np.exp(train_loss):7.3f}')
    print(f'\t Val. Loss: {valid_loss:.3f} |  Val. PPL: {np.exp(valid_loss):7.3f}')

8). 推断

def translate_sentence(sentence, src_field, trg_field, model, device, max_len = 50):
    model.eval()
    
    tokens = [token.lower() for token in sentence]
    
    tokens = [src_field.init_token] + tokens + [src_field.eos_token]
    
    src_indexes = [src_field.vocab.stoi[token] for token in tokens]
    
    src_tensor = torch.LongTensor(src_indexes).unsqueeze(1).to(device)
    
    with torch.no_grad():
        encoder_outputs, hidden = model.encoder(src_tensor)
    
    trg_indexes = [trg_field.vocab.stoi[trg_field.init_token]]
    
    for i in range(max_len):
        trg_tensor = torch.LongTensor([trg_indexes[-1]]).to(device)
        
        with torch.no_grad():
            output, hidden = model.decoder(trg_tensor, hidden)
            copy_prob = F.softmax(output, dim=1)
            gen_prob = F.softmax(model.decoder.fc_out(hidden[-1]), dim=1)
            final_prob = torch.log(gen_prob + copy_prob)
            
        pred_token = final_prob.argmax(1).item()
        
        trg_indexes.append(pred_token)
        
        if pred_token == trg_field.vocab.stoi[trg_field.eos_token]:
            break
    
    trg_tokens = [trg_field.vocab.itos[i] for i in trg_indexes]
    
    return trg_tokens[1:]

9). 结论

上面的代码展示了如何用 PyTorch 实现 CopyNet 模型。这个实现简单展示了 CopyNet 的核心思想,包括编码器、解码器以及生成和复制机制的结合。在实际应用中,模型可能需要进一步优化和调整以适应具体任务。

通过引入复制机制,CopyNet 大大提升了模型处理具有高重叠率的序列生成任务的能力,成为自然语言处理领域中一项有力的工具。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2072682.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

Spring--三级缓存机制

一、什么是三级缓存 就是在Bean生成流程中保存Bean对象三种形态的三个Map集合&#xff0c;如下&#xff1a; // 一级缓存Map 存放完整的Bean&#xff08;流程跑完的&#xff09; private final Map<String, Object> singletonObjects new ConcurrentHashMap(256);// 二…

51单片机——LED灯控制

1、LED介绍 中文名&#xff1a;发光二极管 外文名&#xff1a;Light Emitting Diode 简称&#xff1a;LED 用途&#xff1a;照明、广告灯、指引灯、屏幕 2、LED原理图 电阻在原理图上标注为1k&#xff0c;表示这是1千欧的电阻&#xff0c;实际在电路板上的表示是102 102解…

HarmonyOs应用权限申请,system_grant和user_grant区别。本文附头像上传申请user-grant权限代码示例

HarmonyOs应用权限申请&#xff0c;system_grant和user_grant区别。本文附头像上传申请user-grant权限代码示例 system_grant&#xff08;系统授权&#xff09; system_grant指的是系统授权类型&#xff0c;在该类型的权限许可下&#xff0c;应用被允许访问的数据不会涉及到用户…

【大数据算法】一文掌握大数据算法之:排序链表搜索的亚线性算法。

排序链表搜索的亚线性算法 1、引言2、平面图直径问题的亚线性算法2.1 定义2.2 核心原理2.2.1 跳表2.2.2 跳跃搜索2.2.3 分块搜索 2.3 应用场景2.4 算法公式2.5 代码示例 3、总结 1、引言 小屌丝&#xff1a;鱼哥&#xff0c;这茶味道怎么样&#xff1f; 小鱼&#xff1a;嗯&am…

计算机毕业设计选题推荐-保险业务管理系统-Java/Python项目实战

✨作者主页&#xff1a;IT研究室✨ 个人简介&#xff1a;曾从事计算机专业培训教学&#xff0c;擅长Java、Python、微信小程序、Golang、安卓Android等项目实战。接项目定制开发、代码讲解、答辩教学、文档编写、降重等。 ☑文末获取源码☑ 精彩专栏推荐⬇⬇⬇ Java项目 Python…

[CUDA编程] --- cuda线程模型

1 核函数 先看一个cuda版本的hello world #include <stdio.h>__global__ void helloworld() {printf("hello world\n"); }int main() {helloworld()<<<1, 1>>>();cudaDeviceSynchronize();return 0; }这里helloworld()<<<1, 1>…

旅行达人必备!有道翻译和这三款神器,轻松走遍世界

在如今的全球化和科技迅猛发展的时代&#xff0c;翻译工具在我们的日常生活中发挥着越来越重要的作用。在各种格式数据的翻译当中&#xff0c;我们就可以发现各种类型的翻译工具纷纷崭露头角。今天就分享三款除了有道翻译外的好用翻译工具&#xff0c;希望可以解决大家翻译的需…

虚幻5|暴击攻击和释放技能,造成伤害

玩家数据的Actor组件制作&#xff1a;虚幻5|制作玩家血量&#xff0c;体力-CSDN博客 造成伤害时&#xff0c;显示暴击及暴击字体颜色和未暴击的字体颜色&#xff0c;还有释放技能连击 一.编辑暴击数据 1.打开之前创建的玩家数据Actor组件 创建一个浮点变量&#xff0c;命名…

从法律风险的角度来看,项目经理遇到不清楚或不明确问题时的处理

大家好&#xff0c;我是不会魔法的兔子&#xff0c;在北京从事律师工作&#xff0c;日常分享项目管理风险预防方面的内容。 序言 在项目开展过程中&#xff0c;有时候会遇到一些不清楚或不明确的状况&#xff0c;但碍于项目进度的紧迫性&#xff0c;不得不硬着头皮做决策&…

喜羊羊教你(如何应对突发的技术故障和危机?)

开发团队如何应对突发的技术故障和危机&#xff1f; 在数字化时代&#xff0c;软件服务的稳定性至关重要。、8月19日下午&#xff0c;网易云音乐疑似出现服务器故障&#xff0c;网页端出现502 Bad Gateway 报错&#xff0c;且App也无法正常使用。 怀疑了自己的电脑、自己的手…

OpenStack 常见模块(二)

&#x1f49d;&#x1f49d;&#x1f49d;欢迎来到我的博客&#xff0c;很高兴能够在这里和您见面&#xff01;希望您在这里可以感受到一份轻松愉快的氛围&#xff0c;不仅可以获得有趣的内容和知识&#xff0c;也可以畅所欲言、分享您的想法和见解。 推荐:Linux运维老纪的首页…

MFC工控项目实例之七点击下拉菜单弹出对话框

承接专栏《MFC工控项目实例之六CFile添加菜单栏》 1、在SEAL_PRESSUREDlg.h文件中添加代码 class CSEAL_PRESSUREDlg : public CDialog { ...afx_msg void OnTypeManage(); ... } 2、在SEAL_PRESSUREDlg.cpp文件中添加代码 BEGIN_MESSAGE_MAP(CSEAL_PRESSUREDlg, CDialog)//…

如何使用ssm实现基于Java的学生信息管理系统的设计与实现

TOC ssm165基于Java的学生信息管理系统的设计与实现jsp 绪论 1.1 研究背景 当前社会各行业领域竞争压力非常大&#xff0c;随着当前时代的信息化&#xff0c;科学化发展&#xff0c;让社会各行业领域都争相使用新的信息技术&#xff0c;对行业内的各种相关数据进行科学化&a…

python-随机序列(赛氪OJ)

[题目描述] 小理的作业太多了&#xff0c;怎么也做不完。 小理的数学作业由 T 张试卷组成&#xff0c;每张试卷上有 n 个数 a1..n​ &#xff0c;小理需要算出这些数的极差和方差。极差是一个整数&#xff0c;方差是一个浮点数&#xff0c;要求保留到小数点后 3 位。虽然题目很…

iPhone 手机使用技巧:iPhone 数据恢复软件

无论是由于意外删除、系统崩溃还是软件更新&#xff0c;丢失 iPhone 上的数据都是一场噩梦。从珍贵的照片到重要的工作文件&#xff0c;这种损失可能会让人感到毁灭性。值得庆幸的是&#xff0c;几个 iPhone 数据恢复软件选项可以帮助您找回丢失的文件。这些工具提供不同的功能…

大学数据库系统原理 Mysql数据库实验记录

软件版本说明&#xff1a; 1.Mysql数据库&#xff1a;sql server8.0 2.命令实现使用以及数据库可视化查看&#xff1a;Navicat 16 #不用Mysql Command Line 的原因是不喜欢那个黑框&#xff0c;也不常用&#xff0c;使用Navicat的MYSQL命令列界面是一样的 另外说明 实现相同…

Junit单元测试笔记

常用mock类框架 在软件测试和开发过程中&#xff0c;Mock框架扮演着至关重要的角色&#xff0c;它们允许开发者模拟对象的行为&#xff0c;以便在不需要实际依赖的情况下进行测试。以下是一些常用的Mock框架&#xff1a; MockitoPowerMockEasyMockJMockSpock 初始化mock/spy…

解决ONENOTE复制文字到外部为图片(Ditto)

默认情况下&#xff0c;在ONENOTE中记录的文字&#xff0c;在复制粘贴到外部时&#xff0c;会成为一张图片格式 如下图这段文字&#xff0c;粘贴到QQ中变为了图片 解决办法&#xff1a;安装Ditto Ditto下载链接 点击Download下载 双击安装.exe&#xff0c;选择安装路径后&…

JVM上篇:内存与垃圾回收篇-07-方法区

笔记来源&#xff1a;尚硅谷 JVM 全套教程&#xff0c;百万播放&#xff0c;全网巅峰&#xff08;宋红康详解 java 虚拟机&#xff09; 文章目录 7. 方法区7.1. 栈、堆、方法区的交互关系7.2. 方法区的理解7.2.1. 方法区在哪里&#xff1f;7.2.2. 方法区的基本理解7.2.3. HotSp…

编译 wolfssl 库

wolfssl github: https://github.com/wolfSSL/wolfssl 编译 .lib 或者 .dll wolfssl 很好的提供了 win32 的工程》sln 文件 这样就不用折腾 CMakeLists 文件了&#xff0c;使用 Visual Studio 打开 sln 文件后&#xff0c;设置好 Static 编译库即可&#xff0c;开箱即用 编译 .…