Coggle 30 Days of ML(23年7月)任务七:训练TextCNN模型

news2024/11/16 15:56:05

Coggle 30 Days of ML(23年7月)任务七:训练TextCNN模型

任务七:使用Word2Vec词向量,搭建TextCNN模型进行训练和预测

  • 说明:在这个任务中,你将使用Word2Vec词向量,搭建TextCNN模型进行文本分类的训练和预测,通过卷积神经网络来进行文本分类。
  • 实践步骤:
    1. 准备Word2Vec词向量模型和相应的训练数据集。
    2. 构建TextCNN模型,包括卷积层、池化层、全连接层等。
    3. 将Word2Vec词向量应用到模型中,作为词特征的输入。
    4. 使用训练数据集对TextCNN模型进行训练。
    5. 使用训练好的TextCNN模型对测试数据集进行预测

导入训练好的Word2Vec模型

由于上一部分我们已经训练好了我们的模型,所以这一部分我们直接导入即可

# 准备Word2Vec词向量模型和训练数据集
word2vec_model = Word2Vec.load("word2vec.model")

在数据分析的时候,我们已经发现,词语的数量是不一的,所以首先我们先对数据进行处理,将文本序列转化为词向量表示,并且填充为长度为200

# 获取Word2Vec词向量的维度
embedding_dim = word2vec_model.vector_size

# 转换训练数据集的文本序列为词向量表示,并进行填充
train_sequences = []
for text in train_data:
    sequence = [word2vec_model.wv[word] for word in text if word in word2vec_model.wv]
    padded_sequence = pad_sequences([sequence], maxlen=max_length, padding='post', truncating='post')[0]
    train_sequences.append(padded_sequence)

# 转换测试数据集的文本序列为词向量表示,并进行填充
test_sequences = []
for text in test_data:
    sequence = [word2vec_model.wv[word] for word in text if word in word2vec_model.wv]
    padded_sequence = pad_sequences([sequence], maxlen=max_length, padding='post', truncating='post')[0]
    test_sequences.append(padded_sequence)

构建TextCNN模型

在这里插入图片描述

接下来我们就开始构建一下TextCNN模型,包括卷积层、池化层、全连接层等,这样我们就初步得到一个非常简单的模型了

# 构建TextCNN模型
model = tf.keras.Sequential()
model.add(layers.Conv1D(128, 5, activation='relu', input_shape=(max_length, embedding_dim)))
model.add(layers.GlobalMaxPooling1D())
model.add(layers.Dense(64, activation='relu'))
model.add(layers.Dense(num_classes, activation='softmax'))

训练模型

接下来我们就可以开始训练我们的模型了,这里使用了SGD优化器进行操作,在训练之前,我们还需要把训练数据集的标签转换为one-hot编码

# 设置优化器和学习率
optimizer = optimizers.SGD(learning_rate=0.1)  # 使用SGD优化器,并设置学习率为0.1

# 编译模型
model.compile(optimizer=optimizer, loss='categorical_crossentropy', metrics=['accuracy'])


# 转换训练数据集的标签为one-hot编码
train_labels = tf.keras.utils.to_categorical(train_labels)

# 训练模型
model.fit(np.array(train_sequences), train_labels, epochs=5, batch_size=32)
Epoch 1/5
438/438 [==============================] - 5s 6ms/step - loss: 0.4385 - accuracy: 0.8454
Epoch 2/5
438/438 [==============================] - 2s 5ms/step - loss: 0.4309 - accuracy: 0.8454
Epoch 3/5
438/438 [==============================] - 3s 6ms/step - loss: 0.4308 - accuracy: 0.8454
Epoch 4/5
438/438 [==============================] - 2s 6ms/step - loss: 0.4307 - accuracy: 0.8454
Epoch 5/5
438/438 [==============================] - 2s 6ms/step - loss: 0.4309 - accuracy: 0.8454

可能是模型太简单了,所以可以看到,通过训练以后,准确率也没有较大的提升,还可以继续改进

预测与提交

最后使用训练好的TextCNN模型对测试数据集进行预测,得到csv数据以后进行提交

# 预测测试数据集的分类结果
predictions = model.predict(np.array(test_sequences))
predicted_labels = predictions.argmax(axis=1)

# 读取提交样例文件
submit = pd.read_csv('ChatGPT/sample_submit.csv')
submit = submit.sort_values(by='name')

# 将预测结果赋值给提交文件的label列
submit['label'] = predicted_labels

# 保存提交文件
submit.to_csv('ChatGPT/textcnn.csv', index=None)

总结

这个TextCNN模型太过于简单,所以以至于可能没有学习到很多的数据,接下来可以进行调参和设置合理的模型结构,以期得到更好的结果,再接再厉,加油!!!

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/738278.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

【LeetCode热题100】打卡第32天:最长连续序列只出现一次的数字单词拆分环形链表

文章目录 【LeetCode热题100】打卡第32天:最长连续序列&只出现一次的数字&单词拆分&环形链表⛅前言 最长连续序列🔒题目🔑题解 只出现一次的数字🔒题目🔑题解 单词拆分🔒题目🔑题解…

webAPI学习笔记5——移动端网页特效和本地存储

一、移动端网页特效 1. 触屏事件 1.1 触屏事件概述 移动端浏览器兼容性较好,我们不需要考虑以前 JS 的兼容性问题,可以放心的使用原生 JS 书写效果,但是移动端也有自己独特的地方。比如触屏事件 touch(也称触摸事件&#xff09…

联想M7605DW怎么连接WiFi网络

联想M7605DW是一款拥有WiFi功能的打印机,可以通过WiFi连接无线网络,实现打印无线传输。 首先,需要确保你的WiFi网络已经正常连接,并且知道WiFI的网络名称和密码,同时确保你的电脑或手机设备与WiFi相连接。 启动联想M76…

数组、指针练习题及解析(含笔试题目讲解)其一

目录 前言 题目列表: 题目解析 一维数组 字符数组 字符串 字符指针 二维数组 笔试题 总结 前言 前几期的博客已经将有关指针、数组的所以知识都已基本讲解完毕,那么接下来我们就做一些练习巩固,这些练习依据历年来一些公司笔试题进行…

java的ThreadLocal变量

Java的ThreadLocal变量是线程的局部变量,只能被本线程访问,不能被其它线程访问,也就是说线程间进行了隔离。每个线程访问该变量的一个独立拷贝,互相不干扰。感觉跟synchronized的作用相反,synchronized是为了保护线程间…

Kafka入门,mysql5.7 Kafka-Eagle部署(二十五)

官网 https://www.kafka-eagle.org/ 下载解压 这里使用的是2.0.8 创建mysql数据库 创建名为ke数据库,新版本会自动创建,不会创建的话,自己手动创建,不然会报查不到相关表信息错误 SET NAMES utf8; SET FOREIGN_KEY_CHECKS 0;-- ------…

从2023中国峰会,看亚马逊云科技的生成式AI战略

“生成式AI的发展就像一场马拉松比赛,当比赛刚刚开始时,如果只跑了三四步就断言某某会赢得这场比赛,显然是不合理的。我们现在还处于非常早期的阶段。” 近日,在2023亚马逊云科技中国峰会上,亚马逊云科技全球产品副总裁…

智慧农业:温室大棚物联网系统,助力实现可视化科学管理

我国传统农业的特点是靠天吃饭,而智慧农业发端于物联网设备和对应的农业信息化管理系统,是利用数字技术、数据分析和人工智能等先进技术手段,对农业生产进行精细化管理和智能化决策的一种新型农业生产模式。它可以通过实时监测、预测和调控土…

java 配置打包Spring Boot项目过程中跳过测试环节

上文 java 打包Spring Boot项目,并运行在windows系统中中 我们演示了打包 Spring Boot项目的并运行在本地的方法 但是 我们这里会看到 每次打包 他这都会有个T E S T S 测试的部分 但是 我们自己开发的程序 要上线 有没有问题我们肯定自己清楚啊 没必要它做测试 而且…

web学习笔记2

文档流 网页是一个多层的结构,设置样式也是一层一层的设置,最终我们看到的最上面的一层。 文档流是网页最底层 我们创建的元素默认情况下,都在文档流中 元素分为两种状态:在文档流中,脱离文档流 元素在文档流中的特点 …

同一段数据分别做傅里叶变化和逆变换的结果及分析

已知有公式 D F T : X [ k ] ∑ n 0 N − 1 x [ n ] e − j 2 π k n N , 0 ≤ k ≤ N − 1 DFT:Χ[k]\sum_{n0}^{N-1}x[n]e^{-\frac{j2\pi kn}{N}},0≤k≤N-1 DFT:X[k]n0∑N−1​x[n]e−Nj2πkn​,0≤k…

超详细 | 模拟退火-粒子群自适应优化算法及其实现(Matlab)

作者在前面的文章中介绍了经典的优化算法——粒子群算法(PSO),各种智能优化算法解决问题的方式和角度各不相同,都有各自的适用域和局限性,对智能优化算法自身做的改进在算法性能方面得到了一定程度的提升,但算法缺点的解决并不彻底…

学生公寓智能电表控电系统的技术要求

学生公寓电表智能控电石家庄光大远通电气有限公司模块采用高精度计量芯片,的计量计费功能。 控制路数:可输出1~4路输出,每个回路都可以设置负载识别,定时断送过载功率等控电参数。 自动断电 :具有自动断电功能,可用电量为0时,应自动切断该分路电源 支持正…

创建Spring CloudDEMO流程

创建普通的maven工程作为父工程 然后设置字符集为UTF-8 再注解生效激活 java编译版本选择8 idea文件忽略(忽略乱七八糟的文件) *.hprof;*.pyc;*.pyo;*.rbc;*.yarb;*~;.DS_Store;.git;.hg;.svn;CVS;__pycache__;_svn;vssver.scc;vssver2.scc;.idea;*.iml…

TencentOS3.1安装PHP+Nginx+redis测试系统

PHP和Nginx应用统一安装在/application下。 Nginx选用了较新的版本1.25.0 官网下载安装包,解包。执行如下命令编译: ./configure --prefix/application/nginx-1.25.0 --usernginx --groupnginx --with-http_ssl_module --with-http_stub_status_modu…

win系统电脑在线打开sketch文件的方法

自Sketch诞生以来,只有Mac版本。Windows计算机如何在线打开Sketch文件? 即时设计已经解决了你遇到的大部分问题,不占用内存也是免费的。 您可以使用此软件直接在线打开Sketch文件,完整预览并导出CSS、SVG、PNG等,还具…

解析JSON格式数据

解析JSON格式数据 比起XML,JSON的体积更小,语义性更差 传入的JSON文件如下 使用JSONObject private fun parseJSONWithJSONObject(jsonData: String) { try { val jsonArray JSONArray(jsonData) for (i in 0 until jsonArray.length()){ val j…

视频去除水印怎么弄?这几个实用方法分享给大家!

在我们观看或分享视频时,可能会遇到一些带有水印的视频。这些水印可能会影响我们的观看体验,或者在我们需要使用这些视频时造成不便。下面,我将为你介绍三种去除视频水印的方法。 方法一:使用记灵在线工具 记灵在线工具是一个非…

Leetcode:684. 冗余连接(并查集C++)

目录 684. 冗余连接 题目描述: 实现代码与解析: 并查集 原理思路: 684. 冗余连接 题目描述: 树可以看成是一个连通且 无环 的 无向 图。 给定往一棵 n 个节点 (节点值 1~n) 的树中添加一条边后的图。添加的边的…

Python安装解释器

文章目录 一、下载Python解释器二. Linux环境的安装三. pycharm创建项目四、验证安装是否成功 一、下载Python解释器 首先,您需要从官方Python网站(https://python.org)下载Python解释器。Python的当前稳定版本是3.9.x系列。网站上提供了针对…