python pytorch实现RNN,LSTM,GRU,文本情感分类

news2024/9/22 22:31:05

python pytorch实现RNN,LSTM,GRU,文本情感分类

数据集格式:
在这里插入图片描述
有需要的可以联系我

实现步骤就是:
1.先对句子进行分词并构建词表
2.生成word2id
3.构建模型
4.训练模型
5.测试模型

代码如下:


import pandas as pd
import torch
import matplotlib.pyplot as plt
import jieba
import numpy as np

"""
作业:
一、完成优化
优化思路

1 jieba
2 取常用的3000字
3 修改model:rnn、lstm、gru

二、完成测试代码
"""

# 了解数据
dd = pd.read_csv(r'E:\peixun\data\train.csv')
# print(dd.head())

# print(dd['label'].value_counts())

# 句子长度分析
# 确定输入句子长度为 500
text_len = [len(i) for i in dd['text']]
# plt.hist(text_len)
# plt.show()
# print(max(text_len), min(text_len))

# 基本参数 config
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print('my device:', DEVICE)

MAX_LEN = 500
BATCH_SIZE = 16
EPOCH = 1
LR = 3e-4

# 构建词表 word2id
vocab = []
for i in dd['text']:
    vocab.extend(jieba.lcut(i, cut_all=True))  # 使用 jieba 分词
    # vocab.extend(list(i))

vocab_se = pd.Series(vocab)
print(vocab_se.head())
print(vocab_se.value_counts().head())

vocab = vocab_se.value_counts().index.tolist()[:3000]  # 取频率最高的 3000 token
# print(vocab[:10])
# exit()

WORD_PAD = "<PAD>"
WORD_UNK = "<UNK>"
WORD_PAD_ID = 0
WORD_UNK_ID = 1

vocab = [WORD_PAD, WORD_UNK] + list(set(vocab))

print(vocab[:10])
print(len(vocab))

vocab_dict = {k: v for v, k in enumerate(vocab)}

# 词表大小,vocab_dict: word2id; vocab: id2word
print(len(vocab_dict))

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils import data
import pandas as pd


# 定义数据集 Dataset
class Dataset(data.Dataset):
    def __init__(self, split='train'):
        # ChnSentiCorp 情感分类数据集
        path =  r'E:/peixun/data/' + str(split) + '.csv'
        self.data = pd.read_csv(path)

    def __len__(self):
        return len(self.data)

    def __getitem__(self, i):
        text = self.data.loc[i, 'text']
        label = self.data.loc[i, 'label']

        return text, label


# 实例化 Dataset
dataset = Dataset('train')

# 样本数量
print(len(dataset))
print(dataset[0])


# 句子批处理函数
def collate_fn(batch):
    # [(text1, label1), (text2, label2), (3, 3)...]
    sents = [i[0][:MAX_LEN] for i in batch]
    labels = [i[1] for i in batch]

    inputs = []
    # masks = []

    for sent in sents:
        sent = [vocab_dict.get(i, WORD_UNK_ID) for i in list(sent)]
        pad_len = MAX_LEN - len(sent)

        # mask = len(sent) * [1] + pad_len * [0]
        # masks.append(mask)

        sent += pad_len * [WORD_PAD_ID]

        inputs.append(sent)

    # 只使用 lstm 不需要用 masks
    # masks = torch.tensor(masks)
   # print(inputs)
    inputs = torch.tensor(inputs)
    labels = torch.LongTensor(labels)

    return inputs.to(DEVICE), labels.to(DEVICE)


# 测试 loader
loader = data.DataLoader(dataset,
                         batch_size=BATCH_SIZE,
                         collate_fn=collate_fn,
                         shuffle=True,
                         drop_last=False)

inputs, labels = iter(loader).__next__()
print(inputs.shape, labels)


# 定义模型
class Model(nn.Module):
    def __init__(self, vocab_size=5000):
        super().__init__()
        self.embed = nn.Embedding(vocab_size, 100, padding_idx=WORD_PAD_ID)

        # 多种 rnn
        self.rnn = nn.RNN(100, 100, 1, batch_first=True, bidirectional=True)
        self.gru = nn.GRU(100, 100, 1, batch_first=True, bidirectional=True)
        self.lstm = nn.LSTM(100, 100, 1, batch_first=True, bidirectional=True)

        self.l1 = nn.Linear(500 * 100 * 2, 100)
        self.l2 = nn.Linear(100, 2)

    def forward(self, inputs):
        out = self.embed(inputs)
        out, _ = self.lstm(out)
        out = out.reshape(BATCH_SIZE, -1)  # 16 * 100000
        out = F.relu(self.l1(out))  # 16 * 100
        out = F.softmax(self.l2(out))  # 16 * 2

        return out


# 测试 Model
model = Model()
print(model)

# 模型训练
dataset = Dataset()
loader = data.DataLoader(dataset,
                         batch_size=BATCH_SIZE,
                         collate_fn=collate_fn,
                         shuffle=True)

model = Model().to(DEVICE)

# 交叉熵损失
loss_fn = nn.CrossEntropyLoss()
optimizer = torch.optim.AdamW(model.parameters(), lr=LR)

model.train()
for e in range(EPOCH):
    for idx, (inputs, labels) in enumerate(loader):
        # 前向传播,计算预测值
        out = model(inputs)
        # 计算损失
        loss = loss_fn(out, labels)
        # 反向传播,计算梯度
        loss.backward()
        # 参数更新
        optimizer.step()
        # 梯度清零
        optimizer.zero_grad()

        if idx % 10 == 0:
            out = out.argmax(dim=-1)
            acc = (out == labels).sum().item() / len(labels)

            print('>>epoch:', e,
                  '\tbatch:', idx,
                  '\tloss:', loss.item(),
                  '\tacc:', acc)

# 模型测试
test_dataset = Dataset('test')
test_loader = data.DataLoader(test_dataset,
                              batch_size=BATCH_SIZE,
                              collate_fn=collate_fn,
                              shuffle=False)

loss_fn = nn.CrossEntropyLoss()

out_total = []
labels_total = []

model.eval()
for idx, (inputs, labels) in enumerate(test_loader):
    out = model(inputs)
    loss = loss_fn(out, labels)

    out_total.append(out)
    labels_total.append(labels)

    if idx % 50 == 0:
        print('>>batch:', idx, '\tloss:', loss.item())
        
correct=0
sumz=0
for i in range(len(out_total)):
   out = out_total[i].argmax(dim=-1)
   correct = (out == labels_total[i]).sum().item() +correct
   sumz=sumz+len(labels_total[i])
    #acc = (out_total == labels_total).sum().item() / len(labels_total)

print('>>acc:', correct/sumz)

运行结果如下:
在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1272863.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

Linux多线程同步

Linux多线程同步 1、线程同步的概念1.1 为什么要同步1.2 同步方式 2、互斥锁2.1 互斥锁函数2.2 互斥锁使用 3、死锁4、读写锁4.1 读写锁函数4.2 读写锁使用 5、条件变量5.1 条件变量函数5.2 生产者和消费者 6、信号量6.1 信号量函数6.2 生产者和消费者6.3 信号量的使用6.3.1 总…

Android系统源码中添加可编译运行执行程序,java

文章目录 Android系统源码中添加可编译运行执行程序&#xff0c;java1.Android设备中执行编译运行java代码2.编译执行jar包 Android系统源码中添加可编译运行执行程序&#xff0c;java 1.Android设备中执行编译运行java代码 新建一个文件夹&#xff0c;以及Java类的包路径 测…

AI 编程如何助力开发者高效完成架构设计工作?

▼最近直播超级多&#xff0c;预约保你有收获 今晚直播&#xff1a;《AI 编程技术架构剖析和案例开发实战》 —1— AI 编程能帮我们完成哪些工作&#xff1f; 从目前企业级种种现实场景应用来看&#xff0c;AI 编程已经成为一种帮助开发者解决架构设计复杂问题、提高编程效率以…

C/C++转义符:\x

文章目录 什么是转义符使用"\x"定义char数组宏定义中的\ 什么是转义符 在C语言中&#xff0c;转义符用于将一些特殊字符表示为单个字符&#xff0c;常用的转义符有&#xff1a; \\&#xff1a;反斜杠符号\&#xff1a;单引号\"&#xff1a;双引号\a&#xff1…

力扣 --- 删除有序数组中的重复项 II

题目描述&#xff1a; 给你一个有序数组 nums &#xff0c;请你 原地 删除重复出现的元素&#xff0c;使得出现次数超过两次的元素只出现两次 &#xff0c;返回删除后数组的新长度。 不要使用额外的数组空间&#xff0c;你必须在 原地 修改输入数组 并在使用 O(1) 额外空间的…

考试复习

选择20道 填空10道 判断10道 简答4-5道 编程题2道 一、选择题 1.js中更改一个input框的值&#xff1a; <input ida type"text" value"123456"> 通过a.value改变他的值 方法&#xff1a; 在script标签中通过id获得该输入框对象&#xff0c;然…

记录一次爱快路由ACL策略引起的大坑

环境&#xff1a; A公司和B公司采用爱快的ipsec互联 B公司同时有加密软件限制网络 问题&#xff1a;对方ERP无法连接我们的数据库服务器 先简单测试了下1433端口是不是通的 下面的测试结果&#xff0c;直接ping是通的&#xff0c;但是加上1433端口后就不通 排查过程&#xff1…

高等数学上岸宝典笔记

①不单调的函数也可能有反函数 ②注意反函数与函数转换时的定义域与值域 ③收敛数列不一定有最值 收敛数列必有上界和下界&#xff0c;但不一定有最值&#xff0c;比如{An}1/n&#xff0c;下界为0&#xff0c;但永远取不到0 ④数列与其子数列的关系 例题&#xff1a; ⑤带根号…

道路病害检测数据集RDD2022的标签映射关系【参考自官网给出的label_map.pbtxt文件,附查看代码】

TOC 结论 Label ID: 1, Label Name: D00 Label ID: 2, Label Name: D10 Label ID: 3, Label Name: D20 Label ID: 4, Label Name: D40链接地址 https://github.com/sekilab/RoadDamageDetector/ 查看代码 # 打开 label_map.pbtxt 文件 def read_label_map(file_path):label…

模拟算法【2】

文章目录 &#x1f958;6. N 字形变换&#x1f372;题目&#x1fad5;算法原理&#x1f963;代码实现 &#x1f957;38. 外观数列&#x1f37f;题目&#x1f9c2;算法原理&#x1f9c8;代码实现 &#x1f958;6. N 字形变换 &#x1f372;题目 题目链接&#xff1a;6. N 字形变…

【linux网络】补充网关服务器搭建,综合应用SNAT、DNAT转换,dhcp分配、dns分离解析,nfs网络共享以及ssh免密登录

目录 linux网络的综合应用 1&#xff09;网关服务器&#xff1a;ens35&#xff1a;12.0.0.254/24&#xff0c;ens33&#xff1a;192.168.100.254/24&#xff1b;Server1&#xff1a;192.168.100.101/24&#xff1b;PC1和server2&#xff1a;自动获取IP&#xff1b;交换机无需…

python中的字符串

字符串 字符串是编程语言中的一种基本数据类型&#xff0c;用于表示一串字符序列。在Python中&#xff0c;字符串是不可变的&#xff0c;也就是说一旦字符串被创建&#xff0c;就无法修改其中的字符。 Python中的字符串可以用单引号或双引号括起来&#xff0c;例如&#xff1…

手机电脑同步的时间管理工具

有不少上班族会发现自己有太多的工作要完成&#xff0c;并且在工作中往往会浪费很多时间在无关紧要的事情上&#xff0c;而不是专注于真正重要的任务&#xff0c;因此没有足够的时间来完成所有任务。在这种情况下&#xff0c;我们可以使用时间管理软件来帮助自己优先考虑重要的…

Android 架构实战MVI进阶

MVI架构的原理和流程 MVI架构是一种基于响应式编程的架构模式&#xff0c;它将应用程序分为四个核心组件&#xff1a;模型&#xff08;Model&#xff09;、视图&#xff08;View&#xff09;、意图&#xff08;Intent&#xff09;和状态&#xff08;State&#xff09;。 原理&…

Mybatisplus同时向两张表里插入数据[事务的一致性]

一、需求&#xff1a;把靶器官的数据&#xff0c;单独拿出来作为一个从表&#xff0c;以List的方式接收这段数据&#xff1b; 此时分析&#xff0c;是需要有两个实体的&#xff0c;一个是主表的实体&#xff0c;一个是从表的实体&#xff0c;并在主表实体新增一个List 字段来接…

免费WordPress站群插件-批量管理站群的免费软件

WordPress站群插件&#xff1a;让文章管理如丝般顺滑 在众多网站建设工具中&#xff0c;WordPress一直以其简便易用、丰富的插件生态而备受青睐。对于站群管理者而言&#xff0c;如何高效地更新、发布和推送文章是一项不可忽视的任务。本文将专注分享一款WordPress站群插件&am…

Rust的Vec优化

本篇是对Rust编程语言17_Rust的Vec优化[1]学习与记录 MiniVec https://crates.io/crates/minivec enum DataWithVec { // tag,uint64,8字节 I32(i32), // 4字节,但需内存对齐到8字节? F64(f64), // 8字节 Bytes(Vec<u8>), // 24字节}fn main()…

watch函数与watchEffect函数

watach函数&#xff1a; 与vue2.x的配置功能一致 监视指定的一个或多个响应式数据, 一旦数据变化, 就自动执行监视回调 默认初始时不执行回调, 但可以通过配置immediate为true, 来指定初始时立即执行第一次 通过配置deep为true, 来指定深度监视 watchEffect函数&#xff1a;…

电商项目之Web实时消息推送(附源码)

文章目录 1 问题背景2 前言3 什么是消息推送4 短轮询5 长轮询5.1 demo代码 6 iframe流6.1 demo代码 7 SSE7.1 demo代码7.2 生产环境的应用 &#xff08;重要&#xff09; 8 MQTT 1 问题背景 扩宽自己的知识广度&#xff0c;研究一下web实时消息推送 2 前言 文章参考自Web 实时消…

什么是PDN的交流阻抗?

什么是PDN的交流阻抗&#xff1f; 在电力电子领域&#xff0c;PDN&#xff08;Power Distribution Network&#xff09;的交流阻抗是一个重要的概念&#xff0c;它反映了PDN在交流电源和负载之间传输电能的能力。了解PDN的交流阻抗对于优化电源设计、提高系统性能和可靠性具有重…