基于 LSTM 进行多类文本分类(附源码)

news2025/1/24 6:20:55

NLP 的许多创新是如何将上下文添加到词向量中。一种常见的方法是使用循环神经网络。以下是循环神经网络的概念:

  • 他们利用顺序信息。

  • 他们可以捕捉到到目前为止已经计算过的内容,即:我最后说的内容会影响我接下来要说的内容。

  • RNNs 是文本和语音分析的理想选择。

  • 最常用的 RNNs 是 LSTM。

图片

来源:https://colah.github.io/posts/2015-08-Understanding-LSTMs/

以上是循环神经网络的架构:

  • “A”是一层前馈神经网络。

  • 如果我们只看右边,它确实会循环通过每个序列的元素。

  • 如果我们打开左边,它看起来就像右边。

图片

来源:https://colah.github.io/posts/2015-08-Understanding-LSTMs

假设我们正在解决新闻文章数据集的文档分类问题:

  • 我们输入每个单词,单词以某种方式相互关联。

  • 当我们看到那篇文章中的所有单词时,我们会在文章末尾做出预测。

  • RNNs 通过传递来自最后输出的输入,能够保留信息,并能够在最后利用所有信息进行预测。

图片

https://colah.github.io/posts/2015-08-Understanding-LSTMs

  • 这适用于短句,当我们处理长文章时,会出现长期依赖问题。

因此,我们一般不使用 vanilla RNN,而是使用 Long Short Term Memory。LSTM 是一种可以解决这种长期依赖问题的 RNN。

图片

在我们的新闻文章文档分类示例中,我们有这种多对一的关系。输入是单词序列,输出是单个类或标签。

技术交流&源码

技术要学会分享、交流,不建议闭门造车。一个人可以走的很快、一堆人可以走的更远。

相关资料、数据、技术交流提升,均可加我们的交流群获取,群友已超过2000人,添加时最好的备注方式为:来源+兴趣方向,方便找到志同道合的朋友。

方式①、添加微信号:dkl88194,备注:来自CSDN + 技术资料
方式②、微信搜索公众号:Python学习与数据挖掘,后台回复:LSTM

现在我们将使用 TensorFlow 2.0 和 Keras,基于 LSTM 解决 BBC 新闻文档分类问题。

数据链接:https://raw.githubusercontent.com/susanli2016/PyCon-Canada-2019-NLP-Tutorial/master/bbc-text.csv

  • 首先,我们导入相关库并确保我们的 TensorFlow 是正确的版本。
import csv
import tensorflow as tf
import numpy as np
from tensorflow.keras.preprocessing.text import Tokenizer
from tensorflow.keras.preprocessing.sequence import pad_sequences
from nltk.corpus import stopwords
STOPWORDS = set(stopwords.words('english'))

print(tf.__version__)

2.0.0

  • 像这样将超参数放在顶部,以便更容易更改和编辑。

  • 当我们用到时,再解释每个超参数是如何工作的。

vocab_size = 5000
embedding_dim = 64
max_length = 200
trunc_type = 'post'
padding_type = 'post'
oov_tok = '<OOV>'
training_portion = .8

hyperparamenter.py

  • 定义两个包含文章和标签的列表。,与此同时,我们过滤掉了停用词。
articles = []
labels = []

with open("bbc-text.csv", 'r') as csvfile:
    reader = csv.reader(csvfile, delimiter=',')
    next(reader)
    for row in reader:
        labels.append(row[0])
        article = row[1]
        for word in STOPWORDS:
            token = ' ' + word + ' '
            article = article.replace(token, ' ')
            article = article.replace(' ', ' ')
        articles.append(article)
print(len(labels))
print(len(articles))

articles_labels.py

2225

2225

数据中有 2225 篇新闻文章,我们将它们分为训练集和验证集,根据我们之前设置的参数,80% 用于训练,20% 用于验证。

train_size = int(len(articles) * training_portion)

train_articles = articles[0: train_size]
train_labels = labels[0: train_size]

validation_articles = articles[train_size:]
validation_labels = labels[train_size:]

print(train_size)
print(len(train_articles))
print(len(train_labels))
print(len(validation_articles))
print(len(validation_labels))

train_valid.py

1780

1780

1780

445

445

Tokenizer 为我们完成了所有繁重的工作。在我们对其进行标记的文章中,它将使用 5,000 个最常用的单词。oov_token 是在遇到看不见的单词时放入一个特殊值。这意味着我们希望用于不在 word_index 中的单词。fit_on_text 将遍历所有文本并创建如下字典:

tokenizer = Tokenizer(num_words = vocab_size, oov_token=oov_tok)
tokenizer.fit_on_texts(train_articles)
word_index = tokenizer.word_index
dict(list(word_index.items())[0:10])

tokenize.py

图片

我们可以看到“”是我们语料库中最常见的标记,其次是“said”,然后是“mr”等等。

标记化之后,下一步是将这些标记转换为序列列表。以下是已转为序列的训练数据中的第 11 篇文章。

train_sequences = tokenizer.texts_to_sequences(train_articles)
print(train_sequences[10])

图片

当我们为 NLP 训练神经网络时,我们需要数据长度大小相同,这就是我们使用填充的原因。如果你查一下,我们的 max_length 是 200,所以我们使用 pad_sequences 使我们所有文章的长度都相同,即 200。结果,你会看到第 1 篇文章的长度是 426,它变成了 200,第 2 篇文章 长度为 192,变为 200,依此类推。

train_padded = pad_sequences(train_sequences, maxlen=max_length, padding=padding_type, truncating=trunc_type)
print(len(train_sequences[0]))
print(len(train_padded[0]))

print(len(train_sequences[1]))
print(len(train_padded[1]))

print(len(train_sequences[10]))
print(len(train_padded[10]))

425

200

192

200

186

200

此外,还有padding_type和truncating_type,都有所有帖子,例如,对于第11篇文章,长度为186,我们填充了200个,我们在末端填充了14个零。

print(train_padded[10])

图片

对于第一篇文章,它的长度为426,我们截断为200。

然后,我们为验证集合做同样的事情。

validation_sequences = tokenizer.texts_to_sequences(validation_articles)
validation_padded = pad_sequences(validation_sequences, maxlen=max_length, padding=padding_type, truncating=trunc_type)

print(len(validation_sequences))
print(validation_padded.shape)

445

(445, 200)

现在,我们将查看标签。因为我们的标签是文本,所以我们将在训练时将其贴上标签,预计标签将为Numpy矩阵。因此,我们将标签列表变成像这样的Numpy矩阵格式:

label_tokenizer = Tokenizer()
label_tokenizer.fit_on_texts(labels)

training_label_seq = np.array(label_tokenizer.texts_to_sequences(train_labels))
validation_label_seq = np.array(label_tokenizer.texts_to_sequences(validation_labels))
print(training_label_seq[0])
print(training_label_seq[1])
print(training_label_seq[2])
print(training_label_seq.shape)

print(validation_label_seq[0])
print(validation_label_seq[1])
print(validation_label_seq[2])
print(validation_label_seq.shape)

图片

在训练深度神经网络之前,我们可以抽查原始文章和填充后的文章。运行以下代码,我们查看了第11篇文章,我们可以看到有些单词变成“”,因为它们没有进入前5,000名。

reverse_word_index = dict([(value, key) for (key, value) in word_index.items()])

def decode_article(text):
    return ' '.join([reverse_word_index.get(i, '?') for i in text])
print(decode_article(train_padded[10]))
print('---')
print(train_articles[10])

图片

现在是训练LSTM的时候了。

  • 我们构建一个tf.keras.Sequential模型,并从 embeddings 开始。一个 embeddings 将每个单词存储为一个矢量。当调用时,它将单词索引序列转换为向量序列。训练后,具有相似含义的单词通常具有相似的向量。

  • Bidirectional wrapper 与LSTM层一起使用,这可以通过LSTM层向前和向后传播输入,然后将输出串联。这有助于LSTM学习长期依赖性。

  • 我们使用 relu 代替 tahn。

  • 我们添加一个具有6个单元和 SoftMax 函数。当我们有多个输出时,SoftMax 将输出层转换为概率分布。

model = tf.keras.Sequential([
    # Add an Embedding layer expecting input vocab of size 5000, and output embedding dimension of size 64 we set at the top
    tf.keras.layers.Embedding(vocab_size, embedding_dim),
    tf.keras.layers.Bidirectional(tf.keras.layers.LSTM(embedding_dim)),
#    tf.keras.layers.Bidirectional(tf.keras.layers.LSTM(32)),
    # use ReLU in place of tanh function since they are very good alternatives of each other.
    tf.keras.layers.Dense(embedding_dim, activation='relu'),
    # Add a Dense layer with 6 units and softmax activation.
    # When we have multiple outputs, softmax convert outputs layers into a probability distribution.
    tf.keras.layers.Dense(6, activation='softmax')
])
model.summary()

图片

在我们的模型 summary 中,我们有 embeddings,Bidirectional LSTM,然后是两个 dense 层。双向的输出为128,因为它使我们在LSTM中输入的输出增加了一倍。我们也可以堆叠LSTM层,但我发现结果更糟。

print(set(labels))

图片

我们总共有5个标签,但是由于我们没有使用 one-hot 编码标签,因此我们必须使用 sparse_categorical_crossentropy 作为损失函数,因此似乎也认为0也是一个可能的标签。因此,最后一个密集层需要标签0、1、2、3、4、5的输出,而不是整数,尽管从未使用过0。

如果您希望最后一个密集的层是5,则需要从训练集和验证集的标签中减去1。(此处保留)

我决定训练10个epoch。

model.compile(loss='sparse_categorical_crossentropy', optimizer='adam', metrics=['accuracy'])
num_epochs = 10
history = model.fit(train_padded, training_label_seq, epochs=num_epochs, validation_data=(validation_padded, validation_label_seq), verbose=2)

图片

def plot_graphs(history, string):
  plt.plot(history.history[string])
  plt.plot(history.history['val_'+string])
  plt.xlabel("Epochs")
  plt.ylabel(string)
  plt.legend([string, 'val_'+string])
  plt.show()
  
plot_graphs(history, "accuracy")
plot_graphs(history, "loss")

图片

我们可能只需要3或4个epoch。(在训练结束时,我们可以看到有点过于拟合)

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1078178.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

一次性读懂Mendix的库间“数据同步”功能

Data sync&#xff0c;对于那些深谙其道的技术高手而言&#xff0c;意义不言自明。然鹅对整天在村工厂里打螺丝的我来说&#xff0c;却经历了一段难捱的时期。时至今日&#xff0c;我仍然时不时地选择性地遗忘某些概念和技术点。因此&#xff0c;本文章记录我之前一点实操的心得…

【测试】robotframework安装

目录 python安装 pip一系列安装 运行效果 参考文档 python安装 注意管理员权限安装&#xff0c;不然2503的错误 。 pip一系列安装 pip install robotframework pip install wxPython pip install robotframework-ride 运行python ride.py pip install setuptools 解决…

125KHz低频接收唤醒芯片:Si3933(TSSOP16)

Si3933 具有内部时钟产生器&#xff0c;可使用晶体振荡器或者RC振荡器&#xff0c;也可以使用外部时钟。 Si3933 是一款三通道的D功耗ASK接 收机&#xff0c;可用于检测15KHz-150KHz低频载波频率的数字信号&#xff0c;并产生唤醒信号。内部集成的校验器用于检测 16 位或 32 位…

双态IT乌镇大会 | 首批《数据中心业务连续性等级评价准则》试点单位将诞生

2023年10月13日-15日&#xff0c;由ITSS分会、证券基金行业信息技术应用创新联盟指导&#xff0c;ITSS数据中心运营管理组&#xff08;DCMG&#xff09;、双态IT论坛、智能运维国标工作组主办&#xff0c;ITSS媒体组、AI范儿协办的“2023第六届双态IT乌镇用户大会”将于浙江乌镇…

SpringBoot采用Dynamic-Datasource方式实现多JDBC数据源

目录 1. Dynamic-Datasource实现多JDBC数据源配置1.1 特性1.2 Mysql数据准备2.2 通过Dynamic-Datasource实现多JDBC数据源2.2.1 pom.xml依赖 2.2.2 application.properties配置2.2.3 使用DS注解选择DataSource2.2.4 使用Transactional DSTransactional实现事务 2.3 动态数据源…

下一代架构设计:云原生、容器和微前端的综合应用

文章目录 云原生&#xff1a;构建可弹性扩展的应用1. 微服务架构2. 容器化3. 自动化和自动扩展 容器化和云原生的结合1. 一致性和可移植性2. 弹性和可伸缩性3. 快速部署和更新4. 资源利用率 微前端&#xff1a;前端架构的演进1. 微前端应用2. 统一的外壳应用3. 独立部署 云原生…

TikTok在跨境电商中的作用:挖掘潜在客户的最佳途径

​随着全球数字化浪潮的不断发展&#xff0c;跨境电商行业也经历了巨大的变革。传统的市场营销渠道已经不再足够&#xff0c;企业们需要不断探寻新的方法来吸引潜在客户。在这个过程中&#xff0c;社交媒体平台TikTok逐渐崭露头角&#xff0c;成为了吸引潜在客户的一个选择。本…

[PwnThyBytes 2019]Baby_SQL - 代码审计+布尔盲注+SESSION_UPLOAD_PROGRESS利用

[PwnThyBytes 2019]Baby_SQL 1 解题流程1.1 分析1.2 解题 2 思考总结 1 解题流程 1.1 分析 此题参考文章&#xff1a;浅谈 SESSION_UPLOAD_PROGRESS 的利用 访问正常来讲用ctf-wscan是能扫出source.zip文件的&#xff0c;且F12后提示了有source.zip&#xff0c;那我们就下载…

Apache POI使用

1.导入坐标 <!-- poi --><dependency><groupId>org.apache.poi</groupId><artifactId>poi</artifactId><version>${poi}</version></dependency><dependency><groupId>org.apache.poi</groupId><a…

elasticSearch7.9数据占用磁盘存储空间情况

最近&#xff0c;在VMware Workstation虚拟机上安装了es7.9&#xff0c;单节点的es&#xff0c;不是集群&#xff0c;然后建了一个索引&#xff08;包含3个分片和一个副本&#xff09;&#xff0c;插入了500万条数据&#xff0c;占据磁盘空间17G。如下图&#xff1a; 索引的字…

什么样的人适合下班后做点兼职副业

我们身边不乏一些讨论兼职副业的人&#xff0c;可是很多人都只停留在“想”的层面上&#xff0c;真正有执行力的人早就偷偷做起了副业&#xff0c;能力强的还做得风生水起。 什么样的人适合下班后做点副业呢&#xff1f;我觉得下班后&#xff0c;时间很宽裕&#xff0c;或者经济…

S7-1200PLC与昆仑通态触摸屏通讯

测试环境&#xff1a;Win10、MCGS、博图V16、1214DCDCDC 博途工控人平时在哪里技术交流博途工控人社群 博途工控人平时在哪里技术交流博途工控人社群 将PLC端做如下配置 1-MCGS配置S7-1200驱动 1.1-添加驱动 双击设备窗口 点击设备组态窗口下的设备管理&#xff0c;选择西门…

串级/级联控制知识点整理

串级控制系统是改善控制质量的有效方法之一&#xff0c;在过程控制中得到了广泛的应用。所谓串级控制&#xff0c;就是采用两个控制器串联工作&#xff0c;外环控制器的输出作为内环控制器的设定值&#xff0c;由内环控制器的输出去操纵控制阀&#xff0c;从而对外环被控量具有…

【力扣LCP】速算机器人

&#x1f451;专栏内容&#xff1a;力扣刷题⛪个人主页&#xff1a;子夜的星的主页&#x1f495;座右铭&#xff1a;前路未远&#xff0c;步履不停 目录 一、题目描述二、题目分析1、常规解法2、取巧解法 一、题目描述 题目链接&#xff1a;力扣LCP.14 速算机器人 小扣在秋日…

app如何新增广告位以提升广告变现收益?

app广告位资源是平台变现能力之一&#xff0c;广告位资源包括开屏广告、首页轮播广告、首页弹窗等大家熟知的广告位&#xff0c;流量主为了获得更高的收益&#xff0c;通常会考虑在应用中增加广告位。 增设新的广告位&#xff0c;流量主应该从以下几方面考虑。 1、广告类型 …

overflow真实使用场景-表格最右侧显示空白

问题 先看问题。下方滚动条滚动到右侧之后上下都有空白&#xff0c;但是缩放之后正常。分析之后是overflow的问题。 overflow作用是什么&#xff1f; overflow在内容大于元素框高度或者宽度时候设置&#xff0c;保证内容显示正常。 单独一个内容大于元素框高度或者宽度比较…

手机端下载文件时显示0B问题

文章目录 下载文件时显示文件大小如果是OutputStream输出流&#xff0c;如何设置大小扩展问题pdfjs预览pdf文件时遇到的问题 下载文件时显示文件大小 设置下载文件的大小 File filenew File("D:/test.txt");response.setHeader("Accept-Ranges","byt…

Axios 封装

请注意以下文件夹: utils下的setToken.js 是token封装(封装 Token-CSDN博客),service.js 是axios封装。 Axios封装: 1.安装axios 在项目终端下 输入: npm install axios --save 2.在main.js全局引入axios import axios from axiosVue.prototype.$axios =axios //挂…

python psutil库之——获取网络信息(网络接口信息、网络配置信息、以太网接口、ip信息、ip地址信息)

文章目录 使用Python psutil库获取网络信息安装psutil库获取网络连接信息查看所有网络连接过滤特定状态的连接 获取网络接口信息获取网络IO统计信息实例1实例2 总结 使用Python psutil库获取网络信息 Python的psutil库是一个跨平台库&#xff0c;能够方便地获取系统使用情况和…

C200/10/1/1/1/00 VPM04D300000 VDM01U30AL00

C200/10/1/1/1/00 VPM04D300000 VDM01U30AL00 受其客户对集成、远程和日益自主的运营的关注&#xff0c;横河于2022年6月6日推出了OpreX Asset Health Insights&#xff0c;以使资产数据更加可见、集成和可操作。 Asset Health Insights的原始版本支持Amazon Web Services和…