Seq2Seq - Dataset 类

news2025/4/18 11:04:27

本节代码定义了一个 CMN 类,它继承自 PyTorch 的 Dataset 类,用于处理英文和中文的平行语料库。这个类的主要作用是将文本数据转换为模型可以处理的格式,并进行必要的填充操作,以确保所有序列的长度一致。

⭐重写Dataset类是模型训练的重中之重请务必掌握!

重写时格式固定为三件套 __init__  __len__ __getitem__重点记忆! 

1. 类定义

class CMN(Dataset):
    def __init__(self, en_corpus, cn_corpus, en_tokenizer: Tokenizer, cn_tokenizer: Tokenizer, seq_len):
        self.en_corpus = en_corpus
        self.cn_corpus = cn_corpus

        self.en_tokenizer = en_tokenizer
        self.cn_tokenizer = cn_tokenizer

        self.seq_len = seq_len

        self.pad_id = self.en_tokenizer.vocab["[PAD]"]
        self.bos_id = self.en_tokenizer.vocab["[BOS]"]
        self.eos_id = self.en_tokenizer.vocab["[EOS]"]
参数
  • en_corpus:英文语料库,是一个字符串列表。

  • cn_corpus:中文语料库,是一个字符串列表。

  • en_tokenizer:英文分词器,用于将英文文本转换为索引。

  • cn_tokenizer:中文分词器,用于将中文文本转换为索引。

  • seq_len:序列的最大长度,用于填充或截断序列。

属性
  • self.pad_id:填充标记 [PAD] 的索引。

  • self.bos_id:序列开始标记 [BOS] 的索引。

  • self.eos_id:序列结束标记 [EOS] 的索引。

2. 数据集长度(__len__

def __len__(self):
    return len(self.en_corpus)
  • 功能:返回数据集的长度,即语料库中句子的数量。

  • 返回值:数据集的长度。

3. 获取数据项(__getitem__

def __getitem__(self, idx):
    en_text = self.en_corpus[idx]
    cn_text = self.cn_corpus[idx]

    en_ids = self.en_tokenizer.encode(en_text)
    cn_ids = self.cn_tokenizer.encode(cn_text)

    encoder_input = self.pad_to_seq_len([self.bos_id] + en_ids)  # batch * seq_len
    decoder_input = self.pad_to_seq_len([self.bos_id] + cn_ids)

    labels = self.pad_to_seq_len(cn_ids + [self.eos_id])

    return {
        "encoder_input": encoder_input,
        "decoder_input": decoder_input,
        "labels": labels,
        "en_text": en_text,
        "cn_text": cn_text
    }

CMN 类的 __getitem__ 方法中,代码的主要目的是将英文和中文文本转换为模型可以处理的格式,并进行必要的填充操作,以确保所有序列的长度一致。以下是对 __getitem__ 方法中各个部分的详细解释:

1. 获取文本
en_text = self.en_corpus[idx]
cn_text = self.cn_corpus[idx]
  • 功能:从语料库中获取索引为 idx 的英文句子 en_text 和中文句子 cn_text

  • 目的:为每个索引提供一对对应的英文和中文句子,用于后续的编码和解码。

2. 文本编码
en_ids = self.en_tokenizer.encode(en_text)
cn_ids = self.cn_tokenizer.encode(cn_text)
  • 功能:将英文和中文句子分别通过对应的分词器编码为索引列表。

  • 目的:将文本转换为模型可以处理的数值形式。分词器将每个字符(或单词)映射为词汇表中的索引。

3. 构建输入序列
encoder_input = self.pad_to_seq_len([self.bos_id] + en_ids)  # batch * seq_len
decoder_input = self.pad_to_seq_len([self.bos_id] + cn_ids)
  • 功能:构建编码器和解码器的输入序列。

  • 目的

    • 编码器输入:在英文索引列表的开头添加 [BOS] 标记,表示序列的开始。然后对序列进行填充或截断,使其长度为 seq_len

    • 解码器输入:在中文索引列表的开头添加 [BOS] 标记,表示序列的开始。同样进行填充或截断,使其长度为 seq_len

  • 为什么这样写

    • [BOS] 标记:在序列的开头添加 [BOS] 标记,是为了让模型知道序列的开始位置。这对于模型理解序列的起始点非常重要,尤其是在解码阶段。

    • 填充或截断:为了确保所有序列的长度一致,需要对序列进行填充或截断。填充是通过添加 [PAD] 标记来实现的,截断则是直接截取序列的前 seq_len 个元素。

4. 构建目标序列(标签)
labels = self.pad_to_seq_len(cn_ids + [self.eos_id])
  • 功能:构建解码器的目标序列(标签)。

  • 目的:为目标序列添加 [EOS] 标记,表示序列的结束。然后进行填充或截断,使其长度为 seq_len

  • 为什么这样写

    • [EOS] 标记:在目标序列的末尾添加 [EOS] 标记,是为了让模型知道序列的结束位置。这对于模型在解码阶段生成完整的序列非常重要。

    • 填充或截断:同样是为了确保所有序列的长度一致,需要对目标序列进行填充或截断。

5. 返回结果
return {
    "encoder_input": encoder_input,
    "decoder_input": decoder_input,
    "labels": labels,
    "en_text": en_text,
    "cn_text": cn_text
}
  • 功能:返回一个字典,包含以下内容:

    • "encoder_input":编码器的输入序列。

    • "decoder_input":解码器的输入序列。

    • "labels":解码器的目标序列。

    • "en_text":原始英文句子。

    • "cn_text":原始中文句子。

  • 目的:提供模型训练所需的所有输入和目标数据,同时保留原始文本以便后续验证和调试。

6. 填充序列(pad_to_seq_len
def pad_to_seq_len(self, x):
    pad_num = self.seq_len - len(x)
    return torch.tensor(x + [self.pad_id] * pad_num)
  • 功能:将一个索引列表填充或截断到指定的序列长度 seq_len

  • 目的:确保所有序列的长度一致,以便模型可以批量处理。

  • 为什么这样写

    • 填充:如果序列长度小于 seq_len,则在末尾添加 [PAD] 标记,直到长度达到 seq_len

    • 截断:如果序列长度大于 seq_len,则直接截取前 seq_len 个元素。

    • 转换为张量:将填充或截断后的列表转换为 PyTorch 张量,以便模型可以直接使用。

4. 填充序列(pad_to_seq_len

def pad_to_seq_len(self, x):
    pad_num = self.seq_len - len(x)
    return torch.tensor(x + [self.pad_id] * pad_num)
功能
  • 将一个索引列表填充或截断到指定的序列长度 seq_len

过程
  1. 计算填充数量

    • pad_num 是目标长度 seq_len 与当前列表长度的差值。

    • 如果 pad_num 为正数,则需要填充;如果为负数,则需要截断。

  2. 填充或截断

    • 如果 pad_num 为正数,将 [self.pad_id] 重复 pad_num 次,添加到列表的末尾。

    • 如果 pad_num 为负数,直接截断列表的末尾部分。

  3. 返回结果

    • 将填充或截断后的列表转换为 PyTorch 张量并返回。

示例

假设 seq_len=10x=[2, 3, 4],调用 pad_to_seq_len(x) 的结果:

pad_num = 10 - 3 = 7
result = [2, 3, 4] + [0, 0, 0, 0, 0, 0, 0]  # 假设 pad_id=0
torch.tensor([2, 3, 4, 0, 0, 0, 0, 0, 0, 0])

5.  CMN 类实现了以下功能:

  1. 数据读取:从语料库中读取英文和中文句子。

  2. 文本编码:将文本转换为索引列表。

  3. 序列填充:将索引列表填充或截断到指定长度。

  4. 构建输入和标签:为编码器和解码器构建输入序列和目标序列。

这些步骤是构建 Seq2Seq 模型中数据预处理的重要环节,确保了数据可以被模型有效处理。

需复现完整代码如下:

class CMN(Dataset):
    def __init__(self, en_corpus, cn_corpus, en_tokenizer: Tokenizer, cn_tokenizer: Tokenizer, seq_len):
        self.en_corpus = en_corpus
        self.cn_corpus = cn_corpus

        self.en_tokenizer = en_tokenizer
        self.cn_tokenizer = cn_tokenizer

        self.seq_len = seq_len

        self.pad_id = self.en_tokenizer.vocab["[PAD]"]
        self.bos_id = self.en_tokenizer.vocab["[BOS]"]
        self.eos_id = self.en_tokenizer.vocab["[EOS]"]


    def __len__(self):
        return len(self.en_corpus)
    
    def __getitem__(self, idx):
        en_text = self.en_corpus[idx]
        cn_text = self.cn_corpus[idx]

        en_ids = self.en_tokenizer.encode(en_text)
        cn_ids = self.cn_tokenizer.encode(cn_text)

        encoder_input = self.pad_to_seq_len([self.bos_id] + en_ids) #batch * seq_len
        decoder_input = self.pad_to_seq_len([self.bos_id] + cn_ids)

        labels = self.pad_to_seq_len(cn_ids + [self.eos_id])

        
        return {
            "encoder_input": encoder_input,
            "decoder_input": decoder_input,
            "labels": labels,
            "en_text": en_text,
            "cn_text": cn_text
        }
    
    def pad_to_seq_len(self, x):
        pad_num = self.seq_len - len(x)
        return torch.tensor(x + [self.pad_id] * pad_num)

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2331588.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

echarts图表相关

echarts图表相关 echarts官网折线图实际开发场景一: echarts官网 echarts官网 折线图 实际开发场景一: 只有一条折线,一半实线,一半虚线。 option {tooltip: {trigger: "axis",formatter: (params: any) > {const …

idea自动部署jar包到服务器Alibaba Cloud Toolkit

安装插件:Alibaba Cloud Toolkit 配置服务器: 服务器配置: 项目启动Shell脚本命令: projectpd-otb.jar echo 根据项目名称查询对应的pid pid$(pgrep -f $project); echo $pid echo 杀掉对应的进程,如果pid不存在,则不执行 if [ …

Element Plus 图标使用方式整理

Element Plus 图标使用方式整理 以下是 Element Plus 图标的所有使用方式&#xff0c;包含完整代码示例和总结表格&#xff1a; 1. 按需引入图标组件 适用场景&#xff1a;仅需少量图标时&#xff0c;按需导入减少打包体积 示例代码&#xff1a; <template><div>…

链路聚合+vrrp

1.链路聚合 作用注意事项将多个物理接口&#xff08;线路&#xff09;逻辑上绑定在一起形成一条逻辑链路&#xff0c;起到叠加带宽的作用1.聚合接口必须转发速率一致。2.聚合设备两端必须一致 配置命令 方法一 [Huawei]interface Eth-Trunk 0----先创建聚合接口&#xff0c;…

Dynamics 365 Business Central Register Customer Payment 客户付款登记

#Dynamics 365 BC ERP# #D365 ERP# #Navision 前言 在实施过程&#xff0c;经常给客户介绍的 给客户付款一般用Payment Journal. 在客户熟悉系统运行后&#xff0c;往往会推荐客户使用Register Customer Payment.用这个function 工作会快很多&#xff0c;但出错的机会也比较大…

Odoo免费开源ERP:企业销售过程中出现的问题

在企业未上线Odoo免费开源ERP时&#xff0c;企业销售过程中会存在失误。比如&#xff0c;许多销售订单都有如下问题&#xff1a;不当的定价、向客户过多地询问、处理订单延误、错过发货日期等。这些问题源于企业三个未集成的信息系统&#xff1a;销售管理系统、库存系统和财务系…

网络稳定性--LCA+最大生成树+bfs1/dfs1找最小边

1.最大生成树去除重边&#xff0c;只要最大的边成树 2.LCA查最近公共祖先&#xff0c;然后询问的lca(x,y)ff,分别从x,y向上找最小边 3.bfs1/dfs1就是2.中向上找的具体实现 #include<bits/stdc.h> using namespace std; #define N 100011 typedef long long ll; typede…

混合并行技术在医疗AI领域的应用分析(代码版)

混合并行技术(专家并行/张量并行/数据并行)通过多维度的计算资源分配策略,显著提升了医疗AI大模型的训练效率与推理性能。以下结合技术原理与医疗场景实践,从策略分解、技术对比、编排优化及典型案例等维度展开分析: 一、混合并行技术:突破单卡算力限制 1. 并行策略三维分…

【C++面向对象】封装(上):探寻构造函数的幽微之境

每文一诗 &#x1f4aa;&#x1f3fc; 我本将心向明月&#xff0c;奈何明月照沟渠 —— 元/高明《琵琶记》 译文&#xff1a;我本是以真诚的心来对待你&#xff0c;就像明月一样纯洁无瑕&#xff1b;然而&#xff0c;你却像沟渠里的污水一样&#xff0c;对这份心意无动于衷&a…

每日算法-250409

这是我今天的算法学习记录。 2187. 完成旅途的最少时间 题目描述 思路 二分查找 解题过程 为什么可以使用二分查找&#xff1f; 问题的关键在于寻找一个最小的时间 t&#xff0c;使得在时间 t 内所有公交车完成的总旅途次数 sum 大于等于 totalTrips。 我们可以观察到时间的单…

如何实现文本回复Ai ChatGPT DeepSeek 式文字渐显效果?前端技术详解(附完整代码)

个人开发的塔罗牌占卜小程序&#xff1a;【问问塔罗牌】 快来瞧瞧吧&#xff01; 一、核心实现原理 我们通过三步实现这个效果&#xff1a; 逐字渲染&#xff1a;通过 JavaScript 定时添加字符 透明度动画&#xff1a;CSS 实现淡入效果 光标动画&#xff1a;伪元素 CSS 动画…

linux下截图工具的选择

方案一 gnome插件Screenshot Tool&#xff08;截屏&#xff09; ksnip&#xff08;图片标注&#xff09; gnome setting设置图片的默认打开方式为ksnip就可以快捷的将Screenshot Tool截屏的图片打开进行标记了。 但是最近我发现Screenshot Tool的延迟截图功能是有问题的&…

rkmpp 解码 精简mpi_dec_test.c例程

rkmpp 解码流程&#xff08;除 MPP_VIDEO_CodingMJPEG 之外&#xff09; 源码 输入h264码流 输出nv12文件 /** Copyright 2015 Rockchip Electronics Co. LTD** Licensed under the Apache License, Version 2.0 (the "License");* you may not use this file exce…

怎么构造思维链数据?思维链提示工程的五大原则

我来为您翻译这篇关于思维链提示工程的文章&#xff0c;采用通俗易懂的中文表达&#xff1a; 思维链(CoT)提示工程是生成式AI(GenAI)中一种强大的方法&#xff0c;它能让模型通过逐步推理来解决复杂任务。通过构建引导模型思考过程的提示&#xff0c;思维链能提高输出的准确性…

网络安全之-信息收集

域名收集 域名注册信息 站长之家 https://whois.chinaz.com/ whois 查询的相关网站有:中国万网域名WHOIS信息查询地址: https://whois.aliyun.com/西部数码域名WHOIS信息查询地址: https://whois.west.cn/新网域名WHOIS信息查询地址: http://whois.xinnet.com/domain/whois/in…

JdbcTemplate基本使用

JdbcTemplate概述 它是spring框架中提供的一个对象&#xff0c;是对原始繁琐的JdbcAPI对象的简单封装。spring框架为我们提供了很多的操作模板类。例如:操作关系型数据的JdbcTemplate和MbernateTemplate&#xff0c;操作nosql数据库的RedisTemplate&#xff0c;操作消息队列的…

openEuler24.03 LTS下安装Spark

目录 安装模式介绍 下载Spark 安装Local模式 前提条件 解压安装包 简单使用 安装Standalone模式 前提条件 集群规划 解压安装包 配置Spark 配置Spark-env.sh 配置workers 分发到其他机器 启动集群 简单使用 关闭集群 安装YARN模式 前提条件 解压安装包 配…

使用 DeepSeek API 实现新闻文章地理位置检测与地图可视化

使用 DeepSeek API 实现新闻文章地理位置检测与地图可视化 | Implementing News Article Location Detection and Map Visualization with DeepSeek API 作者&#xff1a;zhutoutoutousan | Author: zhutoutoutousan 发布时间&#xff1a;2025-04-08 | Published: 2025-04-08 标…

如何精准控制大模型的推理深度

论文标题 ThinkEdit: Interpretable Weight Editing to Mitigate Overly Short Thinking in Reasoning Models 论文地址 https://arxiv.org/pdf/2503.22048 代码地址 https://github.com/Trustworthy-ML-Lab/ThinkEdit 作者背景 加州大学圣迭戈分校 动机 链式推理能显…

【力扣hot100题】(078)跳跃游戏Ⅱ

好难啊&#xff0c;我愿称之为跳崖游戏。 依旧用了两种方法&#xff0c;一种是我一开始想到的&#xff0c;一种是看答案学会的。 我自己用的方法是动态规划&#xff0c;维护一个数组记录到该位置的最少步长&#xff0c;每遍历到一个位置就嵌套循环遍历这个位置能到达的位置&a…