NLP项目实战01--电影评论分类

news2024/11/16 10:47:18

介绍:

欢迎来到本篇文章!在这里,我们将探讨一个常见而重要的自然语言处理任务——文本分类。具体而言,我们将关注情感分析任务,即通过分析电影评论的情感来判断评论是正面的、负面的。

展示:
训练展示如下:

在这里插入图片描述
在这里插入图片描述

实际使用如下:

请添加图片描述

实现方式:

选择PyTorch作为深度学习框架,使用电影评论IMDB数据集,并结合torchtext对数据进行预处理。

环境:

Windows+Anaconda
重要库版本信息
torch==1.8.2+cu102
torchaudio==0.8.2
torchdata==0.7.1
torchtext==0.9.2
torchvision==0.9.2+cu102

实现思路:

1、数据集
本次使用的是IMDB数据集,IMDB是一个含有50000条关于电影评论的数据集
数据如下:
请添加图片描述
请添加图片描述

2、数据加载与预处理
使用torchtext加载IMDB数据集,并对数据集进行划分
具体划分如下:

TEXT = data.Field(tokenize='spacy', tokenizer_language='en_core_web_sm')
LABEL = data.LabelField(dtype=torch.float)
# Load the IMDB dataset
train_data, test_data = datasets.IMDB.splits(TEXT, LABEL)

创建一个 Field 对象,用于处理文本数据。同时使用spacy分词器对文本进行分词,由于IMDB是英文的,所以使用en_core_web_sm语言模型。
创建一个 LabelField 对象,用于处理标签数据。设置dtype 参数为 torch.float,表示标签的数据类型为浮点型。

使用 datasets.IMDB.splits 方法加载 IMDB 数据集,并将文本字段 TEXT 和标签字段 LABEL 传递给该方法。返回的 train_data 和 test_data 包含了 IMDB 数据集的训练和测试部分。
下面是train_data的输出
请添加图片描述

3、构建词汇表与加载预训练词向量

TEXT.build_vocab(train_data,max_size=25000,vectors="glove.6B.100d",unk_init=torch.Tensor.normal_)
LABEL.build_vocab(train_data)

train_data:表示使用train_data中数据构建词汇表
max_size:限制词汇表的大小为 25000
vectors=“glove.6B.100d”:表示使用预训练的 GloVe 词向量,其中 “glove.6B.100d” 指的是包含 100 维向量的 6B 版 GloVe。
unk_init=torch.Tensor.normal_ :表示指定未知单词(UNK)的初始化方式,这里使用正态分布进行初始化。
LABEL.build_vocab(train_data):表示对标签进行类似的操作,构建标签的词汇表

train_iterator, valid_iterator, test_iterator = data.BucketIterator.splits( (train_data, valid_data, test_data), batch_size=BATCH_SIZE, device=device)

使用data.BucketIterator.splits 来创建数据加载器,包括训练、验证和测试集的迭代器。这将确保你能够方便地以批量的形式获取数据进行训练和评估。

4、定义神经网络
这里的网络定义比较简单,主要采用在词嵌入层(embedding)后接一个全连接层的方式完成对文本数据的分类。
具体如下:

class NetWork(nn.Module):
    def __init__(self,vocab_size,embedding_dim,output_dim,pad_idx):
        super(NetWork,self).__init__()
        self.embedding = nn.Embedding(vocab_size,embedding_dim,padding_idx=pad_idx)
        self.fc = nn.Linear(embedding_dim,output_dim)
        self.dropout = nn.Dropout(0.5)
        self.relu = nn.ReLU()

    def forward(self,x):
        embedded = self.embedding(x)
        embedded = embedded.permute(1,0,2) 
        pooled = F.avg_pool2d(embedded, (embedded.shape[1], 1)).squeeze(1)
        pooled = self.relu(pooled)
        pooled = self.dropout(pooled)
        
        output = self.fc(pooled)
        return output

5、模型初始化

vocab_size = len(TEXT.vocab)
embedding_dim  = 100
output = 1
pad_idx = TEXT.vocab.stoi[TEXT.pad_token]
model = NetWork(vocab_size,embedding_dim,output,pad_idx)
pretrained_embeddings = TEXT.vocab.vectors
model.embedding.weight.data.copy_(pretrained_embeddings)

定义模型的超参数,包括词汇表大小(vocab_size)、词向量维度(embedding_dim)、输出维度(output,在这个任务中是1,因为是二元分类,所以使用1),以及 PAD 标记的索引(pad_idx)

之后需要将预训练的词向量加载到嵌入层的权重中。TEXT.vocab.vectors 包含了词汇表中每个单词的预训练词向量,然后通过 copy_ 方法将这些词向量复制到模型的嵌入层权重中对网络进行初始化。这样做确保了模型的初始化状态良好。

6、训练模型

 total_loss = 0
 train_acc = 0 
model.train()
for batch in train_iterator:
        optimizer.zero_grad()
        preds = model(batch.text).squeeze(1)
        loss = criterion(preds,batch.label)
        total_loss += loss.item()

        batch_acc = (torch.round(torch.sigmoid(preds)) == batch.label).sum().item()
        train_acc += batch_acc
        
        loss.backward()
        optimizer.step()

    average_loss = total_loss / len(train_iterator)
    train_acc /= len(train_iterator.dataset)

optimizer.zero_grad():表示将模型参数的梯度清零,以准备接收新的梯度。
preds = model(batch.text).squeeze(1):表示一次前向传播的过程,由于model输出的是torch.tensor(batch_size,1)所以使用squeeze(1)给其中的1维度数据去除,以匹配标签张量的形状
criterion(preds,batch.label):定义的损失函数 criterion 计算预测值 preds 与真实标签 batch.label 之间的损失

(torch.round(torch.sigmoid(preds)) == batch.label).sum().item():
通过比较模型的预测值与真实标签,计算当前批次的准确率,并将其累加到 train_acc 中
后面的就是进行反向传播更新参数,还有就是计算loss和train_acc的值了
7、模型评估:

model.eval()
    valid_loss = 0
    valid_acc = 0
    best_valid_acc = 0
    with torch.no_grad():
        for batch in valid_iterator:
            preds = model(batch.text).squeeze(1)
            loss = criterion(preds,batch.label)
            valid_loss += loss.item()
            batch_acc = ((torch.round(torch.sigmoid(preds)) == batch.label).sum().item())
            valid_acc += batch_acc

和训练模型的类似,这里就不解释了

8、保存模型
这里一共使用了两种保存模型的方式:

torch.save(model, "model.pth")
torch.save(model.state_dict(),"model.pth")

第一种方式叫做模型的全量保存
第二种方式叫做模型的参数保存

全量保存是保存了整个模型,包括模型的结构、参数、优化器状态等信息
参数量保存是保存了模型的参数(state_dict),不包括模型的结构
9、测试模型
测试模型的基本思路:
加载训练保存的模型、对待推理的文本进行预处理、将文本数据加载给模型进行推理

加载模型:

saved_model_path = "model.pth"
saved_model = torch.load(saved_model_path)

输入文本:
input_text = “Great service! The staff was very friendly and helpful.”

文本进行处理:

tokenizer = get_tokenizer("spacy", language="en_core_web_sm")
tokenized_text = tokenizer(input_text)
indexed_text = [TEXT.vocab.stoi[token] for token in tokenized_text]
tensor_text = torch.LongTensor(indexed_text).unsqueeze(1).to(device)

模型推理:

saved_model.eval()
with torch.no_grad():
    output = saved_model(tensor_text).squeeze(1)
    prediction = torch.round(torch.sigmoid(output)).item()
    probability = torch.sigmoid(output).item()

由于笔者能力有限,所以在描述的过程中难免会有不准确的地方,还请多多包含!

更多NLP和CV文章以及完整代码请到"陶陶name"获取。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1299353.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

消息队列使用指南

介绍 消息队列是一种常用的应用程序间通信方法,可以用来在不同应用程序或组件之间传递数据或消息。消息队列就像一个缓冲区,接收来自发送方的消息,并存储在队列中,等待接收方从队列中取出并处理。 在分布式系统中,消…

Git的安装以及SSH配置

前言 近期工作需要,所以版本管理工具要用到Git,某些操作需要ssh进行操作,在某次操作中遇到:git bash报错:Permission denied, please try again。经排查是ssh没有配置我的key,所以就借着这篇文章整理了一下…

【小白专用】使用PHP创建和操作MySQL数据库,数据表

php数据库操作 php连接mysql数据库 <?php $hostlocalhost; // 数据库主机名 $username"root"; // 数据库用户名 $password"al6"; // 数据库密码 $dbname"mysql"; // 数据库名 $connIDmysqli_connect($host,$username,$password,$dbn…

Electron[4] Electron最简单的打包实践

1 背景 前面三篇已经完成通过Electron搭建的最简单的HelloWorld应用了&#xff0c;虽然这个应用还没添加任何实质的功能&#xff0c;但是用来作为打包的案例&#xff0c;足矣。下面再分享下通过Electron-forge来将应用打包成安装包。 2 依赖 在Electron[2] Electron使用准备…

AXURE地图获取方法

AXURE地图截取地址 https://axhub.im/maps/ 1、点击上方地图或筛选所需地区的地图&#xff0c;点击复制到 Axure 按钮&#xff0c;到 Axure 粘贴就可以了 2、复制到 Axure 后&#xff0c;转化为 svg 图形&#xff0c;就可以随意更改尺寸/颜色/边框&#xff0c;具体操作如下&am…

RocketMQ-源码架构二

梳理一些比较完整&#xff0c;比较复杂的业务线 消息持久化设计 RocketMQ的持久化文件结构 消息持久化也就是将内存中的消息写入到本地磁盘的过程。而磁盘IO操作通常是一个很耗性能&#xff0c;很慢的操作&#xff0c;所以&#xff0c;对消息持久化机制的设计&#xff0c;是…

使用Java8的Stream流的Collectors.toMap来生成Map结构

问题描述 在日常开发中总会有这样的代码&#xff0c;将一个List转为Map集合&#xff0c;使用其中的某个属性为key&#xff0c;某个属性为value。 常规实现 public class CollectorsToMapDemo {DataNoArgsConstructorAllArgsConstructorpublic static class Student {private…

基于YOLOv8深度学习的舰船目标分类检测系统【python源码+Pyqt5界面+数据集+训练代码】目标检测、深度学习实战

《博主简介》 小伙伴们好&#xff0c;我是阿旭。专注于人工智能、AIGC、python、计算机视觉相关分享研究。 ✌更多学习资源&#xff0c;可关注公-仲-hao:【阿旭算法与机器学习】&#xff0c;共同学习交流~ &#x1f44d;感谢小伙伴们点赞、关注&#xff01; 《------往期经典推…

【pycharm】Pycharm中进行Git版本控制

本篇文章主要记录一下自己在pycharm上使用git的操作&#xff0c;一个新项目如何使用git进行版本控制。 文章使用的pycharm版本PyCharm Community Edition 2017.2.4&#xff0c;远程仓库为https://gitee.com/ 1.配置Git&#xff08;File>Settings&#xff09; 2.去Gitee创建…

【C语言】位运算实现二进制数据处理及BCD码转换

文章目录 1&#xff0e;编程实验&#xff1a;按short和unsigned short类型分别对-12345进行左移2位和右移2位操作&#xff0c;并输出结果。2&#xff0e;编程实验&#xff1a;利用位运算实现BCD码与十进制数之间的转换&#xff0c;假设数据类型为unsigned char。3&#xff0e;编…

边缘计算系统设计与实践:引领科技创新的新浪潮

文章目录 一、边缘计算的概念二、边缘计算的设计原则三、边缘计算的关键技术四、边缘计算的实践应用《边缘计算系统设计与实践》特色内容简介作者简介目录前言/序言本书读者对象获取方式 随着物联网、大数据和人工智能等技术的快速发展&#xff0c;传统的中心化计算模式已经无法…

用php和mysql制作一个网站

当使用PHP和MySQL制作网站时&#xff0c;我们可以利用PHP的强大功能来与MySQL数据库进行交互&#xff0c;从而实现动态网页的创建和数据存取。下面是一个关于如何使用PHP和MySQL制作网站的简单说明&#xff0c;以及一些示例代码。 ​ 1、R5Ai智能助手 chatgpt国内版本 :R5Ai智…

P7 Linux C三种终止进程的方法

前言 &#x1f3ac; 个人主页&#xff1a;ChenPi &#x1f43b;推荐专栏1: 《C_ChenPi的博客-CSDN博客》✨✨✨ &#x1f525; 推荐专栏2: 《Linux C应用编程&#xff08;概念类&#xff09;_ChenPi的博客-CSDN博客》✨✨✨ &#x1f6f8;推荐专栏3: ​​​​​​《 链表_Chen…

C语言——字符函数和字符串函数(一)

&#x1f4dd;前言&#xff1a; 这篇文章对我最近学习的有关字符串的函数做一个总结和整理&#xff0c;主要讲解字符函数和字符串函数&#xff08;strlen&#xff0c;strcpy和strncpy&#xff0c;strcat和strncat&#xff09;的使用方法&#xff0c;使用场景和一些注意事项&…

机器人、智能小车常用的TT电机/310电机/370电机选型对比

在制作智能小车或小型玩具时&#xff0c;在电机选型上一些到各种模糊混淆的概念&#xff0c;以及各种错综复杂的电机参数&#xff0c;本文综合对比几种常用电机的参数及特性适应范围&#xff0c;以便快速选型&#xff0c;注意不同生产厂家的电机参数规则会有较大差异。 普通TT…

2023.12.09小爆发(31.56元) 穿山甲SDK接入收益·android广告接入·app变现·广告千展收益·eCPM收益

接入穿山甲SDK的app 数独训练APP 广告接入示例: Android 个人开发者如何接入广告SDK&#xff0c;实现app流量变现 接入穿山甲SDK app示例&#xff1a; android 数独小游戏 经典数独休闲益智 2023.12.09 广告收入有31.56R,小爆发了一下 1.用户统计上图&#xff1a; 昨天新增…

揭秘字符串的奥秘:探索String类的深层含义与源码解读

文章目录 一、导论1.1 引言&#xff1a;字符串在编程中的重要性1.2 目的&#xff1a;深入了解String类的内部机制 二、String类的设计哲学2.1 设计原则&#xff1a;为什么String类如此重要&#xff1f;2.2 字符串池的概念与作用 三、String类源码解析3.1 成员变量3.2 构造函数3…

【小聆送书第二期】人工智能时代之AIGC重塑教育

&#x1f308;个人主页&#xff1a;聆风吟 &#x1f525;系列专栏&#xff1a;网络奇遇记、数据结构 &#x1f516;少年有梦不应止于心动&#xff0c;更要付诸行动。 文章目录 &#x1f4cb;正文&#x1f4dd;活动参与规则 参与活动方式文末详见。 &#x1f4cb;正文 AI正迅猛地…

倪海厦:教你正确煮中药,发挥最大药效

同样的一个汤剂&#xff0c;我开给你&#xff0c;你如果煮的方法不对&#xff0c;吃下去效果就没那么好。 所以&#xff0c;汤&#xff0c;取它的迅捷&#xff0c;速度很快&#xff0c;煮汤的时候还有技巧&#xff0c;你喝汤料的时候&#xff0c;你到底是喝它的气&#xff0c;…

自动驾驶学习笔记(十六)——目标跟踪

#Apollo开发者# 学习课程的传送门如下&#xff0c;当您也准备学习自动驾驶时&#xff0c;可以和我一同前往&#xff1a; 《自动驾驶新人之旅》免费课程—> 传送门 《Apollo 社区开发者圆桌会》免费报名—>传送门 文章目录 前言 匹配关联 轨迹记录 状态预测 总结 前…