TensorFlow2实战-系列教程11:RNN文本分类3

news2025/1/2 0:02:40

🧡💛💚TensorFlow2实战-系列教程 总目录

有任何问题欢迎在下面留言
本篇文章的代码运行界面均在Jupyter Notebook中进行
本篇文章配套的代码资源已经上传

6、构建训练数据

  • 所有的输入样本必须都是相同shape(文本长度,词向量维度等)
  • tf.data.Dataset.from_tensor_slices(tensor):将tensor沿其第一个维度切片,返回一个含有N个样本的数据集,这样做的问题就是需要将整个数据集整体传入,然后切片建立数据集类对象,比较占内存。
  • tf.data.Dataset.from_generator(data_generator,output_data_type,output_data_shape):从一个生成器中不断读取样本
def data_generator(f_path, params):
    with open(f_path,encoding='utf-8') as f:
        print('Reading', f_path)
        for line in f:
            line = line.rstrip()
            label, text = line.split('\t')
            text = text.split(' ')
            x = [params['word2idx'].get(w, len(word2idx)) for w in text]#得到当前词所对应的ID
            if len(x) >= params['max_len']:#截断操作
                x = x[:params['max_len']]
            else:
                x += [0] * (params['max_len'] - len(x))#补齐操作
            y = int(label)
            yield x, y
  1. 定义一个生成器函数,传进来读数据的路径、和一些有限制的参数,params 在是一个字典,它包含了最大序列长度(max_len)、词到索引的映射(word2idx)等关键信息
  2. 打开文件
  3. 打印文件路径
  4. 遍历每行数据
  5. 获取标签和文本
  6. 文本按照空格分离出单词
  7. 获取当前句子的所有词对应的索引,for w in text取出这个句子的每一个单词,[params[‘word2idx’]取出params中对应的word2idx字典,.get(w, len(word2idx))从word2idx字典中取出该单词对应的索引,如果有这个索引则返回这个索引,如果没有则返回len(word2idx)作为索引,这个索引表示unknow
  8. 如果当前句子大于预设的最大句子长度
  9. 进行截断操作
  10. 如果小于
  11. 补充0
  12. 标签从str转换为int类型
  13. yield 关键字:用于从一个函数返回一个生成器(generator)。与 return 不同,yield 不会退出函数,而是将函数暂时挂起,保存当前的状态,当生成器再次被调用时,函数会从上次 yield 的地方继续执行,使用 yield 的函数可以在处理大数据集时节省内存,因为它允许逐个生成和处理数据,而不是一次性加载整个数据集到内存中

也就是说yield 会从上一次取得地方再接着去取数据,而return却不会

def dataset(is_training, params):
    _shapes = ([params['max_len']], ())
    _types = (tf.int32, tf.int32)
  
    if is_training:
        ds = tf.data.Dataset.from_generator(
            lambda: data_generator(params['train_path'], params),
            output_shapes = _shapes,
            output_types = _types,)
        ds = ds.shuffle(params['num_samples'])
        ds = ds.batch(params['batch_size'])
        ds = ds.prefetch(tf.data.experimental.AUTOTUNE)
    else:
        ds = tf.data.Dataset.from_generator(
            lambda: data_generator(params['test_path'], params),
            output_shapes = _shapes,
            output_types = _types,)
        ds = ds.batch(params['batch_size'])
        ds = ds.prefetch(tf.data.experimental.AUTOTUNE)
  
    return ds
  1. 定义一个制作数据集的函数,is_training表示是否是训练,这个函数在验证和测试也会使用,训练的时候设置为True,验证和测试为False
  2. 当前shape值
  3. 1
  4. 是否在训练,如果是:
  5. 构建一个Dataset
  6. 传进我们刚刚定义的生成器函数,并且传进实际的路径和配置参数
  7. 输出的shape值
  8. 输出的类型
  9. 指定shuffle
  10. 指定 batch_size
  11. 设置缓存序列,根据可用的CPU动态设置并行调用的数量,说白了就是加速
  12. 如果不是在训练,则:
  13. 验证和测试不同的就是路径不同,以及没有shuffle操作,其他都一样
  14. 最后把做好的Datasets返回回去

7、自定义网络模型

在这里插入图片描述
一条文本变成一组向量/矩阵的基本流程:

  1. 拿到一个英文句子
  2. 通过查语料表将句子变成一组索引
  3. 通过词嵌入表结合索引将每个单词都变成一组向量,一条句子就变成了一个矩阵,这就是特征了

在这里插入图片描述
在这里插入图片描述
BiLSTM即双向LSTM,就是在原本的LSTM增加了一个从后往前走的模块,这样前向和反向两个方向都各自生成了一组特征,把两个特征拼接起来得到一组新的特征,得到翻倍的特征。其他前面和后续的处理操作都是一样的。

class Model(tf.keras.Model):
    def __init__(self, params):
        super().__init__()
        self.embedding = tf.Variable(np.load('./vocab/word.npy'), dtype=tf.float32, name='pretrained_embedding', trainable=False,)
        self.drop1 = tf.keras.layers.Dropout(params['dropout_rate'])
        self.drop2 = tf.keras.layers.Dropout(params['dropout_rate'])
        self.drop3 = tf.keras.layers.Dropout(params['dropout_rate'])

        self.rnn1 = tf.keras.layers.Bidirectional(tf.keras.layers.LSTM(params['rnn_units'], return_sequences=True))
        self.rnn2 = tf.keras.layers.Bidirectional(tf.keras.layers.LSTM(params['rnn_units'], return_sequences=True))
        self.rnn3 = tf.keras.layers.Bidirectional(tf.keras.layers.LSTM(params['rnn_units'], return_sequences=False))
        self.drop_fc = tf.keras.layers.Dropout(params['dropout_rate'])
        self.fc = tf.keras.layers.Dense(2*params['rnn_units'], tf.nn.elu)
        self.out_linear = tf.keras.layers.Dense(2)  
    def call(self, inputs, training=False):
        if inputs.dtype != tf.int32:
            inputs = tf.cast(inputs, tf.int32)
        batch_sz = tf.shape(inputs)[0]
        rnn_units = 2*params['rnn_units']
        x = tf.nn.embedding_lookup(self.embedding, inputs)        
        x = self.drop1(x, training=training)
        x = self.rnn1(x)
        x = self.drop2(x, training=training)
        x = self.rnn2(x)
        x = self.drop3(x, training=training)
        x = self.rnn3(x)
        x = self.drop_fc(x, training=training)
        x = self.fc(x)
        x = self.out_linear(x)
        return x
  1. 自定义一个模型,继承tf.keras.Model模块
  2. 初始化函数
  3. 初始化
  4. 词嵌入,把之前保存好的词嵌入文件向量读进来
  5. 定义一层dropout1
  6. 定义一层dropout2
  7. 定义一层dropout3
  8. 定义一个rnn1,rnn_units表示得到多少维的特征,return_sequences表示是返回一个序列还是最后一个输出
  9. 定义一个rnn2,最后一层的rnn肯定只需要最后一个输出,前后两个rnn的堆叠肯定需要返回一个序列
  10. 定义一个rnn3 ,tf.keras.layers.LSTM()直接就可以定义一个LSTM,在外面再封装一层API:tf.keras.layers.Bidirectional就实现了双向LSTM
  11. 定义全连接层的dropout
  12. 定义一个全连接层,因为是双向的,这里就需要把参数乘以2
  13. 定义最后输出的全连接层,只需要得到是正例还是负例,所以是2
  14. 定义前向传播函数,传进来一个batch的数据和是否是在训练
  15. 如果输入数据不是tf.int32类型
  16. 转换成tf.int32类型
  17. 取出batch_size
  18. 设置LSTM神经元个数,双向乘以2
  19. 使用 TensorFlow 的 embedding_lookup 函数将输入的整数索引转换为词向量
  20. 数据通过第1个 Dropout 层
  21. 数据通过第1个rnn
  22. 数据通过第2个 Dropout 层
  23. 数据通过第2个rnn
  24. 数据通过第3个 Dropout 层
  25. 数据通过第3个rnn
  26. 经过全连接层对应的Dropout
  27. 数据通过一个全连接层
  28. 最后,数据通过一个输出层
  29. 返回最终的模型输出

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1421901.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

Prometheus基于pod部署

1、kube-api的自动发现:Prometheus的容器化部署(生产中都是pod部署) 2、部署export (1)创建目录 (2)创建命名空间 (3)部署node-export ①9100端口被占用,停…

useEffect的第二个参数

目录 1、第一个参数: 2、第二个参数: 2.1 不传值:无限循环 2.2 空数组作为依赖:执行一次 2.3 基本类型作为依赖:无限循环 2.4 引用类型 2.4.1 数组作为依赖:无限循环 2.4.2 函数作为依赖&#…

添加了gateway之后远程调用失败

前端提示500,后端提示[400 ] during [GET] to [http://userservice/user/1] 原因是这个,因为在请求地址写了两个参数,实际上只传了一个参数 解决方案:加上(required false)并重启所有相关服务

Redis(十)SpringBoot集成Redis

文章目录 连接单机mvnYMLController.javaRedisConfig.java 连接集群YML问题复现 RedisTemplate方式 连接单机 mvn <!--Redis--> <dependency><groupId>org.springframework.boot</groupId><artifactId>spring-boot-starter-data-redis</art…

Qt应用开发(安卓篇)——调用ioctl、socket等C函数

一、前言 在 Qt for Android 中没办法像在嵌入式linux中一样直接使用 ioctl 等底层函数&#xff0c;这是因为因为 Android 平台的安全性和权限限制。 在 Android 中&#xff0c;访问设备硬件和系统资源需要特定的权限&#xff0c;并且需要通过 Android 系统提供的 API 来进行。…

Java链表(2)

&#x1f435;本篇文章将对双向链表进行讲解&#xff0c;模拟实现双向链表的常用方法 一、什么是双向链表 双向链表在指针域上相较于单链表&#xff0c;每一个节点多了一个指向前驱节点的引用prev以及多了指向最后一个节点的引用last&#xff1a; 二、双向链表的模拟实现 首先…

深度学习:机器智能的革命性突破与未来挑战

深度学习&#xff1a;机器智能的革命性突破与未来挑战 深度学习是人工智能领域的一个重要分支&#xff0c;它利用神经网络模拟人类大脑的学习过程&#xff0c;通过大量数据训练模型&#xff0c;使其能够自动提取特征、识别模式、进行分类和预测等任务。近年来&#xff0c;深度学…

忍不了,客户让我在一个接口里兼容多种业务逻辑

分享是最有效的学习方式。 博客&#xff1a;https://blog.ktdaddy.com/ 老猫的设计模式专栏已经偷偷发车了。 不甘愿做crud boy&#xff1f;看了好几遍的设计模式还记不住&#xff1f;那就不要刻意记了&#xff0c;跟上老猫的步伐&#xff0c;在一个个有趣的职场故事中领悟设计…

一文掌握SpringBoot注解之@Component 知识文集(8)

&#x1f3c6;作者简介&#xff0c;普修罗双战士&#xff0c;一直追求不断学习和成长&#xff0c;在技术的道路上持续探索和实践。 &#x1f3c6;多年互联网行业从业经验&#xff0c;历任核心研发工程师&#xff0c;项目技术负责人。 &#x1f389;欢迎 &#x1f44d;点赞✍评论…

山体滑坡在线安全监测预警系统(解决方案)

在近年来&#xff0c;随着全球气候变化的影响&#xff0c;山体滑坡等自然灾害频发&#xff0c;给人们的生命财产安全带来了严重威胁。为了有效预防和减少山体滑坡带来的危害&#xff0c;许多地方开始在山上安装山体滑坡在线安全监测预警系统&#xff08;解决方案&#xff09;。…

【图论】【状态压缩】【树】【深度优先搜索】1617. 统计子树中城市之间最大距离

作者推荐 【动态规划】【字符串】【行程码】1531. 压缩字符串 本文涉及的知识点 图论 深度优先搜索 状态压缩 树 LeetCode1617. 统计子树中城市之间最大距离 给你 n 个城市&#xff0c;编号为从 1 到 n 。同时给你一个大小为 n-1 的数组 edges &#xff0c;其中 edges[i] …

在 Linux 中挂载新硬盘动态使用

目录 一&#xff1a;添加硬盘并且格式化 二&#xff1a;创建逻辑卷 三&#xff1a;挂载卷到目录 在 Linux 中挂载新硬盘并进行格式化的操作可以按照以下步骤进行&#xff1a; 一&#xff1a;添加硬盘并且格式化 查看现有分区状态和服务器安装的硬盘状态&#xff1a; df -…

开源项目TARZAN-NAV | 基于springboot的现代化导航网站系统

TARZAN-NAV 导航网站 一个基于 Spring Boot、MyBatis-Plus、h2database、ehcache、Docker、websocket等技术栈实现的导航网站系统&#xff0c;采用主流的互联网技术架构、全新的UI设计、支持一键源码部署&#xff0c;拥有完整的仪表板、导航管理&#xff0c;用户管理、评论管理…

操作筛选器的 1 个应用实例:自动启用事务

前言 在数据库操作过程中&#xff0c;有一个概念是绕不开的&#xff0c;那就是事务。 事务能够确保一系列数据库操作要么全部成功提交&#xff0c;要么全部失败回滚&#xff0c;保证数据的一致性和完整性。 在 Asp.Net Core Web API 中&#xff0c;我们可以使用操作筛选器给…

uni-app小程序自定义导航栏

最近在开发一个uni-app小程序&#xff0c;用到了自定义导航栏&#xff0c;在这里记录一下实现过程&#xff1a; page.json 在对应页面路由的style中设置入"navigationStyle": "custom"取消原生导航栏&#xff0c;自定义导航栏 {"path": "…

C++ //练习 3.8 分别用while循环和传统的for循环重写第一题的程序,你觉得哪种形式更好呢?为什么?

C Primer&#xff08;第5版&#xff09; 练习 3.8 练习 3.8 分别用while循环和传统的for循环重写第一题的程序&#xff0c;你觉得哪种形式更好呢&#xff1f;为什么? 环境&#xff1a;Linux Ubuntu&#xff08;云服务器&#xff09; 工具&#xff1a;vim 代码块 /********…

[晓理紫]每日论文分享(有源码或项目地址、中文摘要)--强化学习、模仿学习、机器人

专属领域论文订阅 VX 关注{晓理紫},每日更新论文,如感兴趣,请转发给有需要的同学,谢谢支持 如果你感觉对你有所帮助,请关注我,每日准时为你推送最新论文。 为了答谢各位网友的支持,从今日起免费为300名读者提供订阅主题论文服务,只需VX关注公号并回复{邮箱+论文主题}(…

深度强化学习(王树森)笔记10

深度强化学习&#xff08;DRL&#xff09; 本文是学习笔记&#xff0c;如有侵权&#xff0c;请联系删除。本文在ChatGPT辅助下完成。 参考链接 Deep Reinforcement Learning官方链接&#xff1a;https://github.com/wangshusen/DRL 源代码链接&#xff1a;https://github.c…

KAFKA高可用架构涉及常用功能整理

KAFKA高可用架构涉及常用功能整理 1. kafka的高可用系统架构和相关组件2. kafka的核心参数2.1 常规配置2.2 特殊优化配置 3. kafka常用命令3.1 常用基础命令3.1.1 创建topic3.1.2 获取集群的topic列表3.1.3 获取集群的topic详情3.1.4 删除集群的topic3.1.5 获取集群的消费组列表…

如何使用 Google 搜索引擎保姆级教程(附链接)

一、介绍 "Google语法"通常是指在 Google 搜索引擎中使用一系列特定的搜索语法和操作符来精确地定义搜索查询。这些语法和操作符允许用户过滤和调整搜索结果&#xff0c;提高搜索的准确性。 二、安装 Google 下载 Google 浏览器 Google 官网https://www.google.c…