huggingface的transformers训练bert

news2024/11/16 11:31:54

目录

理论

实践


理论

https://arxiv.org/abs/1810.04805

BERT(Bidirectional Encoder Representations from Transformers)是一种自然语言处理(NLP)模型,由Google在2018年提出。它是基于Transformer模型的预训练方法,通过在大规模的无标注文本上进行预训练,学习到了丰富的语言表示。

BERT的主要特点是双向性预训练-微调框架。在传统的语言模型中,只使用了单向的上下文信息,而BERT利用了双向Transformer编码器来同时考虑上下文的信息,使得模型能够更好地理解句子中的语义和关系。BERT采用了Transformer的多层编码器结构,其中包含了自注意力机制(self-attention mechanism),能够有效地捕捉句子中不同位置的依赖关系。

单向的Transformer一般被称为Transformer decoder,其每一个token(符号)只会attend到目前往左的token。而双向的Transformer则被称为Transformer encoder,其每一个token会attend到所有的token。

BERT模型通过两个阶段的训练来获得语言表示。首先,它在大规模无标注的文本上进行预训练,使用两个任务:掩码语言建模(Masked Language Modeling,MLM)和下一句预测(Next Sentence Prediction,NSP)。

MLM任务中,随机掩盖输入句子的一些词汇,模型需要预测这些被掩盖的词汇。MLM任务的目的是让模型通过上下文来推断被掩盖的词汇,从而学习到丰富的语言表示。在预训练阶段,BERT模型会使用大规模的无标注文本进行训练,其中包括了来自维基百科、新闻文章、书籍等的文本数据。模型在这些大规模数据上进行预训练,通过尝试预测被掩盖词汇的方法来学习词汇的上下文关系和语义。在MLM任务中,模型的输入句子经过编码器(Transformer)进行编码,然后通过一个全连接层(输出层)来预测被掩盖的词汇。对于被掩盖的位置,模型会生成一个概率分布,以表示每个可能的词汇是被掩盖位置的预测。通常情况下,模型会根据预训练过程中的目标函数(如交叉熵损失)来优化预测结果。通过进行MLM任务的预训练,BERT模型能够学习到词汇的上下文信息和语义表示,从而在下游任务中具有更好的表现。在微调阶段,模型会使用有标签的数据进行进一步的训练,以适应特定任务的要求,并通过微调来提升模型在特定任务上的性能。对比gpt,中间的词只能和前面的词做attention而不能和后面的词做attention,所以没法做到上下文综合理解。

在NSP任务中,模型接收两个句子作为输入,要判断这两个句子是否是原文中的连续句子。

在预训练完成后,BERT模型可以用于各种下游任务的微调,如文本分类、命名实体识别、问答等。在微调阶段,模型会在特定任务的标注数据上进行进一步的训练,以适应具体任务的要求。只需要添加一个额外的输出层进行fine-tune,就可以在各种各样的下游任务中取得state-of-the-art的表现。在这过程中并不需要对BERT进行任务特定的结构修改。

RoBERTa(Robustly Optimized BERT Approach)是由Facebook AI于2019年提出的一种语言模型,它是在BERT模型的基础上进行改进和优化的。RoBERTa的目标是通过更大规模的数据和更长的训练时间来获得更强大的语言表示能力。相比于BERT,RoBERTa采用了一系列的训练技巧和策略,如动态掩码、更长的训练序列、更大的批量大小等,以提升模型的性能。RoBERTa在多项自然语言处理任务上取得了显著的性能提升,并成为了当前领域内的重要基准模型之一。

实践

https://github.com/huggingface/transformers/tree/main/examples/pytorch/language-modeling

安装:

git clone https://github.com/huggingface/transformers
cd transformers
pip install .
pip install -r requirements.txt
python run_clm.py \
    --model_name_or_path openai-community/gpt2 \
    --dataset_name wikitext \
    --dataset_config_name wikitext-2-raw-v1 \
    --per_device_train_batch_size 8 \
    --per_device_eval_batch_size 8 \
    --do_train \
    --do_eval \
    --output_dir /tmp/test-clm

RobertaForMaskedLM = RobertaModel + RobertaLMHead

RobertaModel = RobertaEmbeddings + RobertaEncoder + RobertaPooler

RobertaEmbeddings = nn.Embedding(word,position,token_type) + nn.LayerNorm + nn.Dropout

RobertaEncoder = nn.ModuleList([RobertaLayer(config))

RobertaLayer = RobertaAttention + RobertaIntermediate + RobertaOutput

RobertaAttention = RobertaSelfAttention + RobertaSelfOutput

基本上就是x--》q,k,v-->q*k-->mask-->softmax-->*v

RobertaIntermediate = Fc + activate

RobertaOutput = Linear + dropout + layernorm

RobertaPooler = Linear + 激活函数Tanh

RobertaLMHead = Linear + gelu + layernorm +linear

总结:

RobertaForMaskedLM = RobertaModel + RobertaLMHead

        RobertaModel = RobertaEmbeddings + RobertaEncoder + RobertaPooler

            RobertaEmbeddings = nn.Embedding(word,position,token_type) + nn.LayerNorm + nn.Dropout

            RobertaEncoder = nn.ModuleList([RobertaLayer(config))

                    RobertaLayer = RobertaAttention + RobertaIntermediate + RobertaOutput * 12

                              RobertaAttention = RobertaSelfAttention + RobertaSelfOutput

                              RobertaIntermediate = Fc + activate

              ​​​​​​​                RobertaOutput = Linear + dropout + layernorm

          ​​​​​​​   RobertaPooler = Linear + 激活函数Tanh

       RobertaLMHead = Linear + gelu + layernorm +linear

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1538624.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

力扣HOT100 - 128. 最长连续序列

解题思路: 注意: 1.Set不能直接排序,必须要转换成ArrayList或者LinkedList后用Collections.sort()方法进行排序。 (Queue也不能直接排序,排序方法同Set) 2.连续的序列不能只找第一个,因为不…

dbscan算法实现鸢尾花聚类(python实现)

DBscan算法原理 : dbscan算法-CSDN博客 法一(调库) : 直接调库 : import numpy as np import matplotlib.pyplot as plt from sklearn import datasets from sklearn.cluster import DBSCAN from sklearn.decomposition import PCA from sklearn.discriminant_analysis …

【数据结构刷题专题】——二分查找

二分查找 二分查找模板题&#xff1a;704. 二分查找 二分查找前提&#xff1a; 有序数组数组中无重复元素 左闭右闭&#xff1a; class Solution { public:int search(vector<int>& nums, int target) {int left 0;int right nums.size() - 1;while (left <…

重新配置node.js,npm,环境变量

起因是检查最近收到的一些朋友分享给我的各种资料&#xff0c;什么前端&#xff0c;后端&#xff0c;java,go,python等语言&#xff0c;想着将一个模拟QQ音乐的一个源代码进行跑通&#xff0c;看看有什么特别之处。如下图 出现了node环境路径问题&#xff0c;参考链接 https:/…

回收站的数据删了可以找回来吗?方法已备好

在数字化时代&#xff0c;数据的安全性与恢复问题逐渐受到大家的关注。回收站&#xff0c;作为电脑中存储已删除文件的地方&#xff0c;常常被视为数据恢复的“救命稻草”。然而&#xff0c;当回收站中的数据也被删除时&#xff0c;许多人可能会感到无助和困惑。本文旨在探讨回…

nuclei使用方法

nuclei使用方法 查看帮助 nuclei -h 列出所有模板 nuclei -tl 查找某种cms的相关漏洞模板&#xff0c;wordpress为例 nuclei -tl -tc "contains(name,wordpress)"便会列出内容里含有wordpress关键字的漏洞检测模板 使用与某cms相关的所有漏洞模板进行扫描&#…

每日一题 --- 209. 长度最小的子数组[力扣][Go]

长度最小子数组 题目&#xff1a; 给定一个含有 n 个正整数的数组和一个正整数 target 。 找出该数组中满足其总和大于等于 target 的长度最小的 连续 子数组 [numsl, numsl1, ..., numsr-1, numsr] &#xff0c;并返回其长度**。**如果不存在符合条件的子数组&#xff0c…

web学习笔记(四十三)ajax

目录 1.相关基础概念 1.1客户端与服务器 1.2URL地址 1.3 客户端和服务器端通信的过程 1.4 一个URL地址放入浏览器&#xff0c;到页面渲染发生了什么事情 1.5 数据 1.6资源的请求方式 2.Ajax 2.1什么是Ajax 2.2 jQuery 中的Ajax 2.2.1 $.get()的语法 2.2.2$.post()…

Linux:http协议初步认识

文章目录 OSI七层模型http协议域名路径信息请求和响应 编写一个httpserver OSI七层模型 在结束了前面对于序列化反序列化等内容的学习后&#xff0c;重新回到对于OSI模型的部分 如上所示的是对于OSI接口的示意图&#xff0c;在这当中可以看到会话层的概念&#xff0c;会话层的…

简介:KMeans聚类算法

在机器学习中&#xff0c;无监督学习一直是我们追求的方向&#xff0c;而其中的聚类算法更是发现隐藏数据结构与知识的有效手段。聚类是一种包括数据点分组的机器学习技术。给定一组数据点&#xff0c;我们可以用聚类算法将每个数据点分到特定的组中。 理论上&#xff0c;属于同…

SQL Server 2008R2 日志文件大小设置及查询

SQL Server 2008R2 建立数据库存在日志无限增长问题&#xff0c;造成磁盘内存不足。本文解决这个问题&#xff0c;如下&#xff1a; 1.设置日志文件的最大大小 USE master; GO ALTER DATABASE [D_total] MODIFY FILE (NAME D_total_log, -- 日志文件的逻辑名称MAXSIZE 200…

LeetCode Python - 69. x 的平方根

目录 题目描述解法运行结果 题目描述 给你一个非负整数 x &#xff0c;计算并返回 x 的 算术平方根 。 由于返回类型是整数&#xff0c;结果只保留 整数部分 &#xff0c;小数部分将被 舍去 。 注意&#xff1a;不允许使用任何内置指数函数和算符&#xff0c;例如 pow(x, 0.…

稀碎从零算法笔记Day21-LeetCode:单词规律

题型&#xff1a;哈希表、字符串 链接&#xff1a;290. 单词规律 - 力扣&#xff08;LeetCode&#xff09; 来源&#xff1a;LeetCode 题目描述 给定一种规律 pattern 和一个字符串 s &#xff0c;判断 s 是否遵循相同的规律。 这里的 遵循 指完全匹配&#xff0c;例如&am…

先进电机技术 —— 何为轴电压?

一、特定场景举例 长线驱动指的是在电动机与变频器之间存在较长的连接电缆&#xff0c;尤其是在大型工业应用中&#xff0c;电机可能远离变频器几十米甚至上百米。这种情况下&#xff0c;变频器通过长线向电动机传输功率时&#xff0c;可能会加剧电机轴电压和轴电流的产生&…

Flutter-仿携程首页类型切换

效果 唠叨 闲来无事&#xff0c;不小心下载了携程app&#xff0c;还幻想可以去旅游一番&#xff0c;奈何自己运气不好&#xff0c;自从高考时第一次吹空调导致自己拉肚子考试&#xff0c;物理&#xff0c;数学考了一半就交卷&#xff0c;英语2B铅笔除了问题&#xff0c;导致原…

ensp静态路由综合实验(一)

实验拓扑&#xff1a; 实验目的&#xff1a; 1、R6为ISP&#xff0c;接口IP地址均为公有地址&#xff0c;该设备只能配置IP地址&#xff0c;之后不能再对其进行任何配置&#xff1b; 2、R1-R5为局域网&#xff0c;私有IP地址192.168.1.0/24&#xff0c;请合理分配&#xff1b;…

【图解物联网】第7章 物联网与可穿戴设备

7.1 可穿戴设备的基础 顾名思义&#xff0c;可穿戴设备就是指穿戴在身上的设备&#xff0c;因此&#xff0c;比起单独使用前面说的那些设备&#xff0c;可穿戴设备能够令服务更加贴近人们的生活。如果你想率先实现物联网服务&#xff0c;那么就可以选择使用可穿戴设备。 …

java业务需求——爆金币

假设我们要模拟金铲铲中塔姆的爆金币需求&#xff0c;我们该如何实现该需求呢&#xff1f; 所以假设下面具体场景&#xff1a; 1.在每一回合的15s中&#xff0c;该棋子不断被攻击。 2.该棋子被攻击时有十分之三的概率的会爆出一个金币&#xff0c; 3.每被攻击10次必爆一个金币…

HAL STM32G4 +TIM1 3路PWM互补输出+VOFA波形演示

HAL STM32G4 TIM1 3路PWM互补输出VOFA波形演示 ✨最近学习研究无刷电机驱动&#xff0c;虽然之前有使用过&#xff0c;但是在STM32上还没实现过。本文内容参考欧拉电子例程&#xff0c;从PWM驱动开始学习。 欧拉电子相关视频讲解&#xff1a; STM32G4 FOC开发实战—高级定时器发…

【光标精灵】让您享受鼠标皮肤多样化快捷更换

鼠标作为我们日常使用频率最高的“小伙伴”&#xff0c;扮演着至关重要的角色。尤其是在女生群体中&#xff0c;对于打造一个个性化、可爱的电脑桌面和软件界面的需求日益增长。然而&#xff0c;尽管电脑默认提供了一些可更换的光标图案&#xff0c;但仍显得有些单调和呆板。想…