机器学习深度学习——NLP实战(情感分析模型——textCNN实现)

news2024/11/17 21:44:07

👨‍🎓作者简介:一位即将上大四,正专攻机器学习的保研er
🌌上期文章:机器学习&&深度学习——NLP实战(情感分析模型——RNN实现)
📚订阅专栏:机器学习&&深度学习
希望文章对你们有所帮助

NLP实战(情感分析模型——textCNN实现)

  • 引入
  • 一维卷积
  • 最大时间池化层
  • textCNN模型
    • 定义模型
    • 加载预训练词向量
    • 训练和评估模型
  • 小结

引入

之前已经讨论过使用二维卷积神经网络来处理二维图像数据的机制,并将其应用于局部特征,如相邻像素。虽然卷积神经网络最初是为计算机视觉而设计的,但它也被广泛应用于NLP。简单来说,只要将任何文本序列想象成一维图像即可。通过这种方式,一维卷积神经网络可以处理文本中的局部特征,例如n元语法。
本节将使用textCNN模型来演示如何设计一个表示单个文本的卷积神经网络架构。与上一节的情感分析相比,唯一的区别只有架构的选择不同。
在这里插入图片描述

import torch
from torch import nn
from d2l import torch as d2l

batch_size = 64
train_iter, test_iter, vocab = d2l.load_data_imdb(batch_size)

一维卷积

在介绍该模型前,让我们看下一维卷积的工作原理,其实就可以看作是二维卷积的特例:
在这里插入图片描述
我们在下面的corr1d函数实现了一维互相关,给定输入张量X和核张量K,它返回输出张量Y。

def corr1d(X, K):
    w = K.shape[0]
    Y = torch.zeros((X.shape[0] - w + 1))
    for i in range(Y.shape[0]):
        Y[i] = (X[i: i + w] * K).sum()
    return Y

构造输入张量X和核张量K来验证上述一维互相关实现的输出。

X, K = torch.tensor([0, 1, 2, 3, 4, 5, 6]), torch.tensor([1, 2])
print(corr1d(X, K))

输出结果:

tensor([ 2., 5., 8., 11., 14., 17.])

对于任何具有多个通道的一维输入,卷积核需要具有相同数量的输入通道。然后,对于每个通道,对输入的一维张量和卷积核的一维张量执行互相关运算,将所有通道上的结果相加以产生一维输出张量。如图:
在这里插入图片描述
我们可以实现多个输入通道的一维互相关运算:

def corr1d_multi_in(X, K):
    # 首先,遍历'X'和'K'的第0维(通道维)。然后,把它们加在一起
    return sum(corr1d(x, k) for x, k in zip(X, K))

可以验证结果:

X = torch.tensor([[0, 1, 2, 3, 4, 5, 6],
              [1, 2, 3, 4, 5, 6, 7],
              [2, 3, 4, 5, 6, 7, 8]])
K = torch.tensor([[1, 2], [3, 4], [-1, -3]])
print(corr1d_multi_in(X, K))

输出结果:

tensor([ 2., 8., 14., 20., 26., 32.])

多输入通道的一维互相关等同于单输入通道的二维互相关,如上图中的例子可以等价为下图:
在这里插入图片描述

最大时间池化层

类似地,我们可以使用池化层从序列表示中提取最大值,作为跨时间步的最重要特征。textCNN中使用的最大时间池化层的工作原理类似于一维全局池化。对于每个通道在不同时间步存储值的多通道输入,每个通道的输出是该通道的最大值。注意,最大时间池化允许在不同通道上使用不同数量的时间步。

textCNN模型

使用一维卷积核最大时间池化,textCNN模型将单个预训练的词元表示作为输入,然后获得并转换用于下游任务的序列表示。
对于具有由d维向量表示的n个词元的单个文本序列,输入张量的宽度、高度和通道数分别为n、1和d。textCNN模型将输入转换为输出,如下所示:
1、定义多个一维卷积核,并分别对输入执行卷积运算。具有不同宽度的卷积核可以捕获不同数目的相邻词元之间的局部特征。
2、在所有输出通道上执行最大时间池化层,然后将所有标量池化输出连结为向量。
3、使用全连接层将连结后的向量转换为输出类别。dropout可以用来减少过拟合。
如下所示:
在这里插入图片描述
上面所示的例子,我们有一个宽度为11的6通道输入。定义两个宽度为2和4的一维卷积核,分别具有4个和5个输出通道。它们产生4个宽度为11-2+1=10的输出通道和5个宽度为11-4+1=8的输出通道。尽管这9个通道的宽度不同,但最大时间池化层给出了一个连结的9维向量,该向量最终被转换为用于二元情感预测的2维输出向量。

定义模型

现在实现textCNN模型,与之前的双向循环神经网络模型相比,除了用卷积层代替循环神经网络层外,我们还使用了两个嵌入层:一个是可训练权重,另一个是固定权重。

class TextCNN(nn.Module):
    def __init__(self, vocab_size, embed_size, kernel_sizes, num_channels,
                 **kwargs):
        super(TextCNN, self).__init__(**kwargs)
        self.embedding = nn.Embedding(vocab_size, embed_size)
        # 这个嵌入层不需要训练
        self.constant_embedding = nn.Embedding(vocab_size, embed_size)
        self.dropout = nn.Dropout(0.5)
        self.decoder = nn.Linear(sum(num_channels), 2)
        # 最大时间汇聚层没有参数,因此可以共享此实例
        self.pool = nn.AdaptiveAvgPool1d(1)
        self.relu = nn.ReLU()
        # 创建多个一维卷积层
        self.convs = nn.ModuleList()
        for c, k in zip(num_channels, kernel_sizes):
            self.convs.append(nn.Conv1d(2 * embed_size, c, k))

    def forward(self, inputs):
        # 沿着向量维度将两个嵌入层连结起来,
        # 每个嵌入层的输出形状都是(批量大小,词元数量,词元向量维度)连结起来
        embeddings = torch.cat((
            self.embedding(inputs), self.constant_embedding(inputs)), dim=2)
        # 根据一维卷积层的输入格式,重新排列张量,以便通道作为第2维
        embeddings = embeddings.permute(0, 2, 1)
        # 每个一维卷积层在最大时间汇聚层合并后,获得的张量形状是(批量大小,通道数,1)
        # 删除最后一个维度并沿通道维度连结
        encoding = torch.cat([
            torch.squeeze(self.relu(self.pool(conv(embeddings))), dim=-1)
            for conv in self.convs], dim=1)
        outputs = self.decoder(self.dropout(encoding))
        return outputs

让我们创建一个textCNN实例。它有3个卷积层,卷积核宽度分别为3、4和5,均有100个输出通道。

embed_size, kernel_sizes, nums_channels = 100, [3, 4, 5], [100, 100, 100]
devices = d2l.try_all_gpus()
net = TextCNN(len(vocab), embed_size, kernel_sizes, nums_channels)

def init_weights(m):
    if type(m) in (nn.Linear, nn.Conv1d):
        nn.init.xavier_uniform_(m.weight)

net.apply(init_weights)

加载预训练词向量

我们加载预训练的100维GloVe嵌入作为初始化的词元表示。这些词元表示(嵌入权重)在embedding中将被训练,在constant_embedding中将被固定。

glove_embedding = d2l.TokenEmbedding('glove.6b.100d')
embeds = glove_embedding[vocab.idx_to_token]
net.embedding.weight.data.copy_(embeds)
net.constant_embedding.weight.data.copy_(embeds)
net.constant_embedding.weight.requires_grad = False

训练和评估模型

现在我们可以训练textCNN模型进行情感分析。

lr, num_epochs = 0.001, 5
trainer = torch.optim.Adam(net.parameters(), lr=lr)
loss = nn.CrossEntropyLoss(reduction="none")
d2l.train_ch13(net, train_iter, test_iter, loss, trainer, num_epochs, devices)

运行结果:

loss 0.064, train acc 0.979, test acc 0.873
138.7 examples/sec on [device(type=‘cpu’)]

运行图片:
在这里插入图片描述

下面,我们使用训练好的模型来预测两个简单句子的情感。

d2l.predict_sentiment(net, vocab, 'this movie is so great')

预测结果:

‘positive’

d2l.predict_sentiment(net, vocab, 'this movie is so bad')

预测结果:

‘negative’

小结

1、一维卷积神经网络可以处理文本中的局部特征,例如n元语法。
2、多输入通道的一维互相关等价于单输入通道的二维互相关。
3、最大时间池化层允许在不同通道上使用不同数量的时间步长。
4、textCNN模型使用一维卷积层和最大时间池化层将单个词元表示转换为下游应用输出。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/906115.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

NDK 的配置记录~

NDK 的配置 NDK配置 NDK设置在 AS 路径中设置在 local.properties设置在 build.gradle ndk 和 gradle 对应关系gradle的插件和版本对应关系gradle 插件和NDK对应关系 NDK NDK(Native Development Kit)是一组工具和库,用于在 Android 平台上开…

加强预算管理一体化,走进全面预算管理的数智化时代

自2021年,国务院印发《国务院关于进一步深化预算管理制度改革的意见》(国发〔2021〕5号)以来,中央政府对企业实施全面预算管理越来越重视,预算绩效管理逐渐进入大家的视线。各个企业逐步落实应用,推进预算管…

C++的常用基础知识100个

1、定义一个常量 2、数据类型-整型 3、数据类型-字符型 4、数据的输入 5、运算符 6、三目运算符 7、循环案例-99乘法表 8、数组 9、冒泡排序 10、函数的定义 11、函数的分文件编写 12、指针 12、结构体 13、通讯录项目 创建一个空项目,并命名为通讯录管理系统。 14…

“我来拿”APP设计报告

1.设计摘要 想必大家对学校的悬赏互助群并不陌生,学生们在群里提出要求并标明价格,就可以找人帮忙。我们的跑腿平台就是以此为灵感,让学生之间通过一个专门的020平台实现有报酬的互助跑腿,但是相比QQ、微信群,我们让定…

js数组常用的方法(总结)

目录 1.数组头和尾操作——push、pop、unshift/shift 2、数组转为字符串 —— join() 3、数组截取 —— slice() 4、数组更新 —— splice() 5、反转数组 —— reverse() 6、连接数组 —— concat() 7、ES6连接数组 —— ... ES5数组新增方法 8、索引方法 —— indexO…

PgSQL中的DATE_PART使用

用法: DATE_PART(field, source) 这个DATE_PART()函数返回类型为double precision的值 century decade year month day hour minute second microseconds milliseconds dow doy epoch isodow isoyear timezone timezone_hour timezone_minute

水经微图网页版发布

水经微图网页版,可轻松将关注的地点制作成你的个人地图。 你可以在任意位置添加标注点或绘制地图,查找地点并将其保存到你的地图中,或导入地图数据迅速制作地图并保存,你还可以运用图标和颜色展示个性风采,从而可让每…

线程阻塞队列

阻塞队列 一、BlockingQueue 接口 BlockingQueue 是阻塞队列接口实现机制是使用两条线程,允许两个线程同时操作队列一个线程用于写入 Put ,一个线程用于读取 Take当队列中没有数据的情况下,读取线程会自动阻塞,直到有数据放入队列当队列中数…

opencv进阶12-EigenFaces 人脸识别

EigenFaces 通常也被称为 特征脸,它使用主成分分析(Principal Component Analysis,PCA) 方法将高维的人脸数据处理为低维数据后(降维),再进行数据分析和处理,获取识别结果。 基本原理…

蓝牙资讯|消息称富士康投资4亿美元在印度生产苹果 AirPods 耳机

根据印度最大通讯社 PTI 报道,苹果和富士康已经签署一项新的协议,富士康将投资 4 亿美元在印度第四大城市海得拉巴扩建工厂,负责为苹果生产 AirPods TWS 耳机。 报道称苹果已经决定在印度本土生产 AirPods 耳机,富士康计划投资 …

测试框架pytest教程(2)-用例依赖库-pytest-dependency

对于 pytest 的用例依赖管理,可以使用 pytest-dependency 插件。该插件提供了更多的依赖管理功能,使你能够更灵活地定义和控制测试用例之间的依赖关系。 Using pytest-dependency — pytest-dependency 0.5.1 documentation 安装 pytest-dependency 插…

ipad手写笔有必要买吗?开学便宜又好用电容笔推荐

苹果电容笔之所以能够被iPad用户广泛使用,很大程度上是因为其的优秀性能,具有着独特的重力压感功能。但苹果原装的电容笔,价格相对比较高,所以很多人,都选择了普通的平替电容笔。如今许多人都爱用iPad来画图或写笔记&a…

Go 数组

一、复合类型: 二、数组 如果要存储班级里所有学生的数学成绩,应该怎样存储呢?可能有同学说,通过定义变量来存储。但是,问题是班级有80个学生,那么要定义80个变量吗? 像以上情况,最…

攻防世界-command_execution

原题 解题思路 题目告诉了,这可以执行ping命令且没WAF,那就可以在ping命令后连接其他命令。 服务器一般使用Linux,在Linux中可使用“&”连接命令。 ping 127.0.0.1&find / -name "flag*" ping 127.0.0.1&cat /home/f…

Linux中shell脚本常用命令、条件语句与if、case语句

目录 一.shell脚本常用命令 1.1.echo命令 1.2.date命令 1.3.cal命令 1.4.tr命令 1.5.cut命令 1.6.sort命令 1.7.uniq命令 1.8.cat多行重定向 二.条件语句 2.1.条件测试(三种测试方法) 2.2.正整数值比较 2.3.字符串比较 2.4.逻辑测试 三.i…

深入了解 Java 中 Files 类的常用方法及抽象类的 final 修饰

文章目录 Files 类常用方法抽象类的 final 修饰 🎉欢迎来到Java学习路线专栏~深入了解 Java 中 Files 类的常用方法及抽象类的 final 修饰 ☆* o(≧▽≦)o *☆嗨~我是IT陈寒🍹✨博客主页:IT陈寒的博客🎈该系列文章专栏&#xff1a…

【C语言学习】二分法查找有序数组中的数

二分查找的基本原理 二分查找的基本逻辑就是每次找区间的中间数,然后与要查找的数进行比较,不断的缩小区间,最后区间中只剩一个数,即为要查找的数。如果不是,则没有该数。 二分查找只适用于有序数组 以数组中的数从左…

计算机视觉领域文献引用

Bag of freebies 炼丹白嫖加油包 Bag of freebies、致力于解决数据集中语义分布可能存在偏差的问题。在处理语义分布偏差问题时,一个非常重要的问题是不同类别之间存在数据不平衡的问题。 一、数据增强篇 Data Augmentation (1)图片像素调整…

安全模式进不去,解决方法在这!

“我想让电脑进入安全模式,但无论我怎么操作都无法进入。这是怎么回事呢?我怎么才能让电脑进入安全模式呢?请求帮助!” 安全模式是Windows操作系统的一种启动选项,用于解决系统问题和故障。然而,有时候用户…

PON测试,“信”助力 | 信而泰测试解决方案浅析

PON介绍 一、什么是PON网络 PON是“Passive Optical Network”的缩写,是一种基于光纤的网络技术。PON网络通过单向的光信号传输来实现数据、语音和视频等信息的传输。PON网络可以支持多个传输速率和距离要求,因此广泛应用于FTTH、FTTB(Fibe…