6.2 通过构建情感分类器训练词向量

news2024/10/6 2:21:48

        在上一节中,我们简要地了解了词向量,但并没有去实现它。在本节中,我们将下载一个名为IMDB的数据集(其中包含了评论),然后构建一个用于计算评论的情感是正面、负面还是未知的情感分类器。在构建过程中,还将为 IMDB 数据集中存在的词进行词向量的训练。我们将使用一个名为 torchtext 的库,这个库使下载、向量化文本和批处理等许多过程变得更加容易。训练情感分类器将包括以下步骤。

  1. 下载 IMDB 数据并对文本分词;
  2. 建立词表;
  3. 生成向量的批数据;
  4. 使用词向量创建网络模型;
  5. 训练模型。

6.2.1 下载 IMDB 数据并对文本分词

        对于与计算机视觉相关的应用,我们使用过 torchvision 库。它提供了许多实用功能,并帮助我们构建计算机视觉应用程序。同样,有一个名为 torchtext 的库,它也是 PyTorch 的一部分,它与 PyTorch 一起工作,通过为文本提供不同的数据加载器和抽象,简化了许多与自然语言处理相关的活动。在本书写作时,torchtext 没有包含在 PyTorch 包内,需要独立安装。可以在计算机的命令行中运行以下代码来安装torchtext:

pip install torchtext

        安装完成后就可以使用它了。torchtext 提供了两个重要的模块:torchtext.data和torchtext.datasets。

        1. torchtext.data

        torchtext.data 实例定义了一个名为 Field 的类,它可以用来定义数据如何读取和分词。让我们看一下使用它来准备 IMDB 数据集的示例:

from torchtext import data
TEXT = data.Field(lower=True, batch_first=True, fix_length=20)
LABEL = data.Field(sequential=False)

        在上述代码中,我们定义了两个 Field 对象,一个用于实际的文本,另一个用于标签数据。对于实际的文本,我们期望 torchtext 将所有文本都小写并对文本分词,同时将其修整为最大长度为20。如果我们正在为生产环境构建应用程序,则可以将长度修正为更大的数字。当然对于当前练习的例子,20的长度够用了。Field 的构造函数还接受另一个名为 tokenize 的参数,该参数默认使用str.split 函数。还可以指定spaCy作为参数或任何其他分词器。我们的例子将使用 str.split。

        2. torchtext.datasets

        torchtext.datasets 实例提供了使用不同数据集的封装,如IMDB、TREC(问题分类)、语言建模(WikiText-2)和一些其他数据集。我们将使用 torch.datasets 下载 IMDB 数据集并将其拆分为 train 和 test 数据集。以下代码执行此操作,当第一次运行它时,可能需要几分钟,具体取决于网络连接速度,因为它是从 Internet 上下载 IMDB 数据集的:

train,test=datasets.IMDB.splits(TEXT,LABEL)

        之前的数据集的 IMDB 类抽象出了下载、分词和将数据库拆分为 train 和 test 数据集涉及的所有复杂度。train.fields 包含一个字典,其中 TEXT 是键,值是 LABEL。让我们看看 train.fields 和 train 集合的每个元素:

print('train.fields',train.fields)

        从这些结果中可以看到,单个元素包含了一个字段 text 和表示 text 的所有 token,以及包含了文本标签的字段 label。现在已准备好对 IMDB 数据集进行批处理了。

6.2.2 构建词表

        当为 thor_review 创建独热编码时,同时创建了一个作为词表的 word2idx 字典,它包含文档中唯一词的所有细节。torchtext 实例使处理更加容易。在加载数据后,可以调用 build_vocab 并传入负责为数据构建词表的必要参数。以下代码说明了如何构建词表:

TEXT.build_vocab(train,vectors=GloVe(name=,6B,dim=300),max_size=10000,min_freq=10)
LABEL.build_vocab(train)

        在上述代码中,传入了需要构建词表的 train 对象,并让它使用维度为 300 的预训练词向量来初始化向量。当使用预训练权重训练情感分类器时,build_vocab 对象只是下载并创建稍后将使用的维度。max_size 实例限制了词表中词的数量,而min_freg删除了出现不超过10 次的词,其中 10是可配置的。
        当词汇表构建完成后,我们就可以获得例如词频、词索引和每个词的向量表示等不同的值。下面的代码演示了如何访问这些值:

print(TEXT.vocab.freqs)

        以下代码演示了如何访问结果:

print(TEXT.vocab.vectors)

        使用 stoi 访问包含词及其索引的字典。

6.2.3 生成向量的批数据

        torchtext 提供了 BucketIterator,它有助于批处理所有文本并将词替换成词的索引。BucketIterator 实例带有许多有用的参数,如batch_size、device(GPU或CPU)和 shuffle (是否必须对数据进行混洗)。下面的代码演示了如何为 train 和 test 数据集创建生成批处理的迭代器:

train_iter, test_iter = data.BucketIterator.splits((train, test),
batch_size=128,device=-1,shuffle=True)
#device = -1 表示使用 cpu,设置为 None 时使用 gpu.

        上述代码为 train 和 test 数据集提供了一个 BucketIterator 对象。以下代码将说明如何创建 batch 并显示 batch 的结果:

batch = next(iter(train_iter))
batch.text

        从上面代码段的结果中,可以看到文本数据如何转换为 batch_size * fix_len (即128x20) 大小的矩阵。

6.2.4 使用词向量创建网络模型

        我们之前简要地讨论过词向量。在本节中,我们将创建作为网络架构的一部分的词向量,并训练整个模型用以预测每个评论的情感。在训练结束时,将得到一个情感分类器模型,以及 IMDB 数据集的词向量。以下代码演示了如何使用词向量创建用于情感预测的网络架构:

class EmbNet(nn.Module):
    def _init_(self,emb_size,hidden_sizel,hidden_size2 = 400):
        super()._init_()
        self.embedding = nn.Embedding(emb_size,hidden_sizel)
        self.fc = nn.Linear(hidden_size2,3)
    def forward(self,x):
        embeds = self.embedding(x).view(x.size(0),-1)
        out = self.fc(embeds)
        return F.log_softmax(out, dim = -1)

        在上述代码中,EmbNet 创建了情感分类模型。在_init_函数中,我们使用两个参数初始化了 nn.Embedding 类的一个对象,它接收两个参数,即词表的大小和希望为每个单词创建的维度。由于限制了唯一单词的数量,因此词表的大小将为10,000,并且我们可以从一个小的向量尺寸(比如10)开始。为了快速运行程序,有必要使用个小尺寸的向量值,但是当试图为生产系统构建应用程序时,请使用大尺寸的词向量。我们还有一个线性层,将词向量映射到情感的类别(如正面、负面或未知)。
        forward 函数确定了输入数据的处理方式。对于批量大小为 32 以及最大长度为 20 个词的句子,输入形状为 32x20。第一个 embedding 层充当查找表,用相应的词向量替换掉每个词。对于向量维度 10,当每个词被其相应的词向量替换时,输出形状变成了 32x20x10。view 函数将使 embedding 层的结果变得扁平。传递给 view 函数的第一个参数将保持维数不变。在我们的例子中,我们不希望组合来自不同批次的数据,因此保留第一个维数并将张量中的其余值扁平化。在应用 view 函数后,张量形状变为 32x200。全连接层将扁平化的词向量映射到类别的编号。定义了网络后就可以像往常一样训练它了。

6.2.5 训练模型

        训练模型与在构建图像分类器时看到的非常类似,因此将使用相同的函数。我们把批数据传入模型并计算输出和损失,然后优化包括词向量权重在内的模型权重。以下代码执行此操作:

def fit(epoch,model,data_loader,phase=,training,,volatile=False):
    if phase == 'training':
        model.train()
    if phase == 'validation':
        model.eval()
        volatile = True 
    running_loss = 0.0
    running_correct = 0
    for batch_idx r batch in enumerate(data_loader):
        text, target = batch.text r batch.label
        if is_cuda:
            text,target = text.cuda(), target.cuda()
        if phase =='training':
            optimizer.zero_grad()
        output = model(text)
        loss = F.nll_loss(output,target)
        running loss += F.nll loss(output,target,size_average=False).data[0]
        preds = output.data.max(dim=1,keepdim=True)[1]
        running_correct += preds.eq(target.data.view_as(preds)).cpu().sum()
        if phase == 'training': 
            loss.backward()
            optimizer.step()
        loss = running_loss/len(data_loader.dataset)
        accuracy = 100. * running_correct/len(data_loader.dataset)
        print(f'{phase} loss is (loss:(5}.{2}} and {phase} accuracy is 
            {running_correct}/{len(data_loader.dataset)}{accuracy:{10}.{4}}')
        return loss,accuracy
    train_losses,train_accuracy = [],[]
    val_losses,val_accuracy = [],[]
    train_iter.repeat = False
    test_iter.repeat = False
    for epoch in range(1,10):
        epoch_loss,epoch_accuracy = fit(epoch,model,train_iter,phase='training')
        val_epoch_loss,val_epoch_accuracy = 
            fit(epoch,model,test_iter,phase='validation')
        train_losses.append(epoch_loss)
        train_accuracy.append(epoch_accuracy)
        val_losses.append(val_epoch_loss)
        val_accuracy.append(val_epoch_accuracy)

        在上述代码中,通过传入为批处理数据创建的 BucketIterator 对象来调用 fit 方法。默认情况下,迭代器不会停止生成批数据,因此必须将 BucketIterator 对象的 repeat 变量设置为 False。如果不将 repeat 变量设置为 False,那么 fit 函数将无限地运行。模型训练10轮后得到的验证准确率约为70%。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1866634.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

分享一个 MySQL 简单快速进行自动备份和还原的脚本和方法

前言 数据备份和还原在信息技术领域中具有非常重要的作用,不论是人为误操作、硬件故障、病毒感染、自然灾害还是其他原因,数据丢失的风险都是存在的。如果没有备份,一旦数据丢失,可能对个人、企业甚至整个组织造成巨大的损失。 …

2-17 基于matlab的改进的遗传算法(IGA)对城市交通信号优化分析

基于matlab的改进的遗传算法(IGA)对城市交通信号优化分析。根据交通流量以及饱和流量,对城市道路交叉口交通信号灯实施合理优化控制,考虑到交通状况的动态变化,及每个交叉口的唯一性。通过实时监测交通流量&#xff0c…

部署企业级AI知识库最重要的是什么?✍

随着人工智能技术的迅猛发展,企业级AI知识库成为提升企业管理效率和信息获取能力的重要工具。那么,在部署企业级AI知识库时,最重要的是什么呢?本文将从数据质量、系统可扩展性、用户体验以及智能化这四个关键方面进行详细分析。 …

单片机是否有损坏,怎沫判断

目录 1、操作步骤: 2、单片机损坏常见原因: 3、 单片机不工作的原因: 参考:细讲寄存器读写与Bit位操作原理--单片机C语言编程Bit位的与或非屏蔽运算--洋桃电子大百科P019_哔哩哔哩_bilibili 1、操作步骤: 首先需要…

Objects and Classes (对象和类)

Objects and Classes [对象和类] 1. Procedural and Object-Oriented Programming (过程性编程和面向对象编程)2. Abstraction and Classes (抽象和类)2.1. Classes in C (C 中的类)2.2. Implementing Class Member Functions (实现类成员函数)2.3. Using Classes References O…

第 28 篇 : SSH秘钥登录

1 生成秘钥 ssh-keygen -t rsa ls -a ./.ssh/一直回车就行了 2. 修改配置 vi /etc/ssh/sshd_config放开注释 公钥的位置修改 关闭密码登录 PubkeyAuthentication yes AuthorizedKeysFile .ssh/id_rsa.pub PasswordAuthentication no3. 下载id_rsa私钥, 自行解决 注意…

Vue中数组的【响应式】操作

在 Vue.js 中,当你修改数组时,Vue 不能检测到以下变动的数组: 当你利用索引直接设置一个项时,例如:vm.items[indexOfItem] newValue当你修改数组的长度时,例如:vm.items.length newLength 为…

常见图像分割模型介绍:FCN、U-Net、SegNet、Mask R-CNN

《博主简介》 小伙伴们好,我是阿旭。专注于人工智能、AIGC、python、计算机视觉相关分享研究。 ✌更多学习资源,可关注公-仲-hao:【阿旭算法与机器学习】,共同学习交流~ 👍感谢小伙伴们点赞、关注! 《------往期经典推…

AI助力校园安全:EasyCVR视频智能技术在校园欺凌中的应用

一、背景分析 近年来,各地深入开展中小学生欺凌行为治理工作,但有的地方学生欺凌事件仍时有发生,严重损害学生身心健康,引发社会广泛关注。为此,教育部制定了《防范中小学生欺凌专项治理行动工作方案》进一步防范和遏…

软件构造 | 期末查缺补漏

软件构造 | 期末查缺补漏 总体观 软件构造的三维度八度图是由软件工程师Steve McConnell提出的概念,用于描述软件构建过程中的三个关键维度和八个要素。这些维度和要素可以帮助软件开发团队全面考虑软件构建的方方面面,从而提高软件质量和开发效率。 下…

文件批量重命名001到100 最简单的数字序号递增的改名技巧

文件批量重命名001到100 最简单的数字序号递增的改名方法。最近看到很多人都在找怎么批量修改文件名称,还要按固定的ID需要递增,这个办法用F2或者右键改名是不能做到的。 这时候我们可以通过一个专业的文件批量重命名软件来批量处理这些文档。 芝麻文件…

A-8 项目开源 qt1.0

A-8 2024/6/26 项目开源 由于大家有相关的需求,就创建一个项目来放置相关的代码和项目 欢迎交流,QQ:963385291 介绍 利用opencascade和vulkanscene实现stp模型的查看器打算公布好几个版本的代码放在不同的分支下,用qt实现&am…

如何获取特定 HIVE 库的元数据信息如其所有分区表和所有分区

如何获取特定 HIVE 库的元数据信息如其所有分区表和所有分区 1. 问题背景 有时我们需要获取特定 HIVE 库下所有分区表,或者所有分区表的所有分区,以便执行进一步的操作,比如通过 使用 HIVE 命令 MSCK REPAIR TABLE table_name sync partiti…

Redis实战—基于setnx的分布式锁与Redisson

本博客为个人学习笔记,学习网站与详细见:黑马程序员Redis入门到实战 P56 - P63 目录 分布式锁介绍 基于SETNX的分布式锁 SETNX锁代码实现 修改业务代码 SETNX锁误删问题 SETNX锁原子性问题 Lua脚本 编写脚本 代码优化 总结 Redisson 前言…

基于盲信号处理的人声分离

1.问题描述 在实际生活中,存在一种基本现象称为“鸡尾酒效应”,该效应指即使在非常嘈杂的环境中,人依然可以从噪声中提取出自己所感兴趣的声音。 在实际应用中,我们可能需要对混合的声音进行分离,此时已知的只有混合…

springcloud第4季 springcloud-alibaba之openfegin+sentinel整合案例

一 介绍说明 1.1 说明 1.1.1 消费者8081 1.1.2 openfegin接口 1.1.3 提供者9091 9091微服务满足: 1 openfegin 配置fallback逻辑,作为统一fallback服务降级处理。 2.sentinel访问触发了自定义的限流配置,在注解sentinelResource里面配置…

DVWA 靶场 SQL Injection 通关解析

前言 DVWA代表Damn Vulnerable Web Application,是一个用于学习和练习Web应用程序漏洞的开源漏洞应用程序。它被设计成一个易于安装和配置的漏洞应用程序,旨在帮助安全专业人员和爱好者了解和熟悉不同类型的Web应用程序漏洞。 DVWA提供了一系列的漏洞场…

数学建模--Matlab求解线性规划问题两种类型实际应用

1.约束条件的符号一致 (1)约束条件的符号一致的意思就是指的是这个约束条件里面的,像这个下面的实例里面的三个约束条件,都是小于号,这个我称之为约束条件符号一致; (2)下面的就是上…

关于linux的图形界面

关于linux的图形界面 1. 概述1.1 X1.2 DM(显示管理器/登录管理器)1.3 WM(窗口管理器)1.4 GUI Toolkits1.5 Desktop Environment1.6 基本架构 2. 安装桌面2.1 Centos安装桌面2.2 Ubuntu安装桌面(未实践) 3. …