NLP 算法实战项目:使用 BERT 进行模型微调,进行文本情感分析

news2024/9/28 9:24:36

本篇我们使用公开的微博数据集(weibo_senti_100k)进行训练,此数据集已经进行标注,0: 负面情绪,1:正面情绪。数据集共计82718条(包含标题)。如下图:

图片

下面我们使用bert-base-chinese预训练模型进行微调并进行测试。 技术交流,文末获取。

1. 导入必要的库

import torch
from transformers import BertTokenizer, BertForSequenceClassification, AdamW
from torch.utils.data import DataLoader, Dataset, random_split
import pandas as pd
from tqdm import tqdm
import random

2. 加载数据集和预训练模型

# 读取训练数据集
df = pd.read_csv("weibo_senti_100k.csv")  # 替换为你的训练数据集路径
# 加载预训练的BERT模型和分词器
tokenizer = BertTokenizer.from_pretrained('bert-base-chinese')
model = BertForSequenceClassification.from_pretrained('bert-base-chinese')

3. 对数据集进行预处理

注意:此处需要打乱数据行,为了快速训练展示,下面程序只加载了1500条数据。

# 设置随机种子以确保可重复性
random.seed(42)
# 随机打乱数据行
df = df.sample(frac=1).reset_index(drop=True)
# 数据集中1为正面,0为反面
class SentimentDataset(Dataset):
    def __init__(self, dataframe, tokenizer, max_length=128):
        self.dataframe = dataframe
        self.tokenizer = tokenizer
        self.max_length = max_length

    def __len__(self):
        return len(self.dataframe)

    def __getitem__(self, idx):
        text = self.dataframe.iloc[idx]['review']
        label = self.dataframe.iloc[idx]['label']
        encoding = self.tokenizer(text, padding='max_length', truncation=True, max_length=self.max_length, return_tensors='pt')
        return {
            'input_ids': encoding['input_ids'].flatten(),
            'attention_mask': encoding['attention_mask'].flatten(),
            'labels': torch.tensor(label, dtype=torch.long)
        }

# 创建数据集对象
dataset = SentimentDataset(df[:1500], tokenizer)

4. 将数据集分为训练集、验证集

# 创建数据集对象
dataset = SentimentDataset(df[:1500], tokenizer)

# 划分训练集和验证集
train_size = int(0.8 * len(dataset))
val_size = len(dataset) - train_size
train_dataset, val_dataset = random_split(dataset, [train_size, val_size])

# 创建数据加载器
train_loader = DataLoader(train_dataset, batch_size=8, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=8, shuffle=False)

5. 设置训练参数

# 设置训练参数
optimizer = AdamW(model.parameters(), lr=5e-5)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)

6. 训练模型

# 训练模型
model.train()
for epoch in range(3):  # 3个epoch作为示例
    for batch in tqdm(train_loader, desc="Epoch {}".format(epoch + 1)):
        input_ids = batch['input_ids'].to(device)
        attention_mask = batch['attention_mask'].to(device)
        labels = batch['labels'].to(device)
        
        optimizer.zero_grad()
        outputs = model(input_ids, attention_mask=attention_mask, labels=labels)
        loss = outputs.loss
        loss.backward()
        optimizer.step()
# 输出
Epoch 1: 100%|██████████| 150/150 [00:28<00:00,  5.28it/s]
Epoch 2: 100%|██████████| 150/150 [00:29<00:00,  5.15it/s]
Epoch 3: 100%|██████████| 150/150 [00:27<00:00,  5.36it/s]

7. 评估模型

# 评估模型
model.eval()
total_eval_accuracy = 0
for batch in tqdm(val_loader, desc="Evaluating"):
    input_ids = batch['input_ids'].to(device)
    attention_mask = batch['attention_mask'].to(device)
    labels = batch['labels'].to(device)

    with torch.no_grad():
        outputs = model(input_ids, attention_mask=attention_mask)
    
    logits = outputs.logits
    preds = torch.argmax(logits, dim=1)
    accuracy = (preds == labels).float().mean()
    total_eval_accuracy += accuracy.item()

average_eval_accuracy = total_eval_accuracy / len(val_loader)
print("Validation Accuracy:", average_eval_accuracy)
# 输出
Evaluating: 100%|██████████| 38/38 [00:02<00:00, 16.57it/s]Validation Accuracy: 0.9407894736842105

8. 进行预测

# 使用微调后的模型进行预测
def predict_sentiment(sentence):
    inputs = tokenizer(sentence, padding='max_length', truncation=True, max_length=128, return_tensors='pt').to(device)
    with torch.no_grad():
        outputs = model(**inputs)
    logits = outputs.logits
    probs = torch.softmax(logits, dim=1)
    positive_prob = probs[0][1].item()  # 1表示正面
    print("Positive Probability:", positive_prob)

# 测试一个句子
predict_sentiment("我要发火了")
# 输出
Positive Probability: 0.19748596847057343

技术交流群

前沿技术资讯、算法交流、求职内推、算法竞赛、面试交流(校招、社招、实习)等、与 10000+来自港科大、北大、清华、中科院、CMU、腾讯、百度等名校名企开发者互动交流~

我们建了NLP技术与面试交流群, 想要进交流群、需要源码&资料、提升技术的同学,可以直接加微信号:mlc2060。加的时候备注一下:研究方向 +学校/公司+CSDN,即可。然后就可以拉你进群了。

方式①、微信搜索公众号:机器学习社区,后台回复:加群
方式②、添加微信号:mlc2060,备注:技术交流

用通俗易懂方式讲解系列

  • 用通俗易懂的方式讲解:自然语言处理初学者指南(附1000页的PPT讲解)
  • 用通俗易懂的方式讲解:1.6万字全面掌握 BERT
  • 用通俗易懂的方式讲解:NLP 这样学习才是正确路线
  • 用通俗易懂的方式讲解:28张图全解深度学习知识!
  • 用通俗易懂的方式讲解:不用再找了,这就是 NLP 方向最全面试题库
  • 用通俗易懂的方式讲解:实体关系抽取入门教程
  • 用通俗易懂的方式讲解:灵魂 20 问帮你彻底搞定Transformer
  • 用通俗易懂的方式讲解:图解 Transformer 架构
  • 用通俗易懂的方式讲解:大模型算法面经指南(附答案)
  • 用通俗易懂的方式讲解:十分钟部署清华 ChatGLM-6B,实测效果超预期
  • 用通俗易懂的方式讲解:内容讲解+代码案例,轻松掌握大模型应用框架 LangChain
  • 用通俗易懂的方式讲解:如何用大语言模型构建一个知识问答系统
  • 用通俗易懂的方式讲解:最全的大模型 RAG 技术概览
  • 用通俗易懂的方式讲解:利用 LangChain 和 Neo4j 向量索引,构建一个RAG应用程序
  • 用通俗易懂的方式讲解:使用 Neo4j 和 LangChain 集成非结构化知识图增强 QA
  • 用通俗易懂的方式讲解:面了 5 家知名企业的NLP算法岗(大模型方向),被考倒了。。。。。
  • 用通俗易懂的方式讲解:NLP 算法实习岗,对我后续找工作太重要了!。
  • 用通俗易懂的方式讲解:理想汽车大模型算法工程师面试,被问的瑟瑟发抖。。。。
  • 用通俗易懂的方式讲解:基于 Langchain-Chatchat,我搭建了一个本地知识库问答系统
  • 用通俗易懂的方式讲解:面试字节大模型算法岗(实习)
  • 用通俗易懂的方式讲解:大模型算法岗(含实习)最走心的总结
  • 用通俗易懂的方式讲解:大模型微调方法汇总

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1506841.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

鸿蒙Harmony应用开发—ArkTS声明式开发(基础手势:Gauge)

数据量规图表组件&#xff0c;用于将数据展示为环形图表。 说明&#xff1a; 该组件从API Version 8开始支持。后续版本如有新增内容&#xff0c;则采用上角标单独标记该内容的起始版本。 子组件 可以包含单个子组件。 说明&#xff1a; 建议使用文本组件构建当前数值文本和辅…

信息系统项目管理师005:工业互联网(1信息化发展—1.2现代化基础设施—1.2.2工业互联网)

文章目录 1.2.2 工业互联网1.内涵和外延2.平台体系3.融合应用 记忆要点总结 1.2.2 工业互联网 工业互联网(Industrial Internet)是新一代信息通信技术与工业经济深度融合的新型基础设施、应用模式和工业生态&#xff0c;通过对人、机、物、系统等的全面连接&#xff0c;构建起覆…

【EDK II】作为UEFI的实现,EDK II 的架构是什么样的

目录 前言 EDK II 架构 配置文件 结语 前言 基本输入输出系统 (Basic Input Output System, BIOS) 最早由 IBM&#xff08;International Business Machines Corporation) 公司于1981年提出并开发&#xff0c;后来成为个人计算机(PC)的标准固件接口。但受限于传统BIOS (Le…

Git分支管理(IDEA)

文章目录 Git分支管理&#xff08;IDEA&#xff09;1.Git分支管理&#xff08;IDEA&#xff09;1.基本介绍1.分支理解2.示意图 2.搭建分支和合并的环境1.创建Gitee仓库2.创建普通maven项目3.克隆Gitee项目到E:\GiteeRepository4.复制erp文件夹下的内容到IDEA项目下5.IDEA项目中…

Kafka的分区机制

Kafka的分区机制是其核心功能之一&#xff0c;旨在提高可扩展性和并行处理能力。下面概述了Kafka分区的基本概念和工作原理&#xff1a; Kafka分区基本概念 分区&#xff08;Partition&#xff09;&#xff1a;Kafka中的主题&#xff08;Topic&#xff09;可以细分为多个分区…

软件测试APP完整测试作业流程(附流程图),公司级软件测试流程化办公

目录 1. 概述 2. 软件测试流程 3. 软件测试周期人员活动图 4. 总结 1. 概述 1.1 目的 有效的保证软件质量&#xff1b; 有效的制定不同测试类型&#xff08;软件系统测试、音频主观性测试、Field Trial、专项测试、自动化测试、性 能测试、用户体验测试&#xff09;的软件…

【HarmonyOS】ArkUI - 自定义卡片样式

ArkUI - 自定义卡片样式 HarmonyOS API 9 没有提供原生的卡片样式&#xff0c;我定义了一个卡片样式&#xff0c;可以方便大家在日常开发中使用。 效果图&#xff1a; 卡片样式代码如下&#xff1a; Styles function card() {.width(95%).padding(20).backgroundColor(Col…

【CSP】2022-03-2 出行计划 经典差分和前缀和 (包含完整思路、代码和写代码过程中遇到的问题)

2022-03-2 出行计划 差分和前缀和 2022-03-2 出行计划 差分和前缀和思路遇到的问题&#xff08;不小心出现的细节问题&#xff09;完整代码 2022-03-2 出行计划 差分和前缀和 这题很久之前做过一次&#xff0c;现在已经基本忘记了&#xff0c;所以重新做一遍&#xff0c;然后一…

Linux动态追踪——ftrace

目录 摘要 1 初识 1.1 tracefs 1.2 文件描述 2 函数跟踪 2.1 函数的调用栈 2.2 函数调用栈 2.3 函数的子调用 3 事件跟踪 4 简化命令行工具 5 总结 摘要 Linux下有多种动态追踪的机制&#xff0c;常用的有 ftrace、perf、eBPF 等&#xff0c;每种机制适应于不同的场…

银河麒麟V10 安装部署大数据管理软件 DataSophon

一、概览 1、愿景 致力于快速实现部署、管理、监控以及自动化运维大数据云原生平台&#xff0c;帮助您快速构建起稳定、高效、可弹性伸缩的大数据云原生平台。 2、DataSophon是什么 《三体》&#xff0c;这部获世界科幻文学最高奖项雨果奖的作品以惊艳的"硬科幻"…

Jmeter+Ant+Git/SVN+Jenkins实现持续集成接口测试,一文精通(一)

前言 Jmeter&#xff0c;Postman一些基本大家相比都懂。那么真实在项目中去使用&#xff0c;又是如何使用的呢&#xff1f;本文将一文详解jmeter接口测试 一、接口测试分类 二、目前接口架构设计 三、市面上的接口测试工具 四、Jmeter简介&#xff0c;安装&#xff0c;环境…

【Kafka系列 08】生产者消息分区机制详解

一、前言 我们在使用 Apache Kafka 生产和消费消息的时候&#xff0c;肯定是希望能够将数据均匀地分配到所有服务器上。 比如很多公司使用 Kafka 收集应用服务器的日志数据&#xff0c;这种数据都是很多的&#xff0c;特别是对于那种大批量机器组成的集群环境&#xff0c;每分…

Visio无空白无黑边导出PDF

步骤1 文件->选项->自定义功能区->勾选开发工具 步骤2 开发工具->显示ShapeSheet->页->将Print Properties中的Margin都设置为0 步骤3 设计->大小->适应绘图 步骤4 导出PDF->选项->取消勾选【辅助功能文档结构标记】->发布

BetterDisplay for mac V2.2.5 强大的mac显示器管理开源工具

BetterDisplay是Mac OS 一个很棒的工具&#xff01; 它允许您将显示器转换为完全可扩展的屏幕 管理显示器配置覆盖 允许亮度和颜色控制 提供 XDR/HDR 亮度升级&#xff08;Apple Silicon 和 Intel Mac 上兼容的 XDR 或 HDR 显示器的额外亮度超过 100% - 多种方法可用&#x…

opengl 学习(三)-----纹理

纹理就是贴图 分类前提demo效果解析 分类 前提 需要使用一个库来处理图片&#xff1a;#include <stb_image.h> https://github.com/nothings/stb 你下载好了之后&#xff0c;把目目录包含了就好 然后再引入 #define STB_IMAGE_IMPLEMENTATION #include "stb_i…

鸿蒙开发学习:【ets_frontend组件】

简介 ets_frontend组件是方舟运行时子系统的前端工具&#xff0c;结合ace-ets2bundle组件&#xff0c;支持将ets文件转换为方舟字节码文件。 ets_frontend组件架构图 目录 /arkcompiler/ets_frontend/ ├── test262 # test262测试配置和运行脚本 ├── testTs…

如何轻松打造属于自己的水印相机小程序?

水印相机小程序源码 描述&#xff1a;微信小程序。本文将为您详细介绍小程序水印相机源码的搭建过程&#xff0c;教您如何轻松打造属于自己的水印相机小程序。无论您是初学者还是有一定基础的开发者&#xff0c;都能轻松掌握这个教程。 一&#xff1a;水印相机搭建教程 1 隐…

Node.js:构建高性能网络应用的利器

&#x1f90d; 前端开发工程师、技术日更博主、已过CET6 &#x1f368; 阿珊和她的猫_CSDN博客专家、23年度博客之星前端领域TOP1 &#x1f560; 牛客高级专题作者、打造专栏《前端面试必备》 、《2024面试高频手撕题》 &#x1f35a; 蓝桥云课签约作者、上架课程《Vue.js 和 E…

IDEA + Git + GitHub(保姆级教学)

文章目录 IDEA Git GitHub1.IDEA克隆远程仓库到本地仓库1.创建一个GitHub远程仓库test12.IDEA克隆仓库到本地1.复制远程仓库地址2.创建一个版本控制项目3.克隆到本地仓库4.克隆成功 2.IDEA将本地项目push到远程仓库1.在这个项目下新建一个java模块1.新建模块2.填写模块名3.在…

【机器学习】一文掌握逻辑回归全部核心点(上)。

逻辑回归核心点-上 1、引言2、逻辑回归核心点2.1 定义与目的2.2 模型原理2.2.1 定义解析2.2.2 公式2.2.3 代码示例 2.3 损失函数与优化2.3.1 定义解析2.3.2 公式2.3.3 代码示例 2.4 正则化2.4.1 分类2.4.2 L1正则化2.4.3 L2正则化2.4.4 代码示例 3、总结 1、引言 小屌丝&#…