Pytorch实战笔记(2)——CNN实现情感分析

news2024/12/26 11:03:10

本文展示的是使用 Pytorch 构建一个 TextCNN 来实现情感分析。本文的架构是第一章详细介绍 TextCNN(不带公式版),第二章是核心代码部分。

目录

  • 1. TextCNN
  • 2. TextCNN 实现情感分析
  • 参考

1. TextCNN

相较于 LSTM 而言,我个人其实是没看过 CNN 的任何公式的,主要是我觉得也没必要,因为从使用的角度上讲,会用就行;从 CNN 的角度上讲,你只需要知道 CNN 提取的是一种聚合关系就行(与 GNN 不同的是,CNN 提取的是欧式数据的聚合关系,GNN 提取的是非欧数据的聚合关系)。

TextCNN [1] 的模型图如下图所示。其中一共包含了有 3 个模块:卷积层最大池化层,和输出层
TextCNN
在原论文中,作者采用了多个通道来提取不同词嵌入的特征,然而如果只有一种词嵌入输入的话,可以参考下面这篇论文[2]中的模型图:
TextCNN
我们现在就以上图作为例子,详细介绍下 TextCNN(仅有一种词嵌入输入时)的具体流程(放心,没有任何公式):

  1. 首先是输入,TextCNN 的输入是词嵌入,设序列长度为 s s s(在图中为 s = 7 s=7 s=7),设嵌入维度为 d d d(在图中 d = 5 d=5 d=5)。
  2. 接着 TextCNN 会经过一次二维卷积。首先是卷积核,从图中可以看到,一共有三个卷积核,大小分别为 ( 4 × 5 ) (4\times 5) (4×5) ( 3 × 5 ) (3\times 5) (3×5) ( 2 × 5 ) (2\times 5) (2×5)。先从卷积核的第二维开始说,我们发现卷积核的第二维都是 5,这个尺寸大小与嵌入维度 d d d 相同,就是说对于 TextCNN 而言,一次卷积要囊括所有词嵌入,这个也很好理解,因为只有 d d d 才能够代表整个词语。而对于第一维,分别为 432,这个就指的是,一次卷积考虑几个词语的依赖关系。假设卷积核大小为 ( 4 × 5 ) (4\times5) (4×5),那么就说明该卷积核一次聚合 4 个词语的依赖关系。(:图中从左到右第二个区域是卷积核,第三个区域才是卷积后的输出)
  3. 然后从图上可以发现,该图中每个卷积核共有 2 个滤波器(filter),所以一共有 2 ( 4 × 5 ) (4\times5) (4×5) 的卷积核、2 ( 3 × 5 ) (3\times5) (3×5) 的卷积核、2 ( 2 × 5 ) (2\times5) (2×5) 的卷积核。这里就是说,我们用不同数量的滤波器来捕获这个区间内不同的特征,以 ( 4 × 5 ) (4\times5) (4×5) 的卷积核举例,我们用滤波器 A 来捕获这4个词语之间的 A 特征,用滤波器 B 来捕获这4个词语之间的 B 特征。
  4. 卷积之后是一次一维最大池化操作。该过程提取卷积后的向量中的最大值,作为该滤波器的特征。
  5. 最后将所有滤波器通过最大池化后得到的特征拼接在一起,放到分类器中进行输出。

通过上述的解释,我们可以发现,TextCNN 是通过卷积核的尺寸,来控制模型捕获多少个词语之间的上下文关系。

2. TextCNN 实现情感分析

  • 全部代码在 github 上,网址为:https://github.com/Balding-Lee/Pytorch4NLP
  • 我采用的是 IMDb 数据集,由于数据集没有验证集,而且读取起来很麻烦,所以我将数据给读取出来,放到了一个文件中,并且将训练集中的10%划分为了验证集,数据集链接如下: https://pan.baidu.com/s/128EYenTiEirEn0StR9slqw,提取码:xtu3 。
  • 采用的词嵌入是谷歌的词嵌入,词嵌入的链接如下:链接:https://pan.baidu.com/s/1SPf8hmJCHF-kdV6vWLEbrQ,提取码:r5vx
    在本博客中仅介绍模型部分,详细代码见 github。

具体的模型代码如下:

import torch
import torch.nn as nn
import torch.nn.functional as F


class Config:
    def __init__(self):
        # 训练配置
        self.seed = 22
        self.batch_size = 64
        self.lr = 1e-3
        self.weight_decay = 1e-4
        self.num_epochs = 100
        self.early_stop = 512
        self.max_seq_length = 128
        self.save_path = '../model_parameters/CNN_SA.bin'

        # 模型配置
        self.filter_sizes = (3, 4, 5)
        self.num_filters = 100
        self.dense_hidden_size = 128
        self.dropout = 0.5
        self.embed_size = 300
        self.num_outputs = 2


class Model(nn.Module):
    def __init__(self, embed, config):
        super().__init__()
        self.embedding = nn.Embedding.from_pretrained(embed, freeze=False)
        self.convs = nn.ModuleList(
            [nn.Conv2d(1, config.num_filters, (k, config.embed_size)) for k in config.filter_sizes])
        self.dropout = nn.Dropout(config.dropout)
        self.relu = nn.ReLU()
        self.ffn = nn.Linear(config.num_filters * len(config.filter_sizes), config.dense_hidden_size)
        self.classifier = nn.Linear(config.dense_hidden_size, config.num_outputs)

    def max_pooling(self, x):
        x = F.max_pool1d(x, x.size(2)).squeeze(2)
        return x

    def forward(self, inputs):
        # shape: (batch_size, max_seq_length, embed_size)
        embed = self.embedding(inputs)

        # CNN 接受四维数据输入,
        # 第一维: batch,
        # 第二维: 通道数 (Channel), 在图像中指的是 RGB 这样的通道, 在自然语言里面指的是多少种词嵌入, 本项目中仅采用一种词嵌入, 所以就是 1 通道
        # 第三维: 高度 (Height), 在图像中指的是图片的高, 在自然语言里面就是序列长度
        # 第四维: 宽度 (Weight), 在图像中指的是图片的宽, 在自然语言里面就是嵌入维度
        # shape: (batch_size, 1, max_seq_length, embed_size)
        embed = embed.unsqueeze(1)

        cnn_outputs = []
        for conv in self.convs:
            # shape: (batch_size, filter_size, max_seq_length - kernel_size + 1, 1)
            conv_output = conv(embed)
            # shape: (batch_size, filter_size, max_seq_length - kernel_size + 1, 1)
            relu_output = self.relu(conv_output)
            # shape: (batch_size, filter_size, max_seq_length - kernel_size + 1, 1)
            relu_output = relu_output.squeeze(3)
            # shape: (batch_size, filter_size)
            pooling_output = self.max_pooling(relu_output)
            cnn_outputs.append(pooling_output)

        # shape: (batch, num_filters * len(num_filters))
        cnn_outputs = torch.cat(cnn_outputs, 1)
        cnn_outputs = self.dropout(cnn_outputs)
        # shape: (batch, dense_hidden_size)
        ffn_output = self.relu(self.ffn(cnn_outputs))
        # shape: (batch, num_outputs)
        logits = self.classifier(ffn_output)

        return logits

在该代码中,我才用的卷积核尺寸是 ( 3 × e m b e d _ s i z e ) (3\times {\rm embed\_size}) (3×embed_size) ( 4 × e m b e d _ s i z e ) (4\times {\rm embed\_size}) (4×embed_size) ( 5 × e m b e d _ s i z e ) (5\times {\rm embed\_size}) (5×embed_size),每个卷积核共有100个滤波器。同时,分类器一共有两层,一层的尺寸大小为 ( n u m _ f i l t e r s × l e n ( n u m _ f i l t e r s ) , d e n s e _ h i d d e n _ s i z e ) ({\rm num\_filters \times len(num\_filters)}, {\rm dense\_hidden\_size}) (num_filters×len(num_filters),dense_hidden_size),一层的尺寸大小为 ( d e n s e _ h i d d e n _ s i z e , n u m _ o u t p u t s ) ({\rm dense\_hidden\_size}, {\rm num\_outputs}) (dense_hidden_size,num_outputs)

实验结果如下:

test loss 0.367522 | test accuracy 0.833840 | test precision 0.822838 | test recall 0.850880 | test F1 0.836624

参考

[1] Yoon Kim. Convolutional Neural Networks for Sentence Classification [EB/OL]. https://arxiv.org/pdf/1408.5882.pdf, 2014.
[2] Ye Zhang, Byron C. Wallace. A Sensitivity Analysis of (and Practitioners’ Guide to) Convolutional Neural Networks for Sentence Classification [EB/OL]. https://arxiv.org/pdf/1510.03820.pdf, 2015.

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/173387.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

手把手教你学51单片机-C语言基础

二进制、十进制和十六进制 对于二进制来说,8 位二进制我们称之为一个字节。 我们在进行 C 语言编程的时候,我们只写十进制和十六进制,那么不带 0x 的就 是十进制,带了 0x 符号的就是十六进制。 C 语言变量类型和范围 C 语言的数据基本类型分为字符型、整型、长整型以及…

【算法题解】12. 删除链表中的节点

文章目录题目题解Java 代码实现Go 代码实现复杂度分析这是一道简单题,题目来自 leetcode 题目 给定一个单链表的 head,我们想删除它其中的一个节点 node。 只给你一个需要删除的节点 node 。你将 无法访问 第一个节点 head。 链表的所有值都是 唯一的&…

1.吴恩达机器学习课程笔记:多元梯度下降法

1.吴恩达机器学习课程笔记:多元梯度下降法 笔记来源:吴恩达机器学习课程笔记:多元梯度下降法 仅作为个人学习笔记,若各位大佬发现错误请指正 1.1 多元特征(变量) 每一列代表一个特征,例如&…

FMC子卡设计资料原理图:FMC550-基于ADRV9002双窄带宽带射频收发器FMC子卡

FMC550-基于ADRV9002双窄带宽带射频收发器FMC子卡一、产品概述 ADRV9002 是一款高性能、高线性度、高动态范围收发器,旨在针对性能与功耗系统进行优化。该设备是可配置的,非常适合要求苛刻、低功耗、便携式和电池供电的设备。ADRV9002 的工作频率为 …

计算机基础——无处不网络(2)

作者简介:一名云计算网络运维人员、每天分享网络与运维的技术与干货。 座右铭:低头赶路,敬事如仪 个人主页:网络豆的主页​​​​​​ 目录 前言 一.计算机网络的接入方式 1

青龙脚本-稳定阅读

稳定阅读积分 查看请求头的cookie 变量 export zzbhd 多号或换行 */ 脚本附后 /* 查看请求头的cookie 变量 export zzbhd 多号或换行*/const $ new Env(至尊宝阅读); const axios require(axios); let request require("request"); request request.defaults(…

纯手工模拟Vue中的数据劫持和代理

为什么要实现数据劫持和代理 举一个场景:比如在小程序开发中,我们需要逻辑层修改的数据能同步响应更新到视图层的页面上,那么底层框架在实现这种效果的时候,机制是什么样的呢? 其实这里的底层原理类似于Vue中的数据劫…

基于蜣螂算法优化的BP神经网络(预测应用) - 附代码

基于蜣螂算法优化的BP神经网络(预测应用) - 附代码 文章目录基于蜣螂算法优化的BP神经网络(预测应用) - 附代码1.数据介绍3.蜣螂优化BP神经网络3.1 BP神经网络参数设置3.2 蜣螂算法应用4.测试结果:5.Matlab代码摘要&am…

ClickHouse快速复习

ClickHouse​一.特性​1.列式数据库管理系统​2.数据压缩​3.数据的磁盘存储​4.支持SQL​5.索引​6.适合在线查询​7.支持数据复制和数据完整性​8.实时的数据更新​9.处理大量短查询的吞吐量​10.处理大量短查询的吞吐量​11.限制​二.数据类型​1.数字类型​2.浮点数(float)…

前端都在聊什么 - 第 1 期

Hello 小伙伴们早上、中午、下午、晚上、深夜好,我是爱折腾的 jsliang~「前端都在聊什么」是 jsliang 日常写文章/做视频/搞直播过程中,小伙伴们的提问以及我的解疑整理。本期对应 2023 年的 01.01-01.15 这个时间段。本期针对「工作」「学习」「规划」「…

迭代器、可迭代对象、生成器的区别和联系

目录1 迭代器2 可迭代对象3 生成器1 迭代器 迭代器是一种可以更新迭代的工具,迭代器对象从集合的第一个元素开始访问,直到所有的元素被访问完结束。但是他不能像列表一样使用下标来获取数据,也就是说迭代器是不能返回的。迭代器只能往前不会…

Universal Links方式:私有化部署服务器来托管apple-app-site-association文件创建通用链接

Universal Links方式:私有化部署服务器来托管apple-app-site-association第一步:开启Associated Domains服务第二步:配置Associated Domains(域名)第三步:服务器配置apple-app-site-association文件第四步:…

java的数据类型:引用数据类型(String、数组、枚举)

2.3.3 引用数据类型 引用数据类型大致包括:类、 接口、 数组、 枚举、 注解、 字符串等 它和基本数据类型的最大区别就是: 基本数据类型是直接保存在栈中的引用数据类型在栈中保存的是一个地址引用,这个地址指向的是其在堆内存中的实际位置…

四旋翼无人机学习第22节--padstack editor创建过孔

1 首先打开padstack editor软件。 2、选择过孔,注意与前面的博客不同,这里的单位最好使用mil。 在小马哥的教程中,过孔可以分为几类,下面主要对下图的五种过孔进行设置。 3、接着对过孔的孔径进行设置。 4、不做修改。 5、修…

网络交换机常见故障及解决方法

在日常的网络故障维护中我们接触最多的设备就是交换机,特别是接入层交换机,它是连接用户和交换路由设备的桥梁。但是交换机设备无论性能多么好,都会存在潜在故障问题,就像人一样,无论多么健康,也总会出现一…

MindMaster思维导图及亿图图示会员 优惠活动

MindMaster思维导图及亿图图示会员 超值获取途径 会员九折优惠方法分享给大家!如果有需要,可以上~ 以下是食用方法: MindMaster 截图 亿图图示 截图 如果需要MindMaster思维导图或者亿图图示会员,可按照如下操作领取超值折扣优惠…

java成员变量/局部变量2023017

成员变量/局部变量 1.定义位置不同,成员变量定义在类里,局部变量定义在类的方法里。 来自网络 2.成员变量中,其中类变量从该类的准备阶段起开始存在,直到系统完全销毁这个类,类变量的作用域与这个类的生存范围相同&…

超市进销存之openGauss数据库的应用与实践

目录 一、背景 二、目的 三、什么是“进销存”,什么是超市进销存管理系统? 四、什么是openGauss数据库? 五、应用与实践(模拟超市进销存系统) 1、超市进销存数据库表设计 2、创建数据库表 3、手工插入数据 4、…

Python:使用xlrd过滤execl表中数据

一、写代码前需要注意事项首先我们需要注意:python xlrd库的新版本2.0.1版本移除了对.xlsx格式的支持,只支持.xls格式。报错信息如下:File "/Library/Frameworks/Python.framework/Versions/3.7/lib/python3.7/site-packages/xlrd/__ini…

react17+ts 学习

文章目录前言一、创建一个react项目启动项目项目打包打包命令npm run eject的作用入口文件分析react的设计理念二、创建一个reacttypescript的项目创建项目命令如何让react支持json引入不报错react为什么使用jsxjsx特点jsx命令规范jsx表示对象如何在jsx中防止注入攻击&#xff…