Encoder-Decoder:Seq2seq

news2025/1/13 13:13:11

目录

  • 一、编码器解码器架构:
    • 1.定义:
    • 2.在CNN中的体现:
    • 3.在RNN中的体现:
    • 4.代码:
  • 二、Seq2seq:
    • 1.模型架构:
      • 1.1编码器:
      • 1.2解码器:
    • 2.架构细节:
    • 3.模型评估指标BLEU:
    • 4.代码:
  • 三、束搜索:
    • 1.贪心搜索:
    • 2.束搜索:

一、编码器解码器架构:

1.定义:

在这里插入图片描述

Encoder负责对Input进行特征提取,输出特征矩阵State
Decoder接收State,负责进行预测并输出

2.在CNN中的体现:

在这里插入图片描述

3.在RNN中的体现:

在这里插入图片描述

4.代码:

from torch import nn

class Encoder(nn.Module):
    """编码器-解码器结构的基本编码器接口。"""
    def __init__(self, **kwargs):
        super(Encoder, self).__init__(**kwargs)

    def forward(self, X, *args):
        raise NotImplementedError
    
class Decoder(nn.Module):
    """编码器-解码器结构的基本解码器接口。"""
    def __init__(self, **kwargs):
        super(Decoder, self).__init__(**kwargs)

    def init_state(self, enc_outputs, *args):
        raise NotImplementedError

    def forward(self, X, state):
        raise NotImplementedError
    
class EncoderDecoder(nn.Module):
    """编码器-解码器结构的基类。"""
    def __init__(self, encoder, decoder, **kwargs):
        super(EncoderDecoder, self).__init__(**kwargs)
        self.encoder = encoder
        self.decoder = decoder

    def forward(self, enc_X, dec_X, *args):
        enc_outputs = self.encoder(enc_X, *args)
        dec_state = self.decoder.init_state(enc_outputs, *args)
        return self.decoder(dec_X, dec_state)    

二、Seq2seq:

1.模型架构:

这里以机器翻译任务为例:
在这里插入图片描述

1.1编码器:

编码器不管在训练阶段还是预测阶段都是用于提取特征,可以是单层RNN、多层RNN、双向RNN(双向RNN不仅可以提取上文序列特征,还可以提取下文的序列特征)

1.2解码器:

解码器在不同阶段作用不同,只能是单层RNN或多层RNN,不能是双向RNN(解码器用于预测,双向RNN不能预测)

  • 训练阶段,解码器主要是为了特征提取,通过接收编码器的输出隐藏状态作为h0并接收预测的真实值Input,每个时间步使用隐藏状态ht-1进行特征提取并更新隐藏状态ht,然后将ht和当前时间步的真实值token(而非预测值,因为是要更好的学习)作为下一个时间步的输入,不断更新可学习参数。
  • 预测阶段,解码器主要是为了执行预测任务,不再接收预测的真实值(因为不知道),仅接收编码器的输出隐藏状态作为h0,每个时间步使用隐藏状态ht-1进行预测并更新隐藏状态ht,然后将ht和当前时间步的预测值token作为下一个时间步的输入,进行下一个token的预测。

2.架构细节:

Seq2seq的编码器和解码器都是RNN
在这里插入图片描述

3.模型评估指标BLEU:

在这里插入图片描述

4.代码:

import collections
import math
import torch
from torch import nn
from d2l import torch as d2l

# 使用GRU作为编码器
class Seq2SeqEncoder(d2l.Encoder):
    def __init__(self, vocab_size, embed_size, num_hiddens, num_layers,
                 dropout=0, **kwargs):
        super(Seq2SeqEncoder, self).__init__(**kwargs)
        # 1
        self.embedding = nn.Embedding(vocab_size, embed_size)
        # 2 
        self.rnn = nn.GRU(embed_size, num_hiddens, num_layers,
                          dropout=dropout)

    def forward(self, X, *args):
        X = self.embedding(X)
        X = X.permute(1, 0, 2)
        output, state = self.rnn(X)
        return output, state
    
# 使用GRU作为解码器
class Seq2SeqDecoder(d2l.Decoder):
    def __init__(self, vocab_size, embed_size, num_hiddens, num_layers,
                 dropout=0, **kwargs):
        super(Seq2SeqDecoder, self).__init__(**kwargs)
        # 1
        self.embedding = nn.Embedding(vocab_size, embed_size)
        # 2
        self.rnn = nn.GRU(embed_size + num_hiddens, num_hiddens, num_layers,
                          dropout=dropout)
        # 3
        self.dense = nn.Linear(num_hiddens, vocab_size)

    def init_state(self, enc_outputs, *args):
        return enc_outputs[1]

    def forward(self, X, state):
        X = self.embedding(X).permute(1, 0, 2)
        context = state[-1].repeat(X.shape[0], 1, 1)
        X_and_context = torch.cat((X, context), 2)
        output, state = self.rnn(X_and_context, state)
        output = self.dense(output).permute(1, 0, 2)
        return output, state
    
# 在序列中屏蔽不相关的项,即屏蔽序列中之前使用<pad>填充的无效值   
def sequence_mask(X, valid_len, value=0):
    maxlen = X.size(1)
    mask = torch.arange((maxlen), dtype=torch.float32,
                        device=X.device)[None, :] < valid_len[:, None]
    X[~mask] = value
    return X

# 填充的无效值不参与损失值的计算,因为这些值的预测对错没有意义
class MaskedSoftmaxCELoss(nn.CrossEntropyLoss):
    def forward(self, pred, label, valid_len):
        weights = torch.ones_like(label)
        weights = sequence_mask(weights, valid_len)
        self.reduction = 'none'
        unweighted_loss = super(MaskedSoftmaxCELoss,
                                self).forward(pred.permute(0, 2, 1), label)
        weighted_loss = (unweighted_loss * weights).mean(dim=1)
        return weighted_loss
    
# 训练过程
def train_seq2seq(net, data_iter, lr, num_epochs, tgt_vocab, device):
    def xavier_init_weights(m):
        if type(m) == nn.Linear:
            nn.init.xavier_uniform_(m.weight)
        if type(m) == nn.GRU:
            for param in m._flat_weights_names:
                if "weight" in param:
                    nn.init.xavier_uniform_(m._parameters[param])

    net.apply(xavier_init_weights)
    net.to(device)
    optimizer = torch.optim.Adam(net.parameters(), lr=lr)
    loss = MaskedSoftmaxCELoss()
    net.train()
    animator = d2l.Animator(xlabel='epoch', ylabel='loss',
                            xlim=[10, num_epochs])
    for epoch in range(num_epochs):
        timer = d2l.Timer()
        metric = d2l.Accumulator(2)
        for batch in data_iter:
            X, X_valid_len, Y, Y_valid_len = [x.to(device) for x in batch]
            bos = torch.tensor([tgt_vocab['<bos>']] * Y.shape[0],
                               device=device).reshape(-1, 1)
            dec_input = torch.cat([bos, Y[:, :-1]], 1)
            Y_hat, _ = net(X, dec_input, X_valid_len)
            l = loss(Y_hat, Y, Y_valid_len)
            l.sum().backward()
            d2l.grad_clipping(net, 1)
            num_tokens = Y_valid_len.sum()
            optimizer.step()
            with torch.no_grad():
                metric.add(l.sum(), num_tokens)
        if (epoch + 1) % 10 == 0:
            animator.add(epoch + 1, (metric[0] / metric[1],))
    print(f'loss {metric[0] / metric[1]:.3f}, {metric[1] / timer.stop():.1f} '
          f'tokens/sec on {str(device)}')
    
# 预测过程    
def predict_seq2seq(net, src_sentence, src_vocab, tgt_vocab, num_steps,
                    device, save_attention_weights=False):
    net.eval()
    src_tokens = src_vocab[src_sentence.lower().split(' ')] + [
        src_vocab['<eos>']]
    enc_valid_len = torch.tensor([len(src_tokens)], device=device)
    src_tokens = d2l.truncate_pad(src_tokens, num_steps, src_vocab['<pad>'])
    enc_X = torch.unsqueeze(
        torch.tensor(src_tokens, dtype=torch.long, device=device), dim=0)
    enc_outputs = net.encoder(enc_X, enc_valid_len)
    dec_state = net.decoder.init_state(enc_outputs, enc_valid_len)
    dec_X = torch.unsqueeze(
        torch.tensor([tgt_vocab['<bos>']], dtype=torch.long, device=device),
        dim=0)
    output_seq, attention_weight_seq = [], []
    for _ in range(num_steps):
        Y, dec_state = net.decoder(dec_X, dec_state)
        dec_X = Y.argmax(dim=2)
        pred = dec_X.squeeze(dim=0).type(torch.int32).item()
        if save_attention_weights:
            attention_weight_seq.append(net.decoder.attention_weights)
        if pred == tgt_vocab['<eos>']:
            break
        output_seq.append(pred)
    return ' '.join(tgt_vocab.to_tokens(output_seq)), attention_weight_seq    
    
# 模型评估指标
def bleu(pred_seq, label_seq, k):  
    pred_tokens, label_tokens = pred_seq.split(' '), label_seq.split(' ')
    len_pred, len_label = len(pred_tokens), len(label_tokens)
    score = math.exp(min(0, 1 - len_label / len_pred))
    for n in range(1, k + 1):
        num_matches, label_subs = 0, collections.defaultdict(int)
        for i in range(len_label - n + 1):
            label_subs[''.join(label_tokens[i:i + n])] += 1
        for i in range(len_pred - n + 1):
            if label_subs[''.join(pred_tokens[i:i + n])] > 0:
                num_matches += 1
                label_subs[''.join(pred_tokens[i:i + n])] -= 1
        score *= math.pow(num_matches / (len_pred - n + 1), math.pow(0.5, n))
    return score

# 训练    
embed_size, num_hiddens, num_layers, dropout = 32, 32, 2, 0.1
batch_size, num_steps = 64, 10
lr, num_epochs, device = 0.005, 300, d2l.try_gpu()

train_iter, src_vocab, tgt_vocab = d2l.load_data_nmt(batch_size, num_steps)
encoder = Seq2SeqEncoder(len(src_vocab), embed_size, num_hiddens, num_layers,
                         dropout)
decoder = Seq2SeqDecoder(len(tgt_vocab), embed_size, num_hiddens, num_layers,
                         dropout)
net = d2l.EncoderDecoder(encoder, decoder)
train_seq2seq(net, train_iter, lr, num_epochs, tgt_vocab, device)    
    
# 预测    
engs = ['go .', "i lost .", 'he\'s calm .', 'i\'m home .']
fras = ['va !', 'j\'ai perdu .', 'il est calme .', 'je suis chez moi .']
for eng, fra in zip(engs, fras):
    translation, attention_weight_seq = predict_seq2seq(
        net, eng, src_vocab, tgt_vocab, num_steps, device)
    print(f'{eng} => {translation}, bleu {bleu(translation, fra, k=2):.3f}')

三、束搜索:

1.贪心搜索:

在这里插入图片描述

2.束搜索:

在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1971056.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

C# Unity 补全计划 泛型

本文仅作学习笔记与交流&#xff0c;不作任何商业用途&#xff0c;作者能力有限&#xff0c;如有不足还请斧正 1.什么是泛型 泛型&#xff08;Generics&#xff09;是C#中的一个强大特性&#xff0c;允许你编写可以适用于多种数据类型的可重用代码&#xff0c;而不需要重复编写…

第二证券:刚刚!亚太股市,跌麻了!

今天早盘&#xff0c;亚太股市全线崩跌。日经225指数在大幅低开之后快速下行&#xff0c;最大跌幅近5%&#xff1b;韩国、澳大利亚股指亦迎来逾越2%以上的暴降。那么&#xff0c;毕竟发生了什么&#xff1f; 剖析人士认为&#xff0c;或许仍是与日元套息有关。从前史来看&…

Duplicate class kotlin.collections.jdk8.CollectionsJDK8Kt found in modules。Android studio纯java代码报错

我使用java代码 构建项目&#xff0c;初始代码运行就会报错。我使用的是Android Studio Giraffe&#xff08;Adroid-studio-2022.3.1.18-windows&#xff09;。我在网上找的解决办法是删除重复的类&#xff0c;但这操作起来真的太麻烦了。 这是全部报错代码&#xff1a; Dupli…

mysql环境的部署安装及数据库的操作(twenty day)

一、centos7 中安装 mysql 8.x 1、下载安装包 wget https://downloads.mysql.com/archives/get/p/23/file/mysql-8.0.33-1.el7.x86_64.rpm-bundle.tar 2、解压 tar -zxvf mysql-8.0.33-1.el7.x86_64.rpm-bundle.tar 3、卸载mariodb yum remove -y *mariadb* 4、依次安装依赖包…

SC215TA是C型/ PD和DPDM快速充电控制器,集成了内部反馈补偿PD3.0快充

SC215TA是C型/ PD和DPDM快速充电控制器&#xff0c;集成了内部反馈补偿。它符合最新的C型和PD 3.0标准&#xff0c;并支持专有的高压快速充电协议与DPDM接口。它的目标是旅行适配器的应用程序。SC215TA通过集成USB PD基带PHY、Type-C检测、DPDM PHY、VBUS放电路径、VCONN电源、…

旧衣回收小程序,旧衣回收行业新态势

进入网络时代后&#xff0c;互联网改变了大众的生活&#xff0c;传统的回收模式逐渐被淘汰&#xff0c;新兴的互联网旧衣回收受到了大众的关注&#xff01;通过技术创新为行业带来新模式&#xff0c;不断优化回收流程&#xff0c;提高回收效率&#xff0c;提升居民的回收体验&a…

Java编程达人:每日一练,提升自我

目录 题目1.以下哪个单词不是 Java 的关键字&#xff1f;2.boolean 类型的默认值为&#xff1f;3.以下代码输出正确的是&#xff1f;4.以下代码&#xff0c;输出结果为&#xff1a;5.以下代码输出结果为&#xff1a;6.以下代码输出结果为&#xff1f;7.float 变量的默认值为&am…

Three.js WebGPU 节点材质系统 控制instances的某个实例单独的透明度,颜色等属性

文章目录 1. 声明一个实例必要的属性instanceMatrix同级别的属性2. 在设置位置矩阵的时候填充这个数组3. 在shader中获取当前的索引4. 增加uniform5. 对比当前着色的实例是否是选中的实例6. 如果是选中的实例7. 影响片元着色器透明度参数 8.源码 写在前面 本文环境是 原生js 没…

EV代码签名证书申请流程

EV代码签名证书可以有效提高用户信赖。可以用于任何软件&#xff0c;支持Microsoft SmartScreen应用程序信誉功能以及对Windows 10内核驱动程序进行签名。 下面是EV代码签名证书的申请流程 代码签名证书_代码签名证书申请购买-JoySSL代码签名证书是对可执行脚本、软件代码和内容…

500+伙伴齐聚上海:纷享销客生态伙伴大会·上海站成功举办

近日&#xff0c;纷享销客生态伙伴大会上海站成功举办&#xff0c;此次会议汇聚了500余位来自各行各业的伙伴&#xff0c;齐聚一堂&#xff0c;共同探讨行业的未来发展趋势。 01、展望CRM市场 国内外双轮驱动&#xff0c;SaaS巅峰在价值创造与效率运营 纷享销客创始人兼CEO罗…

vulhub:nginx解析漏洞CVE-2013-4547

此漏洞为文件名逻辑漏洞&#xff0c;该漏洞在上传图片时&#xff0c;修改其16进制编码可使其绕过策略&#xff0c;导致解析为 php。当Nginx 得到一个用户请求时&#xff0c;首先对 url 进行解析&#xff0c;进行正则匹配&#xff0c;如果匹配到以.php后缀结尾的文件名&#xff…

零售门店客流统计系统支持回头客识别,更好维护老客户

随着市场竞争日益激烈&#xff0c;零售业面临着诸多挑战&#xff0c;尤其是如何吸引新客户的同时留住老客户。客流统计系统作为一项关键的技术手段&#xff0c;正在帮助零售门店解决这一难题。 一、零售门店客流统计痛点 1.数据准确性低&#xff1a;传统的人工统计方法往往存在…

MATLAB(10)分类算法

前言 MATLAB中实现分类算法的代码可以非常多样&#xff0c;取决于你具体想要使用的分类算法类型&#xff08;如决策树、逻辑回归、支持向量机、K近邻等&#xff09;。以下是一些常见分类算法的基本MATLAB实现示例。 一、逻辑回归 逻辑回归是分类问题中的一种基础算法&#xff0…

第十六天学习笔记2024.7.29

web yum -y install httpd systemctl start httpd.service systemctl stop firewalld systemctl disable firewalld 2、动态⻚⾯与静态⻚⾯的差别 &#xff08;1&#xff09;URL不同 静态⻚⾯链接⾥没有“&#xff1f;” 动态⻚⾯链接⾥包含“&#xff1f;” &#xff08…

第一 二章 小车硬件介绍-(全网最详细)基于STM32智能小车-蓝牙遥控、避障、循迹、跟随、PID速度控制、视觉循迹、openmv与STM32通信、openmv图像处理、smt32f103c8t6

第一篇-STM32智能小车硬件介绍 后续章节也放这里 持续更新中&#xff0c;视频发布在小B站 里面。这边也会更新。 B站视频合集: STM32智能小车V3-STM32入门教程-openmv与STM32循迹小车-stm32f103c8t6-电赛 嵌入式学习 PID控制算法 编码器电机 跟随 小B站链接:https://www.bilib…

贪心算法—股票交易时机Ⅱ

在此前我们已经介绍过贪心算法以及股票交易时机Ⅰ&#xff0c;有需要的话可以移步至贪心算法_Yuan_Source的博客-CSDN博客 题目介绍 122. 买卖股票的最佳时机 II - 力扣&#xff08;LeetCode&#xff09; 给你一个整数数组 prices &#xff0c;其中 prices[i] 表示某支股票第…

【Linux】问题解决:yum repolist出现“!”号

问题描述&#xff1a;在运行 yum repolist 时&#xff0c;出现以下状况&#xff1a; 原因&#xff1a;表示仓库里有过期的元数据&#xff0c;并不是最新版本。 解决方法&#xff1a; 清楚过期缓存 yum clean all 快速创建新yum缓存 yum makecache fast 结果&#xff1a;…

Qt——QTCreater ui界面如何统一设置字体

第一步&#xff1a;来到 ui 设计界面&#xff0c;鼠标右键点击 改变样式表 第二步&#xff1a;选择添加字体 第三步&#xff1a;选择字体样式和大小&#xff0c;点击 ok 第四步&#xff1a;点击ok或apply&#xff0c;完成设置

基于100G-PAM4技术的LinkX 线缆

LinkX线缆专注于加速数据中心和人工智能计算系统&#xff0c;这些产品不仅提供了高数据传输速率&#xff0c;还在设计上特别优化了低延迟性能&#xff0c;以满足现代计算系统对速度和效率的高要求。 一、主要特点与技术规格 1、传输距离与速率 数据中心应用&#xff1a;支持…

用 Bytebase 实现批量、多环境、多租户数据库的丝滑变更

Bytebase 提供了多种功能来简化批量变更管理&#xff0c;适用于多环境或多租户情况。本教程将指导您如何使用 部署配置 和 数据库组 在不同场景下进行数据库批量变更。 默认流水线 vs 部署配置 图片数据库 vs 数据库组 1. 准备 请确保已安装 Docker&#xff0c;如果本地没有重…