pytorch集智4-情绪分类器

news2024/11/18 17:30:19

1 目标

从中文文本中识别出句子里的情绪。和上一章节单车预测回归问题相比,这个问题是分类问题,不是回归问题

2 神经网络分类器

2.1 如何用神经网络分类

第二章节用torch.nn.Sequantial做的回归预测器,输出神经元只有一个。分类器和其区别如下

1 分类器输出单元有几个,就有几个分类

2 分类器各输出单元输出值加和为1(值表示某个预测分类的概率,概率求和为1)

3 回归预测最后一层可以用sigmoid,分类不行,因为sigmoid无法保证分类器输出单元值求和为1,可以用softmax,表达式如下

2.2 分类问题损失函数

分类问题损失函数一般用交叉熵:

N表示样本数量,Ci表示第i个样本的类别标签,y表示每个样本的损失值

3 词袋模型分类器

直白理解:就是统计句子里各词语出现数量,忽略语法,语义等因素

思想:1 将句子转化为词袋,然后可以onehot处理,方便用神经网络预测;2 大概思想:在词袋里寻找和输出相关关键词语,然后统计词语数量,哪个输出类别的敏感词语统计数量高,结果就是哪个输出类别

数据预处理:可以归一化处理,网络计算会更快

3.1 简单文本分类器

背景与数据:获取网购商城顾客评论信息,从评论信息获取句子情绪

神经网络组成:用前馈神经网络,有3层,分别是输入,中间,输出层,输出层数量有两个,分别代表情绪是正面还是负面

3.2 程序实现

3.2.1 数据处理

主要步骤:处理标点,句子分词,建立词表

标点处理:正则,涉及的标点全部删除

句子分词:使用jieba模块对句子分词

建立词表:python字典建立词表


def deal_punc(sentence):
    sentence = re.sub('[\s+\.\!\/_,$%^*(+\"\'“”《》?]+|[+——!,。?、~@#¥%……&*():)]+', '', sentence)
    return sentence


def prepare_data(good_file, bad_file, is_filter=True):
    print('prepare data begin')
    all_words = []
    pos_sen, neg_sen = [], []
    for path, sen in zip((good_file, bad_file), (pos_sen, neg_sen)):
        with open(path, 'r', encoding='utf-8') as f:
            for index, line in enumerate(f):
                if is_filter:
                    line = deal_punc(line)
                
                words = jieba.lcut(line)
                if len(words) > 0:
                    all_words += words
                    sen.append(words)
            print(f'{path} include {index} rows, all words:{len(all_words)}')
    print(f'pos_sen len:{len(pos_sen)}, neg_sen len:{len(neg_sen)}')
            
    diction = {}
    cnt = Counter(all_words)
    for word, freq in cnt.items():
        diction[word] = [len(diction), freq]
    print(f'diction len:{len(diction)}')
    return (pos_sen, neg_sen, diction)

def word2index(word, diction):
    if word in diction:
        value = diction[word][0]
    else:
        value = -1
    return (value)

def index2word(index, diction):
    for w, v in diction.items():
        if v[0] == index:
            return (w)
    return (None)

3.2.2 文本数据向量化

    pos_sen, neg_sen, diction = prepare_data_sample(good_file, bad_file)
    st = sorted([(v[1], w) for w, v in diction.items()])
    datasets, labels, sentences = [], [], []
    '''
    for sens, tag in zip((pos_sen, neg_sen), (0, 1)):
        for sen in sens:
            new_sen = []
            for l in sen:
                if l in diction:
                    new_sen.append(word2index_sample(1, diction))
            datasets.append(sentence2vec_sample(new_sen, diction))
            labels.append(tag)
            sentences.append(sen)
    '''
    for sentence in pos_sen:
        new_sentence = []
        for l in sentence:
            if l in diction:
                new_sentence.append(word2index(l, diction))
        datasets.append(sentence2vec_sample(new_sentence, diction))
        labels.append(0) #正标签为0
        sentences.append(sentence)

    # 处理负向评论
    for sentence in neg_sen:
        new_sentence = []
        for l in sentence:
            if l in diction:
                new_sentence.append(word2index(l, diction))
        datasets.append(sentence2vec_sample(new_sentence, diction))
        labels.append(1) #负标签为1
        sentences.append(sentence)
    indices = np.random.permutation(len(datasets))
    datasets = [datasets[i] for i in indices]
    labels = [labels[i] for i in indices]
    sentences = [sentences[i] for i in indices]

3.3.3 划分数据集

训练集,校验集,测试集,比例一般为10:1:1

    # split train, test, verify datasets
    test_size = int(len(datasets) // 10)
    train_data = datasets[2 * test_size :]
    train_label = labels[2 * test_size : ]
    
    valid_data = datasets[: test_size]
    valid_label = labels[: test_size]
    
    test_data = datasets[test_size : 2 * test_size]
    test_label = labels[test_size : 2 * test_size]

3.3.4 建立神经网络

结构:输入层(约7000个样本),一个中间层(10个隐单元),输出层(2个单元)

注意,此处用的是relu,不是sigmoid。原因:虽然x>0时没处理数据,但x<0时为0让relu有了非线性的能力,且运算比sigmoid简单,速度更快,方便反向误差传导


def plot(records):
    print('plot begin')
    a = [i[0] for i in records]
    b = [i[1] for i in records]
    c = [i[2] for i in records]
    pyplot.plot(a, label='train loss')
    pyplot.plot(b, label='verify loss')
    pyplot.plot(c, label='valid accuracy')
    pyplot.xlabel('step')
    pyplot.ylabel('losses & accuracy')
    pyplot.legend()
    pyplot.show()


def main():
    model = torch.nn.Sequential(
        torch.nn.Linear(len(diction), 10),
        torch.nn.ReLU(),
        torch.nn.Linear(10, 2),
        torch.nn.LogSoftmax(dim=1),
    )
    cost = torch.nn.NLLLoss()
    optimizer = torch.optim.SGD(model.parameters(), lr=0.01)
    records = []
    losses = []
    for epoch in range(3):
        for i, data in enumerate(zip(train_data, train_label)):
            x, y = data
            x = torch.tensor(x, requires_grad=True, dtype=torch.float).view(1, -1)
            y = torch.tensor(np.array([y]), dtype=torch.long)
            optimizer.zero_grad()
            predict = model(x)
            loss = cost(predict, y)
            losses.append(loss.data.numpy())
            loss.backward()
            optimizer.step()

        val_losses = []
        rights = []
        for j, val in enumerate(zip(valid_data, valid_label)):
            x, y = val
            x = torch.tensor(x, requires_grad=True, dtype=torch.float).view(1, -1)
            y = torch.tensor(np.array([y]), dtype = torch.long)
            predict = model(x)
            right = rightness_sample(predict, y)
            rights.append(right)
            loss = cost(predict, y)
            val_losses.append(loss.data.numpy())
        right_ratio = 1.0 * np.sum([i[0] for i in rights]) / np.sum([i[1] for i in rights])
        print(f'No.{epoch}, train loss:{np.mean(losses):.4f}, verify loss:{np.mean(val_losses):.4f}, verify accuracy:{right_ratio}')
        records.append([np.mean(losses), np.mean(val_losses), right_ratio])
    # plot
    plot(records)

3.3 运行结果

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1384855.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

安装nodejs出现问题

Error: EPERM: operation not permitted, mkdir… 全局安装express模块进行测试时&#xff1a; npm install express -g出现&#xff1a; 表示nodejs的安装目录无权限&#xff0c;根据错误日志的信息&#xff0c;定位到安装目录下&#xff1a; 点击属性&#xff1a; 点击编…

【江科大STM32合集】day2按键控制LED光敏传感器控制峰鸣器

【STM32合集】day2按键控制LED&光敏传感器控制峰鸣器 电路基础c语言基础main.ckey.c结果 实现一个键开关灯实验结果避坑 电路基础 运算放大器-在江科大51单片机b站视频&#xff08;AD/DA&#xff09;复习 原理&#xff1a;两个极端 同相输入端电压 》反相输入端 电压输出最…

基于RTOS(实时操作系统)的CMT液晶屏控制器驱动程序开发与实现

RTOS&#xff08;实时操作系统&#xff09;提供了一种有效的方式来管理和调度多任务系统&#xff0c;对于液晶屏控制器的驱动程序开发来说&#xff0c;RTOS能够提供良好的实时性和可靠性。本文以RTOS为基础&#xff0c;设计并实现了一个用于控制CMT液晶屏的驱动程序。在设计过程…

【android】rk3588-android-bt

文章目录 蓝牙框架HCI接口蓝牙VENDORLIBvendorlib是什么 代码层面解读vendorlib1、 vendorlib实现&#xff0c;协议栈调用2、协议栈实现&#xff0c;vendorlib调用&#xff08;回调函数&#xff09;2.1、 init函数2.2、BT_VND_OP_POWER_CTRL对应处理2.3、BT_VND_OP_USERIAL_OPE…

【LV13 DAY16 轮询与中断】

轮询实现按键实验 #include "exynos_4412.h"int main() {//GPX1_1设置为输入模式//GPX1.CONGPX1.CON & (~ (0XF<<4));while(1){if(!(GPX1.DAT&(1<<1))){printf("key pressed\n");while(!(GPX1.DAT&(1<<1)));}else{}}return…

i18n多国语言Internationalization的实现

i18n 是"Internationalization”的缩写&#xff0c;这个术语来源于英文单词中首尾字母“”和“n”以及中间的字符数(共计18个字符) 当我们需要开发不同语言版本时&#xff0c;就可以使用i18n多国语言的一个操作处理&#xff0c;i18n主要实现那一方面的内容呢&#xff1f;…

蓝桥杯省赛无忧 STL 课件17 map

01 map 02 multimap 03 unordered_map 04 代码示例

训练DAMO-YOLO(damoyolo_tinynasL25_S.py)

文章目录 参考链接1 准备数据1.1 转为COCO格式1.2 指明数据路径 2 设置训练配置文件&#xff0c;在configs/damoyolo_tinynasL25_S.py进行如下两块修改2.1 关于训练参数的设置2.2 根据自己数据集设置 3 开始训练4 调用tools/eval.py进行测试5 训练时可能遇到的报错5.1 RuntimeE…

Java生成四位数随机验证码

引言&#xff1a; 我们生活中登录的时候都要输入验证码&#xff0c;这些验证码是为了增加注册或者登录难度&#xff0c;减少被人用脚本疯狂登录注册导致的一系列危害&#xff0c;减少数据库的一些压力。 毕竟那些用脚本生成的账号都是垃圾账号 本次实践&#xff1a;生成这样的…

Spring Boot - 利用Resilience4j-RateLimiter进行流量控制和服务降级

文章目录 Resilience4j概述Resilience4j官方地址Resilience4j-RateLimiter微服务演示Payment processorPOM配置文件ServiceController Payment servicePOMModelServiceRestConfigController配置验证 探究 Rate Limiting请求三次 &#xff0c;观察等待15秒连续访问6次 Resilienc…

mysql原理--undo日志1

1.事务回滚的需求 我们说过 事务 需要保证 原子性 &#xff0c;也就是事务中的操作要么全部完成&#xff0c;要么什么也不做。但是偏偏有时候事务执行到一半会出现一些情况&#xff0c;比如&#xff1a; (1). 事务执行过程中可能遇到各种错误&#xff0c;比如服务器本身的错误&…

前端对接电子秤、扫码枪设备serialPort 串口使用教程

因为最近工作项目中用到了电子秤&#xff0c;需要对接电子秤设备。以前也没有对接过这种设备&#xff0c;当时也是一脸懵逼&#xff0c;脑袋空空。后来就去网上搜了一下前端怎么对接&#xff0c;然后就发现了SerialPort串口。 Serialport 官网地址&#xff1a;https://serialpo…

软件工程:黑盒测试等价分类法相关知识和多实例分析

目录 一、黑盒测试和等价分类法 1. 黑盒测试 2. 等价分类法 二、黑盒测试等价分类法实例分析 1. 工厂招工年龄测试 2. 规定电话号码测试 3. 八位微机测试 4. 三角形判断测试 一、黑盒测试和等价分类法 1. 黑盒测试 黑盒测试就是根据被测试程序功能来进行测试&#xf…

通过开源端点可见性改善网络安全响应

在当今复杂的数字环境中&#xff0c;企业内的许多不同端点&#xff08;从数据中心的服务器到咖啡店的笔记本电脑&#xff09;创建了巨大且多样化的攻击面。每个设备都存在网络安全威胁的机会&#xff0c;每个设备都有其独特的特征和复杂性。攻击者使用的多种攻击媒介不仅是一个…

【占用网络】FlashOcc:快速、易部署的占用预测模型

前言 FlashOcc是一个它只需2D卷积就能实现“占用预测模型”&#xff0c;具有快速、节约内存、易部署的特点。 它首先采用2D卷积提取图形信息&#xff0c;生成BEV特征。然后通过通道到高度变换&#xff0c;将BEV特征提升到3D空间特征。 对于常规的占用预测模型&#xff0c;将…

web前端算法简介之字典与哈希表

回顾 栈、队列 &#xff1a; 进、出 栈&#xff08;Stack&#xff09;&#xff1a; 栈的操作主要包括&#xff1a; 队列&#xff08;Queue&#xff09;&#xff1a; 队列的操作主要包括&#xff1a; 链表、数组 &#xff1a; 多个元素存储组成的 简述链表&#xff1a;数组&…

阶段十-分布式锁

5.1 节 为什么要使用分布式锁 锁是多线程代码中的概念&#xff0c;只有当多任务访问同一个互斥的共享资源时才需要。如下图&#xff1a; 在我们进行单机应用开发&#xff0c;涉及并发同步的时候&#xff0c;我们往往采用synchronized或者lock的方式来解决多线程间的代码同步问…

分布式任务调度平台XXL-JOB使用(二)

说明&#xff1a;之前总结过在CentOS系统里使用XXL-JOB。但在代码开发阶段&#xff0c;我们有时需要在本地环境测试定时任务代码&#xff0c;本文介绍如何在Windows系统使用XXL-JOB。 下载 &#xff08;1&#xff09;下载代码&#xff0c;解压打开 首先&#xff0c;去Github…

权责发生制和收付实现制

目录 一. 权责发生制(应记制)二. 收付实现制 \quad 一. 权责发生制(应记制) 应计制就是应该记入的意思 各项收入和费用的确认应当以“实际发生”&#xff08;归属期&#xff09;而不是以款项的实际收付作为记账的基础。 正是有会计期间假设&#xff0c;才有权责发生制和收付实…

逸学Docker【java工程师基础】3.1安装Jenkins

1.下载镜像 docker pull jenkins/jenkins:lts 2.运行容器 docker run -d -u root -p 8080:8080 -p 50000:50000 -v /var/jenkins_home:/var/jenkins_home -v /etc/localtime:/etc/localtime --name jenkins jenkins/jenkins:lts 3.要启动名为 jenkins 的 Docker 容器 docker st…