基于 chinese-roberta-wwm-ext 微调训练 6 分类情感分析模型

news2024/10/6 12:31:13

一、模型和数据集介绍

1.1 预训练模型

chinese-roberta-wwm-ext 是基于 RoBERTa 架构下开发,其中 wwm 代表 Whole Word Masking,即对整个词进行掩码处理,通过这种方式,模型能够更好地理解上下文和语义关联,提高中文文本处理的准确性和效果。

与原始的 BERT 模型相比,chinese-roberta-wwm-ext 在训练数据规模和训练步数上做了一些调整,以进一步提升模型的性能和鲁棒性。并且在大规模无监督语料库上进行了预训练,使其具备强大的语言理解和生成能力。它能够广泛应用于各种自然语言处理任务,如文本分类、命名实体识别、情感分析等。我们可以使用这个模型作为基础,在不同的任务上进行微调和迁移学习,以实现更准确、高效的中文文本处理。

huggingface地址:https://huggingface.co/hfl/chinese-roberta-wwm-ext

进到 huggingface 中下载预训练模型:

在这里插入图片描述

1.2 数据集

数据集采用 SMP2020微博情绪分类评测 ,进入下面链接下载数据集:

https://smp2020ewect.github.io/

在这里插入图片描述

下载后数据分了三种类型,训练集、测试集、验证集

在这里插入图片描述

数据的格式为 JSON 格式,结构如下:

在这里插入图片描述

其中 label 为当前 content 的分类,一共分了 6 种类别,如下所示:

label说明
fear恐惧
neutral无情绪
sad悲伤
surprise惊奇
angry愤怒
happy积极

二、模型微调训练

2.1、处理数据集构建DataLoader

from transformers import BertTokenizer
from torch.utils.data import DataLoader, RandomSampler, TensorDataset
import numpy as np
import torch
import json
import os


class GenDateSet():
    def __init__(self, tokenizer, train_file, val_file, label_dict, max_length=128, batch_size=10):
        self.train_file = train_file
        self.val_file = val_file
        self.max_length = max_length
        self.batch_size = batch_size
        self.label_dict = label_dict
        self.tokenizer = tokenizer

    def gen_data(self, file):
        if not os.path.exists(file):
            raise Exception("数据集不存在")
        input_ids = []
        input_types = []
        input_masks = []
        labels = []
        with open(file, encoding='utf8') as f:
            data = json.load(f)
        if not data:
            raise Exception("数据集不存在")

        # 处理数据
        for index, item in enumerate(data):
            text = item['content']
            label = item['label']
            tokens = self.tokenizer(text, padding="max_length", truncation=True, max_length=self.max_length)
            input_id, types, masks = tokens['input_ids'], tokens['token_type_ids'], tokens['attention_mask']
            input_ids.append(input_id)
            input_types.append(types)
            input_masks.append(masks)
            y_ = self.label_dict[label]
            labels.append([y_])

            if index % 1000 == 0:
                print('处理', index, '条数据')

        # 构建 TensorDataset
        data_gen = TensorDataset(torch.LongTensor(np.array(input_ids)),
                                 torch.LongTensor(np.array(input_types)),
                                 torch.LongTensor(np.array(input_masks)),
                                 torch.LongTensor(np.array(labels)))
        # 打乱
        sampler = RandomSampler(data_gen)
        # 构建 DataLoader
        return DataLoader(data_gen, sampler=sampler, batch_size=self.batch_size)

    def gen_train_data(self):
        # 生成训练集
        return self.gen_data(self.train_file)

    def gen_val_data(self):
        # 生成验证集
        return self.gen_data(self.val_file)

2.2 构建模型迭代训练

import torch
from transformers import AutoTokenizer, AutoModelForSequenceClassification
import torch.nn as nn
from tqdm import tqdm
from gen_datasets import GenDateSet

# 标签结构
label_dict = {
    'fear': 0,
    'neutral': 1,
    'sad': 2,
    'surprise': 3,
    'angry': 4,
    'happy': 5
}

# 预训练模型位置
model_dir = 'D:\\AIGC\\model\\chinese-roberta-wwm-ext'
# 这里暂时使用测试集训练,数据较少
train_file = 'data/usual_train.txt'
# train_file = 'data/usual_test_labeled.txt'
# 验证集
val_file = 'data/virus_eval_labeled.txt'
# 训练模型存储位置
save_model_path = './model/'
# 最大长度
max_length = 128
# 分类数
num_classes = 6
batch_size = 10
epoch = 10

def val(model, device, data):
    model.eval()
    test_loss = 0.0
    acc = 0
    for (input_id, types, masks, y) in tqdm(data):
        input_id, types, masks, y = input_id.to(device), types.to(device), masks.to(device), y.to(device)
        with torch.no_grad():
            y_ = model(input_id, token_type_ids=types, attention_mask=masks)
            logits = y_['logits']
        test_loss += nn.functional.cross_entropy(logits, y.squeeze())
        pred = logits.max(-1, keepdim=True)[1]
        acc += pred.eq(y.view_as(pred)).sum().item()
    test_loss /= len(data)
    return acc / len(data.dataset)


def main():
    # 加载 tokenizer 和  model
    tokenizer = AutoTokenizer.from_pretrained(model_dir)
    model = AutoModelForSequenceClassification.from_pretrained(model_dir, num_labels=num_classes)
    # 加载数据集
    dateset = GenDateSet(tokenizer, train_file, val_file, label_dict, max_length, batch_size)
    # 训练集
    train_data = dateset.gen_train_data()
    # 验证集
    val_data = dateset.gen_val_data()

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model.train()
    model = model.to(device)
    # 优化器
    optimizer = torch.optim.Adam(model.parameters(), lr=1e-5)

    best_acc = 0.0
    for epoch_index in range(epoch):
        batch_epoch = 0
        for (input_id, types, masks, y) in tqdm(train_data):
            input_id, types, masks, y = input_id.to(device), types.to(device), masks.to(device), y.to(device)
            # 前向传播
            outputs = model(input_id, token_type_ids=types, attention_mask=masks, labels=y)
            # 梯度清零
            optimizer.zero_grad()
            # 计算 loss
            loss = outputs.loss
            # 反向传播
            loss.backward()
            optimizer.step()
            batch_epoch += 1
            if batch_epoch % 10 == 0:
                print('Train Epoch:', epoch_index, ' , batch_epoch: ', batch_epoch, ' , loss = ', loss.item())

        # 评估准确度
        acc = val(model, device, val_data)
        print('Train Epoch:', epoch_index, ' val acc = ', acc)
        # 存储 best model
        if best_acc < acc:
            # torch.save(model.state_dict(), save_model_path)
            model.save_pretrained("./model")
            tokenizer.save_pretrained("./model")
            best_acc = acc

if __name__ == '__main__':
    main()

运行之后可以看到训练进度:

在这里插入图片描述

训练中可以看到验证集的准确率,等待训练结束后,可以在 model 下看到保存的模型和 tokenizer 文件:

在这里插入图片描述

三、模型测试

import torch
from transformers import AutoTokenizer, AutoModelForSequenceClassification

label_dict = {
    0: '恐惧',
    1: '无情绪',
    2: '悲伤',
    3: '惊奇',
    4: '愤怒',
    5: '积极'
}
model_dir = "./model"
num_classes = 6
max_length = 128


def main():
    # 加载预训练模型和分词器
    tokenizer = AutoTokenizer.from_pretrained(model_dir)
    model = AutoModelForSequenceClassification.from_pretrained(model_dir, num_labels=num_classes)

    while True:
        text = input("请输入内容: \n ")
        if not text or text == "":
            continue
        if text == "q":
            break

        encoded_input = tokenizer(text, padding="max_length", truncation=True, max_length=max_length)

        input_ids = torch.tensor([encoded_input['input_ids']])
        token_type_ids = torch.tensor([encoded_input['token_type_ids']])
        attention_mask = torch.tensor([encoded_input['attention_mask']])
        # 前向传播
        y_ = model(input_ids, token_type_ids=token_type_ids, attention_mask=attention_mask)
        output = y_['logits'][0]
        pred = output.max(-1, keepdim=True)[1][0].item()
        print('预测结果:', label_dict[pred])

if __name__ == '__main__':
    main()

在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/772105.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

NAS 问题处理记录

在解决自动配网的过程中&#xff0c;突然NAS不给力&#xff0c;偏偏这个时间找事情。上面这两个问题&#xff0c;说不复杂也不复杂&#xff0c;主要是自己在完全远程处理&#xff0c;很多不方便。当然少不了师弟的助攻&#xff0c;很感谢我的师弟帮忙&#xff0c;实验室的网络不…

Flink 启动就报错,但exception没提示。其中一个task failure 该怎么办?

文章目录 前言一、排查二、解决 前言 最近我在生产又遇到一个问题&#xff0c;就是消费着一段时间之后&#xff0c;忽然就不再消费了&#xff0c;但也不报错。观察了几次&#xff0c;我发现时间基本是停留在上下班高峰期数据量最大的时候。我主观猜测可能是同时间进来的数据过…

css通过子元素选择父元素

伪类:has选择父元素 td:has(> .unfoldTable){//可选中所有td下包含unfoldTable的class标签的td属性color: red; }td:has(> div){//可选中所有td下包含div标签的td属性color: red; } 特殊举例分析&#xff1a; 个别UI框架个别标签通过事件直接生成或者无法选中的情况。…

爆肝整理,Postman接口测试-全局变量/接口关联/加密/解密(超细)

目录&#xff1a;导读 前言一、Python编程入门到精通二、接口自动化项目实战三、Web自动化项目实战四、App自动化项目实战五、一线大厂简历六、测试开发DevOps体系七、常用自动化测试工具八、JMeter性能测试九、总结&#xff08;尾部小惊喜&#xff09; 前言 全局变量和环境变…

随手记——前端安全策略 Content-Security-Policy – CSP

随手记——前端安全策略 Content-Security-Policy – CSP 一、问题 1. 问题&#xff1a;前端meta标签中配置了CSP安全策略&#xff0c;导致使用第三方地图插件的时间报错不展示 2. 原安全配置&#xff1a; <meta http-equiv"Content-Security-Policy" content&…

STM32MP157驱动开发——按键驱动(查询方式)

文章目录 概述APP 读取按键的 4 种方法查询方式休眠-唤醒方式poll 方式异步通知方式 查询方式的按键驱动程序&#xff08;框架&#xff09;按键驱动编写思路board_xxx.cbutton_drv.cbutton_drv.hbutton_test.cMakefile编译测试 查询方式的按键驱动程序(stm32mp157)board_stm32m…

常见Redis使用问题

一 lettuce使用问题 1 问题描述 Redis Cluster集群&#xff0c;当master宕机&#xff0c;主从切换&#xff0c;客户端报错 timed out 2 原因 SpringBoot2.X版本开始Redis默认的连接池都是采用的Lettuce。当节点发生改变后&#xff0c;Letture默认是不会刷新节点拓扑的。 3…

Spring+SpringMvc+Mybatis整合小Demo

原始方式整合SSM 不使用spring-mybatis包 项目内容 整合ssm完成对account表新增和查询的操作 项目大体结构 创建mavenWeb项目 pom文件中引入依赖 spring核心、aspectj(aop)、spring-jdbc(jdbcTemplate)、spring-tx(事务)、 数据源&#xff1a;mysql、c3p0、mybatis my…

linux之Ubuntu系列 find 、 ln 、 tar、apt-get 指令 软链接和硬链接

查找文件 find 命令 功能非常强大&#xff0c;通常用来在 特定的目录下 搜索 符合条件的文件 find [path] -name “.txt” 记得要加 “ ” 支持通配符 &#xff0c;正则表达式 包括子目录 ls 不包括 子目录 如果省略路径&#xff0c;表示 在当前路径下&#xff0c;搜索 软链接…

测试老鸟总结,性能测试-最佳并发和最大并发,性能测试实施...

目录&#xff1a;导读 前言一、Python编程入门到精通二、接口自动化项目实战三、Web自动化项目实战四、App自动化项目实战五、一线大厂简历六、测试开发DevOps体系七、常用自动化测试工具八、JMeter性能测试九、总结&#xff08;尾部小惊喜&#xff09; 前言 性能测试&#xf…

无涯教程-Javascript - Switch语句

从JavaScript 1.2开始&#xff0c;您可以使用 switch 语句来处理这种情况&#xff0c;它比重复的 if ... else if 语句更有效。 流程图 以下流程图说明了switch-case语句的工作原理。 switch 语句的目的是给出一个要求值的表达式&#xff0c;并根据表达式的值执行多个不同的语…

云曦暑期学习第一周——sql注入

1浅谈sql注入 1.1sql注入 sql注入是指web应用程序对用户输入数据的合法性没有判断&#xff0c;前端传入后端的参数是攻击者可控的&#xff0c;并且参数带入数据库查询&#xff0c;攻击者可以通过构造不同的sql语句来实现对数据库的任意操作 1.2原理 条件&#xff1a; 1.参…

接入端口与中继端口

交换机端口是支持 IT 的基本组件&#xff0c;可实现网络通信。这些有线硬件设备负责连接并允许在不同设备和连接到其端口的网络部分之间进行数据传输。由于网络管理员在确保网络连接和可用性方面发挥着关键作用&#xff0c;因此网络管理员必须清楚地了解、映射和查看其网络交换…

面向对象Java基础

前言 看大话设计模式的时候&#xff0c;发现自己的基础不是很扎实&#xff0c;重新回顾一些存在有点点不确定的内容&#xff0c;并从书中截取下来&#xff0c;做成笔记快速复习。 1、字段和属性 字段&#xff1a;用private修饰&#xff0c;也叫私有变量。属性&#xff1a;字…

2.Docker操作

文章目录 Docker操作Docker镜像操作搜索镜像获取镜像镜像加速下载查看镜像详细信息为镜像添加标签删除镜像导出导入镜像上传镜像 Docker容器操作创建容器查看容器状态启动容器创建并启动容器进入容器停止容器删除容器复制容器文件到宿主机容器的导出导入 Docker操作 ###查看do…

【天工Godwork精品教程】天工3.1.7安装教程(附Godwork完整版下载地址)

本文讲解天工3.1.7安装过程(附Godwork完整版网盘下载地址)。 文章目录 一、天工3.1.7安装教程1. 安装GodWork-AT 3.1.72. 安装GodWork-AT 3.1.7补丁3. 安装GodWork-EOS-Setup-2017B-12314. 安装GodWork-EOS补丁5. 运行godwokr软件6. 生成ZC码7. 输入ZC码8. eos插件调用二、天…

AtcoderABC245场

A - Good morningA - Good morning 题目大意 给定Takahashi和Aoki的起床时间&#xff0c;判断谁先起床。 思路分析 题目要求比较Takahashi和Aoki的起床时间。首先&#xff0c;将起床时间转换为以分钟为单位。然后&#xff0c;通过比较两者的起床时间来确定谁先起床。 时间复…

文献阅读笔记——求解车辆路径问题及其变体的元启发式算法的分类综述

论文题目&#xff1a;A taxonomic review of metaheuristic algorithms for solving the vehicle routing problem and its variants 其他信息&#xff1a;Computers & Industrial Engineering|2020|Raafat Elshaer⁎, Hadeer Awad 文章贡献&#xff1a;1&#xff09;对使…

RabbitMQ安装及简单使用

说明&#xff1a;RabbitMQ&#xff08;官网&#xff1a;&#xff09;是一门异步通讯技术&#xff0c;使用异步通讯技术&#xff0c;可解决同步通讯的一些问题。 安装 本文介绍在云服务器上安装RabbitMQ&#xff0c;操作系统是CentOS 7&#xff0c;远程连接工具是WindTerm&…

抖音seo源码部署/开源不加密可二开/抖音seo优化开发方案

一、前言 抖音是目前国内非常流行的短视频平台之一&#xff0c;用户数量庞大&#xff0c;更是吸引了许多企业和个人在上面开设账号&#xff0c;通过发布内容来进行流量变现。但是&#xff0c;在一个账号发布内容的同时&#xff0c;管理员又需要同时关注多个账号&#xff0c;对账…