【深度学习】语言模型与注意力机制以及Bert实战指引之一

news2025/2/4 6:43:20

文章目录

  • 统计语言模型和神经网络语言模型
  • 注意力机制和Bert
  • 实战Bert
    • 配置环境和模型转换
    • 格式准备
  • 模型构建
    • 网络设计
    • 模型配置
    • 代码实战


统计语言模型和神经网络语言模型

区别:统计语言模型的本质是基于词与词共现频次的统计,而神经网络语言模型则是给每个词赋予了向量空间的位置作为表征,从而计算它们在高维连续空间中的依赖关系。
相对来说,神经网络的表示以及非线性映射,更加适合对自然语言进行建模。


注意力机制和Bert

Attention is all you need 是关于注意力机制最经典的论文,它是机制,准确来说,不是一个算法,而是一个构建网络的思路。
其实 Attention 机制要做:找到最重要的关键内容。它对网络中的输入(或者中间层)的不同位置,给予了不同的注意力或者权重,然后再通过学习,网络就可以逐渐知道哪些是重点,哪些是可以舍弃的内容了。

BERT 的全称是 Bidirectional Encoder Representation from Transformers,即双向 Transformer 的 Encoder。作为一种基于 Attention 方法的模型,它最开始出现的时候可以说是抢尽了风头,在文本分类、自动对话、语义理解等十几项 NLP 任务上拿到了历史最好成绩。BERT 的理论框架主要是基于论文《Attention is all you need》中提出的 Transformer,而后者的原理则是刚才提到的 Attention。
在这里插入图片描述
结合上图我们要注意的是,BERT 采用了基于 MLM 的模型训练方式,即 Mask Language Model。因为 BERT 是 Transformer 的一部分,即 encoder 环节,所以没有 decoder 的部分(其实就是 GPT)。
在这里插入图片描述

用过 Word2Vec 的小伙伴应该比较清楚,在 Word2Vec 中,对于同一个词语,它的向量表示是固定的,这也就是为什么会有那个经典的“国王-男人+女人=皇后”的计算式了,但有个问题,“苹果”可能是水果,也可能是手机品牌。如果还是用同一个向量表示,就有偏差了,而BERT可以根据上下文的不同,对同一个token给出的词向量是动态变化的,很灵活。

实战Bert

配置环境和模型转换

安装hugging face的Pytorch版本的Transformers包

pip install Transformers

ref:https://github.com/huggingface/transformers

在这里插入图片描述
找到这里面很重要的两个文件:
convert_BERT_original_tf_checkpoint_to_PyTorch.py 和 modeling_BERT.py
第一个是将tf2的模型转成pytorch版的
第二个是给咱一个BERT的使用案例

选一个预训练模型Pre-trained Models
ref:https://github.com/google-research/bert

在这里插入图片描述
102 languages, 12-layer, 768-hidden, 12-heads, 110M parameters

模型转换:

python convert_tf_checkpoint_to_pytorch.py --tf_checkpoint_path D:\code\python_project\01mlforeveryone\multilingual_L-12_H-768_A-12\multilingual_L-12_H-768_A-12\bert_model.ckpt  --bert_config_file D:\code\python_project\01mlforeveryone\multilingual_L-12_H-768_A-12\multilingual_L-12_H-768_A-12\bert_config.json --pytorch_dump_path bert/pytorch_model.bin

转换成功在这里插入图片描述
confiig.json BERT模型的配置文件,记录了所有用于训练的参数设置
PyTorch_model.bin 模型文件
vocab.txt 词表文件 用于识别所支持语言的字符、字符串或者单词

格式准备

输入的数据不能是直接把词输入到模型,需要转换成三种向量:

  1. Token embeddings: 词向量, [CLS] 开头,用于文本分类等任务
  2. Segment embeddings: 将两句话进行区分,比如问答任务,问句和答句同时输入,这就需要能够区分两句话的操作。
  3. Position embeddings 单词的位置信息

模型构建

我们来搭建一个基于BERT的文本分类网络模型,包括:网络的设计、配置、以及数据准备。

网络设计

modeling_BERT.py 文件作者已经给我们提供了很多种类的NLP任务代码。
其中“BERTForSequenceClassification”, 这个分类网络我们可以直接使用,它是最基本的BERT文本分类的流程。

模型配置

config.json
id2label: 类别标签和类别名称的映射关系
label2id:类别名称和列表标签的映射关系
num_labels_cate:类别的数量

修改方式见这里
ref:https://blog.csdn.net/qsx123432/article/details/126159843
ref:https://blog.csdn.net/weixin_42223207/article/details/125024596

代码实战

import numpy as np
import torch
from transformers import BertTokenizer, BertConfig, BertForMaskedLM, BertForNextSentencePrediction
from transformers import BertModel

model_name = "./bert/bert-base-chinese/"
MODEL_PATH = "D:/code/python_project/01mlforeveryone/bert/bert-base-chinese/"
#### 'bert-base-chinese': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-chinese-pytorch_model.bin",
####  'bert-base-chinese': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-chinese-config.json",
####  'bert-base-chinese': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-chinese-vocab.txt",
# a. 通过词典导入分词器
tokenizer = BertTokenizer.from_pretrained(model_name)
# b.导入配置文件
model_config = BertConfig.from_pretrained(model_name)
# 修改配置
model_config.output_hidden_states = True
model_config.output_attentions = True
# 通过配置和路径导入模型
bert_model = BertModel.from_pretrained(MODEL_PATH, config = model_config)

print(tokenizer.encode('吾儿莫慌'))   # [101, 1434, 1036, 5811, 2707, 102]

sen_code = tokenizer.encode_plus('这个故事没有终点', "正如星空没有彼岸")
print(sen_code)
"""
{'input_ids': [101, 100, 100, 100, 100, 100, 100, 100, 100, 102, 100, 100, 100, 100, 100, 100, 100, 100, 102], 'token_type_ids': [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1], 'attention_mask': [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]}
"""
# 将input_ids转化回token:
print(tokenizer.convert_ids_to_tokens(sen_code['input_ids']))
"""
['[CLS]', '这', '个', '故', '事', '没', '有', '终', '点', '[SEP]', '正', '如', '星', '空', '没', '有', '彼', '岸', '[SEP]']
"""
# 将分词输入模型,得到编码:
# 对编码进行转换,以便输入Tensor
tokens_tensor = torch.tensor([sen_code['input_ids']])       # 添加batch维度并,转换为tensor,torch.Size([1, 19])
segments_tensors = torch.tensor(sen_code['token_type_ids']) # torch.Size([19])

bert_model.eval()

# 进行编码
with torch.no_grad():
    
    outputs = bert_model(tokens_tensor, token_type_ids = segments_tensors)
    encoded_layers = outputs   # outputs类型为tuple
    
    print(encoded_layers[0].shape, encoded_layers[1].shape, 
          encoded_layers[2][0].shape, encoded_layers[3][0].shape)
    # torch.Size([1, 19, 768]) torch.Size([1, 768])
    # torch.Size([1, 19, 768]) torch.Size([1, 12, 19, 19])
model_name = 'bert-base-chinese'    # 指定需下载的预训练模型参数

在下一 篇文章,我们将开始做更多的练习

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1328585.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

解决IDEA编译/启动报错:Abnormal build process termination

报错信息 报错信息如下: Abnormal build process termination: "D:\Software\Java\jdk\bin\java" -Xmx3048m -Djava.awt.headlesstrue -Djava.endorsed.dirs\"\" -Djdt.compiler.useSingleThreadtrue -Dpreload.project.path………………很纳…

C/C++ 基础函数

memcpy:C/C语言中的一个用于内存复制的函数,声明在 string.h 中(C是 cstring) void *memcpy(void *destin, void *source, unsigned n);作用是:以source指向的地址为起点,将连续的n个字节数据,…

【Grafana】Grafana匿名访问以及与LDAP连接

上一篇文章利用Docker快速部署了Grafana用来展示Zabbix得监控数据,但还需要给用户去创建账号允许他们登录后才能看展示得数据,那有什么办法让非管理员更方便得去访问Grafana呢?下面介绍两个比较方便实现的: 在开始设置前&#xff…

Echarts 仪表盘实现平均值和实时值

const gaugeData [{value: 20,name: 互动发起率实时值,title: {offsetCenter: [-25%, 10%]},detail: {offsetCenter: [-25%, 18%]}},{value: 40,name: 互动发起平均值,title: {offsetCenter: [25%, 10%]},detail: {offsetCenter: [25%, 18%]}},// {// value: 60,// name: …

Megatron模型并行研究

Megatron模型并行研究 1. 技术调研 a. Megatron-LM Megatron-LM针对的是特别大的语言模型,使用的是模型并行的训练方式。但和普通的模型并行不同,他采用的其实是张量并行的形式,具体来说就是将一个层切开放到不同的GPU上,属于层…

前端---后端 跨域?

一、跨域 ? 跨域(Cross-Origin Resource Sharing,CORS)是浏览器的一项安全功能,它用于限制一个域名下的文档如何从另一个不同的域名、端口或协议请求资源。跨域资源共享(Cross-Origin Resource Sharing&am…

Protobuf 编码规则及c++使用详解

Protobuf 编码规则及c使用详解 Protobuf 介绍 Protocol Buffers (a.k.a., protobuf) are Google’s language-neutral, platform-neutral, extensible mechanism for serializing structured data Protocol Buffers(简称为protobuf)是谷歌的语言无关、…

邻接矩阵表示 深度遍历 广度遍历

邻接矩阵表示法是一种图的表示方法,其中每个顶点都有一个唯一的索引,而每条边则由两个顶点之间的连接确定。深度优先遍历(DFS)和广度优先遍历(BFS)是两种常用的图遍历算法。 1. 深度优先遍历(D…

【Python】matplotlib画图_饼状图

柱状图主要使用pie()函数,基本格式如下: plt.pie(x,explodeNone,labelsNone,colorsNone,autopctsNone,pctdistance0.6,shadowFalse,labeldistance1.1,staatangleNone,radiusNone,counterclockTrue,wedgepropsNone,textpropsNone,center(0,0),frameFalse…

【大数据存储与处理】第一次作业

hbase 启动步骤 1、启动 hadoop,master 虚拟机,切换 root 用户,输入终端命令:start-all.sh 2、启动 zookeeper,分别在 master、slave1、slave2 虚拟机终端命令执行:zkServer.sh start 3、启动 hbase&#x…

MySQL 分表真的能提高查询效率?

背景 首先我们以InnoDB引擎,BTree 3层为例。我们需要先了解几个知识点:页的概念、InnoDB数据的读取方式、什么是树搜索?、一次查询花费的I/O次数,跨页查询。 页的概念 索引树的页(page)是指存储索引数据…

Flutter本地化(国际化)之App名称

文章目录 Android国际化IOS国际化 Flutter开发的App,如果名称想要跟随着系统的语言自动改变,则必须同时配置Android和IOS原生。 Android国际化 打开android\app\src\main\res\values 创建strings.xml 在values上右键,选择New>Values Res…

【lesson21】MySQL复合查询(2)子查询

文章目录 子查询测试要用到的表测试要用到的数据单行子查询案例 多行子查询案例 多列子查询案例 在from子句中使用子查询案例 合并查询union案例union all案例 子查询 子查询是指嵌入在其他sql语句中的select语句,也叫嵌套查询 测试要用到的表 测试要用到的数据 单…

TOPCON拓普康BM-7A亮度色度计

用途: 其应用范围非常之广:各种电视/手机/电脑/复印机等的液晶显示屏LCD的亮度、色度、色温、对比度等项目测定;液晶领域内各部件(LED、CCFL、EL背光源,液晶模组,滤光片)的亮度、色度、配光特性…

Qt通用属性工具:随心定义,随时可见(二)

一、话接上篇 本片咱们话接上篇《Qt通用属性工具:随心定义,随时可见(一)》,讲讲自定义的对象属性如何绑定通用属性编辑工具。 二、破杯二两酒 1、一颗小花生 同样,我们先准备一个比较简单的demo&#x…

案例系列:营销模型_客户细分_无监督聚类

案例系列:营销模型_客户细分_无监督聚类 import numpy as np # 线性代数库 import pandas as pd # 数据处理库,CSV文件的输入输出(例如pd.read_csv)/kaggle/input/customer-personality-analysis/marketing_campaign.csv在这个项…

老师的责任和义务

作为一名老师,我们的责任和义务是重大的。在教育领域,我们扮演着至关重要的角色,肩负着培养下一代人才的重任。下面,我将以知乎的口吻,从几个方面谈谈老师的责任和义务。 确保学生获得高质量的教育。这包括制定合理的教…

企业级低代码平台:助力IT部门,释放业务创新力

随着低代码技术的升级,越来越多的企业开始采用低代码平台,如恒逸集团利用低代码平台快速搭建了综合业务管理平台,时间比传统开发缩短近一倍。云表低代码提供的数据、流程、权限、图表等引擎工具,完美适配企业数字化需求。根据Gart…

HarmonyOS应用事件打点开发指导

简介 传统的日志系统里汇聚了整个设备上所有程序运行的过程流水日志,难以识别其中的关键信息。因此,应用开发者需要一种数据打点机制,用来评估如访问数、日活、用户操作习惯以及影响用户使用的关键因素等关键信息。 HiAppEvent 是在系统层面…

《每天一分钟学习C语言·六》

1、 1字节(Byte)8位,1KB1024字节,1M1024KB,1G1024MB 2、 char ch A; printf(“ch %d\n”, ch);ch为65 这里是ASCII码转换 3、 scanf("%d", &i); //一般scanf直接加输入控制符 scanf("m%d&qu…