3 文本分类入门finetune:bert-base-chinese

news2025/1/11 7:11:00

项目实战:

数据准备工作

        `bert-base-chinese` 是一种预训练的语言模型,基于 BERT(Bidirectional Encoder Representations from Transformers)架构,专门用于中文自然语言处理任务。BERT 是由 Google 在 2018 年提出的一种革命性的预训练模型,通过大规模的无监督训练,能够学习到丰富的语言表示。

        `bert-base-chinese` 是 BERT 在中文语料上进行预训练的版本,它包含了 12 层 Transformer 编码器和 110 万个参数。这个模型在中文文本上进行了大规模的预训练,可以用于各种中文自然语言处理任务,如文本分类、命名实体识别、情感分析等。

使用 `bert-base-chinese` 模型时,可以将其作为一个特征提取器,将输入的文本转换为固定长度的向量表示,然后将这些向量输入到其他机器学习模型中进行训练或推断。也可以对 `bert-base-chinese` 进行微调,将其用于特定任务的训练。

        预训练的 `bert-base-chinese` 模型可以通过 Hugging Face 的 Transformers 库进行加载和使用。在加载模型后,可以使用它的 `encode` 方法将文本转换为向量表示,或者使用 `forward` 方法对文本进行特定任务的预测。

        需要注意的是,`bert-base-chinese` 是一个通用的中文语言模型,但它可能在特定的任务上表现不佳。在某些情况下,可能需要使用更大的模型或进行微调来获得更好的性能。

进行微调时,可以按照以下步骤进行操作:

1. 准备数据集:首先,你需要准备一个与你的任务相关的标注数据集。这个数据集应该包含输入文本以及相应的标签或注释,用于训练和评估模型。

2. 加载预训练模型:使用 Hugging Face 的 Transformers 库加载预训练的 `bert-base-chinese` 模型。你可以选择加载整个模型或只加载其中的一部分,具体取决于你的任务需求。

3. 创建模型架构:根据你的任务需求,创建一个适当的模型架构。这通常包括在 `bert-base-chinese` 模型之上添加一些额外的层,用于适应特定的任务。

4. 数据预处理:将你的数据集转换为适合模型输入的格式。这可能包括将文本转换为输入的编码表示,进行分词、填充和截断等操作。

5. 定义损失函数和优化器:选择适当的损失函数来衡量模型预测与真实标签之间的差异,并选择合适的优化器来更新模型的参数。

6. 微调模型:使用训练集对模型进行训练。在每个训练步骤中,将输入文本提供给模型,计算损失并进行反向传播,然后使用优化器更新模型的参数。

7. 评估模型:使用验证集或测试集评估模型的性能。可以计算准确率、精确率、召回率等指标来评估模型在任务上的表现。

8. 调整和优化:根据评估结果,对模型进行调整和优化。你可以尝试不同的超参数设置、模型架构或训练策略,以获得更好的性能。

9. 推断和应用:在微调完成后,你可以使用微调后的模型进行推断和应用。将新的输入文本提供给模型,获取预测结果,并根据任务需求进行后续处理。

需要注意的是,微调的过程可能需要大量的计算资源和时间,并且需要对模型和数据进行仔细的调整和优化。此外,合适的数据集规模和质量对于获得良好的微调结果也非常重要。

准备模型

https://huggingface.co/bert-base-chinese/tree/main

数据集

与之前的训练数据一样使用;

代码部分

# 导入transformers
import transformers
from transformers import BertModel, BertTokenizer, AdamW, get_linear_schedule_with_warmup

# 导入torch
import torch
from torch import nn, optim
from torch.utils.data import Dataset, DataLoader
import torch.nn.functional as F
 
# 常用包
import re
import numpy as np
import pandas as pd
import seaborn as sns
from pylab import rcParams
import matplotlib.pyplot as plt
from matplotlib import rc
from sklearn.model_selection import train_test_split
from sklearn.metrics import confusion_matrix, classification_report
from collections import defaultdict
from textwrap import wrap
from tqdm import tqdm
 
%matplotlib inline
%config InlineBackend.figure_format='retina' # 主题
 
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
device


PRE_TRAINED_MODEL_NAME = '../bert-base-chinese/' # 英文bert预训练模型
tokenizer = BertTokenizer.from_pretrained(PRE_TRAINED_MODEL_NAME)
BertModel.from_pretrained(PRE_TRAINED_MODEL_NAME)
tokenizer

df = pd.read_csv("../data/online_shopping_10_cats.csv")
#myDataset[0]
RANDOM_SEED = 1012
df_train, df_test = train_test_split(df, test_size=0.1, random_state=RANDOM_SEED)
df_val, df_test = train_test_split(df_test, test_size=0.5, random_state=RANDOM_SEED)
df_train.shape, df_val.shape, df_test.shape



class MyDataSet(Dataset):
    def __init__(self,texts,labels,tokenizer,max_len):
        self.texts = texts
        self.labels = labels
        self.tokenizer = tokenizer
        self.max_len = max_len
    def __len__(self):
        return len(self.texts)
    def __getitem__(self,item):
        text = str(self.texts[item])
        label = self.labels[item]
 
        encoding = self.tokenizer(text=text,
                      max_length=self.max_len,
                      pad_to_max_length=True,
                      add_special_tokens=True,
                      return_attention_mask=True,
                      return_token_type_ids=True,
                      return_tensors='pt')
        return {"text":text,
               "input_ids":encoding['token_type_ids'].flatten(),
               "attention_mask":encoding['attention_mask'].flatten(),
               "labels":torch.tensor(label,dtype=torch.long)}
def create_data_loader(df,tokenizer,max_len,batch_size=4):
    ds = MyDataSet(texts=df['review'].values,
                  labels=df['label'].values,
                   tokenizer = tokenizer,
                   max_len=max_len
                  )
    return DataLoader(ds,batch_size=batch_size)
MAX_LEN = 512
BATCH_SIZE = 32
train_data_loader = create_data_loader(df_train,tokenizer,max_len=MAX_LEN, batch_size=BATCH_SIZE)
val_data_loader = create_data_loader(df_val,tokenizer,max_len=MAX_LEN, batch_size=BATCH_SIZE)
test_data_loader = create_data_loader(df_test,tokenizer,max_len=MAX_LEN, batch_size=BATCH_SIZE)

class BaseBertModel(nn.Module):
    def __init__(self,n_class=2):
        super(BaseBertModel,self).__init__()
        self.bert = BertModel.from_pretrained(PRE_TRAINED_MODEL_NAME)
        self.drop = nn.Dropout(0.2)
        self.out = nn.Linear(self.bert.config.hidden_size, n_class)
        pass
    def forward(self,input_ids,attention_mask):
        _,pooled_output = self.bert(input_ids=input_ids,
                 attention_mask=attention_mask,
                 return_dict = False)
        out = self.drop(pooled_output)
        return self.out(out)
 
#pooled_output
model = BaseBertModel()
model = model.to(device)

EPOCHS = 10 # 训练轮数
optimizer = AdamW(model.parameters(),lr=2e-4,correct_bias=False)
total_steps = len(train_data_loader) * EPOCHS
scheduler = get_linear_schedule_with_warmup(optimizer,num_warmup_steps=0,
                               num_training_steps=total_steps)
loss_fn = nn.CrossEntropyLoss().to(device)

def train_epoch(model,data_loader,loss_fn,optimizer,device,schedule,n_examples):
    model = model.train()
    losses = []
    correct_predictions = 0
    count = 0
    for d in tqdm(data_loader):
        input_ids = d['input_ids'].to(device)
        attention_mask = d['attention_mask'].to(device)
        targets = d['labels'].to(device)
        outputs = model(
            input_ids=input_ids,
            attention_mask=attention_mask
        )
        _,preds = torch.max(outputs, dim=1)
        loss = loss_fn(outputs,targets)
        losses.append(loss.item())
        
        correct_predictions += torch.sum(preds==targets)
        loss.backward()
        nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
        optimizer.step()
        scheduler.step()
        optimizer.zero_grad()
        
        count += 1
        # if count > 100:
        #     break
        #break
    return correct_predictions.double() / n_examples, np.mean(losses)
 
def eval_model(model, data_loader, loss_fn, device, n_examples):
    model = model.eval() # 验证预测模式
 
    losses = []
    correct_predictions = 0
 
    with torch.no_grad():
        for d in data_loader:
            input_ids = d["input_ids"].to(device)
            attention_mask = d["attention_mask"].to(device)
            targets = d["labels"].to(device)
 
            outputs = model(
                input_ids=input_ids,
                attention_mask=attention_mask
            )
            _, preds = torch.max(outputs, dim=1)
 
            loss = loss_fn(outputs, targets)
 
            correct_predictions += torch.sum(preds == targets)
            losses.append(loss.item())
 
    return correct_predictions.double() / n_examples, np.mean(losses)
 
 
# train model
history = defaultdict(list) # 记录10轮loss和acc
best_accuracy = 0
 
for epoch in range(EPOCHS):
 
    print(f'Epoch {epoch + 1}/{EPOCHS}')
    print('-' * 10)
 #(model,data_loader,loss_fn,device,schedule,n_exmaples)
    train_acc, train_loss = train_epoch(
        model,
        train_data_loader,
        loss_fn,
        optimizer,
        device,
        scheduler,
        len(df_train)
    )
 
    print(f'Train loss {train_loss} accuracy {train_acc}')
 
    val_acc, val_loss = eval_model(
        model,
        val_data_loader,
        loss_fn,
        device,
        len(df_val)
    )
 
    print(f'Val loss {val_loss} accuracy {val_acc}')
    print()
 
    history['train_acc'].append(train_acc)
    history['train_loss'].append(train_loss)
    history['val_acc'].append(val_acc)
    history['val_loss'].append(val_loss)
 
    if val_acc > best_accuracy:
        torch.save(model.state_dict(), 'best_model_state.bin')
        best_accuracy = val_acc
 
 
# 模型评估
test_acc, _ = eval_model(
  model,
  test_data_loader,
  loss_fn,
  device,
  len(df_test)
)
test_acc.item()
 
 
def get_predictions(model, data_loader):
    model = model.eval()
 
    texts = []
    predictions = []
    prediction_probs = []
    real_values = []
 
    with torch.no_grad():
        for d in data_loader:
            texts = d["texts"]
            input_ids = d["input_ids"].to(device)
            attention_mask = d["attention_mask"].to(device)
            targets = d["labels"].to(device)
 
            outputs = model(
                input_ids=input_ids,
                attention_mask=attention_mask
            )
            _, preds = torch.max(outputs, dim=1)
 
            probs = F.softmax(outputs, dim=1)
 
            texts.extend(texts)
            predictions.extend(preds)
            prediction_probs.extend(probs)
            real_values.extend(targets)
 
    predictions = torch.stack(predictions).cpu()
    prediction_probs = torch.stack(prediction_probs).cpu()
    real_values = torch.stack(real_values).cpu()
    return texts, predictions, prediction_probs, real_values
 
y_texts, y_pred, y_pred_probs, y_test = get_predictions(
  model,
  test_data_loader
)
print(classification_report(y_test, y_pred, target_names=[str(label) for label in class_names]))
 
# 模型预测
 
sample_text='Hard but Robust, Easy but Sensitive: How Encod.'
encoded_text = tokenizer.encode_plus(
  sample_text,
  max_length=MAX_LEN,
  add_special_tokens=True,
  return_token_type_ids=False,
  pad_to_max_length=True,
  return_attention_mask=True,
  return_tensors='pt',
)
 
input_ids = encoded_text['input_ids'].to(device)
attention_mask = encoded_text['attention_mask'].to(device)
 
output = model(input_ids, attention_mask)
_, prediction = torch.max(output, dim=1)
 
print(f'Sample text: {sample_text}')
print(f'Danger label  : {label_id2cate[prediction.cpu().numpy()[0]]}')

runing....

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1292329.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

Leetcode—1466.重新规划路线【中等】

2023每日刷题&#xff08;五十二&#xff09; Leetcode—1466.重新规划路线 算法思想 实现代码 class Solution { public:int minReorder(int n, vector<vector<int>>& connections) {vector<pair<int, int>> g[n];for(auto e: connections) {in…

“掌握高效视频分割技巧,降低误差,提高精度“

如果你是一名视频编辑爱好者或者专业人士&#xff0c;那么你一定会在视频剪辑的过程中遇到各种挑战。其中&#xff0c;如何准确高效地进行视频分割是一个至关重要的问题。现在&#xff0c;我们将向你展示一种全新的解决方案&#xff0c;帮助你轻松解决这些问题。 首先第一步&a…

Vue使用百度地图以及实现轨迹回放 附完整代码

百度地图开放平台 https://lbs.baidu.com/index.php?title%E9%A6%96%E9%A1%B5 javaScript API https://lbs.baidu.com/index.php?titlejspopularGL 百度地图实例 https://lbsyun.baidu.com/index.php?titleopen/jsdemoVue Baidu Map文档 https://dafrok.github.io/vue-baidu…

数据结构与算法:红黑树

目录 什么是红黑树 ​编辑 红黑树的性质 红黑树的辨析 红黑树实现 红黑树的结构 红黑树的插入 红黑树的调试 红黑树平衡判断 什么是红黑树 这里引入一下NIL节点的概念&#xff1a; NIL节点也称为空节点或外部节点&#xff0c;是红黑树中一种特殊的节点类型。NIL节点不…

DeepIn,UOS统信专业版安装运行Java,JavaFx程序

因为要适配国产统信UOS系统&#xff0c;要求JavaFx程序能简便双击运行&#xff0c;由于网上UOS开发相关文章少&#xff0c;多数文章没用&#xff0c;因此花了不少时间&#xff0c;踩了不少坑&#xff0c;下面记录一些遇到的问题&#xff0c;我的程序环境是jdk1.8&#xff0c;为…

应用程序中实现用户隐私合规和数据保护合规的处理方案及建议

随着移动互联网的发展&#xff0c;用户隐私合规和数据保护合规已经成为应用开发过程中不可忽视的重要环节。为了帮助开发者实现隐私和数据保护合规&#xff0c;本文将介绍一些处理方案和建议。 图片来源&#xff1a;应用程序中实现用户隐私合规和数据保护合规的处理方案及建议 …

720度vr虚拟家居展厅提升客户的参观兴致

VR虚拟展厅线上3D交互展示的优势有以下几点&#xff1a; 打破了场馆的展示限制&#xff0c;可展示危险性制品、珍贵稀有物品、超大型设备等&#xff0c;同时提供了更大的展示空间和更丰富的展示内容。 可提供企业真实环境的实时VR全景参观&#xff0c;提升潜在客户信任度。 提供…

【附源码】完整版,Python+Selenium+Pytest+POM自动化测试框架封装

目录&#xff1a;导读 前言一、Python编程入门到精通二、接口自动化项目实战三、Web自动化项目实战四、App自动化项目实战五、一线大厂简历六、测试开发DevOps体系七、常用自动化测试工具八、JMeter性能测试九、总结&#xff08;尾部小惊喜&#xff09; 前言 1、测试框架简介 …

【S2ST】Direct Speech-to-Speech Translation With Discrete Units

【S2ST】Direct Speech-to-Speech Translation With Discrete Units AbstractIntroductionRelated workModelSpeech-to-unit translation (S2UT) modelMultitask learningUnit-based vocoder ExperimentsDataSystem setupBaselineASRMTTTSS2TTransformer Translatotron Evaluat…

【广州华锐互动】VR沉浸式体验铝厂安全事故让伤害教育更加深刻

随着科技的不断发展&#xff0c;虚拟现实&#xff08;VR&#xff09;技术已经逐渐渗透到各个领域&#xff0c;为我们的生活带来了前所未有的便捷和体验。在安全生产领域&#xff0c;VR技术的应用也日益受到重视。 VR公司广州华锐互动就开发了多款VR安全事故体验系统&#xff0c…

大数据技术2:大数据处理流程

前言&#xff1a;下图是一个简化的大数据处理流程图&#xff0c;大数据处理的主要流程包括数据收集、数据存储、数据处理、数据应用等主要环节。 1.1 数据收集 大数据处理的第一步是数据的收集。现在的中大型项目通常采用微服务架构进行分布式部署&#xff0c;所以数据的采集需…

如何使用Docker本地搭建开源CMF Drupal并结合内网穿透公网访问

文章目录 前言1. Docker安装Drupal2. 本地局域网访问3 . Linux 安装cpolar4. 配置Drupal公网访问地址5. 公网远程访问Drupal6. 固定Drupal 公网地址 前言 Dupal是一个强大的CMS&#xff0c;适用于各种不同的网站项目&#xff0c;从小型个人博客到大型企业级门户网站。它的学习…

【目标检测算法】IOU、GIOU、DIOU、CIOU

目录 参考链接 前言 IOU(Intersection over Union) 优点 缺点 代码 存在的问题 GIOU(Generalized Intersection over Union) 来源 GIOU公式 实现代码 存在的问题 DIoU(Distance-IoU) 来源 DIOU公式 优点 实现代码 总结 参考链接 IoU系列&#xff08;IoU, GIoU…

Java程序员,你掌握了多线程吗?(文末送书)

目录 01、多线程对于Java的意义02、为什么Java工程师必须掌握多线程03、Java多线程使用方式04、如何学好Java多线程送书规则 摘要&#xff1a;互联网的每一个角落&#xff0c;无论是大型电商平台的秒杀活动&#xff0c;社交平台的实时消息推送&#xff0c;还是在线视频平台的流…

Java API接口强势对接:构建高效稳定的系统集成方案

文章目录 1. Java API接口简介2. Java API接口的优势2.1 高度可移植性2.2 强大的网络通信能力2.3 多样化的数据处理能力 3. 实战&#xff1a;Java API接口强势对接示例3.1 场景描述3.2 用户管理系统3.3 订单处理系统3.4 系统集成 4. 拓展&#xff1a;Java API接口在微服务架构中…

麒麟信安系统下的硬盘分区情况说明

目前飞腾平台上面麒麟信安系统分区情况如下&#xff1a; Tmpfs为内存文件系统&#xff0c;可以不考虑&#xff0c;真正使用的是两个分区 两个分区加起来为51G 查看cat /etc/fstab可以看到/data这个分区下包含了home opt root等常用文件夹 再加上这个分区容量只有17G&#xff0c…

基于Browscap对浏览器工具类优化

项目背景 原有的启动平台公共组件库comm-util的浏览器工具类BrowserUtils是基于UserAgentUtils的&#xff0c;但是该项目最后一个版本发布于 2018/01/24&#xff0c;之至今日23年底&#xff0c;已有5年没有维护更新&#xff0c;会造成最新版本的部分浏览器不能正确获取到浏览器…

嵌入式工程师校招经验与学习路线总结

前言&#xff1a;不知不觉2023年秋招已经结束&#xff0c;作者本人侥幸于秋招中斩获数十份大差不差的OFFER&#xff0c;包含&#xff1a;Top级的AIGC&#xff0c;工控龙头&#xff0c;国产MCU原厂&#xff0c;医疗器械&#xff0c;新能源车企等。总而言之&#xff0c;秋招总体情…

NR重写console.log 增加时间格式

如题&#xff0c;默认console.log输出的日志是13位的时间戳&#xff0c;然后不方便查查看与对比代码运行点的耗时&#xff0c;我们可以简单的重写 console.log方法&#xff0c;增加自定义时间戳格式&#xff0c;如下是增加时间&#xff08;时&#xff0c;分&#xff0c;秒&…

苍穹外卖+git开源

搁置了很久重新开始学 为了学习方便&#xff0c;苍穹外卖的前后端代码已放至git开源。前端源代码请看给i他-->sky-take-out: 苍穹外卖 git学习-->Git基础使用-CSDN博客 后端接口员工管理和分类管理模块 添加员工&#xff0c;添加的表单账号、手机号、身份证都…