【Foundation】(三)transformers之Model

news2025/1/15 20:48:01

文章目录

  • 1、介绍
    • 1.1、 模型类型
    • 1.2、Model Head
  • 2、模型加载
  • 3、模型调用
    • 3.1、不带Model Head的模型调用
    • 3.2、带Model Head的模型调用
  • 4、模型微调实战
    • 4.1、导包
    • 4.2、加载数据
    • 4.3、创建数据集
    • 4.4、划分数据集
    • 4.5、创建加载器
    • 4.6、创建模型以及优化器
    • 4.7、模型训练
    • 4.8、模型评估
    • 4.9、模型预测

本篇博客内容以及后续内容均来自b站up主你可是处女座啊

1、介绍

1.1、 模型类型

  • 编码器模型:自编码器模型,使用encoder,拥有双向注意力机制,即计算每一个词的特征时都看到完整的上下文
  • 解码器模型:自回归模型,使用decoder,拥有单向注意力机制,即计算每一个词的特征时智能看到上文,无法看到下文
  • 编码解码器模型:序列到序列的模型,使用encoder-decoder encoder使用双向注意力,decoder使用单向注意力
    在这里插入图片描述

1.2、Model Head

在这里插入图片描述

2、模型加载

from transformers import  AutoConfig,AutoModel,AutoTokenizer
#在线加载
model = AutoModel.from_pretrained('hfl/rbt3')

#模型下载
#!git clone “https://huggingface.co/hfl/rbt3”
!git lfs clone “https://huggingface.co/hfl/rbt3” --include=“*.bin”

#离线加载
model = AutoModel.from_pretrained('rbt3')
模型参数
model.config

3、模型调用

sen = '弱小的我也有大梦想'
tokenizer = AutoTokenizer.from_pretrained('hfl/rbt3',output_attentions=True)
inputs = tokenizer(sen,return_tensors='pt')
inputs

3.1、不带Model Head的模型调用

model = AutoModel.from_pretrained('hfl/rbt3',output_attentions=True)
output = model(**inputs)
output
output.last_hidden_state.size()

3.2、带Model Head的模型调用

from transformers import AutoModelForSequenceClassification,BertForSequenceClassification
clz_model = AutoModelForSequenceClassification.from_pretrained('hfl/rbt3',num_labels=10)
clz_model(**inputs)

clz_model.config.id2label
clz_model.config.num_labels

4、模型微调实战

4.1、导包

#文本分类实战
import pandas as pd
from datasets import load_dataset
from transformers import AutoTokenizer,AutoModelForSequenceClassification,Trainer,TrainingArguments

4.2、加载数据

#法一
data = pd.read_csv('./datasets/ChnSentiCorp_htl_all.csv')
data.head()
data = data.dropna()
data
#法二
dataset = load_dataset('csv',data_files='datasets/ChnSentiCorp_htl_all.csv',split='train')
dataset = dataset.filter(lambda x : x['review'] is not None)
dataset

4.3、创建数据集

from torch.utils.data import Dataset
class MyDataset(Dataset):
    def __init__(self) -> None:
        super().__init__()
        self.data =pd.read_csv('./datasets/ChnSentiCorp_htl_all.csv')
        self.data = self.data.dropna()

    def __getitem__(self, index):
        return self.data.iloc[index]['review'],self.data.iloc[index]['label']
    def __len__(self):
        return len(self.data)
dataset = MyDataset()
for i in range(5):
    print(dataset[i])

4.4、划分数据集

from torch.utils.data import random_split
#法一
trainset,validset = random_split(dataset,lengths=[0.9,0.1])
len(trainset),len(validset)
#法二
dataset = dataset.train_test_split(test_size=0.1)
dataset

4.5、创建加载器

from transformers import AutoTokenizer
import torch
tokenizer =AutoTokenizer.from_pretrained('hfl/rbt3')
def collate_func(batch):
    texts,labels = [],[]
    for item in batch:
        texts.append(item[0])
        labels.append(item[1])
    inputs = tokenizer(texts,max_length=128,padding='max_length',truncation=True,
                       return_tensors='pt')
    inputs['labels']= torch.tensor(labels)
    return inputs
from torch.utils.data import DataLoader
trainloader = DataLoader(trainset,batch_size=32,shuffle=True,collate_fn=collate_func)
validloader = DataLoader(trainset,batch_size=64,shuffle=False,collate_fn=collate_func)
import torch
tokenizer =AutoTokenizer.from_pretrained('rbt3')


def process_function(examples):
    tokenized_examples = tokenizer(examples['review'],max_length=128,truncation=True)
    tokenized_examples['labels'] = examples["label"]
    return tokenized_examples


tokenized_datasets = dataset.map(process_function,batched=True,remove_columns=dataset['train'].column_names)
tokenized_datasets
from torch.utils.data import DataLoader
from transformers import DataCollatorWithPadding
trainset ,validset = tokenized_datasets['train'],tokenized_datasets['test']

trainloader = DataLoader(trainset,batch_size=32,shuffle=True,collate_fn=DataCollatorWithPadding(tokenizer))
validloader = DataLoader(validset,batch_size=32,shuffle=True,collate_fn=DataCollatorWithPadding(tokenizer))

4.6、创建模型以及优化器

from torch.optim import Adam
from transformers import AutoModelForSequenceClassification
#法一
model = AutoModelForSequenceClassification.from_pretrained('hfl/rbt3')
if torch.cuda.is_available():
   model.to('cuda')
opt = Adam(model.parameters(),lr=2e-5)
#方法二 trainer
model = AutoModelForSequenceClassification.from_pretrained('rbt3')
########## 创建评估函数
import evaluate

acc_metric = evaluate.load("accuracy")
f1_metric = evaluate.load("f1")
def eval_metric(eval_predict):
    predictions,labels = eval_predict
    predictions = predictions.argmax(axis=-1)
    acc = acc_metric.compute(predictions=predictions, references=labels)
    f1 = f1_metric.compute(predictions=predictions, references=labels)
    acc.update(f1)
    return acc

4.7、模型训练

import evaluate
clf_metrics = evaluate.combine(['accuracy', 'f1'])

def train(epoch=3,log_step=100):
    global_step = 0
    for ep in range(epoch):
        model.train()
        for batch in trainloader:
            if torch.cuda.is_available():
                batch = {k:v.to('cuda') for k,v in batch.items()}
            opt.zero_grad()
            outputs = model(**batch)
            loss = outputs.loss
            loss.backward()
            opt.step()
            global_step += 1
            if global_step % log_step == 0:
                print(f'epoch:{ep},global_step:{global_step},loss:{loss.item()}')
        acc = evaluate()
        # print(f'ep:{ep},acc:{acc}')
        print(f'ep:{ep},{acc}')
def evaluate():
    model.eval()
    acc_num = 0
    with torch.inference_mode():
        for batch in validloader:
            if torch.cuda.is_available():
                batch = {k:v.to('cuda') for k,v in batch.items()}
            outputs = model(**batch)
            pred = torch.argmax(outputs.logits,dim=-1)
            clf_metrics.add_batch(predictions=pred.long(),references=batch['labels'].long())
            
    return clf_metrics.compute()
            # acc_num += (pred.long() == batch['labels'].long()).float().sum()
    # return acc_num / len(validset)
#创建Training Arguments

train_args = TrainingArguments(output_dir='./checkpoint',       #输出文件
                               per_device_eval_batch_size=8,    #验证时batch大小
                               per_device_train_batch_size=8,   #训练时batch大小
                               logging_steps=10,                #每10步打印日志
                               eval_strategy="epoch",     #评估策略  epoch、step
                               save_steps=100,                  #每100步保存一次模型
                               save_strategy='epoch',           #保存策略
                               save_total_limit=2,              #保存模型的数量
                               learning_rate=2e-5,              #学习率
                               weight_decay=1e-5,               #衰减率
                               metric_for_best_model="f1",      #最好模型评估标准
                               load_best_model_at_end = True,   #加载最优模型
                               )
train_args
#创建Trainer
from transformers import DataCollatorWithPadding
trainer = Trainer(
    model = model,
    args=train_args,
    train_dataset=tokenized_datasets['train'],
    eval_dataset=tokenized_datasets['test'],
    data_collator=DataCollatorWithPadding(tokenizer=tokenizer),
    compute_metrics=eval_metric
    )
trainer.train()

4.8、模型评估

trainer.evaluate(tokenized_datasets["test"])

4.9、模型预测

trainer.predict(tokenized_datasets["test"])
sen = '我觉得这家酒店不错,饭很好吃'
inputs = tokenizer(sen,return_tensors='pt')
id2label = {0:'差评!',1:'好评!'}
model.eval()
with torch.inference_mode():
    inputs = tokenizer(sen,return_tensors='pt')
    inputs = {k:v.cuda() for k,v in inputs.items()}
    logits = model(**inputs).logits
    pred = torch.argmax(logits,dim=-1)
    print(f'输入:{sen}\n模型预测结果:{pred.item()}')
    print(pred)
from transformers import pipeline

model.config.id2label = id2label
pipe = pipeline('text-classification',model=model,tokenizer=tokenizer,device=0)
pipe(sen)

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1985687.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

Figma 替代品 Excalidraw 安装和使用教程

如今远程办公盛行,一个好用的在线白板工具对于团队协作至关重要。然而,市面上的大多数白板应用要么功能单一,要么操作复杂,难以满足用户的多样化需求。尤其是在进行头脑风暴、流程设计或产品原型绘制时,我们常常会遇到…

linux入门到精通-第二十章-bufferevent(开源高性能事件通知库)

目录 参考bufferevent简单介绍工作流程事件Api新建事件节点 bufferevent_socket_new设置事件节点回调bufferevent_setcb使事件势能bufferevent_enable发送数据bufferevent_write接收数据bufferevent_read evconnlistener的简介 参考 视频教程 libevent的基本使用 libevent–bu…

HslCommunicationDemo各品牌Plc通信测试软件工具

目录 1、HslCommunicationDemo程序包 2、ModbusTCP举例说明 (0)概述 (1)线圈写操作 (2)寄存器写操作 3、C#工程中DLL库文件使用 (1)创建Winform程序工程 (2)写寄存器 1、HslC…

【Linux】匿名管道|命名管道|pipe|mkfifo|管道原理|通信分类|管道的特征和情况

目录 ​编辑 进程间通信 为什么要有进程间通信 进程通信的目的 进程间通信分类 如何理解通信 管道 匿名管道 管道原理 半双工 通信两问 pipe 演示 管道情况 管道的特征 命名管道 mkfifo指令 mkfifo接口 命名管道提供的是流式服务 匿名管道与命名管道的…

day08 1.进程间通信

work1.c #include <myhead.h> //要发送的消息类型 struct msgbuf {long mtype;char mtext[1024]; };#define SIZE sizeof(struct msgbuf)-sizeof(long)int main(int argc, const char *argv[]) {pid_t pid fork();if(pid -1){perror("fork error");return -…

Webpack入门基础知识及案例

webpack相信大家都已经不陌生了&#xff0c;应用程序的静态模块打包工具。前面我们总结了vue&#xff0c;react入门基础知识&#xff0c;也分别做了vue3的实战小案例&#xff0c;react的实战案例&#xff0c;那么我们如何使用webpack对项目进行模块化打包呢&#xff1f; 话不多…

RPA与智慧政务的关系

自1992年国务院明确提出构建全国行政机关办公决策系统&#xff0c;我国政府信息化建设已走过三十余年历程&#xff0c;并取得了阶段性成果&#xff0c;随着社会需求的变化以及信息技术和数字化工具的不断完善&#xff0c;人们对政府的信息化建设也提出了新的要求&#xff0c;推…

【C#语音文字互转】C#语音转文字(方法一)

Whisper.NET开源项目&#xff1a;https://github.com/sandrohanea/whisper.net/tree/main 一. 环境准备 在VS中安装 Whisper.net&#xff0c;在NuGet包管理器控制台中运行以下命令&#xff1a; Install-Package Whisper.net Install-Package Whisper.net.Runtime其中运行时包…

uniapp 实现自定义缩略滚动条

<template><view class"container-scroll"><!-- 文字导航 --><scroll-view class"scroll-view-text" scroll-x"true" v-if"type 1"><navigator:url"item.url"class"scroll-view-item"…

LE-50821F/FA激光扫描传感器|360°避障雷达之功能与连接使用说明

LE-50821F/FA激光扫描传感器|360避障雷达广泛应用于工业自动化、移动机器人应用场景中的环境感知、高精度定位&#xff08;如建图、扫描、避障、防护等&#xff09; LE-50xxxF系列升级扫描频率最高可达600KHz​​​​。 本文重点介绍LE-50821F/FA激光扫描传感器|360避障雷达之…

【C++】二维数组 数组名

二维数组名用途 1、查看所占内存空间 2、查看二维数组首地址 针对第一种用途&#xff0c;还可以计算数组有多少行、多少列、多少元素 针对第二种用途&#xff0c;数组元素、行数、列数都是连续的&#xff0c;且相差地址是有规律的 下面是一个实例 #include<iostream&g…

FreeRTOS基础入门——FreeRTOS的系统配置(三)

个人名片&#xff1a; ​ &#x1f393;作者简介&#xff1a;嵌入式领域优质创作者&#x1f310;个人主页&#xff1a;妄北y &#x1f4de;个人QQ&#xff1a;2061314755 &#x1f48c;个人邮箱&#xff1a;[mailto:2061314755qq.com] &#x1f4f1;个人微信&#xff1a;Vir202…

基于大模型的Agent

2023年&#xff0c;对于所有的人工智能领域只有一个共同的主题——大模型。大模型的受关注程度与发展速度可谓前所未有。其中&#xff0c;基于大模型的Agent又是最近几个月大模型领域的热点。这不开始研究没有几个月&#xff0c;综述文章都出来了&#xff0c;你说快不快&#x…

FashionAI比赛-服饰属性标签识别比赛赛后总结(来自 Top14 Team)

关联比赛: FashionAI全球挑战赛—服饰属性标签识别 推荐大家看本篇博客之前&#xff0c;看一下数据集制作的方法&#xff0c;如何做一个实用的图像数据集 PS&#xff1a;我是参加完比赛之后才看的&#xff0c;看完之后&#xff0c;万马奔腾.....&#xff0c;因为发现比赛中还…

62 函数参数——传递参数时的序列解包

与可变长度的参数相反&#xff0c;这里的序列解包是指实参&#xff0c;同样也有 * 和 ** 两种形式。 ① 调用含有多个位置参数的函数时&#xff0c;可以使用 Python 列表、元组、集合、字典以及其他可迭代对象作为实参&#xff0c;并在实参名称前加一个星号&#xff0c;Python …

element-ui/plus使用el-date-picker周 选择器返回时间范围处理案例

element-ui/plus使用el-date-picker周 选择器返回时间范围处理案例 如图所示 <el-date-pickerchange"changeTime":picker-options"{ firstDayOfWeek: 1 }"v-model"value1"type"week"format"YYYY年 第ww周"placeholder&…

C++初学者指南-5.标准库(第二部分)--数值运算算法

C初学者指南-5.标准库(第二部分)–数值运算算法 文章目录 C初学者指南-5.标准库(第二部分)--数值运算算法iota (注意不是itoa函数)Reductions reduce transform_reduce遗留操作&#xff08;无法并行执行&#xff09;accumulate (≈ reduce) C98inner_product (≈ transform_r…

sanger序列拼接--一次错误示范

文章目录 目的实现步骤 目的 NGS得到了很多的reads&#xff0c;其中有一些paired reads我想根据overlap 搭建起来&#xff0c;因为我对序列的ID做了删减&#xff0c;所以再pandaseq那里跑不通。 总结来说&#xff0c;目的很简单&#xff0c;就是把 有重叠区域的 reads 搭起来…

【学习笔记】A2X通信的协议(二)- A2X配置参数

目录 5. A2X配置参数 5.1 一般说明 5.2 A2X配置参数的配置和优先级 5.2.1 一般说明 5.2.2 A2X配置参数的优先级 5.2.3 通过PC5进行的A2X通信的配置参数 5.2.4 广播远程ID&#xff08;BRID&#xff09;的配置参数 5.2.5 直接检测和避免&#xff08;DDAA&#xff09;的配…

解决 Beyond Compare 30天过期问题

解决 Beyond Compare 30天过期的步骤如下&#xff1a; 1、使用快捷键WinR打开运行窗口&#xff0c;输入regedit并回车&#xff0c;打开注册表编辑器。 2、在注册表编辑器中&#xff0c;找到Beyond Compare的注册表位置&#xff0c;路径通常是HKEY_CURRENT_USER\Software\Scoot…