MXNet中使用卷积神经网络textCNN对文本进行情感分类

news2025/1/15 17:40:34

在图像识别领域,卷积神经网络是非常常见和有用的,我们试图将它应用到文本的情感分类上,如何处理呢?其实思路也是一样的,图片是二维的,文本是一维的,同样的,我们使用一维的卷积核去处理一维的文本(当作一维的图片)即可。这样也可以达到图片抽取特征类似的效果,也可以捕捉到临近词之间的关联。

下面是这节将需要导入的包跟模块

import d2lzh as d2l
from mxnet import gluon,init,nd
from mxnet.contrib import text
from mxnet.gluon import data as gdata,loss as gloss,nn

一维卷积层

一维卷积层的原理跟前面学到的二维卷积层是一样的,一维卷积层使用一维的互相关运算,在一维互相关运算中,卷积窗口从输入数组的最左边开始,按照从左往右的顺序,依次在输入数组上滑动,当卷积窗口滑动到某一个位置时,窗口中的输入子数组就跟核数组按元素相乘并求和。

我们来直观的看图就明白了,输入是宽为7的一维数组,核数组宽为2,输出的宽度:7-2+1=6,高亮颜色的地方按照元素相乘再相加:0x1+1x2=2,如下图:

以前的图很多都是通过CorelDraw来画图,现在使用小画桌来在线画图还是挺方便快捷的,推荐大家使用。

一维互相关运算函数如下:

def corr1d(X,K):
    w=K.shape[0]
    Y=nd.zeros((X.shape[0]-w+1))
    for i in range(Y.shape[0]):
        Y[i]=(X[i:i+w]*K).sum()
    return Y

X,K=nd.array([0,1,2,3,4,5,6]),nd.array([1,2])
print(corr1d(X,K))
'''
[ 2.  5.  8. 11. 14. 17.]
<NDArray 6 @cpu(0)>
'''

跟图片中的结果是一样的,接下来看下多通道的输入和多个卷积核是怎么计算的,先看图:

然后我们也使用一个函数来验证下多通道的结果:

def corr1d_multi_in(X,K):
    return nd.add_n(*[corr1d(x,k) for x,k in zip(X,K)])
X=nd.array([[0,1,2,3,4,5,6],[1,2,3,4,5,6,7],[2,3,4,5,6,7,8]])
K=nd.array([[1,2],[3,4],[-1,-3]])
print(corr1d_multi_in(X,K))
'''
[ 2.  8. 14. 20. 26. 32.]
<NDArray 6 @cpu(0)>
'''

没有问题,其中*星号是将结果列表变为add_n函数的位置参数,然后进行相加运算。上图的三通道输入的一维卷积运算,是可以看作单通道输入的二维卷积互相关运算。如下图:

时序最大池化层

前面的文章介绍的卷积运算有接池化层,这里同样的,也有一维的池化层。textCNN中使用的时序最大池化(max-over-time pooling)层实际上对应的是一维全局最大池化层:假设输入包含多个通道,各通道由不同时间步上的数值组成,各通道的输出即该通道所有时间步中最大的数值。因此,时序最大池化层的输入在各个通道上的时间步数可以不同。

在textCNN模型中是怎么使用卷积层与时序最大池化层的,我们先画个图来直观感受下:

图片比较直观感受到这个模型的流程,接下来我们设计这个模型,在此之前整理数据集,还是使用前面介绍的电影评论数据集来做情感分析

batch_size=64
d2l.download_imdb()
train_data,test_data=d2l.read_imdb('train'),d2l.read_imdb('test')
vocab=d2l.get_vocab_imdb(train_data)
train_iter=gdata.DataLoader(gdata.ArrayDataset(*d2l.preprocess_imdb(train_data,vocab)),batch_size,shuffle=True)
test_iter=gdata.DataLoader(gdata.ArrayDataset(*d2l.preprocess_imdb(test_data,vocab)),batch_size)

创建textCNN模型

textCNN模型主要步骤如下:

1、定义多个一维卷积,分别对这些输入做卷积计算,宽度不同的卷积核可能会捕捉到不同个数的相邻词的相关性,从图中我们也可以看到卷积核的一个宽度是2,另一个是4
2、对输出的所有通道分别做时序最大池化,再将这些通道的池化输出值连结为向量
3、通过全连接层将连接后的向量变换为有关各类别的输出,这里可以加一个Dropout丢弃层来应对过拟合

实现模型的代码,这里使用两个嵌入层,一个的权重固定,另一个的权重参与训练

class TextCNN(nn.Block):
    def __init__(self,vocab,embed_size,kernel_sizes,num_channels,**kwargs):
        super(TextCNN,self).__init__(**kwargs)
        self.embedding=nn.Embedding(len(vocab),embed_size)
        # 不参与训练的嵌入层
        self.constant_embedding=nn.Embedding(len(vocab),embed_size)
        self.dropout=nn.Dropout(0.5)
        self.decoder=nn.Dense(2)
        # 时序最大池化层没有权重,所以可以共用一个实例
        self.pool=nn.GlobalMaxPool1D()
        self.convs=nn.Sequential()
        # 添加多个一维的卷积层
        for c,k in zip(num_channels,kernel_sizes):
            # NCW
            self.convs.add(nn.Conv1D(c,k,activation='relu'))
    
    def forward(self,inputs):
        # NWC(批量大小,词数,词向量维度[通道])的两个嵌入层的输出按照词向量维度dim=2连结
        embeddings=nd.concat(self.embedding(inputs),self.constant_embedding(inputs),dim=2)
        # 一维卷积的输入格式是NCW,所以进行形状变换
        embeddings=embeddings.transpose((0,2,1))
        # 对于每个一维卷积层,在时序最大池化后会得到一个形状为(批量大小,通道大小,1)的NDArray
        # 使用flatten函数去掉最后一维,然后在通道维上连结
        encoding=nd.concat(*[nd.flatten(self.pool(conv(embeddings))) for conv in self.convs],dim=1)
        # 应用丢弃法后使用全连接层得到输出
        outputs=self.decoder(self.dropout(encoding))
        return outputs

#创建textCNN实例,3个卷积层,其核宽分别是3,4,5,输出通道数均为100
embed_size,kernel_size,num_channels=100,[3,4,5],[100,100,100]
ctx=d2l.try_all_gpus()
net=TextCNN(vocab,embed_size,kernel_size,num_channels)
net.initialize(init.Xavier(),ctx=ctx)

训练模型

模型的创建,这里使用100维的GloVe词向量,对于GloVe的了解可以参阅:自然语言处理(NLP)之求近义词和类比词<MXNet中GloVe和FastText的模型使用>

glove_embedding=text.embedding.create('glove',pretrained_file_name='glove.6B.100d.txt',vocabulary=vocab)
# 这个嵌入层的权重参数训练
net.embedding.weight.set_data(glove_embedding.idx_to_vec)
# 固定权重
net.constant_embedding.weight.set_data(glove_embedding.idx_to_vec)
net.constant_embedding.collect_params().setattr('grad_req','null')

预训练词向量搞定之后就开始训练模型

lr,num_epochs=0.001,5
trainer=gluon.Trainer(net.collect_params(),'adam',{'learning_rate':lr})
loss=gloss.SoftmaxCrossEntropyLoss()
d2l.train(train_iter,test_iter,net,loss,trainer,ctx,num_epochs)
# 预测
print(d2l.predict_sentiment(net,vocab,['this','movie','is','very','nice']))
print(d2l.predict_sentiment(net,vocab,['this','movie','is','so','bad']))
print(d2l.predict_sentiment(net,vocab,['this','movie','is','not','bad']))
print(d2l.predict_sentiment(net,vocab,['this','movie','is','too','bad']))

'''
epoch 1, loss 0.6138, train acc 0.714, test acc 0.832, time 44.2 sec
epoch 2, loss 0.3582, train acc 0.844, test acc 0.852, time 43.5 sec
epoch 3, loss 0.2646, train acc 0.892, test acc 0.864, time 43.5 sec
epoch 4, loss 0.1711, train acc 0.937, test acc 0.868, time 43.3 sec
epoch 5, loss 0.1081, train acc 0.962, test acc 0.858, time 43.4 sec
positive
negative
negative
negative
'''

可以看到训练的准确度还是很不错的,测试的准确度也可以,有待提高,第三条影评识别错误,其余都预测对了。

对于准确度的提高,有两个方向可以去做,还记得吗,就是 MXNet中使用双向循环神经网络BiRNN对文本进行情感分类<改进版>

这篇文章中的两种方法,使用SpaCy分词工具和扩大词向量的维度,有兴趣的伙伴们可以去试试。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/372307.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

LLVM程序分析与编译转换框架论文分享

LLVM 2004年论文原文 概述 本文描述了 LLVM&#xff08;低级虚拟机&#xff09;&#xff0c;一种编译器框架&#xff0c;旨在通过在编译时、链接时、运行时&#xff0c;以及运行之间的空闲时间。 LLVM 以静态单一赋值 (SSA) 形式定义了一种通用的低级代码表示&#xff0c;具有…

多区域的OSPF实战配置

多区域的OSPF实战配置 需求 如图配置设备的接口IP地址如图规划OSPF网络的区域要求每个设备的 router-id 都是 x.x.x.x&#xff08;x是每个路由器的名字&#xff09;确保不同的PC之间可以互通 拓扑图 配置命令 PC1&#xff1a; 192.168.1.1 255.255.255.0 192.168.1.254PC2:…

【手把手一起学习】(六) Altium Designer 20 STM32核心板Demo----PCB设计

1 PCB设计 PCB设计是制作STM32核心板的关键步骤&#xff0c;其关系到最终生产厂家制作的电路板能否正常使用&#xff0c;PCB设计包括布局&#xff0c;裁板&#xff0c;布线&#xff0c;覆铜&#xff0c;DRC检查等&#xff0c;其中要求、细节、技巧比较多&#xff0c;以后会更详…

redis(7)哈希Hash

哈希Hash Redis hash 是一个键值对集合。 Redis hash 是一个 string 类型的 field 和 value 的映射表&#xff0c;hash 特别适合用于存储对象。 类似 Java 里面的 Map<String,Object>。 用户 ID 为查找的 key&#xff0c;存储的 value 用户对象包含姓名&#xff0c;年…

西北工业大学大学物理(II)选填解析2019-2020期末

2 又是考查“一个电子和一个光子具有相同的波长&#xff0c;则二者动量相等。”4 斯特恩盖拉赫实验&#xff0c;原子的自旋磁矩取向量子化。7 通常我们感受不到电子的波动性。因为其波长短&#xff0c;其实也就是粒子运动速率高。10 考查无限长直导线周围B分布。常见的模型要记…

【035】基于Vue的电商推荐管理系统(含源码数据库、超详细论文)

摘 要&#xff1a;基于Vue&#xff0b;Nodejs&#xff0b;mysql的电商推荐管理系统&#xff0c;这个项目论文超详细&#xff0c;er图、接口文档、功能展示、技术栈等说明特别全&#xff01;&#xff01;&#xff01; &#xff08;文末附源码数据库、课设论文获取方式&#xff0…

xgboost学习-原理

文章目录一、xgboost库与XGB的sklearn APIXGBoost的三大板块二、梯度提升树提升集成算法&#xff1a;重要参数n_estimators三、有放回随机抽样&#xff1a;重要参数subsample四、迭代决策树&#xff1a;重要参数eta总结一、xgboost库与XGB的sklearn API 现在&#xff0c;我们有…

【ROS学习笔记4】话题通信

【ROS学习笔记4】话题通信 文章目录【ROS学习笔记4】话题通信零、话题通信概述一、话题通信的理论模型二、话题通信基本操作的Cpp实现三、话题通信基本操作的Python实现四、话题通信自定义msg五、话题通信自定义msg调用的Cpp实现六、话题通信自定义msg的Python实现七、Referenc…

《MySQL学习》 Order by 工作原理

《MySQL学习》 Order by 工作原理 一.排序缓存 sort_buffer 当我们使用explain 分析一条带有排序操作的SQL语句时&#xff0c;会看到Extra中有使用 Using filesort explain select * from t order by k descMySQL 会为每个线程分配固定大小的 sort buffer 用作排序。 sort b…

SAP 怎么冲销已冲销的凭证?

假如有这么一种场景&#xff1a;你做了一张凭证A&#xff0c;你第一次发现账务做错了。你用fb08来冲销A&#xff0c;生成了冲销凭证B&#xff0c;然后第二次检查的时候你发现&#xff0c;凭证A其实没错&#xff0c;这时候能不能把冲销凭证B撤销掉&#xff1f; 然后凭证A就恢复了…

【C语言每日一题】猜名次

【C语言每日一题】—— 猜名次&#x1f60e;&#x1f60e;&#x1f60e; &#x1f4a1;前言&#x1f31e;&#xff1a; &#x1f49b;猜名次题目&#x1f49b; &#x1f4aa; 解题思路的分享&#x1f4aa; &#x1f60a;题目源码的分享&#x1f60a; &#x1f449; 本菜鸡…

【MySQL】增删改操作(基础篇)

目录 1、新增操作(Create) 1.1 单行数据 全列插入 1.2 多行数据 全列插入 1.3 单行数据 指定列插入 2、修改操作(Update) 3、删除操作(Delete) 1、新增操作(Create) 如何给一张表新增数据呢&#xff1f; 新增(Create)&#xff0c;在我们数据库中&#xff0c;用 ins…

三行代码让你的git记录保持整洁

前言笔者最近在主导一个项目的架构迁移工作&#xff0c;由于迁移项目的历史包袱较重&#xff0c;人员合作较多&#xff0c;在迁移过程中免不了进行多分支、多次commit的情况&#xff0c;时间一长&#xff0c;git的提交记录便混乱不堪&#xff0c;随便截一个图形化的git提交历史…

location

目录 匹配的目标 格式 匹配符号&#xff1a; 优先级 要表达不匹配条件&#xff0c;则用 if 实现 例子&#xff1a;根目录的匹配最弱 例子&#xff1a;区分大小写 和 不区分大小写 例子&#xff1a;以根开头 和 不区分大小写 例子&#xff1a;等号 匹配的目标 ng…

Vue2.0开发之——使用ref引用组件实例(41)

一 概述 在本组件内部修改count的值在父组件内修改子组件的count值 二 在本组件内部修改count的值 2.1 Left.vue 布局代码 <template><div class"left-container"><h3 >Left 组件---{{count}}</h3><button click"count 1"&…

团队:在人身上,你到底愿意花多大精力?

你好&#xff0c;我是叶芊。 今天我们讨论怎么带团队这个话题&#xff0c;哎先别急着走&#xff0c;你可能跟很多人一样&#xff0c;觉得带团队离我还太远&#xff0c;或者觉得我才不要做管理&#xff0c;我要一路技术走到底&#xff0c;但是你知道吗&#xff1f;带团队做事&am…

华为OD机试用Python实现 -【MVP 争夺战】(2023-Q1 新题)

华为OD机试题 华为OD机试300题大纲MVP 争夺战题目描述输入描述输出描述示例一输入输出说明Python 代码实现代码实现思路华为OD机试300题大纲 参加华为od机试,一定要注意不要完全背诵代码,需要理解之后模仿写出,通过率才会高。 华为 OD 清单查看地址:blog.csdn.net/hihell…

​AAAI 2023 | 利用脉冲神经网络扩展动态图表示学习

©PaperWeekly 原创 作者 | 李金膛单位 | 中山大学博士生研究方向 | 可信图学习2020 年国家双碳战略的确立与实施&#xff0c;绿色低碳已经成为全社会的重要议题&#xff0c;也是科技从业者的重要使命和责任。有文献指出&#xff0c;从 2012 年到 2018 年&#xff0c;用于…

CountDownLatch与CyclicBarrier原理剖析

1.CountDownLatch 1.1 什么是CountDownLatch CountDownLatch是一个同步工具类&#xff0c;用来协调多个线程之间的同步&#xff0c;或者说起到线程之间的通信&#xff08;而不是用作互斥的作用&#xff09;。 CountDownLatch能够使一个线程在等待另外一些线程完成各自工作之…

分布式算法 - Snowflake算法

Snowflake&#xff0c;雪花算法是由Twitter开源的分布式ID生成算法&#xff0c;以划分命名空间的方式将 64-bit位分割成多个部分&#xff0c;每个部分代表不同的含义。这种就是将64位划分为不同的段&#xff0c;每段代表不同的涵义&#xff0c;基本就是时间戳、机器ID和序列数。…