R2GenCMN中的Encoder_Decoder结构

news2024/11/24 17:38:08

R2GenCMN中的 Encoder_Decoder 结构

Encoder_Decoder 结构直接关系到文本的生成,它结构参考的transformer的结构

在这里插入图片描述

在这里插入图片描述

我们这里主要看代码的实现,从视觉编码器的输出开始

1. 模型结构

首先介绍一下整体结构,这里的baseCMN其实就是一个包装了的Transformer的 Transformer,这个Transformer里面是有n个连续的encoder和n个连续的decoder组成的。图片的输入进入encoder进行编码,这个过程是Transformer的结构,加入了位置编码和注意力机制。(凡是框框里面有的,都是一个类)

文章中的 CMN组件是在 encoder之后起作用的,CMN如同一个字典,(这是一个虚构的字典),这个字典负责查询,正常encoder的输出直接进入decoder就可以了,但是在这里,encoder的输出,需要先经过CMN的查询响应,与响应叠加之后的输出进入decoder。它同时对文本和图像两个变量进行索引和反馈,在prepare_feature的函数中,它将 图像的特征进行查询反馈。在decode过程中对 文本特征进行查询反馈。(这里今后我打算采用 稀疏学习 的方式进行优化)

decoder的结构是标准的Transfomer的decoder结构,有两个多头注意力机制,但是在这里,第一个多头注意力进行文本内容的注意力特征提取,第二头进行跨模态的特征提取,也就是使用x对图片的特征进行特征提取。
x = self.sublayer[0](x, lambda x: self.self_attn(x, x, x, tgt_mask))
x = self.sublayer[1](x, lambda x: self.src_attn(x, m, m, src_mask))
最终,得到输出。

2. 模型训练

文本生成模型是一个自回归模型,模型的不是一下子输出全部,而是一点一点的输出一个token,而这里的实现就是在使用core函数。实际上core函数并不是标准内置的响应函数,这个代码中,将使用forward和sample来进行区分运行的模式,在模型的训练阶段,模型的运行模式是自回归方式。对于自回归模型,如GPT系列,在模型的训练过程中,即使模型在某个步骤中预测错误(比如预测了“海边”而不是“公园”),下一步的训练输入仍然是真实的序列中的词(“公园”),而不是模型错误的预测结果(“海边”)。这样做的目的是加速(并行处理)训练并提高模型的稳定性和性能。

训练:
1. 并行处理: 尽管模型预测下一个词是基于之前的所有词,但在训练时,这个过程是并行化的。给定一个序列,模型能够同时计算序列中每个位置的输出。这是通过使用所谓的“掩码”技术在自注意力层中实现的,它防止位置注意到它之后的任何位置,确保预测仅依赖于之前的词和当前位置的词。
2.(Teacher foring): 给定序列的当前位置,模型使用之前位置的真实词(而不是模型自己生成的词)来预测下一个词。

这实际上也能看出来GPT模型的缺点,就是内容连贯性,但是如果模型一开始是错误的,那么模型的很容易一直错下去,生成开口完全一致,内容一模一样,如同幻觉一般的句子,如果数据集的多样性非常有限,就是文本之间非常像的话,最终模型的训练会陷入一个误区就是找整个数据集中的一个内容不变的句子,这个句子对于 整个数据集来说差异最小 。导致报告的生成陷入一个由于数据集差异太小,同时样本多数一致的 训练误区

解决办法
找到一种办法就是可以 重新量化差异,让本来差异很小的数据集,在新的视角下,差异变大
3. 代码实现过程

encoder_decoder的传播函数

    def _forward(self, fc_feats, att_feats, seq, att_masks=None):
        print(f"这里是encoder_decoder的forward")
        embed()
        att_feats, seq, att_masks, seq_mask = self._prepare_feature_forward(att_feats, att_masks, seq)
        out = self.model(att_feats, seq, att_masks, seq_mask)
        outputs = F.log_softmax(self.logit(out), dim=-1)
        return outputs

这里我们查看输入都是什么, 可以看到fc_feature= 2048 * 4 = 8192 的特征是我进行堆叠,图片att_feature是我进行了cat= 7* 7* 4

In [1]:  fc_feats.shape
Out[1]: torch.Size([10, 8192])

In [2]:  att_feats.shape
Out[2]: torch.Size([10, 196, 2048])

In [3]:  seq.shape
Out[3]: torch.Size([10, 284])

In [4]:  att_masks.shape
---------------------------------------------------------------------------
AttributeError                            Traceback (most recent call last)
Cell In[4], line 1
----> 1 att_masks.shape

AttributeError: 'NoneType' object has no attribute 'shape'

经过prepare_feature,

In [5]:  att_feats, seq, att_masks, seq_mask = self._prepare_feature_forward(att_feats, att_masks, seq)

In [6]:  att_feats.shape
Out[6]: torch.Size([10, 196, 512])

In [7]:  seq.shape
Out[7]: torch.Size([10, 283])

In [8]:  att_masks.shape
Out[8]: torch.Size([10, 1, 196])

In [9]:  seq_mask.shape
Out[9]: torch.Size([10, 283, 283])

在prepare的函数中,使用clip进行了特征的裁剪,如果是直接使用预训练的,我认为这样直接剪切是不合理,应该进行embedding进行映射

def _prepare_feature(self, fc_feats, att_feats, att_masks):
	att_feats, att_masks = self.clip_att(att_feats, att_masks)
	
	# embed fc and att feats
	fc_feats = self.fc_embed(fc_feats)
	att_feats = pack_wrapper(self.att_embed, att_feats, att_masks)
	
	# Project the attention feats first to reduce memory and computation comsumptions.
	p_att_feats = self.ctx2att(att_feats)

1. Transformer进行编码

1.1 Transformer架构
class Transformer(nn.Module):
    def __init__(self, encoder, decoder, src_embed, tgt_embed, cmn, model_type):
        super(Transformer, self).__init__()
        self.encoder = encoder
        self.decoder = decoder
        self.src_embed = src_embed
        self.tgt_embed = tgt_embed
        self.cmn = cmn
        self.model_type=model_type
        
    def forward(self, src, tgt, src_mask, tgt_mask, memory_matrix=None):
        return self.decode(self.encode(src, src_mask), src_mask, tgt, tgt_mask, memory_matrix=memory_matrix)

    def encode(self, src, src_mask):
        return self.encoder(self.src_embed(src), src_mask)

    def decode(self, memory, src_mask, tgt, tgt_mask, past=None, memory_matrix=None):
        embeddings = self.tgt_embed(tgt)
        print(f"这个是transformer的decode函数")
        embed()
        #@ymy CLS
        if self.model_type=="CMN":
            # Memory querying and responding for textual features
            dummy_memory_matrix = memory_matrix.unsqueeze(0).expand(embeddings.size(0), memory_matrix.size(0), memory_matrix.size(1))
            responses = self.cmn(embeddings, dummy_memory_matrix, dummy_memory_matrix)
            embeddings = embeddings + responses
            # Memory querying and responding for textual features
        #@ymy SEP
        
        return self.decoder(embeddings, memory, src_mask, tgt_mask, past=past)
1.2 位置编码,和前馈网络 (Position-wise FFN), 注意力, 多头注意力
class PositionalEncoding(nn.Module):
    def __init__(self, d_model, dropout, max_len=5000):
        super(PositionalEncoding, self).__init__()
        self.dropout = nn.Dropout(p=dropout)

        pe = torch.zeros(max_len, d_model)
        position = torch.arange(0, max_len).unsqueeze(1).float()
        div_term = torch.exp(torch.arange(0, d_model, 2).float() *
                             -(math.log(10000.0) / d_model))
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        pe = pe.unsqueeze(0)
        self.register_buffer('pe', pe)

    def forward(self, x):
        print(f"这里是PositionalEncoding的forward")
        embed()
        x = x + self.pe[:, :x.size(1)]
        return self.dropout(x)
        
class PositionwiseFeedForward(nn.Module):
    def __init__(self, d_model, d_ff, dropout=0.1):
        super(PositionwiseFeedForward, self).__init__()
        self.w_1 = nn.Linear(d_model, d_ff)
        self.w_2 = nn.Linear(d_ff, d_model)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x):
        print(f"这里是 PositionwiseFeedForward 的forward")
        embed()
        return self.w_2(self.dropout(F.relu(self.w_1(x))))

def attention(query, key, value, mask=None, dropout=None):
    d_k = query.size(-1)
    scores = torch.matmul(query, key.transpose(-2, -1)) / math.sqrt(d_k)
    if mask is not None:
        scores = scores.masked_fill(mask == 0, float('-inf'))
    p_attn = F.softmax(scores, dim=-1)
    if dropout is not None:
        p_attn = dropout(p_attn)
    return torch.matmul(p_attn, value), p_attn
    
class MultiHeadedAttention(nn.Module):
    def __init__(self, h, d_model, dropout=0.1):
        super(MultiHeadedAttention, self).__init__()
        assert d_model % h == 0
        self.d_k = d_model // h
        self.h = h
        self.linears = clones(nn.Linear(d_model, d_model), 4)
        self.attn = None
        self.dropout = nn.Dropout(p=dropout)

    def forward(self, query, key, value, mask=None, layer_past=None):
        print(f"这里是多头注意力的forward")
        embed()
        if mask is not None:
            mask = mask.unsqueeze(1)
        nbatches = query.size(0)
        if layer_past is not None and layer_past.shape[2] == key.shape[1] > 1:
            query = self.linears[0](query)
            key, value = layer_past[0], layer_past[1]
            present = torch.stack([key, value])
        else:
            query, key, value = \
                [l(x) for l, x in zip(self.linears, (query, key, value))]

        if layer_past is not None and not (layer_past.shape[2] == key.shape[1] > 1):
            past_key, past_value = layer_past[0], layer_past[1]
            key = torch.cat((past_key, key), dim=1)
            value = torch.cat((past_value, value), dim=1)
            present = torch.stack([key, value])

        query, key, value = \
            [x.view(nbatches, -1, self.h, self.d_k).transpose(1, 2)
             for x in [query, key, value]]

        x, self.attn = attention(query, key, value, mask=mask,
                                 dropout=self.dropout)
        x = x.transpose(1, 2).contiguous() \
            .view(nbatches, -1, self.h * self.d_k)
        if layer_past is not None:
            return self.linears[-1](x), present
        else:
            return self.linears[-1](x)
1.3 模型框架,整体梳理
###########   Encoder:   #####################
ModuleList(
  (0-2): 3 x EncoderLayer(
    (self_attn): MultiHeadedAttention(
      (linears): ModuleList(
        (0-3): 4 x Linear(in_features=512, out_features=512, bias=True)
      )
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (feed_forward): PositionwiseFeedForward(
      (w_1): Linear(in_features=512, out_features=512, bias=True)
      (w_2): Linear(in_features=512, out_features=512, bias=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (sublayer): ModuleList(
      (0-1): 2 x SublayerConnection(
        (norm): LayerNorm()
        (dropout): Dropout(p=0.1, inplace=False)
      )
    )
  )
)

###########   Decoder:   #####################
ModuleList(
  (0-2): 3 x DecoderLayer(
    (self_attn): MultiHeadedAttention(
      (linears): ModuleList(
        (0-3): 4 x Linear(in_features=512, out_features=512, bias=True)
      )
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (src_attn): MultiHeadedAttention(
      (linears): ModuleList(
        (0-3): 4 x Linear(in_features=512, out_features=512, bias=True)
      )
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (feed_forward): PositionwiseFeedForward(
      (w_1): Linear(in_features=512, out_features=512, bias=True)
      (w_2): Linear(in_features=512, out_features=512, bias=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (sublayer): ModuleList(
      (0-2): 3 x SublayerConnection(
        (norm): LayerNorm()
        (dropout): Dropout(p=0.1, inplace=False)
      )
    )
  )
)
1.4 模型的代码,这里我把类的名字改成了EncoderDecoder,实际上它原来是BaseCMN
class EncoderDecoder(AttModel):
    def __init__(self, args, tokenizer, model_type="base", core_type='fast'):
        super(EncoderDecoder, self).__init__(args, tokenizer)
        self.args = args
        self.num_layers = args.num_layers
        self.d_model = args.d_model
        self.d_ff = args.d_ff
        self.num_heads = args.num_heads
        self.dropout = args.dropout
        self.topk = args.topk

        tgt_vocab = self.vocab_size + 1

        self.cmn = MultiThreadMemory(args.num_heads, args.d_model, topk=args.topk)
        
        self.model_type = model_type
        self.core_type = core_type
        
        self.model = self.make_model(tgt_vocab, self.cmn)
        self.logit = nn.Linear(args.d_model, tgt_vocab)

        self.memory_matrix = nn.Parameter(torch.FloatTensor(args.cmm_size, args.cmm_dim))
        nn.init.normal_(self.memory_matrix, 0, 1 / args.cmm_dim)
        
    def make_model(self, tgt_vocab, cmn):
        c = copy.deepcopy
        attn = MultiHeadedAttention(self.num_heads, self.d_model)
        ff = PositionwiseFeedForward(self.d_model, self.d_ff, self.dropout)
        position = PositionalEncoding(self.d_model, self.dropout)
        model = Transformer(
            Encoder(EncoderLayer(self.d_model, c(attn), c(ff), self.dropout), self.num_layers),
            Decoder(DecoderLayer(self.d_model, c(attn), c(attn), c(ff), self.dropout), self.num_layers),
            nn.Sequential(c(position)),
            nn.Sequential(Embeddings(self.d_model, tgt_vocab), c(position)), 
            cmn,
            self.model_type
        )
        for p in model.parameters():
            if p.dim() > 1:
                nn.init.xavier_uniform_(p)
        return model
        
    def init_hidden(self, bsz):
        return []

    def _prepare_feature(self, fc_feats, att_feats, att_masks):
        att_feats, seq, att_masks, seq_mask = self._prepare_feature_forward(att_feats, att_masks)
        memory = self.model.encode(att_feats, att_masks)

        return fc_feats[..., :1], att_feats[..., :1], memory, att_masks

    def _prepare_feature_forward(self, att_feats, att_masks=None, seq=None):
        #att_feats, att_masks = self.clip_att(att_feats, att_masks)
        att_feats = pack_wrapper(self.att_embed, att_feats, att_masks)

        if att_masks is None:
            att_masks = att_feats.new_ones(att_feats.shape[:2], dtype=torch.long)

        #@ymy CLS
        
        # Memory querying and responding for visual features
        if self.model_type=="CMN":
            print(f"这里是prepare feature forward")
            embed()
            dummy_memory_matrix = self.memory_matrix.unsqueeze(0).expand(att_feats.size(0), self.memory_matrix.size(0), self.memory_matrix.size(1))
            responses = self.cmn(att_feats, dummy_memory_matrix, dummy_memory_matrix)
            att_feats = att_feats + responses
        # Memory querying and responding for visual features

        ##@ymy SEP
        
        att_masks = att_masks.unsqueeze(-2)
        if seq is not None:
            
            seq = seq[:, :-1]
            seq_mask = (seq.data > 0)
            seq_mask[:, 0] += True

            seq_mask = seq_mask.unsqueeze(-2)
            seq_mask = seq_mask & subsequent_mask(seq.size(-1)).to(seq_mask)
        else:
            seq_mask = None
        
        return att_feats, seq, att_masks, seq_mask

    def _forward(self, fc_feats, att_feats, seq, att_masks=None):
        att_feats, seq, att_masks, seq_mask = self._prepare_feature_forward(att_feats, att_masks, seq)

        #@ymy CLS
        if self.model_type=="CMN":
            out = self.model(att_feats, seq, att_masks, seq_mask, memory_matrix=self.memory_matrix)
        else:
            out = self.model(att_feats, seq, att_masks, seq_mask)
        #@ymy SEP
        
        outputs = F.log_softmax(self.logit(out), dim=-1)

        return outputs

    def _save_attns(self, start=False):
        if start:
            self.attention_weights = []
        self.attention_weights.append([layer.src_attn.attn.cpu().numpy() for layer in self.model.decoder.layers])

    def core(self, it, fc_feats_ph, att_feats_ph, memory, state, mask):
        if len(state) == 0:
            ys = it.unsqueeze(1)
            past = [fc_feats_ph.new_zeros(self.num_layers * 2, fc_feats_ph.shape[0], 0, self.d_model),
                    fc_feats_ph.new_zeros(self.num_layers * 2, fc_feats_ph.shape[0], 0, self.d_model)]
        else:
            ys = torch.cat([state[0][0], it.unsqueeze(1)], dim=1)
            past = state[1:]

        #@ymy CLS
        if self.model_type=="CMN":
            out, past = self.model.decode(memory, mask, ys, subsequent_mask(ys.size(1)).to(memory.device), past=past,
                                      memory_matrix=self.memory_matrix)
        else:
            out, past = self.model.decode(memory, mask, ys, subsequent_mask(ys.size(1)).to(memory.device), past=past)
        #@ymy SEP
        
        if not self.training:
            self._save_attns(start=len(state) == 0)
        return out[:, -1], [ys.unsqueeze(0)] + past

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1559449.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

2024年noc指导教师认证测评参考试题题目5-6合集

[noc指导教师认证] 测评参考试题 说明:NOC教师指导认证考试题目是从题库里抽题,因此每位老师每次考试题目都不一样以下题目为测试考试时收集到的一些题目,作为辅助提供给各位老师,老师们可以记住题目及答案的具体内容 (选项顺序会变),以免考试时遇到。2024年的做的题目有的…

两张图片相似度匹配算法学习路线

大纲:​​​​​​目标跟踪基础:两张图片相似度算法-腾讯云开发者社区-腾讯云 (tencent.com) 目标跟踪基础:两张图片相似度算法 (qq.com) 一、传统方法 1.欧式距离(用于判断是否完全相同) [三维重建] [机器学习] 图…

RAG进阶笔记:RAG进阶

1 查询/索引部分 1.1 层次索引 创建两个索引——一个由摘要组成,另一个由文档块组成分两步进行搜索:首先通过摘要过滤出相关文档,接着只在这个相关群体内进行搜索 1.2 假设性问题 让LLM为每个块生成一个假设性问题,并将这些问…

一文看懂DDR1到DDR5的演变

DDR内存技术自问世以来,已经经历了多代的迭代和优化。每一代DDR内存都在性能、容量、功耗和功能上有所提升。以下是从DDR1到DDR5的主要区别和特点: DDR1 DDR1是第一代双倍数据速率同步动态随机存取存储器,标志着内存技术的一个重要里程碑。…

网安基础2-Sniffer的使用与防范

1. 嗅探器sniffer的工作原理 能捕获经过该网络设备的报文,通过分析网络流量,找出关键信息,解决网络问题。 不同于键盘捕获程序,如keylogger利用中断或钩子技术,Sniffer将网络接口置成适当的模式,如杂收。…

瑞吉外卖实战学习--9、mybatisPlus公共字段自动填充

mybatisPlus公共字段自动填充 前言实现步骤实体类添加注解按照框架要求编写元数据对象处理器,在此类中统一为公共字段赋值,此类需要实现MetaObjectHandler接口1、在创建和更新的时候修改创建和更新的时候自动填充时间2、如何获取到当前的id 测试结果 前言…

Vision Mamba 双向状态空间模型下的高效视觉表示学习

论文题目:Vision Mamba: Efficient Visual Representation Learning with Bidirectional State Space Model 双向状态空间模型下的高效视觉表示学习 论文链接:http://arxiv.org/abs/2401.09417 代码链接:https://github.com/hustvl/Vim 1、…

Java中的多线程和线程安全问题

线程 线程是操作系统进行调度的最小单位。一个进程至少包含一个主线程,而一个线程可以启动多个子线程。线程之间共享进程的资源,但也有自己的局部变量。多线程程序和普通程序的区别:每个线程都是一个独立的执行流;多个线程之间是…

大模型面试准备(九):简单透彻理解MoE

节前,我们组织了一场算法岗技术&面试讨论会,邀请了一些互联网大厂朋友、参加社招和校招面试的同学,针对大模型技术趋势、大模型落地项目经验分享、新手如何入门算法岗、该如何备战、面试常考点分享等热门话题进行了深入的讨论。 合集在这…

C++list的模拟实现

为了实现list&#xff0c;我们需要实现三个类 一、List的节点类 template<class T> struct ListNode {ListNode(const T& val T()):_pPre(nullptr),_pNext(nullptr),_val(val){}ListNode<T>* _pPre;ListNode<T>* _pNext;T _val; }; 二、List的迭代器…

双网卡环境概率出现DNS解析错误

测试环境 VMware Rocky Linux 9 虚拟机, 双网卡(eth0和eth1)配置如下&#xff1a; eth0 10.206.216.27/24 DNS 10.204.16.18 eth1 192.168.1.27/24 DNS 192.168.1.1问题描述 手动配置eth1的DNS后&#xff0c;网络不通&#xff0c;通过抓包发现是eth1的DNS server配置有误…

【JavaWeb】Day29.SpringBootWeb请求响应——请求(二)

请求响应 4.数组集合参数 数组集合参数的使用场景&#xff1a;在HTML的表单中&#xff0c;有一个表单项是支持多选的(复选框)&#xff0c;可以提交选择的多个值。 4.1 数组 数组参数&#xff1a;请求参数名与形参数组名称相同且请求参数为多个&#xff0c;定义数组类型形参即…

springboot简历系统

摘 要 随着科学技术的飞速发展&#xff0c;社会的方方面面、各行各业都在努力与现代的先进技术接轨&#xff0c;通过科技手段来提高自身的优势&#xff0c;简历系统当然也不能排除在外。简历系统是以实际运用为开发背景&#xff0c;运用软件工程原理和开发方法&#xff0c;采用…

速通汇编(三)寄存器及汇编mul、div指令

一&#xff0c;寄存器及标志 AH&ALAX(accumulator)&#xff1a;累加寄存器BH&BLBX(base)&#xff1a;基址寄存器CH&CLCX(count)&#xff1a;计数寄存器DH&DLDX(data)&#xff1a;数据寄存器SP(Stack Pointer)&#xff1a;堆栈指针寄存器BP(Base Pointer)&#…

Vue3+Vite Nginx部署 跨域

打包项目 webstorm打开项目之后&#xff0c;在Terminal执行打包命令 pnpm run build:prod 复制到Nginx 打包完成之后,生成的包在根目录dist&#xff0c;把dist目录拷贝到Nginx放网站目录下&#xff1a;\nginx-1.25.2\html\divided &#xff0c;dist改名了divided 修改配置…

力扣---网络延迟时间---迪杰斯特拉,弗洛伊德floyd

首先推荐博客&#xff1a;图论最短路径专题&#xff08;力扣743、5888&#xff09;_力扣 最短路径-CSDN博客 迪杰斯特拉算法&#xff1a; 太久没有做图论的题了&#xff0c;&#xff0c;临时抱佛脚。。 这道题可以转化为max{点x到点k的距离}。因为带权图&#xff08;权值为正…

[超详细]3种方法判断一个数是否为质数(Python)

(发现好多博客对第三种进阶方法说的不明白&#xff0c;至少我是没完全看明白。后面结合自己的理解应该算是弄懂了&#xff0c;供大家参考&#xff0c;欢迎纠正。) 方法一&#xff1a;最暴力&#xff0c;最简单&#xff0c;也最耗时O(n) 思想&#xff1a;由素数的定义&#xf…

arcgis 无法编辑元素的解决办法(无法删除元素或者缺失值替换)

打开“编辑器”中&#xff0c;“开始编辑”即可进行元素编辑&#xff0c;也可进行缺失值替换 &#xff08;其他方式&#xff1a;选中图层&#xff0c;右击点击开始编辑&#xff09; 在元素编辑状态下无法删除变量&#xff0c;可以删除元素 元素编辑结束后 点击“编辑器”&…

深入剖析Spring WebFlux:从MethodHandler到反射获取请求信息的源码之旅

文章目录 前言一、获取请求执行的类、方法信息二、获取请求url变量三、获取请求处理数据总结 前言 最近想写一个代办事项后台服务&#xff0c;底层&#xff0c;选型WebFlux。在操作层面上&#xff0c;针对部分操作&#xff0c;想在不侵入业务代码的前提下&#xff0c;记录操作…

使用 Seq2Seq 模型进行文本摘要

目录 引言 1 导入数据集 2 清洗数据集 3 确定允许的最大序列长度 4 选择合理的文本和摘要 5 对文本进行标记 6 删除空文本和摘要 7 构建模型 7.1 编码器 7.2 解码器 8 训练模型 9 测试模型 10 注意 11 整体代码 引言 文本摘要是指在捕捉其本质的同时缩短长文本的…