bert-base-chinese模型使用教程

news2024/11/24 8:43:29
向量编码和向量相似度展示
import torch
from transformers import BertTokenizer, BertModel
import numpy as np

model_name = "C:/Users/Administrator.DESKTOP-TPJL4TC/.cache/modelscope/hub/tiansz/bert-base-chinese"

sentences = ['春眠不觉晓', '大梦谁先觉', '浓睡不消残酒', '东临碣石以观沧海']

tokenizer = BertTokenizer.from_pretrained(model_name)
# print(type(tokenizer)) # <class 'transformers.models.bert.tokenization_bert.BertTokenizer'>

model = BertModel.from_pretrained(model_name)
# print(type(model)) # <class 'transformers.models.bert.modeling_bert.BertModel'>

def test_encode():
    input_ids = tokenizer.encode('春眠不觉晓', return_tensors='pt') # shape (1, 7)
    output = model(input_ids)
    print(output.last_hidden_state.shape)  # shape (1, 7, 768)
    v = torch.mean(output.last_hidden_state, dim=1)  # shape (1, 768)
    print(v.shape)  # shape (1, 768)
    print(output.pooler_output.shape)  # shape (1, 768)

根据 transformers\modeling_outputs.py:196,即 BaseModelOutputWithPoolingAndCrossAttentions 的注释:

@dataclass
class BaseModelOutputWithPoolingAndCrossAttentions(ModelOutput):
    """
    Base class for model's outputs that also contains a pooling of the last hidden states.

    Args:
        last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
            Sequence of hidden-states at the output of the last layer of the model.
        pooler_output (`torch.FloatTensor` of shape `(batch_size, hidden_size)`):
            Last layer hidden-state of the first token of the sequence (classification token) after further processing
            through the layers used for the auxiliary pretraining task. E.g. for BERT-family of models, this returns
            the classification token after processing through a linear layer and a tanh activation function. The linear
            layer weights are trained from the next sentence prediction (classification) objective during pretraining.
        hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
            Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, +
            one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`.

            Hidden-states of the model at the output of each layer plus the optional initial embedding outputs.
        attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
            Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
            sequence_length)`.

            Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
            heads.
        cross_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` and `config.add_cross_attention=True` is passed or when `config.output_attentions=True`):
            Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
            sequence_length)`.

            Attentions weights of the decoder's cross-attention layer, after the attention softmax, used to compute the
            weighted average in the cross-attention heads.
        past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
            Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape
            `(batch_size, num_heads, sequence_length, embed_size_per_head)`) and optionally if
            `config.is_encoder_decoder=True` 2 additional tensors of shape `(batch_size, num_heads,
            encoder_sequence_length, embed_size_per_head)`.

            Contains pre-computed hidden-states (key and values in the self-attention blocks and optionally if
            `config.is_encoder_decoder=True` in the cross-attention blocks) that can be used (see `past_key_values`
            input) to speed up sequential decoding.
    """

    last_hidden_state: torch.FloatTensor = None
    pooler_output: torch.FloatTensor = None
    hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None
    past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None
    attentions: Optional[Tuple[torch.FloatTensor, ...]] = None
    cross_attentions: Optional[Tuple[torch.FloatTensor, ...]] = None

1为batch_size,我们出入的就是一个字符串所以batch_size为1。

7为sequence_length,BERT模型会为单句的输入前面加特殊字符[CLS]和后面加特殊字符[SEP],因此为7个字符。

768为hidden_size,即每个字符被编码成768个数字组成的向量。

根据文档的说法,pooler_output向量一般不是很好的句子语义摘要,因此这里采用了将 last_hidden_state 进行池化的方法。

torch.mean(dim)函数

这是常用的池化pooling方法,降低向量的维度,便于运算。

dim不指定任何参数就是所有元素的算术平均值

dim指定为0时,求得是列的平均值。

dim指定为1时,求得是行的平均值;

也就是经过tarch.mean()之后,这句话变成了一个 1X768 的向量,即使用这一个向量来代表着句话的语义。

接下来就可以计算文本相似度得分了,目标为上面给出的四个句子。

def test_similarity():
    with torch.no_grad():
        vs = [sentence_embedding(sentence).numpy() for sentence in sentences]
        nvs = [v / np.linalg.norm(v) for v in vs]  # normalize each vector
        m = np.array(nvs).squeeze(1)  # shape (4, 768)
        print(np.around(m @ m.T, decimals=2))  # pairwise cosine similarity


def sentence_embedding(sentence):
    input_ids = tokenizer.encode(sentence, return_tensors='pt')
    output = model(input_ids)
    return torch.mean(output.last_hidden_state, dim=1)

@符号是矩阵相乘

输出结果 4X4

[[1.   0.75 0.83 0.57]
 [0.75 1.   0.72 0.51]
 [0.83 0.72 1.   0.58]
 [0.57 0.51 0.58 1.  ]]
春眠不觉晓大梦谁先觉浓睡不消残酒东临碣石以观沧海
春眠不觉晓10.750.830.57
大梦谁先觉0.7510.720.51
浓睡不消残酒0.830.7210.58
东临碣石以观沧海0.570.510.581

with torch.no_grad() 的作用是将模型状态置为推断(inference),即在计算过程中不进行梯度计算和反向传播操作。

关于 last_hidden_state 与 pooler_output

看bert中文文本分类任务,发现训练输出结果向量仅使用了第一个token的向量,而一句话中的第一个位置是特殊字符[CLS],那么该如何理解呢?

BERT在第一句前会加一个[CLS]标志,CLS 就是 classification 分类之意,即 CLS 的作用就是它能代表整句话的意思,使用它就可以对整句话进行分类。用于下游的分类任务等。

为什么选它呢,因为与文本中已有的其它词相比,这个无明显语义信息的符号会更“公平”地融合文本中各个词的语义信息,从而更好的表示整句话的语义。具体来说,self-attention是用文本中的其它词来增强目标词的语义表示,但是目标词本身的语义还是会占主要部分的,因此,经过BERT的12层,每次词的embedding融合了所有词的信息,可以去更好的表示自己的语义。而[CLS]位本身没有语义,经过12层,得到的是attention后所有词的加权平均,相比其他正常词,可以更好的表征句子语义。

当然,也可以通过对最后一层所有词的embedding做pooling池化处理去表征句子语义,最常用的池化方法就是平均法mean。

从上面的打印信息我们可以知道 pooler_output 的结果也是 1X768,与 torch.mean() 的维度一样,可以理解为它是经过模型池化后的结果。pooler_output是通过应用一个线性层和一个激活函数(通常是Tanh)到last_hidden_state的第一个token(即[CLS]标记)的隐藏状态来生成的。这个输出通常用于分类任务,因为它编码了整个序列的信息。

BERT下游NLP任务的微调

https://zhuanlan.zhihu.com/p/149779660

在BERT原论文的设计中,下游NLP任务被分成如下几类:

  1. 句子对分类型任务,如GLUE的MNLI、QQP、MRPC等
  2. 单句分类型任务,如GLUE的STS-2、CoLA等
  3. 问答/阅读理解型任务,需要根据问题,在备选回答文本中标识出答案部分,如SQuAD
  4. 单句标注任务,给每个token打上标签,如POS词性标注、NER命名实例识别任务

一般来说,以上4类任务已经基本上可以涵盖大部分下游NLP任务了。

我们在上节指定的sst-2任务实际上就是一类“单句分类”型NLP任务,利用run_glue.py脚本,我们可以几乎不用写代码就可以完成微调。GLUE基准测试集一共包括9种不同类型的数据集,如下表小结。

在这里插入图片描述

动手写BERT系列 https://www.bilibili.com/video/av258262103?vd_source=0c75dc193ee55511d0515b3a8c375bd0&spm_id_from=333.788.videopod.sections

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2237071.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

Qt/C++ 海康SDK开发示例Demo

*** 工业相机在机器视觉中起到关键作用&#xff0c;本文基于海康 SDK 详细解读了设备连接与控制的各个步骤。内容涵盖设备枚举、句柄创建、图像采集回调以及设备异常处理&#xff0c;帮助开发者快速理解如何通过代码控制相机&#xff0c;实时采集并处理图像数据。*** 1. 搜索并…

RabbitMQ的应用

七种工作模式介绍 1.Simple(简单模式) P&#xff1a;生产者&#xff0c;也就是要发送信息的程序 C&#xff1a;消费者&#xff0c;消息的接收者 Queue&#xff1a;消息队列。图中黄色背景部分&#xff0c;类似一个邮箱&#xff0c;可以缓存发送信息&#xff1b;生产者向其中…

K8S网络插件故障处理

1网络插件故障 1此故障问题处理方法 查询ip是否正常是否是主节点IP地址如果不是需要更改 更改方式 1 修改calico.yaml文件的相应参数 # Cluster type to identify the deployment type - name: IP_AUTODETECTION_METHOD #增加内容value: "interfaceens*" 或者 value…

【论文速看】DL最新进展20241109-图像超分、物理信息神经网络、扩散模型

目录 【图像超分】【物理信息神经网络】【扩散模型】 【图像超分】 [2024 红外图像超分] Infrared Image Super-Resolution via Lightweight Information Split Network 论文链接&#xff1a;https://arxiv.org/pdf/2405.10561v2 代码链接&#xff1a;无 单图像超分辨率&…

Python学习从0到1 day26 第三阶段 Spark ① 数据输入

要学会 剥落旧痂 然后 循此新生 —— 24.11.8 一、Spark是什么 定义&#xff1a; Apache Spark 是用于大规模数据处理的统一分析引擎 简单来说&#xff0c;Spark是一款分布式的计算框架&#xff0c;用于调度成百上千的服务器集群&#xff0c;计算TB、PB乃至EB级别的海量数据…

[Python学习日记-63] 继承与派生

[Python学习日记-63] 继承与派生 简介 继承 派生 简介 上一篇文章我们学习了类如何使用&#xff0c;以及相关特性&#xff0c;也做了相关的练习&#xff0c;在练习当中发现类与类之间有时也会存在重复代码&#xff0c;其实在类中我们还有一个继承和派生的概念没有说&#xf…

基于 Encoder-only 架构的大语言模型

基于 Encoder-only 架构的大语言模型 Encoder-only 架构 Encoder-only 架构凭借着其独特的双向编码模型在自然语言处理任务中表现出色&#xff0c;尤其是在各类需要深入理解输入文本的任务中。 核心特点&#xff1a;双向编码模型&#xff0c;能够捕捉全面的上下文信息。 En…

Python学习------第四天

Python的判断语句 一、布尔类型和比较运算符 二、 if语句的基本格式 if语句注意空格缩进&#xff01;&#xff01;&#xff01; if else python判断语句的嵌套用法&#xff1a;

uniapp实现H5和微信小程序获取当前位置(腾讯地图)

之前的一个老项目&#xff0c;使用 uniapp 的 uni.getLocation 发现H5端定位不准确&#xff0c;比如余杭区会定位到临平区&#xff0c;根据官方文档初步判断是项目的uniapp的版本太低。 我选择的方式不是区更新uniapp的版本&#xff0c;是直接使用高德地图的api获取定位。 1.首…

测试网空投进行中 — 全面了解 DePIN 赛道潜力项目 ICN Protocol 及其不可错过的早期红利

随着云计算技术的飞速发展&#xff0c;越来越多的企业和个人对云服务的需求变得多样化且复杂化。然而&#xff0c;传统的中心化云服务平台&#xff08;如AWS、微软Azure等&#xff09;往往存在着高成本、数据隐私保护不足以及灵活性差等问题。 为了解决这些挑战&#xff0c;Imp…

IntelliJ IDEA 使用心得与常用快捷键

刚开始学习写Java的时候&#xff0c;用的eclipse&#xff0c;正式工作后&#xff0c;主要用的myeclipse&#xff0c;去年初在前辈的推荐下&#xff0c;在2折的时候买了正版的 IntelliJ IDEA 和 Pycharm&#xff0c;12.0版终生使用&#xff0c;一年更新。 使用前早就久闻其名&am…

【rust】rust基础代码案例

文章目录 代码篇HelloWorld斐波那契数列计算表达式&#xff08;加减乘除&#xff09;web接口 优化篇target/目录占用一个g&#xff0c;仅仅一个actix的helloWorld demo升级rust版本&#xff0c; 通过rustupcargo换源windows下放弃吧&#xff0c;需要额外安装1g的toolchain并且要…

施工企业为什么要用工程项目管理软件?工程项目管理软件的用处是什么?

施工企业一定会遇到哪些问题&#xff1f;工人怠工、材料浪费、数据造假、工期拖延、质量问题、安全隐患等。这些问题正在悄然侵蚀建施工业的经济效益。每一个环节的失控都可能导致巨大的经济损失&#xff0c;还可能损害企业的声誉。面对日益复杂的工程管理环境&#xff0c;如何…

【C++】详解RAII思想与智能指针

&#x1f308; 个人主页&#xff1a;谁在夜里看海. &#x1f525; 个人专栏&#xff1a;《C系列》《Linux系列》 ⛰️ 丢掉幻想&#xff0c;准备斗争 目录 引言 内存泄漏 内存泄漏的危害 内存泄漏的处理 一、RAII思想 二、智能指针 1.auto_ptr 实现原理 模拟实现 弊端…

所谓的情商高,其实就是会说话!

所谓的情商高&#xff0c;其实就是会说话&#xff01; 1.当遇到不知道的事情时&#xff0c;不要直截了当地说“不知道”。而应委婉地表达为“我想听听你的看法”。 如此既能避免尴尬&#xff0c;又能展现出对对方见解的尊重和期待。 2.不要简单地说“我迟到了”&#xff0c;…

ALB搭建

ALB: 多级分发、消除单点故障提升应用系统的可用性&#xff08;健康检查&#xff09;。 海量微服务间的高效API通信。 自带DDoS防护&#xff0c;集成Web应用防火墙 配置&#xff1a; 1.创建ECS实例 2.搭建应用 此处安装的LNMP 3.创建应用型负载均衡ALB实例 需要创建服务关联角…

【spark面试】spark的shuffle过程

概述 所有的shuffle的过程本质上就是一个task将内存中的数据写入磁盘&#xff0c;然后另一个task将磁盘中的数据读入内存的过程。 对于mapreduce来说&#xff0c;我们将内存中的数据写入磁盘成为maptask&#xff0c;将磁盘中的数据读入内存称为reducetask。 而对于spark来说&…

Android 实现一个系统级的悬浮秒表

前言 由于项目需要将手机录屏和时间日志对应起来&#xff0c;一般的手机录屏只能看到分钟&#xff0c;但是APP的日志输出通常都是秒级别的&#xff0c;于是决定自己手撸一个悬浮秒表&#xff08;有拖拽效果&#xff09;。 效果如下 具体实现 大致的实现思路&#xff1a; 创…

【科普小白】LLM大语言模型的基本原理

一、要了解LLM大模型的基本原理就要先来了解一下自然语言处理&#xff08;NLP&#xff09;。 NLP 是 AI 的一个子领域&#xff0c;专注于使计算机能够处理、解释和生成人类语言&#xff0c;主要任务包括&#xff1a;文本分类、自动翻译、问题回答、生成文本等。到底是NLP促生了…

Go语言开发商城管理后台-GoFly框架商城插件已发布 需要Go开发商城的朋友可以来看看哦!

温馨提示&#xff1a;我们分享的文章是给需要的人&#xff0c;不需要的人请绕过&#xff0c;文明浏览&#xff0c;误恶语伤人&#xff01; 前言 虽然现在做商城的需求不多&#xff0c;但有很多项目中带有商城功能&#xff0c;如社区医院系统有上服务套餐、理疗产品需求、宠物…