学习transformer模型-权重矩阵Wq,Wk,Wv的困扰和解析

news2024/11/24 12:44:04

背景:

学习transformer模型,计算multiHead attention的时候,权重矩阵Wq,Wk,Wv给我造成了很大的困扰:

1,为啥要需要W*?

2,这个W*是从哪里来的?

搜索了各种信息,消化理解如下:

1,W*权重矩阵就是训练的目的,就是要找到合适的W*(weights)。

2,W* 是函数nn.Linear初始化的,默认为随机数。经过不断地训练,更新,最终获得比较好的结果

训练W*过程举例:

在PyTorch中,训练一个包含nn.Linear层的神经网络涉及几个关键步骤。以下是一个基本的训练流程:

1. 定义模型结构

首先,你需要定义你的神经网络模型,这包括使用nn.Linear来创建全连接层。

import torch
import torch.nn as nn
class MyModel(nn.Module):
def __init__(self, input_size, hidden_size, output_size):
super(MyModel, self).__init__()
self.fc1 = nn.Linear(input_size, hidden_size) # 第一个全连接层
self.relu = nn.ReLU() # 激活函数
self.fc2 = nn.Linear(hidden_size, output_size) # 第二个全连接层(输出层)
def forward(self, x):
x = self.fc1(x)
x = self.relu(x)
x = self.fc2(x)
return x
# 实例化模型
input_size = 784 # 假设输入是28x28的图像,展平后为784维
hidden_size = 128 # 隐藏层的大小
output_size = 10 # 假设有10个分类
model = MyModel(input_size, hidden_size, output_size)

2. 定义损失函数和优化器

接下来,你需要选择一个合适的损失函数和优化器。损失函数用于衡量模型预测与真实标签之间的差异,而优化器则用于根据损失函数的梯度更新模型的权重。

criterion = nn.CrossEntropyLoss() # 多分类问题常用的损失函数
optimizer = torch.optim.Adam(model.parameters(), lr=0.001) # 使用Adam优化器,学习率设为0.001

3. 准备数据集

你需要准备训练数据集和验证数据集(如果有的话)。这些数据集应该被转换为PyTorch张量,并且通常会被划分为小批量以便进行迭代训练。

# 假设你已经有了训练数据和标签
train_data = ...
train_labels = ...
# 转换为张量
train_data = torch.tensor(train_data, dtype=torch.float32)
train_labels = torch.tensor(train_labels, dtype=torch.long)

4. 训练循环

现在你可以开始训练循环了。在每个epoch中,你会遍历整个数据集(或其一个子集),进行前向传播、计算损失、反向传播和参数更新。

num_epochs = 10 # 训练轮数
for epoch in range(num_epochs):
# 将梯度清零,否则梯度会累积
optimizer.zero_grad()
# 前向传播
outputs = model(train_data)
# 计算损失
loss = criterion(outputs, train_labels)
# 反向传播
loss.backward()
# 更新参数
optimizer.step()
# 打印统计信息(可选)
if (epoch+1) % 10 == 0:
print(f'Epoch [{epoch+1}/{num_epochs}], Loss: {loss.item():.4f}')

5. 验证和测试

在训练过程中或训练结束后,你可能还希望验证模型的性能。这通常通过在验证集或测试集上运行模型并计算相关指标(如准确率)来完成。

# 假设你也有一个验证集
val_data = ...
val_labels = ...
# 转换为张量
val_data = torch.tensor(val_data, dtype=torch.float32)
val_labels = torch.tensor(val_labels, dtype=torch.long)
# 不需要计算梯度
with torch.no_grad():
val_outputs = model(val_data)
val_loss = criterion(val_outputs, val_labels)
_, predicted = torch.max(val_outputs, 1)
correct = (predicted == val_labels).sum().item()
accuracy = correct / val_labels.size(0)
print(f'Validation Loss: {val_loss.item():.4f}, Accuracy: {accuracy:.4f}')

这就是训练包含nn.Linear层的神经网络的基本流程。在实际应用中,你可能还需要添加其他组件,如数据加载器、学习率调度器、模型保存

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1557975.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

Python-VBA编程500例-027(入门级)

验证字符串能否转换(Verify Whether A String Can Be Converted)在多个实际应用场景中扮演着重要角色。常见的应用场景有: 1、数据清洗与预处理:在数据处理和分析过程中,原始数据可能包含格式错误、多余字符或不符合规范的内容。验证字符串…

『Apisix系列』破局传统架构:探索新一代微服务体系下的API管理新范式与最佳实践

文章目录 『Apisix基石篇』『Apisix入门篇』『Apisix进阶篇』『Apisix安全篇』 『Apisix基石篇』 🚀 手把手教你从零部署APISIX高性能API网关 利用Docker-compose快速部署Apache APISIX及其依赖组件,实现高效的API网关搭建通过编写RPM安装脚本来自动化安…

love 2d win 下超简单安装方式,学习Lua 中文编程 刚需!!

一、下载love 2d 参考:【Love2d从青铜到王者】第一篇:Love2d入门以及安装教程 或直接下载: 64位,现在一般电脑都可以用。 64-bit zipped 32位,很复古的电脑都可以用。 32-bit zipped 二、解压 下载好了之后,解压到…

54 npm run serve 和 npm run build 输出的关联和差异

前言 通常来说 我们开发的时候一般会用到的命令是 “npm run serve”, “npm run build” 前者会编译当前项目, 然后将编译之后的结果以 node 的形式启动一个服务, 暴露相关业务资源, 因此 我们可以通过 该服务访问到当前项目 后者是编译当前项目, 然后做一下最小化代码的优…

机器学习—— PU-Learning算法

机器学习—— PU-Learning算法 本篇博客将介绍PU-Learning算法的基本概念、基本流程、基本方法,并简单探讨Two-step PU Learning算法和无偏PU Learning算法的具体流程。最后,将通过Python代码实现一个简单的PU-Learning示例,以便更好地理解这…

动态规划——回文串问题

目录 练习1:回文子串 练习2:最长回文子串 练习3:回文串分割IV 练习4:分割回文串 练习5:最长回文子序列 练习6:让字符串成为回文串的最小插入次数 本篇文章主要学习使用动态规划来解决回文串相关问题&…

【C#】知识点速通

前言: 笔者是跟着哔站课程(Trigger)学习unity才去学习的C#,并且C语言功底尚存,所以只是简单地跟着课程将unity所用的C#语言的关键部分进行了了解,然后在后期unity学习过程中加以深度学习。如需完善的C#知识…

JDBC远程连接mysql报错:NotBefore: Sat Mar 30 16:37:41 UTC 2024

虚拟机docker已经部署了mysql,用navicat可以直接远程连接,datagrip却不能,如图: 需要在最后加上 ?useSSLfalse , 如:jdbc:mysql://192.168.30.128:3306?useSSLfalse navicat不用加的原因是没有使用jdbc连接&#x…

实验二 pandas库绘图以及数据清洗

1.1pandas验证操作 1、验证以下代码,并将结果附截图 import pandas as pd A[1,3,6,4,9,10,15] weight[67,66,83,68,79,88] sex[女,男,男,女,男, 男] S1pd.Series(A)#构建S1序列 print(S1) S2pd.Series(weight)#构建S2序列 print(S2) S3pd.Series(sex)#构建S3序列 p…

第3章.引导ChatGPT精准角色扮演:高效输出专业内容

角色提示技术 角色提示技术(role prompting technique),是通过模型扮演特定角色来产出文本的一种方法。用户为模型设定一个明确的角色,它就能更精准地生成符合特定上下文或听众需求的内容。 比如,想生成客户服务的回复…

STM32的DMA

DMA(Direct memory access)直接存储器存取,用来提供在外设和存储器之间或者存储 器和存储器之间的高速数据传输,无须CPU干预,数据可以通过DMA快速地移动,这就节 省了CPU的资源来做其他操作。 STM32有两个DMA控制器共12个通道(DMA1有7个通道…

(八)目标跟踪中参数估计(似然、贝叶斯估计)理论知识

目录 前言 一、统计学基础知识 (一)随机变量 (二)全概率公式 (三)高斯分布及其性质 二、似然是什么? (一)概率和似然 (二)极大似然估计 …

Linux内核之Binder驱动container_of进阶用法(三十四)

简介: CSDN博客专家,专注Android/Linux系统,分享多mic语音方案、音视频、编解码等技术,与大家一起成长! 优质专栏:Audio工程师进阶系列【原创干货持续更新中……】🚀 优质专栏:多媒…

EasyRecovery2024汉化精简版,无需注册

EasyRecovery2024是世界著名数据恢复公司 Ontrack 的技术杰作,它是一个威力非常强大的硬盘数据恢复软件。能够帮你恢复丢失的数据以及重建文件系统。 EasyRecovery不会向你的原始驱动器写入任何东东,它主要是在内存中重建文件分区表使数据能够安全地传输…

ctfshow web入门 XXE

XXE基础知识 XXE(XML External Entity)攻击是一种针对XML处理漏洞的网络安全攻击手段。攻击者利用应用程序在解析XML输入时的漏洞,构造恶意的XML数据,进而实现各种恶意目的。 所以要学习xxe就需要了解xml xml相关: …

Chrome浏览器 安装Vue插件vue-devtools

前言 vue-devtools 是一个为 Vue.js 开发者设计的 Chrome 插件。它可以让你更轻松地审查和调试 Vue 应用程序。与普通的浏览器控制台工具不同,Vue.js devtools 专为 Vue 的响应性数据和组件结构量身定做。 1. 功能介绍 组件树浏览:这个功能可以让你查…

使用Python实现ID3决策树中特征选择的先后顺序,字节跳动面试真题

def empty1(pri_data): hair [] #[‘长’, ‘短’, ‘短’, ‘长’, ‘短’, ‘短’, ‘长’, ‘长’] voice [] #[‘粗’, ‘粗’, ‘粗’, ‘细’, ‘细’, ‘粗’, ‘粗’, ‘粗’] sex [] #[‘男’, ‘男’, ‘男’, ‘女’, ‘女’, ‘女’, ‘女’, ‘女’] for o…

OpenHarmony error: signature verification failed due to not trusted app source

问题:error: signature verification failed due to not trusted app source 今天在做OpenHarmony App开发,之前一直用的设备A在测试开效果,今天换成了设备B,通过DevEco Studio安装应用程序的时候,就出现错误&#xf…

爱上数据结构:栈和队列的概念及使用

​ ​ 🔥个人主页:guoguoqiang. 🔥专栏:数据结构 ​ 一、栈 1.栈的基本概念 栈:一种特殊的线性表,其只允许在固定的一端进行插入和删除元素操作。进行数据插入和删除操作的一端 称为栈顶,…

不同Python版本和wxPython版本用pyinstaller打包文件大小对比

1、确定wxPython和Python版本的对应关系 在这里可以找到Python支持的所有wxPython版本:https://pypi.tuna.tsinghua.edu.cn/simple/wxpython/ 由于Python从3.6版本开始支持f字符串、从3.9版本开始不支持Windows7操作系统,所以我仅筛选3.6-3.8之间的版本…