Pytorch 文本情感分类案例

news2024/12/26 22:15:40

一共六个脚本,分别是:

        ①generateDictionary.py用于生成词典

        ②datasets.py定义了数据集加载的方法

        ③models.py定义了网络模型

        ④configs.py配置一些参数

        ⑤run_train.py训练模型

        ⑥run_test.py测试模型

数据集icon-default.png?t=N7T8https://download.csdn.net/download/Victor_Li_/88486959?spm=1001.2014.3001.5501停用词表icon-default.png?t=N7T8https://download.csdn.net/download/Victor_Li_/88486973?spm=1001.2014.3001.5501

generateDictionary.py如下

import jieba

data_path = "./weibo_senti_100k.csv"
data_stop_path = "./hit_stopwords.txt"
data_list = open(data_path,encoding='utf-8').readlines()[1:]
stops_word = open(data_stop_path,encoding='utf-8').readlines()
stops_word = [line.strip() for line in stops_word]
stops_word.append(" ")
stops_word.append("\n")

voc_dict = {}
min_seq = 1
top_n = 1000
UNK = "UNK"
PAD = "PAD"
for item in data_list:
    label = item[0]
    content = item[2:].strip()
    seg_list = jieba.cut(content,cut_all=False)

    seg_res = []
    for seg_item in seg_list:
        if seg_item in stops_word:
            continue
        seg_res.append(seg_item)
        if seg_item in voc_dict.keys():
            voc_dict[seg_item] += 1
        else:
            voc_dict[seg_item] = 1

    # print(content)
    # print(seg_res)

    voc_list = sorted([_ for _ in voc_dict.items() if _[1] > min_seq],key=lambda x:x[1],reverse=True)[:top_n]
    voc_dict = {word_count[0]:idx for idx,word_count in enumerate(voc_list)}
    voc_dict.update({UNK:len(voc_dict),PAD:len(voc_dict)+1})

ff = open("./dict","w")
for item in voc_dict.keys():
    ff.writelines("{},{}\n".format(item,voc_dict[item]))
ff.close()

datasets.py如下

from torch.utils.data import Dataset, DataLoader
import jieba
import numpy as np


def read_dict(voc_dict_path):
    voc_dict = {}
    with open(voc_dict_path, 'r') as f:
        for line in f:
            line = line.strip()
            if line == '':
                continue
            word, index = line.split(",")
            voc_dict[word] = int(index)
    return voc_dict


def load_data(data_path, data_stop_path,isTest):
    data_list = open(data_path, encoding='utf-8').readlines()[1:]
    stops_word = open(data_stop_path, encoding='utf-8').readlines()
    stops_word = [line.strip() for line in stops_word]
    stops_word.append(" ")
    stops_word.append("\n")

    voc_dict = {}
    data = []
    max_len_seq = 0
    for item in data_list:
        label = item[0]
        content = item[2:].strip()
        seg_list = jieba.cut(content, cut_all=False)

        seg_res = []
        for seg_item in seg_list:
            if seg_item in stops_word:
                continue
            seg_res.append(seg_item)
            if seg_item in voc_dict.keys():
                voc_dict[seg_item] += 1
            else:
                voc_dict[seg_item] = 1
        if len(seg_res) > max_len_seq:
            max_len_seq = len(seg_res)
        if isTest:
            data.append([label, seg_res,content])
        else:
            data.append([label, seg_res])
    return data, max_len_seq


class text_ClS(Dataset):
    def __init__(self, data_path, data_stop_path,voc_dict_path,isTest=False):
        self.isTest = isTest
        self.data_path = data_path
        self.data_stop_path = data_stop_path
        self.voc_dict = read_dict(voc_dict_path)
        self.data, self.max_len_seq = load_data(self.data_path, self.data_stop_path,isTest)
        np.random.shuffle(self.data)

    def __len__(self):
        return len(self.data)

    def __getitem__(self, item):
        data = self.data[item]
        label = int(data[0])
        word_list = data[1]
        if self.isTest:
            content = data[2]
        input_idx = []
        for word in word_list:
            if word in self.voc_dict.keys():
                input_idx.append(self.voc_dict[word])
            else:
                input_idx.append(self.voc_dict["UNK"])
        if len(input_idx) < self.max_len_seq:
            input_idx += [self.voc_dict["PAD"] for _ in range(self.max_len_seq - len(input_idx))]
        data = np.array(input_idx)
        if self.isTest:
            return label,data,content
        else:
            return label, data

def data_loader(dataset,config):
    return DataLoader(dataset,batch_size=config.batch_size,shuffle=config.is_shuffle,num_workers=4,pin_memory=True)

models.py如下

import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np

class Model(nn.Module):
    def __init__(self,config):
        super(Model,self).__init__()
        self.embeding = nn.Embedding(config.n_vocab,config.embed_size,padding_idx=config.n_vocab - 1)
        self.lstm = nn.LSTM(config.embed_size,config.hidden_size,config.num_layers,batch_first=True,bidirectional=True,dropout=config.dropout)
        self.maxpool = nn.MaxPool1d(config.pad_size)
        self.fc = nn.Linear(config.hidden_size * 2 + config.embed_size,config.num_classes)
        self.softmax = nn.Softmax(dim=1)

    def forward(self,x):
        embed = self.embeding(x)
        out, _ = self.lstm(embed)
        out = torch.cat((embed, out), 2)
        out = F.relu(out)
        out = out.permute(0, 2, 1)
        out = self.maxpool(out).reshape(out.size()[0],-1)
        out = self.fc(out)
        out = self.softmax(out)
        return out

configs.py如下

import torch.types


class Config():
    def __init__(self):
        self.n_vocab = 1002
        self.embed_size = 256
        self.hidden_size = 256
        self.num_layers = 5
        self.dropout = 0.8
        self.num_classes = 2
        self.pad_size = 32
        self.batch_size = 32
        self.is_shuffle = True
        self.learning_rate = 0.001
        self.num_epochs = 100
        self.devices = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

run_train.py如下

import torch
import torch.nn as nn
from torch import optim
from models import Model
from datasets import data_loader,text_ClS
from configs import Config
import time
import torch.multiprocessing as mp

if __name__ == '__main__':
    mp.freeze_support()
    cfg = Config()

    data_path = "./weibo_senti_100k.csv"
    data_stop_path = "./hit_stopwords.txt"
    dict_path = "./dict"

    dataset = text_ClS(data_path, data_stop_path, dict_path)
    train_dataloader = data_loader(dataset,cfg)

    cfg.pad_size = dataset.max_len_seq

    model_text_cls = Model(cfg)
    model_text_cls.to(cfg.devices)

    loss_func = nn.CrossEntropyLoss()
    optimizer = optim.Adam(model_text_cls.parameters(), lr=cfg.learning_rate)
    scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=1, gamma=0.9)

    for epoch in range(cfg.num_epochs):
        running_loss = 0
        correct = 0
        total = 0
        epoch_start_time = time.time()
        for i,(labels,datas) in enumerate(train_dataloader):
            datas = datas.to(cfg.devices)
            labels = labels.to(cfg.devices)

            pred = model_text_cls.forward(datas)
            loss_val = loss_func(pred,labels)
            running_loss += loss_val.item()
            loss_val.backward()
            if ((i + 1) % 4 == 0) or (i + 1 == len(train_dataloader)):
                optimizer.step()
                optimizer.zero_grad()
            _, predicted = torch.max(pred.data, 1)
            correct += (predicted == labels).sum().item()
            total += labels.size(0)
        scheduler.step()
        accuracy_train = 100 * correct / total
        epoch_end_time = time.time()
        epoch_time = epoch_end_time - epoch_start_time
        tain_loss = running_loss / len(train_dataloader)
        print("Epoch [{}/{}],Time: {:.4f}s,Loss: {:.4f},Acc: {:.2f}%".format(epoch + 1, cfg.num_epochs, epoch_time, tain_loss,accuracy_train))
        torch.save(model_text_cls.state_dict(),"./text_cls_model/text_cls_model{}.pth".format(epoch))

run_test.py如下

import torch
import torch.nn as nn
from torch import optim
from models import Model
from datasets import data_loader,text_ClS
from configs import Config
import time
import torch.multiprocessing as mp

if __name__ == '__main__':
    mp.freeze_support()
    cfg = Config()
    data_path = "./test.csv"
    data_stop_path = "./hit_stopwords.txt"
    dict_path = "./dict"
    cfg.batch_size = 1
    dataset = text_ClS(data_path, data_stop_path, dict_path,isTest=True)
    dataloader = data_loader(dataset,cfg)

    cfg.pad_size = dataset.max_len_seq

    model_text_cls = Model(cfg)
    model_text_cls.load_state_dict(torch.load('./text_cls_model/text_cls_model0.pth'))
    model_text_cls.to(cfg.devices)
    classes_name = ['负面的','正面的']
    for i,(label,input,content) in enumerate(dataloader):
        label = label.to(cfg.devices)
        input = input.to(cfg.devices)
        pred = model_text_cls.forward(input)
        _, predicted = torch.max(pred.data, 1)
        print("内容:{}, 实际结果:{}, 预测结果:{}".format(content,classes_name[label],classes_name[predicted[0]]))

测试结果如下

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1152923.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

可视化文件编辑与SSH传输神器WinSCP如何公网远程访问本地服务器

&#x1f525;博客主页&#xff1a; 小羊失眠啦. &#x1f3a5;系列专栏&#xff1a;《C语言》 《数据结构》 《Linux》《Cpolar》 ❤️感谢大家点赞&#x1f44d;收藏⭐评论✍️ 可视化文件编辑与SSH传输神器WinSCP如何公网远程本地服务器 文章目录 可视化文件编辑与SSH传输神…

如何在 Photoshop 中使用污点修复画笔

学习污点修复画笔工具的基础知识&#xff0c;以及如何使用它来修复、平滑和删除图像中不需要的部分 1.如何在 Photoshop 中使用污点修复画笔 步骤1 在 Photoshop 中使用污点修复画笔的方法有很多。今天&#xff0c;让我们用它来去除这只手臂上的一些较小的纹身。 步骤2 在我…

在科技展厅设计中,如何通过空间规划来突出展品和主题?

数字多媒体技术在各行业内的广泛应用&#xff0c;使内容展览展示技术得到了更新&#xff0c;尤其是在科技展厅设计中&#xff0c;更是将各类多媒体互动装置的优势发挥到了极致&#xff0c;为观众提供现代化的感官体验&#xff0c;而这其中有效的空间规划对于现代化科技展厅的效…

【开发日记】必须记录一下困扰我两天的问题 MyBatisPlus适配达梦insert时提示:无效的列

【需求】 项目ORM框架使用的是MyBatisPlus&#xff0c;数据库原来使用的是MySQL&#xff0c;现在需要适配达梦。 【问题】 项目ORM框架使用的是MyBatisPlus&#xff0c;数据库原来使用的是MySQL&#xff0c;现在需要适配达梦数据库。 在适配过程中查询、更新、删除都没有问题…

LeetCode刷题---找出字符串中第一个匹配项的下标(Java实现KMP算法)

朴素算法 朴素算法是用来解决字符串匹配的问题的&#xff0c;现有主串aaaab和子串aab,如果使用朴素算法解决该问题&#xff0c;它首先会将主串的第一个字符和子串的第一个字符进行比较&#xff0c;如果主串和子串第一个字符相同&#xff0c;则比较第二个字符&#xff0c;依次往…

基于springboot鞋包商城-计算机毕设 附源码 28653

springboot鞋包商城 摘 要 鞋包商城采用B/S结构、java开发语言、以及Mysql数据库等技术。系统主要分为管理员和用户及卖家三部分&#xff0c;管理员管理主要功能包括&#xff1a;首页、网站管理&#xff08;轮播图、网站公告&#xff09;、人员管理&#xff08;管理员、卖家、…

KT6368A蓝牙芯片的4脚也就是蓝牙天线脚对地短路了呢?是不是坏了

一、问题简介 KT6368A芯片的4脚&#xff0c;也就是蓝牙天线脚&#xff0c;万用表测量对地短路了呢&#xff1f;是不是芯片坏掉了&#xff0c;能不能重新寄样品给我。 详细说明 首先&#xff0c;芯片没有坏&#xff0c;遇到自己不懂的地方&#xff0c;不要轻易的去怀疑。 而是…

ZKP7.3 Linear-time encodable code based on expanders

ZKP学习笔记 ZK-Learning MOOC课程笔记 Lecture 7: Polynomial Commitments Based on Error-correcting Codes (Yupeng Zhang) 7.3 Linear-time encodable code based on expanders SNARKs with linear prover time Linear-time encodable code [Spielman’96][Druk-Ishai…

会议邀请 | 思腾合力邀您共赴第二十一届中国电博会与元宇宙产业论坛

由国务院台湾事务办公室、江苏省人民政府主办的「第二十一届中国&#xff08;苏州&#xff09;电子信息博览会」将于2023年11月9日-11日在苏州国际博览中心举办。思腾合力作为行业领先的人工智能基础架构解决方案商&#xff0c;受邀参加本次盛会。思腾合力市场总监徐莉受邀出席…

构造最小堆、最小堆排序

堆是一种特殊的完全二叉树 堆具有以下方法 关键点&#xff1a; 插入&#xff1a;在 store 末端插入新元素&#xff0c;然后把新元素上浮。弹出&#xff1a;将 store 顶端&#xff08;索引为0处&#xff09;弹出&#xff0c;作为最小元素&#xff1b;把末端元素放到索引0处&a…

动态表单生成Demo(Vue+elment)

摘要&#xff1a;本文将介绍如何使用vue和elment ui组件库实现一个简单的动态表单生成的Demo。主要涉及两个.vue文件的书写&#xff0c;一个是动态表单生成的组件文件&#xff0c;一个是使用该动态表单生成的组件。 1.动态表单生成组件 这里仅集成了输入框、选择框、日期框三种…

WebAssembly完全入门——了解wasm的前世今身

前言 接触WebAssembly之后&#xff0c;在google上看了很多资料。感觉对WebAssembly的使用、介绍、意义都说的比较模糊和笼统。感觉看了之后收获没有达到预期&#xff0c;要么是文章中的例子自己去实操不能成功&#xff0c;要么就是不知所云、一脸蒙蔽。本着业务催生技术的态度&…

Linux 环境下 安装 Elasticsearch 7.13.2

Linux 环境下 安装 Elasticsearch 7.13.2 前言镜像下载&#xff08;国内镜像地址&#xff09;解压安装包修改配置文件用 Es 自带Jdk 运行配置 Es 可被远程访问然后启动接着启动本地测试一下能不能连 Es 前言 借公司的 centos 7 服务器&#xff0c;搭建一个 Es&#xff0c;正好熟…

8.3 矢量图层点要素单一符号使用五

文章目录 前言单一符号&#xff08;Single symbol&#xff09;渲染几何生成器&#xff08;Geometry generator&#xff09;QGis代码实现 总结 前言 上一篇教程介绍了矢量图层点要素单一符号中填充标记的用法本章继续介绍单一符号中各种标记的用法说明&#xff1a;文章中的示例…

AR的光学原理?

AR智能眼镜的光学成像系统 AR眼镜的光学成像系统由微型显示屏和光学镜片组成&#xff0c;可以将其理解为智能手机的屏幕。 增强现实&#xff0c;从本质上说&#xff0c;是将设备生成的影像与现实世界进行叠加融合。这种技术基本就是通过光学镜片组件对微型显示屏幕发出的光线…

java项目之中学校园网站(ssm框架)

项目简介 中学校园网站实现了以下功能&#xff1a; 管理员&#xff1a;个人中心、教师管理、学生管理、校园概况管理、名师风采管理、校园公告管理、试卷管理、试题管理、校园论坛、系统管理、考试管理。教师&#xff1a;个人中心、校园概况管理、名师风采管理、校园公告管理…

PM866 3BSE050200R1 L003748-AR 3BSX108237R300

PM866 3BSE050200R1 L003748-AR 3BSX108237R300 工业自动化制造商和工业物联网工具开发商Opto 22宣布推出新版groov&#xff0c;将IIoT technologies MQTT和OPC-UA驱动程序直接嵌入其工业边缘设备。新版本添加到用于web和移动可视化的groov View软件以及开源的Node-RED开发环境…

第三方支付预付卡业务详解

第三方支付预付卡业务详解 第三方支付预付卡业务是指由第三方支付公司提供的一种预先充值后消费的支付方式。用户可以在第三方支付平台上购买预付卡&#xff0c;然后在指定的商户或者服务提供商那里进行消费。 运作模式&#xff1a; 1. 用户在第三方支付平台购买预付卡&#xf…

Django项目单字段的区间查询

在Django项目中会碰到一些需求就是查询某个表中的一些字段从某日到某日的数据&#xff0c;而且是对但字段查询这个时候我们有两两种方法解决 单字段类型是DateTimeField的 查询日期范围的 这个时候在filter.py里面重写DateTimeFromToRangeFilter&#xff0c;为什么要重写呢&am…

kubernetes部署(web界面)

基本队对象 pod 最小单位 service 跟网络相关 Volume Namespace 准备工作&#xff1a; master node1 node2 修改主机名&#xff1a; 做本地解析 10.0.0.51 master 10.0.0.56 node-1 10.0.0.186 node-2 关闭swap分区&#xff1a; swapoff -a  临时关闭 …