【Transformers基础入门篇4】基础组件之Model

news2024/12/28 22:25:31

文章目录

  • 一、Model简介
    • 1.1 Transformer
    • 1.2 注意力机制
    • 1.3 模型类型
  • 二、Model Head
    • 2.1 什么是 Model Head
    • 2.2 Transformers中的Model Head
  • 三、Model基本使用方法
    • 3.0 模型下载-浏览器下载
    • 3.1 模型加载与保存
    • 3.2 配置加载参数
    • 3.3 加载config文件
    • 3.2 模型调用
      • 3.2.1 带ModelHead的调用
      • 3.2.2 不带ModelHead的调用
      • 3.2.3带ModelHead的调用
  • 四、模型微调代码实例


本文为 https://space.bilibili.com/21060026/channel/collectiondetail?sid=1357748的视频学习笔记

项目地址为:https://github.com/zyds/transformers-code


一、Model简介

1.1 Transformer

既然这个包的名字叫Transformers,那么大部分整个模型都是基于Transformer架构。

  • 原始的Tranformer为编码器(Encoder)、解码器(Decoder)模型
  • Encoder部分接收输入并构建其完整的特征表示,Decoder部分使用Encoder的编码结果以及其他的输入生成目标序列
  • 无论是编码器还是解码器,均由多个TransformerBlock堆叠而成
  • TransformerBloc由注意力机制(Attention)和FFN组成。

1.2 注意力机制

  • 注意力机制的使用是Transformer的一个核心特征,在计算当前次的特征表示时,可以通过注意力机制有选择性的告诉模型使用哪些上下文
    在这里插入图片描述

1.3 模型类型

  • 编码器模型:自编码模型,使用Encoder,拥有双向的注意力机制,即计算每一个词的特征都看到完整的上下文
  • 解码器模型:自回归模型,使用Decoder,拥有单向的注意力机制,即计算每一个词的特征时,都只能看到上文,无法看到下文
  • 编码器解码器模型:序列到序列模型,使用Encoder+Decoder,Encoder部分拥有双向的注意力,Decoder部分使用单向注意力
    在这里插入图片描述

二、Model Head

2.1 什么是 Model Head

  • Model Head是连接在模型后的层,通常为1个或多个全连接层
  • Model Head将模型的编码的表示结果进行映射,以解决不同类型的任务
    在这里插入图片描述

2.2 Transformers中的Model Head

  • *Model: 模型本身,只返回编码结果
  • *ForCausalLM: 纯的解码器模型
  • *ForMaskedLM:像Bert那种,在句子中随机Mask一些token,然后来预测token是什么
  • *ForSeq2SeqLM: 序列到序列的模型
  • *ForMultipleChoice:多项选择任务
  • *ForQuestionAnswering
  • *ForSequenceClassification
  • *ForTokenClassification
  • …其他

下图是Bert刚提出来时,给出的不同任务的解决方案

  • 句子对分类:取CLS 分类
  • 单一句子分类 取CLS 分类
  • 问答: 将Question和Paragraph同时输入,然后对Paragraph做两个全连接,记住起始和结束位置。
  • NER任务(命名实体识别):对每一个token进行分类
    在这里插入图片描述

三、Model基本使用方法

3.0 模型下载-浏览器下载

https://huggingface.co/hfl/rbt3/tree/main
自行下,然后放到hfl/rbt3文件夹
在这里插入图片描述
也可以通过代码下载,在huggingface模型上选
在这里插入图片描述

!git clone "https://huggingface.co/hfl/rbt3"  # 会下载所有模型,包括tf、flax模型
!git lfs clone "https://huggingface.co/hfl/rbt3" --include="*.bin" # 仅下载pytorch模型

3.1 模型加载与保存

在HuggingFace上选一个小模型 hfl/rbt3

from transformers import AutoConfig, AutoModel, AutoTokenizer
model = AutoModel.from_pretrained("../../models/hfl/rbt3")

3.2 配置加载参数

model.config

3.3 加载config文件

# 加载config文件
config = AutoConfig.from_pretrained("../../models/hfl/rbt3")
config

输出BertConfig,那么我们就可以进入BertConfig来看具体的参数
在这里插入图片描述from transformers import BertConfig
查看BertConfig

class BertConfig(PretrainedConfig):
    r"""
    This is the configuration class to store the configuration of a [`BertModel`] or a [`TFBertModel`]. It is used to
    instantiate a BERT model according to the specified arguments, defining the model architecture. Instantiating a
    configuration with the defaults will yield a similar configuration to that of the BERT
    [google-bert/bert-base-uncased](https://huggingface.co/google-bert/bert-base-uncased) architecture.

    Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
    documentation from [`PretrainedConfig`] for more information.


    Args:
        vocab_size (`int`, *optional*, defaults to 30522):
            Vocabulary size of the BERT model. Defines the number of different tokens that can be represented by the
            `inputs_ids` passed when calling [`BertModel`] or [`TFBertModel`].
        hidden_size (`int`, *optional*, defaults to 768):
            Dimensionality of the encoder layers and the pooler layer.
        num_hidden_layers (`int`, *optional*, defaults to 12):
            Number of hidden layers in the Transformer encoder.
        num_attention_heads (`int`, *optional*, defaults to 12):
            Number of attention heads for each attention layer in the Transformer encoder.
        intermediate_size (`int`, *optional*, defaults to 3072):
            Dimensionality of the "intermediate" (often named feed-forward) layer in the Transformer encoder.
        hidden_act (`str` or `Callable`, *optional*, defaults to `"gelu"`):
            The non-linear activation function (function or string) in the encoder and pooler. If string, `"gelu"`,
            `"relu"`, `"silu"` and `"gelu_new"` are supported.
        hidden_dropout_prob (`float`, *optional*, defaults to 0.1):
            The dropout probability for all fully connected layers in the embeddings, encoder, and pooler.
        attention_probs_dropout_prob (`float`, *optional*, defaults to 0.1):
            The dropout ratio for the attention probabilities.
        max_position_embeddings (`int`, *optional*, defaults to 512):
            The maximum sequence length that this model might ever be used with. Typically set this to something large
            just in case (e.g., 512 or 1024 or 2048).
        type_vocab_size (`int`, *optional*, defaults to 2):
            The vocabulary size of the `token_type_ids` passed when calling [`BertModel`] or [`TFBertModel`].
        initializer_range (`float`, *optional*, defaults to 0.02):
            The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
        layer_norm_eps (`float`, *optional*, defaults to 1e-12):
            The epsilon used by the layer normalization layers.
        position_embedding_type (`str`, *optional*, defaults to `"absolute"`):
            Type of position embedding. Choose one of `"absolute"`, `"relative_key"`, `"relative_key_query"`. For
            positional embeddings use `"absolute"`. For more information on `"relative_key"`, please refer to
            [Self-Attention with Relative Position Representations (Shaw et al.)](https://arxiv.org/abs/1803.02155).
            For more information on `"relative_key_query"`, please refer to *Method 4* in [Improve Transformer Models
            with Better Relative Position Embeddings (Huang et al.)](https://arxiv.org/abs/2009.13658).
        is_decoder (`bool`, *optional*, defaults to `False`):
            Whether the model is used as a decoder or not. If `False`, the model is used as an encoder.
        use_cache (`bool`, *optional*, defaults to `True`):
            Whether or not the model should return the last key/values attentions (not used by all models). Only
            relevant if `config.is_decoder=True`.
        classifier_dropout (`float`, *optional*):
            The dropout ratio for the classification head.

    Examples:

    ```python
    >>> from transformers import BertConfig, BertModel

    >>> # Initializing a BERT google-bert/bert-base-uncased style configuration
    >>> configuration = BertConfig()

    >>> # Initializing a model (with random weights) from the google-bert/bert-base-uncased style configuration
    >>> model = BertModel(configuration)

    >>> # Accessing the model configuration
    >>> configuration = model.config
    ```"""

3.2 模型调用

3.2.1 带ModelHead的调用

sen = "弱小的我也有大梦想!"
tokenizer = AutoTokenizer.from_pretrained("../../models/hfl/rbt3")
inputs = tokenizer(sen, return_tensors="pt")
output = model(**inputs)

3.2.2 不带ModelHead的调用

model = AutoModel.from_pretrained("../../models/hfl/rbt3", output_attentions=True)
output = model(**inputs)
output.last_hidden_state.size() # 输出序列的输入编码的维度
len(inputs["input_ids"][0]) # 查看输入ID的长度

输出的12就是字的长度
在这里插入图片描述

3.2.3带ModelHead的调用

from transformers import AutoModelForSequenceClassification, BertForSequenceClassification
clz_model = AutoModelForSequenceClassification.from_pretrained("../../models/hfl/rbt3")
clz_model(**inputs)

输出可以看到任务是一个文本分类任务,输出loss,logits .
这时候logits还只有一个二分类模型,我们可以调整成10类
在这里插入图片描述

clz_model = AutoModelForSequenceClassification.from_pretrained("../../models/hfl/rbt3", num_labels=10)
clz_model(**inputs)

在这里插入图片描述
那么怎么看num_labels参数呢?
定位到 BertForSequenceClassification类实现里看到 config里有num_labels

class BertForSequenceClassification(BertPreTrainedModel):
    def __init__(self, config):
        super().__init__(config)
        self.num_labels = config.num_labels
        self.config = config

        self.bert = BertModel(config) 
        classifier_dropout = (
            config.classifier_dropout if config.classifier_dropout is not None else config.hidden_dropout_prob
        )
        self.dropout = nn.Dropout(classifier_dropout)
        self.classifier = nn.Linear(config.hidden_size, config.num_labels)

        # Initialize weights and apply final processing
        self.post_init()

四、模型微调代码实例

  • 任务类型: 文本分类

  • 使用模型: hfl/rbt3

  • 数据集地址 https://github.com/SophonPlus/ChineseNlpCorpus

  • Step1 导入相关包

from transformers import AutoTokenizer, AutoModelForSequenceClassification,BertForSequenceClassification
  • Step2 加载数据
import pandas as pd
data = pd.read_csv("./ChnSentiCorp_htl_all.csv")
data = data.dropna() # 清除数据损坏的部分
  • Step3 创建Dataset

class MyDataset(Dataset):

    def __init__(self) -> None:
        super().__init__()
        self.data = pd.read_csv("./ChnSentiCorp_htl_all.csv")
        self.data = self.data.dropna()

    def __getitem__(self, index):
        return self.data.iloc[index]["review"], self.data.iloc[index]["label"]
    
    def __len__(self):
        return len(self.data)
dataset = MyDataset()
for i in range(5):
    print(dataset[i])

在这里插入图片描述

  • Step4 划分数据集
rom torch.utils.data import random_split
trainset, validset = random_split(dataset, lengths=[0.9, 0.1])
len(trainset), len(validset)
  • Step5 创建Dataloader

这里读取数据,然后写处理函数传到Dataloder中collate_fn,可以实现批量处理

import torch

tokenizer = AutoTokenizer.from_pretrained("../../models/hfl/rbt3")

def collate_func(batch):
    texts, labels = [], []
    for item in batch:
        texts.append(item[0])
        labels.append(item[1])
    # truncation=True 过长做截断
    inputs = tokenizer(texts, max_length=128, padding="max_length", truncation=True, return_tensors="pt")
    inputs["labels"] = torch.tensor(labels)
    return inputs
from torch.utils.data import DataLoader

trainloader = DataLoader(trainset, batch_size=32, shuffle=True, collate_fn=collate_func)
validloader = DataLoader(validset, batch_size=64, shuffle=False, collate_fn=collate_func)

# 查看loader
next(enumerate(validloader))[1]
  • step6 创建模型以及优化器
from torch.optim import Adam

model = AutoModelForSequenceClassification.from_pretrained("../../models/hfl/rbt3")

if torch.cuda.is_available():
    model = model.cuda()
optimizer = Adam(model.parameters(), lr=2e-5)
  • step7 训练与验证
def evaluate():
    model.eval()
    acc_num = 0
    with torch.inference_mode():
        for batch in validloader:
            if torch.cuda.is_available():
                batch = {k: v.cuda() for k, v in batch.items()}
            output = model(**batch)
            pred = torch.argmax(output.logits, dim=-1)
            acc_num += (pred.long() == batch["labels"].long()).float().sum()
    return acc_num / len(validset)

def train(epoch=3, log_step=100):
    global_step = 0
    for ep in range(epoch):
        model.train()
        for batch in trainloader:
            if torch.cuda.is_available():
                batch = {k: v.cuda() for k, v in batch.items()}
            optimizer.zero_grad()
            output = model(**batch)
            output.loss.backward()
            optimizer.step()
            if global_step % log_step == 0:
                print(f"ep: {ep}, global_step: {global_step}, loss: {output.loss.item()}")
            global_step += 1
        acc = evaluate()
        print(f"ep: {ep}, acc: {acc}")
  • step9 模型训练与预测
train()
# sen = "我觉得这家酒店不错,饭很好吃!"
sen = "我觉得这家酒店太差了,饭很难吃!"
id2_label = {0: "差评!", 1: "好评!"}
model.eval()
with torch.inference_mode():
    inputs = tokenizer(sen, return_tensors="pt")
    inputs = {k: v.cuda() for k, v in inputs.items()}
    logits = model(**inputs).logits
    pred = torch.argmax(logits, dim=-1)
    print(f"输入:{sen}\n模型预测结果:{id2_label.get(pred.item())}")

尝试pipeline

from transformers import pipeline

model.config.id2label = id2_label
pipe = pipeline("text-classification", model=model, tokenizer=tokenizer, device=0)
pipe(sen)

输出如下:
在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2158261.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

【PAM】Linux登录认证限制

PAM(Pluggable Authentication Modules,可插拔认证模块)是一种灵活的认证框架,用于在 Linux 和其他类 Unix 系统上管理用户的身份验证。PAM 允许系统管理员通过配置不同的认证模块来定制应用程序和服务的认证方式,而不…

软件设计师:01计算机组成与结构

文章目录 一、校验码1.奇偶校验码2.海明码3.循环冗余检验码 二、原码反码补码移码三、浮点数表示法1.浮点数相加时 四、寻址方式五、CPU1.访问速度2.cpu的组成 六、RISC和CISC&#xff08;<font color red>只用记住不同就可以&#xff09;七、冗余技术1.结构冗余2.信息冗…

HyperWorks的实体几何创建与六面体网格剖分

创建和编辑实体几何 在 HyperMesh 有限元前处理环境中&#xff0c;有许多操作是针对“实体几何”的&#xff0c;例如创建六面体网格。在创建实体网格的工作中&#xff0c;我们既可以使用闭合曲面创建实体网格&#xff0c;也可以使用完整的实体几何创建实体网格。与闭合曲面相比…

【rabbitmq-server】安装使用介绍

在 1050a 系统下安装 rabbitmq-server 服务以及基本配置;【注】:改方案用于A版统信服务器操作系统 文章目录 功能概述功能介绍一、安装软件包二、启动服务三、验证四、基本配置功能概述 RabbitMQ 是AMQP的实现,高性能的企业消息的新标准。RabbitMQ服务器是一个强大和可扩展…

截取递增数-第15届蓝桥省赛Scratch中级组真题第6题

[导读]&#xff1a;超平老师的《Scratch蓝桥杯真题解析100讲》已经全部完成&#xff0c;后续会不定期解读蓝桥杯真题&#xff0c;这是Scratch蓝桥杯真题解析第191讲。 如果想持续关注Scratch蓝桥真题解读&#xff0c;可以点击《Scratch蓝桥杯历年真题》并订阅合集&#xff0c;…

【c数据结构】OJ练习篇 帮你更深层次理解链表!(相交链表、相交链表、环形链表、环形链表之寻找环形入口点、判断链表是否是回文结构、 随机链表的复制)

目录 一. 相交链表 二. 环形链表 三. 环形链表之寻找环形入口点 四. 判断链表是否是回文结构 五. 随机链表的复制 一. 相交链表 最简单粗暴的思路&#xff0c;遍历两个链表&#xff0c;分别寻找是否有相同的对应的结点。 我们对两个链表的每个对应的节点进行判断比较&…

力扣 209.长度最小的子数组

一、长度最小的子数组 二、解题思路 采用滑动窗口的思路&#xff0c;详细见代码。 三、代码 class Solution {public int minSubArrayLen(int target, int[] nums) {int n nums.length, left 0, right 0, sum 0;int ans n 1; for (right 0; right < n; right ) { …

数通。。。

通信&#xff1a;需要介质才能通信电话离信号塔&#xff08;基站&#xff09;越远&#xff0c;信号越弱。信号在基站之间传递。你离路由器越远&#xff0c;信号越差。一个意思 比如想传一张图片&#xff0c;这张图片就是数据载荷 网关&#xff0c;分割两个网络。路由器可以是网…

Chat2VIS: Generating Data Visualizations via Natural Language

Chat2VIS:通过使用ChatGPT, Codex和GPT-3大型语言模型的自然语言生成数据可视化 梅西大学数学与计算科学学院&#xff0c;新西兰奥克兰 IEEE Access 1 Abstract 数据可视化领域一直致力于设计直接从自然语言文本生成可视化的解决方案。自然语言接口 (NLI) 的研究为这些技术的…

巴黎嫩事件对数据信息安全的影响及必要措施

2024年9月17日&#xff0c;黎巴嫩首都贝鲁特发生了多起小型无线电通信设备爆炸事件&#xff0c;导致伊朗驻黎巴嫩大使受轻伤。这一事件不仅引发了对安全的广泛关注&#xff0c;也对数据信息安全提出了新的挑战。 王工 18913263502 对数据信息安全的影响&#xff1a; 数据泄露风…

MySQL慢查询优化指南

​ 博客主页: 南来_北往 系列专栏&#xff1a;Spring Boot实战 前言 当遇到慢查询问题时&#xff0c;不仅影响服务效率&#xff0c;还可能成为系统瓶颈。作为一位软件工程师&#xff0c;掌握MySQL慢查询优化技巧至关重要。今天&#xff0c;我们就来一场“数据库加速之旅…

Thinkphp(TP)

1.远程命令执行 /index.php?sindex/think\app/invokefunction&functioncall_user_func_array&vars[0]system&vars[1][]whoami 2.远程代码执行 /index.php?sindex/think\app/invokefunction&functioncall_user_func_array&vars[0]phpinfo&vars[1][]…

Java面向对象——内部类(成员内部类、静态内部类、局部内部类、匿名内部类,完整详解附有代码+案例)

文章目录 内部类17.1概述17.2成员内部类17.2.1 获取成员内部类对象17.2.2 成员内部类内存图 17.3静态内部类17.4局部内部类17.5匿名内部类17.5.1概述 内部类 17.1概述 写在一个类里面的类叫内部类,即 在一个类的里面再定义一个类。 如&#xff0c;A类的里面的定义B类&#x…

微信支付商户号注册流程

目录 一、官方指引二、申请规则三、申请流程1.提交资料2.签约协议3.绑定场景 四、微信支付商户登录入口 一、官方指引 https://kf.qq.com/faq/210423UrIRB7210423by6fQn.html 二、申请规则 1、微信支付商家仅面向企业、个体工商户、政府及事业单位、民办非企业、社会团体、基…

java sdk下载,解决下载了java但是编译不了

直接搜Java得到的网站使用不了的 应该只是个功能包或者版本太低用不了 得去oracle公司搜java这个产品去下载

Java语言程序设计基础篇_编程练习题**18.34 (游戏:八皇后问题)

目录 题目&#xff1a;**18.34 (游戏:八皇后问题) 代码示例 代码解析 输出结果 使用文件 题目&#xff1a;**18.34 (游戏:八皇后问题) 八皇后问题是要找到一个解决方案&#xff0c;将一个皇后棋子放到棋盘上的每行中&#xff0c;并且两个皇后棋子之间不能相互攻击。编写个…

Llama 3.1 技术研究报告-2

3.3 基础设施、扩展性和效率 我们描述了⽀持Llama 3 405B⼤规模预训练的硬件和基础设施&#xff0c;并讨论了⼏项优化措施&#xff0c;这些措施提⾼了训练效率。 3.3.1 训练基础设施 Llama 1和2模型在Meta的AI研究超级集群&#xff08;Lee和Sengupta&#xff0c;2022&#x…

模型融合创新性Max!5种模型融合方法刷新SOTA!发顶会必看!

近年来&#xff0c;关于模型融合的研究逐渐火热&#xff0c;出现了很多效果出众的成果。模型融合&#xff08;Model Merging&#xff09;技术&#xff0c;即利用现有模型的参数、架构和特性&#xff0c;巧妙结合成一个新的、功能更强大的模型&#xff0c;这不仅减少了从头训练大…

计算机毕业设计新闻资讯知识施肥技术网站推荐评论搜索猜你喜欢留言/springboot/javaWEB/J2EE/MYSQL数据库/vue前后分离小程序

摘要‌ 随着互联网的快速发展&#xff0c;新闻网站成为人们获取新闻资讯的重要途径。本文旨在介绍一款新闻网站毕业设计的开发与实现过程&#xff0c;该系统集新闻发布、用户互动、个性化推荐等功能于一体&#xff0c;采用Spring Boot、Vue等前后端分离技术&#xff0c;旨在提…

风力发电场集中监控解决方案

0引言 风力发电装机容量近年来快速增长。截至7月底&#xff0c;全国发电装机容量达27.4亿千瓦&#xff0c;同比增长11.5%。其中&#xff0c;太阳能和风力发电装机容量分别为4.9亿千瓦和3.9亿千瓦&#xff0c;同比增长42.9%和14.3%。风力发电场分陆上和海上风电&#xff0c;常位…