NLP项目实战01之电影评论分类

news2024/10/7 16:18:13

介绍:

欢迎来到本篇文章!在这里,我们将探讨一个常见而重要的自然语言处理任务——文本分类。具体而言,我们将关注情感分析任务,即通过分析电影评论的情感来判断评论是正面的、负面的。

展示:
训练展示如下:

在这里插入图片描述
在这里插入图片描述

实际使用如下:

请添加图片描述

实现方式:

选择PyTorch作为深度学习框架,使用电影评论IMDB数据集,并结合torchtext对数据进行预处理。

环境:

Windows+Anaconda
重要库版本信息
torch==1.8.2+cu102
torchaudio==0.8.2
torchdata==0.7.1
torchtext==0.9.2
torchvision==0.9.2+cu102

实现思路:

1、数据集
本次使用的是IMDB数据集,IMDB是一个含有50000条关于电影评论的数据集
数据如下:
请添加图片描述
请添加图片描述

2、数据加载与预处理
使用torchtext加载IMDB数据集,并对数据集进行划分
具体划分如下:

TEXT = data.Field(tokenize='spacy', tokenizer_language='en_core_web_sm')
LABEL = data.LabelField(dtype=torch.float)
# Load the IMDB dataset
train_data, test_data = datasets.IMDB.splits(TEXT, LABEL)

创建一个 Field 对象,用于处理文本数据。同时使用spacy分词器对文本进行分词,由于IMDB是英文的,所以使用en_core_web_sm语言模型。
创建一个 LabelField 对象,用于处理标签数据。设置dtype 参数为 torch.float,表示标签的数据类型为浮点型。

使用 datasets.IMDB.splits 方法加载 IMDB 数据集,并将文本字段 TEXT 和标签字段 LABEL 传递给该方法。返回的 train_data 和 test_data 包含了 IMDB 数据集的训练和测试部分。
下面是train_data的输出
请添加图片描述

3、构建词汇表与加载预训练词向量

TEXT.build_vocab(train_data,max_size=25000,vectors="glove.6B.100d",unk_init=torch.Tensor.normal_)
LABEL.build_vocab(train_data)

train_data:表示使用train_data中数据构建词汇表
max_size:限制词汇表的大小为 25000
vectors=“glove.6B.100d”:表示使用预训练的 GloVe 词向量,其中 “glove.6B.100d” 指的是包含 100 维向量的 6B 版 GloVe。
unk_init=torch.Tensor.normal_ :表示指定未知单词(UNK)的初始化方式,这里使用正态分布进行初始化。
LABEL.build_vocab(train_data):表示对标签进行类似的操作,构建标签的词汇表

train_iterator, valid_iterator, test_iterator = data.BucketIterator.splits( (train_data, valid_data, test_data), batch_size=BATCH_SIZE, device=device)

使用data.BucketIterator.splits 来创建数据加载器,包括训练、验证和测试集的迭代器。这将确保你能够方便地以批量的形式获取数据进行训练和评估。

4、定义神经网络
这里的网络定义比较简单,主要采用在词嵌入层(embedding)后接一个全连接层的方式完成对文本数据的分类。
具体如下:

class NetWork(nn.Module):
    def __init__(self,vocab_size,embedding_dim,output_dim,pad_idx):
        super(NetWork,self).__init__()
        self.embedding = nn.Embedding(vocab_size,embedding_dim,padding_idx=pad_idx)
        self.fc = nn.Linear(embedding_dim,output_dim)
        self.dropout = nn.Dropout(0.5)
        self.relu = nn.ReLU()

    def forward(self,x):
        embedded = self.embedding(x)
        embedded = embedded.permute(1,0,2) 
        pooled = F.avg_pool2d(embedded, (embedded.shape[1], 1)).squeeze(1)
        pooled = self.relu(pooled)
        pooled = self.dropout(pooled)
        
        output = self.fc(pooled)
        return output

5、模型初始化

vocab_size = len(TEXT.vocab)
embedding_dim  = 100
output = 1
pad_idx = TEXT.vocab.stoi[TEXT.pad_token]
model = NetWork(vocab_size,embedding_dim,output,pad_idx)
pretrained_embeddings = TEXT.vocab.vectors
model.embedding.weight.data.copy_(pretrained_embeddings)

定义模型的超参数,包括词汇表大小(vocab_size)、词向量维度(embedding_dim)、输出维度(output,在这个任务中是1,因为是二元分类,所以使用1),以及 PAD 标记的索引(pad_idx)

之后需要将预训练的词向量加载到嵌入层的权重中。TEXT.vocab.vectors 包含了词汇表中每个单词的预训练词向量,然后通过 copy_ 方法将这些词向量复制到模型的嵌入层权重中对网络进行初始化。这样做确保了模型的初始化状态良好。

6、训练模型

 total_loss = 0
 train_acc = 0 
model.train()
for batch in train_iterator:
        optimizer.zero_grad()
        preds = model(batch.text).squeeze(1)
        loss = criterion(preds,batch.label)
        total_loss += loss.item()

        batch_acc = (torch.round(torch.sigmoid(preds)) == batch.label).sum().item()
        train_acc += batch_acc
        
        loss.backward()
        optimizer.step()

    average_loss = total_loss / len(train_iterator)
    train_acc /= len(train_iterator.dataset)

optimizer.zero_grad():表示将模型参数的梯度清零,以准备接收新的梯度。
preds = model(batch.text).squeeze(1):表示一次前向传播的过程,由于model输出的是torch.tensor(batch_size,1)所以使用squeeze(1)给其中的1维度数据去除,以匹配标签张量的形状
criterion(preds,batch.label):定义的损失函数 criterion 计算预测值 preds 与真实标签 batch.label 之间的损失

(torch.round(torch.sigmoid(preds)) == batch.label).sum().item():
通过比较模型的预测值与真实标签,计算当前批次的准确率,并将其累加到 train_acc 中
后面的就是进行反向传播更新参数,还有就是计算loss和train_acc的值了
7、模型评估:

model.eval()
    valid_loss = 0
    valid_acc = 0
    best_valid_acc = 0
    with torch.no_grad():
        for batch in valid_iterator:
            preds = model(batch.text).squeeze(1)
            loss = criterion(preds,batch.label)
            valid_loss += loss.item()
            batch_acc = ((torch.round(torch.sigmoid(preds)) == batch.label).sum().item())
            valid_acc += batch_acc

和训练模型的类似,这里就不解释了

8、保存模型
这里一共使用了两种保存模型的方式:

torch.save(model, "model.pth")
torch.save(model.state_dict(),"model.pth")

第一种方式叫做模型的全量保存
第二种方式叫做模型的参数保存

全量保存是保存了整个模型,包括模型的结构、参数、优化器状态等信息
参数量保存是保存了模型的参数(state_dict),不包括模型的结构
9、测试模型
测试模型的基本思路:
加载训练保存的模型、对待推理的文本进行预处理、将文本数据加载给模型进行推理

加载模型:

saved_model_path = "model.pth"
saved_model = torch.load(saved_model_path)

输入文本:
input_text = “Great service! The staff was very friendly and helpful.”

文本进行处理:

tokenizer = get_tokenizer("spacy", language="en_core_web_sm")
tokenized_text = tokenizer(input_text)
indexed_text = [TEXT.vocab.stoi[token] for token in tokenized_text]
tensor_text = torch.LongTensor(indexed_text).unsqueeze(1).to(device)

模型推理:

saved_model.eval()
with torch.no_grad():
    output = saved_model(tensor_text).squeeze(1)
    prediction = torch.round(torch.sigmoid(output)).item()
    probability = torch.sigmoid(output).item()

由于笔者能力有限,所以在描述的过程中难免会有不准确的地方,还请多多包含!

更多NLP和CV文章以及完整代码请到"陶陶name"获取。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1296830.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

ABAP - Function ALV 01 Function ALV的三大基石

森莫是Function ALV? 业务顾问和用户方面的名词定义为报表,在开发顾问方面定义的名词为ALV 通过调用Function方式展示的ALV叫做FunctionALV.Function的解释:封装好的函数 Function ALV的三大基石 Fieldcat :Function ALV字段级别的处理 Layout …

CentOS服务自启权威指南:手动启动变为开机自启动(以Jenkins服务为例)

前言 CentOS系统提供了多种配置服务开机自启动的方式。本文将介绍其中两种常见的方式, 一种是使用Systemd服务管理器配置,不过,在实际中,如果你已经通过包管理工具安装的,那么服务通常已经被配置为Systemd服务&#…

利用法线贴图渲染逼真的3D老虎模型

在线工具推荐: 3D数字孪生场景编辑器 - GLTF/GLB材质纹理编辑器 - 3D模型在线转换 - Three.js AI自动纹理开发包 - YOLO 虚幻合成数据生成器 - 三维模型预览图生成器 - 3D模型语义搜索引擎 当谈到游戏角色的3D模型风格时,有几种不同的风格&#xf…

让聪明的车连接智慧的路,C-V2X开启智慧出行生活

“聪明的车 智慧的路”形容的便是车路协同的智慧交通系统,从具备无钥匙启动,智能辅助驾驶和丰富娱乐影音功能的智能网联汽车,到园区的无人快递配送车,和开放的城市道路上自动驾驶的公交车、出租车,越来越多的车联网应用…

ELK(四)—els基本操作

目录 elasticsearch基本概念RESTful API创建非结构化索引(增)创建空索引(删)删除索引(改)插入数据(改)数据更新(查)搜索数据(id)&…

【Copilot】Edge浏览器的copilot消失了怎么办

这种原因,可能是因为你的ip地址的不在这个服务的允许范围内。你需要重新使用之前出现copilot的ip地址,然后退出edge的账号,重新登录一遍,最后重启edge,就能够使得copilot侧边栏重新出现了。

mac苹果电脑清除数据软件CleanMyMac X4.16

在数字时代,保护个人隐私变得越来越重要。当我们出售个人使用的电脑,亦或者离职后需要上交电脑,都需要对存留在电脑的个人信息做彻底的清除。随着越来越多的人选择使用苹果电脑,很多人想要了解苹果电脑清除数据要怎样做才是最彻底…

Endnote在word中加入参考文献及自定义参考文献格式方式

第一部分:在word中增加引用步骤 1、先下载对应文献的endnote引用格式,如在谷歌学术中的下载格式如下: 2、在endnote中打开存储env的格式库,导入对应下载的文件格式:file>import>file>choose,import对应文件&a…

套接字应用程序

这章节是关于实现 lib_chan 库的 。 lib_chan 的代码在 TCP/IP 之上实现了一个完整的网络层,能够提供认证和Erlang 数据流功能。一旦理解了 lib_chan 的原理,就能量身定制我们自己的通信基础结构,并把它叠加在TCP/IP 之上了。 就lib_chan 本身…

PSP - 计算蛋白质复合物链间接触的残基与面积

欢迎关注我的CSDN:https://spike.blog.csdn.net/ 本文地址:https://spike.blog.csdn.net/article/details/134884889 在蛋白质复合物中,通过链间距离,可以计算出在接触面的残基与接触面的面积,使用 BioPython 库 与 SA…

在Pytorch中使用Tensorboard可视化训练过程

这篇是我对哔哩哔哩up主 霹雳吧啦Wz 的视频的文字版学习笔记 感谢他对知识的分享 本节课我们来讲一下如何在pytouch当中去使用我们的tensorboard 对我们的训练过程进行一个可视化 左边有一个visualizing models data and training with tensorboard 主要是这么一个教程 那么这里…

28BYJ-48步进电机的驱动

ULN2003的工作原理 28BYJ48可以用ULN2003来驱动,STM32使用开漏模式外接5V上拉电阻也可以产生5V电压,为什么不直接使用单片机的 GPIO来驱动的原因是虽然电压符合电机的驱动要求,但单片机引脚产生的驱动电流太小,因此驱动步进电机要…

IBM Qiskit量子机器学习速成(四)

量子核机器学习 一般步骤 量子核机器学习的一般步骤如下 定义量子核 我们使用FidelityQuantumKernel类创建量子核,该类需要传入两个参数:特征映射和忠诚度(fidelity)。如果我们不传入忠诚度,该类会自动创建一个忠诚度。 注意各个类所属的…

leaflet:经纬度坐标转为地址,点击鼠标显示地址信息(137)

第137个 点击查看专栏目录 本示例的目的是介绍演示如何在vue+leaflet中将经纬度坐标转化为地址,点击鼠标显示某地的地址信息 。主要利用mapbox的api将坐标转化为地址,然后在固定的位置显示出来。 直接复制下面的 vue+leaflet源代码,操作2分钟即可运行实现效果 文章目录 示…

UniDBGrid序号列添加标题

有人想要在UniDBGrid的序号列加上标题,就是这里 可以使用如下代码 UniSession.AddJS(MainForm.UniDBGrid1.columnManager.columns[0].setText("序号"));

VMware安装Ubuntu20.04并使用Xshell连接虚拟机

文章目录 虚拟机环境准备重置虚拟网络适配器属性(可选)配置NAT模式的静态IP创建虚拟机虚拟机安装配置 Xshell连接虚拟机 虚拟机环境准备 VMware WorkStation Pro 17.5:https://customerconnect.vmware.com/cn/downloads/details?downloadGr…

样本数量对问卷信度效度分析的影响及应对策略

问卷调研是一种常见的数据收集方法。明确问卷的真实性和效率是保证其靠谱性和有效性的重要一步。但问卷的真实性和品质会受到样本数量的影响吗? 一、问卷信度的认识 1、信度的概念和重要性:在问卷实验中,信度是指问卷测量值的稳定性和一致性。高信度代…

go grpc高级用法

文章目录 错误处理常规用法进阶用法原理 多路复用元数据负载均衡压缩数据 错误处理 gRPC 一般不在 message 中定义错误。毕竟每个 gRPC 服务本身就带一个 error 的返回值,这是用来传输错误的专用通道。gRPC 中所有的错误返回都应该是 nil 或者 由 status.Status 产…

希尔排序详解:一种高效的排序方法

在探索排序算法的世界中,我们经常遇到需要对大量数据进行排序的情况。传统的插入排序虽然简单,但在处理大规模数据时效率并不高。这时,希尔排序(Shell Sort)就显得尤为重要。本文将通过深入解析希尔排序的逻辑&#xf…

分享十几个适合新手练习的软件测试项目

说实话,在找项目的过程中,我下载过(甚至付费下载过)N多个项目、联系过很多项目的作者,但是绝大部分项目,在我看来,并不适合你拿来练习,它们或多或少都存在着“问题”,比如…