BERT 词向量理解及训练更新

news2024/10/6 16:20:24

1、BERT 词向量理解

在预训练阶段中,词向量是在不断更新的,而在fine-tuning阶段中,词向量是固定不变的。在fine-tuning阶段中,我们使用预训练好的模型参数来对新的数据进行训练。

BERT模型在预训练阶段中,会学习词表中所有词的词向量。在学习过程中,词表中每个词的词向量是通过输入的语料来学习的。

在训练过程中,词表中每个词的词向量都是随机初始化的,然后通过训练数据和反向传播算法来不断更新。反向传播算法会根据当前的词向量和训练数据的误差来调整词向量的值,使得模型在语料中学到的语言知识能够更好地概括文本。

在预训练阶段结束之后,这些词向量就成为了预训练权重。在 fine-tuning 阶段中,使用这些预训练好的词向量来对新的数据进行训练。

2、词向量训练更新

词向量权重矩阵为什么能训练更新?

理解就是输入字x,1个神经元对应了多个神经元,权重(即是这个x的词向量)就是1对多的连接层上的权重,相当于是个线性函数的连接层参数

#比如这里输入层有1个神经元,输出层有3个神经元,因此W是一个1*3的矩阵,b是一个3维的向量

y = Wx + b

假设输入层有1个神经元x=2,W是1*3的矩阵 [[1, 2, 3]],W就是x的词向量,b是3维的向量 [1, 1, 1]。

那么 y = Wx + b

y = [[1, 2, 3]] * 2 + [1, 1, 1] = [2, 4, 6] + [1, 1, 1] = [3, 5, 7]

这就是输入值为2时,输出层的结果。

1)torch全连接层表示词向量,3维度词向量代码更新训练说明
import torch
import torch.nn as nn

# 定义模型
class LinearModel(nn.Module):
    def __init__(self):
        super(LinearModel, self).__init__()
        self.linear = nn.Linear(in_features=1, out_features=3)

    def forward(self, x):
        return self.linear(x)

# 实例化模型
model = LinearModel()

# 获取权重
print(model.linear.weight)
print(model.linear.bias)

# 设置输入
x = torch.tensor([[1.0]])

# 模型预测
print(model(x))

# 设置损失函数和优化器
criterion = nn.MSELoss()
optimizer = torch.optim.SGD(model.parameters(), lr=0.06)

# 训练循环
for epoch in range(1000):
    # 清空梯度
    optimizer.zero_grad()
    # 预测
    y_pred = model(x)
    # 计算损失
    loss = criterion(y_pred, torch.tensor([[2.0, 4.0, 6.0]]))
    # 反向传播
    loss.backward()
    # 更新权重
    optimizer.step()
    if (epoch+1) % 100 == 0:
        print(f'Epoch {epoch+1}, Loss: {loss.item()}')

# 训练后查看权重
print(model.linear.weight)
print(model.linear.bias)

线性连接层上参数训练前后变化
在这里插入图片描述

2)pytorch 词向量Embedding封装层表示,词向量更新训练代码

Embedding层理解,参考https://blog.csdn.net/weixin_42357472/article/details/120886559

import torch
import torch.nn as nn
import torch.optim as optim

# 定义词汇表大小和词向量维度
vocab_size = 10000
embedding_dim = 300

# 创建Embedding层
embedding = nn.Embedding(num_embeddings=vocab_size, embedding_dim=embedding_dim)

初始化的词向量矩阵向量,每行表示一个字(或词)的向量
在这里插入图片描述

# 定义输入数据
inputs = torch.LongTensor([1, 2, 3, 4])

# 定义标签数据
labels = torch.LongTensor([1, 2, 3, 4])


# 定义损失函数
criterion = nn.CrossEntropyLoss()

# 定义优化器
optimizer = torch.optim.SGD(embedding.parameters(), lr=0.3)


# 迭代训练
for epoch in range(100):
    # 计算模型输出
    outputs = embedding(inputs)
    
    # 计算损失
    loss = criterion(outputs, labels)
    
    # 清空梯度
    optimizer.zero_grad()
    
    # 反向传播
    loss.backward()
    
    # 更新权重
    optimizer.step()


embedding.weight

训练后的词向量矩阵向量变化,因为只有1,2,3,4行测试数据,所以只更新的是这几行的向量
在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/164630.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

win10开启自带的手机投屏功能方式

本篇文章主要讲解win10开启自带的手机投屏方式。 日期:2023年1月15日 作者:任聪聪 开启后效果 点击连接 打开连接或通过手机其他网络进行连接。 连接步骤: 步骤一、打开手机端的wifi网络设置,点击高级设置或其他网络设置&…

论文的正确打开方式—如何细读一篇论文分享

前段时间听了一个关于读论文的公开课,课上的老师讲的非常好,听完之后确实发现从以前看论文的没头没脑到现在的有了一些思绪的变化,所以特此整理了一下分享给大家,希望对大家有用。 在我们初次接触论文的时候,经常性的遇…

《后端技术面试 38 讲》学习笔记 Day 12

《后端技术面试 38 讲》学习笔记 Day 12 31 | 大数据架构:大数据技术架构的思想和原理是什么? 原文摘抄 大数据技术其实是分布式技术在数据处理领域的创新性应用,本质和我们此前讲到的分布式技术思路一脉相承:用更多的计算机组成…

smart-doc的使用

smart-doc的使用 目录 1. 什么是smart-doc 2. smart-doc的功能特性 3. smart-doc自定义注释tag 4. 通过引入依赖生成文档 5. 通过集成smart-doc的maven插件生成文档 6. 生成Postman json文件与导入Postman测试 1. 什么是smart-doc smart-doc是一款同时支持JAVA REST API和…

MySQL监控(二): Prometheus入门

1.官网 OpenTelemetry - CNCF Prometheus官方文档 安装包下载页 Prometheus安装官方文档指引 2.安装mysqld_exporter (1)下载 mysqld_exporter下载 (2)配置文件 my.cnf [client] hostxx.xx.xx.xx port31090 userroot passwordroot(3)启动 启动命令: nohup …

关于常见排序的一些细节的理解

最近复习了一下十种基本的排序算法,但是发现有很多的细节理解不到位,不是忘了而是根本没理解。就比如为啥有的排序是不稳定排序,而有的排序的时间复杂度高等等问题。一、不稳定排序的稳定性分析和复杂度常见排序算法中有4种排序是不稳定的。快…

详解最近公共祖先(LCA)

看本博客前建议先看一下ST算法解决BMQ问题详解一,LCA概念最近公共祖先(Lowest Common Ancestors, LCA)指有根树中距离两个节点最近的公共祖先。祖先指从当前节点到树根路径上的所有节点。u和v的公共祖先指一个节点既是u的祖先,又是v的祖先。u和v的最近公…

php网上书城|基于PHP实现网上书店商城藉项目

作者主页:编程指南针 作者简介:Java领域优质创作者、CSDN博客专家 、掘金特邀作者、多年架构师设计经验、腾讯课堂常驻讲师 主要内容:Java项目、毕业设计、简历模板、学习资料、面试题库、技术互助 收藏点赞不迷路 关注作者有好处 文末获取源…

3分钟秒懂,最简单通俗易懂的spring bean 生命周期介绍与源码分析,附上demo完整源码

文章写作背景 最近突然身边很多小伙伴问我有没有spring bean生命周期的通俗移动的介绍 起初不太理解为什么,后来才想明白,哦对了,年底了,快开始跳槽季了,这不就是java八股文面试 的题目嘛,不得不说&#xf…

【5G RRC】Master Information Block (NR-MIB)

博主未授权任何人或组织机构转载博主任何原创文章,感谢各位对原创的支持! 博主链接 本人就职于国际知名终端厂商,负责modem芯片研发。 在5G早期负责终端数据业务层、核心网相关的开发工作,目前牵头6G算力网络技术标准研究。 博客…

手把手教你分析 Linux 启动流程

下载 Linux 内核网址: https://www.kernel.org/ 常用 Linux 内核源码为 4.14、4.19、4.9、5.10、5.15、6.1 等版本,其中 4.14 版本源码压缩包大概 90+M,解压后 700+M,合计 61350 个文件。如此众多的文件,用 source insight 或者 VSCode 查看都会比较卡,所以可以采用在线…

计算机网络第四章

1.网络层主要任务是把分组从源端传到目的端,为分组交换网上的不同主机提供通信服务,网络层传输单位是数据报三个功能:路由选择与分组转发(最佳路径)异构网络互联拥塞控制数据交换方式三种交换方式:电路交换…

一动不动是王八?动态内存有话说

文章目录前言动态内存函数介绍mallocfreecallocrealloc柔性数组柔性数组特点柔性数组的优点方便内存释放提高我们的访问速度总结前言 一动不动是王八,出自2014年的春晚,小时候经常喜欢说这句话,那在我们C语言中,我们知道&#xf…

年度征文|一个业余电脑玩家的30年(1992-2022)

《论语为政》:“五十而知天命”。岁月真的是一把刀,一晃已过不惑之年,还有几天就要进入知非之年。不论知非还是知天命,反正是花甲将至而从心所欲了。年少时因某种不合机缘,错与IT界擦肩而过,每每想起就扼腕…

gradel学习+IDEA配置

Gradle的下载 Gradle下载地址如下 https://gradle.org/releases/ 我自己的下载的7.4.2 可以选择下载完整的压缩包,将压缩包解压到自己指定的目录中即可。 Gradle安装 1、配置系统变量 GRADLE_HOME 2、配置环境变量 %GRADLE_HOME%是获取变量名称为GRADLE_HOME的…

项目看板开发经验分享(一)——光伏绿色能源看板

今天新开一个系列,专门介绍近期工作中开发的几个比较酷炫的看板的开发思路与经验分享。第一节我们就来介绍下这个光伏绿色能源看板,整体浏览如下: 那就直接进入正题吧—— 0、可复用组件panel 在讲解各个模块之前,我们先来完成一…

Mybatis 框架下 SQL 注入攻击的 3 种方式

SQL注入漏洞作为WEB安全的最常见的漏洞之一,在java中随着预编译与各种ORM框架的使用,注入问题也越来越少。 新手代码审计者往往对Java Web应用的多个框架组合而心生畏惧,不知如何下手,希望通过Mybatis框架使用不当导致的SQL注入问…

Node.js学习笔记

Node.js学习笔记 浏览器的内核包括两部分核心:DOM渲染引擎、JavaScript解析引擎。脱离浏览器环境也可以运行JavaScript,只要有JavaScript引擎就可以。 Node.js是一个基于Chrome V8引擎的JavaScript运行环境。Node.js内置了Chrome的V8 引擎,…

SpringBoot项目部署

系列文章目录 Spring Boot[概述、功能、快速入门]_心态还需努力呀的博客-CSDN博客 Spring Boot读取配置文件内容的三种方式_心态还需努力呀的博客-CSDN博客 Spring Boot整合Junit_心态还需努力呀的博客-CSDN博客 Spring Boot自动配置--如何切换内置Web服务器_心态还需努力呀…

Open3D SOR滤波(Python版本)

文章目录一、简介二、实现代码三、实现效果参考资料一、简介 SOR滤波过程相对简单,其原理是通过查询点与邻域点集之间的距离统计判断来进行过滤离群点。假设一个点的邻近点集符合正太分布,因此我们可以通过计算出该点到它所有临近点的平均距离meanD和标准…