基于 LSTM 进行多类文本分类( TensorFlow 2.0)

news2024/11/30 12:28:35

NLP 的许多创新是如何将上下文添加到词向量中。一种常见的方法是使用循环神经网络。以下是循环神经网络的概念:

  • 他们利用顺序信息。

  • 他们可以捕捉到到目前为止已经计算过的内容,即:我最后说的内容会影响我接下来要说的内容。

  • RNNs 是文本和语音分析的理想选择。

  • 最常用的 RNNs 是 LSTM。

c057cfc6109aab084100f1ef0a824c75.png

来源:https://colah.github.io/posts/2015-08-Understanding-LSTMs/

以上是循环神经网络的架构:

  • “A”是一层前馈神经网络。

  • 如果我们只看右边,它确实会循环通过每个序列的元素。

  • 如果我们打开左边,它看起来就像右边。

08115c9078c45b929c5d45d700b33edf.png

来源:https://colah.github.io/posts/2015-08-Understanding-LSTMs

假设我们正在解决新闻文章数据集的文档分类问题:

  • 我们输入每个单词,单词以某种方式相互关联。

  • 当我们看到那篇文章中的所有单词时,我们会在文章末尾做出预测。

  • RNNs 通过传递来自最后输出的输入,能够保留信息,并能够在最后利用所有信息进行预测。

ebd18260e88a16df07a2ae0cbb49f25c.png

https://colah.github.io/posts/2015-08-Understanding-LSTMs

  • 这适用于短句,当我们处理长文章时,会出现长期依赖问题。

因此,我们一般不使用 vanilla RNN,而是使用 Long Short Term Memory。LSTM 是一种可以解决这种长期依赖问题的 RNN。

bc2da8410727e2aadc604ae32104b349.png

在我们的新闻文章文档分类示例中,我们有这种多对一的关系。输入是单词序列,输出是单个类或标签。

现在我们将使用 TensorFlow 2.0 和 Keras,基于 LSTM 解决 BBC 新闻文档分类问题。

数据链接:https://raw.githubusercontent.com/susanli2016/PyCon-Canada-2019-NLP-Tutorial/master/bbc-text.csv

  • 首先,我们导入相关库并确保我们的 TensorFlow 是正确的版本。

import csv
import tensorflow as tf
import numpy as np
from tensorflow.keras.preprocessing.text import Tokenizer
from tensorflow.keras.preprocessing.sequence import pad_sequences
from nltk.corpus import stopwords
STOPWORDS = set(stopwords.words('english'))


print(tf.__version__)

2.0.0

  • 像这样将超参数放在顶部,以便更容易更改和编辑。

  • 当我们用到时,再解释每个超参数是如何工作的。

vocab_size = 5000
embedding_dim = 64
max_length = 200
trunc_type = 'post'
padding_type = 'post'
oov_tok = '<OOV>'
training_portion = .8

hyperparamenter.py

  • 定义两个包含文章和标签的列表。,与此同时,我们过滤掉了停用词。

articles = []
labels = []


with open("bbc-text.csv", 'r') as csvfile:
    reader = csv.reader(csvfile, delimiter=',')
    next(reader)
    for row in reader:
        labels.append(row[0])
        article = row[1]
        for word in STOPWORDS:
            token = ' ' + word + ' '
            article = article.replace(token, ' ')
            article = article.replace(' ', ' ')
        articles.append(article)
print(len(labels))
print(len(articles))

articles_labels.py

2225

2225

数据中有 2225 篇新闻文章,我们将它们分为训练集和验证集,根据我们之前设置的参数,80% 用于训练,20% 用于验证。

train_size = int(len(articles) * training_portion)


train_articles = articles[0: train_size]
train_labels = labels[0: train_size]


validation_articles = articles[train_size:]
validation_labels = labels[train_size:]


print(train_size)
print(len(train_articles))
print(len(train_labels))
print(len(validation_articles))
print(len(validation_labels))

train_valid.py

1780

1780

1780

445

445

Tokenizer 为我们完成了所有繁重的工作。在我们对其进行标记的文章中,它将使用 5,000 个最常用的单词。oov_token 是在遇到看不见的单词时放入一个特殊值。这意味着我们希望用于不在 word_index 中的单词。fit_on_text 将遍历所有文本并创建如下字典:

tokenizer = Tokenizer(num_words = vocab_size, oov_token=oov_tok)
tokenizer.fit_on_texts(train_articles)
word_index = tokenizer.word_index
dict(list(word_index.items())[0:10])

tokenize.py

3e2764179d5b54da4c944c3cfa3cf352.png

我们可以看到“”是我们语料库中最常见的标记,其次是“said”,然后是“mr”等等。

标记化之后,下一步是将这些标记转换为序列列表。以下是已转为序列的训练数据中的第 11 篇文章。

train_sequences = tokenizer.texts_to_sequences(train_articles)
print(train_sequences[10])

07bb884420705b707860196ee1ce420b.png

当我们为 NLP 训练神经网络时,我们需要数据长度大小相同,这就是我们使用填充的原因。如果你查一下,我们的 max_length 是 200,所以我们使用 pad_sequences 使我们所有文章的长度都相同,即 200。结果,你会看到第 1 篇文章的长度是 426,它变成了 200,第 2 篇文章 长度为 192,变为 200,依此类推。

train_padded = pad_sequences(train_sequences, maxlen=max_length, padding=padding_type, truncating=trunc_type)
print(len(train_sequences[0]))
print(len(train_padded[0]))


print(len(train_sequences[1]))
print(len(train_padded[1]))


print(len(train_sequences[10]))
print(len(train_padded[10]))

425

200

192

200

186

200

此外,还有padding_type和truncating_type,都有所有帖子,例如,对于第11篇文章,长度为186,我们填充了200个,我们在末端填充了14个零。

print(train_padded[10])

6b8c0d9ba9366be5892bdba001f0d849.png

对于第一篇文章,它的长度为426,我们截断为200。

然后,我们为验证集合做同样的事情。

validation_sequences = tokenizer.texts_to_sequences(validation_articles)
validation_padded = pad_sequences(validation_sequences, maxlen=max_length, padding=padding_type, truncating=trunc_type)


print(len(validation_sequences))
print(validation_padded.shape)

445

(445, 200)

现在,我们将查看标签。因为我们的标签是文本,所以我们将在训练时将其贴上标签,预计标签将为Numpy矩阵。因此,我们将标签列表变成像这样的Numpy矩阵格式:

label_tokenizer = Tokenizer()
label_tokenizer.fit_on_texts(labels)


training_label_seq = np.array(label_tokenizer.texts_to_sequences(train_labels))
validation_label_seq = np.array(label_tokenizer.texts_to_sequences(validation_labels))
print(training_label_seq[0])
print(training_label_seq[1])
print(training_label_seq[2])
print(training_label_seq.shape)


print(validation_label_seq[0])
print(validation_label_seq[1])
print(validation_label_seq[2])
print(validation_label_seq.shape)

7624b7a3e806050f3d6ffc840052ab13.png

在训练深度神经网络之前,我们可以抽查原始文章和填充后的文章。运行以下代码,我们查看了第11篇文章,我们可以看到有些单词变成“”,因为它们没有进入前5,000名。

reverse_word_index = dict([(value, key) for (key, value) in word_index.items()])


def decode_article(text):
    return ' '.join([reverse_word_index.get(i, '?') for i in text])
print(decode_article(train_padded[10]))
print('---')
print(train_articles[10])

9f425cb2b1654f8f09dea3c42a43ae19.png

现在是训练LSTM的时候了。

  • 我们构建一个tf.keras.Sequential模型,并从 embeddings 开始。一个 embeddings 将每个单词存储为一个矢量。当调用时,它将单词索引序列转换为向量序列。训练后,具有相似含义的单词通常具有相似的向量。

  • Bidirectional wrapper 与LSTM层一起使用,这可以通过LSTM层向前和向后传播输入,然后将输出串联。这有助于LSTM学习长期依赖性。

  • 我们使用 relu 代替 tahn。

  • 我们添加一个具有6个单元和 SoftMax 函数。当我们有多个输出时,SoftMax 将输出层转换为概率分布。

model = tf.keras.Sequential([
    # Add an Embedding layer expecting input vocab of size 5000, and output embedding dimension of size 64 we set at the top
    tf.keras.layers.Embedding(vocab_size, embedding_dim),
    tf.keras.layers.Bidirectional(tf.keras.layers.LSTM(embedding_dim)),
#    tf.keras.layers.Bidirectional(tf.keras.layers.LSTM(32)),
    # use ReLU in place of tanh function since they are very good alternatives of each other.
    tf.keras.layers.Dense(embedding_dim, activation='relu'),
    # Add a Dense layer with 6 units and softmax activation.
    # When we have multiple outputs, softmax convert outputs layers into a probability distribution.
    tf.keras.layers.Dense(6, activation='softmax')
])
model.summary()

3484c98d9cd2888c363dc9d114e16b99.png

在我们的模型 summary 中,我们有 embeddings,Bidirectional LSTM,然后是两个 dense 层。双向的输出为128,因为它使我们在LSTM中输入的输出增加了一倍。我们也可以堆叠LSTM层,但我发现结果更糟。

print(set(labels))

4d451ef184aa108b2b380ba86094e847.png

我们总共有5个标签,但是由于我们没有使用 one-hot 编码标签,因此我们必须使用 sparse_categorical_crossentropy 作为损失函数,因此似乎也认为0也是一个可能的标签。因此,最后一个密集层需要标签0、1、2、3、4、5的输出,而不是整数,尽管从未使用过0。

如果您希望最后一个密集的层是5,则需要从训练集和验证集的标签中减去1。(此处保留)

我决定训练10个epoch。

model.compile(loss='sparse_categorical_crossentropy', optimizer='adam', metrics=['accuracy'])
num_epochs = 10
history = model.fit(train_padded, training_label_seq, epochs=num_epochs, validation_data=(validation_padded, validation_label_seq), verbose=2)

7a781c88c0b3766431d24211d99c4470.png

def plot_graphs(history, string):
  plt.plot(history.history[string])
  plt.plot(history.history['val_'+string])
  plt.xlabel("Epochs")
  plt.ylabel(string)
  plt.legend([string, 'val_'+string])
  plt.show()
  
plot_graphs(history, "accuracy")
plot_graphs(history, "loss")

4f0da67813a1971a5a66b42442090747.png

我们可能只需要3或4个epoch。(在训练结束时,我们可以看到有点过于拟合)

·  END  ·

HAPPY LIFE

6acd2461ceaa0cf16c8a5bfd2b4e3b71.png

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/512133.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

mac下删除python3.7,并将版本更新到3.9

如何卸载python3.7 有些小伙伴想直接从3.7升级到3.9 那恐怕是不行的&#xff0c;python3.7的库占的空间不少&#xff0c;所以首先我们应该来删除它. python安装后的路径分类 在删除之前需要先了解&#xff1a;python安装后有几类路径需要我们去查看删除 python存储库路径&am…

【分组码系列】线性分组码的网格图和维特比译码

线性分组码的网格图 由于码字的比特位是统计独立的,所以编码过程可以利用有限状态机来描述,它能精确地确定初始和最终状态。可以利用网格图进一步描述编码过程[36],采用维特比算法进行最大似然译码. 在GF(2)上定义线性分组码(n,k)。相应的(n-k)Xn维校验阵可以写成 令码字为系…

TensorFlow vs PyTorch:哪一个更适合您的深度学习项目?

在深度学习领域中&#xff0c;TensorFlow 和 PyTorch 都是非常流行的框架。这两个框架都提供了用于开发神经网络模型的工具和库&#xff0c;但它们在设计和实现上有很大的差异。在本文中&#xff0c;我们将比较 TensorFlow 和 PyTorch&#xff0c;并讨论哪个框架更适合您的深度…

队列、栈专题

队列、栈专题 LeetCode 20. 有效的括号解题思路代码实现 LeetCode 921. 使括号有效的最少添加解题思路代码实现 LeetCode 1541. 平衡括号字符串的最少插入次数解题思路代码实现 总结 不要纠结&#xff0c;干就完事了&#xff0c;熟练度很重要&#xff01;&#xff01;&#xff…

icevision环境安装

Installation - IceVision # 1. git clone 代码# pip 换源&#xff1a; ~/.pip/pip.conf 隐藏文件[global] index-url https://pypi.tuna.tsinghua.edu.cn/simple [install] trusted-hostmirrors.aliyun.compip install -e .[all,dev]ImportError: cannot import name Multi…

chatgpt-4它的未来是什么?该如何应用起来?

在当今快节奏的数字通信世界中&#xff0c;ChatGPT已成为一个强大的在线聊天平台&#xff0c;改变了人们互动和沟通的方式。凭借其先进的AI功能、用户友好的界面和创新技术&#xff0c;ChatGPT已成为个人和企业的热门选择。 然而&#xff0c;ChatGPT的未来有望更加激动人心和具…

VSCode的安装以及相关插件配置

VSCode是什么&#xff1f; VSCode严格来说&#xff0c;也是一款编辑器&#xff0c;强大之处在于集成了各种各样的插件。至此往后&#xff0c;将使用VSCode来取代vim。话不多说&#xff0c;步骤如下&#xff1a; 安装步骤 1、VSCode的下载 https://vscode.cdn.azure.cn/stabl…

NSSCTF (3)

[GDOUCTF 2023]hate eat snake 我们打开js源码 很明显这里当score大于60会出flag score = getScore 我们寻找到了getScore方法所在的地方 之后发现他存在于Snake

Python多线程之_thread与threading模块

Python多线程之_thread与threading模块 在Python程序中&#xff0c;多线程的应用程序会创建一个函数&#xff0c;来执行需要重复执行多次的程序代码&#xff0c;然后创建一个线程执行该函数。一个线程是一个应用程序单元&#xff0c;用于在后台并行执行多个耗时的动作。 在多…

DBWeaver 连接H2数据库 详细教程

1.DBWwaver下载网址 https://github.com/dbeaver/dbeaver/releases 2. 软件安装 点击安装文件&#xff0c;一直下一步即可 3. DBWeaver连接H2数据库 3.1打开软件在搜索框里面输入&#xff1a;h2 3.2 查询到h2数据库 3.3 点击选中的数据库&#xff0c;出现这样的页面&#xf…

铁路信号计轴设备简介

设备概述 计轴设备是铁路信号系统中的一个重要组成部分。它的主要功能是&#xff1a; 利用安装在钢轨上的传感器&#xff0c;来探测进入和出清轨道区段的车轮对数&#xff0c;进而判别轨道区段的占用和出清&#xff0c;其作用与轨道电路等效。 根据两站办理发车进路情况及区…

浪涌保护器的类型和应用

我们可能经常遇到电子设备损坏的情况。发生这种情况是由于多种情况造成的&#xff0c;例如大气变化&#xff08;闪电和雷声&#xff09;、电压击穿以及使用压缩机等重型设备。所有这些中断都可能会对电气设备造成破坏。进入这种情况的一种设备是浪涌保护器&#xff0c;也称为浪…

EndNote X9 导入知网文献 插入引用文献 方法

文章目录 1 EndNote X9 导入知网文献2 EndNote X9 插入参考文献常见问题总结3 EndNote X9 快速上手教程&#xff08;毕业论文参考文献管理器&#xff09; 1 EndNote X9 导入知网文献 下载知网参考文献引用&#xff1a; ①下载 引用&#xff1b; ②格式为 EndNote&#xff1b; 知…

流水线三维可视化运维,装配自动化提质增效

大家带来智慧生产线/设备流水线合集。 智慧仓储产线 智慧仓储产线通过对仓储现场的数字化建模&#xff0c;利用先进的物联网、大数据、人工智能等技术&#xff0c;对仓储现场设备、环境、人员进行全流程数字化管理。 为贯彻仓储行业应用的全面性&#xff0c;图扑 HT 应用 Web…

15个提高Javascript开发技巧

大厂面试题分享 面试题库 前后端面试题库 &#xff08;面试必备&#xff09; 推荐&#xff1a;★★★★★ 地址&#xff1a;前端面试题库 web前端面试题库 VS java后端面试题库大全 劈柴不照纹&#xff0c;累死劈柴人。上学的时候就总有那些“小怪物们”总能解出来难题&…

固态硬盘无法识别,怎么办?4招教您解决!

案例&#xff1a;电脑识别不了固态硬盘怎么办&#xff1f; 【我的电脑识别不了固态硬盘&#xff0c;这给我带来了很大的困扰。我尝试了很多方法&#xff0c;还没有解决。求一个有效的解决方法&#xff01;】 固态硬盘在计算机领域中越来越普遍&#xff0c;其快速读取和写入速…

PyCharm十大提高生产力的插件

PyCharm是一个非常流行的Python开发IDE。除了支持Python语言&#xff0c;PyCharm还支持其他流行的语言&#xff0c;如C、C、JavaScript等。PyCharm被广泛使用&#xff0c;是因为它拥有许多方便而实用的插件&#xff0c;这些插件能够显著提高开发者的生产力。下面我们将介绍十大…

鲸鸿动能广告接入如何高效变现流量?

广告是App开发者最常用的流量变现方法之一&#xff0c;当App拥有一定数量用户时&#xff0c;开发者就需要考虑如何进行流量变现&#xff0c;帮助App实现商业可持续增长。 鲸鸿动能流量变现服务是广告服务依托华为终端强大的平台与数据能力为开发者提供的App流量变现服务&#…

用例评审的正确姿势,2个要点不容忽视

&#xff0c;点击蓝字&#x1f446; 关注Agilean&#xff0c;获取一手干货 导语 用例评审的作用已经不言而喻&#xff0c;但是在很多组织的实际落地过程中&#xff0c;却收效甚微。研发管理人员常常会发现即使做了用例评审&#xff0c;一些显而易见的问题还是会出现&#xff1a…

ECharts折线图堆叠和不堆叠的问题

今天配合后台联调数据的时候遇到一种情况 第三条数据为0时候并没有在y轴为0上&#xff0c;而是跟上一条线重合了 ECharts折线图是堆叠的&#xff0c;折线图堆叠的意思就是&#xff1a;第二条线的数值本身的数值第一条线的数值&#xff0c;第三条的数值第二条线图上的数值本身的…